Compare commits
144 Commits
main
...
provider-e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4a32d2e871 | ||
|
|
65c22bd356 | ||
|
|
d601cce315 | ||
|
|
3c2207b3a0 | ||
|
|
9d23e5733c | ||
|
|
3a6e91abcb | ||
|
|
4f22272b0d | ||
|
|
20d7513c73 | ||
|
|
493f8d59e6 | ||
|
|
65a395fa9a | ||
|
|
ca8279ca79 | ||
|
|
19833f0132 | ||
|
|
ad0687a987 | ||
|
|
a51b99216d | ||
|
|
3de07eaf0c | ||
|
|
5fa97e8da8 | ||
|
|
6acc4cc038 | ||
|
|
6a07fe4e99 | ||
|
|
6f05a4b6df | ||
|
|
78f9f4a768 | ||
|
|
46dedb3e13 | ||
|
|
ea5800b322 | ||
|
|
b652196356 | ||
|
|
155a2d2a1e | ||
|
|
f182aa43bb | ||
|
|
f783f22e33 | ||
|
|
6811c57550 | ||
|
|
5739fce607 | ||
|
|
f0fc578fe6 | ||
|
|
7cbc6fb337 | ||
|
|
55c9113177 | ||
|
|
98248d5a7a | ||
|
|
8d5b12a6be | ||
|
|
aa69a52685 | ||
|
|
d5e2a2a00c | ||
|
|
094f514414 | ||
|
|
5abf968748 | ||
|
|
dd455306b2 | ||
|
|
dd4d5b5b0c | ||
|
|
cc7799af38 | ||
|
|
13776b7898 | ||
|
|
67f3b0987a | ||
|
|
6a6b556143 | ||
|
|
3debec1393 | ||
|
|
bde75bb11a | ||
|
|
eff0105c04 | ||
|
|
4bbc53b0ee | ||
|
|
00a62555ec | ||
|
|
d1f085c063 | ||
|
|
73341e51ac | ||
|
|
ed111bf528 | ||
|
|
64966bbecc | ||
|
|
fe895c7c97 | ||
|
|
9c2c9ea949 | ||
|
|
f46b94635d | ||
|
|
b9c8f8b79e | ||
|
|
6ac42dde0d | ||
|
|
7f51ca3dbb | ||
|
|
c050b4225a | ||
|
|
b2073af63a | ||
|
|
a52e4af96d | ||
|
|
35aa3f2207 | ||
|
|
1a808c4642 | ||
|
|
fda2688165 | ||
|
|
7881047432 | ||
|
|
da9281c4a4 | ||
|
|
9cc517e0dd | ||
|
|
d1390a5b78 | ||
|
|
ee4faede38 | ||
|
|
8d96a699b3 | ||
|
|
8cfb7471db | ||
|
|
def9c87837 | ||
|
|
0313ab6d41 | ||
|
|
c5329fdff2 | ||
|
|
a676a6895b | ||
|
|
3b5d7d7d89 | ||
|
|
91f01131b1 | ||
|
|
5fa5226286 | ||
|
|
ae94007227 | ||
|
|
8f425a1bd5 | ||
|
|
743c414e7b | ||
|
|
0fe335efc5 | ||
|
|
36b95aac4b | ||
|
|
b2df70ab58 | ||
|
|
36293d7dd9 | ||
|
|
3ae3e1fce8 | ||
|
|
e5f1fc7478 | ||
|
|
a4f6076da7 | ||
|
|
43726b2620 | ||
|
|
94980ffb49 | ||
|
|
22cc731450 | ||
|
|
d9396373e3 | ||
|
|
48002be135 | ||
|
|
58db83f8f5 | ||
|
|
0243d5b542 | ||
|
|
06230327fa | ||
|
|
ca5c8992f9 | ||
|
|
1038e1c2ef | ||
|
|
e1fe0b3287 | ||
|
|
a0e10a91bf | ||
|
|
272b1aa4bc | ||
|
|
9ef0537b44 | ||
|
|
77f1de742b | ||
|
|
e054cabd41 | ||
|
|
3b95cb5682 | ||
|
|
c89653bd07 | ||
|
|
b90ac2dc07 | ||
|
|
c9998541f0 | ||
|
|
e2b49b3cd3 | ||
|
|
d1e77397c6 | ||
|
|
cc5f5e35e4 | ||
|
|
7183b8a1cd | ||
|
|
b1934fb712 | ||
|
|
a198b6c0d1 | ||
|
|
8b5b2712c8 | ||
|
|
4464392e8e | ||
|
|
a0d3bc31e9 | ||
|
|
ccd6672d1a | ||
|
|
21de6d35dd | ||
|
|
2031ca17e5 | ||
|
|
8b1ce75a57 | ||
|
|
5559726fd7 | ||
|
|
e1a9269921 | ||
|
|
3b6b3ff504 | ||
|
|
aabed94970 | ||
|
|
2d3a3521ba | ||
|
|
a48bd10da0 | ||
|
|
fec9525be4 | ||
|
|
bf2b8e999e | ||
|
|
63c35d2b00 | ||
|
|
1396c68010 | ||
|
|
fcb3d3dec6 | ||
|
|
f54e7f8c9d | ||
|
|
2a89529d7f | ||
|
|
58207325e2 | ||
|
|
e08ab99e8d | ||
|
|
a95f3f33a4 | ||
|
|
b0767c1b1f | ||
|
|
b200e10bc4 | ||
|
|
948905d916 | ||
|
|
04de456373 | ||
|
|
e5ce32e936 | ||
|
|
d7caae30de | ||
|
|
c7e77674a1 |
11
Cargo.lock
generated
11
Cargo.lock
generated
@@ -5924,9 +5924,11 @@ dependencies = [
|
||||
"async-trait",
|
||||
"client",
|
||||
"collections",
|
||||
"credentials_provider",
|
||||
"criterion",
|
||||
"ctor",
|
||||
"dap",
|
||||
"dirs 4.0.0",
|
||||
"extension",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
@@ -5935,8 +5937,11 @@ dependencies = [
|
||||
"http_client",
|
||||
"language",
|
||||
"language_extension",
|
||||
"language_model",
|
||||
"log",
|
||||
"lsp",
|
||||
"markdown",
|
||||
"menu",
|
||||
"moka",
|
||||
"node_runtime",
|
||||
"parking_lot",
|
||||
@@ -5951,17 +5956,21 @@ dependencies = [
|
||||
"serde_json",
|
||||
"serde_json_lenient",
|
||||
"settings",
|
||||
"smol",
|
||||
"task",
|
||||
"telemetry",
|
||||
"tempfile",
|
||||
"theme",
|
||||
"theme_extension",
|
||||
"toml 0.8.23",
|
||||
"ui",
|
||||
"ui_input",
|
||||
"url",
|
||||
"util",
|
||||
"wasmparser 0.221.3",
|
||||
"wasmtime",
|
||||
"wasmtime-wasi",
|
||||
"workspace",
|
||||
"zlog",
|
||||
]
|
||||
|
||||
@@ -14873,6 +14882,7 @@ dependencies = [
|
||||
"copilot",
|
||||
"edit_prediction",
|
||||
"editor",
|
||||
"extension_host",
|
||||
"feature_flags",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
@@ -14880,6 +14890,7 @@ dependencies = [
|
||||
"gpui",
|
||||
"heck 0.5.0",
|
||||
"language",
|
||||
"language_model",
|
||||
"language_models",
|
||||
"log",
|
||||
"menu",
|
||||
|
||||
@@ -3,11 +3,11 @@ use agent_client_protocol::{self as acp};
|
||||
use anyhow::Result;
|
||||
use collections::IndexMap;
|
||||
use gpui::{Entity, SharedString, Task};
|
||||
use language_model::LanguageModelProviderId;
|
||||
use language_model::{IconOrSvg, LanguageModelProviderId};
|
||||
use project::Project;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{any::Any, error::Error, fmt, path::Path, rc::Rc, sync::Arc};
|
||||
use ui::{App, IconName};
|
||||
use ui::App;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Hash)]
|
||||
@@ -210,21 +210,12 @@ pub trait AgentModelSelector: 'static {
|
||||
}
|
||||
}
|
||||
|
||||
/// Icon for a model in the model selector.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum AgentModelIcon {
|
||||
/// A built-in icon from Zed's icon set.
|
||||
Named(IconName),
|
||||
/// Path to a custom SVG icon file.
|
||||
Path(SharedString),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct AgentModelInfo {
|
||||
pub id: acp::ModelId,
|
||||
pub name: SharedString,
|
||||
pub description: Option<SharedString>,
|
||||
pub icon: Option<AgentModelIcon>,
|
||||
pub icon: Option<IconOrSvg>,
|
||||
}
|
||||
|
||||
impl From<acp::ModelInfo> for AgentModelInfo {
|
||||
|
||||
@@ -739,7 +739,7 @@ impl ActivityIndicator {
|
||||
extension_store.outstanding_operations().iter().next()
|
||||
{
|
||||
let (message, icon, rotate) = match operation {
|
||||
ExtensionOperation::Install => (
|
||||
ExtensionOperation::Install | ExtensionOperation::AutoInstall => (
|
||||
format!("Installing {extension_id} extension…"),
|
||||
IconName::LoadCircle,
|
||||
true,
|
||||
|
||||
@@ -30,7 +30,7 @@ use futures::{StreamExt, future};
|
||||
use gpui::{
|
||||
App, AppContext, AsyncApp, Context, Entity, SharedString, Subscription, Task, WeakEntity,
|
||||
};
|
||||
use language_model::{IconOrSvg, LanguageModel, LanguageModelProvider, LanguageModelRegistry};
|
||||
use language_model::{LanguageModel, LanguageModelProvider, LanguageModelRegistry};
|
||||
use project::{Project, ProjectItem, ProjectPath, Worktree};
|
||||
use prompt_store::{
|
||||
ProjectContext, PromptStore, RULES_FILE_NAMES, RulesFileContext, UserRulesContext,
|
||||
@@ -153,10 +153,7 @@ impl LanguageModels {
|
||||
id: Self::model_id(model),
|
||||
name: model.name().0,
|
||||
description: None,
|
||||
icon: Some(match provider.icon() {
|
||||
IconOrSvg::Svg(path) => acp_thread::AgentModelIcon::Path(path),
|
||||
IconOrSvg::Icon(name) => acp_thread::AgentModelIcon::Named(name),
|
||||
}),
|
||||
icon: Some(provider.icon()),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1633,9 +1630,7 @@ mod internal_tests {
|
||||
id: acp::ModelId::new("fake/fake"),
|
||||
name: "Fake".into(),
|
||||
description: None,
|
||||
icon: Some(acp_thread::AgentModelIcon::Named(
|
||||
ui::IconName::ZedAssistant
|
||||
)),
|
||||
icon: Some(language_model::IconOrSvg::Icon(ui::IconName::ZedAssistant)),
|
||||
}]
|
||||
)])
|
||||
);
|
||||
|
||||
@@ -5,7 +5,7 @@ use crate::{AgentServer, AgentServerDelegate, load_proxy_env};
|
||||
use acp_thread::AgentConnection;
|
||||
use anyhow::{Context as _, Result};
|
||||
use gpui::{App, SharedString, Task};
|
||||
use language_models::provider::google::GoogleLanguageModelProvider;
|
||||
use language_models::api_key_for_gemini_cli;
|
||||
use project::agent_server_store::GEMINI_NAME;
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -37,11 +37,7 @@ impl AgentServer for Gemini {
|
||||
cx.spawn(async move |cx| {
|
||||
extra_env.insert("SURFACE".to_owned(), "zed".to_owned());
|
||||
|
||||
if let Some(api_key) = cx
|
||||
.update(GoogleLanguageModelProvider::api_key_for_gemini_cli)?
|
||||
.await
|
||||
.ok()
|
||||
{
|
||||
if let Some(api_key) = cx.update(api_key_for_gemini_cli)?.await.ok() {
|
||||
extra_env.insert("GEMINI_API_KEY".into(), api_key);
|
||||
}
|
||||
let (command, root_dir, login) = store
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use std::{cmp::Reverse, rc::Rc, sync::Arc};
|
||||
|
||||
use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelList, AgentModelSelector};
|
||||
use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector};
|
||||
use agent_client_protocol::ModelId;
|
||||
use agent_servers::AgentServer;
|
||||
use agent_settings::AgentSettings;
|
||||
@@ -13,6 +13,7 @@ use gpui::{
|
||||
Action, AsyncWindowContext, BackgroundExecutor, DismissEvent, FocusHandle, Task, WeakEntity,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use language_model::IconOrSvg;
|
||||
use ordered_float::OrderedFloat;
|
||||
use picker::{Picker, PickerDelegate};
|
||||
use settings::Settings;
|
||||
@@ -351,8 +352,8 @@ impl PickerDelegate for AcpModelPickerDelegate {
|
||||
.child(
|
||||
ModelSelectorListItem::new(ix, model_info.name.clone())
|
||||
.map(|this| match &model_info.icon {
|
||||
Some(AgentModelIcon::Path(path)) => this.icon_path(path.clone()),
|
||||
Some(AgentModelIcon::Named(icon)) => this.icon(*icon),
|
||||
Some(IconOrSvg::Svg(path)) => this.icon_path(path.clone()),
|
||||
Some(IconOrSvg::Icon(icon)) => this.icon(*icon),
|
||||
None => this,
|
||||
})
|
||||
.is_selected(is_selected)
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
use std::rc::Rc;
|
||||
use std::sync::Arc;
|
||||
|
||||
use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelSelector};
|
||||
use acp_thread::{AgentModelInfo, AgentModelSelector};
|
||||
use agent_servers::AgentServer;
|
||||
use agent_settings::AgentSettings;
|
||||
use fs::Fs;
|
||||
use gpui::{Entity, FocusHandle};
|
||||
use language_model::IconOrSvg;
|
||||
use picker::popover_menu::PickerPopoverMenu;
|
||||
use settings::Settings as _;
|
||||
use ui::{ButtonLike, KeyBinding, PopoverMenuHandle, TintColor, Tooltip, prelude::*};
|
||||
@@ -127,8 +128,8 @@ impl Render for AcpModelSelectorPopover {
|
||||
.when_some(model_icon, |this, icon| {
|
||||
this.child(
|
||||
match icon {
|
||||
AgentModelIcon::Path(path) => Icon::from_external_svg(path),
|
||||
AgentModelIcon::Named(icon_name) => Icon::new(icon_name),
|
||||
IconOrSvg::Svg(path) => Icon::from_external_svg(path),
|
||||
IconOrSvg::Icon(icon_name) => Icon::new(icon_name),
|
||||
}
|
||||
.color(color)
|
||||
.size(IconSize::XSmall),
|
||||
|
||||
@@ -368,26 +368,49 @@ fn update_active_language_model_from_settings(cx: &mut App) {
|
||||
}
|
||||
}
|
||||
|
||||
let default = settings.default_model.as_ref().map(to_selected_model);
|
||||
// Filter out models from providers that are not authenticated
|
||||
fn is_provider_authenticated(
|
||||
selection: &LanguageModelSelection,
|
||||
registry: &LanguageModelRegistry,
|
||||
cx: &App,
|
||||
) -> bool {
|
||||
let provider_id = LanguageModelProviderId::from(selection.provider.0.clone());
|
||||
registry
|
||||
.provider(&provider_id)
|
||||
.map_or(false, |provider| provider.is_authenticated(cx))
|
||||
}
|
||||
|
||||
let registry = LanguageModelRegistry::global(cx);
|
||||
let registry_ref = registry.read(cx);
|
||||
|
||||
let default = settings
|
||||
.default_model
|
||||
.as_ref()
|
||||
.filter(|s| is_provider_authenticated(s, registry_ref, cx))
|
||||
.map(to_selected_model);
|
||||
let inline_assistant = settings
|
||||
.inline_assistant_model
|
||||
.as_ref()
|
||||
.filter(|s| is_provider_authenticated(s, registry_ref, cx))
|
||||
.map(to_selected_model);
|
||||
let commit_message = settings
|
||||
.commit_message_model
|
||||
.as_ref()
|
||||
.filter(|s| is_provider_authenticated(s, registry_ref, cx))
|
||||
.map(to_selected_model);
|
||||
let thread_summary = settings
|
||||
.thread_summary_model
|
||||
.as_ref()
|
||||
.filter(|s| is_provider_authenticated(s, registry_ref, cx))
|
||||
.map(to_selected_model);
|
||||
let inline_alternatives = settings
|
||||
.inline_alternatives
|
||||
.iter()
|
||||
.filter(|s| is_provider_authenticated(s, registry_ref, cx))
|
||||
.map(to_selected_model)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
|
||||
registry.update(cx, |registry, cx| {
|
||||
registry.select_default_model(default.as_ref(), cx);
|
||||
registry.select_inline_assistant_model(inline_assistant.as_ref(), cx);
|
||||
registry.select_commit_message_model(commit_message.as_ref(), cx);
|
||||
|
||||
@@ -2,10 +2,9 @@ use std::{cmp::Reverse, sync::Arc};
|
||||
|
||||
use agent_settings::AgentSettings;
|
||||
use collections::{HashMap, HashSet, IndexMap};
|
||||
use futures::{StreamExt, channel::mpsc};
|
||||
use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
|
||||
use gpui::{
|
||||
Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Subscription, Task,
|
||||
};
|
||||
use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Task};
|
||||
use language_model::{
|
||||
AuthenticateError, ConfiguredModel, IconOrSvg, LanguageModel, LanguageModelId,
|
||||
LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
|
||||
@@ -76,7 +75,7 @@ fn all_models(cx: &App) -> GroupedModels {
|
||||
})
|
||||
.collect();
|
||||
|
||||
let all = providers
|
||||
let all: Vec<ModelInfo> = providers
|
||||
.iter()
|
||||
.flat_map(|provider| {
|
||||
provider
|
||||
@@ -124,7 +123,7 @@ pub struct LanguageModelPickerDelegate {
|
||||
filtered_entries: Vec<LanguageModelPickerEntry>,
|
||||
selected_index: usize,
|
||||
_authenticate_all_providers_task: Task<()>,
|
||||
_subscriptions: Vec<Subscription>,
|
||||
_refresh_models_task: Task<()>,
|
||||
popover_styles: bool,
|
||||
focus_handle: FocusHandle,
|
||||
}
|
||||
@@ -151,24 +150,42 @@ impl LanguageModelPickerDelegate {
|
||||
get_active_model: Arc::new(get_active_model),
|
||||
on_toggle_favorite: Arc::new(on_toggle_favorite),
|
||||
_authenticate_all_providers_task: Self::authenticate_all_providers(cx),
|
||||
_subscriptions: vec![cx.subscribe_in(
|
||||
&LanguageModelRegistry::global(cx),
|
||||
window,
|
||||
|picker, _, event, window, cx| {
|
||||
match event {
|
||||
language_model::Event::ProviderStateChanged(_)
|
||||
| language_model::Event::AddedProvider(_)
|
||||
| language_model::Event::RemovedProvider(_) => {
|
||||
let query = picker.query(cx);
|
||||
picker.delegate.all_models = Arc::new(all_models(cx));
|
||||
// Update matches will automatically drop the previous task
|
||||
// if we get a provider event again
|
||||
picker.update_matches(query, window, cx)
|
||||
}
|
||||
_ => {}
|
||||
_refresh_models_task: {
|
||||
// Create a channel to signal when models need refreshing
|
||||
let (refresh_tx, mut refresh_rx) = mpsc::unbounded::<()>();
|
||||
|
||||
// Subscribe to registry events and send refresh signals through the channel
|
||||
let registry = LanguageModelRegistry::global(cx);
|
||||
cx.subscribe(®istry, move |_picker, _, event, _cx| match event {
|
||||
language_model::Event::ProviderStateChanged(_)
|
||||
| language_model::Event::AddedProvider(_)
|
||||
| language_model::Event::RemovedProvider(_)
|
||||
| language_model::Event::ProvidersChanged => {
|
||||
refresh_tx.unbounded_send(()).ok();
|
||||
}
|
||||
},
|
||||
)],
|
||||
language_model::Event::DefaultModelChanged
|
||||
| language_model::Event::InlineAssistantModelChanged
|
||||
| language_model::Event::CommitMessageModelChanged
|
||||
| language_model::Event::ThreadSummaryModelChanged => {}
|
||||
})
|
||||
.detach();
|
||||
|
||||
// Spawn a task that listens for refresh signals and updates the picker
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
while let Some(()) = refresh_rx.next().await {
|
||||
if this
|
||||
.update_in(cx, |picker, window, cx| {
|
||||
picker.delegate.all_models = Arc::new(all_models(cx));
|
||||
picker.refresh(window, cx);
|
||||
})
|
||||
.is_err()
|
||||
{
|
||||
// Picker was dropped, exit the loop
|
||||
break;
|
||||
}
|
||||
}
|
||||
})
|
||||
},
|
||||
popover_styles,
|
||||
focus_handle,
|
||||
}
|
||||
|
||||
@@ -1,11 +1,6 @@
|
||||
use gpui::{Action, FocusHandle, prelude::*};
|
||||
use ui::{ElevationIndex, KeyBinding, ListItem, ListItemSpacing, Tooltip, prelude::*};
|
||||
|
||||
enum ModelIcon {
|
||||
Name(IconName),
|
||||
Path(SharedString),
|
||||
}
|
||||
|
||||
#[derive(IntoElement)]
|
||||
pub struct ModelSelectorHeader {
|
||||
title: SharedString,
|
||||
@@ -40,6 +35,11 @@ impl RenderOnce for ModelSelectorHeader {
|
||||
}
|
||||
}
|
||||
|
||||
enum ModelIcon {
|
||||
Name(IconName),
|
||||
Path(SharedString),
|
||||
}
|
||||
|
||||
#[derive(IntoElement)]
|
||||
pub struct ModelSelectorListItem {
|
||||
index: usize,
|
||||
|
||||
@@ -8,7 +8,7 @@ use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::B
|
||||
use http_client::http::{self, HeaderMap, HeaderValue};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode};
|
||||
use serde::{Deserialize, Serialize};
|
||||
pub use settings::{AnthropicAvailableModel as AvailableModel, ModelMode};
|
||||
pub use settings::ModelMode;
|
||||
use strum::{EnumIter, EnumString};
|
||||
use thiserror::Error;
|
||||
|
||||
|
||||
@@ -389,6 +389,49 @@ pub trait ExtensionContextServerProxy: Send + Sync + 'static {
|
||||
fn unregister_context_server(&self, server_id: Arc<str>, cx: &mut App);
|
||||
}
|
||||
|
||||
/// A function that registers a language model provider with the registry.
|
||||
/// This allows extension_host to create the provider (which requires WasmExtension)
|
||||
/// and pass a registration closure to the language_models crate.
|
||||
pub type LanguageModelProviderRegistration = Box<dyn FnOnce(&mut App) + Send + Sync + 'static>;
|
||||
|
||||
pub trait ExtensionLanguageModelProviderProxy: Send + Sync + 'static {
|
||||
/// Register an LLM provider from an extension.
|
||||
/// The `register_fn` closure will be called with the App context and should
|
||||
/// register the provider with the LanguageModelRegistry.
|
||||
fn register_language_model_provider(
|
||||
&self,
|
||||
provider_id: Arc<str>,
|
||||
register_fn: LanguageModelProviderRegistration,
|
||||
cx: &mut App,
|
||||
);
|
||||
|
||||
/// Unregister an LLM provider when an extension is unloaded.
|
||||
fn unregister_language_model_provider(&self, provider_id: Arc<str>, cx: &mut App);
|
||||
}
|
||||
|
||||
impl ExtensionLanguageModelProviderProxy for ExtensionHostProxy {
|
||||
fn register_language_model_provider(
|
||||
&self,
|
||||
provider_id: Arc<str>,
|
||||
register_fn: LanguageModelProviderRegistration,
|
||||
cx: &mut App,
|
||||
) {
|
||||
let Some(proxy) = self.language_model_provider_proxy.read().clone() else {
|
||||
return;
|
||||
};
|
||||
|
||||
proxy.register_language_model_provider(provider_id, register_fn, cx)
|
||||
}
|
||||
|
||||
fn unregister_language_model_provider(&self, provider_id: Arc<str>, cx: &mut App) {
|
||||
let Some(proxy) = self.language_model_provider_proxy.read().clone() else {
|
||||
return;
|
||||
};
|
||||
|
||||
proxy.unregister_language_model_provider(provider_id, cx)
|
||||
}
|
||||
}
|
||||
|
||||
impl ExtensionContextServerProxy for ExtensionHostProxy {
|
||||
fn register_context_server(
|
||||
&self,
|
||||
|
||||
@@ -298,6 +298,58 @@ pub struct LanguageModelProviderManifestEntry {
|
||||
/// Path to an SVG icon file relative to the extension root (e.g., "icons/provider.svg").
|
||||
#[serde(default)]
|
||||
pub icon: Option<String>,
|
||||
/// Hardcoded models to always show (as opposed to a model list loaded over the network).
|
||||
#[serde(default)]
|
||||
pub models: Vec<LanguageModelManifestEntry>,
|
||||
/// Authentication configuration.
|
||||
#[serde(default)]
|
||||
pub auth: Option<LanguageModelAuthConfig>,
|
||||
}
|
||||
|
||||
/// Manifest entry for a language model.
|
||||
#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
|
||||
pub struct LanguageModelManifestEntry {
|
||||
/// Unique identifier for the model.
|
||||
pub id: String,
|
||||
/// Display name for the model.
|
||||
pub name: String,
|
||||
/// Maximum input token count.
|
||||
pub max_token_count: u64,
|
||||
/// Maximum output tokens (optional).
|
||||
pub max_output_tokens: Option<u64>,
|
||||
/// Whether the model supports image inputs.
|
||||
pub supports_images: bool,
|
||||
/// Whether the model supports tool/function calling.
|
||||
pub supports_tools: bool,
|
||||
/// Whether the model supports extended thinking/reasoning.
|
||||
pub supports_thinking: bool,
|
||||
}
|
||||
|
||||
/// Authentication configuration for a language model provider.
|
||||
#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
|
||||
pub struct LanguageModelAuthConfig {
|
||||
/// Human-readable name for the credential shown in the UI input field (e.g. "API Key", "Access Token").
|
||||
pub credential_label: Option<String>,
|
||||
/// Environment variable names for the API key (if env var auth supported).
|
||||
/// Multiple env vars can be specified; they will be checked in order.
|
||||
#[serde(default)]
|
||||
pub env_vars: Option<Vec<String>>,
|
||||
/// OAuth configuration for web-based authentication flows.
|
||||
#[serde(default)]
|
||||
pub oauth: Option<OAuthConfig>,
|
||||
}
|
||||
|
||||
/// OAuth configuration for web-based authentication.
|
||||
#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
|
||||
pub struct OAuthConfig {
|
||||
/// The text to display on the sign-in button (e.g. "Sign in with GitHub").
|
||||
pub sign_in_button_label: Option<String>,
|
||||
/// The Zed icon path to display on the sign-in button (e.g. "github").
|
||||
#[serde(default)]
|
||||
pub sign_in_button_icon: Option<String>,
|
||||
/// The description text shown next to the sign-in button in edit prediction settings.
|
||||
#[serde(default)]
|
||||
pub sign_in_description: Option<String>,
|
||||
}
|
||||
|
||||
impl ExtensionManifest {
|
||||
|
||||
@@ -29,6 +29,26 @@ pub use wit::{
|
||||
GithubRelease, GithubReleaseAsset, GithubReleaseOptions, github_release_by_tag_name,
|
||||
latest_github_release,
|
||||
},
|
||||
zed::extension::llm_provider::{
|
||||
CacheConfiguration as LlmCacheConfiguration, CompletionEvent as LlmCompletionEvent,
|
||||
CompletionRequest as LlmCompletionRequest, CustomModelConfig as LlmCustomModelConfig,
|
||||
DeviceFlowPromptInfo as LlmDeviceFlowPromptInfo, ImageData as LlmImageData,
|
||||
MessageContent as LlmMessageContent, MessageRole as LlmMessageRole,
|
||||
ModelCapabilities as LlmModelCapabilities, ModelInfo as LlmModelInfo,
|
||||
OauthWebAuthConfig as LlmOauthWebAuthConfig, OauthWebAuthResult as LlmOauthWebAuthResult,
|
||||
ProviderInfo as LlmProviderInfo, ProviderSettings as LlmProviderSettings,
|
||||
RequestMessage as LlmRequestMessage, StopReason as LlmStopReason,
|
||||
ThinkingContent as LlmThinkingContent, TokenUsage as LlmTokenUsage,
|
||||
ToolChoice as LlmToolChoice, ToolDefinition as LlmToolDefinition,
|
||||
ToolInputFormat as LlmToolInputFormat, ToolResult as LlmToolResult,
|
||||
ToolResultContent as LlmToolResultContent, ToolUse as LlmToolUse,
|
||||
ToolUseJsonParseError as LlmToolUseJsonParseError,
|
||||
delete_credential as llm_delete_credential, get_credential as llm_get_credential,
|
||||
get_env_var as llm_get_env_var, get_provider_settings as llm_get_provider_settings,
|
||||
oauth_open_browser as llm_oauth_open_browser,
|
||||
oauth_send_http_request as llm_oauth_send_http_request,
|
||||
oauth_start_web_auth as llm_oauth_start_web_auth, store_credential as llm_store_credential,
|
||||
},
|
||||
zed::extension::nodejs::{
|
||||
node_binary_path, npm_install_package, npm_package_installed_version,
|
||||
npm_package_latest_version,
|
||||
@@ -259,6 +279,93 @@ pub trait Extension: Send + Sync {
|
||||
) -> Result<DebugRequest, String> {
|
||||
Err("`run_dap_locator` not implemented".to_string())
|
||||
}
|
||||
|
||||
/// Returns information about language model providers offered by this extension.
|
||||
fn llm_providers(&self) -> Vec<LlmProviderInfo> {
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
/// Returns the models available for a provider.
|
||||
fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
/// Returns markdown content to display in the provider's settings UI.
|
||||
/// This can include setup instructions, links to documentation, etc.
|
||||
fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
|
||||
None
|
||||
}
|
||||
|
||||
/// Check if the provider is authenticated.
|
||||
fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// Start an OAuth device flow sign-in.
|
||||
/// This is called when the user explicitly clicks "Sign in with GitHub" or similar.
|
||||
/// Returns information needed to display the device flow prompt modal to the user.
|
||||
fn llm_provider_start_device_flow_sign_in(
|
||||
&mut self,
|
||||
_provider_id: &str,
|
||||
) -> Result<LlmDeviceFlowPromptInfo, String> {
|
||||
Err("`llm_provider_start_device_flow_sign_in` not implemented".to_string())
|
||||
}
|
||||
|
||||
/// Poll for device flow sign-in completion.
|
||||
/// This is called after llm_provider_start_device_flow_sign_in returns the user code.
|
||||
/// The extension should poll the OAuth provider until the user authorizes or the flow times out.
|
||||
fn llm_provider_poll_device_flow_sign_in(&mut self, _provider_id: &str) -> Result<(), String> {
|
||||
Err("`llm_provider_poll_device_flow_sign_in` not implemented".to_string())
|
||||
}
|
||||
|
||||
/// Reset credentials for the provider.
|
||||
fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
|
||||
Err("`llm_provider_reset_credentials` not implemented".to_string())
|
||||
}
|
||||
|
||||
/// Count tokens for a request.
|
||||
fn llm_count_tokens(
|
||||
&self,
|
||||
_provider_id: &str,
|
||||
_model_id: &str,
|
||||
_request: &LlmCompletionRequest,
|
||||
) -> Result<u64, String> {
|
||||
Err("`llm_count_tokens` not implemented".to_string())
|
||||
}
|
||||
|
||||
/// Start streaming a completion from the model.
|
||||
/// Returns a stream ID that can be used with `llm_stream_completion_next` and `llm_stream_completion_close`.
|
||||
fn llm_stream_completion_start(
|
||||
&mut self,
|
||||
_provider_id: &str,
|
||||
_model_id: &str,
|
||||
_request: &LlmCompletionRequest,
|
||||
) -> Result<String, String> {
|
||||
Err("`llm_stream_completion_start` not implemented".to_string())
|
||||
}
|
||||
|
||||
/// Get the next event from a completion stream.
|
||||
/// Returns `Ok(None)` when the stream is complete.
|
||||
fn llm_stream_completion_next(
|
||||
&mut self,
|
||||
_stream_id: &str,
|
||||
) -> Result<Option<LlmCompletionEvent>, String> {
|
||||
Err("`llm_stream_completion_next` not implemented".to_string())
|
||||
}
|
||||
|
||||
/// Close a completion stream and release its resources.
|
||||
fn llm_stream_completion_close(&mut self, _stream_id: &str) {
|
||||
// Default implementation does nothing
|
||||
}
|
||||
|
||||
/// Get cache configuration for a model (if prompt caching is supported).
|
||||
fn llm_cache_configuration(
|
||||
&self,
|
||||
_provider_id: &str,
|
||||
_model_id: &str,
|
||||
) -> Option<LlmCacheConfiguration> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Registers the provided type as a Zed extension.
|
||||
@@ -517,6 +624,67 @@ impl wit::Guest for Component {
|
||||
) -> Result<DebugRequest, String> {
|
||||
extension().run_dap_locator(locator_name, build_task)
|
||||
}
|
||||
|
||||
fn llm_providers() -> Vec<LlmProviderInfo> {
|
||||
extension().llm_providers()
|
||||
}
|
||||
|
||||
fn llm_provider_models(provider_id: String) -> Result<Vec<LlmModelInfo>, String> {
|
||||
extension().llm_provider_models(&provider_id)
|
||||
}
|
||||
|
||||
fn llm_provider_settings_markdown(provider_id: String) -> Option<String> {
|
||||
extension().llm_provider_settings_markdown(&provider_id)
|
||||
}
|
||||
|
||||
fn llm_provider_is_authenticated(provider_id: String) -> bool {
|
||||
extension().llm_provider_is_authenticated(&provider_id)
|
||||
}
|
||||
|
||||
fn llm_provider_start_device_flow_sign_in(
|
||||
provider_id: String,
|
||||
) -> Result<LlmDeviceFlowPromptInfo, String> {
|
||||
extension().llm_provider_start_device_flow_sign_in(&provider_id)
|
||||
}
|
||||
|
||||
fn llm_provider_poll_device_flow_sign_in(provider_id: String) -> Result<(), String> {
|
||||
extension().llm_provider_poll_device_flow_sign_in(&provider_id)
|
||||
}
|
||||
|
||||
fn llm_provider_reset_credentials(provider_id: String) -> Result<(), String> {
|
||||
extension().llm_provider_reset_credentials(&provider_id)
|
||||
}
|
||||
|
||||
fn llm_count_tokens(
|
||||
provider_id: String,
|
||||
model_id: String,
|
||||
request: LlmCompletionRequest,
|
||||
) -> Result<u64, String> {
|
||||
extension().llm_count_tokens(&provider_id, &model_id, &request)
|
||||
}
|
||||
|
||||
fn llm_stream_completion_start(
|
||||
provider_id: String,
|
||||
model_id: String,
|
||||
request: LlmCompletionRequest,
|
||||
) -> Result<String, String> {
|
||||
extension().llm_stream_completion_start(&provider_id, &model_id, &request)
|
||||
}
|
||||
|
||||
fn llm_stream_completion_next(stream_id: String) -> Result<Option<LlmCompletionEvent>, String> {
|
||||
extension().llm_stream_completion_next(&stream_id)
|
||||
}
|
||||
|
||||
fn llm_stream_completion_close(stream_id: String) {
|
||||
extension().llm_stream_completion_close(&stream_id)
|
||||
}
|
||||
|
||||
fn llm_cache_configuration(
|
||||
provider_id: String,
|
||||
model_id: String,
|
||||
) -> Option<LlmCacheConfiguration> {
|
||||
extension().llm_cache_configuration(&provider_id, &model_id)
|
||||
}
|
||||
}
|
||||
|
||||
/// The ID of a language server.
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
//! An HTTP client.
|
||||
|
||||
pub use crate::wit::zed::extension::http_client::{
|
||||
HttpMethod, HttpRequest, HttpResponse, HttpResponseStream, RedirectPolicy, fetch, fetch_stream,
|
||||
HttpMethod, HttpRequest, HttpResponse, HttpResponseStream, HttpResponseWithStatus,
|
||||
RedirectPolicy, fetch, fetch_fallible, fetch_stream,
|
||||
};
|
||||
|
||||
impl HttpRequest {
|
||||
@@ -15,6 +16,11 @@ impl HttpRequest {
|
||||
fetch(self)
|
||||
}
|
||||
|
||||
/// Like [`fetch`], except it doesn't treat any status codes as errors.
|
||||
pub fn fetch_fallible(&self) -> Result<HttpResponseWithStatus, String> {
|
||||
fetch_fallible(self)
|
||||
}
|
||||
|
||||
/// Executes the [`HttpRequest`] with [`fetch_stream`].
|
||||
pub fn fetch_stream(&self) -> Result<HttpResponseStream, String> {
|
||||
fetch_stream(self)
|
||||
|
||||
@@ -8,6 +8,7 @@ world extension {
|
||||
import platform;
|
||||
import process;
|
||||
import nodejs;
|
||||
import llm-provider;
|
||||
|
||||
use common.{env-vars, range};
|
||||
use context-server.{context-server-configuration};
|
||||
@@ -15,6 +16,11 @@ world extension {
|
||||
use lsp.{completion, symbol};
|
||||
use process.{command};
|
||||
use slash-command.{slash-command, slash-command-argument-completion, slash-command-output};
|
||||
use llm-provider.{
|
||||
provider-info, model-info, completion-request,
|
||||
cache-configuration, completion-event, token-usage,
|
||||
device-flow-prompt-info
|
||||
};
|
||||
|
||||
/// Initializes the extension.
|
||||
export init-extension: func();
|
||||
@@ -164,4 +170,73 @@ world extension {
|
||||
export dap-config-to-scenario: func(config: debug-config) -> result<debug-scenario, string>;
|
||||
export dap-locator-create-scenario: func(locator-name: string, build-config-template: build-task-template, resolved-label: string, debug-adapter-name: string) -> option<debug-scenario>;
|
||||
export run-dap-locator: func(locator-name: string, config: resolved-task) -> result<debug-request, string>;
|
||||
|
||||
/// Returns information about language model providers offered by this extension.
|
||||
export llm-providers: func() -> list<provider-info>;
|
||||
|
||||
/// Returns the models available for a provider.
|
||||
export llm-provider-models: func(provider-id: string) -> result<list<model-info>, string>;
|
||||
|
||||
/// Returns markdown content to display in the provider's settings UI.
|
||||
/// This can include setup instructions, links to documentation, etc.
|
||||
export llm-provider-settings-markdown: func(provider-id: string) -> option<string>;
|
||||
|
||||
/// Check if the provider is authenticated.
|
||||
export llm-provider-is-authenticated: func(provider-id: string) -> bool;
|
||||
|
||||
/// Start an OAuth device flow sign-in.
|
||||
/// This is called when the user explicitly clicks "Sign in with GitHub" or similar.
|
||||
///
|
||||
/// The device flow works as follows:
|
||||
/// 1. Extension requests a device code from the OAuth provider
|
||||
/// 2. Extension returns prompt info including user code and verification URL
|
||||
/// 3. Host displays a modal with the prompt info
|
||||
/// 4. Host calls llm-provider-poll-device-flow-sign-in
|
||||
/// 5. Extension polls for the access token while user authorizes in browser
|
||||
/// 6. Once authorized, extension stores the credential and returns success
|
||||
///
|
||||
/// Returns information needed to display the device flow prompt modal.
|
||||
export llm-provider-start-device-flow-sign-in: func(provider-id: string) -> result<device-flow-prompt-info, string>;
|
||||
|
||||
/// Poll for device flow sign-in completion.
|
||||
/// This is called after llm-provider-start-device-flow-sign-in returns the user code.
|
||||
/// The extension should poll the OAuth provider until the user authorizes or the flow times out.
|
||||
/// Returns Ok(()) on successful authentication, or an error message on failure.
|
||||
export llm-provider-poll-device-flow-sign-in: func(provider-id: string) -> result<_, string>;
|
||||
|
||||
/// Reset credentials for the provider.
|
||||
export llm-provider-reset-credentials: func(provider-id: string) -> result<_, string>;
|
||||
|
||||
/// Count tokens for a request.
|
||||
export llm-count-tokens: func(
|
||||
provider-id: string,
|
||||
model-id: string,
|
||||
request: completion-request
|
||||
) -> result<u64, string>;
|
||||
|
||||
/// Start streaming a completion from the model.
|
||||
/// Returns a stream ID that can be used with llm-stream-next and llm-stream-close.
|
||||
export llm-stream-completion-start: func(
|
||||
provider-id: string,
|
||||
model-id: string,
|
||||
request: completion-request
|
||||
) -> result<string, string>;
|
||||
|
||||
/// Get the next event from a completion stream.
|
||||
/// Returns None when the stream is complete.
|
||||
export llm-stream-completion-next: func(
|
||||
stream-id: string
|
||||
) -> result<option<completion-event>, string>;
|
||||
|
||||
/// Close a completion stream and release its resources.
|
||||
export llm-stream-completion-close: func(
|
||||
stream-id: string
|
||||
);
|
||||
|
||||
/// Get cache configuration for a model (if prompt caching is supported).
|
||||
export llm-cache-configuration: func(
|
||||
provider-id: string,
|
||||
model-id: string
|
||||
) -> option<cache-configuration>;
|
||||
|
||||
}
|
||||
|
||||
@@ -51,9 +51,26 @@ interface http-client {
|
||||
body: list<u8>,
|
||||
}
|
||||
|
||||
/// An HTTP response that includes the status code.
|
||||
///
|
||||
/// Used by `fetch-fallible` which returns responses for all status codes
|
||||
/// rather than treating some status codes as errors.
|
||||
record http-response-with-status {
|
||||
/// The HTTP status code.
|
||||
status: u16,
|
||||
/// The response headers.
|
||||
headers: list<tuple<string, string>>,
|
||||
/// The response body.
|
||||
body: list<u8>,
|
||||
}
|
||||
|
||||
/// Performs an HTTP request and returns the response.
|
||||
/// Returns an error if the response status is 4xx or 5xx.
|
||||
fetch: func(req: http-request) -> result<http-response, string>;
|
||||
|
||||
/// Performs an HTTP request and returns the response regardless of its status code.
|
||||
fetch-fallible: func(req: http-request) -> result<http-response-with-status, string>;
|
||||
|
||||
/// An HTTP response stream.
|
||||
resource http-response-stream {
|
||||
/// Retrieves the next chunk of data from the response stream.
|
||||
|
||||
362
crates/extension_api/wit/since_v0.8.0/llm-provider.wit
Normal file
362
crates/extension_api/wit/since_v0.8.0/llm-provider.wit
Normal file
@@ -0,0 +1,362 @@
|
||||
interface llm-provider {
|
||||
use http-client.{http-request, http-response-with-status};
|
||||
|
||||
/// Information about a language model provider.
|
||||
record provider-info {
|
||||
/// Unique identifier for the provider (e.g. "my-extension.my-provider").
|
||||
id: string,
|
||||
/// Display name for the provider.
|
||||
name: string,
|
||||
/// Path to an SVG icon file relative to the extension root (e.g. "icons/provider.svg").
|
||||
icon: option<string>,
|
||||
}
|
||||
|
||||
/// Capabilities of a language model.
|
||||
record model-capabilities {
|
||||
/// Whether the model supports image inputs.
|
||||
supports-images: bool,
|
||||
/// Whether the model supports tool/function calling.
|
||||
supports-tools: bool,
|
||||
/// Whether the model supports the "auto" tool choice.
|
||||
supports-tool-choice-auto: bool,
|
||||
/// Whether the model supports the "any" tool choice.
|
||||
supports-tool-choice-any: bool,
|
||||
/// Whether the model supports the "none" tool choice.
|
||||
supports-tool-choice-none: bool,
|
||||
/// Whether the model supports extended thinking/reasoning.
|
||||
supports-thinking: bool,
|
||||
/// The format for tool input schemas.
|
||||
tool-input-format: tool-input-format,
|
||||
}
|
||||
|
||||
/// Format for tool input schemas.
|
||||
enum tool-input-format {
|
||||
/// Standard JSON Schema format.
|
||||
json-schema,
|
||||
/// A subset of JSON Schema supported by Google AI.
|
||||
/// See https://ai.google.dev/api/caching#Schema
|
||||
json-schema-subset,
|
||||
/// Simplified schema format for certain providers.
|
||||
simplified,
|
||||
}
|
||||
|
||||
/// Information about a specific model.
|
||||
record model-info {
|
||||
/// Unique identifier for the model.
|
||||
id: string,
|
||||
/// Display name for the model.
|
||||
name: string,
|
||||
/// Maximum input token count.
|
||||
max-token-count: u64,
|
||||
/// Maximum output tokens (optional).
|
||||
max-output-tokens: option<u64>,
|
||||
/// Model capabilities.
|
||||
capabilities: model-capabilities,
|
||||
/// Whether this is the default model for the provider.
|
||||
is-default: bool,
|
||||
/// Whether this is the default fast model.
|
||||
is-default-fast: bool,
|
||||
}
|
||||
|
||||
/// The role of a message participant.
|
||||
enum message-role {
|
||||
/// User message.
|
||||
user,
|
||||
/// Assistant message.
|
||||
assistant,
|
||||
/// System message.
|
||||
system,
|
||||
}
|
||||
|
||||
/// A message in a completion request.
|
||||
record request-message {
|
||||
/// The role of the message sender.
|
||||
role: message-role,
|
||||
/// The content of the message.
|
||||
content: list<message-content>,
|
||||
/// Whether to cache this message for prompt caching.
|
||||
cache: bool,
|
||||
}
|
||||
|
||||
/// Content within a message.
|
||||
variant message-content {
|
||||
/// Plain text content.
|
||||
text(string),
|
||||
/// Image content.
|
||||
image(image-data),
|
||||
/// A tool use request from the assistant.
|
||||
tool-use(tool-use),
|
||||
/// A tool result from the user.
|
||||
tool-result(tool-result),
|
||||
/// Thinking/reasoning content.
|
||||
thinking(thinking-content),
|
||||
/// Redacted/encrypted thinking content.
|
||||
redacted-thinking(string),
|
||||
}
|
||||
|
||||
/// Image data for vision models.
|
||||
record image-data {
|
||||
/// Base64-encoded image data.
|
||||
source: string,
|
||||
/// Image width in pixels (optional).
|
||||
width: option<u32>,
|
||||
/// Image height in pixels (optional).
|
||||
height: option<u32>,
|
||||
}
|
||||
|
||||
/// A tool use request from the model.
|
||||
record tool-use {
|
||||
/// Unique identifier for this tool use.
|
||||
id: string,
|
||||
/// The name of the tool being used.
|
||||
name: string,
|
||||
/// JSON string of the tool input arguments.
|
||||
input: string,
|
||||
/// Whether the input JSON is complete (false while streaming, true when done).
|
||||
is-input-complete: bool,
|
||||
/// Thought signature for providers that support it (e.g., Anthropic).
|
||||
thought-signature: option<string>,
|
||||
}
|
||||
|
||||
/// A tool result to send back to the model.
|
||||
record tool-result {
|
||||
/// The ID of the tool use this is a result for.
|
||||
tool-use-id: string,
|
||||
/// The name of the tool.
|
||||
tool-name: string,
|
||||
/// Whether this result represents an error.
|
||||
is-error: bool,
|
||||
/// The content of the result.
|
||||
content: tool-result-content,
|
||||
}
|
||||
|
||||
/// Content of a tool result.
|
||||
variant tool-result-content {
|
||||
/// Text result.
|
||||
text(string),
|
||||
/// Image result.
|
||||
image(image-data),
|
||||
}
|
||||
|
||||
/// Thinking/reasoning content from models that support extended thinking.
|
||||
record thinking-content {
|
||||
/// The thinking text.
|
||||
text: string,
|
||||
/// Signature for the thinking block (provider-specific).
|
||||
signature: option<string>,
|
||||
}
|
||||
|
||||
/// A tool definition for function calling.
|
||||
record tool-definition {
|
||||
/// The name of the tool.
|
||||
name: string,
|
||||
/// Description of what the tool does.
|
||||
description: string,
|
||||
/// JSON Schema for input parameters.
|
||||
input-schema: string,
|
||||
}
|
||||
|
||||
/// Tool choice preference for the model.
|
||||
enum tool-choice {
|
||||
/// Let the model decide whether to use tools.
|
||||
auto,
|
||||
/// Force the model to use at least one tool.
|
||||
any,
|
||||
/// Prevent the model from using tools.
|
||||
none,
|
||||
}
|
||||
|
||||
/// A completion request to send to the model.
|
||||
record completion-request {
|
||||
/// The messages in the conversation.
|
||||
messages: list<request-message>,
|
||||
/// Available tools for the model to use.
|
||||
tools: list<tool-definition>,
|
||||
/// Tool choice preference.
|
||||
tool-choice: option<tool-choice>,
|
||||
/// Stop sequences to end generation.
|
||||
stop-sequences: list<string>,
|
||||
/// Temperature for sampling (0.0-1.0).
|
||||
temperature: option<f32>,
|
||||
/// Whether thinking/reasoning is allowed.
|
||||
thinking-allowed: bool,
|
||||
/// Maximum tokens to generate.
|
||||
max-tokens: option<u64>,
|
||||
}
|
||||
|
||||
/// Events emitted during completion streaming.
|
||||
variant completion-event {
|
||||
/// Completion has started.
|
||||
started,
|
||||
/// Text content chunk.
|
||||
text(string),
|
||||
/// Thinking/reasoning content chunk.
|
||||
thinking(thinking-content),
|
||||
/// Redacted thinking (encrypted) chunk.
|
||||
redacted-thinking(string),
|
||||
/// Tool use request from the model.
|
||||
tool-use(tool-use),
|
||||
/// JSON parse error when parsing tool input.
|
||||
tool-use-json-parse-error(tool-use-json-parse-error),
|
||||
/// Completion stopped.
|
||||
stop(stop-reason),
|
||||
/// Token usage update.
|
||||
usage(token-usage),
|
||||
/// Reasoning details (provider-specific JSON).
|
||||
reasoning-details(string),
|
||||
}
|
||||
|
||||
/// Error information when tool use JSON parsing fails.
|
||||
record tool-use-json-parse-error {
|
||||
/// The tool use ID.
|
||||
id: string,
|
||||
/// The tool name.
|
||||
tool-name: string,
|
||||
/// The raw input that failed to parse.
|
||||
raw-input: string,
|
||||
/// The parse error message.
|
||||
error: string,
|
||||
}
|
||||
|
||||
/// Reason the completion stopped.
|
||||
enum stop-reason {
|
||||
/// The model finished generating.
|
||||
end-turn,
|
||||
/// Maximum tokens reached.
|
||||
max-tokens,
|
||||
/// The model wants to use a tool.
|
||||
tool-use,
|
||||
/// The model refused to respond.
|
||||
refusal,
|
||||
}
|
||||
|
||||
/// Token usage statistics.
|
||||
record token-usage {
|
||||
/// Number of input tokens used.
|
||||
input-tokens: u64,
|
||||
/// Number of output tokens generated.
|
||||
output-tokens: u64,
|
||||
/// Tokens used for cache creation (if supported).
|
||||
cache-creation-input-tokens: option<u64>,
|
||||
/// Tokens read from cache (if supported).
|
||||
cache-read-input-tokens: option<u64>,
|
||||
}
|
||||
|
||||
/// Cache configuration for prompt caching.
|
||||
record cache-configuration {
|
||||
/// Maximum number of cache anchors.
|
||||
max-cache-anchors: u32,
|
||||
/// Whether caching should be applied to tool definitions.
|
||||
should-cache-tool-definitions: bool,
|
||||
/// Minimum token count for a message to be cached.
|
||||
min-total-token-count: u64,
|
||||
}
|
||||
|
||||
/// Configuration for starting an OAuth web authentication flow.
|
||||
record oauth-web-auth-config {
|
||||
/// The URL to open in the user's browser to start authentication.
|
||||
/// This should include client_id, redirect_uri, scope, state, etc.
|
||||
/// Use `{port}` as a placeholder in the URL - it will be replaced with
|
||||
/// the actual localhost port before opening the browser.
|
||||
/// Example: "https://example.com/oauth?redirect_uri=http://127.0.0.1:{port}/callback"
|
||||
auth-url: string,
|
||||
/// The path to listen on for the OAuth callback (e.g., "/callback").
|
||||
/// A localhost server will be started to receive the redirect.
|
||||
callback-path: string,
|
||||
/// Timeout in seconds to wait for the callback (default: 300 = 5 minutes).
|
||||
timeout-secs: option<u32>,
|
||||
}
|
||||
|
||||
/// Result of an OAuth web authentication flow.
|
||||
record oauth-web-auth-result {
|
||||
/// The full callback URL that was received, including query parameters.
|
||||
/// The extension is responsible for parsing the code, state, etc.
|
||||
callback-url: string,
|
||||
/// The port that was used for the localhost callback server.
|
||||
port: u32,
|
||||
}
|
||||
|
||||
/// Get a stored credential for this provider.
|
||||
get-credential: func(provider-id: string) -> option<string>;
|
||||
|
||||
/// Store a credential for this provider.
|
||||
store-credential: func(provider-id: string, value: string) -> result<_, string>;
|
||||
|
||||
/// Delete a stored credential for this provider.
|
||||
delete-credential: func(provider-id: string) -> result<_, string>;
|
||||
|
||||
/// Read an environment variable.
|
||||
get-env-var: func(name: string) -> option<string>;
|
||||
|
||||
/// Start an OAuth web authentication flow.
|
||||
///
|
||||
/// This will:
|
||||
/// 1. Start a localhost server to receive the OAuth callback
|
||||
/// 2. Open the auth URL in the user's default browser
|
||||
/// 3. Wait for the callback (up to the timeout)
|
||||
/// 4. Return the callback URL with query parameters
|
||||
///
|
||||
/// The extension is responsible for:
|
||||
/// - Constructing the auth URL with client_id, redirect_uri, scope, state, etc.
|
||||
/// - Parsing the callback URL to extract the authorization code
|
||||
/// - Exchanging the code for tokens using fetch-fallible from http-client
|
||||
oauth-start-web-auth: func(config: oauth-web-auth-config) -> result<oauth-web-auth-result, string>;
|
||||
|
||||
/// Make an HTTP request for OAuth token exchange.
|
||||
///
|
||||
/// This is a convenience wrapper around http-client's fetch-fallible for OAuth flows.
|
||||
/// Unlike the standard fetch, this does not treat non-2xx responses as errors,
|
||||
/// allowing proper handling of OAuth error responses.
|
||||
oauth-send-http-request: func(request: http-request) -> result<http-response-with-status, string>;
|
||||
|
||||
/// Open a URL in the user's default browser.
|
||||
///
|
||||
/// Useful for OAuth flows that need to open a browser but handle the
|
||||
/// callback differently (e.g., polling-based flows).
|
||||
oauth-open-browser: func(url: string) -> result<_, string>;
|
||||
|
||||
/// Provider settings from user configuration.
|
||||
/// Extensions can use this to allow custom API URLs, custom models, etc.
|
||||
record provider-settings {
|
||||
/// Custom API URL override (if configured by the user).
|
||||
api-url: option<string>,
|
||||
/// Custom models configured by the user.
|
||||
available-models: list<custom-model-config>,
|
||||
}
|
||||
|
||||
/// Configuration for a custom model defined by the user.
|
||||
record custom-model-config {
|
||||
/// The model's API identifier.
|
||||
name: string,
|
||||
/// Display name for the UI.
|
||||
display-name: option<string>,
|
||||
/// Maximum input token count.
|
||||
max-tokens: u64,
|
||||
/// Maximum output tokens (optional).
|
||||
max-output-tokens: option<u64>,
|
||||
/// Thinking budget for models that support extended thinking (None = auto).
|
||||
thinking-budget: option<u32>,
|
||||
}
|
||||
|
||||
/// Get provider-specific settings configured by the user.
|
||||
/// Returns settings like custom API URLs and custom model configurations.
|
||||
get-provider-settings: func(provider-id: string) -> option<provider-settings>;
|
||||
|
||||
/// Information needed to display the device flow prompt modal to the user.
|
||||
record device-flow-prompt-info {
|
||||
/// The user code to display (e.g., "ABC-123").
|
||||
user-code: string,
|
||||
/// The URL the user needs to visit to authorize (for the "Connect" button).
|
||||
verification-url: string,
|
||||
/// The headline text for the modal (e.g., "Use GitHub Copilot in Zed.").
|
||||
headline: string,
|
||||
/// A description to show below the headline (e.g., "Using Copilot requires an active subscription on GitHub.").
|
||||
description: string,
|
||||
/// Label for the connect button (e.g., "Connect to GitHub").
|
||||
connect-button-label: string,
|
||||
/// Success headline shown when authorization completes.
|
||||
success-headline: string,
|
||||
/// Success message shown when authorization completes.
|
||||
success-message: string,
|
||||
}
|
||||
}
|
||||
@@ -255,6 +255,21 @@ async fn copy_extension_resources(
|
||||
}
|
||||
}
|
||||
|
||||
for (_, provider_entry) in &manifest.language_model_providers {
|
||||
if let Some(icon_path) = &provider_entry.icon {
|
||||
let source_icon = extension_path.join(icon_path);
|
||||
let dest_icon = output_dir.join(icon_path);
|
||||
|
||||
// Create parent directory if needed
|
||||
if let Some(parent) = dest_icon.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
fs::copy(&source_icon, &dest_icon)
|
||||
.with_context(|| format!("failed to copy LLM provider icon '{}'", icon_path))?;
|
||||
}
|
||||
}
|
||||
|
||||
if !manifest.languages.is_empty() {
|
||||
let output_languages_dir = output_dir.join("languages");
|
||||
fs::create_dir_all(&output_languages_dir)?;
|
||||
|
||||
@@ -22,7 +22,9 @@ async-tar.workspace = true
|
||||
async-trait.workspace = true
|
||||
client.workspace = true
|
||||
collections.workspace = true
|
||||
credentials_provider.workspace = true
|
||||
dap.workspace = true
|
||||
dirs.workspace = true
|
||||
extension.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
@@ -30,8 +32,11 @@ gpui.workspace = true
|
||||
gpui_tokio.workspace = true
|
||||
http_client.workspace = true
|
||||
language.workspace = true
|
||||
language_model.workspace = true
|
||||
log.workspace = true
|
||||
markdown.workspace = true
|
||||
lsp.workspace = true
|
||||
menu.workspace = true
|
||||
moka.workspace = true
|
||||
node_runtime.workspace = true
|
||||
paths.workspace = true
|
||||
@@ -43,11 +48,16 @@ serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
serde_json_lenient.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
task.workspace = true
|
||||
telemetry.workspace = true
|
||||
tempfile.workspace = true
|
||||
theme.workspace = true
|
||||
toml.workspace = true
|
||||
ui.workspace = true
|
||||
ui_input.workspace = true
|
||||
url.workspace = true
|
||||
workspace.workspace = true
|
||||
util.workspace = true
|
||||
wasmparser.workspace = true
|
||||
wasmtime-wasi.workspace = true
|
||||
|
||||
124
crates/extension_host/src/anthropic_migration.rs
Normal file
124
crates/extension_host/src/anthropic_migration.rs
Normal file
@@ -0,0 +1,124 @@
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use gpui::App;
|
||||
|
||||
const ANTHROPIC_EXTENSION_ID: &str = "anthropic";
|
||||
const ANTHROPIC_PROVIDER_ID: &str = "anthropic";
|
||||
const ANTHROPIC_DEFAULT_API_URL: &str = "https://api.anthropic.com";
|
||||
|
||||
/// Migrates Anthropic API credentials from the old built-in provider location
|
||||
/// to the new extension-based location.
|
||||
///
|
||||
/// This should only be called during auto-install of the extension.
|
||||
pub fn migrate_anthropic_credentials_if_needed(extension_id: &str, cx: &mut App) {
|
||||
if extension_id != ANTHROPIC_EXTENSION_ID {
|
||||
return;
|
||||
}
|
||||
|
||||
let extension_credential_key = format!(
|
||||
"extension-llm-{}:{}",
|
||||
ANTHROPIC_EXTENSION_ID, ANTHROPIC_PROVIDER_ID
|
||||
);
|
||||
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
// Read from old location
|
||||
let old_credential = credentials_provider
|
||||
.read_credentials(ANTHROPIC_DEFAULT_API_URL, &cx)
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
|
||||
let api_key = match old_credential {
|
||||
Some((_, key_bytes)) => match String::from_utf8(key_bytes) {
|
||||
Ok(key) if !key.is_empty() => key,
|
||||
Ok(_) => {
|
||||
log::debug!("Existing Anthropic API key is empty, nothing to migrate");
|
||||
return;
|
||||
}
|
||||
Err(_) => {
|
||||
log::error!("Failed to decode Anthropic API key as UTF-8");
|
||||
return;
|
||||
}
|
||||
},
|
||||
None => {
|
||||
log::debug!("No existing Anthropic API key found to migrate");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
log::info!("Migrating existing Anthropic API key to Anthropic extension");
|
||||
|
||||
match credentials_provider
|
||||
.write_credentials(&extension_credential_key, "Bearer", api_key.as_bytes(), &cx)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
log::info!("Successfully migrated Anthropic API key to extension");
|
||||
}
|
||||
Err(err) => {
|
||||
log::error!("Failed to migrate Anthropic API key: {}", err);
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_migrates_credentials_from_old_location(cx: &mut TestAppContext) {
|
||||
let api_key = "sk-ant-test-key-12345";
|
||||
|
||||
cx.write_credentials(ANTHROPIC_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
|
||||
|
||||
cx.update(|cx| {
|
||||
migrate_anthropic_credentials_if_needed(ANTHROPIC_EXTENSION_ID, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let migrated = cx.read_credentials("extension-llm-anthropic:anthropic");
|
||||
assert!(migrated.is_some(), "Credentials should have been migrated");
|
||||
let (username, password) = migrated.unwrap();
|
||||
assert_eq!(username, "Bearer");
|
||||
assert_eq!(String::from_utf8(password).unwrap(), api_key);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_no_migration_if_no_old_credentials(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
migrate_anthropic_credentials_if_needed(ANTHROPIC_EXTENSION_ID, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let credentials = cx.read_credentials("extension-llm-anthropic:anthropic");
|
||||
assert!(
|
||||
credentials.is_none(),
|
||||
"Should not create credentials if none existed"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_skips_migration_for_other_extensions(cx: &mut TestAppContext) {
|
||||
let api_key = "sk-ant-test-key";
|
||||
|
||||
cx.write_credentials(ANTHROPIC_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
|
||||
|
||||
cx.update(|cx| {
|
||||
migrate_anthropic_credentials_if_needed("some-other-extension", cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let credentials = cx.read_credentials("extension-llm-anthropic:anthropic");
|
||||
assert!(
|
||||
credentials.is_none(),
|
||||
"Should not migrate for other extensions"
|
||||
);
|
||||
}
|
||||
}
|
||||
216
crates/extension_host/src/copilot_migration.rs
Normal file
216
crates/extension_host/src/copilot_migration.rs
Normal file
@@ -0,0 +1,216 @@
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use gpui::App;
|
||||
use std::path::PathBuf;
|
||||
|
||||
const COPILOT_CHAT_EXTENSION_ID: &str = "copilot-chat";
|
||||
const COPILOT_CHAT_PROVIDER_ID: &str = "copilot-chat";
|
||||
|
||||
/// Migrates Copilot OAuth credentials from the GitHub Copilot config files
|
||||
/// to the new extension-based credential location.
|
||||
///
|
||||
/// This should only be called during auto-install of the extension.
|
||||
pub fn migrate_copilot_credentials_if_needed(extension_id: &str, cx: &mut App) {
|
||||
if extension_id != COPILOT_CHAT_EXTENSION_ID {
|
||||
return;
|
||||
}
|
||||
|
||||
let credential_key = format!(
|
||||
"extension-llm-{}:{}",
|
||||
COPILOT_CHAT_EXTENSION_ID, COPILOT_CHAT_PROVIDER_ID
|
||||
);
|
||||
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
|
||||
cx.spawn(async move |_cx| {
|
||||
// Read from copilot config files
|
||||
let oauth_token = match read_copilot_oauth_token().await {
|
||||
Some(token) if !token.is_empty() => token,
|
||||
_ => {
|
||||
log::debug!("No existing Copilot OAuth token found to migrate");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
log::info!("Migrating existing Copilot OAuth token to Copilot Chat extension");
|
||||
|
||||
match credentials_provider
|
||||
.write_credentials(&credential_key, "api_key", oauth_token.as_bytes(), &_cx)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
log::info!("Successfully migrated Copilot OAuth token to Copilot Chat extension");
|
||||
}
|
||||
Err(err) => {
|
||||
log::error!("Failed to migrate Copilot OAuth token: {}", err);
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
async fn read_copilot_oauth_token() -> Option<String> {
|
||||
let config_paths = copilot_config_paths();
|
||||
|
||||
for path in config_paths {
|
||||
if let Some(token) = read_oauth_token_from_file(&path).await {
|
||||
return Some(token);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn copilot_config_paths() -> Vec<PathBuf> {
|
||||
let config_dir = if cfg!(target_os = "windows") {
|
||||
dirs::data_local_dir()
|
||||
} else {
|
||||
std::env::var("XDG_CONFIG_HOME")
|
||||
.map(PathBuf::from)
|
||||
.ok()
|
||||
.or_else(|| dirs::home_dir().map(|h| h.join(".config")))
|
||||
};
|
||||
|
||||
let Some(config_dir) = config_dir else {
|
||||
return Vec::new();
|
||||
};
|
||||
|
||||
let copilot_dir = config_dir.join("github-copilot");
|
||||
|
||||
vec![
|
||||
copilot_dir.join("hosts.json"),
|
||||
copilot_dir.join("apps.json"),
|
||||
]
|
||||
}
|
||||
|
||||
async fn read_oauth_token_from_file(path: &PathBuf) -> Option<String> {
|
||||
let contents = match smol::fs::read_to_string(path).await {
|
||||
Ok(contents) => contents,
|
||||
Err(_) => return None,
|
||||
};
|
||||
|
||||
extract_oauth_token(&contents, "github.com")
|
||||
}
|
||||
|
||||
fn extract_oauth_token(contents: &str, domain: &str) -> Option<String> {
|
||||
let value: serde_json::Value = serde_json::from_str(contents).ok()?;
|
||||
let obj = value.as_object()?;
|
||||
|
||||
for (key, value) in obj.iter() {
|
||||
if key.starts_with(domain) {
|
||||
if let Some(token) = value.get("oauth_token").and_then(|v| v.as_str()) {
|
||||
return Some(token.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
|
||||
#[test]
|
||||
fn test_extract_oauth_token_from_hosts_json() {
|
||||
let contents = r#"{
|
||||
"github.com": {
|
||||
"oauth_token": "ghu_test_token_12345"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let token = extract_oauth_token(contents, "github.com");
|
||||
assert_eq!(token, Some("ghu_test_token_12345".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_oauth_token_with_user_suffix() {
|
||||
let contents = r#"{
|
||||
"github.com:user": {
|
||||
"oauth_token": "ghu_another_token"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let token = extract_oauth_token(contents, "github.com");
|
||||
assert_eq!(token, Some("ghu_another_token".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_oauth_token_wrong_domain() {
|
||||
let contents = r#"{
|
||||
"gitlab.com": {
|
||||
"oauth_token": "some_token"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let token = extract_oauth_token(contents, "github.com");
|
||||
assert_eq!(token, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_oauth_token_invalid_json() {
|
||||
let contents = "not valid json";
|
||||
let token = extract_oauth_token(contents, "github.com");
|
||||
assert_eq!(token, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_oauth_token_missing_oauth_token_field() {
|
||||
let contents = r#"{
|
||||
"github.com": {
|
||||
"user": "testuser"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let token = extract_oauth_token(contents, "github.com");
|
||||
assert_eq!(token, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_oauth_token_multiple_entries_picks_first_match() {
|
||||
let contents = r#"{
|
||||
"gitlab.com": {
|
||||
"oauth_token": "gitlab_token"
|
||||
},
|
||||
"github.com": {
|
||||
"oauth_token": "github_token"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let token = extract_oauth_token(contents, "github.com");
|
||||
assert_eq!(token, Some("github_token".to_string()));
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_skips_migration_for_other_extensions(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
migrate_copilot_credentials_if_needed("some-other-extension", cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let credentials = cx.read_credentials("extension-llm-copilot-chat:copilot-chat");
|
||||
assert!(
|
||||
credentials.is_none(),
|
||||
"Should not create credentials for other extensions"
|
||||
);
|
||||
}
|
||||
|
||||
// Note: Unlike the other migrations, copilot migration reads from the filesystem
|
||||
// (copilot config files), not from the credentials provider. In tests, these files
|
||||
// don't exist, so no migration occurs.
|
||||
#[gpui::test]
|
||||
async fn test_no_credentials_when_no_copilot_config_exists(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
migrate_copilot_credentials_if_needed(COPILOT_CHAT_EXTENSION_ID, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let credentials = cx.read_credentials("extension-llm-copilot-chat:copilot-chat");
|
||||
assert!(
|
||||
credentials.is_none(),
|
||||
"No credentials should be written when copilot config doesn't exist"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,11 @@
|
||||
mod anthropic_migration;
|
||||
mod capability_granter;
|
||||
mod copilot_migration;
|
||||
pub mod extension_settings;
|
||||
mod google_ai_migration;
|
||||
pub mod headless_host;
|
||||
mod open_router_migration;
|
||||
mod openai_migration;
|
||||
pub mod wasm_host;
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -12,13 +17,14 @@ use async_tar::Archive;
|
||||
use client::ExtensionProvides;
|
||||
use client::{Client, ExtensionMetadata, GetExtensionsResponse, proto, telemetry::Telemetry};
|
||||
use collections::{BTreeMap, BTreeSet, HashSet, btree_map};
|
||||
|
||||
pub use extension::ExtensionManifest;
|
||||
use extension::extension_builder::{CompileExtensionOptions, ExtensionBuilder};
|
||||
use extension::{
|
||||
ExtensionContextServerProxy, ExtensionDebugAdapterProviderProxy, ExtensionEvents,
|
||||
ExtensionGrammarProxy, ExtensionHostProxy, ExtensionLanguageProxy,
|
||||
ExtensionLanguageServerProxy, ExtensionSlashCommandProxy, ExtensionSnippetProxy,
|
||||
ExtensionThemeProxy,
|
||||
ExtensionGrammarProxy, ExtensionHostProxy, ExtensionLanguageModelProviderProxy,
|
||||
ExtensionLanguageProxy, ExtensionLanguageServerProxy, ExtensionSlashCommandProxy,
|
||||
ExtensionSnippetProxy, ExtensionThemeProxy,
|
||||
};
|
||||
use fs::{Fs, RemoveOptions};
|
||||
use futures::future::join_all;
|
||||
@@ -32,8 +38,8 @@ use futures::{
|
||||
select_biased,
|
||||
};
|
||||
use gpui::{
|
||||
App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, Task, WeakEntity,
|
||||
actions,
|
||||
App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, SharedString, Task,
|
||||
WeakEntity, actions,
|
||||
};
|
||||
use http_client::{AsyncBody, HttpClient, HttpClientWithUrl};
|
||||
use language::{
|
||||
@@ -53,15 +59,28 @@ use std::{
|
||||
cmp::Ordering,
|
||||
path::{self, Path, PathBuf},
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
time::Duration,
|
||||
};
|
||||
use url::Url;
|
||||
use util::{ResultExt, paths::RemotePathBuf};
|
||||
use wasm_host::llm_provider::ExtensionLanguageModelProvider;
|
||||
use wasm_host::{
|
||||
WasmExtension, WasmHost,
|
||||
wit::{is_supported_wasm_api_version, wasm_api_version_range},
|
||||
wit::{
|
||||
LlmCacheConfiguration, LlmModelInfo, LlmProviderInfo, is_supported_wasm_api_version,
|
||||
wasm_api_version_range,
|
||||
},
|
||||
};
|
||||
|
||||
struct LlmProviderWithModels {
|
||||
provider_info: LlmProviderInfo,
|
||||
models: Vec<LlmModelInfo>,
|
||||
cache_configs: collections::HashMap<String, LlmCacheConfiguration>,
|
||||
is_authenticated: bool,
|
||||
icon_path: Option<SharedString>,
|
||||
auth_config: Option<extension::LanguageModelAuthConfig>,
|
||||
}
|
||||
|
||||
pub use extension::{
|
||||
ExtensionLibraryKind, GrammarManifestEntry, OldExtensionManifest, SchemaVersion,
|
||||
};
|
||||
@@ -70,6 +89,82 @@ pub use extension_settings::ExtensionSettings;
|
||||
pub const RELOAD_DEBOUNCE_DURATION: Duration = Duration::from_millis(200);
|
||||
const FS_WATCH_LATENCY: Duration = Duration::from_millis(100);
|
||||
|
||||
/// Extension IDs that are being migrated from hardcoded LLM providers.
|
||||
/// For backwards compatibility, if the user has the corresponding env var set,
|
||||
/// we automatically enable env var reading for these extensions on first install.
|
||||
pub const LEGACY_LLM_EXTENSION_IDS: &[&str] = &[
|
||||
"anthropic",
|
||||
"copilot-chat",
|
||||
"google-ai",
|
||||
"openrouter",
|
||||
"openai",
|
||||
];
|
||||
|
||||
/// Migrates legacy LLM provider extensions by auto-enabling env var reading
|
||||
/// if the env var is currently present in the environment.
|
||||
///
|
||||
/// This is idempotent: if the env var is already in `allowed_env_vars`,
|
||||
/// we skip. This means if a user explicitly removes it, it will be re-added on
|
||||
/// next launch if the env var is still set - but that's predictable behavior.
|
||||
fn migrate_legacy_llm_provider_env_var(manifest: &ExtensionManifest, cx: &mut App) {
|
||||
// Only apply migration to known legacy LLM extensions
|
||||
if !LEGACY_LLM_EXTENSION_IDS.contains(&manifest.id.as_ref()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Check each provider in the manifest
|
||||
for (provider_id, provider_entry) in &manifest.language_model_providers {
|
||||
let Some(auth_config) = &provider_entry.auth else {
|
||||
continue;
|
||||
};
|
||||
let Some(env_vars) = &auth_config.env_vars else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let full_provider_id = format!("{}:{}", manifest.id, provider_id);
|
||||
|
||||
// For each env var, check if it's set and enable it if so
|
||||
for env_var_name in env_vars {
|
||||
let env_var_is_set = std::env::var(env_var_name)
|
||||
.map(|v| !v.is_empty())
|
||||
.unwrap_or(false);
|
||||
|
||||
if !env_var_is_set {
|
||||
continue;
|
||||
}
|
||||
|
||||
let settings_key: Arc<str> = format!("{}:{}", full_provider_id, env_var_name).into();
|
||||
|
||||
// Check if already enabled in settings
|
||||
let already_enabled = ExtensionSettings::get_global(cx)
|
||||
.allowed_env_var_providers
|
||||
.contains(settings_key.as_ref());
|
||||
|
||||
if already_enabled {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Enable env var reading since the env var is set
|
||||
settings::update_settings_file(<dyn fs::Fs>::global(cx), cx, {
|
||||
let settings_key = settings_key.clone();
|
||||
move |settings, _| {
|
||||
let allowed = settings
|
||||
.extension
|
||||
.allowed_env_var_providers
|
||||
.get_or_insert_with(Vec::new);
|
||||
|
||||
if !allowed
|
||||
.iter()
|
||||
.any(|id| id.as_ref() == settings_key.as_ref())
|
||||
{
|
||||
allowed.push(settings_key);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The current extension [`SchemaVersion`] supported by Zed.
|
||||
const CURRENT_SCHEMA_VERSION: SchemaVersion = SchemaVersion(1);
|
||||
|
||||
@@ -131,6 +226,8 @@ pub struct ExtensionStore {
|
||||
pub enum ExtensionOperation {
|
||||
Upgrade,
|
||||
Install,
|
||||
/// Auto-install from settings - triggers legacy LLM provider migrations
|
||||
AutoInstall,
|
||||
Remove,
|
||||
}
|
||||
|
||||
@@ -613,8 +710,60 @@ impl ExtensionStore {
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
for extension_id in extensions_to_install {
|
||||
// When enabled, this checks if an extension exists locally in the repo's extensions/
|
||||
// directory and installs it as a dev extension instead of fetching from the registry.
|
||||
// This is useful for testing auto-installed extensions before they've been published.
|
||||
// Set to `true` only during local development/testing of new auto-install extensions.
|
||||
#[cfg(debug_assertions)]
|
||||
const DEBUG_ALLOW_UNPUBLISHED_AUTO_EXTENSIONS: bool = false;
|
||||
|
||||
#[cfg(debug_assertions)]
|
||||
if DEBUG_ALLOW_UNPUBLISHED_AUTO_EXTENSIONS {
|
||||
let local_extension_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
|
||||
.parent()
|
||||
.unwrap()
|
||||
.parent()
|
||||
.unwrap()
|
||||
.join("extensions")
|
||||
.join(extension_id.as_ref());
|
||||
|
||||
if local_extension_path.exists() {
|
||||
// Force-remove existing extension directory if it exists and isn't a symlink
|
||||
// This handles the case where the extension was previously installed from the registry
|
||||
if let Some(installed_dir) = this
|
||||
.update(cx, |this, _cx| this.installed_dir.clone())
|
||||
.ok()
|
||||
{
|
||||
let existing_path = installed_dir.join(extension_id.as_ref());
|
||||
if existing_path.exists() {
|
||||
let metadata = std::fs::symlink_metadata(&existing_path);
|
||||
let is_symlink = metadata.map(|m| m.is_symlink()).unwrap_or(false);
|
||||
if !is_symlink {
|
||||
if let Err(e) = std::fs::remove_dir_all(&existing_path) {
|
||||
log::error!(
|
||||
"Failed to remove existing extension directory {:?}: {}",
|
||||
existing_path,
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(task) = this
|
||||
.update(cx, |this, cx| {
|
||||
this.install_dev_extension(local_extension_path, cx)
|
||||
})
|
||||
.ok()
|
||||
{
|
||||
task.await.log_err();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.install_latest_extension(extension_id.clone(), cx);
|
||||
this.auto_install_latest_extension(extension_id.clone(), cx);
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
@@ -769,7 +918,10 @@ impl ExtensionStore {
|
||||
this.update(cx, |this, cx| this.reload(Some(extension_id.clone()), cx))?
|
||||
.await;
|
||||
|
||||
if let ExtensionOperation::Install = operation {
|
||||
if matches!(
|
||||
operation,
|
||||
ExtensionOperation::Install | ExtensionOperation::AutoInstall
|
||||
) {
|
||||
this.update(cx, |this, cx| {
|
||||
cx.emit(Event::ExtensionInstalled(extension_id.clone()));
|
||||
if let Some(events) = ExtensionEvents::try_global(cx)
|
||||
@@ -779,6 +931,27 @@ impl ExtensionStore {
|
||||
this.emit(extension::Event::ExtensionInstalled(manifest.clone()), cx)
|
||||
});
|
||||
}
|
||||
|
||||
// Run legacy LLM provider migrations only for auto-installed extensions
|
||||
if matches!(operation, ExtensionOperation::AutoInstall) {
|
||||
if let Some(manifest) = this.extension_manifest_for_id(&extension_id) {
|
||||
migrate_legacy_llm_provider_env_var(&manifest, cx);
|
||||
}
|
||||
copilot_migration::migrate_copilot_credentials_if_needed(&extension_id, cx);
|
||||
anthropic_migration::migrate_anthropic_credentials_if_needed(
|
||||
&extension_id,
|
||||
cx,
|
||||
);
|
||||
google_ai_migration::migrate_google_ai_credentials_if_needed(
|
||||
&extension_id,
|
||||
cx,
|
||||
);
|
||||
openai_migration::migrate_openai_credentials_if_needed(&extension_id, cx);
|
||||
open_router_migration::migrate_open_router_credentials_if_needed(
|
||||
&extension_id,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
})
|
||||
.ok();
|
||||
}
|
||||
@@ -788,8 +961,24 @@ impl ExtensionStore {
|
||||
}
|
||||
|
||||
pub fn install_latest_extension(&mut self, extension_id: Arc<str>, cx: &mut Context<Self>) {
|
||||
log::info!("installing extension {extension_id} latest version");
|
||||
self.install_latest_extension_with_operation(extension_id, ExtensionOperation::Install, cx);
|
||||
}
|
||||
|
||||
/// Auto-install an extension, triggering legacy LLM provider migrations.
|
||||
fn auto_install_latest_extension(&mut self, extension_id: Arc<str>, cx: &mut Context<Self>) {
|
||||
self.install_latest_extension_with_operation(
|
||||
extension_id,
|
||||
ExtensionOperation::AutoInstall,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
|
||||
fn install_latest_extension_with_operation(
|
||||
&mut self,
|
||||
extension_id: Arc<str>,
|
||||
operation: ExtensionOperation,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
let schema_versions = schema_version_range();
|
||||
let wasm_api_versions = wasm_api_version_range(ReleaseChannel::global(cx));
|
||||
|
||||
@@ -812,13 +1001,8 @@ impl ExtensionStore {
|
||||
return;
|
||||
};
|
||||
|
||||
self.install_or_upgrade_extension_at_endpoint(
|
||||
extension_id,
|
||||
url,
|
||||
ExtensionOperation::Install,
|
||||
cx,
|
||||
)
|
||||
.detach_and_log_err(cx);
|
||||
self.install_or_upgrade_extension_at_endpoint(extension_id, url, operation, cx)
|
||||
.detach_and_log_err(cx);
|
||||
}
|
||||
|
||||
pub fn upgrade_extension(
|
||||
@@ -837,7 +1021,6 @@ impl ExtensionStore {
|
||||
operation: ExtensionOperation,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Task<Result<()>> {
|
||||
log::info!("installing extension {extension_id} {version}");
|
||||
let Some(url) = self
|
||||
.http_client
|
||||
.build_zed_api_url(
|
||||
@@ -1013,9 +1196,37 @@ impl ExtensionStore {
|
||||
}
|
||||
}
|
||||
|
||||
fs.create_symlink(output_path, extension_source_path)
|
||||
fs.create_symlink(output_path, extension_source_path.clone())
|
||||
.await?;
|
||||
|
||||
// Re-load manifest and run migrations before reload so settings are updated before providers are registered
|
||||
let manifest_for_migration =
|
||||
ExtensionManifest::load(fs.clone(), &extension_source_path).await?;
|
||||
this.update(cx, |_this, cx| {
|
||||
migrate_legacy_llm_provider_env_var(&manifest_for_migration, cx);
|
||||
// Also run credential migrations for dev extensions
|
||||
copilot_migration::migrate_copilot_credentials_if_needed(
|
||||
manifest_for_migration.id.as_ref(),
|
||||
cx,
|
||||
);
|
||||
anthropic_migration::migrate_anthropic_credentials_if_needed(
|
||||
manifest_for_migration.id.as_ref(),
|
||||
cx,
|
||||
);
|
||||
google_ai_migration::migrate_google_ai_credentials_if_needed(
|
||||
manifest_for_migration.id.as_ref(),
|
||||
cx,
|
||||
);
|
||||
openai_migration::migrate_openai_credentials_if_needed(
|
||||
manifest_for_migration.id.as_ref(),
|
||||
cx,
|
||||
);
|
||||
open_router_migration::migrate_open_router_credentials_if_needed(
|
||||
manifest_for_migration.id.as_ref(),
|
||||
cx,
|
||||
);
|
||||
})?;
|
||||
|
||||
this.update(cx, |this, cx| this.reload(None, cx))?.await;
|
||||
this.update(cx, |this, cx| {
|
||||
cx.emit(Event::ExtensionInstalled(extension_id.clone()));
|
||||
@@ -1134,18 +1345,6 @@ impl ExtensionStore {
|
||||
return Task::ready(());
|
||||
}
|
||||
|
||||
let reload_count = extensions_to_unload
|
||||
.iter()
|
||||
.filter(|id| extensions_to_load.contains(id))
|
||||
.count();
|
||||
|
||||
log::info!(
|
||||
"extensions updated. loading {}, reloading {}, unloading {}",
|
||||
extensions_to_load.len() - reload_count,
|
||||
reload_count,
|
||||
extensions_to_unload.len() - reload_count
|
||||
);
|
||||
|
||||
let extension_ids = extensions_to_load
|
||||
.iter()
|
||||
.filter_map(|id| {
|
||||
@@ -1220,6 +1419,11 @@ impl ExtensionStore {
|
||||
for command_name in extension.manifest.slash_commands.keys() {
|
||||
self.proxy.unregister_slash_command(command_name.clone());
|
||||
}
|
||||
for provider_id in extension.manifest.language_model_providers.keys() {
|
||||
let full_provider_id: Arc<str> = format!("{}:{}", extension_id, provider_id).into();
|
||||
self.proxy
|
||||
.unregister_language_model_provider(full_provider_id, cx);
|
||||
}
|
||||
}
|
||||
|
||||
self.wasm_extensions
|
||||
@@ -1358,7 +1562,11 @@ impl ExtensionStore {
|
||||
})
|
||||
.await;
|
||||
|
||||
let mut wasm_extensions = Vec::new();
|
||||
let mut wasm_extensions: Vec<(
|
||||
Arc<ExtensionManifest>,
|
||||
WasmExtension,
|
||||
Vec<LlmProviderWithModels>,
|
||||
)> = Vec::new();
|
||||
for extension in extension_entries {
|
||||
if extension.manifest.lib.kind.is_none() {
|
||||
continue;
|
||||
@@ -1376,7 +1584,149 @@ impl ExtensionStore {
|
||||
|
||||
match wasm_extension {
|
||||
Ok(wasm_extension) => {
|
||||
wasm_extensions.push((extension.manifest.clone(), wasm_extension))
|
||||
// Query for LLM providers if the manifest declares any
|
||||
let mut llm_providers_with_models = Vec::new();
|
||||
if !extension.manifest.language_model_providers.is_empty() {
|
||||
let providers_result = wasm_extension
|
||||
.call(|ext, store| {
|
||||
async move { ext.call_llm_providers(store).await }.boxed()
|
||||
})
|
||||
.await;
|
||||
|
||||
if let Ok(Ok(providers)) = providers_result {
|
||||
for provider_info in providers {
|
||||
let models_result = wasm_extension
|
||||
.call({
|
||||
let provider_id = provider_info.id.clone();
|
||||
|ext, store| {
|
||||
async move {
|
||||
ext.call_llm_provider_models(store, &provider_id)
|
||||
.await
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
let models: Vec<LlmModelInfo> = match models_result {
|
||||
Ok(Ok(Ok(models))) => models,
|
||||
Ok(Ok(Err(e))) => {
|
||||
log::error!(
|
||||
"Failed to get models for LLM provider {} in extension {}: {}",
|
||||
provider_info.id,
|
||||
extension.manifest.id,
|
||||
e
|
||||
);
|
||||
Vec::new()
|
||||
}
|
||||
Ok(Err(e)) => {
|
||||
log::error!(
|
||||
"Wasm error calling llm_provider_models for {} in extension {}: {:?}",
|
||||
provider_info.id,
|
||||
extension.manifest.id,
|
||||
e
|
||||
);
|
||||
Vec::new()
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!(
|
||||
"Extension call failed for llm_provider_models {} in extension {}: {:?}",
|
||||
provider_info.id,
|
||||
extension.manifest.id,
|
||||
e
|
||||
);
|
||||
Vec::new()
|
||||
}
|
||||
};
|
||||
|
||||
// Query cache configurations for each model
|
||||
let mut cache_configs = collections::HashMap::default();
|
||||
for model in &models {
|
||||
let cache_config_result = wasm_extension
|
||||
.call({
|
||||
let provider_id = provider_info.id.clone();
|
||||
let model_id = model.id.clone();
|
||||
|ext, store| {
|
||||
async move {
|
||||
ext.call_llm_cache_configuration(
|
||||
store,
|
||||
&provider_id,
|
||||
&model_id,
|
||||
)
|
||||
.await
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
if let Ok(Ok(Some(config))) = cache_config_result {
|
||||
cache_configs.insert(model.id.clone(), config);
|
||||
}
|
||||
}
|
||||
|
||||
// Query initial authentication state
|
||||
let is_authenticated = wasm_extension
|
||||
.call({
|
||||
let provider_id = provider_info.id.clone();
|
||||
|ext, store| {
|
||||
async move {
|
||||
ext.call_llm_provider_is_authenticated(
|
||||
store,
|
||||
&provider_id,
|
||||
)
|
||||
.await
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
})
|
||||
.await
|
||||
.unwrap_or(Ok(false))
|
||||
.unwrap_or(false);
|
||||
|
||||
// Resolve icon path if provided
|
||||
let icon_path = provider_info.icon.as_ref().map(|icon| {
|
||||
let icon_file_path = extension_path.join(icon);
|
||||
// Canonicalize to resolve symlinks (dev extensions are symlinked)
|
||||
let absolute_icon_path = icon_file_path
|
||||
.canonicalize()
|
||||
.unwrap_or(icon_file_path)
|
||||
.to_string_lossy()
|
||||
.to_string();
|
||||
SharedString::from(absolute_icon_path)
|
||||
});
|
||||
|
||||
let provider_id_arc: Arc<str> =
|
||||
provider_info.id.as_str().into();
|
||||
let auth_config = extension
|
||||
.manifest
|
||||
.language_model_providers
|
||||
.get(&provider_id_arc)
|
||||
.and_then(|entry| entry.auth.clone());
|
||||
|
||||
llm_providers_with_models.push(LlmProviderWithModels {
|
||||
provider_info,
|
||||
models,
|
||||
cache_configs,
|
||||
is_authenticated,
|
||||
icon_path,
|
||||
auth_config,
|
||||
});
|
||||
}
|
||||
} else {
|
||||
log::error!(
|
||||
"Failed to get LLM providers from extension {}: {:?}",
|
||||
extension.manifest.id,
|
||||
providers_result
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
wasm_extensions.push((
|
||||
extension.manifest.clone(),
|
||||
wasm_extension,
|
||||
llm_providers_with_models,
|
||||
))
|
||||
}
|
||||
Err(e) => {
|
||||
log::error!(
|
||||
@@ -1395,7 +1745,7 @@ impl ExtensionStore {
|
||||
this.update(cx, |this, cx| {
|
||||
this.reload_complete_senders.clear();
|
||||
|
||||
for (manifest, wasm_extension) in &wasm_extensions {
|
||||
for (manifest, wasm_extension, llm_providers_with_models) in &wasm_extensions {
|
||||
let extension = Arc::new(wasm_extension.clone());
|
||||
|
||||
for (language_server_id, language_server_config) in &manifest.language_servers {
|
||||
@@ -1449,9 +1799,42 @@ impl ExtensionStore {
|
||||
this.proxy
|
||||
.register_debug_locator(extension.clone(), debug_adapter.clone());
|
||||
}
|
||||
|
||||
// Register LLM providers
|
||||
for llm_provider in llm_providers_with_models {
|
||||
let provider_id: Arc<str> =
|
||||
format!("{}:{}", manifest.id, llm_provider.provider_info.id).into();
|
||||
let wasm_ext = extension.as_ref().clone();
|
||||
let pinfo = llm_provider.provider_info.clone();
|
||||
let mods = llm_provider.models.clone();
|
||||
let cache_cfgs = llm_provider.cache_configs.clone();
|
||||
let auth = llm_provider.is_authenticated;
|
||||
let icon = llm_provider.icon_path.clone();
|
||||
let auth_config = llm_provider.auth_config.clone();
|
||||
|
||||
this.proxy.register_language_model_provider(
|
||||
provider_id.clone(),
|
||||
Box::new(move |cx: &mut App| {
|
||||
let provider = Arc::new(ExtensionLanguageModelProvider::new(
|
||||
wasm_ext, pinfo, mods, cache_cfgs, auth, icon, auth_config, cx,
|
||||
));
|
||||
language_model::LanguageModelRegistry::global(cx).update(
|
||||
cx,
|
||||
|registry, cx| {
|
||||
registry.register_provider(provider, cx);
|
||||
},
|
||||
);
|
||||
}),
|
||||
cx,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
this.wasm_extensions.extend(wasm_extensions);
|
||||
let wasm_extensions_without_llm: Vec<_> = wasm_extensions
|
||||
.into_iter()
|
||||
.map(|(manifest, ext, _)| (manifest, ext))
|
||||
.collect();
|
||||
this.wasm_extensions.extend(wasm_extensions_without_llm);
|
||||
this.proxy.set_extensions_loaded();
|
||||
this.proxy.reload_current_theme(cx);
|
||||
this.proxy.reload_current_icon_theme(cx);
|
||||
@@ -1473,7 +1856,6 @@ impl ExtensionStore {
|
||||
let index_path = self.index_path.clone();
|
||||
let proxy = self.proxy.clone();
|
||||
cx.background_spawn(async move {
|
||||
let start_time = Instant::now();
|
||||
let mut index = ExtensionIndex::default();
|
||||
|
||||
fs.create_dir(&work_dir).await.log_err();
|
||||
@@ -1511,7 +1893,6 @@ impl ExtensionStore {
|
||||
.log_err();
|
||||
}
|
||||
|
||||
log::info!("rebuilt extension index in {:?}", start_time.elapsed());
|
||||
index
|
||||
})
|
||||
}
|
||||
@@ -1785,11 +2166,6 @@ impl ExtensionStore {
|
||||
})?,
|
||||
path_style,
|
||||
);
|
||||
log::info!(
|
||||
"Uploading extension {} to {:?}",
|
||||
missing_extension.clone().id,
|
||||
dest_dir
|
||||
);
|
||||
|
||||
client
|
||||
.update(cx, |client, cx| {
|
||||
@@ -1797,11 +2173,6 @@ impl ExtensionStore {
|
||||
})?
|
||||
.await?;
|
||||
|
||||
log::info!(
|
||||
"Finished uploading extension {}",
|
||||
missing_extension.clone().id
|
||||
);
|
||||
|
||||
let result = client
|
||||
.update(cx, |client, _cx| {
|
||||
client.proto_client().request(proto::InstallExtension {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use collections::HashMap;
|
||||
use collections::{HashMap, HashSet};
|
||||
use extension::{
|
||||
DownloadFileCapability, ExtensionCapability, NpmInstallPackageCapability, ProcessExecCapability,
|
||||
};
|
||||
@@ -16,6 +16,10 @@ pub struct ExtensionSettings {
|
||||
pub auto_install_extensions: HashMap<Arc<str>, bool>,
|
||||
pub auto_update_extensions: HashMap<Arc<str>, bool>,
|
||||
pub granted_capabilities: Vec<ExtensionCapability>,
|
||||
/// The extension language model providers that are allowed to read API keys
|
||||
/// from environment variables. Each entry is in the format
|
||||
/// "extension_id:provider_id:ENV_VAR_NAME".
|
||||
pub allowed_env_var_providers: HashSet<Arc<str>>,
|
||||
}
|
||||
|
||||
impl ExtensionSettings {
|
||||
@@ -60,6 +64,13 @@ impl Settings for ExtensionSettings {
|
||||
}
|
||||
})
|
||||
.collect(),
|
||||
allowed_env_var_providers: content
|
||||
.extension
|
||||
.allowed_env_var_providers
|
||||
.clone()
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
124
crates/extension_host/src/google_ai_migration.rs
Normal file
124
crates/extension_host/src/google_ai_migration.rs
Normal file
@@ -0,0 +1,124 @@
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use gpui::App;
|
||||
|
||||
const GOOGLE_AI_EXTENSION_ID: &str = "google-ai";
|
||||
const GOOGLE_AI_PROVIDER_ID: &str = "google-ai";
|
||||
const GOOGLE_AI_DEFAULT_API_URL: &str = "https://generativelanguage.googleapis.com";
|
||||
|
||||
/// Migrates Google AI API credentials from the old built-in provider location
|
||||
/// to the new extension-based location.
|
||||
///
|
||||
/// This should only be called during auto-install of the extension.
|
||||
pub fn migrate_google_ai_credentials_if_needed(extension_id: &str, cx: &mut App) {
|
||||
if extension_id != GOOGLE_AI_EXTENSION_ID {
|
||||
return;
|
||||
}
|
||||
|
||||
let extension_credential_key = format!(
|
||||
"extension-llm-{}:{}",
|
||||
GOOGLE_AI_EXTENSION_ID, GOOGLE_AI_PROVIDER_ID
|
||||
);
|
||||
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
// Read from old location
|
||||
let old_credential = credentials_provider
|
||||
.read_credentials(GOOGLE_AI_DEFAULT_API_URL, &cx)
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
|
||||
let api_key = match old_credential {
|
||||
Some((_, key_bytes)) => match String::from_utf8(key_bytes) {
|
||||
Ok(key) if !key.is_empty() => key,
|
||||
Ok(_) => {
|
||||
log::debug!("Existing Google AI API key is empty, nothing to migrate");
|
||||
return;
|
||||
}
|
||||
Err(_) => {
|
||||
log::error!("Failed to decode Google AI API key as UTF-8");
|
||||
return;
|
||||
}
|
||||
},
|
||||
None => {
|
||||
log::debug!("No existing Google AI API key found to migrate");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
log::info!("Migrating existing Google AI API key to Google AI extension");
|
||||
|
||||
match credentials_provider
|
||||
.write_credentials(&extension_credential_key, "Bearer", api_key.as_bytes(), &cx)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
log::info!("Successfully migrated Google AI API key to extension");
|
||||
}
|
||||
Err(err) => {
|
||||
log::error!("Failed to migrate Google AI API key: {}", err);
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_migrates_credentials_from_old_location(cx: &mut TestAppContext) {
|
||||
let api_key = "AIzaSy-test-key-12345";
|
||||
|
||||
cx.write_credentials(GOOGLE_AI_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
|
||||
|
||||
cx.update(|cx| {
|
||||
migrate_google_ai_credentials_if_needed(GOOGLE_AI_EXTENSION_ID, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let migrated = cx.read_credentials("extension-llm-google-ai:google-ai");
|
||||
assert!(migrated.is_some(), "Credentials should have been migrated");
|
||||
let (username, password) = migrated.unwrap();
|
||||
assert_eq!(username, "Bearer");
|
||||
assert_eq!(String::from_utf8(password).unwrap(), api_key);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_no_migration_if_no_old_credentials(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
migrate_google_ai_credentials_if_needed(GOOGLE_AI_EXTENSION_ID, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let credentials = cx.read_credentials("extension-llm-google-ai:google-ai");
|
||||
assert!(
|
||||
credentials.is_none(),
|
||||
"Should not create credentials if none existed"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_skips_migration_for_other_extensions(cx: &mut TestAppContext) {
|
||||
let api_key = "AIzaSy-test-key";
|
||||
|
||||
cx.write_credentials(GOOGLE_AI_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
|
||||
|
||||
cx.update(|cx| {
|
||||
migrate_google_ai_credentials_if_needed("some-other-extension", cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let credentials = cx.read_credentials("extension-llm-google-ai:google-ai");
|
||||
assert!(
|
||||
credentials.is_none(),
|
||||
"Should not migrate for other extensions"
|
||||
);
|
||||
}
|
||||
}
|
||||
124
crates/extension_host/src/open_router_migration.rs
Normal file
124
crates/extension_host/src/open_router_migration.rs
Normal file
@@ -0,0 +1,124 @@
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use gpui::App;
|
||||
|
||||
const OPEN_ROUTER_EXTENSION_ID: &str = "openrouter";
|
||||
const OPEN_ROUTER_PROVIDER_ID: &str = "openrouter";
|
||||
const OPEN_ROUTER_DEFAULT_API_URL: &str = "https://openrouter.ai/api/v1";
|
||||
|
||||
/// Migrates OpenRouter API credentials from the old built-in provider location
|
||||
/// to the new extension-based location.
|
||||
///
|
||||
/// This should only be called during auto-install of the extension.
|
||||
pub fn migrate_open_router_credentials_if_needed(extension_id: &str, cx: &mut App) {
|
||||
if extension_id != OPEN_ROUTER_EXTENSION_ID {
|
||||
return;
|
||||
}
|
||||
|
||||
let extension_credential_key = format!(
|
||||
"extension-llm-{}:{}",
|
||||
OPEN_ROUTER_EXTENSION_ID, OPEN_ROUTER_PROVIDER_ID
|
||||
);
|
||||
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
// Read from old location
|
||||
let old_credential = credentials_provider
|
||||
.read_credentials(OPEN_ROUTER_DEFAULT_API_URL, &cx)
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
|
||||
let api_key = match old_credential {
|
||||
Some((_, key_bytes)) => match String::from_utf8(key_bytes) {
|
||||
Ok(key) if !key.is_empty() => key,
|
||||
Ok(_) => {
|
||||
log::debug!("Existing OpenRouter API key is empty, nothing to migrate");
|
||||
return;
|
||||
}
|
||||
Err(_) => {
|
||||
log::error!("Failed to decode OpenRouter API key as UTF-8");
|
||||
return;
|
||||
}
|
||||
},
|
||||
None => {
|
||||
log::debug!("No existing OpenRouter API key found to migrate");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
log::info!("Migrating existing OpenRouter API key to OpenRouter extension");
|
||||
|
||||
match credentials_provider
|
||||
.write_credentials(&extension_credential_key, "Bearer", api_key.as_bytes(), &cx)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
log::info!("Successfully migrated OpenRouter API key to extension");
|
||||
}
|
||||
Err(err) => {
|
||||
log::error!("Failed to migrate OpenRouter API key: {}", err);
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_migrates_credentials_from_old_location(cx: &mut TestAppContext) {
|
||||
let api_key = "sk-or-test-key-12345";
|
||||
|
||||
cx.write_credentials(OPEN_ROUTER_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
|
||||
|
||||
cx.update(|cx| {
|
||||
migrate_open_router_credentials_if_needed(OPEN_ROUTER_EXTENSION_ID, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let migrated = cx.read_credentials("extension-llm-openrouter:openrouter");
|
||||
assert!(migrated.is_some(), "Credentials should have been migrated");
|
||||
let (username, password) = migrated.unwrap();
|
||||
assert_eq!(username, "Bearer");
|
||||
assert_eq!(String::from_utf8(password).unwrap(), api_key);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_no_migration_if_no_old_credentials(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
migrate_open_router_credentials_if_needed(OPEN_ROUTER_EXTENSION_ID, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let credentials = cx.read_credentials("extension-llm-openrouter:openrouter");
|
||||
assert!(
|
||||
credentials.is_none(),
|
||||
"Should not create credentials if none existed"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_skips_migration_for_other_extensions(cx: &mut TestAppContext) {
|
||||
let api_key = "sk-or-test-key";
|
||||
|
||||
cx.write_credentials(OPEN_ROUTER_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
|
||||
|
||||
cx.update(|cx| {
|
||||
migrate_open_router_credentials_if_needed("some-other-extension", cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let credentials = cx.read_credentials("extension-llm-openrouter:openrouter");
|
||||
assert!(
|
||||
credentials.is_none(),
|
||||
"Should not migrate for other extensions"
|
||||
);
|
||||
}
|
||||
}
|
||||
124
crates/extension_host/src/openai_migration.rs
Normal file
124
crates/extension_host/src/openai_migration.rs
Normal file
@@ -0,0 +1,124 @@
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use gpui::App;
|
||||
|
||||
const OPENAI_EXTENSION_ID: &str = "openai";
|
||||
const OPENAI_PROVIDER_ID: &str = "openai";
|
||||
const OPENAI_DEFAULT_API_URL: &str = "https://api.openai.com/v1";
|
||||
|
||||
/// Migrates OpenAI API credentials from the old built-in provider location
|
||||
/// to the new extension-based location.
|
||||
///
|
||||
/// This should only be called during auto-install of the extension.
|
||||
pub fn migrate_openai_credentials_if_needed(extension_id: &str, cx: &mut App) {
|
||||
if extension_id != OPENAI_EXTENSION_ID {
|
||||
return;
|
||||
}
|
||||
|
||||
let extension_credential_key = format!(
|
||||
"extension-llm-{}:{}",
|
||||
OPENAI_EXTENSION_ID, OPENAI_PROVIDER_ID
|
||||
);
|
||||
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
// Read from old location
|
||||
let old_credential = credentials_provider
|
||||
.read_credentials(OPENAI_DEFAULT_API_URL, &cx)
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
|
||||
let api_key = match old_credential {
|
||||
Some((_, key_bytes)) => match String::from_utf8(key_bytes) {
|
||||
Ok(key) if !key.is_empty() => key,
|
||||
Ok(_) => {
|
||||
log::debug!("Existing OpenAI API key is empty, nothing to migrate");
|
||||
return;
|
||||
}
|
||||
Err(_) => {
|
||||
log::error!("Failed to decode OpenAI API key as UTF-8");
|
||||
return;
|
||||
}
|
||||
},
|
||||
None => {
|
||||
log::debug!("No existing OpenAI API key found to migrate");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
log::info!("Migrating existing OpenAI API key to OpenAI extension");
|
||||
|
||||
match credentials_provider
|
||||
.write_credentials(&extension_credential_key, "Bearer", api_key.as_bytes(), &cx)
|
||||
.await
|
||||
{
|
||||
Ok(()) => {
|
||||
log::info!("Successfully migrated OpenAI API key to extension");
|
||||
}
|
||||
Err(err) => {
|
||||
log::error!("Failed to migrate OpenAI API key: {}", err);
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_migrates_credentials_from_old_location(cx: &mut TestAppContext) {
|
||||
let api_key = "sk-test-key-12345";
|
||||
|
||||
cx.write_credentials(OPENAI_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
|
||||
|
||||
cx.update(|cx| {
|
||||
migrate_openai_credentials_if_needed(OPENAI_EXTENSION_ID, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let migrated = cx.read_credentials("extension-llm-openai:openai");
|
||||
assert!(migrated.is_some(), "Credentials should have been migrated");
|
||||
let (username, password) = migrated.unwrap();
|
||||
assert_eq!(username, "Bearer");
|
||||
assert_eq!(String::from_utf8(password).unwrap(), api_key);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_no_migration_if_no_old_credentials(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
migrate_openai_credentials_if_needed(OPENAI_EXTENSION_ID, cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let credentials = cx.read_credentials("extension-llm-openai:openai");
|
||||
assert!(
|
||||
credentials.is_none(),
|
||||
"Should not create credentials if none existed"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_skips_migration_for_other_extensions(cx: &mut TestAppContext) {
|
||||
let api_key = "sk-test-key";
|
||||
|
||||
cx.write_credentials(OPENAI_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
|
||||
|
||||
cx.update(|cx| {
|
||||
migrate_openai_credentials_if_needed("some-other-extension", cx);
|
||||
});
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
let credentials = cx.read_credentials("extension-llm-openai:openai");
|
||||
assert!(
|
||||
credentials.is_none(),
|
||||
"Should not migrate for other extensions"
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,11 @@
|
||||
pub mod llm_provider;
|
||||
pub mod wit;
|
||||
|
||||
use crate::capability_granter::CapabilityGranter;
|
||||
use crate::{ExtensionManifest, ExtensionSettings};
|
||||
use anyhow::{Context as _, Result, anyhow, bail};
|
||||
use async_trait::async_trait;
|
||||
|
||||
use dap::{DebugRequest, StartDebuggingRequestArgumentsRequest};
|
||||
use extension::{
|
||||
CodeLabel, Command, Completion, ContextServerConfiguration, DebugAdapterBinary,
|
||||
@@ -64,7 +66,7 @@ pub struct WasmHost {
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct WasmExtension {
|
||||
tx: UnboundedSender<ExtensionCall>,
|
||||
tx: Arc<UnboundedSender<ExtensionCall>>,
|
||||
pub manifest: Arc<ExtensionManifest>,
|
||||
pub work_dir: Arc<Path>,
|
||||
#[allow(unused)]
|
||||
@@ -74,7 +76,10 @@ pub struct WasmExtension {
|
||||
|
||||
impl Drop for WasmExtension {
|
||||
fn drop(&mut self) {
|
||||
self.tx.close_channel();
|
||||
// Only close the channel when this is the last clone holding the sender
|
||||
if Arc::strong_count(&self.tx) == 1 {
|
||||
self.tx.close_channel();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -671,7 +676,7 @@ impl WasmHost {
|
||||
Ok(WasmExtension {
|
||||
manifest,
|
||||
work_dir,
|
||||
tx,
|
||||
tx: Arc::new(tx),
|
||||
zed_api_version,
|
||||
_task: task,
|
||||
})
|
||||
|
||||
1958
crates/extension_host/src/wasm_host/llm_provider.rs
Normal file
1958
crates/extension_host/src/wasm_host/llm_provider.rs
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,12 @@
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Authentication Complete</title>
|
||||
</head>
|
||||
<body style="font-family: system-ui, sans-serif; display: flex; justify-content: center; align-items: center; height: 100vh; margin: 0;">
|
||||
<div style="text-align: center;">
|
||||
<h1>Authentication Complete</h1>
|
||||
<p>You can close this window and return to Zed.</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
@@ -16,7 +16,7 @@ use lsp::LanguageServerName;
|
||||
use release_channel::ReleaseChannel;
|
||||
use task::{DebugScenario, SpawnInTerminal, TaskTemplate, ZedDebugConfig};
|
||||
|
||||
use crate::wasm_host::wit::since_v0_6_0::dap::StartDebuggingRequestArgumentsRequest;
|
||||
use crate::wasm_host::wit::since_v0_8_0::dap::StartDebuggingRequestArgumentsRequest;
|
||||
|
||||
use super::{WasmState, wasm_engine};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
@@ -33,6 +33,19 @@ pub use latest::CodeLabelSpanLiteral;
|
||||
pub use latest::{
|
||||
CodeLabel, CodeLabelSpan, Command, DebugAdapterBinary, ExtensionProject, Range, SlashCommand,
|
||||
zed::extension::context_server::ContextServerConfiguration,
|
||||
zed::extension::llm_provider::{
|
||||
CacheConfiguration as LlmCacheConfiguration, CompletionEvent as LlmCompletionEvent,
|
||||
CompletionRequest as LlmCompletionRequest, DeviceFlowPromptInfo as LlmDeviceFlowPromptInfo,
|
||||
ImageData as LlmImageData, MessageContent as LlmMessageContent,
|
||||
MessageRole as LlmMessageRole, ModelCapabilities as LlmModelCapabilities,
|
||||
ModelInfo as LlmModelInfo, ProviderInfo as LlmProviderInfo,
|
||||
RequestMessage as LlmRequestMessage, StopReason as LlmStopReason,
|
||||
ThinkingContent as LlmThinkingContent, TokenUsage as LlmTokenUsage,
|
||||
ToolChoice as LlmToolChoice, ToolDefinition as LlmToolDefinition,
|
||||
ToolInputFormat as LlmToolInputFormat, ToolResult as LlmToolResult,
|
||||
ToolResultContent as LlmToolResultContent, ToolUse as LlmToolUse,
|
||||
ToolUseJsonParseError as LlmToolUseJsonParseError,
|
||||
},
|
||||
zed::extension::lsp::{
|
||||
Completion, CompletionKind, CompletionLabelDetails, InsertTextFormat, Symbol, SymbolKind,
|
||||
},
|
||||
@@ -1007,6 +1020,20 @@ impl Extension {
|
||||
resource: Resource<Arc<dyn WorktreeDelegate>>,
|
||||
) -> Result<Result<DebugAdapterBinary, String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
let dap_binary = ext
|
||||
.call_get_dap_binary(
|
||||
store,
|
||||
&adapter_name,
|
||||
&task.try_into()?,
|
||||
user_installed_path.as_ref().and_then(|p| p.to_str()),
|
||||
resource,
|
||||
)
|
||||
.await?
|
||||
.map_err(|e| anyhow!("{e:?}"))?;
|
||||
|
||||
Ok(Ok(dap_binary))
|
||||
}
|
||||
Extension::V0_6_0(ext) => {
|
||||
let dap_binary = ext
|
||||
.call_get_dap_binary(
|
||||
@@ -1032,6 +1059,16 @@ impl Extension {
|
||||
config: serde_json::Value,
|
||||
) -> Result<Result<StartDebuggingRequestArgumentsRequest, String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
let config =
|
||||
serde_json::to_string(&config).context("Adapter config is not a valid JSON")?;
|
||||
let result = ext
|
||||
.call_dap_request_kind(store, &adapter_name, &config)
|
||||
.await?
|
||||
.map_err(|e| anyhow!("{e:?}"))?;
|
||||
|
||||
Ok(Ok(result))
|
||||
}
|
||||
Extension::V0_6_0(ext) => {
|
||||
let config =
|
||||
serde_json::to_string(&config).context("Adapter config is not a valid JSON")?;
|
||||
@@ -1052,6 +1089,15 @@ impl Extension {
|
||||
config: ZedDebugConfig,
|
||||
) -> Result<Result<DebugScenario, String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
let config = config.into();
|
||||
let result = ext
|
||||
.call_dap_config_to_scenario(store, &config)
|
||||
.await?
|
||||
.map_err(|e| anyhow!("{e:?}"))?;
|
||||
|
||||
Ok(Ok(result.try_into()?))
|
||||
}
|
||||
Extension::V0_6_0(ext) => {
|
||||
let config = config.into();
|
||||
let dap_binary = ext
|
||||
@@ -1074,6 +1120,20 @@ impl Extension {
|
||||
debug_adapter_name: String,
|
||||
) -> Result<Option<DebugScenario>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
let build_config_template = build_config_template.into();
|
||||
let result = ext
|
||||
.call_dap_locator_create_scenario(
|
||||
store,
|
||||
&locator_name,
|
||||
&build_config_template,
|
||||
&resolved_label,
|
||||
&debug_adapter_name,
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(result.map(TryInto::try_into).transpose()?)
|
||||
}
|
||||
Extension::V0_6_0(ext) => {
|
||||
let build_config_template = build_config_template.into();
|
||||
let dap_binary = ext
|
||||
@@ -1099,6 +1159,15 @@ impl Extension {
|
||||
resolved_build_task: SpawnInTerminal,
|
||||
) -> Result<Result<DebugRequest, String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
let build_config_template = resolved_build_task.try_into()?;
|
||||
let dap_request = ext
|
||||
.call_run_dap_locator(store, &locator_name, &build_config_template)
|
||||
.await?
|
||||
.map_err(|e| anyhow!("{e:?}"))?;
|
||||
|
||||
Ok(Ok(dap_request.into()))
|
||||
}
|
||||
Extension::V0_6_0(ext) => {
|
||||
let build_config_template = resolved_build_task.try_into()?;
|
||||
let dap_request = ext
|
||||
@@ -1111,6 +1180,174 @@ impl Extension {
|
||||
_ => anyhow::bail!("`dap_locator_create_scenario` not available prior to v0.6.0"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_providers(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
) -> Result<Vec<latest::llm_provider::ProviderInfo>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => ext.call_llm_providers(store).await,
|
||||
_ => Ok(Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_provider_models(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
provider_id: &str,
|
||||
) -> Result<Result<Vec<latest::llm_provider::ModelInfo>, String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => ext.call_llm_provider_models(store, provider_id).await,
|
||||
_ => anyhow::bail!("`llm_provider_models` not available prior to v0.8.0"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_provider_settings_markdown(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
provider_id: &str,
|
||||
) -> Result<Option<String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
ext.call_llm_provider_settings_markdown(store, provider_id)
|
||||
.await
|
||||
}
|
||||
_ => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_provider_is_authenticated(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
provider_id: &str,
|
||||
) -> Result<bool> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
ext.call_llm_provider_is_authenticated(store, provider_id)
|
||||
.await
|
||||
}
|
||||
_ => Ok(false),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_provider_start_device_flow_sign_in(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
provider_id: &str,
|
||||
) -> Result<Result<LlmDeviceFlowPromptInfo, String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
ext.call_llm_provider_start_device_flow_sign_in(store, provider_id)
|
||||
.await
|
||||
}
|
||||
_ => {
|
||||
anyhow::bail!(
|
||||
"`llm_provider_start_device_flow_sign_in` not available prior to v0.8.0"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_provider_poll_device_flow_sign_in(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
provider_id: &str,
|
||||
) -> Result<Result<(), String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
ext.call_llm_provider_poll_device_flow_sign_in(store, provider_id)
|
||||
.await
|
||||
}
|
||||
_ => {
|
||||
anyhow::bail!(
|
||||
"`llm_provider_poll_device_flow_sign_in` not available prior to v0.8.0"
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_provider_reset_credentials(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
provider_id: &str,
|
||||
) -> Result<Result<(), String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
ext.call_llm_provider_reset_credentials(store, provider_id)
|
||||
.await
|
||||
}
|
||||
_ => anyhow::bail!("`llm_provider_reset_credentials` not available prior to v0.8.0"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_count_tokens(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
provider_id: &str,
|
||||
model_id: &str,
|
||||
request: &latest::llm_provider::CompletionRequest,
|
||||
) -> Result<Result<u64, String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
ext.call_llm_count_tokens(store, provider_id, model_id, request)
|
||||
.await
|
||||
}
|
||||
_ => anyhow::bail!("`llm_count_tokens` not available prior to v0.8.0"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_stream_completion_start(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
provider_id: &str,
|
||||
model_id: &str,
|
||||
request: &latest::llm_provider::CompletionRequest,
|
||||
) -> Result<Result<String, String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
ext.call_llm_stream_completion_start(store, provider_id, model_id, request)
|
||||
.await
|
||||
}
|
||||
_ => anyhow::bail!("`llm_stream_completion_start` not available prior to v0.8.0"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_stream_completion_next(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
stream_id: &str,
|
||||
) -> Result<Result<Option<latest::llm_provider::CompletionEvent>, String>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => ext.call_llm_stream_completion_next(store, stream_id).await,
|
||||
_ => anyhow::bail!("`llm_stream_completion_next` not available prior to v0.8.0"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_stream_completion_close(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
stream_id: &str,
|
||||
) -> Result<()> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => ext.call_llm_stream_completion_close(store, stream_id).await,
|
||||
_ => anyhow::bail!("`llm_stream_completion_close` not available prior to v0.8.0"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn call_llm_cache_configuration(
|
||||
&self,
|
||||
store: &mut Store<WasmState>,
|
||||
provider_id: &str,
|
||||
model_id: &str,
|
||||
) -> Result<Option<latest::llm_provider::CacheConfiguration>> {
|
||||
match self {
|
||||
Extension::V0_8_0(ext) => {
|
||||
ext.call_llm_cache_configuration(store, provider_id, model_id)
|
||||
.await
|
||||
}
|
||||
_ => Ok(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
trait ToWasmtimeResult<T> {
|
||||
|
||||
@@ -32,8 +32,6 @@ wasmtime::component::bindgen!({
|
||||
},
|
||||
});
|
||||
|
||||
pub use self::zed::extension::*;
|
||||
|
||||
mod settings {
|
||||
#![allow(dead_code)]
|
||||
include!(concat!(env!("OUT_DIR"), "/since_v0.6.0/settings.rs"));
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
use crate::wasm_host::wit::since_v0_6_0::{
|
||||
use crate::wasm_host::wit::since_v0_8_0::{
|
||||
dap::{
|
||||
BuildTaskDefinition, BuildTaskDefinitionTemplatePayload, StartDebuggingRequestArguments,
|
||||
TcpArguments, TcpArgumentsTemplate,
|
||||
},
|
||||
lsp::{CompletionKind, CompletionLabelDetails, InsertTextFormat, SymbolKind},
|
||||
slash_command::SlashCommandOutputSection,
|
||||
};
|
||||
use crate::wasm_host::wit::{CompletionKind, CompletionLabelDetails, InsertTextFormat, SymbolKind};
|
||||
use crate::wasm_host::{WasmState, wit::ToWasmtimeResult};
|
||||
use ::http_client::{AsyncBody, HttpRequestExt};
|
||||
use ::settings::{Settings, WorktreeId};
|
||||
use ::settings::{ModelMode, Settings, SettingsStore, WorktreeId};
|
||||
use anyhow::{Context as _, Result, bail};
|
||||
use async_compression::futures::bufread::GzipDecoder;
|
||||
use async_tar::Archive;
|
||||
use async_trait::async_trait;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use extension::{
|
||||
ExtensionLanguageServerProxy, KeyValueStoreDelegate, ProjectDelegate, WorktreeDelegate,
|
||||
};
|
||||
@@ -22,12 +23,14 @@ use gpui::{BackgroundExecutor, SharedString};
|
||||
use language::{BinaryStatus, LanguageName, language_settings::AllLanguageSettings};
|
||||
use project::project_settings::ProjectSettings;
|
||||
use semver::Version;
|
||||
use smol::net::TcpListener;
|
||||
use std::{
|
||||
env,
|
||||
net::Ipv4Addr,
|
||||
path::{Path, PathBuf},
|
||||
str::FromStr,
|
||||
sync::{Arc, OnceLock},
|
||||
time::Duration,
|
||||
};
|
||||
use task::{SpawnInTerminal, ZedDebugConfig};
|
||||
use url::Url;
|
||||
@@ -615,6 +618,19 @@ impl http_client::Host for WasmState {
|
||||
.to_wasmtime_result()
|
||||
}
|
||||
|
||||
async fn fetch_fallible(
|
||||
&mut self,
|
||||
request: http_client::HttpRequest,
|
||||
) -> wasmtime::Result<Result<http_client::HttpResponseWithStatus, String>> {
|
||||
maybe!(async {
|
||||
let request = convert_request(&request)?;
|
||||
let mut response = self.host.http_client.send(request).await?;
|
||||
convert_response_with_status(&mut response).await
|
||||
})
|
||||
.await
|
||||
.to_wasmtime_result()
|
||||
}
|
||||
|
||||
async fn fetch_stream(
|
||||
&mut self,
|
||||
request: http_client::HttpRequest,
|
||||
@@ -718,6 +734,26 @@ async fn convert_response(
|
||||
Ok(extension_response)
|
||||
}
|
||||
|
||||
async fn convert_response_with_status(
|
||||
response: &mut ::http_client::Response<AsyncBody>,
|
||||
) -> anyhow::Result<http_client::HttpResponseWithStatus> {
|
||||
let status = response.status().as_u16();
|
||||
let headers: Vec<(String, String)> = response
|
||||
.headers()
|
||||
.iter()
|
||||
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
|
||||
.collect();
|
||||
|
||||
let mut body = Vec::new();
|
||||
response.body_mut().read_to_end(&mut body).await?;
|
||||
|
||||
Ok(http_client::HttpResponseWithStatus {
|
||||
status,
|
||||
headers,
|
||||
body,
|
||||
})
|
||||
}
|
||||
|
||||
impl nodejs::Host for WasmState {
|
||||
async fn node_binary_path(&mut self) -> wasmtime::Result<Result<String, String>> {
|
||||
self.host
|
||||
@@ -1109,3 +1145,369 @@ impl ExtensionImports for WasmState {
|
||||
.to_wasmtime_result()
|
||||
}
|
||||
}
|
||||
|
||||
impl llm_provider::Host for WasmState {
|
||||
async fn get_credential(&mut self, provider_id: String) -> wasmtime::Result<Option<String>> {
|
||||
let extension_id = self.manifest.id.clone();
|
||||
let is_legacy_extension = crate::LEGACY_LLM_EXTENSION_IDS.contains(&extension_id.as_ref());
|
||||
|
||||
// Check if this provider has env vars configured and if the user has allowed any of them
|
||||
let env_vars = self
|
||||
.manifest
|
||||
.language_model_providers
|
||||
.get(&Arc::<str>::from(provider_id.as_str()))
|
||||
.and_then(|entry| entry.auth.as_ref())
|
||||
.and_then(|auth| auth.env_vars.clone());
|
||||
|
||||
if let Some(env_vars) = env_vars {
|
||||
let full_provider_id = format!("{}:{}", extension_id, provider_id);
|
||||
|
||||
// Check each env var to see if it's allowed and set
|
||||
for env_var_name in &env_vars {
|
||||
let settings_key: Arc<str> =
|
||||
format!("{}:{}", full_provider_id, env_var_name).into();
|
||||
|
||||
// For legacy extensions, auto-allow if env var is set
|
||||
let env_var_is_set = env::var(env_var_name)
|
||||
.map(|v| !v.is_empty())
|
||||
.unwrap_or(false);
|
||||
|
||||
let is_allowed = self
|
||||
.on_main_thread({
|
||||
let settings_key = settings_key.clone();
|
||||
move |cx| {
|
||||
async move {
|
||||
cx.update(|cx| {
|
||||
crate::extension_settings::ExtensionSettings::get_global(cx)
|
||||
.allowed_env_var_providers
|
||||
.contains(&settings_key)
|
||||
})
|
||||
}
|
||||
.boxed_local()
|
||||
}
|
||||
})
|
||||
.await
|
||||
.unwrap_or(false);
|
||||
|
||||
if is_allowed || (is_legacy_extension && env_var_is_set) {
|
||||
if let Ok(value) = env::var(env_var_name) {
|
||||
if !value.is_empty() {
|
||||
return Ok(Some(value));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to credential store
|
||||
let credential_key = format!("extension-llm-{}:{}", extension_id, provider_id);
|
||||
|
||||
self.on_main_thread(move |cx| {
|
||||
async move {
|
||||
let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx))?;
|
||||
let result = credentials_provider
|
||||
.read_credentials(&credential_key, cx)
|
||||
.await
|
||||
.ok()
|
||||
.flatten();
|
||||
Ok(result.map(|(_, password)| String::from_utf8_lossy(&password).to_string()))
|
||||
}
|
||||
.boxed_local()
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
async fn store_credential(
|
||||
&mut self,
|
||||
provider_id: String,
|
||||
value: String,
|
||||
) -> wasmtime::Result<Result<(), String>> {
|
||||
let extension_id = self.manifest.id.clone();
|
||||
let credential_key = format!("extension-llm-{}:{}", extension_id, provider_id);
|
||||
|
||||
self.on_main_thread(move |cx| {
|
||||
async move {
|
||||
let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx))?;
|
||||
credentials_provider
|
||||
.write_credentials(&credential_key, "api_key", value.as_bytes(), cx)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("{}", e))
|
||||
}
|
||||
.boxed_local()
|
||||
})
|
||||
.await
|
||||
.to_wasmtime_result()
|
||||
}
|
||||
|
||||
async fn delete_credential(
|
||||
&mut self,
|
||||
provider_id: String,
|
||||
) -> wasmtime::Result<Result<(), String>> {
|
||||
let extension_id = self.manifest.id.clone();
|
||||
let credential_key = format!("extension-llm-{}:{}", extension_id, provider_id);
|
||||
|
||||
self.on_main_thread(move |cx| {
|
||||
async move {
|
||||
let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx))?;
|
||||
credentials_provider
|
||||
.delete_credentials(&credential_key, cx)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("{}", e))
|
||||
}
|
||||
.boxed_local()
|
||||
})
|
||||
.await
|
||||
.to_wasmtime_result()
|
||||
}
|
||||
|
||||
async fn get_env_var(&mut self, name: String) -> wasmtime::Result<Option<String>> {
|
||||
let extension_id = self.manifest.id.clone();
|
||||
|
||||
// Find which provider (if any) declares this env var in its auth config
|
||||
let mut allowed_provider_id: Option<Arc<str>> = None;
|
||||
for (provider_id, provider_entry) in &self.manifest.language_model_providers {
|
||||
if let Some(auth_config) = &provider_entry.auth {
|
||||
if let Some(env_vars) = &auth_config.env_vars {
|
||||
if env_vars.iter().any(|v| v == &name) {
|
||||
allowed_provider_id = Some(provider_id.clone());
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no provider declares this env var, deny access
|
||||
let Some(provider_id) = allowed_provider_id else {
|
||||
log::warn!(
|
||||
"Extension {} attempted to read env var {} which is not declared in any provider auth config",
|
||||
extension_id,
|
||||
name
|
||||
);
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
// Check if the user has allowed this specific env var
|
||||
let settings_key: Arc<str> = format!("{}:{}:{}", extension_id, provider_id, name).into();
|
||||
let is_legacy_extension = crate::LEGACY_LLM_EXTENSION_IDS.contains(&extension_id.as_ref());
|
||||
|
||||
// For legacy extensions, auto-allow if env var is set
|
||||
let env_var_is_set = env::var(&name).map(|v| !v.is_empty()).unwrap_or(false);
|
||||
|
||||
let is_allowed = self
|
||||
.on_main_thread({
|
||||
let settings_key = settings_key.clone();
|
||||
move |cx| {
|
||||
async move {
|
||||
cx.update(|cx| {
|
||||
crate::extension_settings::ExtensionSettings::get_global(cx)
|
||||
.allowed_env_var_providers
|
||||
.contains(&settings_key)
|
||||
})
|
||||
}
|
||||
.boxed_local()
|
||||
}
|
||||
})
|
||||
.await
|
||||
.unwrap_or(false);
|
||||
|
||||
if !is_allowed && !(is_legacy_extension && env_var_is_set) {
|
||||
log::debug!(
|
||||
"Extension {} provider {} is not allowed to read env var {}",
|
||||
extension_id,
|
||||
provider_id,
|
||||
name
|
||||
);
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
Ok(env::var(&name).ok())
|
||||
}
|
||||
|
||||
async fn oauth_start_web_auth(
|
||||
&mut self,
|
||||
config: llm_provider::OauthWebAuthConfig,
|
||||
) -> wasmtime::Result<Result<llm_provider::OauthWebAuthResult, String>> {
|
||||
let auth_url = config.auth_url;
|
||||
let callback_path = config.callback_path;
|
||||
let timeout_secs = config.timeout_secs.unwrap_or(300);
|
||||
|
||||
self.on_main_thread(move |cx| {
|
||||
async move {
|
||||
// Bind to port 0 to let the OS assign an available port, then substitute
|
||||
// it into the auth URL's {port} placeholder for the OAuth callback.
|
||||
let listener = TcpListener::bind("127.0.0.1:0")
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to bind localhost server: {}", e))?;
|
||||
let port = listener
|
||||
.local_addr()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to get local address: {}", e))?
|
||||
.port();
|
||||
|
||||
let auth_url_with_port = auth_url.replace("{port}", &port.to_string());
|
||||
cx.update(|cx| {
|
||||
cx.open_url(&auth_url_with_port);
|
||||
})?;
|
||||
|
||||
let accept_future = async {
|
||||
let (mut stream, _) = listener
|
||||
.accept()
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to accept connection: {}", e))?;
|
||||
|
||||
let mut request_line = String::new();
|
||||
{
|
||||
let mut reader = smol::io::BufReader::new(&mut stream);
|
||||
smol::io::AsyncBufReadExt::read_line(&mut reader, &mut request_line)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to read request: {}", e))?;
|
||||
}
|
||||
|
||||
let path = request_line
|
||||
.split_whitespace()
|
||||
.nth(1)
|
||||
.ok_or_else(|| anyhow::anyhow!("Malformed HTTP request"))?;
|
||||
|
||||
let callback_url = if path.starts_with(&callback_path)
|
||||
|| path.starts_with(&format!("/{}", callback_path.trim_start_matches('/')))
|
||||
{
|
||||
format!("http://localhost:{}{}", port, path)
|
||||
} else {
|
||||
return Err(anyhow::anyhow!("Unexpected callback path: {}", path));
|
||||
};
|
||||
|
||||
let response = format!(
|
||||
"HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nConnection: close\r\n\r\n{}",
|
||||
include_str!("../oauth_callback_response.html")
|
||||
);
|
||||
|
||||
smol::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes())
|
||||
.await
|
||||
.ok();
|
||||
smol::io::AsyncWriteExt::flush(&mut stream).await.ok();
|
||||
|
||||
Ok(callback_url)
|
||||
};
|
||||
|
||||
let timeout_duration = Duration::from_secs(timeout_secs as u64);
|
||||
let callback_url = smol::future::or(accept_future, async {
|
||||
smol::Timer::after(timeout_duration).await;
|
||||
Err(anyhow::anyhow!(
|
||||
"OAuth callback timed out after {} seconds",
|
||||
timeout_secs
|
||||
))
|
||||
})
|
||||
.await?;
|
||||
|
||||
Ok(llm_provider::OauthWebAuthResult {
|
||||
callback_url,
|
||||
port: port as u32,
|
||||
})
|
||||
}
|
||||
.boxed_local()
|
||||
})
|
||||
.await
|
||||
.to_wasmtime_result()
|
||||
}
|
||||
|
||||
async fn oauth_send_http_request(
|
||||
&mut self,
|
||||
request: http_client::HttpRequest,
|
||||
) -> wasmtime::Result<Result<http_client::HttpResponseWithStatus, String>> {
|
||||
maybe!(async {
|
||||
let request = convert_request(&request)?;
|
||||
let mut response = self.host.http_client.send(request).await?;
|
||||
convert_response_with_status(&mut response).await
|
||||
})
|
||||
.await
|
||||
.to_wasmtime_result()
|
||||
}
|
||||
|
||||
async fn oauth_open_browser(&mut self, url: String) -> wasmtime::Result<Result<(), String>> {
|
||||
self.on_main_thread(move |cx| {
|
||||
async move {
|
||||
cx.update(|cx| {
|
||||
cx.open_url(&url);
|
||||
})?;
|
||||
Ok(())
|
||||
}
|
||||
.boxed_local()
|
||||
})
|
||||
.await
|
||||
.to_wasmtime_result()
|
||||
}
|
||||
|
||||
async fn get_provider_settings(
|
||||
&mut self,
|
||||
provider_id: String,
|
||||
) -> wasmtime::Result<Option<llm_provider::ProviderSettings>> {
|
||||
let extension_id = self.manifest.id.clone();
|
||||
|
||||
let result = self
|
||||
.on_main_thread(move |cx| {
|
||||
async move {
|
||||
cx.update(|cx| {
|
||||
let settings_store = cx.global::<SettingsStore>();
|
||||
let user_settings = settings_store.raw_user_settings();
|
||||
let language_models =
|
||||
user_settings.and_then(|s| s.content.language_models.as_ref());
|
||||
|
||||
// Map provider IDs to their settings
|
||||
// The provider_id from the extension is just the provider part (e.g., "google-ai")
|
||||
// We need to match this to the appropriate settings
|
||||
match provider_id.as_str() {
|
||||
"google-ai" => {
|
||||
let google = language_models.and_then(|lm| lm.google.as_ref());
|
||||
let google = google?;
|
||||
|
||||
let api_url = google.api_url.clone().filter(|s| !s.is_empty());
|
||||
|
||||
let available_models = google
|
||||
.available_models
|
||||
.as_ref()
|
||||
.map(|models| {
|
||||
models
|
||||
.iter()
|
||||
.map(|m| {
|
||||
let thinking_budget = match &m.mode {
|
||||
Some(ModelMode::Thinking { budget_tokens }) => {
|
||||
*budget_tokens
|
||||
}
|
||||
_ => None,
|
||||
};
|
||||
llm_provider::CustomModelConfig {
|
||||
name: m.name.clone(),
|
||||
display_name: m.display_name.clone(),
|
||||
max_tokens: m.max_tokens,
|
||||
max_output_tokens: None,
|
||||
thinking_budget,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
Some(llm_provider::ProviderSettings {
|
||||
api_url,
|
||||
available_models,
|
||||
})
|
||||
}
|
||||
_ => {
|
||||
log::debug!(
|
||||
"Extension {} requested settings for unknown provider: {}",
|
||||
extension_id,
|
||||
provider_id
|
||||
);
|
||||
None
|
||||
}
|
||||
}
|
||||
})
|
||||
.ok()
|
||||
.flatten()
|
||||
}
|
||||
.boxed_local()
|
||||
})
|
||||
.await;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -442,7 +442,9 @@ impl ExtensionsPage {
|
||||
let extension_store = ExtensionStore::global(cx).read(cx);
|
||||
|
||||
match extension_store.outstanding_operations().get(extension_id) {
|
||||
Some(ExtensionOperation::Install) => ExtensionStatus::Installing,
|
||||
Some(ExtensionOperation::Install) | Some(ExtensionOperation::AutoInstall) => {
|
||||
ExtensionStatus::Installing
|
||||
}
|
||||
Some(ExtensionOperation::Remove) => ExtensionStatus::Removing,
|
||||
Some(ExtensionOperation::Upgrade) => ExtensionStatus::Upgrading,
|
||||
None => match extension_store.installed_extensions().get(extension_id) {
|
||||
|
||||
@@ -296,6 +296,20 @@ impl TestAppContext {
|
||||
&self.text_system
|
||||
}
|
||||
|
||||
/// Simulates writing credentials to the platform keychain.
|
||||
pub fn write_credentials(&self, url: &str, username: &str, password: &[u8]) {
|
||||
let _ = self
|
||||
.test_platform
|
||||
.write_credentials(url, username, password);
|
||||
}
|
||||
|
||||
/// Simulates reading credentials from the platform keychain.
|
||||
pub fn read_credentials(&self, url: &str) -> Option<(String, Vec<u8>)> {
|
||||
smol::block_on(self.test_platform.read_credentials(url))
|
||||
.ok()
|
||||
.flatten()
|
||||
}
|
||||
|
||||
/// Simulates writing to the platform clipboard
|
||||
pub fn write_to_clipboard(&self, item: ClipboardItem) {
|
||||
self.test_platform.write_to_clipboard(item)
|
||||
|
||||
@@ -6,7 +6,7 @@ use crate::{
|
||||
TestDisplay, TestWindow, WindowAppearance, WindowParams, size,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use collections::VecDeque;
|
||||
use collections::{HashMap, VecDeque};
|
||||
use futures::channel::oneshot;
|
||||
use parking_lot::Mutex;
|
||||
use std::{
|
||||
@@ -32,6 +32,7 @@ pub(crate) struct TestPlatform {
|
||||
current_clipboard_item: Mutex<Option<ClipboardItem>>,
|
||||
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
|
||||
current_primary_item: Mutex<Option<ClipboardItem>>,
|
||||
credentials: Mutex<HashMap<String, (String, Vec<u8>)>>,
|
||||
#[cfg(target_os = "macos")]
|
||||
current_find_pasteboard_item: Mutex<Option<ClipboardItem>>,
|
||||
pub(crate) prompts: RefCell<TestPrompts>,
|
||||
@@ -119,6 +120,7 @@ impl TestPlatform {
|
||||
current_clipboard_item: Mutex::new(None),
|
||||
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
|
||||
current_primary_item: Mutex::new(None),
|
||||
credentials: Mutex::new(HashMap::default()),
|
||||
#[cfg(target_os = "macos")]
|
||||
current_find_pasteboard_item: Mutex::new(None),
|
||||
weak: weak.clone(),
|
||||
@@ -430,15 +432,20 @@ impl Platform for TestPlatform {
|
||||
*self.current_find_pasteboard_item.lock() = Some(item);
|
||||
}
|
||||
|
||||
fn write_credentials(&self, _url: &str, _username: &str, _password: &[u8]) -> Task<Result<()>> {
|
||||
fn write_credentials(&self, url: &str, username: &str, password: &[u8]) -> Task<Result<()>> {
|
||||
self.credentials
|
||||
.lock()
|
||||
.insert(url.to_string(), (username.to_string(), password.to_vec()));
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
|
||||
fn read_credentials(&self, _url: &str) -> Task<Result<Option<(String, Vec<u8>)>>> {
|
||||
Task::ready(Ok(None))
|
||||
fn read_credentials(&self, url: &str) -> Task<Result<Option<(String, Vec<u8>)>>> {
|
||||
let result = self.credentials.lock().get(url).cloned();
|
||||
Task::ready(Ok(result))
|
||||
}
|
||||
|
||||
fn delete_credentials(&self, _url: &str) -> Task<Result<()>> {
|
||||
fn delete_credentials(&self, url: &str) -> Task<Result<()>> {
|
||||
self.credentials.lock().remove(url);
|
||||
Task::ready(Ok(()))
|
||||
}
|
||||
|
||||
|
||||
@@ -228,10 +228,6 @@ impl ApiKeyState {
|
||||
}
|
||||
|
||||
impl ApiKey {
|
||||
pub fn key(&self) -> &str {
|
||||
&self.key
|
||||
}
|
||||
|
||||
pub fn from_env(env_var_name: SharedString, key: &str) -> Self {
|
||||
Self {
|
||||
source: ApiKeySource::EnvVar(env_var_name),
|
||||
@@ -239,16 +235,6 @@ impl ApiKey {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn load_from_system_keychain(
|
||||
url: &str,
|
||||
credentials_provider: &dyn CredentialsProvider,
|
||||
cx: &AsyncApp,
|
||||
) -> Result<Self, AuthenticateError> {
|
||||
Self::load_from_system_keychain_impl(url, credentials_provider, cx)
|
||||
.await
|
||||
.into_authenticate_result()
|
||||
}
|
||||
|
||||
async fn load_from_system_keychain_impl(
|
||||
url: &str,
|
||||
credentials_provider: &dyn CredentialsProvider,
|
||||
|
||||
@@ -818,6 +818,11 @@ pub trait LanguageModelProvider: 'static {
|
||||
fn icon(&self) -> IconOrSvg {
|
||||
IconOrSvg::default()
|
||||
}
|
||||
/// Returns the path to an external SVG icon for this provider, if any.
|
||||
/// When present, this takes precedence over `icon()`.
|
||||
fn icon_path(&self) -> Option<SharedString> {
|
||||
None
|
||||
}
|
||||
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
|
||||
fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
|
||||
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;
|
||||
@@ -839,6 +844,7 @@ pub trait LanguageModelProvider: 'static {
|
||||
pub enum ConfigurationViewTargetAgent {
|
||||
#[default]
|
||||
ZedAgent,
|
||||
EditPrediction,
|
||||
Other(SharedString),
|
||||
}
|
||||
|
||||
|
||||
@@ -492,6 +492,7 @@ mod tests {
|
||||
registry.update(cx, |registry, cx| {
|
||||
registry.register_provider(provider.clone(), cx);
|
||||
|
||||
// Set up a hiding function that hides the fake provider when "fake-extension" is installed
|
||||
registry.set_builtin_provider_hiding_fn(Box::new(|id| {
|
||||
if id == "fake" {
|
||||
Some("fake-extension")
|
||||
@@ -501,17 +502,21 @@ mod tests {
|
||||
}));
|
||||
});
|
||||
|
||||
// Provider should be visible initially
|
||||
let visible = registry.read(cx).visible_providers();
|
||||
assert_eq!(visible.len(), 1);
|
||||
assert_eq!(visible[0].id(), provider_id);
|
||||
|
||||
// Install the extension
|
||||
registry.update(cx, |registry, cx| {
|
||||
registry.extension_installed("fake-extension".into(), cx);
|
||||
});
|
||||
|
||||
// Provider should now be hidden
|
||||
let visible = registry.read(cx).visible_providers();
|
||||
assert!(visible.is_empty());
|
||||
|
||||
// But still in providers()
|
||||
let all = registry.read(cx).providers();
|
||||
assert_eq!(all.len(), 1);
|
||||
}
|
||||
@@ -526,6 +531,7 @@ mod tests {
|
||||
registry.update(cx, |registry, cx| {
|
||||
registry.register_provider(provider.clone(), cx);
|
||||
|
||||
// Set up hiding function
|
||||
registry.set_builtin_provider_hiding_fn(Box::new(|id| {
|
||||
if id == "fake" {
|
||||
Some("fake-extension")
|
||||
@@ -534,16 +540,20 @@ mod tests {
|
||||
}
|
||||
}));
|
||||
|
||||
// Start with extension installed
|
||||
registry.extension_installed("fake-extension".into(), cx);
|
||||
});
|
||||
|
||||
// Provider should be hidden
|
||||
let visible = registry.read(cx).visible_providers();
|
||||
assert!(visible.is_empty());
|
||||
|
||||
// Uninstall the extension
|
||||
registry.update(cx, |registry, cx| {
|
||||
registry.extension_uninstalled("fake-extension", cx);
|
||||
});
|
||||
|
||||
// Provider should now be visible again
|
||||
let visible = registry.read(cx).visible_providers();
|
||||
assert_eq!(visible.len(), 1);
|
||||
assert_eq!(visible[0].id(), provider_id);
|
||||
@@ -554,6 +564,7 @@ mod tests {
|
||||
let registry = cx.new(|_| LanguageModelRegistry::default());
|
||||
|
||||
registry.update(cx, |registry, cx| {
|
||||
// Set up hiding function
|
||||
registry.set_builtin_provider_hiding_fn(Box::new(|id| {
|
||||
if id == "anthropic" {
|
||||
Some("anthropic")
|
||||
@@ -564,15 +575,19 @@ mod tests {
|
||||
}
|
||||
}));
|
||||
|
||||
// Install only anthropic extension
|
||||
registry.extension_installed("anthropic".into(), cx);
|
||||
});
|
||||
|
||||
let registry_read = registry.read(cx);
|
||||
|
||||
// Anthropic should be hidden
|
||||
assert!(registry_read.should_hide_provider(&LanguageModelProviderId("anthropic".into())));
|
||||
|
||||
// OpenAI should not be hidden (extension not installed)
|
||||
assert!(!registry_read.should_hide_provider(&LanguageModelProviderId("openai".into())));
|
||||
|
||||
// Unknown provider should not be hidden
|
||||
assert!(!registry_read.should_hide_provider(&LanguageModelProviderId("unknown".into())));
|
||||
}
|
||||
|
||||
@@ -594,6 +609,7 @@ mod tests {
|
||||
}));
|
||||
});
|
||||
|
||||
// Sync with a set containing the extension
|
||||
let mut extension_ids = HashSet::default();
|
||||
extension_ids.insert(Arc::from("fake-extension"));
|
||||
|
||||
@@ -601,12 +617,15 @@ mod tests {
|
||||
registry.sync_installed_llm_extensions(extension_ids, cx);
|
||||
});
|
||||
|
||||
// Provider should be hidden
|
||||
assert!(registry.read(cx).visible_providers().is_empty());
|
||||
|
||||
// Sync with empty set
|
||||
registry.update(cx, |registry, cx| {
|
||||
registry.sync_installed_llm_extensions(HashSet::default(), cx);
|
||||
});
|
||||
|
||||
// Provider should be visible again
|
||||
assert_eq!(registry.read(cx).visible_providers().len(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use collections::HashMap;
|
||||
use extension::{
|
||||
use ::extension::{
|
||||
ExtensionHostProxy, ExtensionLanguageModelProviderProxy, LanguageModelProviderRegistration,
|
||||
};
|
||||
use collections::HashMap;
|
||||
use gpui::{App, Entity};
|
||||
use language_model::{LanguageModelProviderId, LanguageModelRegistry};
|
||||
use std::sync::{Arc, LazyLock};
|
||||
@@ -59,6 +59,7 @@ pub fn init_proxy(cx: &mut App) {
|
||||
let proxy = ExtensionHostProxy::default_global(cx);
|
||||
let registry = LanguageModelRegistry::global(cx);
|
||||
|
||||
// Set the function that determines which built-in providers should be hidden
|
||||
registry.update(cx, |registry, _cx| {
|
||||
registry.set_builtin_provider_hiding_fn(Box::new(extension_for_builtin_provider));
|
||||
});
|
||||
|
||||
43
crates/language_models/src/google_ai_api_key.rs
Normal file
43
crates/language_models/src/google_ai_api_key.rs
Normal file
@@ -0,0 +1,43 @@
|
||||
use anyhow::Result;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use gpui::{App, Task};
|
||||
|
||||
const GEMINI_API_KEY_VAR_NAME: &str = "GEMINI_API_KEY";
|
||||
const GOOGLE_AI_API_KEY_VAR_NAME: &str = "GOOGLE_AI_API_KEY";
|
||||
const GOOGLE_AI_EXTENSION_CREDENTIAL_KEY: &str = "extension-llm-google-ai:google-ai";
|
||||
|
||||
/// Returns the Google AI API key for use by the Gemini CLI.
|
||||
///
|
||||
/// This function checks the following sources in order:
|
||||
/// 1. `GEMINI_API_KEY` environment variable
|
||||
/// 2. `GOOGLE_AI_API_KEY` environment variable
|
||||
/// 3. Extension credential store (`extension-llm-google-ai:google-ai`)
|
||||
pub fn api_key_for_gemini_cli(cx: &mut App) -> Task<Result<String>> {
|
||||
if let Ok(key) = std::env::var(GEMINI_API_KEY_VAR_NAME) {
|
||||
if !key.is_empty() {
|
||||
return Task::ready(Ok(key));
|
||||
}
|
||||
}
|
||||
|
||||
if let Ok(key) = std::env::var(GOOGLE_AI_API_KEY_VAR_NAME) {
|
||||
if !key.is_empty() {
|
||||
return Task::ready(Ok(key));
|
||||
}
|
||||
}
|
||||
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let credential = credentials_provider
|
||||
.read_credentials(GOOGLE_AI_EXTENSION_CREDENTIAL_KEY, &cx)
|
||||
.await?;
|
||||
|
||||
match credential {
|
||||
Some((_, key_bytes)) => {
|
||||
let key = String::from_utf8(key_bytes)?;
|
||||
Ok(key)
|
||||
}
|
||||
None => Err(anyhow::anyhow!("No Google AI API key found")),
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -8,11 +8,12 @@ use language_model::{LanguageModelProviderId, LanguageModelRegistry};
|
||||
use provider::deepseek::DeepSeekLanguageModelProvider;
|
||||
|
||||
pub mod extension;
|
||||
mod google_ai_api_key;
|
||||
pub mod provider;
|
||||
mod settings;
|
||||
|
||||
pub use crate::extension::init_proxy as init_extension_proxy;
|
||||
|
||||
pub use crate::google_ai_api_key::api_key_for_gemini_cli;
|
||||
use crate::provider::anthropic::AnthropicLanguageModelProvider;
|
||||
use crate::provider::bedrock::BedrockLanguageModelProvider;
|
||||
use crate::provider::cloud::CloudLanguageModelProvider;
|
||||
@@ -38,36 +39,41 @@ pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
|
||||
if let Some(extension_store) = extension_host::ExtensionStore::try_global(cx) {
|
||||
cx.subscribe(&extension_store, {
|
||||
let registry = registry.clone();
|
||||
move |extension_store, event, cx| match event {
|
||||
extension_host::Event::ExtensionInstalled(extension_id) => {
|
||||
if let Some(manifest) = extension_store
|
||||
.read(cx)
|
||||
.extension_manifest_for_id(extension_id)
|
||||
{
|
||||
if !manifest.language_model_providers.is_empty() {
|
||||
registry.update(cx, |registry, cx| {
|
||||
registry.extension_installed(extension_id.clone(), cx);
|
||||
});
|
||||
move |extension_store, event, cx| {
|
||||
match event {
|
||||
extension_host::Event::ExtensionInstalled(extension_id) => {
|
||||
// Check if this extension has language_model_providers
|
||||
if let Some(manifest) = extension_store
|
||||
.read(cx)
|
||||
.extension_manifest_for_id(extension_id)
|
||||
{
|
||||
if !manifest.language_model_providers.is_empty() {
|
||||
registry.update(cx, |registry, cx| {
|
||||
registry.extension_installed(extension_id.clone(), cx);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
extension_host::Event::ExtensionUninstalled(extension_id) => {
|
||||
registry.update(cx, |registry, cx| {
|
||||
registry.extension_uninstalled(extension_id, cx);
|
||||
});
|
||||
}
|
||||
extension_host::Event::ExtensionsUpdated => {
|
||||
let mut new_ids = HashSet::default();
|
||||
for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
|
||||
if !entry.manifest.language_model_providers.is_empty() {
|
||||
new_ids.insert(extension_id.clone());
|
||||
}
|
||||
extension_host::Event::ExtensionUninstalled(extension_id) => {
|
||||
registry.update(cx, |registry, cx| {
|
||||
registry.extension_uninstalled(extension_id, cx);
|
||||
});
|
||||
}
|
||||
registry.update(cx, |registry, cx| {
|
||||
registry.sync_installed_llm_extensions(new_ids, cx);
|
||||
});
|
||||
extension_host::Event::ExtensionsUpdated => {
|
||||
// Re-sync installed extensions on bulk updates
|
||||
let mut new_ids = HashSet::default();
|
||||
for (extension_id, entry) in extension_store.read(cx).installed_extensions()
|
||||
{
|
||||
if !entry.manifest.language_model_providers.is_empty() {
|
||||
new_ids.insert(extension_id.clone());
|
||||
}
|
||||
}
|
||||
registry.update(cx, |registry, cx| {
|
||||
registry.sync_installed_llm_extensions(new_ids, cx);
|
||||
});
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
|
||||
@@ -1104,6 +1104,7 @@ impl Render for ConfigurationView {
|
||||
.on_action(cx.listener(Self::save_api_key))
|
||||
.child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match &self.target_agent {
|
||||
ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Anthropic".into(),
|
||||
ConfigurationViewTargetAgent::EditPrediction => "Anthropic for edit predictions".into(),
|
||||
ConfigurationViewTargetAgent::Other(agent) => agent.clone(),
|
||||
})))
|
||||
.child(
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::BTreeMap;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture};
|
||||
use google_ai::{
|
||||
FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction,
|
||||
@@ -32,7 +31,7 @@ use ui::{ButtonLink, ConfiguredApiCard, List, ListBulletItem, prelude::*};
|
||||
use ui_input::InputField;
|
||||
use util::ResultExt;
|
||||
|
||||
use language_model::{ApiKey, ApiKeyState};
|
||||
use language_model::ApiKeyState;
|
||||
|
||||
const PROVIDER_ID: LanguageModelProviderId = language_model::GOOGLE_PROVIDER_ID;
|
||||
const PROVIDER_NAME: LanguageModelProviderName = language_model::GOOGLE_PROVIDER_NAME;
|
||||
@@ -117,22 +116,6 @@ impl GoogleLanguageModelProvider {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn api_key_for_gemini_cli(cx: &mut App) -> Task<Result<String>> {
|
||||
if let Some(key) = API_KEY_ENV_VAR.value.clone() {
|
||||
return Task::ready(Ok(key));
|
||||
}
|
||||
let credentials_provider = <dyn CredentialsProvider>::global(cx);
|
||||
let api_url = Self::api_url(cx).to_string();
|
||||
cx.spawn(async move |cx| {
|
||||
Ok(
|
||||
ApiKey::load_from_system_keychain(&api_url, credentials_provider.as_ref(), cx)
|
||||
.await?
|
||||
.key()
|
||||
.to_string(),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
fn settings(cx: &App) -> &GoogleSettings {
|
||||
&crate::AllLanguageModelSettings::get_global(cx).google
|
||||
}
|
||||
@@ -707,7 +690,7 @@ pub fn count_google_tokens(
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Tiktoken doesn't yet support these models, so we manually use the
|
||||
// Tiktoken doesn't support these models, so we manually use the
|
||||
// same tokenizer as GPT-4.
|
||||
tiktoken_rs::num_tokens_from_messages("gpt-4", &messages).map(|tokens| tokens as u64)
|
||||
})
|
||||
@@ -858,6 +841,7 @@ impl Render for ConfigurationView {
|
||||
.on_action(cx.listener(Self::save_api_key))
|
||||
.child(Label::new(format!("To use {}, you need to add an API key. Follow these steps:", match &self.target_agent {
|
||||
ConfigurationViewTargetAgent::ZedAgent => "Zed's agent with Google AI".into(),
|
||||
ConfigurationViewTargetAgent::EditPrediction => "Google AI for edit predictions".into(),
|
||||
ConfigurationViewTargetAgent::Other(agent) => agent.clone(),
|
||||
})))
|
||||
.child(
|
||||
|
||||
@@ -281,7 +281,6 @@ impl JsonSchema for LanguageModelProviderSetting {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"amazon-bedrock",
|
||||
"anthropic",
|
||||
"copilot_chat",
|
||||
"deepseek",
|
||||
"google",
|
||||
|
||||
@@ -20,6 +20,12 @@ pub struct ExtensionSettingsContent {
|
||||
pub auto_update_extensions: HashMap<Arc<str>, bool>,
|
||||
/// The capabilities granted to extensions.
|
||||
pub granted_extension_capabilities: Option<Vec<ExtensionCapabilityContent>>,
|
||||
/// Extension language model providers that are allowed to read API keys from
|
||||
/// environment variables. Each entry is in the format
|
||||
/// "extension_id:provider_id:ENV_VAR_NAME" (e.g., "google-ai:google-ai:GEMINI_API_KEY").
|
||||
///
|
||||
/// Default: []
|
||||
pub allowed_env_var_providers: Option<Vec<Arc<str>>>,
|
||||
}
|
||||
|
||||
/// A capability for an extension.
|
||||
|
||||
@@ -20,6 +20,8 @@ anyhow.workspace = true
|
||||
bm25 = "2.3.2"
|
||||
copilot.workspace = true
|
||||
edit_prediction.workspace = true
|
||||
extension_host.workspace = true
|
||||
language_model.workspace = true
|
||||
language_models.workspace = true
|
||||
editor.workspace = true
|
||||
feature_flags.workspace = true
|
||||
|
||||
@@ -7659,8 +7659,8 @@ fn edit_prediction_language_settings_section() -> Vec<SettingsPageItem> {
|
||||
files: USER,
|
||||
render: Arc::new(|_, window, cx| {
|
||||
let settings_window = cx.entity();
|
||||
let page = window.use_state(cx, |_, _| {
|
||||
crate::pages::EditPredictionSetupPage::new(settings_window)
|
||||
let page = window.use_state(cx, |window, cx| {
|
||||
crate::pages::EditPredictionSetupPage::new(settings_window, window, cx)
|
||||
});
|
||||
page.into_any_element()
|
||||
}),
|
||||
|
||||
@@ -3,10 +3,15 @@ use edit_prediction::{
|
||||
mercury::{MERCURY_CREDENTIALS_URL, mercury_api_token},
|
||||
sweep_ai::{SWEEP_CREDENTIALS_URL, sweep_api_token},
|
||||
};
|
||||
use extension_host::ExtensionStore;
|
||||
use feature_flags::FeatureFlagAppExt as _;
|
||||
use gpui::{Entity, ScrollHandle, prelude::*};
|
||||
use gpui::{AnyView, Entity, ScrollHandle, Subscription, prelude::*};
|
||||
use language_model::{
|
||||
ConfigurationViewTargetAgent, LanguageModelProviderId, LanguageModelRegistry,
|
||||
};
|
||||
use language_models::provider::mistral::{CODESTRAL_API_URL, codestral_api_key};
|
||||
use ui::{ButtonLink, ConfiguredApiCard, WithScrollbar, prelude::*};
|
||||
use std::collections::HashMap;
|
||||
use ui::{ButtonLink, ConfiguredApiCard, Icon, WithScrollbar, prelude::*};
|
||||
|
||||
use crate::{
|
||||
SettingField, SettingItem, SettingsFieldMetadata, SettingsPageItem, SettingsWindow, USER,
|
||||
@@ -16,24 +21,133 @@ use crate::{
|
||||
pub struct EditPredictionSetupPage {
|
||||
settings_window: Entity<SettingsWindow>,
|
||||
scroll_handle: ScrollHandle,
|
||||
extension_oauth_views: HashMap<LanguageModelProviderId, ExtensionOAuthProviderView>,
|
||||
_registry_subscription: Subscription,
|
||||
}
|
||||
|
||||
struct ExtensionOAuthProviderView {
|
||||
provider_name: SharedString,
|
||||
provider_icon: IconName,
|
||||
provider_icon_path: Option<SharedString>,
|
||||
configuration_view: AnyView,
|
||||
}
|
||||
|
||||
impl EditPredictionSetupPage {
|
||||
pub fn new(settings_window: Entity<SettingsWindow>) -> Self {
|
||||
Self {
|
||||
pub fn new(
|
||||
settings_window: Entity<SettingsWindow>,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Self {
|
||||
let registry_subscription = cx.subscribe_in(
|
||||
&LanguageModelRegistry::global(cx),
|
||||
window,
|
||||
|this, _, event: &language_model::Event, window, cx| match event {
|
||||
language_model::Event::AddedProvider(provider_id) => {
|
||||
this.maybe_add_extension_oauth_view(provider_id, window, cx);
|
||||
}
|
||||
language_model::Event::RemovedProvider(provider_id) => {
|
||||
this.extension_oauth_views.remove(provider_id);
|
||||
}
|
||||
_ => {}
|
||||
},
|
||||
);
|
||||
|
||||
let mut this = Self {
|
||||
settings_window,
|
||||
scroll_handle: ScrollHandle::new(),
|
||||
extension_oauth_views: HashMap::default(),
|
||||
_registry_subscription: registry_subscription,
|
||||
};
|
||||
this.build_extension_oauth_views(window, cx);
|
||||
this
|
||||
}
|
||||
|
||||
fn build_extension_oauth_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
|
||||
let oauth_provider_ids = get_extension_oauth_provider_ids(cx);
|
||||
for provider_id in oauth_provider_ids {
|
||||
self.maybe_add_extension_oauth_view(&provider_id, window, cx);
|
||||
}
|
||||
}
|
||||
|
||||
fn maybe_add_extension_oauth_view(
|
||||
&mut self,
|
||||
provider_id: &LanguageModelProviderId,
|
||||
window: &mut Window,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
// Check if this provider has OAuth configured in the extension manifest
|
||||
if !is_extension_oauth_provider(provider_id, cx) {
|
||||
return;
|
||||
}
|
||||
|
||||
let registry = LanguageModelRegistry::global(cx).read(cx);
|
||||
let Some(provider) = registry.provider(provider_id) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let provider_name = provider.name().0;
|
||||
let provider_icon = provider.icon();
|
||||
let provider_icon_path = provider.icon_path();
|
||||
let configuration_view =
|
||||
provider.configuration_view(ConfigurationViewTargetAgent::EditPrediction, window, cx);
|
||||
|
||||
self.extension_oauth_views.insert(
|
||||
provider_id.clone(),
|
||||
ExtensionOAuthProviderView {
|
||||
provider_name,
|
||||
provider_icon,
|
||||
provider_icon_path,
|
||||
configuration_view,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for EditPredictionSetupPage {
|
||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let settings_window = self.settings_window.clone();
|
||||
|
||||
let providers = [
|
||||
Some(render_github_copilot_provider(window, cx).into_any_element()),
|
||||
cx.has_flag::<Zeta2FeatureFlag>().then(|| {
|
||||
let copilot_extension_installed = ExtensionStore::global(cx)
|
||||
.read(cx)
|
||||
.installed_extensions()
|
||||
.contains_key("copilot-chat");
|
||||
|
||||
let mut providers: Vec<AnyElement> = Vec::new();
|
||||
|
||||
// Built-in Copilot (hidden if copilot-chat extension is installed)
|
||||
if !copilot_extension_installed {
|
||||
providers.push(render_github_copilot_provider(window, cx).into_any_element());
|
||||
}
|
||||
|
||||
// Extension providers with OAuth support
|
||||
for (provider_id, view) in &self.extension_oauth_views {
|
||||
let icon_element: AnyElement = if let Some(icon_path) = &view.provider_icon_path {
|
||||
Icon::from_external_svg(icon_path.clone())
|
||||
.size(ui::IconSize::Medium)
|
||||
.into_any_element()
|
||||
} else {
|
||||
Icon::new(view.provider_icon)
|
||||
.size(ui::IconSize::Medium)
|
||||
.into_any_element()
|
||||
};
|
||||
|
||||
providers.push(
|
||||
v_flex()
|
||||
.id(SharedString::from(provider_id.0.to_string()))
|
||||
.min_w_0()
|
||||
.gap_1p5()
|
||||
.child(
|
||||
h_flex().gap_2().items_center().child(icon_element).child(
|
||||
Headline::new(view.provider_name.clone()).size(HeadlineSize::Small),
|
||||
),
|
||||
)
|
||||
.child(view.configuration_view.clone())
|
||||
.into_any_element(),
|
||||
);
|
||||
}
|
||||
|
||||
if cx.has_flag::<Zeta2FeatureFlag>() {
|
||||
providers.push(
|
||||
render_api_key_provider(
|
||||
IconName::Inception,
|
||||
"Mercury",
|
||||
@@ -44,9 +158,12 @@ impl Render for EditPredictionSetupPage {
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.into_any_element()
|
||||
}),
|
||||
cx.has_flag::<Zeta2FeatureFlag>().then(|| {
|
||||
.into_any_element(),
|
||||
);
|
||||
}
|
||||
|
||||
if cx.has_flag::<Zeta2FeatureFlag>() {
|
||||
providers.push(
|
||||
render_api_key_provider(
|
||||
IconName::SweepAi,
|
||||
"Sweep",
|
||||
@@ -57,32 +174,33 @@ impl Render for EditPredictionSetupPage {
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.into_any_element()
|
||||
}),
|
||||
Some(
|
||||
render_api_key_provider(
|
||||
IconName::AiMistral,
|
||||
"Codestral",
|
||||
"https://console.mistral.ai/codestral".into(),
|
||||
codestral_api_key(cx),
|
||||
|cx| language_models::MistralLanguageModelProvider::api_url(cx),
|
||||
Some(settings_window.update(cx, |settings_window, cx| {
|
||||
let codestral_settings = codestral_settings();
|
||||
settings_window
|
||||
.render_sub_page_items_section(
|
||||
codestral_settings.iter().enumerate(),
|
||||
None,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.into_any_element()
|
||||
})),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.into_any_element(),
|
||||
),
|
||||
];
|
||||
);
|
||||
}
|
||||
|
||||
providers.push(
|
||||
render_api_key_provider(
|
||||
IconName::AiMistral,
|
||||
"Codestral",
|
||||
"https://console.mistral.ai/codestral".into(),
|
||||
codestral_api_key(cx),
|
||||
|cx| language_models::MistralLanguageModelProvider::api_url(cx),
|
||||
Some(settings_window.update(cx, |settings_window, cx| {
|
||||
let codestral_settings = codestral_settings();
|
||||
settings_window
|
||||
.render_sub_page_items_section(
|
||||
codestral_settings.iter().enumerate(),
|
||||
None,
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.into_any_element()
|
||||
})),
|
||||
window,
|
||||
cx,
|
||||
)
|
||||
.into_any_element(),
|
||||
);
|
||||
|
||||
div()
|
||||
.size_full()
|
||||
@@ -96,11 +214,60 @@ impl Render for EditPredictionSetupPage {
|
||||
.pb_16()
|
||||
.overflow_y_scroll()
|
||||
.track_scroll(&self.scroll_handle)
|
||||
.children(providers.into_iter().flatten()),
|
||||
.children(providers),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Get extension provider IDs that have OAuth configured.
|
||||
fn get_extension_oauth_provider_ids(cx: &App) -> Vec<LanguageModelProviderId> {
|
||||
let extension_store = ExtensionStore::global(cx).read(cx);
|
||||
|
||||
extension_store
|
||||
.installed_extensions()
|
||||
.iter()
|
||||
.flat_map(|(extension_id, entry)| {
|
||||
entry.manifest.language_model_providers.iter().filter_map(
|
||||
move |(provider_id, provider_entry)| {
|
||||
// Check if this provider has OAuth configured
|
||||
let has_oauth = provider_entry
|
||||
.auth
|
||||
.as_ref()
|
||||
.is_some_and(|auth| auth.oauth.is_some());
|
||||
|
||||
if has_oauth {
|
||||
Some(LanguageModelProviderId(
|
||||
format!("{}:{}", extension_id, provider_id).into(),
|
||||
))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Check if a provider ID corresponds to an extension with OAuth configured.
|
||||
fn is_extension_oauth_provider(provider_id: &LanguageModelProviderId, cx: &App) -> bool {
|
||||
// Extension provider IDs are in the format "extension_id:provider_id"
|
||||
let Some((extension_id, local_provider_id)) = provider_id.0.split_once(':') else {
|
||||
return false;
|
||||
};
|
||||
|
||||
let extension_store = ExtensionStore::global(cx).read(cx);
|
||||
let Some(entry) = extension_store.installed_extensions().get(extension_id) else {
|
||||
return false;
|
||||
};
|
||||
|
||||
entry
|
||||
.manifest
|
||||
.language_model_providers
|
||||
.get(local_provider_id)
|
||||
.and_then(|p| p.auth.as_ref())
|
||||
.is_some_and(|auth| auth.oauth.is_some())
|
||||
}
|
||||
|
||||
fn render_api_key_provider(
|
||||
icon: IconName,
|
||||
title: &'static str,
|
||||
|
||||
@@ -344,6 +344,7 @@ pub struct Switch {
|
||||
label: Option<SharedString>,
|
||||
label_position: Option<SwitchLabelPosition>,
|
||||
label_size: LabelSize,
|
||||
label_color: Color,
|
||||
full_width: bool,
|
||||
key_binding: Option<KeyBinding>,
|
||||
color: SwitchColor,
|
||||
@@ -361,6 +362,7 @@ impl Switch {
|
||||
label: None,
|
||||
label_position: None,
|
||||
label_size: LabelSize::Small,
|
||||
label_color: Color::Default,
|
||||
full_width: false,
|
||||
key_binding: None,
|
||||
color: SwitchColor::default(),
|
||||
@@ -408,6 +410,11 @@ impl Switch {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn label_color(mut self, color: Color) -> Self {
|
||||
self.label_color = color;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn full_width(mut self, full_width: bool) -> Self {
|
||||
self.full_width = full_width;
|
||||
self
|
||||
@@ -507,7 +514,11 @@ impl RenderOnce for Switch {
|
||||
self.label_position == Some(SwitchLabelPosition::Start),
|
||||
|this| {
|
||||
this.when_some(label.clone(), |this, label| {
|
||||
this.child(Label::new(label).size(self.label_size))
|
||||
this.child(
|
||||
Label::new(label)
|
||||
.color(self.label_color)
|
||||
.size(self.label_size),
|
||||
)
|
||||
})
|
||||
},
|
||||
)
|
||||
@@ -516,7 +527,11 @@ impl RenderOnce for Switch {
|
||||
self.label_position == Some(SwitchLabelPosition::End),
|
||||
|this| {
|
||||
this.when_some(label, |this, label| {
|
||||
this.child(Label::new(label).size(self.label_size))
|
||||
this.child(
|
||||
Label::new(label)
|
||||
.color(self.label_color)
|
||||
.size(self.label_size),
|
||||
)
|
||||
})
|
||||
},
|
||||
)
|
||||
|
||||
294
crates/workspace/src/oauth_device_flow_modal.rs
Normal file
294
crates/workspace/src/oauth_device_flow_modal.rs
Normal file
@@ -0,0 +1,294 @@
|
||||
use gpui::{
|
||||
Animation, AnimationExt, App, ClipboardItem, Context, DismissEvent, Element, Entity,
|
||||
EventEmitter, FocusHandle, Focusable, InteractiveElement, IntoElement, MouseDownEvent,
|
||||
ParentElement, Render, SharedString, Styled, Subscription, Transformation, Window, div,
|
||||
percentage, rems, svg,
|
||||
};
|
||||
use menu;
|
||||
use std::time::Duration;
|
||||
use ui::{Button, Icon, IconName, Label, Vector, VectorName, prelude::*};
|
||||
|
||||
use crate::ModalView;
|
||||
|
||||
/// Configuration for the OAuth device flow modal.
|
||||
/// This allows extensions to specify the text and appearance of the modal.
|
||||
#[derive(Clone)]
|
||||
pub struct OAuthDeviceFlowModalConfig {
|
||||
/// The user code to display (e.g., "ABC-123").
|
||||
pub user_code: String,
|
||||
/// The URL the user needs to visit to authorize (for the "Connect" button).
|
||||
pub verification_url: String,
|
||||
/// The headline text for the modal (e.g., "Use GitHub Copilot in Zed.").
|
||||
pub headline: String,
|
||||
/// A description to show below the headline.
|
||||
pub description: String,
|
||||
/// Label for the connect button (e.g., "Connect to GitHub").
|
||||
pub connect_button_label: String,
|
||||
/// Success headline shown when authorization completes.
|
||||
pub success_headline: String,
|
||||
/// Success message shown when authorization completes.
|
||||
pub success_message: String,
|
||||
/// Optional path to an SVG icon file (absolute path on disk).
|
||||
pub icon_path: Option<SharedString>,
|
||||
}
|
||||
|
||||
/// The current status of the OAuth device flow.
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum OAuthDeviceFlowStatus {
|
||||
/// Waiting for user to click connect and authorize.
|
||||
Prompting,
|
||||
/// User clicked connect, waiting for authorization.
|
||||
WaitingForAuthorization,
|
||||
/// Successfully authorized.
|
||||
Authorized,
|
||||
/// Authorization failed with an error message.
|
||||
Failed(String),
|
||||
}
|
||||
|
||||
/// Shared state for the OAuth device flow that can be observed by the modal.
|
||||
pub struct OAuthDeviceFlowState {
|
||||
pub config: OAuthDeviceFlowModalConfig,
|
||||
pub status: OAuthDeviceFlowStatus,
|
||||
}
|
||||
|
||||
impl EventEmitter<()> for OAuthDeviceFlowState {}
|
||||
|
||||
impl OAuthDeviceFlowState {
|
||||
pub fn new(config: OAuthDeviceFlowModalConfig) -> Self {
|
||||
Self {
|
||||
config,
|
||||
status: OAuthDeviceFlowStatus::Prompting,
|
||||
}
|
||||
}
|
||||
|
||||
/// Update the status of the OAuth flow.
|
||||
pub fn set_status(&mut self, status: OAuthDeviceFlowStatus, cx: &mut Context<Self>) {
|
||||
self.status = status;
|
||||
cx.emit(());
|
||||
cx.notify();
|
||||
}
|
||||
}
|
||||
|
||||
/// A generic OAuth device flow modal that can be used by extensions.
|
||||
pub struct OAuthDeviceFlowModal {
|
||||
state: Entity<OAuthDeviceFlowState>,
|
||||
connect_clicked: bool,
|
||||
focus_handle: FocusHandle,
|
||||
_subscription: Subscription,
|
||||
}
|
||||
|
||||
impl Focusable for OAuthDeviceFlowModal {
|
||||
fn focus_handle(&self, _: &App) -> FocusHandle {
|
||||
self.focus_handle.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl EventEmitter<DismissEvent> for OAuthDeviceFlowModal {}
|
||||
|
||||
impl ModalView for OAuthDeviceFlowModal {}
|
||||
|
||||
impl OAuthDeviceFlowModal {
|
||||
pub fn new(state: Entity<OAuthDeviceFlowState>, cx: &mut Context<Self>) -> Self {
|
||||
let subscription = cx.observe(&state, |_, _, cx| {
|
||||
cx.notify();
|
||||
});
|
||||
|
||||
Self {
|
||||
state,
|
||||
connect_clicked: false,
|
||||
focus_handle: cx.focus_handle(),
|
||||
_subscription: subscription,
|
||||
}
|
||||
}
|
||||
|
||||
fn render_icon(&self, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let state = self.state.read(cx);
|
||||
let icon_color = Color::Custom(cx.theme().colors().icon);
|
||||
// Match ZedXCopilot visual appearance
|
||||
let icon_size = rems(2.5);
|
||||
let plus_size = rems(0.875);
|
||||
// The "+" in ZedXCopilot SVG has fill-opacity="0.5"
|
||||
let plus_color = cx.theme().colors().icon.opacity(0.5);
|
||||
|
||||
if let Some(icon_path) = &state.config.icon_path {
|
||||
// Show "[Provider Icon] + [Zed Logo]" format to match built-in Copilot modal
|
||||
h_flex()
|
||||
.gap_2()
|
||||
.items_center()
|
||||
.child(
|
||||
Icon::from_external_svg(icon_path.clone())
|
||||
.size(ui::IconSize::Custom(icon_size))
|
||||
.color(icon_color),
|
||||
)
|
||||
.child(
|
||||
svg()
|
||||
.size(plus_size)
|
||||
.path("icons/plus.svg")
|
||||
.text_color(plus_color),
|
||||
)
|
||||
.child(Vector::new(VectorName::ZedLogo, icon_size, icon_size).color(icon_color))
|
||||
.into_any_element()
|
||||
} else {
|
||||
// Fallback to just Zed logo if no provider icon
|
||||
Vector::new(VectorName::ZedLogo, icon_size, icon_size)
|
||||
.color(icon_color)
|
||||
.into_any_element()
|
||||
}
|
||||
}
|
||||
|
||||
fn render_device_code(&self, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let state = self.state.read(cx);
|
||||
let user_code = state.config.user_code.clone();
|
||||
let copied = cx
|
||||
.read_from_clipboard()
|
||||
.map(|item| item.text().as_ref() == Some(&user_code))
|
||||
.unwrap_or(false);
|
||||
let user_code_for_click = user_code.clone();
|
||||
|
||||
h_flex()
|
||||
.w_full()
|
||||
.p_1()
|
||||
.border_1()
|
||||
.border_muted(cx)
|
||||
.rounded_sm()
|
||||
.cursor_pointer()
|
||||
.justify_between()
|
||||
.on_mouse_down(gpui::MouseButton::Left, move |_, window, cx| {
|
||||
cx.write_to_clipboard(ClipboardItem::new_string(user_code_for_click.clone()));
|
||||
window.refresh();
|
||||
})
|
||||
.child(div().flex_1().child(Label::new(user_code)))
|
||||
.child(div().flex_none().px_1().child(Label::new(if copied {
|
||||
"Copied!"
|
||||
} else {
|
||||
"Copy"
|
||||
})))
|
||||
}
|
||||
|
||||
fn render_prompting_modal(&self, cx: &mut Context<Self>) -> impl Element {
|
||||
let (connect_button_label, verification_url, headline, description) = {
|
||||
let state = self.state.read(cx);
|
||||
let label = if self.connect_clicked {
|
||||
"Waiting for connection...".to_string()
|
||||
} else {
|
||||
state.config.connect_button_label.clone()
|
||||
};
|
||||
(
|
||||
label,
|
||||
state.config.verification_url.clone(),
|
||||
state.config.headline.clone(),
|
||||
state.config.description.clone(),
|
||||
)
|
||||
};
|
||||
|
||||
v_flex()
|
||||
.flex_1()
|
||||
.gap_2()
|
||||
.items_center()
|
||||
.child(Headline::new(headline).size(HeadlineSize::Large))
|
||||
.child(Label::new(description).color(Color::Muted))
|
||||
.child(self.render_device_code(cx))
|
||||
.child(
|
||||
Label::new("Paste this code into GitHub after clicking the button below.")
|
||||
.size(ui::LabelSize::Small),
|
||||
)
|
||||
.child(
|
||||
Button::new("connect-button", connect_button_label)
|
||||
.on_click(cx.listener(move |this, _, _window, cx| {
|
||||
cx.open_url(&verification_url);
|
||||
this.connect_clicked = true;
|
||||
}))
|
||||
.full_width()
|
||||
.style(ButtonStyle::Filled),
|
||||
)
|
||||
.child(
|
||||
Button::new("cancel-button", "Cancel")
|
||||
.full_width()
|
||||
.on_click(cx.listener(|_, _, _, cx| {
|
||||
cx.emit(DismissEvent);
|
||||
})),
|
||||
)
|
||||
}
|
||||
|
||||
fn render_authorized_modal(&self, cx: &mut Context<Self>) -> impl Element {
|
||||
let state = self.state.read(cx);
|
||||
let success_headline = state.config.success_headline.clone();
|
||||
let success_message = state.config.success_message.clone();
|
||||
|
||||
v_flex()
|
||||
.gap_2()
|
||||
.child(Headline::new(success_headline).size(HeadlineSize::Large))
|
||||
.child(Label::new(success_message))
|
||||
.child(
|
||||
Button::new("done-button", "Done")
|
||||
.full_width()
|
||||
.on_click(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))),
|
||||
)
|
||||
}
|
||||
|
||||
fn render_failed_modal(&self, error: &str, cx: &mut Context<Self>) -> impl Element {
|
||||
v_flex()
|
||||
.gap_2()
|
||||
.child(Headline::new("Authorization Failed").size(HeadlineSize::Large))
|
||||
.child(Label::new(error.to_string()).color(Color::Error))
|
||||
.child(
|
||||
Button::new("close-button", "Close")
|
||||
.full_width()
|
||||
.on_click(cx.listener(|_, _, _, cx| cx.emit(DismissEvent))),
|
||||
)
|
||||
}
|
||||
|
||||
fn render_loading(window: &mut Window, _cx: &mut Context<Self>) -> impl Element {
|
||||
let loading_icon = svg()
|
||||
.size_8()
|
||||
.path(IconName::ArrowCircle.path())
|
||||
.text_color(window.text_style().color)
|
||||
.with_animation(
|
||||
"icon_circle_arrow",
|
||||
Animation::new(Duration::from_secs(2)).repeat(),
|
||||
|svg, delta| svg.with_transformation(Transformation::rotate(percentage(delta))),
|
||||
);
|
||||
|
||||
h_flex().justify_center().child(loading_icon)
|
||||
}
|
||||
}
|
||||
|
||||
impl Render for OAuthDeviceFlowModal {
|
||||
fn render(&mut self, window: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
|
||||
let status = self.state.read(cx).status.clone();
|
||||
|
||||
let prompt = match &status {
|
||||
OAuthDeviceFlowStatus::Prompting => self.render_prompting_modal(cx).into_any_element(),
|
||||
OAuthDeviceFlowStatus::WaitingForAuthorization => {
|
||||
if self.connect_clicked {
|
||||
self.render_prompting_modal(cx).into_any_element()
|
||||
} else {
|
||||
Self::render_loading(window, cx).into_any_element()
|
||||
}
|
||||
}
|
||||
OAuthDeviceFlowStatus::Authorized => {
|
||||
self.render_authorized_modal(cx).into_any_element()
|
||||
}
|
||||
OAuthDeviceFlowStatus::Failed(error) => {
|
||||
self.render_failed_modal(error, cx).into_any_element()
|
||||
}
|
||||
};
|
||||
|
||||
v_flex()
|
||||
.id("oauth-device-flow-modal")
|
||||
.track_focus(&self.focus_handle(cx))
|
||||
.elevation_3(cx)
|
||||
.w_96()
|
||||
.items_center()
|
||||
.p_4()
|
||||
.gap_2()
|
||||
.on_action(cx.listener(|_, _: &menu::Cancel, _, cx| {
|
||||
cx.emit(DismissEvent);
|
||||
}))
|
||||
.on_any_mouse_down(cx.listener(|this, _: &MouseDownEvent, window, cx| {
|
||||
window.focus(&this.focus_handle, cx);
|
||||
}))
|
||||
.child(self.render_icon(cx))
|
||||
.child(prompt)
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ pub mod invalid_item_view;
|
||||
pub mod item;
|
||||
mod modal_layer;
|
||||
pub mod notifications;
|
||||
pub mod oauth_device_flow_modal;
|
||||
pub mod pane;
|
||||
pub mod pane_group;
|
||||
mod path_list;
|
||||
|
||||
@@ -571,6 +571,11 @@ fn main() {
|
||||
dap_adapters::init(cx);
|
||||
auto_update_ui::init(cx);
|
||||
reliability::init(client.clone(), cx);
|
||||
// Initialize the language model registry first, then set up the extension proxy
|
||||
// BEFORE extension_host::init so that extensions can register their LLM providers
|
||||
// when they load.
|
||||
language_model::init(app_state.client.clone(), cx);
|
||||
language_models::init_extension_proxy(cx);
|
||||
extension_host::init(
|
||||
extension_host_proxy.clone(),
|
||||
app_state.fs.clone(),
|
||||
@@ -596,7 +601,6 @@ fn main() {
|
||||
cx,
|
||||
);
|
||||
supermaven::init(app_state.client.clone(), cx);
|
||||
language_model::init(app_state.client.clone(), cx);
|
||||
language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
|
||||
acp_tools::init(cx);
|
||||
edit_prediction_ui::init(cx);
|
||||
|
||||
@@ -2687,9 +2687,6 @@ These values take in the same options as the root-level settings with the same n
|
||||
```json [settings]
|
||||
{
|
||||
"language_models": {
|
||||
"anthropic": {
|
||||
"api_url": "https://api.anthropic.com"
|
||||
},
|
||||
"google": {
|
||||
"api_url": "https://generativelanguage.googleapis.com"
|
||||
},
|
||||
|
||||
823
extensions/google-ai/Cargo.lock
generated
Normal file
823
extensions/google-ai/Cargo.lock
generated
Normal file
@@ -0,0 +1,823 @@
|
||||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
version = 4
|
||||
|
||||
[[package]]
|
||||
name = "adler2"
|
||||
version = "2.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa"
|
||||
|
||||
[[package]]
|
||||
name = "anyhow"
|
||||
version = "1.0.100"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61"
|
||||
|
||||
[[package]]
|
||||
name = "auditable-serde"
|
||||
version = "0.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5c7bf8143dfc3c0258df908843e169b5cc5fcf76c7718bd66135ef4a9cd558c5"
|
||||
dependencies = [
|
||||
"semver",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"topological-sort",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "2.10.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3"
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
|
||||
|
||||
[[package]]
|
||||
name = "crc32fast"
|
||||
version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "displaydoc"
|
||||
version = "0.2.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "equivalent"
|
||||
version = "1.0.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
|
||||
|
||||
[[package]]
|
||||
name = "flate2"
|
||||
version = "1.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bfe33edd8e85a12a67454e37f8c75e730830d83e313556ab9ebf9ee7fbeb3bfb"
|
||||
dependencies = [
|
||||
"crc32fast",
|
||||
"miniz_oxide",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "foldhash"
|
||||
version = "0.1.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
|
||||
|
||||
[[package]]
|
||||
name = "form_urlencoded"
|
||||
version = "1.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf"
|
||||
dependencies = [
|
||||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures"
|
||||
version = "0.3.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876"
|
||||
dependencies = [
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-executor",
|
||||
"futures-io",
|
||||
"futures-sink",
|
||||
"futures-task",
|
||||
"futures-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-channel"
|
||||
version = "0.3.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-sink",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-core"
|
||||
version = "0.3.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e"
|
||||
|
||||
[[package]]
|
||||
name = "futures-executor"
|
||||
version = "0.3.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-task",
|
||||
"futures-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-io"
|
||||
version = "0.3.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6"
|
||||
|
||||
[[package]]
|
||||
name = "futures-macro"
|
||||
version = "0.3.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-sink"
|
||||
version = "0.3.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7"
|
||||
|
||||
[[package]]
|
||||
name = "futures-task"
|
||||
version = "0.3.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
|
||||
|
||||
[[package]]
|
||||
name = "futures-util"
|
||||
version = "0.3.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
|
||||
dependencies = [
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-io",
|
||||
"futures-macro",
|
||||
"futures-sink",
|
||||
"futures-task",
|
||||
"memchr",
|
||||
"pin-project-lite",
|
||||
"pin-utils",
|
||||
"slab",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "google-ai"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"zed_extension_api",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.15.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1"
|
||||
dependencies = [
|
||||
"foldhash",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hashbrown"
|
||||
version = "0.16.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100"
|
||||
|
||||
[[package]]
|
||||
name = "heck"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
|
||||
|
||||
[[package]]
|
||||
name = "icu_collections"
|
||||
version = "2.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4c6b649701667bbe825c3b7e6388cb521c23d88644678e83c0c4d0a621a34b43"
|
||||
dependencies = [
|
||||
"displaydoc",
|
||||
"potential_utf",
|
||||
"yoke",
|
||||
"zerofrom",
|
||||
"zerovec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "icu_locale_core"
|
||||
version = "2.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "edba7861004dd3714265b4db54a3c390e880ab658fec5f7db895fae2046b5bb6"
|
||||
dependencies = [
|
||||
"displaydoc",
|
||||
"litemap",
|
||||
"tinystr",
|
||||
"writeable",
|
||||
"zerovec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "icu_normalizer"
|
||||
version = "2.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5f6c8828b67bf8908d82127b2054ea1b4427ff0230ee9141c54251934ab1b599"
|
||||
dependencies = [
|
||||
"icu_collections",
|
||||
"icu_normalizer_data",
|
||||
"icu_properties",
|
||||
"icu_provider",
|
||||
"smallvec",
|
||||
"zerovec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "icu_normalizer_data"
|
||||
version = "2.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7aedcccd01fc5fe81e6b489c15b247b8b0690feb23304303a9e560f37efc560a"
|
||||
|
||||
[[package]]
|
||||
name = "icu_properties"
|
||||
version = "2.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e93fcd3157766c0c8da2f8cff6ce651a31f0810eaa1c51ec363ef790bbb5fb99"
|
||||
dependencies = [
|
||||
"icu_collections",
|
||||
"icu_locale_core",
|
||||
"icu_properties_data",
|
||||
"icu_provider",
|
||||
"zerotrie",
|
||||
"zerovec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "icu_properties_data"
|
||||
version = "2.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "02845b3647bb045f1100ecd6480ff52f34c35f82d9880e029d329c21d1054899"
|
||||
|
||||
[[package]]
|
||||
name = "icu_provider"
|
||||
version = "2.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "85962cf0ce02e1e0a629cc34e7ca3e373ce20dda4c4d7294bbd0bf1fdb59e614"
|
||||
dependencies = [
|
||||
"displaydoc",
|
||||
"icu_locale_core",
|
||||
"writeable",
|
||||
"yoke",
|
||||
"zerofrom",
|
||||
"zerotrie",
|
||||
"zerovec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "id-arena"
|
||||
version = "2.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "25a2bc672d1148e28034f176e01fffebb08b35768468cc954630da77a1449005"
|
||||
|
||||
[[package]]
|
||||
name = "idna"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de"
|
||||
dependencies = [
|
||||
"idna_adapter",
|
||||
"smallvec",
|
||||
"utf8_iter",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "idna_adapter"
|
||||
version = "1.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344"
|
||||
dependencies = [
|
||||
"icu_normalizer",
|
||||
"icu_properties",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "2.12.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0ad4bb2b565bca0645f4d68c5c9af97fba094e9791da685bf83cb5f3ce74acf2"
|
||||
dependencies = [
|
||||
"equivalent",
|
||||
"hashbrown 0.16.1",
|
||||
"serde",
|
||||
"serde_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "1.0.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c"
|
||||
|
||||
[[package]]
|
||||
name = "leb128fmt"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2"
|
||||
|
||||
[[package]]
|
||||
name = "litemap"
|
||||
version = "0.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77"
|
||||
|
||||
[[package]]
|
||||
name = "log"
|
||||
version = "0.4.29"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897"
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
version = "2.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273"
|
||||
|
||||
[[package]]
|
||||
name = "miniz_oxide"
|
||||
version = "0.8.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316"
|
||||
dependencies = [
|
||||
"adler2",
|
||||
"simd-adler32",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.21.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
|
||||
|
||||
[[package]]
|
||||
name = "percent-encoding"
|
||||
version = "2.3.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220"
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-lite"
|
||||
version = "0.2.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b"
|
||||
|
||||
[[package]]
|
||||
name = "pin-utils"
|
||||
version = "0.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
|
||||
|
||||
[[package]]
|
||||
name = "potential_utf"
|
||||
version = "0.1.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b73949432f5e2a09657003c25bca5e19a0e9c84f8058ca374f49e0ebe605af77"
|
||||
dependencies = [
|
||||
"zerovec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "prettyplease"
|
||||
version = "0.2.37"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.103"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5ee95bc4ef87b8d5ba32e8b7714ccc834865276eab0aed5c9958d00ec45f49e8"
|
||||
dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.42"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a338cc41d27e6cc6dce6cefc13a0729dfbb81c262b1f519331575dd80ef3067f"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.20"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f"
|
||||
|
||||
[[package]]
|
||||
name = "semver"
|
||||
version = "1.0.27"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.228"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e"
|
||||
dependencies = [
|
||||
"serde_core",
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_core"
|
||||
version = "1.0.228"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad"
|
||||
dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.228"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.145"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"memchr",
|
||||
"ryu",
|
||||
"serde",
|
||||
"serde_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "simd-adler32"
|
||||
version = "0.3.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe"
|
||||
|
||||
[[package]]
|
||||
name = "slab"
|
||||
version = "0.4.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589"
|
||||
|
||||
[[package]]
|
||||
name = "smallvec"
|
||||
version = "1.15.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03"
|
||||
|
||||
[[package]]
|
||||
name = "spdx"
|
||||
version = "0.10.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c3e17e880bafaeb362a7b751ec46bdc5b61445a188f80e0606e68167cd540fa3"
|
||||
dependencies = [
|
||||
"smallvec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "stable_deref_trait"
|
||||
version = "1.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596"
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.111"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "synstructure"
|
||||
version = "0.13.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinystr"
|
||||
version = "0.8.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "42d3e9c45c09de15d06dd8acf5f4e0e399e85927b7f00711024eb7ae10fa4869"
|
||||
dependencies = [
|
||||
"displaydoc",
|
||||
"zerovec",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "topological-sort"
|
||||
version = "0.2.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ea68304e134ecd095ac6c3574494fc62b909f416c4fca77e440530221e549d3d"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-ident"
|
||||
version = "1.0.22"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
|
||||
|
||||
[[package]]
|
||||
name = "unicode-xid"
|
||||
version = "0.2.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
|
||||
|
||||
[[package]]
|
||||
name = "url"
|
||||
version = "2.5.7"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "08bc136a29a3d1758e07a9cca267be308aeebf5cfd5a10f3f67ab2097683ef5b"
|
||||
dependencies = [
|
||||
"form_urlencoded",
|
||||
"idna",
|
||||
"percent-encoding",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "utf8_iter"
|
||||
version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
|
||||
|
||||
[[package]]
|
||||
name = "wasm-encoder"
|
||||
version = "0.227.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "80bb72f02e7fbf07183443b27b0f3d4144abf8c114189f2e088ed95b696a7822"
|
||||
dependencies = [
|
||||
"leb128fmt",
|
||||
"wasmparser",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-metadata"
|
||||
version = "0.227.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ce1ef0faabbbba6674e97a56bee857ccddf942785a336c8b47b42373c922a91d"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"auditable-serde",
|
||||
"flate2",
|
||||
"indexmap",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
"spdx",
|
||||
"url",
|
||||
"wasm-encoder",
|
||||
"wasmparser",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasmparser"
|
||||
version = "0.227.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0f51cad774fb3c9461ab9bccc9c62dfb7388397b5deda31bf40e8108ccd678b2"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"hashbrown 0.15.5",
|
||||
"indexmap",
|
||||
"semver",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wit-bindgen"
|
||||
version = "0.41.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "10fb6648689b3929d56bbc7eb1acf70c9a42a29eb5358c67c10f54dbd5d695de"
|
||||
dependencies = [
|
||||
"wit-bindgen-rt",
|
||||
"wit-bindgen-rust-macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wit-bindgen-core"
|
||||
version = "0.41.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "92fa781d4f2ff6d3f27f3cc9b74a73327b31ca0dc4a3ef25a0ce2983e0e5af9b"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"heck",
|
||||
"wit-parser",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wit-bindgen-rt"
|
||||
version = "0.41.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c4db52a11d4dfb0a59f194c064055794ee6564eb1ced88c25da2cf76e50c5621"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"futures",
|
||||
"once_cell",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wit-bindgen-rust"
|
||||
version = "0.41.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9d0809dc5ba19e2e98661bf32fc0addc5a3ca5bf3a6a7083aa6ba484085ff3ce"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"heck",
|
||||
"indexmap",
|
||||
"prettyplease",
|
||||
"syn",
|
||||
"wasm-metadata",
|
||||
"wit-bindgen-core",
|
||||
"wit-component",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wit-bindgen-rust-macro"
|
||||
version = "0.41.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ad19eec017904e04c60719592a803ee5da76cb51c81e3f6fbf9457f59db49799"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"prettyplease",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"wit-bindgen-core",
|
||||
"wit-bindgen-rust",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wit-component"
|
||||
version = "0.227.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "635c3adc595422cbf2341a17fb73a319669cc8d33deed3a48368a841df86b676"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bitflags",
|
||||
"indexmap",
|
||||
"log",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
"wasm-encoder",
|
||||
"wasm-metadata",
|
||||
"wasmparser",
|
||||
"wit-parser",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wit-parser"
|
||||
version = "0.227.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ddf445ed5157046e4baf56f9138c124a0824d4d1657e7204d71886ad8ce2fc11"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"id-arena",
|
||||
"indexmap",
|
||||
"log",
|
||||
"semver",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
"unicode-xid",
|
||||
"wasmparser",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "writeable"
|
||||
version = "0.6.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9"
|
||||
|
||||
[[package]]
|
||||
name = "yoke"
|
||||
version = "0.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "72d6e5c6afb84d73944e5cedb052c4680d5657337201555f9f2a16b7406d4954"
|
||||
dependencies = [
|
||||
"stable_deref_trait",
|
||||
"yoke-derive",
|
||||
"zerofrom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "yoke-derive"
|
||||
version = "0.8.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"synstructure",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zed_extension_api"
|
||||
version = "0.8.0"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"wit-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zerofrom"
|
||||
version = "0.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5"
|
||||
dependencies = [
|
||||
"zerofrom-derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zerofrom-derive"
|
||||
version = "0.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"synstructure",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zerotrie"
|
||||
version = "0.2.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2a59c17a5562d507e4b54960e8569ebee33bee890c70aa3fe7b97e85a9fd7851"
|
||||
dependencies = [
|
||||
"displaydoc",
|
||||
"yoke",
|
||||
"zerofrom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zerovec"
|
||||
version = "0.11.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6c28719294829477f525be0186d13efa9a3c602f7ec202ca9e353d310fb9a002"
|
||||
dependencies = [
|
||||
"yoke",
|
||||
"zerofrom",
|
||||
"zerovec-derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zerovec-derive"
|
||||
version = "0.11.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eadce39539ca5cb3985590102671f2567e659fca9666581ad3411d59207951f3"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
17
extensions/google-ai/Cargo.toml
Normal file
17
extensions/google-ai/Cargo.toml
Normal file
@@ -0,0 +1,17 @@
|
||||
[package]
|
||||
name = "google-ai"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
license = "Apache-2.0"
|
||||
|
||||
[workspace]
|
||||
|
||||
[lib]
|
||||
path = "src/google_ai.rs"
|
||||
crate-type = ["cdylib"]
|
||||
|
||||
[dependencies]
|
||||
zed_extension_api = { path = "../../crates/extension_api" }
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
1
extensions/google-ai/LICENSE-APACHE
Symbolic link
1
extensions/google-ai/LICENSE-APACHE
Symbolic link
@@ -0,0 +1 @@
|
||||
../../LICENSE-APACHE
|
||||
13
extensions/google-ai/extension.toml
Normal file
13
extensions/google-ai/extension.toml
Normal file
@@ -0,0 +1,13 @@
|
||||
id = "google-ai"
|
||||
name = "Google AI"
|
||||
description = "Google Gemini LLM provider for Zed."
|
||||
version = "0.1.0"
|
||||
schema_version = 1
|
||||
authors = ["Zed Team"]
|
||||
repository = "https://github.com/zed-industries/zed"
|
||||
|
||||
[language_model_providers.google]
|
||||
name = "Google AI"
|
||||
|
||||
[language_model_providers.google.auth]
|
||||
env_vars = ["GEMINI_API_KEY", "GOOGLE_AI_API_KEY"]
|
||||
3
extensions/google-ai/icons/google-ai.svg
Normal file
3
extensions/google-ai/icons/google-ai.svg
Normal file
@@ -0,0 +1,3 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M7.44 12.27C7.81333 13.1217 8 14.0317 8 15C8 14.0317 8.18083 13.1217 8.5425 12.27C8.91583 11.4183 9.4175 10.6775 10.0475 10.0475C10.6775 9.4175 11.4183 8.92167 12.27 8.56C13.1217 8.18667 14.0317 8 15 8C14.0317 8 13.1217 7.81917 12.27 7.4575C11.4411 7.1001 10.6871 6.5895 10.0475 5.9525C9.4105 5.31293 8.8999 4.55891 8.5425 3.73C8.18083 2.87833 8 1.96833 8 1C8 1.96833 7.81333 2.87833 7.44 3.73C7.07833 4.58167 6.5825 5.3225 5.9525 5.9525C5.31293 6.5895 4.55891 7.1001 3.73 7.4575C2.87833 7.81917 1.96833 8 1 8C1.96833 8 2.87833 8.18667 3.73 8.56C4.58167 8.92167 5.3225 9.4175 5.9525 10.0475C6.5825 10.6775 7.07833 11.4183 7.44 12.27Z" fill="black"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 762 B |
1128
extensions/google-ai/src/google_ai.rs
Normal file
1128
extensions/google-ai/src/google_ai.rs
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user