Compare commits

...

78 Commits

Author SHA1 Message Date
Richard Feldman
9cc517e0dd Fix some extension auto install bugs 2025-12-11 00:52:08 -05:00
Richard Feldman
d1390a5b78 Merge remote-tracking branch 'origin/main' into migrate-provider-extensions 2025-12-11 00:26:09 -05:00
Richard Feldman
ee4faede38 Migrate on auto-load 2025-12-11 00:22:38 -05:00
Richard Feldman
8d96a699b3 Revise migration system some more 2025-12-11 00:13:11 -05:00
Richard Feldman
8cfb7471db Minimize how we're tracking migrations 2025-12-10 23:21:14 -05:00
Richard Feldman
def9c87837 Migrate credentials without touching settings 2025-12-10 22:29:48 -05:00
Richard Feldman
0313ab6d41 Change open-router to openrouter in default.json 2025-12-10 22:10:29 -05:00
Richard Feldman
c5329fdff2 Rename extension from open-router to openrouter 2025-12-10 22:09:59 -05:00
Richard Feldman
a676a6895b Remove redundant set_builtin_provider_hiding_fn call 2025-12-10 22:05:03 -05:00
Richard Feldman
3b5d7d7d89 Minor cleanups 2025-12-10 22:04:35 -05:00
Richard Feldman
91f01131b1 Introduce DEBUG_ALLOW_UNPUBLISHED_AUTO_EXTENSIONS 2025-12-10 21:29:10 -05:00
Richard Feldman
5fa5226286 Remove llm_provider_authenticate() 2025-12-10 21:28:58 -05:00
Richard Feldman
ae94007227 Merge remote-tracking branch 'origin/main' into migrate-provider-extensions 2025-12-10 21:13:57 -05:00
Richard Feldman
8f425a1bd5 Fix unused arg 2025-12-10 13:11:30 -05:00
Richard Feldman
743c414e7b Refresh models list after successful auth 2025-12-10 13:10:55 -05:00
Richard Feldman
0fe335efc5 Revise Copilot auth 2025-12-10 13:02:38 -05:00
Richard Feldman
36b95aac4b Debugging extension loading timing and fallbacks 2025-12-10 12:55:41 -05:00
Richard Feldman
b2df70ab58 Clean up extension markdown for settings 2025-12-10 12:55:23 -05:00
Richard Feldman
36293d7dd9 Debugging 2025-12-09 17:04:58 -05:00
Richard Feldman
3ae3e1fce8 Don't use a heuristic for icon path 2025-12-09 14:55:44 -05:00
Richard Feldman
e5f1fc7478 Fix some regressions 2025-12-09 14:48:31 -05:00
Richard Feldman
a4f6076da7 Migrate to extensions with fallback to builtin 2025-12-09 14:14:56 -05:00
Richard Feldman
43726b2620 Restore ai_anthropic icon svg 2025-12-09 12:00:36 -05:00
Richard Feldman
94980ffb49 Reduce duplication in compute_configured_providers 2025-12-09 11:55:37 -05:00
Richard Feldman
22cc731450 Remove some duplication from icon logic 2025-12-09 11:54:58 -05:00
Richard Feldman
d9396373e3 Eliminate more code duplication 2025-12-09 11:54:00 -05:00
Richard Feldman
48002be135 Use | instead of code duplication 2025-12-09 11:53:18 -05:00
Richard Feldman
58db83f8f5 more icon code cleanup 2025-12-09 11:48:06 -05:00
Richard Feldman
0243d5b542 Clean up some more icon code 2025-12-09 11:44:10 -05:00
Richard Feldman
06230327fa Clean up some icon code 2025-12-09 11:44:05 -05:00
Richard Feldman
ca5c8992f9 Merge remote-tracking branch 'origin/main' into migrate-provider-extensions 2025-12-08 20:23:32 -05:00
Richard Feldman
1038e1c2ef Clean up some duplicated code 2025-12-08 16:59:49 -05:00
Richard Feldman
e1fe0b3287 Restore providers, deduplicate if extensions are present 2025-12-08 16:25:41 -05:00
Richard Feldman
a0e10a91bf Merge remote-tracking branch 'origin/main' into migrate-provider-extensions 2025-12-08 15:35:44 -05:00
Richard Feldman
272b1aa4bc Remove obsolete llm_provider_authenticate 2025-12-08 14:46:04 -05:00
Richard Feldman
9ef0537b44 Add the other extensions to auto-install 2025-12-07 23:13:52 -05:00
Richard Feldman
77f1de742b delete hardcoded AI providers in favor of extnesions 2025-12-07 21:31:00 -05:00
Richard Feldman
e054cabd41 Migrate Google AI over to the extension 2025-12-07 20:57:00 -05:00
Richard Feldman
3b95cb5682 Migrate Copilot and Anthropic to extensions 2025-12-07 20:48:42 -05:00
Richard Feldman
c89653bd07 Fix bugs around logging out from provider extensions 2025-12-05 17:07:25 -05:00
Richard Feldman
b90ac2dc07 Fix Drop impl for WasmExtension 2025-12-05 16:21:53 -05:00
Marshall Bowers
c9998541f0 Revert spurious changes to default.json 2025-12-05 13:25:03 -05:00
Marshall Bowers
e2b49b3cd3 Restore blank lines from main 2025-12-05 13:08:30 -05:00
Marshall Bowers
d1e77397c6 Don't make v0.8.0 available on Stable/Preview yet 2025-12-05 13:07:36 -05:00
Richard Feldman
cc5f5e35e4 Clean up some comments 2025-12-05 13:00:19 -05:00
Richard Feldman
7183b8a1cd Fix API key bug 2025-12-05 12:59:19 -05:00
Richard Feldman
b1934fb712 Remove builtin Anthropic provider 2025-12-05 12:11:51 -05:00
Richard Feldman
a198b6c0d1 Use icon in more places 2025-12-05 11:48:11 -05:00
Richard Feldman
8b5b2712c8 Update Cargo.lock 2025-12-05 11:32:58 -05:00
Richard Feldman
4464392e8e Use kebab-case for open-router extension too. 2025-12-05 11:19:10 -05:00
Richard Feldman
a0d3bc31e9 Rename copilot_chat to copilot-chat 2025-12-05 11:15:43 -05:00
Richard Feldman
ccd6672d1a Revert "Remove builtin extensions for now"
This reverts commit 5559726fd7.
2025-12-05 11:13:29 -05:00
Richard Feldman
21de6d35dd Revert "Revert auto-install extensions for now"
This reverts commit 2031ca17e5.
2025-12-05 11:13:22 -05:00
Richard Feldman
2031ca17e5 Revert auto-install extensions for now 2025-12-05 11:06:12 -05:00
Richard Feldman
8b1ce75a57 Move wit extensions into their own module 2025-12-05 10:30:02 -05:00
Richard Feldman
5559726fd7 Remove builtin extensions for now 2025-12-04 17:20:47 -05:00
Richard Feldman
e1a9269921 Delete example provider extension 2025-12-04 17:20:47 -05:00
Richard Feldman
3b6b3ff504 Specify env vars for the builtin extensions 2025-12-04 17:19:35 -05:00
Richard Feldman
aabed94970 Add OAuth via web authentication to llm extensions, migrate copilot 2025-12-04 17:12:55 -05:00
Richard Feldman
2d3a3521ba Add OAuth Web Flow auth option for llm provider extensions 2025-12-04 17:12:55 -05:00
Richard Feldman
a48bd10da0 Add llm extensions to auto_install_extensions 2025-12-04 17:12:55 -05:00
Richard Feldman
fec9525be4 Add env var checkbox 2025-12-04 17:12:23 -05:00
Richard Feldman
bf2b8e999e use fill=black over fill=currentColor 2025-12-04 16:51:47 -05:00
Richard Feldman
63c35d2b00 Use local icons in llm extensions 2025-12-04 16:48:25 -05:00
Richard Feldman
1396c68010 Add svg icons to llm provider extensions 2025-12-04 16:43:49 -05:00
Richard Feldman
fcb3d3dec6 Update a comment 2025-12-04 16:28:29 -05:00
Richard Feldman
f54e7f8c9d Add trailing newlines 2025-12-04 16:18:43 -05:00
Richard Feldman
2a89529d7f Use named fields 2025-12-04 16:17:50 -05:00
Richard Feldman
58207325e2 restore impl Drop for WasmExtension 2025-12-04 16:12:21 -05:00
Richard Feldman
e08ab99e8d Add extensions for LLM providers 2025-12-04 16:03:51 -05:00
Richard Feldman
a95f3f33a4 Clean up debug logging 2025-12-04 12:38:06 -05:00
Richard Feldman
b0767c1b1f Merge remote-tracking branch 'origin/main' into provider-extensions 2025-12-04 12:27:15 -05:00
Richard Feldman
b200e10bc4 Clean up debug statements 2025-12-04 11:30:44 -05:00
Richard Feldman
948905d916 Revise provider extensions for Gemini API 2025-12-03 20:22:10 -05:00
Richard Feldman
04de456373 Use extension-llm- prefix for credential keys 2025-12-03 15:55:10 -05:00
Richard Feldman
e5ce32e936 Add provider extension API key in settings 2025-12-03 14:41:39 -05:00
Richard Feldman
d7caae30de Fix auth and subscriptions for provider extensions 2025-12-03 13:00:53 -05:00
Richard Feldman
c7e77674a1 Initial Claude Opus 4.5 implementation of Provider Extensions 2025-12-02 13:50:00 -05:00
80 changed files with 12822 additions and 296 deletions

10
Cargo.lock generated
View File

@@ -5843,9 +5843,12 @@ dependencies = [
"async-trait",
"client",
"collections",
"credentials_provider",
"criterion",
"ctor",
"dap",
"dirs 4.0.0",
"editor",
"extension",
"fs",
"futures 0.3.31",
@@ -5854,8 +5857,11 @@ dependencies = [
"http_client",
"language",
"language_extension",
"language_model",
"log",
"lsp",
"markdown",
"menu",
"moka",
"node_runtime",
"parking_lot",
@@ -5870,12 +5876,14 @@ dependencies = [
"serde_json",
"serde_json_lenient",
"settings",
"smol",
"task",
"telemetry",
"tempfile",
"theme",
"theme_extension",
"toml 0.8.23",
"ui",
"url",
"util",
"wasmparser 0.221.3",
@@ -8842,6 +8850,8 @@ dependencies = [
"credentials_provider",
"deepseek",
"editor",
"extension",
"extension_host",
"fs",
"futures 0.3.31",
"google_ai",

View File

@@ -1725,7 +1725,12 @@
// If you don't want any of these extensions, add this field to your settings
// and change the value to `false`.
"auto_install_extensions": {
"html": true
"html": true,
"copilot-chat": true,
"anthropic": true,
"google-ai": true,
"openai": true,
"openrouter": true,
},
// The capabilities granted to extensions.
//

View File

@@ -204,12 +204,21 @@ pub trait AgentModelSelector: 'static {
}
}
/// Icon for a model in the model selector.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AgentModelIcon {
/// A built-in icon from Zed's icon set.
Named(IconName),
/// Path to a custom SVG icon file.
Path(SharedString),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AgentModelInfo {
pub id: acp::ModelId,
pub name: SharedString,
pub description: Option<SharedString>,
pub icon: Option<IconName>,
pub icon: Option<AgentModelIcon>,
}
impl From<acp::ModelInfo> for AgentModelInfo {

View File

@@ -739,7 +739,7 @@ impl ActivityIndicator {
extension_store.outstanding_operations().iter().next()
{
let (message, icon, rotate) = match operation {
ExtensionOperation::Install => (
ExtensionOperation::Install | ExtensionOperation::AutoInstall => (
format!("Installing {extension_id} extension…"),
IconName::LoadCircle,
true,

View File

@@ -18,7 +18,7 @@ pub use templates::*;
pub use thread::*;
pub use tools::*;
use acp_thread::{AcpThread, AgentModelSelector};
use acp_thread::{AcpThread, AgentModelIcon, AgentModelSelector};
use agent_client_protocol as acp;
use anyhow::{Context as _, Result, anyhow};
use chrono::{DateTime, Utc};
@@ -105,7 +105,7 @@ impl LanguageModels {
fn refresh_list(&mut self, cx: &App) {
let providers = LanguageModelRegistry::global(cx)
.read(cx)
.providers()
.visible_providers()
.into_iter()
.filter(|provider| provider.is_authenticated(cx))
.collect::<Vec<_>>();
@@ -161,11 +161,16 @@ impl LanguageModels {
model: &Arc<dyn LanguageModel>,
provider: &Arc<dyn LanguageModelProvider>,
) -> acp_thread::AgentModelInfo {
let icon = if let Some(path) = provider.icon_path() {
Some(AgentModelIcon::Path(path))
} else {
Some(AgentModelIcon::Named(provider.icon()))
};
acp_thread::AgentModelInfo {
id: Self::model_id(model),
name: model.name().0,
description: None,
icon: Some(provider.icon()),
icon,
}
}
@@ -1356,7 +1361,7 @@ mod internal_tests {
id: acp::ModelId::new("fake/fake"),
name: "Fake".into(),
description: None,
icon: Some(ui::IconName::ZedAssistant),
icon: Some(AgentModelIcon::Named(ui::IconName::ZedAssistant)),
}]
)])
);

View File

@@ -5,7 +5,7 @@ use crate::{AgentServer, AgentServerDelegate, load_proxy_env};
use acp_thread::AgentConnection;
use anyhow::{Context as _, Result};
use gpui::{App, SharedString, Task};
use language_models::provider::google::GoogleLanguageModelProvider;
use language_models::api_key_for_gemini_cli;
use project::agent_server_store::GEMINI_NAME;
#[derive(Clone)]
@@ -37,11 +37,7 @@ impl AgentServer for Gemini {
cx.spawn(async move |cx| {
extra_env.insert("SURFACE".to_owned(), "zed".to_owned());
if let Some(api_key) = cx
.update(GoogleLanguageModelProvider::api_key_for_gemini_cli)?
.await
.ok()
{
if let Some(api_key) = cx.update(api_key_for_gemini_cli)?.await.ok() {
extra_env.insert("GEMINI_API_KEY".into(), api_key);
}
let (command, root_dir, login) = store

View File

@@ -1,6 +1,6 @@
use std::{cmp::Reverse, rc::Rc, sync::Arc};
use acp_thread::{AgentModelInfo, AgentModelList, AgentModelSelector};
use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelList, AgentModelSelector};
use agent_servers::AgentServer;
use anyhow::Result;
use collections::IndexMap;
@@ -292,12 +292,18 @@ impl PickerDelegate for AcpModelPickerDelegate {
h_flex()
.w_full()
.gap_1p5()
.when_some(model_info.icon, |this, icon| {
this.child(
Icon::new(icon)
.map(|this| match &model_info.icon {
Some(AgentModelIcon::Path(path)) => this.child(
Icon::from_external_svg(path.clone())
.color(model_icon_color)
.size(IconSize::Small)
)
.size(IconSize::Small),
),
Some(AgentModelIcon::Named(icon)) => this.child(
Icon::new(*icon)
.color(model_icon_color)
.size(IconSize::Small),
),
None => this,
})
.child(Label::new(model_info.name.clone()).truncate()),
)

View File

@@ -1,7 +1,7 @@
use std::rc::Rc;
use std::sync::Arc;
use acp_thread::{AgentModelInfo, AgentModelSelector};
use acp_thread::{AgentModelIcon, AgentModelInfo, AgentModelSelector};
use agent_servers::AgentServer;
use fs::Fs;
use gpui::{Entity, FocusHandle};
@@ -64,7 +64,7 @@ impl Render for AcpModelSelectorPopover {
.map(|model| model.name.clone())
.unwrap_or_else(|| SharedString::from("Select a Model"));
let model_icon = model.as_ref().and_then(|model| model.icon);
let model_icon = model.as_ref().and_then(|model| model.icon.clone());
let focus_handle = self.focus_handle.clone();
@@ -78,8 +78,15 @@ impl Render for AcpModelSelectorPopover {
self.selector.clone(),
ButtonLike::new("active-model")
.selected_style(ButtonStyle::Tinted(TintColor::Accent))
.when_some(model_icon, |this, icon| {
this.child(Icon::new(icon).color(color).size(IconSize::XSmall))
.when_some(model_icon, |this, icon| match icon {
AgentModelIcon::Path(path) => this.child(
Icon::from_external_svg(path)
.color(color)
.size(IconSize::XSmall),
),
AgentModelIcon::Named(icon_name) => {
this.child(Icon::new(icon_name).color(color).size(IconSize::XSmall))
}
})
.child(
Label::new(model_name)

View File

@@ -36,7 +36,7 @@ use settings::{Settings, SettingsStore, update_settings_file};
use ui::{
Button, ButtonStyle, Chip, CommonAnimationExt, ContextMenu, ContextMenuEntry, Disclosure,
Divider, DividerColor, ElevationIndex, IconName, IconPosition, IconSize, Indicator, LabelSize,
PopoverMenu, Switch, Tooltip, WithScrollbar, prelude::*,
PopoverMenu, Switch, SwitchColor, Tooltip, WithScrollbar, prelude::*,
};
use util::ResultExt as _;
use workspace::{Workspace, create_and_open_local_file};
@@ -117,7 +117,7 @@ impl AgentConfiguration {
}
fn build_provider_configuration_views(&mut self, window: &mut Window, cx: &mut Context<Self>) {
let providers = LanguageModelRegistry::read_global(cx).providers();
let providers = LanguageModelRegistry::read_global(cx).visible_providers();
for provider in providers {
self.add_provider_configuration_view(&provider, window, cx);
}
@@ -260,11 +260,15 @@ impl AgentConfiguration {
h_flex()
.w_full()
.gap_1p5()
.child(
.child(if let Some(icon_path) = provider.icon_path() {
Icon::from_external_svg(icon_path)
.size(IconSize::Small)
.color(Color::Muted)
} else {
Icon::new(provider.icon())
.size(IconSize::Small)
.color(Color::Muted),
)
.color(Color::Muted)
})
.child(
h_flex()
.w_full()
@@ -416,7 +420,7 @@ impl AgentConfiguration {
&mut self,
cx: &mut Context<Self>,
) -> impl IntoElement {
let providers = LanguageModelRegistry::read_global(cx).providers();
let providers = LanguageModelRegistry::read_global(cx).visible_providers();
let popover_menu = PopoverMenu::new("add-provider-popover")
.trigger(
@@ -879,6 +883,7 @@ impl AgentConfiguration {
.child(context_server_configuration_menu)
.child(
Switch::new("context-server-switch", is_running.into())
.color(SwitchColor::Accent)
.on_click({
let context_server_manager = self.context_server_store.clone();
let fs = self.fs.clone();

View File

@@ -77,7 +77,8 @@ impl Render for AgentModelSelector {
.map(|model| model.model.name().0)
.unwrap_or_else(|| SharedString::from("Select a Model"));
let provider_icon = model.as_ref().map(|model| model.provider.icon());
let provider_icon_path = model.as_ref().and_then(|model| model.provider.icon_path());
let provider_icon_name = model.as_ref().map(|model| model.provider.icon());
let color = if self.menu_handle.is_deployed() {
Color::Accent
} else {
@@ -89,8 +90,17 @@ impl Render for AgentModelSelector {
PickerPopoverMenu::new(
self.selector.clone(),
ButtonLike::new("active-model")
.when_some(provider_icon, |this, icon| {
this.child(Icon::new(icon).color(color).size(IconSize::XSmall))
.when_some(provider_icon_path.clone(), |this, icon_path| {
this.child(
Icon::from_external_svg(icon_path)
.color(color)
.size(IconSize::XSmall),
)
})
.when(provider_icon_path.is_none(), |this| {
this.when_some(provider_icon_name, |this, icon| {
this.child(Icon::new(icon).color(color).size(IconSize::XSmall))
})
})
.selected_style(ButtonStyle::Tinted(TintColor::Accent))
.child(
@@ -102,7 +112,7 @@ impl Render for AgentModelSelector {
.child(
Icon::new(IconName::ChevronDown)
.color(color)
.size(IconSize::Small),
.size(IconSize::XSmall),
),
move |_window, cx| {
Tooltip::for_action_in("Change Model", &ToggleModelSelector, &focus_handle, cx)

View File

@@ -2292,7 +2292,7 @@ impl AgentPanel {
let history_is_empty = self.history_store.read(cx).is_empty(cx);
let has_configured_non_zed_providers = LanguageModelRegistry::read_global(cx)
.providers()
.visible_providers()
.iter()
.any(|provider| {
provider.is_authenticated(cx)

View File

@@ -338,7 +338,8 @@ fn init_language_model_settings(cx: &mut App) {
|_, event: &language_model::Event, cx| match event {
language_model::Event::ProviderStateChanged(_)
| language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => {
| language_model::Event::RemovedProvider(_)
| language_model::Event::ProvidersChanged => {
update_active_language_model_from_settings(cx);
}
_ => {}
@@ -357,26 +358,49 @@ fn update_active_language_model_from_settings(cx: &mut App) {
}
}
let default = settings.default_model.as_ref().map(to_selected_model);
// Filter out models from providers that are not authenticated
fn is_provider_authenticated(
selection: &LanguageModelSelection,
registry: &LanguageModelRegistry,
cx: &App,
) -> bool {
let provider_id = LanguageModelProviderId::from(selection.provider.0.clone());
registry
.provider(&provider_id)
.map_or(false, |provider| provider.is_authenticated(cx))
}
let registry = LanguageModelRegistry::global(cx);
let registry_ref = registry.read(cx);
let default = settings
.default_model
.as_ref()
.filter(|s| is_provider_authenticated(s, registry_ref, cx))
.map(to_selected_model);
let inline_assistant = settings
.inline_assistant_model
.as_ref()
.filter(|s| is_provider_authenticated(s, registry_ref, cx))
.map(to_selected_model);
let commit_message = settings
.commit_message_model
.as_ref()
.filter(|s| is_provider_authenticated(s, registry_ref, cx))
.map(to_selected_model);
let thread_summary = settings
.thread_summary_model
.as_ref()
.filter(|s| is_provider_authenticated(s, registry_ref, cx))
.map(to_selected_model);
let inline_alternatives = settings
.inline_alternatives
.iter()
.filter(|s| is_provider_authenticated(s, registry_ref, cx))
.map(to_selected_model)
.collect::<Vec<_>>();
LanguageModelRegistry::global(cx).update(cx, |registry, cx| {
registry.update(cx, |registry, cx| {
registry.select_default_model(default.as_ref(), cx);
registry.select_inline_assistant_model(inline_assistant.as_ref(), cx);
registry.select_commit_message_model(commit_message.as_ref(), cx);

View File

@@ -1,13 +1,12 @@
use std::{cmp::Reverse, sync::Arc};
use collections::IndexMap;
use futures::{StreamExt, channel::mpsc};
use fuzzy::{StringMatch, StringMatchCandidate, match_strings};
use gpui::{
Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Subscription, Task,
};
use gpui::{Action, AnyElement, App, BackgroundExecutor, DismissEvent, FocusHandle, Task};
use language_model::{
AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProviderId,
LanguageModelRegistry,
AuthenticateError, ConfiguredModel, LanguageModel, LanguageModelProvider,
LanguageModelProviderId, LanguageModelRegistry,
};
use ordered_float::OrderedFloat;
use picker::{Picker, PickerDelegate};
@@ -47,7 +46,9 @@ pub fn language_model_selector(
}
fn all_models(cx: &App) -> GroupedModels {
let providers = LanguageModelRegistry::global(cx).read(cx).providers();
let providers = LanguageModelRegistry::global(cx)
.read(cx)
.visible_providers();
let recommended = providers
.iter()
@@ -57,12 +58,12 @@ fn all_models(cx: &App) -> GroupedModels {
.into_iter()
.map(|model| ModelInfo {
model,
icon: provider.icon(),
icon: ProviderIcon::from_provider(provider.as_ref()),
})
})
.collect();
let all = providers
let all: Vec<ModelInfo> = providers
.iter()
.flat_map(|provider| {
provider
@@ -70,7 +71,7 @@ fn all_models(cx: &App) -> GroupedModels {
.into_iter()
.map(|model| ModelInfo {
model,
icon: provider.icon(),
icon: ProviderIcon::from_provider(provider.as_ref()),
})
})
.collect();
@@ -78,10 +79,26 @@ fn all_models(cx: &App) -> GroupedModels {
GroupedModels::new(all, recommended)
}
#[derive(Clone)]
enum ProviderIcon {
Name(IconName),
Path(SharedString),
}
impl ProviderIcon {
fn from_provider(provider: &dyn LanguageModelProvider) -> Self {
if let Some(path) = provider.icon_path() {
Self::Path(path)
} else {
Self::Name(provider.icon())
}
}
}
#[derive(Clone)]
struct ModelInfo {
model: Arc<dyn LanguageModel>,
icon: IconName,
icon: ProviderIcon,
}
pub struct LanguageModelPickerDelegate {
@@ -91,7 +108,7 @@ pub struct LanguageModelPickerDelegate {
filtered_entries: Vec<LanguageModelPickerEntry>,
selected_index: usize,
_authenticate_all_providers_task: Task<()>,
_subscriptions: Vec<Subscription>,
_refresh_models_task: Task<()>,
popover_styles: bool,
focus_handle: FocusHandle,
}
@@ -116,24 +133,43 @@ impl LanguageModelPickerDelegate {
filtered_entries: entries,
get_active_model: Arc::new(get_active_model),
_authenticate_all_providers_task: Self::authenticate_all_providers(cx),
_subscriptions: vec![cx.subscribe_in(
&LanguageModelRegistry::global(cx),
window,
|picker, _, event, window, cx| {
match event {
language_model::Event::ProviderStateChanged(_)
| language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => {
let query = picker.query(cx);
picker.delegate.all_models = Arc::new(all_models(cx));
// Update matches will automatically drop the previous task
// if we get a provider event again
picker.update_matches(query, window, cx)
}
_ => {}
_refresh_models_task: {
// Create a channel to signal when models need refreshing
let (refresh_tx, mut refresh_rx) = mpsc::unbounded::<()>();
// Subscribe to registry events and send refresh signals through the channel
let registry = LanguageModelRegistry::global(cx);
cx.subscribe(&registry, move |_picker, _, event, _cx| match event {
language_model::Event::ProviderStateChanged(_) => {
refresh_tx.unbounded_send(()).ok();
}
},
)],
language_model::Event::AddedProvider(_) => {
refresh_tx.unbounded_send(()).ok();
}
language_model::Event::RemovedProvider(_) => {
refresh_tx.unbounded_send(()).ok();
}
language_model::Event::ProvidersChanged => {
refresh_tx.unbounded_send(()).ok();
}
_ => {}
})
.detach();
// Spawn a task that listens for refresh signals and updates the picker
cx.spawn_in(window, async move |this, cx| {
while let Some(()) = refresh_rx.next().await {
let result = this.update_in(cx, |picker, window, cx| {
picker.delegate.all_models = Arc::new(all_models(cx));
picker.refresh(window, cx);
});
if result.is_err() {
// Picker was dropped, exit the loop
break;
}
}
})
},
popover_styles,
focus_handle,
}
@@ -392,7 +428,7 @@ impl PickerDelegate for LanguageModelPickerDelegate {
let configured_providers = language_model_registry
.read(cx)
.providers()
.visible_providers()
.into_iter()
.filter(|provider| provider.is_authenticated(cx))
.collect::<Vec<_>>();
@@ -504,11 +540,16 @@ impl PickerDelegate for LanguageModelPickerDelegate {
h_flex()
.w_full()
.gap_1p5()
.child(
Icon::new(model_info.icon)
.child(match &model_info.icon {
ProviderIcon::Name(icon_name) => Icon::new(*icon_name)
.color(model_icon_color)
.size(IconSize::Small),
)
ProviderIcon::Path(icon_path) => {
Icon::from_external_svg(icon_path.clone())
.color(model_icon_color)
.size(IconSize::Small)
}
})
.child(Label::new(model_info.model.name().0).truncate()),
)
.end_slot(div().pr_3().when(is_selected, |this| {
@@ -657,7 +698,7 @@ mod tests {
.into_iter()
.map(|(provider, name)| ModelInfo {
model: Arc::new(TestLanguageModel::new(name, provider)),
icon: IconName::Ai,
icon: ProviderIcon::Name(IconName::Ai),
})
.collect()
}

View File

@@ -1682,98 +1682,6 @@ impl TextThreadEditor {
window: &mut Window,
cx: &mut Context<Self>,
) {
let editor_clipboard_selections = cx
.read_from_clipboard()
.and_then(|item| item.entries().first().cloned())
.and_then(|entry| match entry {
ClipboardEntry::String(text) => {
text.metadata_json::<Vec<editor::ClipboardSelection>>()
}
_ => None,
});
let has_file_context = editor_clipboard_selections
.as_ref()
.is_some_and(|selections| {
selections
.iter()
.any(|sel| sel.file_path.is_some() && sel.line_range.is_some())
});
if has_file_context {
if let Some(clipboard_item) = cx.read_from_clipboard() {
if let Some(ClipboardEntry::String(clipboard_text)) =
clipboard_item.entries().first()
{
if let Some(selections) = editor_clipboard_selections {
cx.stop_propagation();
let text = clipboard_text.text();
self.editor.update(cx, |editor, cx| {
let mut current_offset = 0;
let weak_editor = cx.entity().downgrade();
for selection in selections {
if let (Some(file_path), Some(line_range)) =
(selection.file_path, selection.line_range)
{
let selected_text =
&text[current_offset..current_offset + selection.len];
let fence = assistant_slash_commands::codeblock_fence_for_path(
file_path.to_str(),
Some(line_range.clone()),
);
let formatted_text = format!("{fence}{selected_text}\n```");
let insert_point = editor
.selections
.newest::<Point>(&editor.display_snapshot(cx))
.head();
let start_row = MultiBufferRow(insert_point.row);
editor.insert(&formatted_text, window, cx);
let snapshot = editor.buffer().read(cx).snapshot(cx);
let anchor_before = snapshot.anchor_after(insert_point);
let anchor_after = editor
.selections
.newest_anchor()
.head()
.bias_left(&snapshot);
editor.insert("\n", window, cx);
let crease_text = acp_thread::selection_name(
Some(file_path.as_ref()),
&line_range,
);
let fold_placeholder = quote_selection_fold_placeholder(
crease_text,
weak_editor.clone(),
);
let crease = Crease::inline(
anchor_before..anchor_after,
fold_placeholder,
render_quote_selection_output_toggle,
|_, _, _, _| Empty.into_any(),
);
editor.insert_creases(vec![crease], cx);
editor.fold_at(start_row, window, cx);
current_offset += selection.len;
if !selection.is_entire_line && current_offset < text.len() {
current_offset += 1;
}
}
}
});
return;
}
}
}
}
cx.stop_propagation();
let mut images = if let Some(item) = cx.read_from_clipboard() {
@@ -2189,7 +2097,8 @@ impl TextThreadEditor {
.default_model()
.map(|default| default.provider);
let provider_icon = match active_provider {
let provider_icon_path = active_provider.as_ref().and_then(|p| p.icon_path());
let provider_icon_name = match &active_provider {
Some(provider) => provider.icon(),
None => IconName::Ai,
};
@@ -2201,6 +2110,16 @@ impl TextThreadEditor {
(Color::Muted, IconName::ChevronDown)
};
let provider_icon_element = if let Some(icon_path) = provider_icon_path {
Icon::from_external_svg(icon_path)
.color(color)
.size(IconSize::XSmall)
} else {
Icon::new(provider_icon_name)
.color(color)
.size(IconSize::XSmall)
};
PickerPopoverMenu::new(
self.language_model_selector.clone(),
ButtonLike::new("active-model")
@@ -2208,7 +2127,7 @@ impl TextThreadEditor {
.child(
h_flex()
.gap_0p5()
.child(Icon::new(provider_icon).color(color).size(IconSize::XSmall))
.child(provider_icon_element)
.child(
Label::new(model_name)
.color(color)

View File

@@ -1,9 +1,25 @@
use gpui::{Action, IntoElement, ParentElement, RenderOnce, point};
use language_model::{LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID};
use language_model::{LanguageModelProvider, LanguageModelRegistry, ZED_CLOUD_PROVIDER_ID};
use ui::{Divider, List, ListBulletItem, prelude::*};
#[derive(Clone)]
enum ProviderIcon {
Name(IconName),
Path(SharedString),
}
impl ProviderIcon {
fn from_provider(provider: &dyn LanguageModelProvider) -> Self {
if let Some(path) = provider.icon_path() {
Self::Path(path)
} else {
Self::Name(provider.icon())
}
}
}
pub struct ApiKeysWithProviders {
configured_providers: Vec<(IconName, SharedString)>,
configured_providers: Vec<(ProviderIcon, SharedString)>,
}
impl ApiKeysWithProviders {
@@ -13,7 +29,8 @@ impl ApiKeysWithProviders {
|this: &mut Self, _registry, event: &language_model::Event, cx| match event {
language_model::Event::ProviderStateChanged(_)
| language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => {
| language_model::Event::RemovedProvider(_)
| language_model::Event::ProvidersChanged => {
this.configured_providers = Self::compute_configured_providers(cx)
}
_ => {}
@@ -26,14 +43,19 @@ impl ApiKeysWithProviders {
}
}
fn compute_configured_providers(cx: &App) -> Vec<(IconName, SharedString)> {
fn compute_configured_providers(cx: &App) -> Vec<(ProviderIcon, SharedString)> {
LanguageModelRegistry::read_global(cx)
.providers()
.visible_providers()
.iter()
.filter(|provider| {
provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID
})
.map(|provider| (provider.icon(), provider.name().0))
.map(|provider| {
(
ProviderIcon::from_provider(provider.as_ref()),
provider.name().0,
)
})
.collect()
}
}
@@ -47,7 +69,14 @@ impl Render for ApiKeysWithProviders {
.map(|(icon, name)| {
h_flex()
.gap_1p5()
.child(Icon::new(icon).size(IconSize::XSmall).color(Color::Muted))
.child(match icon {
ProviderIcon::Name(icon_name) => Icon::new(icon_name)
.size(IconSize::XSmall)
.color(Color::Muted),
ProviderIcon::Path(icon_path) => Icon::from_external_svg(icon_path)
.size(IconSize::XSmall)
.color(Color::Muted),
})
.child(Label::new(name))
});
div()

View File

@@ -11,7 +11,7 @@ use crate::{AgentPanelOnboardingCard, ApiKeysWithoutProviders, ZedAiOnboarding};
pub struct AgentPanelOnboarding {
user_store: Entity<UserStore>,
client: Arc<Client>,
configured_providers: Vec<(IconName, SharedString)>,
has_configured_providers: bool,
continue_with_zed_ai: Arc<dyn Fn(&mut Window, &mut App)>,
}
@@ -27,8 +27,9 @@ impl AgentPanelOnboarding {
|this: &mut Self, _registry, event: &language_model::Event, cx| match event {
language_model::Event::ProviderStateChanged(_)
| language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => {
this.configured_providers = Self::compute_available_providers(cx)
| language_model::Event::RemovedProvider(_)
| language_model::Event::ProvidersChanged => {
this.has_configured_providers = Self::has_configured_providers(cx)
}
_ => {}
},
@@ -38,20 +39,16 @@ impl AgentPanelOnboarding {
Self {
user_store,
client,
configured_providers: Self::compute_available_providers(cx),
has_configured_providers: Self::has_configured_providers(cx),
continue_with_zed_ai: Arc::new(continue_with_zed_ai),
}
}
fn compute_available_providers(cx: &App) -> Vec<(IconName, SharedString)> {
fn has_configured_providers(cx: &App) -> bool {
LanguageModelRegistry::read_global(cx)
.providers()
.visible_providers()
.iter()
.filter(|provider| {
provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID
})
.map(|provider| (provider.icon(), provider.name().0))
.collect()
.any(|provider| provider.is_authenticated(cx) && provider.id() != ZED_CLOUD_PROVIDER_ID)
}
}
@@ -81,7 +78,7 @@ impl Render for AgentPanelOnboarding {
}),
)
.map(|this| {
if enrolled_in_trial || is_pro_user || !self.configured_providers.is_empty() {
if enrolled_in_trial || is_pro_user || self.has_configured_providers {
this
} else {
this.child(ApiKeysWithoutProviders::new())

View File

@@ -8,7 +8,7 @@ use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::B
use http_client::http::{self, HeaderMap, HeaderValue};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest, StatusCode};
use serde::{Deserialize, Serialize};
pub use settings::{AnthropicAvailableModel as AvailableModel, ModelMode};
pub use settings::ModelMode;
use strum::{EnumIter, EnumString};
use thiserror::Error;

View File

@@ -29,6 +29,7 @@ pub struct ExtensionHostProxy {
slash_command_proxy: RwLock<Option<Arc<dyn ExtensionSlashCommandProxy>>>,
context_server_proxy: RwLock<Option<Arc<dyn ExtensionContextServerProxy>>>,
debug_adapter_provider_proxy: RwLock<Option<Arc<dyn ExtensionDebugAdapterProviderProxy>>>,
language_model_provider_proxy: RwLock<Option<Arc<dyn ExtensionLanguageModelProviderProxy>>>,
}
impl ExtensionHostProxy {
@@ -54,6 +55,7 @@ impl ExtensionHostProxy {
slash_command_proxy: RwLock::default(),
context_server_proxy: RwLock::default(),
debug_adapter_provider_proxy: RwLock::default(),
language_model_provider_proxy: RwLock::default(),
}
}
@@ -90,6 +92,15 @@ impl ExtensionHostProxy {
.write()
.replace(Arc::new(proxy));
}
pub fn register_language_model_provider_proxy(
&self,
proxy: impl ExtensionLanguageModelProviderProxy,
) {
self.language_model_provider_proxy
.write()
.replace(Arc::new(proxy));
}
}
pub trait ExtensionThemeProxy: Send + Sync + 'static {
@@ -375,6 +386,49 @@ pub trait ExtensionContextServerProxy: Send + Sync + 'static {
fn unregister_context_server(&self, server_id: Arc<str>, cx: &mut App);
}
/// A function that registers a language model provider with the registry.
/// This allows extension_host to create the provider (which requires WasmExtension)
/// and pass a registration closure to the language_models crate.
pub type LanguageModelProviderRegistration = Box<dyn FnOnce(&mut App) + Send + Sync + 'static>;
pub trait ExtensionLanguageModelProviderProxy: Send + Sync + 'static {
/// Register an LLM provider from an extension.
/// The `register_fn` closure will be called with the App context and should
/// register the provider with the LanguageModelRegistry.
fn register_language_model_provider(
&self,
provider_id: Arc<str>,
register_fn: LanguageModelProviderRegistration,
cx: &mut App,
);
/// Unregister an LLM provider when an extension is unloaded.
fn unregister_language_model_provider(&self, provider_id: Arc<str>, cx: &mut App);
}
impl ExtensionLanguageModelProviderProxy for ExtensionHostProxy {
fn register_language_model_provider(
&self,
provider_id: Arc<str>,
register_fn: LanguageModelProviderRegistration,
cx: &mut App,
) {
let Some(proxy) = self.language_model_provider_proxy.read().clone() else {
return;
};
proxy.register_language_model_provider(provider_id, register_fn, cx)
}
fn unregister_language_model_provider(&self, provider_id: Arc<str>, cx: &mut App) {
let Some(proxy) = self.language_model_provider_proxy.read().clone() else {
return;
};
proxy.unregister_language_model_provider(provider_id, cx)
}
}
impl ExtensionContextServerProxy for ExtensionHostProxy {
fn register_context_server(
&self,

View File

@@ -93,6 +93,8 @@ pub struct ExtensionManifest {
pub debug_adapters: BTreeMap<Arc<str>, DebugAdapterManifestEntry>,
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub debug_locators: BTreeMap<Arc<str>, DebugLocatorManifestEntry>,
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub language_model_providers: BTreeMap<Arc<str>, LanguageModelProviderManifestEntry>,
}
impl ExtensionManifest {
@@ -288,6 +290,71 @@ pub struct DebugAdapterManifestEntry {
#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
pub struct DebugLocatorManifestEntry {}
/// Manifest entry for a language model provider.
#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
pub struct LanguageModelProviderManifestEntry {
/// Display name for the provider.
pub name: String,
/// Path to an SVG icon file relative to the extension root (e.g., "icons/provider.svg").
#[serde(default)]
pub icon: Option<String>,
/// Default models to show even before API connection.
#[serde(default)]
pub models: Vec<LanguageModelManifestEntry>,
/// Authentication configuration.
#[serde(default)]
pub auth: Option<LanguageModelAuthConfig>,
}
/// Manifest entry for a language model.
#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
pub struct LanguageModelManifestEntry {
/// Unique identifier for the model.
pub id: String,
/// Display name for the model.
pub name: String,
/// Maximum input token count.
#[serde(default)]
pub max_token_count: u64,
/// Maximum output tokens (optional).
#[serde(default)]
pub max_output_tokens: Option<u64>,
/// Whether the model supports image inputs.
#[serde(default)]
pub supports_images: bool,
/// Whether the model supports tool/function calling.
#[serde(default)]
pub supports_tools: bool,
/// Whether the model supports extended thinking/reasoning.
#[serde(default)]
pub supports_thinking: bool,
}
/// Authentication configuration for a language model provider.
#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
pub struct LanguageModelAuthConfig {
/// Environment variable name for the API key.
#[serde(default)]
pub env_var: Option<String>,
/// Human-readable name for the credential shown in the UI input field (e.g., "API Key", "Access Token").
#[serde(default)]
pub credential_label: Option<String>,
/// OAuth configuration for web-based authentication flows.
#[serde(default)]
pub oauth: Option<OAuthConfig>,
}
/// OAuth configuration for web-based authentication.
#[derive(Clone, PartialEq, Eq, Debug, Deserialize, Serialize)]
pub struct OAuthConfig {
/// The text to display on the sign-in button (e.g., "Sign in with GitHub").
#[serde(default)]
pub sign_in_button_label: Option<String>,
/// The icon to display on the sign-in button (e.g., "github").
#[serde(default)]
pub sign_in_button_icon: Option<String>,
}
impl ExtensionManifest {
pub async fn load(fs: Arc<dyn Fs>, extension_dir: &Path) -> Result<Self> {
let extension_name = extension_dir
@@ -358,6 +425,7 @@ fn manifest_from_old_manifest(
capabilities: Vec::new(),
debug_adapters: Default::default(),
debug_locators: Default::default(),
language_model_providers: Default::default(),
}
}
@@ -391,6 +459,7 @@ mod tests {
capabilities: vec![],
debug_adapters: Default::default(),
debug_locators: Default::default(),
language_model_providers: BTreeMap::default(),
}
}

View File

@@ -29,6 +29,27 @@ pub use wit::{
GithubRelease, GithubReleaseAsset, GithubReleaseOptions, github_release_by_tag_name,
latest_github_release,
},
zed::extension::llm_provider::{
CacheConfiguration as LlmCacheConfiguration, CompletionEvent as LlmCompletionEvent,
CompletionRequest as LlmCompletionRequest, CredentialType as LlmCredentialType,
ImageData as LlmImageData, MessageContent as LlmMessageContent,
MessageRole as LlmMessageRole, ModelCapabilities as LlmModelCapabilities,
ModelInfo as LlmModelInfo, OauthHttpRequest as LlmOauthHttpRequest,
OauthHttpResponse as LlmOauthHttpResponse, OauthWebAuthConfig as LlmOauthWebAuthConfig,
OauthWebAuthResult as LlmOauthWebAuthResult, ProviderInfo as LlmProviderInfo,
RequestMessage as LlmRequestMessage, StopReason as LlmStopReason,
ThinkingContent as LlmThinkingContent, TokenUsage as LlmTokenUsage,
ToolChoice as LlmToolChoice, ToolDefinition as LlmToolDefinition,
ToolInputFormat as LlmToolInputFormat, ToolResult as LlmToolResult,
ToolResultContent as LlmToolResultContent, ToolUse as LlmToolUse,
ToolUseJsonParseError as LlmToolUseJsonParseError,
delete_credential as llm_delete_credential, get_credential as llm_get_credential,
get_env_var as llm_get_env_var, oauth_open_browser as llm_oauth_open_browser,
oauth_start_web_auth as llm_oauth_start_web_auth,
request_credential as llm_request_credential,
send_oauth_http_request as llm_oauth_http_request,
store_credential as llm_store_credential,
},
zed::extension::nodejs::{
node_binary_path, npm_install_package, npm_package_installed_version,
npm_package_latest_version,
@@ -259,6 +280,94 @@ pub trait Extension: Send + Sync {
) -> Result<DebugRequest, String> {
Err("`run_dap_locator` not implemented".to_string())
}
/// Returns information about language model providers offered by this extension.
fn llm_providers(&self) -> Vec<LlmProviderInfo> {
Vec::new()
}
/// Returns the models available for a provider.
fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
Ok(Vec::new())
}
/// Returns markdown content to display in the provider's settings UI.
/// This can include setup instructions, links to documentation, etc.
fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
None
}
/// Check if the provider is authenticated.
fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
false
}
/// Start an OAuth device flow sign-in.
/// This is called when the user explicitly clicks "Sign in with GitHub" or similar.
/// Opens the browser to the verification URL and returns the user code that should
/// be displayed to the user.
fn llm_provider_start_device_flow_sign_in(
&mut self,
_provider_id: &str,
) -> Result<String, String> {
Err("`llm_provider_start_device_flow_sign_in` not implemented".to_string())
}
/// Poll for device flow sign-in completion.
/// This is called after llm_provider_start_device_flow_sign_in returns the user code.
/// The extension should poll the OAuth provider until the user authorizes or the flow times out.
fn llm_provider_poll_device_flow_sign_in(&mut self, _provider_id: &str) -> Result<(), String> {
Err("`llm_provider_poll_device_flow_sign_in` not implemented".to_string())
}
/// Reset credentials for the provider.
fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
Err("`llm_provider_reset_credentials` not implemented".to_string())
}
/// Count tokens for a request.
fn llm_count_tokens(
&self,
_provider_id: &str,
_model_id: &str,
_request: &LlmCompletionRequest,
) -> Result<u64, String> {
Err("`llm_count_tokens` not implemented".to_string())
}
/// Start streaming a completion from the model.
/// Returns a stream ID that can be used with `llm_stream_completion_next` and `llm_stream_completion_close`.
fn llm_stream_completion_start(
&mut self,
_provider_id: &str,
_model_id: &str,
_request: &LlmCompletionRequest,
) -> Result<String, String> {
Err("`llm_stream_completion_start` not implemented".to_string())
}
/// Get the next event from a completion stream.
/// Returns `Ok(None)` when the stream is complete.
fn llm_stream_completion_next(
&mut self,
_stream_id: &str,
) -> Result<Option<LlmCompletionEvent>, String> {
Err("`llm_stream_completion_next` not implemented".to_string())
}
/// Close a completion stream and release its resources.
fn llm_stream_completion_close(&mut self, _stream_id: &str) {
// Default implementation does nothing
}
/// Get cache configuration for a model (if prompt caching is supported).
fn llm_cache_configuration(
&self,
_provider_id: &str,
_model_id: &str,
) -> Option<LlmCacheConfiguration> {
None
}
}
/// Registers the provided type as a Zed extension.
@@ -518,6 +627,65 @@ impl wit::Guest for Component {
) -> Result<DebugRequest, String> {
extension().run_dap_locator(locator_name, build_task)
}
fn llm_providers() -> Vec<LlmProviderInfo> {
extension().llm_providers()
}
fn llm_provider_models(provider_id: String) -> Result<Vec<LlmModelInfo>, String> {
extension().llm_provider_models(&provider_id)
}
fn llm_provider_settings_markdown(provider_id: String) -> Option<String> {
extension().llm_provider_settings_markdown(&provider_id)
}
fn llm_provider_is_authenticated(provider_id: String) -> bool {
extension().llm_provider_is_authenticated(&provider_id)
}
fn llm_provider_start_device_flow_sign_in(provider_id: String) -> Result<String, String> {
extension().llm_provider_start_device_flow_sign_in(&provider_id)
}
fn llm_provider_poll_device_flow_sign_in(provider_id: String) -> Result<(), String> {
extension().llm_provider_poll_device_flow_sign_in(&provider_id)
}
fn llm_provider_reset_credentials(provider_id: String) -> Result<(), String> {
extension().llm_provider_reset_credentials(&provider_id)
}
fn llm_count_tokens(
provider_id: String,
model_id: String,
request: LlmCompletionRequest,
) -> Result<u64, String> {
extension().llm_count_tokens(&provider_id, &model_id, &request)
}
fn llm_stream_completion_start(
provider_id: String,
model_id: String,
request: LlmCompletionRequest,
) -> Result<String, String> {
extension().llm_stream_completion_start(&provider_id, &model_id, &request)
}
fn llm_stream_completion_next(stream_id: String) -> Result<Option<LlmCompletionEvent>, String> {
extension().llm_stream_completion_next(&stream_id)
}
fn llm_stream_completion_close(stream_id: String) {
extension().llm_stream_completion_close(&stream_id)
}
fn llm_cache_configuration(
provider_id: String,
model_id: String,
) -> Option<LlmCacheConfiguration> {
extension().llm_cache_configuration(&provider_id, &model_id)
}
}
/// The ID of a language server.

View File

@@ -8,6 +8,7 @@ world extension {
import platform;
import process;
import nodejs;
import llm-provider;
use common.{env-vars, range};
use context-server.{context-server-configuration};
@@ -15,6 +16,10 @@ world extension {
use lsp.{completion, symbol};
use process.{command};
use slash-command.{slash-command, slash-command-argument-completion, slash-command-output};
use llm-provider.{
provider-info, model-info, completion-request,
credential-type, cache-configuration, completion-event, token-usage
};
/// Initializes the extension.
export init-extension: func();
@@ -164,4 +169,74 @@ world extension {
export dap-config-to-scenario: func(config: debug-config) -> result<debug-scenario, string>;
export dap-locator-create-scenario: func(locator-name: string, build-config-template: build-task-template, resolved-label: string, debug-adapter-name: string) -> option<debug-scenario>;
export run-dap-locator: func(locator-name: string, config: resolved-task) -> result<debug-request, string>;
/// Returns information about language model providers offered by this extension.
export llm-providers: func() -> list<provider-info>;
/// Returns the models available for a provider.
export llm-provider-models: func(provider-id: string) -> result<list<model-info>, string>;
/// Returns markdown content to display in the provider's settings UI.
/// This can include setup instructions, links to documentation, etc.
export llm-provider-settings-markdown: func(provider-id: string) -> option<string>;
/// Check if the provider is authenticated.
export llm-provider-is-authenticated: func(provider-id: string) -> bool;
/// Start an OAuth device flow sign-in.
/// This is called when the user explicitly clicks "Sign in with GitHub" or similar.
///
/// The device flow works as follows:
/// 1. Extension requests a device code from the OAuth provider
/// 2. Extension opens the verification URL in the browser
/// 3. Extension returns the user code to display to the user
/// 4. Host displays the user code and calls llm-provider-poll-device-flow-sign-in
/// 5. Extension polls for the access token while user authorizes in browser
/// 6. Once authorized, extension stores the credential and returns success
///
/// Returns the user code that should be displayed to the user while they
/// complete authorization in the browser.
export llm-provider-start-device-flow-sign-in: func(provider-id: string) -> result<string, string>;
/// Poll for device flow sign-in completion.
/// This is called after llm-provider-start-device-flow-sign-in returns the user code.
/// The extension should poll the OAuth provider until the user authorizes or the flow times out.
/// Returns Ok(()) on successful authentication, or an error message on failure.
export llm-provider-poll-device-flow-sign-in: func(provider-id: string) -> result<_, string>;
/// Reset credentials for the provider.
export llm-provider-reset-credentials: func(provider-id: string) -> result<_, string>;
/// Count tokens for a request.
export llm-count-tokens: func(
provider-id: string,
model-id: string,
request: completion-request
) -> result<u64, string>;
/// Start streaming a completion from the model.
/// Returns a stream ID that can be used with llm-stream-next and llm-stream-close.
export llm-stream-completion-start: func(
provider-id: string,
model-id: string,
request: completion-request
) -> result<string, string>;
/// Get the next event from a completion stream.
/// Returns None when the stream is complete.
export llm-stream-completion-next: func(
stream-id: string
) -> result<option<completion-event>, string>;
/// Close a completion stream and release its resources.
export llm-stream-completion-close: func(
stream-id: string
);
/// Get cache configuration for a model (if prompt caching is supported).
export llm-cache-configuration: func(
provider-id: string,
model-id: string
) -> option<cache-configuration>;
}

View File

@@ -0,0 +1,348 @@
interface llm-provider {
/// Information about a language model provider.
record provider-info {
/// Unique identifier for the provider (e.g., "my-extension.my-provider").
id: string,
/// Display name for the provider.
name: string,
/// Path to an SVG icon file relative to the extension root (e.g., "icons/provider.svg").
icon: option<string>,
}
/// Capabilities of a language model.
record model-capabilities {
/// Whether the model supports image inputs.
supports-images: bool,
/// Whether the model supports tool/function calling.
supports-tools: bool,
/// Whether the model supports the "auto" tool choice.
supports-tool-choice-auto: bool,
/// Whether the model supports the "any" tool choice.
supports-tool-choice-any: bool,
/// Whether the model supports the "none" tool choice.
supports-tool-choice-none: bool,
/// Whether the model supports extended thinking/reasoning.
supports-thinking: bool,
/// The format for tool input schemas.
tool-input-format: tool-input-format,
}
/// Format for tool input schemas.
enum tool-input-format {
/// Standard JSON Schema format.
json-schema,
/// Simplified schema format for certain providers.
simplified,
}
/// Information about a specific model.
record model-info {
/// Unique identifier for the model.
id: string,
/// Display name for the model.
name: string,
/// Maximum input token count.
max-token-count: u64,
/// Maximum output tokens (optional).
max-output-tokens: option<u64>,
/// Model capabilities.
capabilities: model-capabilities,
/// Whether this is the default model for the provider.
is-default: bool,
/// Whether this is the default fast model.
is-default-fast: bool,
}
/// The role of a message participant.
enum message-role {
/// User message.
user,
/// Assistant message.
assistant,
/// System message.
system,
}
/// A message in a completion request.
record request-message {
/// The role of the message sender.
role: message-role,
/// The content of the message.
content: list<message-content>,
/// Whether to cache this message for prompt caching.
cache: bool,
}
/// Content within a message.
variant message-content {
/// Plain text content.
text(string),
/// Image content.
image(image-data),
/// A tool use request from the assistant.
tool-use(tool-use),
/// A tool result from the user.
tool-result(tool-result),
/// Thinking/reasoning content.
thinking(thinking-content),
/// Redacted/encrypted thinking content.
redacted-thinking(string),
}
/// Image data for vision models.
record image-data {
/// Base64-encoded image data.
source: string,
/// Image width in pixels (optional).
width: option<u32>,
/// Image height in pixels (optional).
height: option<u32>,
}
/// A tool use request from the model.
record tool-use {
/// Unique identifier for this tool use.
id: string,
/// The name of the tool being used.
name: string,
/// JSON string of the tool input arguments.
input: string,
/// Thought signature for providers that support it (e.g., Anthropic).
thought-signature: option<string>,
}
/// A tool result to send back to the model.
record tool-result {
/// The ID of the tool use this is a result for.
tool-use-id: string,
/// The name of the tool.
tool-name: string,
/// Whether this result represents an error.
is-error: bool,
/// The content of the result.
content: tool-result-content,
}
/// Content of a tool result.
variant tool-result-content {
/// Text result.
text(string),
/// Image result.
image(image-data),
}
/// Thinking/reasoning content from models that support extended thinking.
record thinking-content {
/// The thinking text.
text: string,
/// Signature for the thinking block (provider-specific).
signature: option<string>,
}
/// A tool definition for function calling.
record tool-definition {
/// The name of the tool.
name: string,
/// Description of what the tool does.
description: string,
/// JSON Schema for input parameters.
input-schema: string,
}
/// Tool choice preference for the model.
enum tool-choice {
/// Let the model decide whether to use tools.
auto,
/// Force the model to use at least one tool.
any,
/// Prevent the model from using tools.
none,
}
/// A completion request to send to the model.
record completion-request {
/// The messages in the conversation.
messages: list<request-message>,
/// Available tools for the model to use.
tools: list<tool-definition>,
/// Tool choice preference.
tool-choice: option<tool-choice>,
/// Stop sequences to end generation.
stop-sequences: list<string>,
/// Temperature for sampling (0.0-1.0).
temperature: option<f32>,
/// Whether thinking/reasoning is allowed.
thinking-allowed: bool,
/// Maximum tokens to generate.
max-tokens: option<u64>,
}
/// Events emitted during completion streaming.
variant completion-event {
/// Completion has started.
started,
/// Text content chunk.
text(string),
/// Thinking/reasoning content chunk.
thinking(thinking-content),
/// Redacted thinking (encrypted) chunk.
redacted-thinking(string),
/// Tool use request from the model.
tool-use(tool-use),
/// JSON parse error when parsing tool input.
tool-use-json-parse-error(tool-use-json-parse-error),
/// Completion stopped.
stop(stop-reason),
/// Token usage update.
usage(token-usage),
/// Reasoning details (provider-specific JSON).
reasoning-details(string),
}
/// Error information when tool use JSON parsing fails.
record tool-use-json-parse-error {
/// The tool use ID.
id: string,
/// The tool name.
tool-name: string,
/// The raw input that failed to parse.
raw-input: string,
/// The parse error message.
error: string,
}
/// Reason the completion stopped.
enum stop-reason {
/// The model finished generating.
end-turn,
/// Maximum tokens reached.
max-tokens,
/// The model wants to use a tool.
tool-use,
/// The model refused to respond.
refusal,
}
/// Token usage statistics.
record token-usage {
/// Number of input tokens used.
input-tokens: u64,
/// Number of output tokens generated.
output-tokens: u64,
/// Tokens used for cache creation (if supported).
cache-creation-input-tokens: option<u64>,
/// Tokens read from cache (if supported).
cache-read-input-tokens: option<u64>,
}
/// Credential types that can be requested.
enum credential-type {
/// An API key.
api-key,
/// An OAuth token.
oauth-token,
}
/// Cache configuration for prompt caching.
record cache-configuration {
/// Maximum number of cache anchors.
max-cache-anchors: u32,
/// Whether caching should be applied to tool definitions.
should-cache-tool-definitions: bool,
/// Minimum token count for a message to be cached.
min-total-token-count: u64,
}
/// Configuration for starting an OAuth web authentication flow.
record oauth-web-auth-config {
/// The URL to open in the user's browser to start authentication.
/// This should include client_id, redirect_uri, scope, state, etc.
/// Use `{port}` as a placeholder in the URL - it will be replaced with
/// the actual localhost port before opening the browser.
/// Example: "https://example.com/oauth?redirect_uri=http://127.0.0.1:{port}/callback"
auth-url: string,
/// The path to listen on for the OAuth callback (e.g., "/callback").
/// A localhost server will be started to receive the redirect.
callback-path: string,
/// Timeout in seconds to wait for the callback (default: 300 = 5 minutes).
timeout-secs: option<u32>,
}
/// Result of an OAuth web authentication flow.
record oauth-web-auth-result {
/// The full callback URL that was received, including query parameters.
/// The extension is responsible for parsing the code, state, etc.
callback-url: string,
/// The port that was used for the localhost callback server.
port: u32,
}
/// A generic HTTP request for OAuth token exchange.
record oauth-http-request {
/// The URL to request.
url: string,
/// HTTP method (e.g., "POST", "GET").
method: string,
/// Request headers as key-value pairs.
headers: list<tuple<string, string>>,
/// Request body as a string (for form-encoded or JSON bodies).
body: string,
}
/// Response from an OAuth HTTP request.
record oauth-http-response {
/// HTTP status code.
status: u16,
/// Response headers as key-value pairs.
headers: list<tuple<string, string>>,
/// Response body as a string.
body: string,
}
/// Request a credential from the user.
/// Returns true if the credential was provided, false if the user cancelled.
request-credential: func(
provider-id: string,
credential-type: credential-type,
label: string,
placeholder: string
) -> result<bool, string>;
/// Get a stored credential for this provider.
get-credential: func(provider-id: string) -> option<string>;
/// Store a credential for this provider.
store-credential: func(provider-id: string, value: string) -> result<_, string>;
/// Delete a stored credential for this provider.
delete-credential: func(provider-id: string) -> result<_, string>;
/// Read an environment variable.
get-env-var: func(name: string) -> option<string>;
/// Start an OAuth web authentication flow.
///
/// This will:
/// 1. Start a localhost server to receive the OAuth callback
/// 2. Open the auth URL in the user's default browser
/// 3. Wait for the callback (up to the timeout)
/// 4. Return the callback URL with query parameters
///
/// The extension is responsible for:
/// - Constructing the auth URL with client_id, redirect_uri, scope, state, etc.
/// - Parsing the callback URL to extract the authorization code
/// - Exchanging the code for tokens using oauth-http-request
oauth-start-web-auth: func(config: oauth-web-auth-config) -> result<oauth-web-auth-result, string>;
/// Make an HTTP request for OAuth token exchange.
///
/// This is a simple HTTP client for OAuth flows, allowing the extension
/// to handle token exchange with full control over serialization.
send-oauth-http-request: func(request: oauth-http-request) -> result<oauth-http-response, string>;
/// Open a URL in the user's default browser.
///
/// Useful for OAuth flows that need to open a browser but handle the
/// callback differently (e.g., polling-based flows).
oauth-open-browser: func(url: string) -> result<_, string>;
}

View File

@@ -255,6 +255,21 @@ async fn copy_extension_resources(
}
}
for (_, provider_entry) in &manifest.language_model_providers {
if let Some(icon_path) = &provider_entry.icon {
let source_icon = extension_path.join(icon_path);
let dest_icon = output_dir.join(icon_path);
// Create parent directory if needed
if let Some(parent) = dest_icon.parent() {
fs::create_dir_all(parent)?;
}
fs::copy(&source_icon, &dest_icon)
.with_context(|| format!("failed to copy LLM provider icon '{}'", icon_path))?;
}
}
if !manifest.languages.is_empty() {
let output_languages_dir = output_dir.join("languages");
fs::create_dir_all(&output_languages_dir)?;

View File

@@ -22,7 +22,10 @@ async-tar.workspace = true
async-trait.workspace = true
client.workspace = true
collections.workspace = true
credentials_provider.workspace = true
dap.workspace = true
dirs.workspace = true
editor.workspace = true
extension.workspace = true
fs.workspace = true
futures.workspace = true
@@ -30,8 +33,11 @@ gpui.workspace = true
gpui_tokio.workspace = true
http_client.workspace = true
language.workspace = true
language_model.workspace = true
log.workspace = true
markdown.workspace = true
lsp.workspace = true
menu.workspace = true
moka.workspace = true
node_runtime.workspace = true
paths.workspace = true
@@ -43,10 +49,13 @@ serde.workspace = true
serde_json.workspace = true
serde_json_lenient.workspace = true
settings.workspace = true
smol.workspace = true
task.workspace = true
telemetry.workspace = true
tempfile.workspace = true
theme.workspace = true
toml.workspace = true
ui.workspace = true
url.workspace = true
util.workspace = true
wasmparser.workspace = true

View File

@@ -148,6 +148,7 @@ fn manifest() -> ExtensionManifest {
)],
debug_adapters: Default::default(),
debug_locators: Default::default(),
language_model_providers: BTreeMap::default(),
}
}

View File

@@ -0,0 +1,124 @@
use credentials_provider::CredentialsProvider;
use gpui::App;
const ANTHROPIC_EXTENSION_ID: &str = "anthropic";
const ANTHROPIC_PROVIDER_ID: &str = "anthropic";
const ANTHROPIC_DEFAULT_API_URL: &str = "https://api.anthropic.com";
/// Migrates Anthropic API credentials from the old built-in provider location
/// to the new extension-based location.
///
/// This should only be called during auto-install of the extension.
pub fn migrate_anthropic_credentials_if_needed(extension_id: &str, cx: &mut App) {
if extension_id != ANTHROPIC_EXTENSION_ID {
return;
}
let extension_credential_key = format!(
"extension-llm-{}:{}",
ANTHROPIC_EXTENSION_ID, ANTHROPIC_PROVIDER_ID
);
let credentials_provider = <dyn CredentialsProvider>::global(cx);
cx.spawn(async move |cx| {
// Read from old location
let old_credential = credentials_provider
.read_credentials(ANTHROPIC_DEFAULT_API_URL, &cx)
.await
.ok()
.flatten();
let api_key = match old_credential {
Some((_, key_bytes)) => match String::from_utf8(key_bytes) {
Ok(key) if !key.is_empty() => key,
Ok(_) => {
log::debug!("Existing Anthropic API key is empty, nothing to migrate");
return;
}
Err(_) => {
log::error!("Failed to decode Anthropic API key as UTF-8");
return;
}
},
None => {
log::debug!("No existing Anthropic API key found to migrate");
return;
}
};
log::info!("Migrating existing Anthropic API key to Anthropic extension");
match credentials_provider
.write_credentials(&extension_credential_key, "Bearer", api_key.as_bytes(), &cx)
.await
{
Ok(()) => {
log::info!("Successfully migrated Anthropic API key to extension");
}
Err(err) => {
log::error!("Failed to migrate Anthropic API key: {}", err);
}
}
})
.detach();
}
#[cfg(test)]
mod tests {
use super::*;
use gpui::TestAppContext;
#[gpui::test]
async fn test_migrates_credentials_from_old_location(cx: &mut TestAppContext) {
let api_key = "sk-ant-test-key-12345";
cx.write_credentials(ANTHROPIC_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
cx.update(|cx| {
migrate_anthropic_credentials_if_needed(ANTHROPIC_EXTENSION_ID, cx);
});
cx.run_until_parked();
let migrated = cx.read_credentials("extension-llm-anthropic:anthropic");
assert!(migrated.is_some(), "Credentials should have been migrated");
let (username, password) = migrated.unwrap();
assert_eq!(username, "Bearer");
assert_eq!(String::from_utf8(password).unwrap(), api_key);
}
#[gpui::test]
async fn test_no_migration_if_no_old_credentials(cx: &mut TestAppContext) {
cx.update(|cx| {
migrate_anthropic_credentials_if_needed(ANTHROPIC_EXTENSION_ID, cx);
});
cx.run_until_parked();
let credentials = cx.read_credentials("extension-llm-anthropic:anthropic");
assert!(
credentials.is_none(),
"Should not create credentials if none existed"
);
}
#[gpui::test]
async fn test_skips_migration_for_other_extensions(cx: &mut TestAppContext) {
let api_key = "sk-ant-test-key";
cx.write_credentials(ANTHROPIC_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
cx.update(|cx| {
migrate_anthropic_credentials_if_needed("some-other-extension", cx);
});
cx.run_until_parked();
let credentials = cx.read_credentials("extension-llm-anthropic:anthropic");
assert!(
credentials.is_none(),
"Should not migrate for other extensions"
);
}
}

View File

@@ -113,6 +113,7 @@ mod tests {
capabilities: vec![],
debug_adapters: Default::default(),
debug_locators: Default::default(),
language_model_providers: BTreeMap::default(),
}
}

View File

@@ -0,0 +1,216 @@
use credentials_provider::CredentialsProvider;
use gpui::App;
use std::path::PathBuf;
const COPILOT_CHAT_EXTENSION_ID: &str = "copilot-chat";
const COPILOT_CHAT_PROVIDER_ID: &str = "copilot-chat";
/// Migrates Copilot OAuth credentials from the GitHub Copilot config files
/// to the new extension-based credential location.
///
/// This should only be called during auto-install of the extension.
pub fn migrate_copilot_credentials_if_needed(extension_id: &str, cx: &mut App) {
if extension_id != COPILOT_CHAT_EXTENSION_ID {
return;
}
let credential_key = format!(
"extension-llm-{}:{}",
COPILOT_CHAT_EXTENSION_ID, COPILOT_CHAT_PROVIDER_ID
);
let credentials_provider = <dyn CredentialsProvider>::global(cx);
cx.spawn(async move |_cx| {
// Read from copilot config files
let oauth_token = match read_copilot_oauth_token().await {
Some(token) if !token.is_empty() => token,
_ => {
log::debug!("No existing Copilot OAuth token found to migrate");
return;
}
};
log::info!("Migrating existing Copilot OAuth token to Copilot Chat extension");
match credentials_provider
.write_credentials(&credential_key, "api_key", oauth_token.as_bytes(), &_cx)
.await
{
Ok(()) => {
log::info!("Successfully migrated Copilot OAuth token to Copilot Chat extension");
}
Err(err) => {
log::error!("Failed to migrate Copilot OAuth token: {}", err);
}
}
})
.detach();
}
async fn read_copilot_oauth_token() -> Option<String> {
let config_paths = copilot_config_paths();
for path in config_paths {
if let Some(token) = read_oauth_token_from_file(&path).await {
return Some(token);
}
}
None
}
fn copilot_config_paths() -> Vec<PathBuf> {
let config_dir = if cfg!(target_os = "windows") {
dirs::data_local_dir()
} else {
std::env::var("XDG_CONFIG_HOME")
.map(PathBuf::from)
.ok()
.or_else(|| dirs::home_dir().map(|h| h.join(".config")))
};
let Some(config_dir) = config_dir else {
return Vec::new();
};
let copilot_dir = config_dir.join("github-copilot");
vec![
copilot_dir.join("hosts.json"),
copilot_dir.join("apps.json"),
]
}
async fn read_oauth_token_from_file(path: &PathBuf) -> Option<String> {
let contents = match smol::fs::read_to_string(path).await {
Ok(contents) => contents,
Err(_) => return None,
};
extract_oauth_token(&contents, "github.com")
}
fn extract_oauth_token(contents: &str, domain: &str) -> Option<String> {
let value: serde_json::Value = serde_json::from_str(contents).ok()?;
let obj = value.as_object()?;
for (key, value) in obj.iter() {
if key.starts_with(domain) {
if let Some(token) = value.get("oauth_token").and_then(|v| v.as_str()) {
return Some(token.to_string());
}
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use gpui::TestAppContext;
#[test]
fn test_extract_oauth_token_from_hosts_json() {
let contents = r#"{
"github.com": {
"oauth_token": "ghu_test_token_12345"
}
}"#;
let token = extract_oauth_token(contents, "github.com");
assert_eq!(token, Some("ghu_test_token_12345".to_string()));
}
#[test]
fn test_extract_oauth_token_with_user_suffix() {
let contents = r#"{
"github.com:user": {
"oauth_token": "ghu_another_token"
}
}"#;
let token = extract_oauth_token(contents, "github.com");
assert_eq!(token, Some("ghu_another_token".to_string()));
}
#[test]
fn test_extract_oauth_token_wrong_domain() {
let contents = r#"{
"gitlab.com": {
"oauth_token": "some_token"
}
}"#;
let token = extract_oauth_token(contents, "github.com");
assert_eq!(token, None);
}
#[test]
fn test_extract_oauth_token_invalid_json() {
let contents = "not valid json";
let token = extract_oauth_token(contents, "github.com");
assert_eq!(token, None);
}
#[test]
fn test_extract_oauth_token_missing_oauth_token_field() {
let contents = r#"{
"github.com": {
"user": "testuser"
}
}"#;
let token = extract_oauth_token(contents, "github.com");
assert_eq!(token, None);
}
#[test]
fn test_extract_oauth_token_multiple_entries_picks_first_match() {
let contents = r#"{
"gitlab.com": {
"oauth_token": "gitlab_token"
},
"github.com": {
"oauth_token": "github_token"
}
}"#;
let token = extract_oauth_token(contents, "github.com");
assert_eq!(token, Some("github_token".to_string()));
}
#[gpui::test]
async fn test_skips_migration_for_other_extensions(cx: &mut TestAppContext) {
cx.update(|cx| {
migrate_copilot_credentials_if_needed("some-other-extension", cx);
});
cx.run_until_parked();
let credentials = cx.read_credentials("extension-llm-copilot-chat:copilot-chat");
assert!(
credentials.is_none(),
"Should not create credentials for other extensions"
);
}
// Note: Unlike the other migrations, copilot migration reads from the filesystem
// (copilot config files), not from the credentials provider. In tests, these files
// don't exist, so no migration occurs.
#[gpui::test]
async fn test_no_credentials_when_no_copilot_config_exists(cx: &mut TestAppContext) {
cx.update(|cx| {
migrate_copilot_credentials_if_needed(COPILOT_CHAT_EXTENSION_ID, cx);
});
cx.run_until_parked();
let credentials = cx.read_credentials("extension-llm-copilot-chat:copilot-chat");
assert!(
credentials.is_none(),
"No credentials should be written when copilot config doesn't exist"
);
}
}

View File

@@ -1,6 +1,11 @@
mod anthropic_migration;
mod capability_granter;
mod copilot_migration;
pub mod extension_settings;
mod google_ai_migration;
pub mod headless_host;
mod open_router_migration;
mod openai_migration;
pub mod wasm_host;
#[cfg(test)]
@@ -12,13 +17,14 @@ use async_tar::Archive;
use client::ExtensionProvides;
use client::{Client, ExtensionMetadata, GetExtensionsResponse, proto, telemetry::Telemetry};
use collections::{BTreeMap, BTreeSet, HashSet, btree_map};
pub use extension::ExtensionManifest;
use extension::extension_builder::{CompileExtensionOptions, ExtensionBuilder};
use extension::{
ExtensionContextServerProxy, ExtensionDebugAdapterProviderProxy, ExtensionEvents,
ExtensionGrammarProxy, ExtensionHostProxy, ExtensionLanguageProxy,
ExtensionLanguageServerProxy, ExtensionSlashCommandProxy, ExtensionSnippetProxy,
ExtensionThemeProxy,
ExtensionGrammarProxy, ExtensionHostProxy, ExtensionLanguageModelProviderProxy,
ExtensionLanguageProxy, ExtensionLanguageServerProxy, ExtensionSlashCommandProxy,
ExtensionSnippetProxy, ExtensionThemeProxy,
};
use fs::{Fs, RemoveOptions};
use futures::future::join_all;
@@ -32,8 +38,8 @@ use futures::{
select_biased,
};
use gpui::{
App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, Task, WeakEntity,
actions,
App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Global, SharedString, Task,
WeakEntity, actions,
};
use http_client::{AsyncBody, HttpClient, HttpClientWithUrl};
use language::{
@@ -53,15 +59,24 @@ use std::{
cmp::Ordering,
path::{self, Path, PathBuf},
sync::Arc,
time::{Duration, Instant},
time::Duration,
};
use url::Url;
use util::{ResultExt, paths::RemotePathBuf};
use wasm_host::llm_provider::ExtensionLanguageModelProvider;
use wasm_host::{
WasmExtension, WasmHost,
wit::{is_supported_wasm_api_version, wasm_api_version_range},
wit::{LlmModelInfo, LlmProviderInfo, is_supported_wasm_api_version, wasm_api_version_range},
};
struct LlmProviderWithModels {
provider_info: LlmProviderInfo,
models: Vec<LlmModelInfo>,
is_authenticated: bool,
icon_path: Option<SharedString>,
auth_config: Option<extension::LanguageModelAuthConfig>,
}
pub use extension::{
ExtensionLibraryKind, GrammarManifestEntry, OldExtensionManifest, SchemaVersion,
};
@@ -70,6 +85,79 @@ pub use extension_settings::ExtensionSettings;
pub const RELOAD_DEBOUNCE_DURATION: Duration = Duration::from_millis(200);
const FS_WATCH_LATENCY: Duration = Duration::from_millis(100);
/// Extension IDs that are being migrated from hardcoded LLM providers.
/// For backwards compatibility, if the user has the corresponding env var set,
/// we automatically enable env var reading for these extensions on first install.
const LEGACY_LLM_EXTENSION_IDS: &[&str] = &[
"anthropic",
"copilot-chat",
"google-ai",
"openrouter",
"openai",
];
/// Migrates legacy LLM provider extensions by auto-enabling env var reading
/// if the env var is currently present in the environment.
///
/// This is idempotent: if the provider is already in `allowed_env_var_providers`,
/// we skip. This means if a user explicitly removes it, it will be re-added on
/// next launch if the env var is still set - but that's predictable behavior.
fn migrate_legacy_llm_provider_env_var(manifest: &ExtensionManifest, cx: &mut App) {
// Only apply migration to known legacy LLM extensions
if !LEGACY_LLM_EXTENSION_IDS.contains(&manifest.id.as_ref()) {
return;
}
// Check each provider in the manifest
for (provider_id, provider_entry) in &manifest.language_model_providers {
let Some(auth_config) = &provider_entry.auth else {
continue;
};
let Some(env_var_name) = &auth_config.env_var else {
continue;
};
let full_provider_id: Arc<str> = format!("{}:{}", manifest.id, provider_id).into();
// Check if the env var is present and non-empty
let env_var_is_set = std::env::var(env_var_name)
.map(|v| !v.is_empty())
.unwrap_or(false);
// If env var isn't set, no need to do anything
if !env_var_is_set {
continue;
}
// Check if already enabled in settings
let already_enabled = ExtensionSettings::get_global(cx)
.allowed_env_var_providers
.contains(full_provider_id.as_ref());
if already_enabled {
continue;
}
// Enable env var reading since the env var is set
settings::update_settings_file(<dyn fs::Fs>::global(cx), cx, {
let full_provider_id = full_provider_id.clone();
move |settings, _| {
let providers = settings
.extension
.allowed_env_var_providers
.get_or_insert_with(Vec::new);
if !providers
.iter()
.any(|id| id.as_ref() == full_provider_id.as_ref())
{
providers.push(full_provider_id);
}
}
});
}
}
/// The current extension [`SchemaVersion`] supported by Zed.
const CURRENT_SCHEMA_VERSION: SchemaVersion = SchemaVersion(1);
@@ -131,6 +219,8 @@ pub struct ExtensionStore {
pub enum ExtensionOperation {
Upgrade,
Install,
/// Auto-install from settings - triggers legacy LLM provider migrations
AutoInstall,
Remove,
}
@@ -606,15 +696,68 @@ impl ExtensionStore {
.extension_index
.extensions
.contains_key(extension_id.as_ref());
!is_already_installed && !SUPPRESSED_EXTENSIONS.contains(&extension_id.as_ref())
let dominated = SUPPRESSED_EXTENSIONS.contains(&extension_id.as_ref());
!is_already_installed && !dominated
})
.cloned()
.collect::<Vec<_>>();
cx.spawn(async move |this, cx| {
for extension_id in extensions_to_install {
// When enabled, this checks if an extension exists locally in the repo's extensions/
// directory and installs it as a dev extension instead of fetching from the registry.
// This is useful for testing auto-installed extensions before they've been published.
// Set to `true` only during local development/testing of new auto-install extensions.
#[cfg(debug_assertions)]
const DEBUG_ALLOW_UNPUBLISHED_AUTO_EXTENSIONS: bool = false;
#[cfg(debug_assertions)]
if DEBUG_ALLOW_UNPUBLISHED_AUTO_EXTENSIONS {
let local_extension_path = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.parent()
.unwrap()
.parent()
.unwrap()
.join("extensions")
.join(extension_id.as_ref());
if local_extension_path.exists() {
// Force-remove existing extension directory if it exists and isn't a symlink
// This handles the case where the extension was previously installed from the registry
if let Some(installed_dir) = this
.update(cx, |this, _cx| this.installed_dir.clone())
.ok()
{
let existing_path = installed_dir.join(extension_id.as_ref());
if existing_path.exists() {
let metadata = std::fs::symlink_metadata(&existing_path);
let is_symlink = metadata.map(|m| m.is_symlink()).unwrap_or(false);
if !is_symlink {
if let Err(e) = std::fs::remove_dir_all(&existing_path) {
log::error!(
"Failed to remove existing extension directory {:?}: {}",
existing_path,
e
);
}
}
}
}
if let Some(task) = this
.update(cx, |this, cx| {
this.install_dev_extension(local_extension_path, cx)
})
.ok()
{
task.await.log_err();
}
continue;
}
}
this.update(cx, |this, cx| {
this.install_latest_extension(extension_id.clone(), cx);
this.auto_install_latest_extension(extension_id.clone(), cx);
})
.ok();
}
@@ -769,7 +912,10 @@ impl ExtensionStore {
this.update(cx, |this, cx| this.reload(Some(extension_id.clone()), cx))?
.await;
if let ExtensionOperation::Install = operation {
if matches!(
operation,
ExtensionOperation::Install | ExtensionOperation::AutoInstall
) {
this.update(cx, |this, cx| {
cx.emit(Event::ExtensionInstalled(extension_id.clone()));
if let Some(events) = ExtensionEvents::try_global(cx)
@@ -779,6 +925,27 @@ impl ExtensionStore {
this.emit(extension::Event::ExtensionInstalled(manifest.clone()), cx)
});
}
// Run legacy LLM provider migrations only for auto-installed extensions
if matches!(operation, ExtensionOperation::AutoInstall) {
if let Some(manifest) = this.extension_manifest_for_id(&extension_id) {
migrate_legacy_llm_provider_env_var(&manifest, cx);
}
copilot_migration::migrate_copilot_credentials_if_needed(&extension_id, cx);
anthropic_migration::migrate_anthropic_credentials_if_needed(
&extension_id,
cx,
);
google_ai_migration::migrate_google_ai_credentials_if_needed(
&extension_id,
cx,
);
openai_migration::migrate_openai_credentials_if_needed(&extension_id, cx);
open_router_migration::migrate_open_router_credentials_if_needed(
&extension_id,
cx,
);
}
})
.ok();
}
@@ -788,8 +955,24 @@ impl ExtensionStore {
}
pub fn install_latest_extension(&mut self, extension_id: Arc<str>, cx: &mut Context<Self>) {
log::info!("installing extension {extension_id} latest version");
self.install_latest_extension_with_operation(extension_id, ExtensionOperation::Install, cx);
}
/// Auto-install an extension, triggering legacy LLM provider migrations.
fn auto_install_latest_extension(&mut self, extension_id: Arc<str>, cx: &mut Context<Self>) {
self.install_latest_extension_with_operation(
extension_id,
ExtensionOperation::AutoInstall,
cx,
);
}
fn install_latest_extension_with_operation(
&mut self,
extension_id: Arc<str>,
operation: ExtensionOperation,
cx: &mut Context<Self>,
) {
let schema_versions = schema_version_range();
let wasm_api_versions = wasm_api_version_range(ReleaseChannel::global(cx));
@@ -812,13 +995,8 @@ impl ExtensionStore {
return;
};
self.install_or_upgrade_extension_at_endpoint(
extension_id,
url,
ExtensionOperation::Install,
cx,
)
.detach_and_log_err(cx);
self.install_or_upgrade_extension_at_endpoint(extension_id, url, operation, cx)
.detach_and_log_err(cx);
}
pub fn upgrade_extension(
@@ -837,7 +1015,6 @@ impl ExtensionStore {
operation: ExtensionOperation,
cx: &mut Context<Self>,
) -> Task<Result<()>> {
log::info!("installing extension {extension_id} {version}");
let Some(url) = self
.http_client
.build_zed_api_url(
@@ -1134,18 +1311,6 @@ impl ExtensionStore {
return Task::ready(());
}
let reload_count = extensions_to_unload
.iter()
.filter(|id| extensions_to_load.contains(id))
.count();
log::info!(
"extensions updated. loading {}, reloading {}, unloading {}",
extensions_to_load.len() - reload_count,
reload_count,
extensions_to_unload.len() - reload_count
);
let extension_ids = extensions_to_load
.iter()
.filter_map(|id| {
@@ -1220,6 +1385,11 @@ impl ExtensionStore {
for command_name in extension.manifest.slash_commands.keys() {
self.proxy.unregister_slash_command(command_name.clone());
}
for provider_id in extension.manifest.language_model_providers.keys() {
let full_provider_id: Arc<str> = format!("{}:{}", extension_id, provider_id).into();
self.proxy
.unregister_language_model_provider(full_provider_id, cx);
}
}
self.wasm_extensions
@@ -1358,7 +1528,11 @@ impl ExtensionStore {
})
.await;
let mut wasm_extensions = Vec::new();
let mut wasm_extensions: Vec<(
Arc<ExtensionManifest>,
WasmExtension,
Vec<LlmProviderWithModels>,
)> = Vec::new();
for extension in extension_entries {
if extension.manifest.lib.kind.is_none() {
continue;
@@ -1376,7 +1550,122 @@ impl ExtensionStore {
match wasm_extension {
Ok(wasm_extension) => {
wasm_extensions.push((extension.manifest.clone(), wasm_extension))
// Query for LLM providers if the manifest declares any
let mut llm_providers_with_models = Vec::new();
if !extension.manifest.language_model_providers.is_empty() {
let providers_result = wasm_extension
.call(|ext, store| {
async move { ext.call_llm_providers(store).await }.boxed()
})
.await;
if let Ok(Ok(providers)) = providers_result {
for provider_info in providers {
let models_result = wasm_extension
.call({
let provider_id = provider_info.id.clone();
|ext, store| {
async move {
ext.call_llm_provider_models(store, &provider_id)
.await
}
.boxed()
}
})
.await;
let models: Vec<LlmModelInfo> = match models_result {
Ok(Ok(Ok(models))) => models,
Ok(Ok(Err(e))) => {
log::error!(
"Failed to get models for LLM provider {} in extension {}: {}",
provider_info.id,
extension.manifest.id,
e
);
Vec::new()
}
Ok(Err(e)) => {
log::error!(
"Wasm error calling llm_provider_models for {} in extension {}: {:?}",
provider_info.id,
extension.manifest.id,
e
);
Vec::new()
}
Err(e) => {
log::error!(
"Extension call failed for llm_provider_models {} in extension {}: {:?}",
provider_info.id,
extension.manifest.id,
e
);
Vec::new()
}
};
// Query initial authentication state
let is_authenticated = wasm_extension
.call({
let provider_id = provider_info.id.clone();
|ext, store| {
async move {
ext.call_llm_provider_is_authenticated(
store,
&provider_id,
)
.await
}
.boxed()
}
})
.await
.unwrap_or(Ok(false))
.unwrap_or(false);
// Resolve icon path if provided
let icon_path = provider_info.icon.as_ref().map(|icon| {
let icon_file_path = extension_path.join(icon);
// Canonicalize to resolve symlinks (dev extensions are symlinked)
let absolute_icon_path = icon_file_path
.canonicalize()
.unwrap_or(icon_file_path)
.to_string_lossy()
.to_string();
SharedString::from(absolute_icon_path)
});
let provider_id_arc: Arc<str> =
provider_info.id.as_str().into();
let auth_config = extension
.manifest
.language_model_providers
.get(&provider_id_arc)
.and_then(|entry| entry.auth.clone());
llm_providers_with_models.push(LlmProviderWithModels {
provider_info,
models,
is_authenticated,
icon_path,
auth_config,
});
}
} else {
log::error!(
"Failed to get LLM providers from extension {}: {:?}",
extension.manifest.id,
providers_result
);
}
}
wasm_extensions.push((
extension.manifest.clone(),
wasm_extension,
llm_providers_with_models,
))
}
Err(e) => {
log::error!(
@@ -1395,7 +1684,7 @@ impl ExtensionStore {
this.update(cx, |this, cx| {
this.reload_complete_senders.clear();
for (manifest, wasm_extension) in &wasm_extensions {
for (manifest, wasm_extension, llm_providers_with_models) in &wasm_extensions {
let extension = Arc::new(wasm_extension.clone());
for (language_server_id, language_server_config) in &manifest.language_servers {
@@ -1449,9 +1738,41 @@ impl ExtensionStore {
this.proxy
.register_debug_locator(extension.clone(), debug_adapter.clone());
}
// Register LLM providers
for llm_provider in llm_providers_with_models {
let provider_id: Arc<str> =
format!("{}:{}", manifest.id, llm_provider.provider_info.id).into();
let wasm_ext = extension.as_ref().clone();
let pinfo = llm_provider.provider_info.clone();
let mods = llm_provider.models.clone();
let auth = llm_provider.is_authenticated;
let icon = llm_provider.icon_path.clone();
let auth_config = llm_provider.auth_config.clone();
this.proxy.register_language_model_provider(
provider_id.clone(),
Box::new(move |cx: &mut App| {
let provider = Arc::new(ExtensionLanguageModelProvider::new(
wasm_ext, pinfo, mods, auth, icon, auth_config, cx,
));
language_model::LanguageModelRegistry::global(cx).update(
cx,
|registry, cx| {
registry.register_provider(provider, cx);
},
);
}),
cx,
);
}
}
this.wasm_extensions.extend(wasm_extensions);
let wasm_extensions_without_llm: Vec<_> = wasm_extensions
.into_iter()
.map(|(manifest, ext, _)| (manifest, ext))
.collect();
this.wasm_extensions.extend(wasm_extensions_without_llm);
this.proxy.set_extensions_loaded();
this.proxy.reload_current_theme(cx);
this.proxy.reload_current_icon_theme(cx);
@@ -1473,7 +1794,6 @@ impl ExtensionStore {
let index_path = self.index_path.clone();
let proxy = self.proxy.clone();
cx.background_spawn(async move {
let start_time = Instant::now();
let mut index = ExtensionIndex::default();
fs.create_dir(&work_dir).await.log_err();
@@ -1511,7 +1831,6 @@ impl ExtensionStore {
.log_err();
}
log::info!("rebuilt extension index in {:?}", start_time.elapsed());
index
})
}
@@ -1785,11 +2104,6 @@ impl ExtensionStore {
})?,
path_style,
);
log::info!(
"Uploading extension {} to {:?}",
missing_extension.clone().id,
dest_dir
);
client
.update(cx, |client, cx| {
@@ -1797,11 +2111,6 @@ impl ExtensionStore {
})?
.await?;
log::info!(
"Finished uploading extension {}",
missing_extension.clone().id
);
let result = client
.update(cx, |client, _cx| {
client.proto_client().request(proto::InstallExtension {

View File

@@ -1,4 +1,4 @@
use collections::HashMap;
use collections::{HashMap, HashSet};
use extension::{
DownloadFileCapability, ExtensionCapability, NpmInstallPackageCapability, ProcessExecCapability,
};
@@ -16,6 +16,10 @@ pub struct ExtensionSettings {
pub auto_install_extensions: HashMap<Arc<str>, bool>,
pub auto_update_extensions: HashMap<Arc<str>, bool>,
pub granted_capabilities: Vec<ExtensionCapability>,
/// The extension language model providers that are allowed to read API keys
/// from environment variables. Each entry is a provider ID in the format
/// "extension_id:provider_id".
pub allowed_env_var_providers: HashSet<Arc<str>>,
}
impl ExtensionSettings {
@@ -60,6 +64,13 @@ impl Settings for ExtensionSettings {
}
})
.collect(),
allowed_env_var_providers: content
.extension
.allowed_env_var_providers
.clone()
.unwrap_or_default()
.into_iter()
.collect(),
}
}
}

View File

@@ -165,6 +165,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
capabilities: Vec::new(),
debug_adapters: Default::default(),
debug_locators: Default::default(),
language_model_providers: BTreeMap::default(),
}),
dev: false,
},
@@ -196,6 +197,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
capabilities: Vec::new(),
debug_adapters: Default::default(),
debug_locators: Default::default(),
language_model_providers: BTreeMap::default(),
}),
dev: false,
},
@@ -376,6 +378,7 @@ async fn test_extension_store(cx: &mut TestAppContext) {
capabilities: Vec::new(),
debug_adapters: Default::default(),
debug_locators: Default::default(),
language_model_providers: BTreeMap::default(),
}),
dev: false,
},

View File

@@ -0,0 +1,124 @@
use credentials_provider::CredentialsProvider;
use gpui::App;
const GOOGLE_AI_EXTENSION_ID: &str = "google-ai";
const GOOGLE_AI_PROVIDER_ID: &str = "google-ai";
const GOOGLE_AI_DEFAULT_API_URL: &str = "https://generativelanguage.googleapis.com";
/// Migrates Google AI API credentials from the old built-in provider location
/// to the new extension-based location.
///
/// This should only be called during auto-install of the extension.
pub fn migrate_google_ai_credentials_if_needed(extension_id: &str, cx: &mut App) {
if extension_id != GOOGLE_AI_EXTENSION_ID {
return;
}
let extension_credential_key = format!(
"extension-llm-{}:{}",
GOOGLE_AI_EXTENSION_ID, GOOGLE_AI_PROVIDER_ID
);
let credentials_provider = <dyn CredentialsProvider>::global(cx);
cx.spawn(async move |cx| {
// Read from old location
let old_credential = credentials_provider
.read_credentials(GOOGLE_AI_DEFAULT_API_URL, &cx)
.await
.ok()
.flatten();
let api_key = match old_credential {
Some((_, key_bytes)) => match String::from_utf8(key_bytes) {
Ok(key) if !key.is_empty() => key,
Ok(_) => {
log::debug!("Existing Google AI API key is empty, nothing to migrate");
return;
}
Err(_) => {
log::error!("Failed to decode Google AI API key as UTF-8");
return;
}
},
None => {
log::debug!("No existing Google AI API key found to migrate");
return;
}
};
log::info!("Migrating existing Google AI API key to Google AI extension");
match credentials_provider
.write_credentials(&extension_credential_key, "Bearer", api_key.as_bytes(), &cx)
.await
{
Ok(()) => {
log::info!("Successfully migrated Google AI API key to extension");
}
Err(err) => {
log::error!("Failed to migrate Google AI API key: {}", err);
}
}
})
.detach();
}
#[cfg(test)]
mod tests {
use super::*;
use gpui::TestAppContext;
#[gpui::test]
async fn test_migrates_credentials_from_old_location(cx: &mut TestAppContext) {
let api_key = "AIzaSy-test-key-12345";
cx.write_credentials(GOOGLE_AI_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
cx.update(|cx| {
migrate_google_ai_credentials_if_needed(GOOGLE_AI_EXTENSION_ID, cx);
});
cx.run_until_parked();
let migrated = cx.read_credentials("extension-llm-google-ai:google-ai");
assert!(migrated.is_some(), "Credentials should have been migrated");
let (username, password) = migrated.unwrap();
assert_eq!(username, "Bearer");
assert_eq!(String::from_utf8(password).unwrap(), api_key);
}
#[gpui::test]
async fn test_no_migration_if_no_old_credentials(cx: &mut TestAppContext) {
cx.update(|cx| {
migrate_google_ai_credentials_if_needed(GOOGLE_AI_EXTENSION_ID, cx);
});
cx.run_until_parked();
let credentials = cx.read_credentials("extension-llm-google-ai:google-ai");
assert!(
credentials.is_none(),
"Should not create credentials if none existed"
);
}
#[gpui::test]
async fn test_skips_migration_for_other_extensions(cx: &mut TestAppContext) {
let api_key = "AIzaSy-test-key";
cx.write_credentials(GOOGLE_AI_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
cx.update(|cx| {
migrate_google_ai_credentials_if_needed("some-other-extension", cx);
});
cx.run_until_parked();
let credentials = cx.read_credentials("extension-llm-google-ai:google-ai");
assert!(
credentials.is_none(),
"Should not migrate for other extensions"
);
}
}

View File

@@ -0,0 +1,124 @@
use credentials_provider::CredentialsProvider;
use gpui::App;
const OPEN_ROUTER_EXTENSION_ID: &str = "openrouter";
const OPEN_ROUTER_PROVIDER_ID: &str = "openrouter";
const OPEN_ROUTER_DEFAULT_API_URL: &str = "https://openrouter.ai/api/v1";
/// Migrates OpenRouter API credentials from the old built-in provider location
/// to the new extension-based location.
///
/// This should only be called during auto-install of the extension.
pub fn migrate_open_router_credentials_if_needed(extension_id: &str, cx: &mut App) {
if extension_id != OPEN_ROUTER_EXTENSION_ID {
return;
}
let extension_credential_key = format!(
"extension-llm-{}:{}",
OPEN_ROUTER_EXTENSION_ID, OPEN_ROUTER_PROVIDER_ID
);
let credentials_provider = <dyn CredentialsProvider>::global(cx);
cx.spawn(async move |cx| {
// Read from old location
let old_credential = credentials_provider
.read_credentials(OPEN_ROUTER_DEFAULT_API_URL, &cx)
.await
.ok()
.flatten();
let api_key = match old_credential {
Some((_, key_bytes)) => match String::from_utf8(key_bytes) {
Ok(key) if !key.is_empty() => key,
Ok(_) => {
log::debug!("Existing OpenRouter API key is empty, nothing to migrate");
return;
}
Err(_) => {
log::error!("Failed to decode OpenRouter API key as UTF-8");
return;
}
},
None => {
log::debug!("No existing OpenRouter API key found to migrate");
return;
}
};
log::info!("Migrating existing OpenRouter API key to OpenRouter extension");
match credentials_provider
.write_credentials(&extension_credential_key, "Bearer", api_key.as_bytes(), &cx)
.await
{
Ok(()) => {
log::info!("Successfully migrated OpenRouter API key to extension");
}
Err(err) => {
log::error!("Failed to migrate OpenRouter API key: {}", err);
}
}
})
.detach();
}
#[cfg(test)]
mod tests {
use super::*;
use gpui::TestAppContext;
#[gpui::test]
async fn test_migrates_credentials_from_old_location(cx: &mut TestAppContext) {
let api_key = "sk-or-test-key-12345";
cx.write_credentials(OPEN_ROUTER_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
cx.update(|cx| {
migrate_open_router_credentials_if_needed(OPEN_ROUTER_EXTENSION_ID, cx);
});
cx.run_until_parked();
let migrated = cx.read_credentials("extension-llm-openrouter:openrouter");
assert!(migrated.is_some(), "Credentials should have been migrated");
let (username, password) = migrated.unwrap();
assert_eq!(username, "Bearer");
assert_eq!(String::from_utf8(password).unwrap(), api_key);
}
#[gpui::test]
async fn test_no_migration_if_no_old_credentials(cx: &mut TestAppContext) {
cx.update(|cx| {
migrate_open_router_credentials_if_needed(OPEN_ROUTER_EXTENSION_ID, cx);
});
cx.run_until_parked();
let credentials = cx.read_credentials("extension-llm-openrouter:openrouter");
assert!(
credentials.is_none(),
"Should not create credentials if none existed"
);
}
#[gpui::test]
async fn test_skips_migration_for_other_extensions(cx: &mut TestAppContext) {
let api_key = "sk-or-test-key";
cx.write_credentials(OPEN_ROUTER_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
cx.update(|cx| {
migrate_open_router_credentials_if_needed("some-other-extension", cx);
});
cx.run_until_parked();
let credentials = cx.read_credentials("extension-llm-openrouter:openrouter");
assert!(
credentials.is_none(),
"Should not migrate for other extensions"
);
}
}

View File

@@ -0,0 +1,124 @@
use credentials_provider::CredentialsProvider;
use gpui::App;
const OPENAI_EXTENSION_ID: &str = "openai";
const OPENAI_PROVIDER_ID: &str = "openai";
const OPENAI_DEFAULT_API_URL: &str = "https://api.openai.com/v1";
/// Migrates OpenAI API credentials from the old built-in provider location
/// to the new extension-based location.
///
/// This should only be called during auto-install of the extension.
pub fn migrate_openai_credentials_if_needed(extension_id: &str, cx: &mut App) {
if extension_id != OPENAI_EXTENSION_ID {
return;
}
let extension_credential_key = format!(
"extension-llm-{}:{}",
OPENAI_EXTENSION_ID, OPENAI_PROVIDER_ID
);
let credentials_provider = <dyn CredentialsProvider>::global(cx);
cx.spawn(async move |cx| {
// Read from old location
let old_credential = credentials_provider
.read_credentials(OPENAI_DEFAULT_API_URL, &cx)
.await
.ok()
.flatten();
let api_key = match old_credential {
Some((_, key_bytes)) => match String::from_utf8(key_bytes) {
Ok(key) if !key.is_empty() => key,
Ok(_) => {
log::debug!("Existing OpenAI API key is empty, nothing to migrate");
return;
}
Err(_) => {
log::error!("Failed to decode OpenAI API key as UTF-8");
return;
}
},
None => {
log::debug!("No existing OpenAI API key found to migrate");
return;
}
};
log::info!("Migrating existing OpenAI API key to OpenAI extension");
match credentials_provider
.write_credentials(&extension_credential_key, "Bearer", api_key.as_bytes(), &cx)
.await
{
Ok(()) => {
log::info!("Successfully migrated OpenAI API key to extension");
}
Err(err) => {
log::error!("Failed to migrate OpenAI API key: {}", err);
}
}
})
.detach();
}
#[cfg(test)]
mod tests {
use super::*;
use gpui::TestAppContext;
#[gpui::test]
async fn test_migrates_credentials_from_old_location(cx: &mut TestAppContext) {
let api_key = "sk-test-key-12345";
cx.write_credentials(OPENAI_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
cx.update(|cx| {
migrate_openai_credentials_if_needed(OPENAI_EXTENSION_ID, cx);
});
cx.run_until_parked();
let migrated = cx.read_credentials("extension-llm-openai:openai");
assert!(migrated.is_some(), "Credentials should have been migrated");
let (username, password) = migrated.unwrap();
assert_eq!(username, "Bearer");
assert_eq!(String::from_utf8(password).unwrap(), api_key);
}
#[gpui::test]
async fn test_no_migration_if_no_old_credentials(cx: &mut TestAppContext) {
cx.update(|cx| {
migrate_openai_credentials_if_needed(OPENAI_EXTENSION_ID, cx);
});
cx.run_until_parked();
let credentials = cx.read_credentials("extension-llm-openai:openai");
assert!(
credentials.is_none(),
"Should not create credentials if none existed"
);
}
#[gpui::test]
async fn test_skips_migration_for_other_extensions(cx: &mut TestAppContext) {
let api_key = "sk-test-key";
cx.write_credentials(OPENAI_DEFAULT_API_URL, "Bearer", api_key.as_bytes());
cx.update(|cx| {
migrate_openai_credentials_if_needed("some-other-extension", cx);
});
cx.run_until_parked();
let credentials = cx.read_credentials("extension-llm-openai:openai");
assert!(
credentials.is_none(),
"Should not migrate for other extensions"
);
}
}

View File

@@ -1,9 +1,11 @@
pub mod llm_provider;
pub mod wit;
use crate::capability_granter::CapabilityGranter;
use crate::{ExtensionManifest, ExtensionSettings};
use anyhow::{Context as _, Result, anyhow, bail};
use async_trait::async_trait;
use dap::{DebugRequest, StartDebuggingRequestArgumentsRequest};
use extension::{
CodeLabel, Command, Completion, ContextServerConfiguration, DebugAdapterBinary,
@@ -64,7 +66,7 @@ pub struct WasmHost {
#[derive(Clone, Debug)]
pub struct WasmExtension {
tx: UnboundedSender<ExtensionCall>,
tx: Arc<UnboundedSender<ExtensionCall>>,
pub manifest: Arc<ExtensionManifest>,
pub work_dir: Arc<Path>,
#[allow(unused)]
@@ -74,7 +76,10 @@ pub struct WasmExtension {
impl Drop for WasmExtension {
fn drop(&mut self) {
self.tx.close_channel();
// Only close the channel when this is the last clone holding the sender
if Arc::strong_count(&self.tx) == 1 {
self.tx.close_channel();
}
}
}
@@ -671,7 +676,7 @@ impl WasmHost {
Ok(WasmExtension {
manifest,
work_dir,
tx,
tx: Arc::new(tx),
zed_api_version,
_task: task,
})

File diff suppressed because it is too large Load Diff

View File

@@ -16,7 +16,7 @@ use lsp::LanguageServerName;
use release_channel::ReleaseChannel;
use task::{DebugScenario, SpawnInTerminal, TaskTemplate, ZedDebugConfig};
use crate::wasm_host::wit::since_v0_6_0::dap::StartDebuggingRequestArgumentsRequest;
use crate::wasm_host::wit::since_v0_8_0::dap::StartDebuggingRequestArgumentsRequest;
use super::{WasmState, wasm_engine};
use anyhow::{Context as _, Result, anyhow};
@@ -33,6 +33,19 @@ pub use latest::CodeLabelSpanLiteral;
pub use latest::{
CodeLabel, CodeLabelSpan, Command, DebugAdapterBinary, ExtensionProject, Range, SlashCommand,
zed::extension::context_server::ContextServerConfiguration,
zed::extension::llm_provider::{
CacheConfiguration as LlmCacheConfiguration, CompletionEvent as LlmCompletionEvent,
CompletionRequest as LlmCompletionRequest, CredentialType as LlmCredentialType,
ImageData as LlmImageData, MessageContent as LlmMessageContent,
MessageRole as LlmMessageRole, ModelCapabilities as LlmModelCapabilities,
ModelInfo as LlmModelInfo, ProviderInfo as LlmProviderInfo,
RequestMessage as LlmRequestMessage, StopReason as LlmStopReason,
ThinkingContent as LlmThinkingContent, TokenUsage as LlmTokenUsage,
ToolChoice as LlmToolChoice, ToolDefinition as LlmToolDefinition,
ToolInputFormat as LlmToolInputFormat, ToolResult as LlmToolResult,
ToolResultContent as LlmToolResultContent, ToolUse as LlmToolUse,
ToolUseJsonParseError as LlmToolUseJsonParseError,
},
zed::extension::lsp::{
Completion, CompletionKind, CompletionLabelDetails, InsertTextFormat, Symbol, SymbolKind,
},
@@ -1007,6 +1020,20 @@ impl Extension {
resource: Resource<Arc<dyn WorktreeDelegate>>,
) -> Result<Result<DebugAdapterBinary, String>> {
match self {
Extension::V0_8_0(ext) => {
let dap_binary = ext
.call_get_dap_binary(
store,
&adapter_name,
&task.try_into()?,
user_installed_path.as_ref().and_then(|p| p.to_str()),
resource,
)
.await?
.map_err(|e| anyhow!("{e:?}"))?;
Ok(Ok(dap_binary))
}
Extension::V0_6_0(ext) => {
let dap_binary = ext
.call_get_dap_binary(
@@ -1032,6 +1059,16 @@ impl Extension {
config: serde_json::Value,
) -> Result<Result<StartDebuggingRequestArgumentsRequest, String>> {
match self {
Extension::V0_8_0(ext) => {
let config =
serde_json::to_string(&config).context("Adapter config is not a valid JSON")?;
let result = ext
.call_dap_request_kind(store, &adapter_name, &config)
.await?
.map_err(|e| anyhow!("{e:?}"))?;
Ok(Ok(result))
}
Extension::V0_6_0(ext) => {
let config =
serde_json::to_string(&config).context("Adapter config is not a valid JSON")?;
@@ -1052,6 +1089,15 @@ impl Extension {
config: ZedDebugConfig,
) -> Result<Result<DebugScenario, String>> {
match self {
Extension::V0_8_0(ext) => {
let config = config.into();
let result = ext
.call_dap_config_to_scenario(store, &config)
.await?
.map_err(|e| anyhow!("{e:?}"))?;
Ok(Ok(result.try_into()?))
}
Extension::V0_6_0(ext) => {
let config = config.into();
let dap_binary = ext
@@ -1074,6 +1120,20 @@ impl Extension {
debug_adapter_name: String,
) -> Result<Option<DebugScenario>> {
match self {
Extension::V0_8_0(ext) => {
let build_config_template = build_config_template.into();
let result = ext
.call_dap_locator_create_scenario(
store,
&locator_name,
&build_config_template,
&resolved_label,
&debug_adapter_name,
)
.await?;
Ok(result.map(TryInto::try_into).transpose()?)
}
Extension::V0_6_0(ext) => {
let build_config_template = build_config_template.into();
let dap_binary = ext
@@ -1099,6 +1159,15 @@ impl Extension {
resolved_build_task: SpawnInTerminal,
) -> Result<Result<DebugRequest, String>> {
match self {
Extension::V0_8_0(ext) => {
let build_config_template = resolved_build_task.try_into()?;
let dap_request = ext
.call_run_dap_locator(store, &locator_name, &build_config_template)
.await?
.map_err(|e| anyhow!("{e:?}"))?;
Ok(Ok(dap_request.into()))
}
Extension::V0_6_0(ext) => {
let build_config_template = resolved_build_task.try_into()?;
let dap_request = ext
@@ -1111,6 +1180,174 @@ impl Extension {
_ => anyhow::bail!("`dap_locator_create_scenario` not available prior to v0.6.0"),
}
}
pub async fn call_llm_providers(
&self,
store: &mut Store<WasmState>,
) -> Result<Vec<latest::llm_provider::ProviderInfo>> {
match self {
Extension::V0_8_0(ext) => ext.call_llm_providers(store).await,
_ => Ok(Vec::new()),
}
}
pub async fn call_llm_provider_models(
&self,
store: &mut Store<WasmState>,
provider_id: &str,
) -> Result<Result<Vec<latest::llm_provider::ModelInfo>, String>> {
match self {
Extension::V0_8_0(ext) => ext.call_llm_provider_models(store, provider_id).await,
_ => anyhow::bail!("`llm_provider_models` not available prior to v0.8.0"),
}
}
pub async fn call_llm_provider_settings_markdown(
&self,
store: &mut Store<WasmState>,
provider_id: &str,
) -> Result<Option<String>> {
match self {
Extension::V0_8_0(ext) => {
ext.call_llm_provider_settings_markdown(store, provider_id)
.await
}
_ => Ok(None),
}
}
pub async fn call_llm_provider_is_authenticated(
&self,
store: &mut Store<WasmState>,
provider_id: &str,
) -> Result<bool> {
match self {
Extension::V0_8_0(ext) => {
ext.call_llm_provider_is_authenticated(store, provider_id)
.await
}
_ => Ok(false),
}
}
pub async fn call_llm_provider_start_device_flow_sign_in(
&self,
store: &mut Store<WasmState>,
provider_id: &str,
) -> Result<Result<String, String>> {
match self {
Extension::V0_8_0(ext) => {
ext.call_llm_provider_start_device_flow_sign_in(store, provider_id)
.await
}
_ => {
anyhow::bail!(
"`llm_provider_start_device_flow_sign_in` not available prior to v0.8.0"
)
}
}
}
pub async fn call_llm_provider_poll_device_flow_sign_in(
&self,
store: &mut Store<WasmState>,
provider_id: &str,
) -> Result<Result<(), String>> {
match self {
Extension::V0_8_0(ext) => {
ext.call_llm_provider_poll_device_flow_sign_in(store, provider_id)
.await
}
_ => {
anyhow::bail!(
"`llm_provider_poll_device_flow_sign_in` not available prior to v0.8.0"
)
}
}
}
pub async fn call_llm_provider_reset_credentials(
&self,
store: &mut Store<WasmState>,
provider_id: &str,
) -> Result<Result<(), String>> {
match self {
Extension::V0_8_0(ext) => {
ext.call_llm_provider_reset_credentials(store, provider_id)
.await
}
_ => anyhow::bail!("`llm_provider_reset_credentials` not available prior to v0.8.0"),
}
}
pub async fn call_llm_count_tokens(
&self,
store: &mut Store<WasmState>,
provider_id: &str,
model_id: &str,
request: &latest::llm_provider::CompletionRequest,
) -> Result<Result<u64, String>> {
match self {
Extension::V0_8_0(ext) => {
ext.call_llm_count_tokens(store, provider_id, model_id, request)
.await
}
_ => anyhow::bail!("`llm_count_tokens` not available prior to v0.8.0"),
}
}
pub async fn call_llm_stream_completion_start(
&self,
store: &mut Store<WasmState>,
provider_id: &str,
model_id: &str,
request: &latest::llm_provider::CompletionRequest,
) -> Result<Result<String, String>> {
match self {
Extension::V0_8_0(ext) => {
ext.call_llm_stream_completion_start(store, provider_id, model_id, request)
.await
}
_ => anyhow::bail!("`llm_stream_completion_start` not available prior to v0.8.0"),
}
}
pub async fn call_llm_stream_completion_next(
&self,
store: &mut Store<WasmState>,
stream_id: &str,
) -> Result<Result<Option<latest::llm_provider::CompletionEvent>, String>> {
match self {
Extension::V0_8_0(ext) => ext.call_llm_stream_completion_next(store, stream_id).await,
_ => anyhow::bail!("`llm_stream_completion_next` not available prior to v0.8.0"),
}
}
pub async fn call_llm_stream_completion_close(
&self,
store: &mut Store<WasmState>,
stream_id: &str,
) -> Result<()> {
match self {
Extension::V0_8_0(ext) => ext.call_llm_stream_completion_close(store, stream_id).await,
_ => anyhow::bail!("`llm_stream_completion_close` not available prior to v0.8.0"),
}
}
pub async fn call_llm_cache_configuration(
&self,
store: &mut Store<WasmState>,
provider_id: &str,
model_id: &str,
) -> Result<Option<latest::llm_provider::CacheConfiguration>> {
match self {
Extension::V0_8_0(ext) => {
ext.call_llm_cache_configuration(store, provider_id, model_id)
.await
}
_ => Ok(None),
}
}
}
trait ToWasmtimeResult<T> {

View File

@@ -32,8 +32,6 @@ wasmtime::component::bindgen!({
},
});
pub use self::zed::extension::*;
mod settings {
#![allow(dead_code)]
include!(concat!(env!("OUT_DIR"), "/since_v0.6.0/settings.rs"));

View File

@@ -1,11 +1,11 @@
use crate::wasm_host::wit::since_v0_6_0::{
use crate::wasm_host::wit::since_v0_8_0::{
dap::{
AttachRequest, BuildTaskDefinition, BuildTaskDefinitionTemplatePayload, LaunchRequest,
StartDebuggingRequestArguments, TcpArguments, TcpArgumentsTemplate,
},
lsp::{CompletionKind, CompletionLabelDetails, InsertTextFormat, SymbolKind},
slash_command::SlashCommandOutputSection,
};
use crate::wasm_host::wit::{CompletionKind, CompletionLabelDetails, InsertTextFormat, SymbolKind};
use crate::wasm_host::{WasmState, wit::ToWasmtimeResult};
use ::http_client::{AsyncBody, HttpRequestExt};
use ::settings::{Settings, WorktreeId};
@@ -13,6 +13,7 @@ use anyhow::{Context as _, Result, bail};
use async_compression::futures::bufread::GzipDecoder;
use async_tar::Archive;
use async_trait::async_trait;
use credentials_provider::CredentialsProvider;
use extension::{
ExtensionLanguageServerProxy, KeyValueStoreDelegate, ProjectDelegate, WorktreeDelegate,
};
@@ -22,12 +23,14 @@ use gpui::{BackgroundExecutor, SharedString};
use language::{BinaryStatus, LanguageName, language_settings::AllLanguageSettings};
use project::project_settings::ProjectSettings;
use semver::Version;
use smol::net::TcpListener;
use std::{
env,
net::Ipv4Addr,
path::{Path, PathBuf},
str::FromStr,
sync::{Arc, OnceLock},
time::Duration,
};
use task::{SpawnInTerminal, ZedDebugConfig};
use url::Url;
@@ -1107,3 +1110,361 @@ impl ExtensionImports for WasmState {
.to_wasmtime_result()
}
}
impl llm_provider::Host for WasmState {
async fn request_credential(
&mut self,
_provider_id: String,
_credential_type: llm_provider::CredentialType,
_label: String,
_placeholder: String,
) -> wasmtime::Result<Result<bool, String>> {
// For now, credential requests return false (not provided)
// Extensions should use get_env_var to check for env vars first,
// then store_credential/get_credential for manual storage
// Full UI credential prompting will be added in a future phase
Ok(Ok(false))
}
async fn get_credential(&mut self, provider_id: String) -> wasmtime::Result<Option<String>> {
let extension_id = self.manifest.id.clone();
// Check if this provider has an env var configured and if the user has allowed it
let env_var_name = self
.manifest
.language_model_providers
.get(&Arc::<str>::from(provider_id.as_str()))
.and_then(|entry| entry.auth.as_ref())
.and_then(|auth| auth.env_var.clone());
if let Some(env_var_name) = env_var_name {
let full_provider_id: Arc<str> = format!("{}:{}", extension_id, provider_id).into();
// Read settings dynamically to get current allowed_env_var_providers
let is_allowed = self
.on_main_thread({
let full_provider_id = full_provider_id.clone();
move |cx| {
async move {
cx.update(|cx| {
crate::extension_settings::ExtensionSettings::get_global(cx)
.allowed_env_var_providers
.contains(&full_provider_id)
})
}
.boxed_local()
}
})
.await
.unwrap_or(false);
if is_allowed {
if let Ok(value) = env::var(&env_var_name) {
if !value.is_empty() {
return Ok(Some(value));
}
}
}
}
// Fall back to credential store
let credential_key = format!("extension-llm-{}:{}", extension_id, provider_id);
self.on_main_thread(move |cx| {
async move {
let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx))?;
let result = credentials_provider
.read_credentials(&credential_key, cx)
.await
.ok()
.flatten();
Ok(result.map(|(_, password)| String::from_utf8_lossy(&password).to_string()))
}
.boxed_local()
})
.await
}
async fn store_credential(
&mut self,
provider_id: String,
value: String,
) -> wasmtime::Result<Result<(), String>> {
let extension_id = self.manifest.id.clone();
let credential_key = format!("extension-llm-{}:{}", extension_id, provider_id);
self.on_main_thread(move |cx| {
async move {
let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx))?;
credentials_provider
.write_credentials(&credential_key, "api_key", value.as_bytes(), cx)
.await
.map_err(|e| anyhow::anyhow!("{}", e))
}
.boxed_local()
})
.await
.to_wasmtime_result()
}
async fn delete_credential(
&mut self,
provider_id: String,
) -> wasmtime::Result<Result<(), String>> {
let extension_id = self.manifest.id.clone();
let credential_key = format!("extension-llm-{}:{}", extension_id, provider_id);
self.on_main_thread(move |cx| {
async move {
let credentials_provider = cx.update(|cx| <dyn CredentialsProvider>::global(cx))?;
credentials_provider
.delete_credentials(&credential_key, cx)
.await
.map_err(|e| anyhow::anyhow!("{}", e))
}
.boxed_local()
})
.await
.to_wasmtime_result()
}
async fn get_env_var(&mut self, name: String) -> wasmtime::Result<Option<String>> {
let extension_id = self.manifest.id.clone();
// Find which provider (if any) declares this env var in its auth config
let mut allowed_provider_id: Option<Arc<str>> = None;
for (provider_id, provider_entry) in &self.manifest.language_model_providers {
if let Some(auth_config) = &provider_entry.auth {
if auth_config.env_var.as_deref() == Some(&name) {
allowed_provider_id = Some(provider_id.clone());
break;
}
}
}
// If no provider declares this env var, deny access
let Some(provider_id) = allowed_provider_id else {
log::warn!(
"Extension {} attempted to read env var {} which is not declared in any provider auth config",
extension_id,
name
);
return Ok(None);
};
// Check if the user has allowed this provider to read env vars
// Read settings dynamically to get current allowed_env_var_providers
let full_provider_id: Arc<str> = format!("{}:{}", extension_id, provider_id).into();
let is_allowed = self
.on_main_thread({
let full_provider_id = full_provider_id.clone();
move |cx| {
async move {
cx.update(|cx| {
crate::extension_settings::ExtensionSettings::get_global(cx)
.allowed_env_var_providers
.contains(&full_provider_id)
})
}
.boxed_local()
}
})
.await
.unwrap_or(false);
if !is_allowed {
log::debug!(
"Extension {} provider {} is not allowed to read env var {}",
extension_id,
provider_id,
name
);
return Ok(None);
}
Ok(env::var(&name).ok())
}
async fn oauth_start_web_auth(
&mut self,
config: llm_provider::OauthWebAuthConfig,
) -> wasmtime::Result<Result<llm_provider::OauthWebAuthResult, String>> {
let auth_url = config.auth_url;
let callback_path = config.callback_path;
let timeout_secs = config.timeout_secs.unwrap_or(300);
self.on_main_thread(move |cx| {
async move {
let listener = TcpListener::bind("127.0.0.1:0")
.await
.map_err(|e| anyhow::anyhow!("Failed to bind localhost server: {}", e))?;
let port = listener
.local_addr()
.map_err(|e| anyhow::anyhow!("Failed to get local address: {}", e))?
.port();
let auth_url_with_port = auth_url.replace("{port}", &port.to_string());
cx.update(|cx| {
cx.open_url(&auth_url_with_port);
})?;
let accept_future = async {
let (mut stream, _) = listener
.accept()
.await
.map_err(|e| anyhow::anyhow!("Failed to accept connection: {}", e))?;
let mut request_line = String::new();
{
let mut reader = smol::io::BufReader::new(&mut stream);
smol::io::AsyncBufReadExt::read_line(&mut reader, &mut request_line)
.await
.map_err(|e| anyhow::anyhow!("Failed to read request: {}", e))?;
}
let callback_url = if let Some(path_start) = request_line.find(' ') {
if let Some(path_end) = request_line[path_start + 1..].find(' ') {
let path = &request_line[path_start + 1..path_start + 1 + path_end];
if path.starts_with(&callback_path) || path.starts_with(&format!("/{}", callback_path.trim_start_matches('/'))) {
format!("http://localhost:{}{}", port, path)
} else {
return Err(anyhow::anyhow!(
"Unexpected callback path: {}",
path
));
}
} else {
return Err(anyhow::anyhow!("Malformed HTTP request"));
}
} else {
return Err(anyhow::anyhow!("Malformed HTTP request"));
};
let response = "HTTP/1.1 200 OK\r\n\
Content-Type: text/html\r\n\
Connection: close\r\n\
\r\n\
<!DOCTYPE html>\
<html><head><title>Authentication Complete</title></head>\
<body style=\"font-family: system-ui, sans-serif; display: flex; justify-content: center; align-items: center; height: 100vh; margin: 0;\">\
<div style=\"text-align: center;\">\
<h1>Authentication Complete</h1>\
<p>You can close this window and return to Zed.</p>\
</div></body></html>";
smol::io::AsyncWriteExt::write_all(&mut stream, response.as_bytes())
.await
.ok();
smol::io::AsyncWriteExt::flush(&mut stream).await.ok();
Ok(callback_url)
};
let timeout_duration = Duration::from_secs(timeout_secs as u64);
let callback_url = smol::future::or(
accept_future,
async {
smol::Timer::after(timeout_duration).await;
Err(anyhow::anyhow!(
"OAuth callback timed out after {} seconds",
timeout_secs
))
},
)
.await?;
Ok(llm_provider::OauthWebAuthResult {
callback_url,
port: port as u32,
})
}
.boxed_local()
})
.await
.to_wasmtime_result()
}
async fn send_oauth_http_request(
&mut self,
request: llm_provider::OauthHttpRequest,
) -> wasmtime::Result<Result<llm_provider::OauthHttpResponse, String>> {
let http_client = self.host.http_client.clone();
self.on_main_thread(move |_cx| {
async move {
let method = match request.method.to_uppercase().as_str() {
"GET" => ::http_client::Method::GET,
"POST" => ::http_client::Method::POST,
"PUT" => ::http_client::Method::PUT,
"DELETE" => ::http_client::Method::DELETE,
"PATCH" => ::http_client::Method::PATCH,
_ => {
return Err(anyhow::anyhow!(
"Unsupported HTTP method: {}",
request.method
));
}
};
let mut builder = ::http_client::Request::builder()
.method(method)
.uri(&request.url);
for (key, value) in &request.headers {
builder = builder.header(key.as_str(), value.as_str());
}
let body = if request.body.is_empty() {
AsyncBody::empty()
} else {
AsyncBody::from(request.body.into_bytes())
};
let http_request = builder
.body(body)
.map_err(|e| anyhow::anyhow!("Failed to build request: {}", e))?;
let mut response = http_client
.send(http_request)
.await
.map_err(|e| anyhow::anyhow!("HTTP request failed: {}", e))?;
let status = response.status().as_u16();
let headers: Vec<(String, String)> = response
.headers()
.iter()
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
.collect();
let mut body_bytes = Vec::new();
futures::AsyncReadExt::read_to_end(response.body_mut(), &mut body_bytes)
.await
.map_err(|e| anyhow::anyhow!("Failed to read response body: {}", e))?;
let body = String::from_utf8_lossy(&body_bytes).to_string();
Ok(llm_provider::OauthHttpResponse {
status,
headers,
body,
})
}
.boxed_local()
})
.await
.to_wasmtime_result()
}
async fn oauth_open_browser(&mut self, url: String) -> wasmtime::Result<Result<(), String>> {
self.on_main_thread(move |cx| {
async move {
cx.update(|cx| {
cx.open_url(&url);
})?;
Ok(())
}
.boxed_local()
})
.await
.to_wasmtime_result()
}
}

View File

@@ -442,7 +442,9 @@ impl ExtensionsPage {
let extension_store = ExtensionStore::global(cx).read(cx);
match extension_store.outstanding_operations().get(extension_id) {
Some(ExtensionOperation::Install) => ExtensionStatus::Installing,
Some(ExtensionOperation::Install) | Some(ExtensionOperation::AutoInstall) => {
ExtensionStatus::Installing
}
Some(ExtensionOperation::Remove) => ExtensionStatus::Removing,
Some(ExtensionOperation::Upgrade) => ExtensionStatus::Upgrading,
None => match extension_store.installed_extensions().get(extension_id) {

View File

@@ -296,6 +296,20 @@ impl TestAppContext {
&self.text_system
}
/// Simulates writing credentials to the platform keychain.
pub fn write_credentials(&self, url: &str, username: &str, password: &[u8]) {
let _ = self
.test_platform
.write_credentials(url, username, password);
}
/// Simulates reading credentials from the platform keychain.
pub fn read_credentials(&self, url: &str) -> Option<(String, Vec<u8>)> {
smol::block_on(self.test_platform.read_credentials(url))
.ok()
.flatten()
}
/// Simulates writing to the platform clipboard
pub fn write_to_clipboard(&self, item: ClipboardItem) {
self.test_platform.write_to_clipboard(item)

View File

@@ -6,7 +6,7 @@ use crate::{
TestDisplay, TestWindow, WindowAppearance, WindowParams, size,
};
use anyhow::Result;
use collections::VecDeque;
use collections::{HashMap, VecDeque};
use futures::channel::oneshot;
use parking_lot::Mutex;
use std::{
@@ -32,6 +32,7 @@ pub(crate) struct TestPlatform {
current_clipboard_item: Mutex<Option<ClipboardItem>>,
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
current_primary_item: Mutex<Option<ClipboardItem>>,
credentials: Mutex<HashMap<String, (String, Vec<u8>)>>,
pub(crate) prompts: RefCell<TestPrompts>,
screen_capture_sources: RefCell<Vec<TestScreenCaptureSource>>,
pub opened_url: RefCell<Option<String>>,
@@ -117,6 +118,7 @@ impl TestPlatform {
current_clipboard_item: Mutex::new(None),
#[cfg(any(target_os = "linux", target_os = "freebsd"))]
current_primary_item: Mutex::new(None),
credentials: Mutex::new(HashMap::default()),
weak: weak.clone(),
opened_url: Default::default(),
#[cfg(target_os = "windows")]
@@ -416,15 +418,20 @@ impl Platform for TestPlatform {
self.current_clipboard_item.lock().clone()
}
fn write_credentials(&self, _url: &str, _username: &str, _password: &[u8]) -> Task<Result<()>> {
fn write_credentials(&self, url: &str, username: &str, password: &[u8]) -> Task<Result<()>> {
self.credentials
.lock()
.insert(url.to_string(), (username.to_string(), password.to_vec()));
Task::ready(Ok(()))
}
fn read_credentials(&self, _url: &str) -> Task<Result<Option<(String, Vec<u8>)>>> {
Task::ready(Ok(None))
fn read_credentials(&self, url: &str) -> Task<Result<Option<(String, Vec<u8>)>>> {
let result = self.credentials.lock().get(url).cloned();
Task::ready(Ok(result))
}
fn delete_credentials(&self, _url: &str) -> Task<Result<()>> {
fn delete_credentials(&self, url: &str) -> Task<Result<()>> {
self.credentials.lock().remove(url);
Task::ready(Ok(()))
}

View File

@@ -780,6 +780,11 @@ pub trait LanguageModelProvider: 'static {
fn icon(&self) -> IconName {
IconName::ZedAssistant
}
/// Returns the path to an external SVG icon for this provider, if any.
/// When present, this takes precedence over `icon()`.
fn icon_path(&self) -> Option<SharedString> {
None
}
fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>>;
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>>;

View File

@@ -2,12 +2,16 @@ use crate::{
LanguageModel, LanguageModelId, LanguageModelProvider, LanguageModelProviderId,
LanguageModelProviderState,
};
use collections::BTreeMap;
use collections::{BTreeMap, HashSet};
use gpui::{App, Context, Entity, EventEmitter, Global, prelude::*};
use std::{str::FromStr, sync::Arc};
use thiserror::Error;
use util::maybe;
/// Function type for checking if a built-in provider should be hidden.
/// Returns Some(extension_id) if the provider should be hidden when that extension is installed.
pub type BuiltinProviderHidingFn = Box<dyn Fn(&str) -> Option<&'static str> + Send + Sync>;
pub fn init(cx: &mut App) {
let registry = cx.new(|_cx| LanguageModelRegistry::default());
cx.set_global(GlobalLanguageModelRegistry(registry));
@@ -48,6 +52,11 @@ pub struct LanguageModelRegistry {
thread_summary_model: Option<ConfiguredModel>,
providers: BTreeMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
inline_alternatives: Vec<Arc<dyn LanguageModel>>,
/// Set of installed extension IDs that provide language models.
/// Used to determine which built-in providers should be hidden.
installed_llm_extension_ids: HashSet<Arc<str>>,
/// Function to check if a built-in provider should be hidden by an extension.
builtin_provider_hiding_fn: Option<BuiltinProviderHidingFn>,
}
#[derive(Debug)]
@@ -104,6 +113,8 @@ pub enum Event {
ProviderStateChanged(LanguageModelProviderId),
AddedProvider(LanguageModelProviderId),
RemovedProvider(LanguageModelProviderId),
/// Emitted when provider visibility changes due to extension install/uninstall.
ProvidersChanged,
}
impl EventEmitter<Event> for LanguageModelRegistry {}
@@ -183,6 +194,60 @@ impl LanguageModelRegistry {
providers
}
/// Returns providers, filtering out hidden built-in providers.
pub fn visible_providers(&self) -> Vec<Arc<dyn LanguageModelProvider>> {
self.providers()
.into_iter()
.filter(|p| !self.should_hide_provider(&p.id()))
.collect()
}
/// Sets the function used to check if a built-in provider should be hidden.
pub fn set_builtin_provider_hiding_fn(&mut self, hiding_fn: BuiltinProviderHidingFn) {
self.builtin_provider_hiding_fn = Some(hiding_fn);
}
/// Called when an extension is installed/loaded.
/// If the extension provides language models, track it so we can hide the corresponding built-in.
pub fn extension_installed(&mut self, extension_id: Arc<str>, cx: &mut Context<Self>) {
if self.installed_llm_extension_ids.insert(extension_id) {
cx.emit(Event::ProvidersChanged);
cx.notify();
}
}
/// Called when an extension is uninstalled/unloaded.
pub fn extension_uninstalled(&mut self, extension_id: &str, cx: &mut Context<Self>) {
if self.installed_llm_extension_ids.remove(extension_id) {
cx.emit(Event::ProvidersChanged);
cx.notify();
}
}
/// Sync the set of installed LLM extension IDs.
pub fn sync_installed_llm_extensions(
&mut self,
extension_ids: HashSet<Arc<str>>,
cx: &mut Context<Self>,
) {
if extension_ids != self.installed_llm_extension_ids {
self.installed_llm_extension_ids = extension_ids;
cx.emit(Event::ProvidersChanged);
cx.notify();
}
}
/// Returns true if a provider should be hidden from the UI.
/// Built-in providers are hidden when their corresponding extension is installed.
pub fn should_hide_provider(&self, provider_id: &LanguageModelProviderId) -> bool {
if let Some(ref hiding_fn) = self.builtin_provider_hiding_fn {
if let Some(extension_id) = hiding_fn(&provider_id.0) {
return self.installed_llm_extension_ids.contains(extension_id);
}
}
false
}
pub fn configuration_error(
&self,
model: Option<ConfiguredModel>,
@@ -416,4 +481,151 @@ mod tests {
let providers = registry.read(cx).providers();
assert!(providers.is_empty());
}
#[gpui::test]
fn test_provider_hiding_on_extension_install(cx: &mut App) {
let registry = cx.new(|_| LanguageModelRegistry::default());
let provider = Arc::new(FakeLanguageModelProvider::default());
let provider_id = provider.id();
registry.update(cx, |registry, cx| {
registry.register_provider(provider.clone(), cx);
// Set up a hiding function that hides the fake provider when "fake-extension" is installed
registry.set_builtin_provider_hiding_fn(Box::new(|id| {
if id == "fake" {
Some("fake-extension")
} else {
None
}
}));
});
// Provider should be visible initially
let visible = registry.read(cx).visible_providers();
assert_eq!(visible.len(), 1);
assert_eq!(visible[0].id(), provider_id);
// Install the extension
registry.update(cx, |registry, cx| {
registry.extension_installed("fake-extension".into(), cx);
});
// Provider should now be hidden
let visible = registry.read(cx).visible_providers();
assert!(visible.is_empty());
// But still in providers()
let all = registry.read(cx).providers();
assert_eq!(all.len(), 1);
}
#[gpui::test]
fn test_provider_unhiding_on_extension_uninstall(cx: &mut App) {
let registry = cx.new(|_| LanguageModelRegistry::default());
let provider = Arc::new(FakeLanguageModelProvider::default());
let provider_id = provider.id();
registry.update(cx, |registry, cx| {
registry.register_provider(provider.clone(), cx);
// Set up hiding function
registry.set_builtin_provider_hiding_fn(Box::new(|id| {
if id == "fake" {
Some("fake-extension")
} else {
None
}
}));
// Start with extension installed
registry.extension_installed("fake-extension".into(), cx);
});
// Provider should be hidden
let visible = registry.read(cx).visible_providers();
assert!(visible.is_empty());
// Uninstall the extension
registry.update(cx, |registry, cx| {
registry.extension_uninstalled("fake-extension", cx);
});
// Provider should now be visible again
let visible = registry.read(cx).visible_providers();
assert_eq!(visible.len(), 1);
assert_eq!(visible[0].id(), provider_id);
}
#[gpui::test]
fn test_should_hide_provider(cx: &mut App) {
let registry = cx.new(|_| LanguageModelRegistry::default());
registry.update(cx, |registry, cx| {
// Set up hiding function
registry.set_builtin_provider_hiding_fn(Box::new(|id| {
if id == "anthropic" {
Some("anthropic")
} else if id == "openai" {
Some("openai")
} else {
None
}
}));
// Install only anthropic extension
registry.extension_installed("anthropic".into(), cx);
});
let registry_read = registry.read(cx);
// Anthropic should be hidden
assert!(registry_read.should_hide_provider(&LanguageModelProviderId("anthropic".into())));
// OpenAI should not be hidden (extension not installed)
assert!(!registry_read.should_hide_provider(&LanguageModelProviderId("openai".into())));
// Unknown provider should not be hidden
assert!(!registry_read.should_hide_provider(&LanguageModelProviderId("unknown".into())));
}
#[gpui::test]
fn test_sync_installed_llm_extensions(cx: &mut App) {
let registry = cx.new(|_| LanguageModelRegistry::default());
let provider = Arc::new(FakeLanguageModelProvider::default());
registry.update(cx, |registry, cx| {
registry.register_provider(provider.clone(), cx);
registry.set_builtin_provider_hiding_fn(Box::new(|id| {
if id == "fake" {
Some("fake-extension")
} else {
None
}
}));
});
// Sync with a set containing the extension
let mut extension_ids = HashSet::default();
extension_ids.insert(Arc::from("fake-extension"));
registry.update(cx, |registry, cx| {
registry.sync_installed_llm_extensions(extension_ids, cx);
});
// Provider should be hidden
assert!(registry.read(cx).visible_providers().is_empty());
// Sync with empty set
registry.update(cx, |registry, cx| {
registry.sync_installed_llm_extensions(HashSet::default(), cx);
});
// Provider should be visible again
assert_eq!(registry.read(cx).visible_providers().len(), 1);
}
}

View File

@@ -28,6 +28,8 @@ convert_case.workspace = true
copilot.workspace = true
credentials_provider.workspace = true
deepseek = { workspace = true, features = ["schemars"] }
extension.workspace = true
extension_host.workspace = true
fs.workspace = true
futures.workspace = true
google_ai = { workspace = true, features = ["schemars"] }

View File

@@ -223,10 +223,6 @@ impl ApiKeyState {
}
impl ApiKey {
pub fn key(&self) -> &str {
&self.key
}
pub fn from_env(env_var_name: SharedString, key: &str) -> Self {
Self {
source: ApiKeySource::EnvVar(env_var_name),
@@ -234,16 +230,6 @@ impl ApiKey {
}
}
pub async fn load_from_system_keychain(
url: &str,
credentials_provider: &dyn CredentialsProvider,
cx: &AsyncApp,
) -> Result<Self, AuthenticateError> {
Self::load_from_system_keychain_impl(url, credentials_provider, cx)
.await
.into_authenticate_result()
}
async fn load_from_system_keychain_impl(
url: &str,
credentials_provider: &dyn CredentialsProvider,

View File

@@ -0,0 +1,68 @@
use ::extension::{
ExtensionHostProxy, ExtensionLanguageModelProviderProxy, LanguageModelProviderRegistration,
};
use collections::HashMap;
use gpui::{App, Entity};
use language_model::{LanguageModelProviderId, LanguageModelRegistry};
use std::sync::{Arc, LazyLock};
/// Maps built-in provider IDs to their corresponding extension IDs.
/// When an extension with this ID is installed, the built-in provider should be hidden.
static BUILTIN_TO_EXTENSION_MAP: LazyLock<HashMap<&'static str, &'static str>> =
LazyLock::new(|| {
let mut map = HashMap::default();
map.insert("anthropic", "anthropic");
map.insert("openai", "openai");
map.insert("google", "google-ai");
map.insert("openrouter", "openrouter");
map.insert("copilot_chat", "copilot-chat");
map
});
/// Returns the extension ID that should hide the given built-in provider.
pub fn extension_for_builtin_provider(provider_id: &str) -> Option<&'static str> {
BUILTIN_TO_EXTENSION_MAP.get(provider_id).copied()
}
/// Proxy that registers extension language model providers with the LanguageModelRegistry.
pub struct LanguageModelProviderRegistryProxy {
registry: Entity<LanguageModelRegistry>,
}
impl LanguageModelProviderRegistryProxy {
pub fn new(registry: Entity<LanguageModelRegistry>) -> Self {
Self { registry }
}
}
impl ExtensionLanguageModelProviderProxy for LanguageModelProviderRegistryProxy {
fn register_language_model_provider(
&self,
_provider_id: Arc<str>,
register_fn: LanguageModelProviderRegistration,
cx: &mut App,
) {
register_fn(cx);
}
fn unregister_language_model_provider(&self, provider_id: Arc<str>, cx: &mut App) {
self.registry.update(cx, |registry, cx| {
registry.unregister_provider(LanguageModelProviderId::from(provider_id), cx);
});
}
}
/// Initialize the extension language model provider proxy.
/// This must be called BEFORE extension_host::init to ensure the proxy is available
/// when extensions try to register their language model providers.
pub fn init_proxy(cx: &mut App) {
let proxy = ExtensionHostProxy::default_global(cx);
let registry = LanguageModelRegistry::global(cx);
// Set the function that determines which built-in providers should be hidden
registry.update(cx, |registry, _cx| {
registry.set_builtin_provider_hiding_fn(Box::new(extension_for_builtin_provider));
});
proxy.register_language_model_provider_proxy(LanguageModelProviderRegistryProxy::new(registry));
}

View File

@@ -0,0 +1,43 @@
use anyhow::Result;
use credentials_provider::CredentialsProvider;
use gpui::{App, Task};
const GEMINI_API_KEY_VAR_NAME: &str = "GEMINI_API_KEY";
const GOOGLE_AI_API_KEY_VAR_NAME: &str = "GOOGLE_AI_API_KEY";
const GOOGLE_AI_EXTENSION_CREDENTIAL_KEY: &str = "extension-llm-google-ai:google-ai";
/// Returns the Google AI API key for use by the Gemini CLI.
///
/// This function checks the following sources in order:
/// 1. `GEMINI_API_KEY` environment variable
/// 2. `GOOGLE_AI_API_KEY` environment variable
/// 3. Extension credential store (`extension-llm-google-ai:google-ai`)
pub fn api_key_for_gemini_cli(cx: &mut App) -> Task<Result<String>> {
if let Ok(key) = std::env::var(GEMINI_API_KEY_VAR_NAME) {
if !key.is_empty() {
return Task::ready(Ok(key));
}
}
if let Ok(key) = std::env::var(GOOGLE_AI_API_KEY_VAR_NAME) {
if !key.is_empty() {
return Task::ready(Ok(key));
}
}
let credentials_provider = <dyn CredentialsProvider>::global(cx);
cx.spawn(async move |cx| {
let credential = credentials_provider
.read_credentials(GOOGLE_AI_EXTENSION_CREDENTIAL_KEY, &cx)
.await?;
match credential {
Some((_, key_bytes)) => {
let key = String::from_utf8(key_bytes)?;
Ok(key)
}
None => Err(anyhow::anyhow!("No Google AI API key found")),
}
})
}

View File

@@ -8,10 +8,14 @@ use language_model::{LanguageModelProviderId, LanguageModelRegistry};
use provider::deepseek::DeepSeekLanguageModelProvider;
mod api_key;
pub mod extension;
mod google_ai_api_key;
pub mod provider;
mod settings;
pub mod ui;
pub use crate::extension::init_proxy as init_extension_proxy;
pub use crate::google_ai_api_key::api_key_for_gemini_cli;
use crate::provider::anthropic::AnthropicLanguageModelProvider;
use crate::provider::bedrock::BedrockLanguageModelProvider;
use crate::provider::cloud::CloudLanguageModelProvider;
@@ -33,6 +37,61 @@ pub fn init(user_store: Entity<UserStore>, client: Arc<Client>, cx: &mut App) {
register_language_model_providers(registry, user_store, client.clone(), cx);
});
// Subscribe to extension store events to track LLM extension installations
if let Some(extension_store) = extension_host::ExtensionStore::try_global(cx) {
cx.subscribe(&extension_store, {
let registry = registry.clone();
move |extension_store, event, cx| {
match event {
extension_host::Event::ExtensionInstalled(extension_id) => {
// Check if this extension has language_model_providers
if let Some(manifest) = extension_store
.read(cx)
.extension_manifest_for_id(extension_id)
{
if !manifest.language_model_providers.is_empty() {
registry.update(cx, |registry, cx| {
registry.extension_installed(extension_id.clone(), cx);
});
}
}
}
extension_host::Event::ExtensionUninstalled(extension_id) => {
registry.update(cx, |registry, cx| {
registry.extension_uninstalled(extension_id, cx);
});
}
extension_host::Event::ExtensionsUpdated => {
// Re-sync installed extensions on bulk updates
let mut new_ids = HashSet::default();
for (extension_id, entry) in extension_store.read(cx).installed_extensions()
{
if !entry.manifest.language_model_providers.is_empty() {
new_ids.insert(extension_id.clone());
}
}
registry.update(cx, |registry, cx| {
registry.sync_installed_llm_extensions(new_ids, cx);
});
}
_ => {}
}
}
})
.detach();
// Initialize with currently installed extensions
registry.update(cx, |registry, cx| {
let mut initial_ids = HashSet::default();
for (extension_id, entry) in extension_store.read(cx).installed_extensions() {
if !entry.manifest.language_model_providers.is_empty() {
initial_ids.insert(extension_id.clone());
}
}
registry.sync_installed_llm_extensions(initial_ids, cx);
});
}
let mut openai_compatible_providers = AllLanguageModelSettings::get_global(cx)
.openai_compatible
.keys()

View File

@@ -1,6 +1,5 @@
use anyhow::{Context as _, Result, anyhow};
use collections::BTreeMap;
use credentials_provider::CredentialsProvider;
use futures::{FutureExt, Stream, StreamExt, future, future::BoxFuture};
use google_ai::{
FunctionDeclaration, GenerateContentResponse, GoogleModelMode, Part, SystemInstruction,
@@ -33,7 +32,6 @@ use ui_input::InputField;
use util::ResultExt;
use zed_env_vars::EnvVar;
use crate::api_key::ApiKey;
use crate::api_key::ApiKeyState;
use crate::ui::{ConfiguredApiCard, InstructionListItem};
@@ -128,22 +126,6 @@ impl GoogleLanguageModelProvider {
})
}
pub fn api_key_for_gemini_cli(cx: &mut App) -> Task<Result<String>> {
if let Some(key) = API_KEY_ENV_VAR.value.clone() {
return Task::ready(Ok(key));
}
let credentials_provider = <dyn CredentialsProvider>::global(cx);
let api_url = Self::api_url(cx).to_string();
cx.spawn(async move |cx| {
Ok(
ApiKey::load_from_system_keychain(&api_url, credentials_provider.as_ref(), cx)
.await?
.key()
.to_string(),
)
})
}
fn settings(cx: &App) -> &GoogleSettings {
&crate::AllLanguageModelSettings::get_global(cx).google
}

View File

@@ -255,7 +255,6 @@ impl JsonSchema for LanguageModelProviderSetting {
"type": "string",
"enum": [
"amazon-bedrock",
"anthropic",
"copilot_chat",
"deepseek",
"google",

View File

@@ -20,6 +20,12 @@ pub struct ExtensionSettingsContent {
pub auto_update_extensions: HashMap<Arc<str>, bool>,
/// The capabilities granted to extensions.
pub granted_extension_capabilities: Option<Vec<ExtensionCapabilityContent>>,
/// Extension language model providers that are allowed to read API keys from
/// environment variables. Each entry is a provider ID in the format
/// "extension_id:provider_id" (e.g., "openai:openai").
///
/// Default: []
pub allowed_env_var_providers: Option<Vec<Arc<str>>>,
}
/// A capability for an extension.

View File

@@ -126,17 +126,6 @@ enum IconSource {
ExternalSvg(SharedString),
}
impl IconSource {
fn from_path(path: impl Into<SharedString>) -> Self {
let path = path.into();
if path.starts_with("icons/") {
Self::Embedded(path)
} else {
Self::External(Arc::from(PathBuf::from(path.as_ref())))
}
}
}
#[derive(IntoElement, RegisterComponent)]
pub struct Icon {
source: IconSource,
@@ -155,9 +144,18 @@ impl Icon {
}
}
/// Create an icon from a path. Uses a heuristic to determine if it's embedded or external:
/// - Paths starting with "icons/" are treated as embedded SVGs
/// - Other paths are treated as external raster images (from icon themes)
pub fn from_path(path: impl Into<SharedString>) -> Self {
let path = path.into();
let source = if path.starts_with("icons/") {
IconSource::Embedded(path)
} else {
IconSource::External(Arc::from(PathBuf::from(path.as_ref())))
};
Self {
source: IconSource::from_path(path),
source,
color: Color::default(),
size: IconSize::default().rems(),
transformation: Transformation::default(),

View File

@@ -555,6 +555,11 @@ pub fn main() {
dap_adapters::init(cx);
auto_update_ui::init(cx);
reliability::init(client.clone(), cx);
// Initialize the language model registry first, then set up the extension proxy
// BEFORE extension_host::init so that extensions can register their LLM providers
// when they load.
language_model::init(app_state.client.clone(), cx);
language_models::init_extension_proxy(cx);
extension_host::init(
extension_host_proxy.clone(),
app_state.fs.clone(),
@@ -580,7 +585,6 @@ pub fn main() {
cx,
);
supermaven::init(app_state.client.clone(), cx);
language_model::init(app_state.client.clone(), cx);
language_models::init(app_state.user_store.clone(), app_state.client.clone(), cx);
acp_tools::init(cx);
edit_prediction_ui::init(cx);

View File

@@ -2626,9 +2626,6 @@ These values take in the same options as the root-level settings with the same n
```json [settings]
{
"language_models": {
"anthropic": {
"api_url": "https://api.anthropic.com"
},
"google": {
"api_url": "https://generativelanguage.googleapis.com"
},

823
extensions/anthropic/Cargo.lock generated Normal file
View File

@@ -0,0 +1,823 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 4
[[package]]
name = "adler2"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa"
[[package]]
name = "anthropic"
version = "0.1.0"
dependencies = [
"serde",
"serde_json",
"zed_extension_api",
]
[[package]]
name = "anyhow"
version = "1.0.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61"
[[package]]
name = "auditable-serde"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c7bf8143dfc3c0258df908843e169b5cc5fcf76c7718bd66135ef4a9cd558c5"
dependencies = [
"semver",
"serde",
"serde_json",
"topological-sort",
]
[[package]]
name = "bitflags"
version = "2.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3"
[[package]]
name = "cfg-if"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
[[package]]
name = "crc32fast"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511"
dependencies = [
"cfg-if",
]
[[package]]
name = "displaydoc"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "equivalent"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
[[package]]
name = "flate2"
version = "1.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfe33edd8e85a12a67454e37f8c75e730830d83e313556ab9ebf9ee7fbeb3bfb"
dependencies = [
"crc32fast",
"miniz_oxide",
]
[[package]]
name = "foldhash"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
[[package]]
name = "form_urlencoded"
version = "1.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf"
dependencies = [
"percent-encoding",
]
[[package]]
name = "futures"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876"
dependencies = [
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-channel"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10"
dependencies = [
"futures-core",
"futures-sink",
]
[[package]]
name = "futures-core"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e"
[[package]]
name = "futures-executor"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f"
dependencies = [
"futures-core",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-io"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6"
[[package]]
name = "futures-macro"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "futures-sink"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7"
[[package]]
name = "futures-task"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
[[package]]
name = "futures-util"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-macro",
"futures-sink",
"futures-task",
"memchr",
"pin-project-lite",
"pin-utils",
"slab",
]
[[package]]
name = "hashbrown"
version = "0.15.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1"
dependencies = [
"foldhash",
]
[[package]]
name = "hashbrown"
version = "0.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100"
[[package]]
name = "heck"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "icu_collections"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c6b649701667bbe825c3b7e6388cb521c23d88644678e83c0c4d0a621a34b43"
dependencies = [
"displaydoc",
"potential_utf",
"yoke",
"zerofrom",
"zerovec",
]
[[package]]
name = "icu_locale_core"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "edba7861004dd3714265b4db54a3c390e880ab658fec5f7db895fae2046b5bb6"
dependencies = [
"displaydoc",
"litemap",
"tinystr",
"writeable",
"zerovec",
]
[[package]]
name = "icu_normalizer"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f6c8828b67bf8908d82127b2054ea1b4427ff0230ee9141c54251934ab1b599"
dependencies = [
"icu_collections",
"icu_normalizer_data",
"icu_properties",
"icu_provider",
"smallvec",
"zerovec",
]
[[package]]
name = "icu_normalizer_data"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7aedcccd01fc5fe81e6b489c15b247b8b0690feb23304303a9e560f37efc560a"
[[package]]
name = "icu_properties"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e93fcd3157766c0c8da2f8cff6ce651a31f0810eaa1c51ec363ef790bbb5fb99"
dependencies = [
"icu_collections",
"icu_locale_core",
"icu_properties_data",
"icu_provider",
"zerotrie",
"zerovec",
]
[[package]]
name = "icu_properties_data"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02845b3647bb045f1100ecd6480ff52f34c35f82d9880e029d329c21d1054899"
[[package]]
name = "icu_provider"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85962cf0ce02e1e0a629cc34e7ca3e373ce20dda4c4d7294bbd0bf1fdb59e614"
dependencies = [
"displaydoc",
"icu_locale_core",
"writeable",
"yoke",
"zerofrom",
"zerotrie",
"zerovec",
]
[[package]]
name = "id-arena"
version = "2.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "25a2bc672d1148e28034f176e01fffebb08b35768468cc954630da77a1449005"
[[package]]
name = "idna"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de"
dependencies = [
"idna_adapter",
"smallvec",
"utf8_iter",
]
[[package]]
name = "idna_adapter"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344"
dependencies = [
"icu_normalizer",
"icu_properties",
]
[[package]]
name = "indexmap"
version = "2.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ad4bb2b565bca0645f4d68c5c9af97fba094e9791da685bf83cb5f3ce74acf2"
dependencies = [
"equivalent",
"hashbrown 0.16.1",
"serde",
"serde_core",
]
[[package]]
name = "itoa"
version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c"
[[package]]
name = "leb128fmt"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2"
[[package]]
name = "litemap"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77"
[[package]]
name = "log"
version = "0.4.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897"
[[package]]
name = "memchr"
version = "2.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273"
[[package]]
name = "miniz_oxide"
version = "0.8.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316"
dependencies = [
"adler2",
"simd-adler32",
]
[[package]]
name = "once_cell"
version = "1.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
[[package]]
name = "percent-encoding"
version = "2.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220"
[[package]]
name = "pin-project-lite"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b"
[[package]]
name = "pin-utils"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]]
name = "potential_utf"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b73949432f5e2a09657003c25bca5e19a0e9c84f8058ca374f49e0ebe605af77"
dependencies = [
"zerovec",
]
[[package]]
name = "prettyplease"
version = "0.2.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b"
dependencies = [
"proc-macro2",
"syn",
]
[[package]]
name = "proc-macro2"
version = "1.0.103"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ee95bc4ef87b8d5ba32e8b7714ccc834865276eab0aed5c9958d00ec45f49e8"
dependencies = [
"unicode-ident",
]
[[package]]
name = "quote"
version = "1.0.42"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a338cc41d27e6cc6dce6cefc13a0729dfbb81c262b1f519331575dd80ef3067f"
dependencies = [
"proc-macro2",
]
[[package]]
name = "ryu"
version = "1.0.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f"
[[package]]
name = "semver"
version = "1.0.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2"
dependencies = [
"serde",
"serde_core",
]
[[package]]
name = "serde"
version = "1.0.228"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e"
dependencies = [
"serde_core",
"serde_derive",
]
[[package]]
name = "serde_core"
version = "1.0.228"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.228"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "serde_json"
version = "1.0.145"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c"
dependencies = [
"itoa",
"memchr",
"ryu",
"serde",
"serde_core",
]
[[package]]
name = "simd-adler32"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe"
[[package]]
name = "slab"
version = "0.4.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589"
[[package]]
name = "smallvec"
version = "1.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03"
[[package]]
name = "spdx"
version = "0.10.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3e17e880bafaeb362a7b751ec46bdc5b61445a188f80e0606e68167cd540fa3"
dependencies = [
"smallvec",
]
[[package]]
name = "stable_deref_trait"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596"
[[package]]
name = "syn"
version = "2.0.111"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "synstructure"
version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "tinystr"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42d3e9c45c09de15d06dd8acf5f4e0e399e85927b7f00711024eb7ae10fa4869"
dependencies = [
"displaydoc",
"zerovec",
]
[[package]]
name = "topological-sort"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ea68304e134ecd095ac6c3574494fc62b909f416c4fca77e440530221e549d3d"
[[package]]
name = "unicode-ident"
version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
[[package]]
name = "unicode-xid"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
[[package]]
name = "url"
version = "2.5.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08bc136a29a3d1758e07a9cca267be308aeebf5cfd5a10f3f67ab2097683ef5b"
dependencies = [
"form_urlencoded",
"idna",
"percent-encoding",
"serde",
]
[[package]]
name = "utf8_iter"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
[[package]]
name = "wasm-encoder"
version = "0.227.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80bb72f02e7fbf07183443b27b0f3d4144abf8c114189f2e088ed95b696a7822"
dependencies = [
"leb128fmt",
"wasmparser",
]
[[package]]
name = "wasm-metadata"
version = "0.227.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ce1ef0faabbbba6674e97a56bee857ccddf942785a336c8b47b42373c922a91d"
dependencies = [
"anyhow",
"auditable-serde",
"flate2",
"indexmap",
"serde",
"serde_derive",
"serde_json",
"spdx",
"url",
"wasm-encoder",
"wasmparser",
]
[[package]]
name = "wasmparser"
version = "0.227.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f51cad774fb3c9461ab9bccc9c62dfb7388397b5deda31bf40e8108ccd678b2"
dependencies = [
"bitflags",
"hashbrown 0.15.5",
"indexmap",
"semver",
]
[[package]]
name = "wit-bindgen"
version = "0.41.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10fb6648689b3929d56bbc7eb1acf70c9a42a29eb5358c67c10f54dbd5d695de"
dependencies = [
"wit-bindgen-rt",
"wit-bindgen-rust-macro",
]
[[package]]
name = "wit-bindgen-core"
version = "0.41.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92fa781d4f2ff6d3f27f3cc9b74a73327b31ca0dc4a3ef25a0ce2983e0e5af9b"
dependencies = [
"anyhow",
"heck",
"wit-parser",
]
[[package]]
name = "wit-bindgen-rt"
version = "0.41.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4db52a11d4dfb0a59f194c064055794ee6564eb1ced88c25da2cf76e50c5621"
dependencies = [
"bitflags",
"futures",
"once_cell",
]
[[package]]
name = "wit-bindgen-rust"
version = "0.41.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d0809dc5ba19e2e98661bf32fc0addc5a3ca5bf3a6a7083aa6ba484085ff3ce"
dependencies = [
"anyhow",
"heck",
"indexmap",
"prettyplease",
"syn",
"wasm-metadata",
"wit-bindgen-core",
"wit-component",
]
[[package]]
name = "wit-bindgen-rust-macro"
version = "0.41.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad19eec017904e04c60719592a803ee5da76cb51c81e3f6fbf9457f59db49799"
dependencies = [
"anyhow",
"prettyplease",
"proc-macro2",
"quote",
"syn",
"wit-bindgen-core",
"wit-bindgen-rust",
]
[[package]]
name = "wit-component"
version = "0.227.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "635c3adc595422cbf2341a17fb73a319669cc8d33deed3a48368a841df86b676"
dependencies = [
"anyhow",
"bitflags",
"indexmap",
"log",
"serde",
"serde_derive",
"serde_json",
"wasm-encoder",
"wasm-metadata",
"wasmparser",
"wit-parser",
]
[[package]]
name = "wit-parser"
version = "0.227.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ddf445ed5157046e4baf56f9138c124a0824d4d1657e7204d71886ad8ce2fc11"
dependencies = [
"anyhow",
"id-arena",
"indexmap",
"log",
"semver",
"serde",
"serde_derive",
"serde_json",
"unicode-xid",
"wasmparser",
]
[[package]]
name = "writeable"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9"
[[package]]
name = "yoke"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72d6e5c6afb84d73944e5cedb052c4680d5657337201555f9f2a16b7406d4954"
dependencies = [
"stable_deref_trait",
"yoke-derive",
"zerofrom",
]
[[package]]
name = "yoke-derive"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d"
dependencies = [
"proc-macro2",
"quote",
"syn",
"synstructure",
]
[[package]]
name = "zed_extension_api"
version = "0.8.0"
dependencies = [
"serde",
"serde_json",
"wit-bindgen",
]
[[package]]
name = "zerofrom"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5"
dependencies = [
"zerofrom-derive",
]
[[package]]
name = "zerofrom-derive"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502"
dependencies = [
"proc-macro2",
"quote",
"syn",
"synstructure",
]
[[package]]
name = "zerotrie"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a59c17a5562d507e4b54960e8569ebee33bee890c70aa3fe7b97e85a9fd7851"
dependencies = [
"displaydoc",
"yoke",
"zerofrom",
]
[[package]]
name = "zerovec"
version = "0.11.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c28719294829477f525be0186d13efa9a3c602f7ec202ca9e353d310fb9a002"
dependencies = [
"yoke",
"zerofrom",
"zerovec-derive",
]
[[package]]
name = "zerovec-derive"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eadce39539ca5cb3985590102671f2567e659fca9666581ad3411d59207951f3"
dependencies = [
"proc-macro2",
"quote",
"syn",
]

View File

@@ -0,0 +1,17 @@
[package]
name = "anthropic"
version = "0.1.0"
edition = "2021"
publish = false
license = "Apache-2.0"
[workspace]
[lib]
path = "src/anthropic.rs"
crate-type = ["cdylib"]
[dependencies]
zed_extension_api = { path = "../../crates/extension_api" }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"

View File

@@ -0,0 +1,13 @@
id = "anthropic"
name = "Anthropic"
description = "Anthropic Claude LLM provider for Zed."
version = "0.1.0"
schema_version = 1
authors = ["Zed Team"]
repository = "https://github.com/zed-industries/zed"
[language_model_providers.anthropic]
name = "Anthropic"
[language_model_providers.anthropic.auth]
env_var = "ANTHROPIC_API_KEY"

View File

@@ -0,0 +1,11 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<g clip-path="url(#clip0_1896_18)">
<path d="M11.094 3.09999H8.952L12.858 12.9H15L11.094 3.09999Z" fill="black"/>
<path d="M4.906 3.09999L1 12.9H3.184L3.98284 10.842H8.06915L8.868 12.9H11.052L7.146 3.09999H4.906ZM4.68928 9.02199L6.026 5.57799L7.3627 9.02199H4.68928Z" fill="black"/>
</g>
<defs>
<clipPath id="clip0_1896_18">
<rect width="14" height="9.8" fill="white" transform="translate(1 3.09999)"/>
</clipPath>
</defs>
</svg>

After

Width:  |  Height:  |  Size: 526 B

View File

@@ -0,0 +1,754 @@
use std::collections::HashMap;
use std::sync::Mutex;
use serde::{Deserialize, Serialize};
use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
use zed_extension_api::{self as zed, *};
struct AnthropicProvider {
streams: Mutex<HashMap<String, StreamState>>,
next_stream_id: Mutex<u64>,
}
struct StreamState {
response_stream: Option<HttpResponseStream>,
buffer: String,
started: bool,
current_tool_use: Option<ToolUseState>,
stop_reason: Option<LlmStopReason>,
pending_signature: Option<String>,
}
struct ToolUseState {
id: String,
name: String,
input_json: String,
}
struct ModelDefinition {
real_id: &'static str,
display_name: &'static str,
max_tokens: u64,
max_output_tokens: u64,
supports_images: bool,
supports_thinking: bool,
is_default: bool,
is_default_fast: bool,
}
const MODELS: &[ModelDefinition] = &[
ModelDefinition {
real_id: "claude-opus-4-5-20251101",
display_name: "Claude Opus 4.5",
max_tokens: 200_000,
max_output_tokens: 8_192,
supports_images: true,
supports_thinking: false,
is_default: false,
is_default_fast: false,
},
ModelDefinition {
real_id: "claude-opus-4-5-20251101",
display_name: "Claude Opus 4.5 Thinking",
max_tokens: 200_000,
max_output_tokens: 8_192,
supports_images: true,
supports_thinking: true,
is_default: false,
is_default_fast: false,
},
ModelDefinition {
real_id: "claude-sonnet-4-5-20250929",
display_name: "Claude Sonnet 4.5",
max_tokens: 200_000,
max_output_tokens: 8_192,
supports_images: true,
supports_thinking: false,
is_default: true,
is_default_fast: false,
},
ModelDefinition {
real_id: "claude-sonnet-4-5-20250929",
display_name: "Claude Sonnet 4.5 Thinking",
max_tokens: 200_000,
max_output_tokens: 8_192,
supports_images: true,
supports_thinking: true,
is_default: false,
is_default_fast: false,
},
ModelDefinition {
real_id: "claude-sonnet-4-20250514",
display_name: "Claude Sonnet 4",
max_tokens: 200_000,
max_output_tokens: 8_192,
supports_images: true,
supports_thinking: false,
is_default: false,
is_default_fast: false,
},
ModelDefinition {
real_id: "claude-sonnet-4-20250514",
display_name: "Claude Sonnet 4 Thinking",
max_tokens: 200_000,
max_output_tokens: 8_192,
supports_images: true,
supports_thinking: true,
is_default: false,
is_default_fast: false,
},
ModelDefinition {
real_id: "claude-haiku-4-5-20251001",
display_name: "Claude Haiku 4.5",
max_tokens: 200_000,
max_output_tokens: 64_000,
supports_images: true,
supports_thinking: false,
is_default: false,
is_default_fast: true,
},
ModelDefinition {
real_id: "claude-haiku-4-5-20251001",
display_name: "Claude Haiku 4.5 Thinking",
max_tokens: 200_000,
max_output_tokens: 64_000,
supports_images: true,
supports_thinking: true,
is_default: false,
is_default_fast: false,
},
ModelDefinition {
real_id: "claude-3-5-sonnet-latest",
display_name: "Claude 3.5 Sonnet",
max_tokens: 200_000,
max_output_tokens: 8_192,
supports_images: true,
supports_thinking: false,
is_default: false,
is_default_fast: false,
},
ModelDefinition {
real_id: "claude-3-5-haiku-latest",
display_name: "Claude 3.5 Haiku",
max_tokens: 200_000,
max_output_tokens: 8_192,
supports_images: true,
supports_thinking: false,
is_default: false,
is_default_fast: false,
},
];
fn get_model_definition(display_name: &str) -> Option<&'static ModelDefinition> {
MODELS.iter().find(|m| m.display_name == display_name)
}
// Anthropic API Request Types
#[derive(Serialize)]
struct AnthropicRequest {
model: String,
max_tokens: u64,
messages: Vec<AnthropicMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
system: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
thinking: Option<AnthropicThinking>,
#[serde(skip_serializing_if = "Vec::is_empty")]
tools: Vec<AnthropicTool>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_choice: Option<AnthropicToolChoice>,
#[serde(skip_serializing_if = "Vec::is_empty")]
stop_sequences: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
stream: bool,
}
#[derive(Serialize)]
struct AnthropicThinking {
#[serde(rename = "type")]
thinking_type: String,
#[serde(skip_serializing_if = "Option::is_none")]
budget_tokens: Option<u32>,
}
#[derive(Serialize)]
struct AnthropicMessage {
role: String,
content: Vec<AnthropicContent>,
}
#[derive(Serialize, Clone)]
#[serde(tag = "type")]
enum AnthropicContent {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "thinking")]
Thinking { thinking: String, signature: String },
#[serde(rename = "redacted_thinking")]
RedactedThinking { data: String },
#[serde(rename = "image")]
Image { source: AnthropicImageSource },
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
#[serde(rename = "tool_result")]
ToolResult {
tool_use_id: String,
is_error: bool,
content: String,
},
}
#[derive(Serialize, Clone)]
struct AnthropicImageSource {
#[serde(rename = "type")]
source_type: String,
media_type: String,
data: String,
}
#[derive(Serialize)]
struct AnthropicTool {
name: String,
description: String,
input_schema: serde_json::Value,
}
#[derive(Serialize)]
#[serde(tag = "type", rename_all = "lowercase")]
enum AnthropicToolChoice {
Auto,
Any,
None,
}
// Anthropic API Response Types
#[derive(Deserialize, Debug)]
#[serde(tag = "type")]
#[allow(dead_code)]
enum AnthropicEvent {
#[serde(rename = "message_start")]
MessageStart { message: AnthropicMessageResponse },
#[serde(rename = "content_block_start")]
ContentBlockStart {
index: usize,
content_block: AnthropicContentBlock,
},
#[serde(rename = "content_block_delta")]
ContentBlockDelta { index: usize, delta: AnthropicDelta },
#[serde(rename = "content_block_stop")]
ContentBlockStop { index: usize },
#[serde(rename = "message_delta")]
MessageDelta {
delta: AnthropicMessageDelta,
usage: AnthropicUsage,
},
#[serde(rename = "message_stop")]
MessageStop,
#[serde(rename = "ping")]
Ping,
#[serde(rename = "error")]
Error { error: AnthropicApiError },
}
#[derive(Deserialize, Debug)]
struct AnthropicMessageResponse {
#[allow(dead_code)]
id: String,
#[allow(dead_code)]
role: String,
#[serde(default)]
usage: AnthropicUsage,
}
#[derive(Deserialize, Debug)]
#[serde(tag = "type")]
enum AnthropicContentBlock {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "thinking")]
Thinking { thinking: String },
#[serde(rename = "redacted_thinking")]
RedactedThinking { data: String },
#[serde(rename = "tool_use")]
ToolUse { id: String, name: String },
}
#[derive(Deserialize, Debug)]
#[serde(tag = "type")]
enum AnthropicDelta {
#[serde(rename = "text_delta")]
TextDelta { text: String },
#[serde(rename = "thinking_delta")]
ThinkingDelta { thinking: String },
#[serde(rename = "signature_delta")]
SignatureDelta { signature: String },
#[serde(rename = "input_json_delta")]
InputJsonDelta { partial_json: String },
}
#[derive(Deserialize, Debug)]
struct AnthropicMessageDelta {
stop_reason: Option<String>,
}
#[derive(Deserialize, Debug, Default)]
struct AnthropicUsage {
#[serde(default)]
input_tokens: Option<u64>,
#[serde(default)]
output_tokens: Option<u64>,
#[serde(default)]
cache_creation_input_tokens: Option<u64>,
#[serde(default)]
cache_read_input_tokens: Option<u64>,
}
#[derive(Deserialize, Debug)]
struct AnthropicApiError {
#[serde(rename = "type")]
#[allow(dead_code)]
error_type: String,
message: String,
}
fn convert_request(
model_id: &str,
request: &LlmCompletionRequest,
) -> Result<AnthropicRequest, String> {
let model_def =
get_model_definition(model_id).ok_or_else(|| format!("Unknown model: {}", model_id))?;
let mut messages: Vec<AnthropicMessage> = Vec::new();
let mut system_message = String::new();
for msg in &request.messages {
match msg.role {
LlmMessageRole::System => {
for content in &msg.content {
if let LlmMessageContent::Text(text) = content {
if !system_message.is_empty() {
system_message.push('\n');
}
system_message.push_str(text);
}
}
}
LlmMessageRole::User => {
let mut contents: Vec<AnthropicContent> = Vec::new();
for content in &msg.content {
match content {
LlmMessageContent::Text(text) => {
if !text.is_empty() {
contents.push(AnthropicContent::Text { text: text.clone() });
}
}
LlmMessageContent::Image(img) => {
contents.push(AnthropicContent::Image {
source: AnthropicImageSource {
source_type: "base64".to_string(),
media_type: "image/png".to_string(),
data: img.source.clone(),
},
});
}
LlmMessageContent::ToolResult(result) => {
let content_text = match &result.content {
LlmToolResultContent::Text(t) => t.clone(),
LlmToolResultContent::Image(_) => "[Image]".to_string(),
};
contents.push(AnthropicContent::ToolResult {
tool_use_id: result.tool_use_id.clone(),
is_error: result.is_error,
content: content_text,
});
}
_ => {}
}
}
if !contents.is_empty() {
messages.push(AnthropicMessage {
role: "user".to_string(),
content: contents,
});
}
}
LlmMessageRole::Assistant => {
let mut contents: Vec<AnthropicContent> = Vec::new();
for content in &msg.content {
match content {
LlmMessageContent::Text(text) => {
if !text.is_empty() {
contents.push(AnthropicContent::Text { text: text.clone() });
}
}
LlmMessageContent::ToolUse(tool_use) => {
let input: serde_json::Value =
serde_json::from_str(&tool_use.input).unwrap_or_default();
contents.push(AnthropicContent::ToolUse {
id: tool_use.id.clone(),
name: tool_use.name.clone(),
input,
});
}
LlmMessageContent::Thinking(thinking) => {
if !thinking.text.is_empty() {
contents.push(AnthropicContent::Thinking {
thinking: thinking.text.clone(),
signature: thinking.signature.clone().unwrap_or_default(),
});
}
}
LlmMessageContent::RedactedThinking(data) => {
if !data.is_empty() {
contents.push(AnthropicContent::RedactedThinking {
data: data.clone(),
});
}
}
_ => {}
}
}
if !contents.is_empty() {
messages.push(AnthropicMessage {
role: "assistant".to_string(),
content: contents,
});
}
}
}
}
let tools: Vec<AnthropicTool> = request
.tools
.iter()
.map(|t| AnthropicTool {
name: t.name.clone(),
description: t.description.clone(),
input_schema: serde_json::from_str(&t.input_schema)
.unwrap_or(serde_json::Value::Object(Default::default())),
})
.collect();
let tool_choice = request.tool_choice.as_ref().map(|tc| match tc {
LlmToolChoice::Auto => AnthropicToolChoice::Auto,
LlmToolChoice::Any => AnthropicToolChoice::Any,
LlmToolChoice::None => AnthropicToolChoice::None,
});
let thinking = if model_def.supports_thinking && request.thinking_allowed {
Some(AnthropicThinking {
thinking_type: "enabled".to_string(),
budget_tokens: Some(4096),
})
} else {
None
};
Ok(AnthropicRequest {
model: model_def.real_id.to_string(),
max_tokens: model_def.max_output_tokens,
messages,
system: if system_message.is_empty() {
None
} else {
Some(system_message)
},
thinking,
tools,
tool_choice,
stop_sequences: request.stop_sequences.clone(),
temperature: request.temperature,
stream: true,
})
}
fn parse_sse_line(line: &str) -> Option<AnthropicEvent> {
let data = line.strip_prefix("data: ")?;
serde_json::from_str(data).ok()
}
impl zed::Extension for AnthropicProvider {
fn new() -> Self {
Self {
streams: Mutex::new(HashMap::new()),
next_stream_id: Mutex::new(0),
}
}
fn llm_providers(&self) -> Vec<LlmProviderInfo> {
vec![LlmProviderInfo {
id: "anthropic".into(),
name: "Anthropic".into(),
icon: Some("icons/anthropic.svg".into()),
}]
}
fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
Ok(MODELS
.iter()
.map(|m| LlmModelInfo {
id: m.display_name.to_string(),
name: m.display_name.to_string(),
max_token_count: m.max_tokens,
max_output_tokens: Some(m.max_output_tokens),
capabilities: LlmModelCapabilities {
supports_images: m.supports_images,
supports_tools: true,
supports_tool_choice_auto: true,
supports_tool_choice_any: true,
supports_tool_choice_none: true,
supports_thinking: m.supports_thinking,
tool_input_format: LlmToolInputFormat::JsonSchema,
},
is_default: m.is_default,
is_default_fast: m.is_default_fast,
})
.collect())
}
fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
llm_get_credential("anthropic").is_some()
}
fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
Some(
"To use Anthropic, you need an API key. You can create one [here](https://console.anthropic.com/settings/keys).".to_string(),
)
}
fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
llm_delete_credential("anthropic")
}
fn llm_stream_completion_start(
&mut self,
_provider_id: &str,
model_id: &str,
request: &LlmCompletionRequest,
) -> Result<String, String> {
let api_key = llm_get_credential("anthropic").ok_or_else(|| {
"No API key configured. Please add your Anthropic API key in settings.".to_string()
})?;
let anthropic_request = convert_request(model_id, request)?;
let body = serde_json::to_vec(&anthropic_request)
.map_err(|e| format!("Failed to serialize request: {}", e))?;
let http_request = HttpRequest {
method: HttpMethod::Post,
url: "https://api.anthropic.com/v1/messages".to_string(),
headers: vec![
("Content-Type".to_string(), "application/json".to_string()),
("x-api-key".to_string(), api_key),
("anthropic-version".to_string(), "2023-06-01".to_string()),
],
body: Some(body),
redirect_policy: RedirectPolicy::FollowAll,
};
let response_stream = http_request
.fetch_stream()
.map_err(|e| format!("HTTP request failed: {}", e))?;
let stream_id = {
let mut id_counter = self.next_stream_id.lock().unwrap();
let id = format!("anthropic-stream-{}", *id_counter);
*id_counter += 1;
id
};
self.streams.lock().unwrap().insert(
stream_id.clone(),
StreamState {
response_stream: Some(response_stream),
buffer: String::new(),
started: false,
current_tool_use: None,
stop_reason: None,
pending_signature: None,
},
);
Ok(stream_id)
}
fn llm_stream_completion_next(
&mut self,
stream_id: &str,
) -> Result<Option<LlmCompletionEvent>, String> {
let mut streams = self.streams.lock().unwrap();
let state = streams
.get_mut(stream_id)
.ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
if !state.started {
state.started = true;
return Ok(Some(LlmCompletionEvent::Started));
}
let response_stream = state
.response_stream
.as_mut()
.ok_or_else(|| "Stream already closed".to_string())?;
loop {
if let Some(newline_pos) = state.buffer.find('\n') {
let line = state.buffer[..newline_pos].to_string();
state.buffer = state.buffer[newline_pos + 1..].to_string();
if line.trim().is_empty() || line.starts_with("event:") {
continue;
}
if let Some(event) = parse_sse_line(&line) {
match event {
AnthropicEvent::MessageStart { message } => {
if let (Some(input), Some(output)) =
(message.usage.input_tokens, message.usage.output_tokens)
{
return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
input_tokens: input,
output_tokens: output,
cache_creation_input_tokens: message
.usage
.cache_creation_input_tokens,
cache_read_input_tokens: message.usage.cache_read_input_tokens,
})));
}
}
AnthropicEvent::ContentBlockStart { content_block, .. } => {
match content_block {
AnthropicContentBlock::Text { text } => {
if !text.is_empty() {
return Ok(Some(LlmCompletionEvent::Text(text)));
}
}
AnthropicContentBlock::Thinking { thinking } => {
return Ok(Some(LlmCompletionEvent::Thinking(
LlmThinkingContent {
text: thinking,
signature: None,
},
)));
}
AnthropicContentBlock::RedactedThinking { data } => {
return Ok(Some(LlmCompletionEvent::RedactedThinking(data)));
}
AnthropicContentBlock::ToolUse { id, name } => {
state.current_tool_use = Some(ToolUseState {
id,
name,
input_json: String::new(),
});
}
}
}
AnthropicEvent::ContentBlockDelta { delta, .. } => match delta {
AnthropicDelta::TextDelta { text } => {
if !text.is_empty() {
return Ok(Some(LlmCompletionEvent::Text(text)));
}
}
AnthropicDelta::ThinkingDelta { thinking } => {
return Ok(Some(LlmCompletionEvent::Thinking(
LlmThinkingContent {
text: thinking,
signature: None,
},
)));
}
AnthropicDelta::SignatureDelta { signature } => {
state.pending_signature = Some(signature.clone());
return Ok(Some(LlmCompletionEvent::Thinking(
LlmThinkingContent {
text: String::new(),
signature: Some(signature),
},
)));
}
AnthropicDelta::InputJsonDelta { partial_json } => {
if let Some(ref mut tool_use) = state.current_tool_use {
tool_use.input_json.push_str(&partial_json);
}
}
},
AnthropicEvent::ContentBlockStop { .. } => {
if let Some(tool_use) = state.current_tool_use.take() {
return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
id: tool_use.id,
name: tool_use.name,
input: tool_use.input_json,
thought_signature: state.pending_signature.take(),
})));
}
}
AnthropicEvent::MessageDelta { delta, usage } => {
if let Some(reason) = delta.stop_reason {
state.stop_reason = Some(match reason.as_str() {
"end_turn" => LlmStopReason::EndTurn,
"max_tokens" => LlmStopReason::MaxTokens,
"tool_use" => LlmStopReason::ToolUse,
_ => LlmStopReason::EndTurn,
});
}
if let Some(output) = usage.output_tokens {
return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
input_tokens: usage.input_tokens.unwrap_or(0),
output_tokens: output,
cache_creation_input_tokens: usage.cache_creation_input_tokens,
cache_read_input_tokens: usage.cache_read_input_tokens,
})));
}
}
AnthropicEvent::MessageStop => {
if let Some(stop_reason) = state.stop_reason.take() {
return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
}
return Ok(Some(LlmCompletionEvent::Stop(LlmStopReason::EndTurn)));
}
AnthropicEvent::Ping => {}
AnthropicEvent::Error { error } => {
return Err(format!("API error: {}", error.message));
}
}
}
continue;
}
match response_stream.next_chunk() {
Ok(Some(chunk)) => {
let text = String::from_utf8_lossy(&chunk);
state.buffer.push_str(&text);
}
Ok(None) => {
if let Some(stop_reason) = state.stop_reason.take() {
return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
}
return Ok(None);
}
Err(e) => {
return Err(format!("Stream error: {}", e));
}
}
}
}
fn llm_stream_completion_close(&mut self, stream_id: &str) {
self.streams.lock().unwrap().remove(stream_id);
}
}
zed::register_extension!(AnthropicProvider);

823
extensions/copilot-chat/Cargo.lock generated Normal file
View File

@@ -0,0 +1,823 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 4
[[package]]
name = "adler2"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa"
[[package]]
name = "anyhow"
version = "1.0.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61"
[[package]]
name = "auditable-serde"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c7bf8143dfc3c0258df908843e169b5cc5fcf76c7718bd66135ef4a9cd558c5"
dependencies = [
"semver",
"serde",
"serde_json",
"topological-sort",
]
[[package]]
name = "bitflags"
version = "2.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3"
[[package]]
name = "cfg-if"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
[[package]]
name = "copilot-chat"
version = "0.1.0"
dependencies = [
"serde",
"serde_json",
"zed_extension_api",
]
[[package]]
name = "crc32fast"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511"
dependencies = [
"cfg-if",
]
[[package]]
name = "displaydoc"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "equivalent"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
[[package]]
name = "flate2"
version = "1.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfe33edd8e85a12a67454e37f8c75e730830d83e313556ab9ebf9ee7fbeb3bfb"
dependencies = [
"crc32fast",
"miniz_oxide",
]
[[package]]
name = "foldhash"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
[[package]]
name = "form_urlencoded"
version = "1.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf"
dependencies = [
"percent-encoding",
]
[[package]]
name = "futures"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876"
dependencies = [
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-channel"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10"
dependencies = [
"futures-core",
"futures-sink",
]
[[package]]
name = "futures-core"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e"
[[package]]
name = "futures-executor"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f"
dependencies = [
"futures-core",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-io"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6"
[[package]]
name = "futures-macro"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "futures-sink"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7"
[[package]]
name = "futures-task"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
[[package]]
name = "futures-util"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-macro",
"futures-sink",
"futures-task",
"memchr",
"pin-project-lite",
"pin-utils",
"slab",
]
[[package]]
name = "hashbrown"
version = "0.15.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1"
dependencies = [
"foldhash",
]
[[package]]
name = "hashbrown"
version = "0.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100"
[[package]]
name = "heck"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "icu_collections"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c6b649701667bbe825c3b7e6388cb521c23d88644678e83c0c4d0a621a34b43"
dependencies = [
"displaydoc",
"potential_utf",
"yoke",
"zerofrom",
"zerovec",
]
[[package]]
name = "icu_locale_core"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "edba7861004dd3714265b4db54a3c390e880ab658fec5f7db895fae2046b5bb6"
dependencies = [
"displaydoc",
"litemap",
"tinystr",
"writeable",
"zerovec",
]
[[package]]
name = "icu_normalizer"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f6c8828b67bf8908d82127b2054ea1b4427ff0230ee9141c54251934ab1b599"
dependencies = [
"icu_collections",
"icu_normalizer_data",
"icu_properties",
"icu_provider",
"smallvec",
"zerovec",
]
[[package]]
name = "icu_normalizer_data"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7aedcccd01fc5fe81e6b489c15b247b8b0690feb23304303a9e560f37efc560a"
[[package]]
name = "icu_properties"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e93fcd3157766c0c8da2f8cff6ce651a31f0810eaa1c51ec363ef790bbb5fb99"
dependencies = [
"icu_collections",
"icu_locale_core",
"icu_properties_data",
"icu_provider",
"zerotrie",
"zerovec",
]
[[package]]
name = "icu_properties_data"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02845b3647bb045f1100ecd6480ff52f34c35f82d9880e029d329c21d1054899"
[[package]]
name = "icu_provider"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85962cf0ce02e1e0a629cc34e7ca3e373ce20dda4c4d7294bbd0bf1fdb59e614"
dependencies = [
"displaydoc",
"icu_locale_core",
"writeable",
"yoke",
"zerofrom",
"zerotrie",
"zerovec",
]
[[package]]
name = "id-arena"
version = "2.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "25a2bc672d1148e28034f176e01fffebb08b35768468cc954630da77a1449005"
[[package]]
name = "idna"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de"
dependencies = [
"idna_adapter",
"smallvec",
"utf8_iter",
]
[[package]]
name = "idna_adapter"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344"
dependencies = [
"icu_normalizer",
"icu_properties",
]
[[package]]
name = "indexmap"
version = "2.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ad4bb2b565bca0645f4d68c5c9af97fba094e9791da685bf83cb5f3ce74acf2"
dependencies = [
"equivalent",
"hashbrown 0.16.1",
"serde",
"serde_core",
]
[[package]]
name = "itoa"
version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c"
[[package]]
name = "leb128fmt"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2"
[[package]]
name = "litemap"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77"
[[package]]
name = "log"
version = "0.4.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897"
[[package]]
name = "memchr"
version = "2.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273"
[[package]]
name = "miniz_oxide"
version = "0.8.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316"
dependencies = [
"adler2",
"simd-adler32",
]
[[package]]
name = "once_cell"
version = "1.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
[[package]]
name = "percent-encoding"
version = "2.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220"
[[package]]
name = "pin-project-lite"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b"
[[package]]
name = "pin-utils"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]]
name = "potential_utf"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b73949432f5e2a09657003c25bca5e19a0e9c84f8058ca374f49e0ebe605af77"
dependencies = [
"zerovec",
]
[[package]]
name = "prettyplease"
version = "0.2.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b"
dependencies = [
"proc-macro2",
"syn",
]
[[package]]
name = "proc-macro2"
version = "1.0.103"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ee95bc4ef87b8d5ba32e8b7714ccc834865276eab0aed5c9958d00ec45f49e8"
dependencies = [
"unicode-ident",
]
[[package]]
name = "quote"
version = "1.0.42"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a338cc41d27e6cc6dce6cefc13a0729dfbb81c262b1f519331575dd80ef3067f"
dependencies = [
"proc-macro2",
]
[[package]]
name = "ryu"
version = "1.0.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f"
[[package]]
name = "semver"
version = "1.0.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2"
dependencies = [
"serde",
"serde_core",
]
[[package]]
name = "serde"
version = "1.0.228"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e"
dependencies = [
"serde_core",
"serde_derive",
]
[[package]]
name = "serde_core"
version = "1.0.228"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.228"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "serde_json"
version = "1.0.145"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c"
dependencies = [
"itoa",
"memchr",
"ryu",
"serde",
"serde_core",
]
[[package]]
name = "simd-adler32"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe"
[[package]]
name = "slab"
version = "0.4.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589"
[[package]]
name = "smallvec"
version = "1.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03"
[[package]]
name = "spdx"
version = "0.10.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3e17e880bafaeb362a7b751ec46bdc5b61445a188f80e0606e68167cd540fa3"
dependencies = [
"smallvec",
]
[[package]]
name = "stable_deref_trait"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596"
[[package]]
name = "syn"
version = "2.0.111"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "synstructure"
version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "tinystr"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42d3e9c45c09de15d06dd8acf5f4e0e399e85927b7f00711024eb7ae10fa4869"
dependencies = [
"displaydoc",
"zerovec",
]
[[package]]
name = "topological-sort"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ea68304e134ecd095ac6c3574494fc62b909f416c4fca77e440530221e549d3d"
[[package]]
name = "unicode-ident"
version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
[[package]]
name = "unicode-xid"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
[[package]]
name = "url"
version = "2.5.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08bc136a29a3d1758e07a9cca267be308aeebf5cfd5a10f3f67ab2097683ef5b"
dependencies = [
"form_urlencoded",
"idna",
"percent-encoding",
"serde",
]
[[package]]
name = "utf8_iter"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
[[package]]
name = "wasm-encoder"
version = "0.227.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80bb72f02e7fbf07183443b27b0f3d4144abf8c114189f2e088ed95b696a7822"
dependencies = [
"leb128fmt",
"wasmparser",
]
[[package]]
name = "wasm-metadata"
version = "0.227.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ce1ef0faabbbba6674e97a56bee857ccddf942785a336c8b47b42373c922a91d"
dependencies = [
"anyhow",
"auditable-serde",
"flate2",
"indexmap",
"serde",
"serde_derive",
"serde_json",
"spdx",
"url",
"wasm-encoder",
"wasmparser",
]
[[package]]
name = "wasmparser"
version = "0.227.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f51cad774fb3c9461ab9bccc9c62dfb7388397b5deda31bf40e8108ccd678b2"
dependencies = [
"bitflags",
"hashbrown 0.15.5",
"indexmap",
"semver",
]
[[package]]
name = "wit-bindgen"
version = "0.41.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10fb6648689b3929d56bbc7eb1acf70c9a42a29eb5358c67c10f54dbd5d695de"
dependencies = [
"wit-bindgen-rt",
"wit-bindgen-rust-macro",
]
[[package]]
name = "wit-bindgen-core"
version = "0.41.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92fa781d4f2ff6d3f27f3cc9b74a73327b31ca0dc4a3ef25a0ce2983e0e5af9b"
dependencies = [
"anyhow",
"heck",
"wit-parser",
]
[[package]]
name = "wit-bindgen-rt"
version = "0.41.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4db52a11d4dfb0a59f194c064055794ee6564eb1ced88c25da2cf76e50c5621"
dependencies = [
"bitflags",
"futures",
"once_cell",
]
[[package]]
name = "wit-bindgen-rust"
version = "0.41.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d0809dc5ba19e2e98661bf32fc0addc5a3ca5bf3a6a7083aa6ba484085ff3ce"
dependencies = [
"anyhow",
"heck",
"indexmap",
"prettyplease",
"syn",
"wasm-metadata",
"wit-bindgen-core",
"wit-component",
]
[[package]]
name = "wit-bindgen-rust-macro"
version = "0.41.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad19eec017904e04c60719592a803ee5da76cb51c81e3f6fbf9457f59db49799"
dependencies = [
"anyhow",
"prettyplease",
"proc-macro2",
"quote",
"syn",
"wit-bindgen-core",
"wit-bindgen-rust",
]
[[package]]
name = "wit-component"
version = "0.227.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "635c3adc595422cbf2341a17fb73a319669cc8d33deed3a48368a841df86b676"
dependencies = [
"anyhow",
"bitflags",
"indexmap",
"log",
"serde",
"serde_derive",
"serde_json",
"wasm-encoder",
"wasm-metadata",
"wasmparser",
"wit-parser",
]
[[package]]
name = "wit-parser"
version = "0.227.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ddf445ed5157046e4baf56f9138c124a0824d4d1657e7204d71886ad8ce2fc11"
dependencies = [
"anyhow",
"id-arena",
"indexmap",
"log",
"semver",
"serde",
"serde_derive",
"serde_json",
"unicode-xid",
"wasmparser",
]
[[package]]
name = "writeable"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9"
[[package]]
name = "yoke"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72d6e5c6afb84d73944e5cedb052c4680d5657337201555f9f2a16b7406d4954"
dependencies = [
"stable_deref_trait",
"yoke-derive",
"zerofrom",
]
[[package]]
name = "yoke-derive"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d"
dependencies = [
"proc-macro2",
"quote",
"syn",
"synstructure",
]
[[package]]
name = "zed_extension_api"
version = "0.8.0"
dependencies = [
"serde",
"serde_json",
"wit-bindgen",
]
[[package]]
name = "zerofrom"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5"
dependencies = [
"zerofrom-derive",
]
[[package]]
name = "zerofrom-derive"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502"
dependencies = [
"proc-macro2",
"quote",
"syn",
"synstructure",
]
[[package]]
name = "zerotrie"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a59c17a5562d507e4b54960e8569ebee33bee890c70aa3fe7b97e85a9fd7851"
dependencies = [
"displaydoc",
"yoke",
"zerofrom",
]
[[package]]
name = "zerovec"
version = "0.11.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c28719294829477f525be0186d13efa9a3c602f7ec202ca9e353d310fb9a002"
dependencies = [
"yoke",
"zerofrom",
"zerovec-derive",
]
[[package]]
name = "zerovec-derive"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eadce39539ca5cb3985590102671f2567e659fca9666581ad3411d59207951f3"
dependencies = [
"proc-macro2",
"quote",
"syn",
]

View File

@@ -0,0 +1,17 @@
[package]
name = "copilot-chat"
version = "0.1.0"
edition = "2021"
publish = false
license = "Apache-2.0"
[workspace]
[lib]
path = "src/copilot_chat.rs"
crate-type = ["cdylib"]
[dependencies]
zed_extension_api = { path = "../../crates/extension_api" }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"

View File

@@ -0,0 +1,15 @@
id = "copilot-chat"
name = "Copilot Chat"
description = "GitHub Copilot Chat LLM provider for Zed."
version = "0.1.0"
schema_version = 1
authors = ["Zed Team"]
repository = "https://github.com/zed-industries/zed"
[language_model_providers.copilot-chat]
name = "Copilot Chat"
icon = "icons/copilot.svg"
[language_model_providers.copilot-chat.auth.oauth]
sign_in_button_label = "Sign in to use GitHub Copilot"
sign_in_button_icon = "github"

View File

@@ -0,0 +1,9 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path fill-rule="evenodd" clip-rule="evenodd" d="M6.44643 8.76593C6.83106 8.76593 7.14286 9.0793 7.14286 9.46588V10.9825C7.14286 11.369 6.83106 11.6824 6.44643 11.6824C6.06181 11.6824 5.75 11.369 5.75 10.9825V9.46588C5.75 9.0793 6.06181 8.76593 6.44643 8.76593Z" fill="black"/>
<path fill-rule="evenodd" clip-rule="evenodd" d="M9.57168 8.76593C9.95631 8.76593 10.2681 9.0793 10.2681 9.46588V10.9825C10.2681 11.369 9.95631 11.6824 9.57168 11.6824C9.18705 11.6824 8.87524 11.369 8.87524 10.9825V9.46588C8.87524 9.0793 9.18705 8.76593 9.57168 8.76593Z" fill="black"/>
<path d="M7.99976 4.17853C7.99976 6.67853 5.83695 7.28202 4.30332 7.28202C2.76971 7.28202 2.44604 6.1547 2.44604 4.76409C2.44604 3.37347 3.68929 2.24615 5.2229 2.24615C6.75651 2.24615 7.99976 2.78791 7.99976 4.17853Z" fill="black" fill-opacity="0.5" stroke="black" stroke-width="1.2"/>
<path d="M8 4.17853C8 6.67853 10.1628 7.28202 11.6965 7.28202C13.2301 7.28202 13.5537 6.1547 13.5537 4.76409C13.5537 3.37347 12.3105 2.24615 10.7769 2.24615C9.24325 2.24615 8 2.78791 8 4.17853Z" fill="black" fill-opacity="0.5" stroke="black" stroke-width="1.2"/>
<path d="M12.5894 6.875C12.5894 6.875 13.3413 7.35585 13.7144 8.08398C14.0876 8.81212 14.0894 10.4985 13.7144 11.1064C13.3395 11.7143 12.8931 12.1429 11.7637 12.7543C10.6344 13.3657 9.143 13.7321 9.143 13.7321H6.85728C6.85728 13.7321 5.37513 13.4107 4.23656 12.7543C3.09798 12.0978 2.55371 11.6786 2.28585 11.1064C2.01799 10.5342 1.92871 8.85715 2.28585 8.08398C2.64299 7.31081 3.42871 6.875 3.42871 6.875" stroke="black" stroke-width="1.2" stroke-linejoin="round"/>
<path d="M11.9375 12.6016V7.33636L13.9052 7.99224V10.9255L11.9375 12.6016Z" fill="black" fill-opacity="0.75"/>
<path d="M4.01793 12.6016V7.33636L2.05029 7.99224V10.9255L4.01793 12.6016Z" fill="black" fill-opacity="0.75"/>
</svg>

After

Width:  |  Height:  |  Size: 1.9 KiB

View File

@@ -0,0 +1,958 @@
use std::collections::HashMap;
use std::sync::Mutex;
use std::thread;
use std::time::Duration;
use serde::{Deserialize, Serialize};
use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
use zed_extension_api::{self as zed, *};
const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
const GITHUB_COPILOT_TOKEN_URL: &str = "https://api.github.com/copilot_internal/v2/token";
const GITHUB_COPILOT_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
struct DeviceFlowState {
device_code: String,
interval: u64,
expires_in: u64,
}
#[derive(Clone)]
struct ApiToken {
api_key: String,
api_endpoint: String,
}
#[derive(Clone, Deserialize)]
struct CopilotModel {
id: String,
name: String,
#[serde(default)]
is_chat_default: bool,
#[serde(default)]
is_chat_fallback: bool,
#[serde(default)]
model_picker_enabled: bool,
#[serde(default)]
capabilities: ModelCapabilities,
#[serde(default)]
policy: Option<ModelPolicy>,
}
#[derive(Clone, Default, Deserialize)]
struct ModelCapabilities {
#[serde(default)]
family: String,
#[serde(default)]
limits: ModelLimits,
#[serde(default)]
supports: ModelSupportedFeatures,
#[serde(rename = "type", default)]
model_type: String,
}
#[derive(Clone, Default, Deserialize)]
struct ModelLimits {
#[serde(default)]
max_context_window_tokens: u64,
#[serde(default)]
max_output_tokens: u64,
}
#[derive(Clone, Default, Deserialize)]
struct ModelSupportedFeatures {
#[serde(default)]
streaming: bool,
#[serde(default)]
tool_calls: bool,
#[serde(default)]
vision: bool,
}
#[derive(Clone, Deserialize)]
struct ModelPolicy {
state: String,
}
struct CopilotChatProvider {
streams: Mutex<HashMap<String, StreamState>>,
next_stream_id: Mutex<u64>,
device_flow_state: Mutex<Option<DeviceFlowState>>,
api_token: Mutex<Option<ApiToken>>,
cached_models: Mutex<Option<Vec<CopilotModel>>>,
}
struct StreamState {
response_stream: Option<HttpResponseStream>,
buffer: String,
started: bool,
tool_calls: HashMap<usize, AccumulatedToolCall>,
tool_calls_emitted: bool,
}
#[derive(Clone, Default)]
struct AccumulatedToolCall {
id: String,
name: String,
arguments: String,
}
#[derive(Serialize)]
struct OpenAiRequest {
model: String,
messages: Vec<OpenAiMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u64>,
#[serde(skip_serializing_if = "Vec::is_empty")]
tools: Vec<OpenAiTool>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_choice: Option<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
stop: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
stream_options: Option<StreamOptions>,
}
#[derive(Serialize)]
struct StreamOptions {
include_usage: bool,
}
#[derive(Serialize)]
struct OpenAiMessage {
role: String,
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<OpenAiContent>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<OpenAiToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_call_id: Option<String>,
}
#[derive(Serialize, Clone)]
#[serde(untagged)]
enum OpenAiContent {
Text(String),
Parts(Vec<OpenAiContentPart>),
}
#[derive(Serialize, Clone)]
#[serde(tag = "type")]
enum OpenAiContentPart {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
ImageUrl { image_url: ImageUrl },
}
#[derive(Serialize, Clone)]
struct ImageUrl {
url: String,
}
#[derive(Serialize, Clone)]
struct OpenAiToolCall {
id: String,
#[serde(rename = "type")]
call_type: String,
function: OpenAiFunctionCall,
}
#[derive(Serialize, Clone)]
struct OpenAiFunctionCall {
name: String,
arguments: String,
}
#[derive(Serialize)]
struct OpenAiTool {
#[serde(rename = "type")]
tool_type: String,
function: OpenAiFunctionDef,
}
#[derive(Serialize)]
struct OpenAiFunctionDef {
name: String,
description: String,
parameters: serde_json::Value,
}
#[derive(Deserialize, Debug)]
struct OpenAiStreamResponse {
choices: Vec<OpenAiStreamChoice>,
#[serde(default)]
usage: Option<OpenAiUsage>,
}
#[derive(Deserialize, Debug)]
struct OpenAiStreamChoice {
delta: OpenAiDelta,
finish_reason: Option<String>,
}
#[derive(Deserialize, Debug, Default)]
struct OpenAiDelta {
#[serde(default)]
content: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<OpenAiToolCallDelta>>,
}
#[derive(Deserialize, Debug)]
struct OpenAiToolCallDelta {
index: usize,
#[serde(default)]
id: Option<String>,
#[serde(default)]
function: Option<OpenAiFunctionDelta>,
}
#[derive(Deserialize, Debug, Default)]
struct OpenAiFunctionDelta {
#[serde(default)]
name: Option<String>,
#[serde(default)]
arguments: Option<String>,
}
#[derive(Deserialize, Debug)]
struct OpenAiUsage {
prompt_tokens: u64,
completion_tokens: u64,
}
fn convert_request(
model_id: &str,
request: &LlmCompletionRequest,
) -> Result<OpenAiRequest, String> {
let mut messages: Vec<OpenAiMessage> = Vec::new();
for msg in &request.messages {
match msg.role {
LlmMessageRole::System => {
let mut text_content = String::new();
for content in &msg.content {
if let LlmMessageContent::Text(text) = content {
if !text_content.is_empty() {
text_content.push('\n');
}
text_content.push_str(text);
}
}
if !text_content.is_empty() {
messages.push(OpenAiMessage {
role: "system".to_string(),
content: Some(OpenAiContent::Text(text_content)),
tool_calls: None,
tool_call_id: None,
});
}
}
LlmMessageRole::User => {
let mut parts: Vec<OpenAiContentPart> = Vec::new();
let mut tool_result_messages: Vec<OpenAiMessage> = Vec::new();
for content in &msg.content {
match content {
LlmMessageContent::Text(text) => {
if !text.is_empty() {
parts.push(OpenAiContentPart::Text { text: text.clone() });
}
}
LlmMessageContent::Image(img) => {
let data_url = format!("data:image/png;base64,{}", img.source);
parts.push(OpenAiContentPart::ImageUrl {
image_url: ImageUrl { url: data_url },
});
}
LlmMessageContent::ToolResult(result) => {
let content_text = match &result.content {
LlmToolResultContent::Text(t) => t.clone(),
LlmToolResultContent::Image(_) => "[Image]".to_string(),
};
tool_result_messages.push(OpenAiMessage {
role: "tool".to_string(),
content: Some(OpenAiContent::Text(content_text)),
tool_calls: None,
tool_call_id: Some(result.tool_use_id.clone()),
});
}
_ => {}
}
}
if !parts.is_empty() {
let content = if parts.len() == 1 {
if let OpenAiContentPart::Text { text } = &parts[0] {
OpenAiContent::Text(text.clone())
} else {
OpenAiContent::Parts(parts)
}
} else {
OpenAiContent::Parts(parts)
};
messages.push(OpenAiMessage {
role: "user".to_string(),
content: Some(content),
tool_calls: None,
tool_call_id: None,
});
}
messages.extend(tool_result_messages);
}
LlmMessageRole::Assistant => {
let mut text_content = String::new();
let mut tool_calls: Vec<OpenAiToolCall> = Vec::new();
for content in &msg.content {
match content {
LlmMessageContent::Text(text) => {
if !text.is_empty() {
if !text_content.is_empty() {
text_content.push('\n');
}
text_content.push_str(text);
}
}
LlmMessageContent::ToolUse(tool_use) => {
tool_calls.push(OpenAiToolCall {
id: tool_use.id.clone(),
call_type: "function".to_string(),
function: OpenAiFunctionCall {
name: tool_use.name.clone(),
arguments: tool_use.input.clone(),
},
});
}
_ => {}
}
}
messages.push(OpenAiMessage {
role: "assistant".to_string(),
content: if text_content.is_empty() {
None
} else {
Some(OpenAiContent::Text(text_content))
},
tool_calls: if tool_calls.is_empty() {
None
} else {
Some(tool_calls)
},
tool_call_id: None,
});
}
}
}
let tools: Vec<OpenAiTool> = request
.tools
.iter()
.map(|t| OpenAiTool {
tool_type: "function".to_string(),
function: OpenAiFunctionDef {
name: t.name.clone(),
description: t.description.clone(),
parameters: serde_json::from_str(&t.input_schema)
.unwrap_or(serde_json::Value::Object(Default::default())),
},
})
.collect();
let tool_choice = request.tool_choice.as_ref().map(|tc| match tc {
LlmToolChoice::Auto => "auto".to_string(),
LlmToolChoice::Any => "required".to_string(),
LlmToolChoice::None => "none".to_string(),
});
let max_tokens = request.max_tokens;
Ok(OpenAiRequest {
model: model_id.to_string(),
messages,
max_tokens,
tools,
tool_choice,
stop: request.stop_sequences.clone(),
temperature: request.temperature,
stream: true,
stream_options: Some(StreamOptions {
include_usage: true,
}),
})
}
fn parse_sse_line(line: &str) -> Option<OpenAiStreamResponse> {
let data = line.strip_prefix("data: ")?;
if data.trim() == "[DONE]" {
return None;
}
serde_json::from_str(data).ok()
}
impl zed::Extension for CopilotChatProvider {
fn new() -> Self {
Self {
streams: Mutex::new(HashMap::new()),
next_stream_id: Mutex::new(0),
device_flow_state: Mutex::new(None),
api_token: Mutex::new(None),
cached_models: Mutex::new(None),
}
}
fn llm_providers(&self) -> Vec<LlmProviderInfo> {
vec![LlmProviderInfo {
id: "copilot-chat".into(),
name: "Copilot Chat".into(),
icon: Some("icons/copilot.svg".into()),
}]
}
fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
// Try to get models from cache first
if let Some(models) = self.cached_models.lock().unwrap().as_ref() {
return Ok(convert_models_to_llm_info(models));
}
// Need to fetch models - requires authentication
let oauth_token = match llm_get_credential("copilot-chat") {
Some(token) => token,
None => return Ok(Vec::new()), // Not authenticated, return empty
};
// Get API token
let api_token = self.get_api_token(&oauth_token)?;
// Fetch models from API
let models = self.fetch_models(&api_token)?;
// Cache the models
*self.cached_models.lock().unwrap() = Some(models.clone());
Ok(convert_models_to_llm_info(&models))
}
fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
llm_get_credential("copilot-chat").is_some()
}
fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
Some(
"To use Copilot Chat, sign in with your GitHub account. This requires an active [GitHub Copilot subscription](https://github.com/features/copilot).".to_string(),
)
}
fn llm_provider_start_device_flow_sign_in(
&mut self,
_provider_id: &str,
) -> Result<String, String> {
// Step 1: Request device and user verification codes
let device_code_response = llm_oauth_http_request(&LlmOauthHttpRequest {
url: GITHUB_DEVICE_CODE_URL.to_string(),
method: "POST".to_string(),
headers: vec![
("Accept".to_string(), "application/json".to_string()),
(
"Content-Type".to_string(),
"application/x-www-form-urlencoded".to_string(),
),
],
body: format!("client_id={}&scope=read:user", GITHUB_COPILOT_CLIENT_ID),
})?;
if device_code_response.status != 200 {
return Err(format!(
"Failed to get device code: HTTP {}",
device_code_response.status
));
}
#[derive(Deserialize)]
struct DeviceCodeResponse {
device_code: String,
user_code: String,
verification_uri: String,
#[serde(default)]
verification_uri_complete: Option<String>,
expires_in: u64,
interval: u64,
}
let device_info: DeviceCodeResponse = serde_json::from_str(&device_code_response.body)
.map_err(|e| format!("Failed to parse device code response: {}", e))?;
// Store device flow state for polling
*self.device_flow_state.lock().unwrap() = Some(DeviceFlowState {
device_code: device_info.device_code,
interval: device_info.interval,
expires_in: device_info.expires_in,
});
// Step 2: Open browser to verification URL
// Use verification_uri_complete if available (has code pre-filled), otherwise construct URL
let verification_url = device_info.verification_uri_complete.unwrap_or_else(|| {
format!(
"{}?user_code={}",
device_info.verification_uri, &device_info.user_code
)
});
llm_oauth_open_browser(&verification_url)?;
// Return the user code for the host to display
Ok(device_info.user_code)
}
fn llm_provider_poll_device_flow_sign_in(&mut self, _provider_id: &str) -> Result<(), String> {
let state = self
.device_flow_state
.lock()
.unwrap()
.take()
.ok_or("No device flow in progress")?;
let poll_interval = Duration::from_secs(state.interval.max(5));
let max_attempts = (state.expires_in / state.interval.max(5)) as usize;
for _ in 0..max_attempts {
thread::sleep(poll_interval);
let token_response = llm_oauth_http_request(&LlmOauthHttpRequest {
url: GITHUB_ACCESS_TOKEN_URL.to_string(),
method: "POST".to_string(),
headers: vec![
("Accept".to_string(), "application/json".to_string()),
(
"Content-Type".to_string(),
"application/x-www-form-urlencoded".to_string(),
),
],
body: format!(
"client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code",
GITHUB_COPILOT_CLIENT_ID, state.device_code
),
})?;
#[derive(Deserialize)]
struct TokenResponse {
access_token: Option<String>,
error: Option<String>,
error_description: Option<String>,
}
let token_json: TokenResponse = serde_json::from_str(&token_response.body)
.map_err(|e| format!("Failed to parse token response: {}", e))?;
if let Some(access_token) = token_json.access_token {
llm_store_credential("copilot-chat", &access_token)?;
return Ok(());
}
if let Some(error) = &token_json.error {
match error.as_str() {
"authorization_pending" => {
// User hasn't authorized yet, keep polling
continue;
}
"slow_down" => {
// Need to slow down polling
thread::sleep(Duration::from_secs(5));
continue;
}
"expired_token" => {
return Err("Device code expired. Please try again.".to_string());
}
"access_denied" => {
return Err("Authorization was denied.".to_string());
}
_ => {
let description = token_json.error_description.unwrap_or_default();
return Err(format!("OAuth error: {} - {}", error, description));
}
}
}
}
Err("Authorization timed out. Please try again.".to_string())
}
fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
// Clear cached API token and models
*self.api_token.lock().unwrap() = None;
*self.cached_models.lock().unwrap() = None;
llm_delete_credential("copilot-chat")
}
fn llm_stream_completion_start(
&mut self,
_provider_id: &str,
model_id: &str,
request: &LlmCompletionRequest,
) -> Result<String, String> {
let oauth_token = llm_get_credential("copilot-chat").ok_or_else(|| {
"No token configured. Please add your GitHub Copilot token in settings.".to_string()
})?;
// Get or refresh API token
let api_token = self.get_api_token(&oauth_token)?;
let openai_request = convert_request(model_id, request)?;
let body = serde_json::to_vec(&openai_request)
.map_err(|e| format!("Failed to serialize request: {}", e))?;
let completions_url = format!("{}/chat/completions", api_token.api_endpoint);
let http_request = HttpRequest {
method: HttpMethod::Post,
url: completions_url,
headers: vec![
("Content-Type".to_string(), "application/json".to_string()),
(
"Authorization".to_string(),
format!("Bearer {}", api_token.api_key),
),
(
"Copilot-Integration-Id".to_string(),
"vscode-chat".to_string(),
),
("Editor-Version".to_string(), "Zed/1.0.0".to_string()),
],
body: Some(body),
redirect_policy: RedirectPolicy::FollowAll,
};
let response_stream = http_request
.fetch_stream()
.map_err(|e| format!("HTTP request failed: {}", e))?;
let stream_id = {
let mut id_counter = self.next_stream_id.lock().unwrap();
let id = format!("copilot-stream-{}", *id_counter);
*id_counter += 1;
id
};
self.streams.lock().unwrap().insert(
stream_id.clone(),
StreamState {
response_stream: Some(response_stream),
buffer: String::new(),
started: false,
tool_calls: HashMap::new(),
tool_calls_emitted: false,
},
);
Ok(stream_id)
}
fn llm_stream_completion_next(
&mut self,
stream_id: &str,
) -> Result<Option<LlmCompletionEvent>, String> {
let mut streams = self.streams.lock().unwrap();
let state = streams
.get_mut(stream_id)
.ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
if !state.started {
state.started = true;
return Ok(Some(LlmCompletionEvent::Started));
}
let response_stream = state
.response_stream
.as_mut()
.ok_or_else(|| "Stream already closed".to_string())?;
loop {
if let Some(newline_pos) = state.buffer.find('\n') {
let line = state.buffer[..newline_pos].to_string();
state.buffer = state.buffer[newline_pos + 1..].to_string();
if line.trim().is_empty() {
continue;
}
if let Some(response) = parse_sse_line(&line) {
if let Some(choice) = response.choices.first() {
if let Some(content) = &choice.delta.content {
if !content.is_empty() {
return Ok(Some(LlmCompletionEvent::Text(content.clone())));
}
}
if let Some(tool_calls) = &choice.delta.tool_calls {
for tc in tool_calls {
let entry = state
.tool_calls
.entry(tc.index)
.or_insert_with(AccumulatedToolCall::default);
if let Some(id) = &tc.id {
entry.id = id.clone();
}
if let Some(func) = &tc.function {
if let Some(name) = &func.name {
entry.name = name.clone();
}
if let Some(args) = &func.arguments {
entry.arguments.push_str(args);
}
}
}
}
if let Some(finish_reason) = &choice.finish_reason {
if !state.tool_calls.is_empty() && !state.tool_calls_emitted {
state.tool_calls_emitted = true;
let mut tool_calls: Vec<_> = state.tool_calls.drain().collect();
tool_calls.sort_by_key(|(idx, _)| *idx);
if let Some((_, tc)) = tool_calls.into_iter().next() {
return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
id: tc.id,
name: tc.name,
input: tc.arguments,
thought_signature: None,
})));
}
}
let stop_reason = match finish_reason.as_str() {
"stop" => LlmStopReason::EndTurn,
"length" => LlmStopReason::MaxTokens,
"tool_calls" => LlmStopReason::ToolUse,
"content_filter" => LlmStopReason::Refusal,
_ => LlmStopReason::EndTurn,
};
return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
}
}
if let Some(usage) = response.usage {
return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
})));
}
}
continue;
}
match response_stream.next_chunk() {
Ok(Some(chunk)) => {
let text = String::from_utf8_lossy(&chunk);
state.buffer.push_str(&text);
}
Ok(None) => {
return Ok(None);
}
Err(e) => {
return Err(format!("Stream error: {}", e));
}
}
}
}
fn llm_stream_completion_close(&mut self, stream_id: &str) {
self.streams.lock().unwrap().remove(stream_id);
}
}
impl CopilotChatProvider {
fn get_api_token(&self, oauth_token: &str) -> Result<ApiToken, String> {
// Check if we have a cached token
if let Some(token) = self.api_token.lock().unwrap().clone() {
return Ok(token);
}
// Request a new API token
let http_request = HttpRequest {
method: HttpMethod::Get,
url: GITHUB_COPILOT_TOKEN_URL.to_string(),
headers: vec![
(
"Authorization".to_string(),
format!("token {}", oauth_token),
),
("Accept".to_string(), "application/json".to_string()),
],
body: None,
redirect_policy: RedirectPolicy::FollowAll,
};
let response = http_request
.fetch()
.map_err(|e| format!("Failed to request API token: {}", e))?;
#[derive(Deserialize)]
struct ApiTokenResponse {
token: String,
endpoints: ApiEndpoints,
}
#[derive(Deserialize)]
struct ApiEndpoints {
api: String,
}
let token_response: ApiTokenResponse =
serde_json::from_slice(&response.body).map_err(|e| {
format!(
"Failed to parse API token response: {} - body: {}",
e,
String::from_utf8_lossy(&response.body)
)
})?;
let api_token = ApiToken {
api_key: token_response.token,
api_endpoint: token_response.endpoints.api,
};
// Cache the token
*self.api_token.lock().unwrap() = Some(api_token.clone());
Ok(api_token)
}
fn fetch_models(&self, api_token: &ApiToken) -> Result<Vec<CopilotModel>, String> {
let models_url = format!("{}/models", api_token.api_endpoint);
let http_request = HttpRequest {
method: HttpMethod::Get,
url: models_url,
headers: vec![
(
"Authorization".to_string(),
format!("Bearer {}", api_token.api_key),
),
("Content-Type".to_string(), "application/json".to_string()),
(
"Copilot-Integration-Id".to_string(),
"vscode-chat".to_string(),
),
("Editor-Version".to_string(), "Zed/1.0.0".to_string()),
("x-github-api-version".to_string(), "2025-05-01".to_string()),
],
body: None,
redirect_policy: RedirectPolicy::FollowAll,
};
let response = http_request
.fetch()
.map_err(|e| format!("Failed to fetch models: {}", e))?;
#[derive(Deserialize)]
struct ModelsResponse {
data: Vec<CopilotModel>,
}
let models_response: ModelsResponse =
serde_json::from_slice(&response.body).map_err(|e| {
format!(
"Failed to parse models response: {} - body: {}",
e,
String::from_utf8_lossy(&response.body)
)
})?;
// Filter models like the built-in Copilot Chat does
let mut models: Vec<CopilotModel> = models_response
.data
.into_iter()
.filter(|model| {
model.model_picker_enabled
&& model.capabilities.model_type == "chat"
&& model
.policy
.as_ref()
.map(|p| p.state == "enabled")
.unwrap_or(true)
})
.collect();
// Sort so default model is first
if let Some(pos) = models.iter().position(|m| m.is_chat_default) {
let default_model = models.remove(pos);
models.insert(0, default_model);
}
Ok(models)
}
}
fn convert_models_to_llm_info(models: &[CopilotModel]) -> Vec<LlmModelInfo> {
models
.iter()
.map(|m| {
let max_tokens = if m.capabilities.limits.max_context_window_tokens > 0 {
m.capabilities.limits.max_context_window_tokens
} else {
128_000 // Default fallback
};
let max_output = if m.capabilities.limits.max_output_tokens > 0 {
Some(m.capabilities.limits.max_output_tokens)
} else {
None
};
LlmModelInfo {
id: m.id.clone(),
name: m.name.clone(),
max_token_count: max_tokens,
max_output_tokens: max_output,
capabilities: LlmModelCapabilities {
supports_images: m.capabilities.supports.vision,
supports_tools: m.capabilities.supports.tool_calls,
supports_tool_choice_auto: m.capabilities.supports.tool_calls,
supports_tool_choice_any: m.capabilities.supports.tool_calls,
supports_tool_choice_none: m.capabilities.supports.tool_calls,
supports_thinking: false,
tool_input_format: LlmToolInputFormat::JsonSchema,
},
is_default: m.is_chat_default,
is_default_fast: m.is_chat_fallback,
}
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_device_flow_request_body() {
let body = format!("client_id={}&scope=read:user", GITHUB_COPILOT_CLIENT_ID);
assert!(body.contains("client_id=Iv1.b507a08c87ecfe98"));
assert!(body.contains("scope=read:user"));
}
#[test]
fn test_token_poll_request_body() {
let device_code = "test_device_code_123";
let body = format!(
"client_id={}&device_code={}&grant_type=urn:ietf:params:oauth:grant-type:device_code",
GITHUB_COPILOT_CLIENT_ID, device_code
);
assert!(body.contains("client_id=Iv1.b507a08c87ecfe98"));
assert!(body.contains("device_code=test_device_code_123"));
assert!(body.contains("grant_type=urn:ietf:params:oauth:grant-type:device_code"));
}
}
zed::register_extension!(CopilotChatProvider);

823
extensions/google-ai/Cargo.lock generated Normal file
View File

@@ -0,0 +1,823 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 4
[[package]]
name = "adler2"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa"
[[package]]
name = "anyhow"
version = "1.0.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61"
[[package]]
name = "auditable-serde"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c7bf8143dfc3c0258df908843e169b5cc5fcf76c7718bd66135ef4a9cd558c5"
dependencies = [
"semver",
"serde",
"serde_json",
"topological-sort",
]
[[package]]
name = "bitflags"
version = "2.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3"
[[package]]
name = "cfg-if"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
[[package]]
name = "crc32fast"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511"
dependencies = [
"cfg-if",
]
[[package]]
name = "displaydoc"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "equivalent"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
[[package]]
name = "flate2"
version = "1.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfe33edd8e85a12a67454e37f8c75e730830d83e313556ab9ebf9ee7fbeb3bfb"
dependencies = [
"crc32fast",
"miniz_oxide",
]
[[package]]
name = "foldhash"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
[[package]]
name = "form_urlencoded"
version = "1.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf"
dependencies = [
"percent-encoding",
]
[[package]]
name = "futures"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876"
dependencies = [
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-channel"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10"
dependencies = [
"futures-core",
"futures-sink",
]
[[package]]
name = "futures-core"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e"
[[package]]
name = "futures-executor"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f"
dependencies = [
"futures-core",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-io"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6"
[[package]]
name = "futures-macro"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "futures-sink"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7"
[[package]]
name = "futures-task"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
[[package]]
name = "futures-util"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-macro",
"futures-sink",
"futures-task",
"memchr",
"pin-project-lite",
"pin-utils",
"slab",
]
[[package]]
name = "google-ai"
version = "0.1.0"
dependencies = [
"serde",
"serde_json",
"zed_extension_api",
]
[[package]]
name = "hashbrown"
version = "0.15.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1"
dependencies = [
"foldhash",
]
[[package]]
name = "hashbrown"
version = "0.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100"
[[package]]
name = "heck"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "icu_collections"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c6b649701667bbe825c3b7e6388cb521c23d88644678e83c0c4d0a621a34b43"
dependencies = [
"displaydoc",
"potential_utf",
"yoke",
"zerofrom",
"zerovec",
]
[[package]]
name = "icu_locale_core"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "edba7861004dd3714265b4db54a3c390e880ab658fec5f7db895fae2046b5bb6"
dependencies = [
"displaydoc",
"litemap",
"tinystr",
"writeable",
"zerovec",
]
[[package]]
name = "icu_normalizer"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f6c8828b67bf8908d82127b2054ea1b4427ff0230ee9141c54251934ab1b599"
dependencies = [
"icu_collections",
"icu_normalizer_data",
"icu_properties",
"icu_provider",
"smallvec",
"zerovec",
]
[[package]]
name = "icu_normalizer_data"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7aedcccd01fc5fe81e6b489c15b247b8b0690feb23304303a9e560f37efc560a"
[[package]]
name = "icu_properties"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e93fcd3157766c0c8da2f8cff6ce651a31f0810eaa1c51ec363ef790bbb5fb99"
dependencies = [
"icu_collections",
"icu_locale_core",
"icu_properties_data",
"icu_provider",
"zerotrie",
"zerovec",
]
[[package]]
name = "icu_properties_data"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02845b3647bb045f1100ecd6480ff52f34c35f82d9880e029d329c21d1054899"
[[package]]
name = "icu_provider"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85962cf0ce02e1e0a629cc34e7ca3e373ce20dda4c4d7294bbd0bf1fdb59e614"
dependencies = [
"displaydoc",
"icu_locale_core",
"writeable",
"yoke",
"zerofrom",
"zerotrie",
"zerovec",
]
[[package]]
name = "id-arena"
version = "2.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "25a2bc672d1148e28034f176e01fffebb08b35768468cc954630da77a1449005"
[[package]]
name = "idna"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de"
dependencies = [
"idna_adapter",
"smallvec",
"utf8_iter",
]
[[package]]
name = "idna_adapter"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344"
dependencies = [
"icu_normalizer",
"icu_properties",
]
[[package]]
name = "indexmap"
version = "2.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ad4bb2b565bca0645f4d68c5c9af97fba094e9791da685bf83cb5f3ce74acf2"
dependencies = [
"equivalent",
"hashbrown 0.16.1",
"serde",
"serde_core",
]
[[package]]
name = "itoa"
version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c"
[[package]]
name = "leb128fmt"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2"
[[package]]
name = "litemap"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77"
[[package]]
name = "log"
version = "0.4.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897"
[[package]]
name = "memchr"
version = "2.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273"
[[package]]
name = "miniz_oxide"
version = "0.8.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316"
dependencies = [
"adler2",
"simd-adler32",
]
[[package]]
name = "once_cell"
version = "1.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
[[package]]
name = "percent-encoding"
version = "2.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220"
[[package]]
name = "pin-project-lite"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b"
[[package]]
name = "pin-utils"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]]
name = "potential_utf"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b73949432f5e2a09657003c25bca5e19a0e9c84f8058ca374f49e0ebe605af77"
dependencies = [
"zerovec",
]
[[package]]
name = "prettyplease"
version = "0.2.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b"
dependencies = [
"proc-macro2",
"syn",
]
[[package]]
name = "proc-macro2"
version = "1.0.103"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ee95bc4ef87b8d5ba32e8b7714ccc834865276eab0aed5c9958d00ec45f49e8"
dependencies = [
"unicode-ident",
]
[[package]]
name = "quote"
version = "1.0.42"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a338cc41d27e6cc6dce6cefc13a0729dfbb81c262b1f519331575dd80ef3067f"
dependencies = [
"proc-macro2",
]
[[package]]
name = "ryu"
version = "1.0.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f"
[[package]]
name = "semver"
version = "1.0.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2"
dependencies = [
"serde",
"serde_core",
]
[[package]]
name = "serde"
version = "1.0.228"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e"
dependencies = [
"serde_core",
"serde_derive",
]
[[package]]
name = "serde_core"
version = "1.0.228"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.228"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "serde_json"
version = "1.0.145"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c"
dependencies = [
"itoa",
"memchr",
"ryu",
"serde",
"serde_core",
]
[[package]]
name = "simd-adler32"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe"
[[package]]
name = "slab"
version = "0.4.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589"
[[package]]
name = "smallvec"
version = "1.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03"
[[package]]
name = "spdx"
version = "0.10.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3e17e880bafaeb362a7b751ec46bdc5b61445a188f80e0606e68167cd540fa3"
dependencies = [
"smallvec",
]
[[package]]
name = "stable_deref_trait"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596"
[[package]]
name = "syn"
version = "2.0.111"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "synstructure"
version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "tinystr"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42d3e9c45c09de15d06dd8acf5f4e0e399e85927b7f00711024eb7ae10fa4869"
dependencies = [
"displaydoc",
"zerovec",
]
[[package]]
name = "topological-sort"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ea68304e134ecd095ac6c3574494fc62b909f416c4fca77e440530221e549d3d"
[[package]]
name = "unicode-ident"
version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
[[package]]
name = "unicode-xid"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
[[package]]
name = "url"
version = "2.5.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08bc136a29a3d1758e07a9cca267be308aeebf5cfd5a10f3f67ab2097683ef5b"
dependencies = [
"form_urlencoded",
"idna",
"percent-encoding",
"serde",
]
[[package]]
name = "utf8_iter"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
[[package]]
name = "wasm-encoder"
version = "0.227.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80bb72f02e7fbf07183443b27b0f3d4144abf8c114189f2e088ed95b696a7822"
dependencies = [
"leb128fmt",
"wasmparser",
]
[[package]]
name = "wasm-metadata"
version = "0.227.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ce1ef0faabbbba6674e97a56bee857ccddf942785a336c8b47b42373c922a91d"
dependencies = [
"anyhow",
"auditable-serde",
"flate2",
"indexmap",
"serde",
"serde_derive",
"serde_json",
"spdx",
"url",
"wasm-encoder",
"wasmparser",
]
[[package]]
name = "wasmparser"
version = "0.227.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f51cad774fb3c9461ab9bccc9c62dfb7388397b5deda31bf40e8108ccd678b2"
dependencies = [
"bitflags",
"hashbrown 0.15.5",
"indexmap",
"semver",
]
[[package]]
name = "wit-bindgen"
version = "0.41.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10fb6648689b3929d56bbc7eb1acf70c9a42a29eb5358c67c10f54dbd5d695de"
dependencies = [
"wit-bindgen-rt",
"wit-bindgen-rust-macro",
]
[[package]]
name = "wit-bindgen-core"
version = "0.41.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92fa781d4f2ff6d3f27f3cc9b74a73327b31ca0dc4a3ef25a0ce2983e0e5af9b"
dependencies = [
"anyhow",
"heck",
"wit-parser",
]
[[package]]
name = "wit-bindgen-rt"
version = "0.41.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4db52a11d4dfb0a59f194c064055794ee6564eb1ced88c25da2cf76e50c5621"
dependencies = [
"bitflags",
"futures",
"once_cell",
]
[[package]]
name = "wit-bindgen-rust"
version = "0.41.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d0809dc5ba19e2e98661bf32fc0addc5a3ca5bf3a6a7083aa6ba484085ff3ce"
dependencies = [
"anyhow",
"heck",
"indexmap",
"prettyplease",
"syn",
"wasm-metadata",
"wit-bindgen-core",
"wit-component",
]
[[package]]
name = "wit-bindgen-rust-macro"
version = "0.41.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad19eec017904e04c60719592a803ee5da76cb51c81e3f6fbf9457f59db49799"
dependencies = [
"anyhow",
"prettyplease",
"proc-macro2",
"quote",
"syn",
"wit-bindgen-core",
"wit-bindgen-rust",
]
[[package]]
name = "wit-component"
version = "0.227.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "635c3adc595422cbf2341a17fb73a319669cc8d33deed3a48368a841df86b676"
dependencies = [
"anyhow",
"bitflags",
"indexmap",
"log",
"serde",
"serde_derive",
"serde_json",
"wasm-encoder",
"wasm-metadata",
"wasmparser",
"wit-parser",
]
[[package]]
name = "wit-parser"
version = "0.227.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ddf445ed5157046e4baf56f9138c124a0824d4d1657e7204d71886ad8ce2fc11"
dependencies = [
"anyhow",
"id-arena",
"indexmap",
"log",
"semver",
"serde",
"serde_derive",
"serde_json",
"unicode-xid",
"wasmparser",
]
[[package]]
name = "writeable"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9"
[[package]]
name = "yoke"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72d6e5c6afb84d73944e5cedb052c4680d5657337201555f9f2a16b7406d4954"
dependencies = [
"stable_deref_trait",
"yoke-derive",
"zerofrom",
]
[[package]]
name = "yoke-derive"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d"
dependencies = [
"proc-macro2",
"quote",
"syn",
"synstructure",
]
[[package]]
name = "zed_extension_api"
version = "0.8.0"
dependencies = [
"serde",
"serde_json",
"wit-bindgen",
]
[[package]]
name = "zerofrom"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5"
dependencies = [
"zerofrom-derive",
]
[[package]]
name = "zerofrom-derive"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502"
dependencies = [
"proc-macro2",
"quote",
"syn",
"synstructure",
]
[[package]]
name = "zerotrie"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a59c17a5562d507e4b54960e8569ebee33bee890c70aa3fe7b97e85a9fd7851"
dependencies = [
"displaydoc",
"yoke",
"zerofrom",
]
[[package]]
name = "zerovec"
version = "0.11.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c28719294829477f525be0186d13efa9a3c602f7ec202ca9e353d310fb9a002"
dependencies = [
"yoke",
"zerofrom",
"zerovec-derive",
]
[[package]]
name = "zerovec-derive"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eadce39539ca5cb3985590102671f2567e659fca9666581ad3411d59207951f3"
dependencies = [
"proc-macro2",
"quote",
"syn",
]

View File

@@ -0,0 +1,17 @@
[package]
name = "google-ai"
version = "0.1.0"
edition = "2021"
publish = false
license = "Apache-2.0"
[workspace]
[lib]
path = "src/google_ai.rs"
crate-type = ["cdylib"]
[dependencies]
zed_extension_api = { path = "../../crates/extension_api" }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"

View File

@@ -0,0 +1,13 @@
id = "google-ai"
name = "Google AI"
description = "Google Gemini LLM provider for Zed."
version = "0.1.0"
schema_version = 1
authors = ["Zed Team"]
repository = "https://github.com/zed-industries/zed"
[language_model_providers.google-ai]
name = "Google AI"
[language_model_providers.google-ai.auth]
env_var = "GEMINI_API_KEY"

View File

@@ -0,0 +1,3 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M7.44 12.27C7.81333 13.1217 8 14.0317 8 15C8 14.0317 8.18083 13.1217 8.5425 12.27C8.91583 11.4183 9.4175 10.6775 10.0475 10.0475C10.6775 9.4175 11.4183 8.92167 12.27 8.56C13.1217 8.18667 14.0317 8 15 8C14.0317 8 13.1217 7.81917 12.27 7.4575C11.4411 7.1001 10.6871 6.5895 10.0475 5.9525C9.4105 5.31293 8.8999 4.55891 8.5425 3.73C8.18083 2.87833 8 1.96833 8 1C8 1.96833 7.81333 2.87833 7.44 3.73C7.07833 4.58167 6.5825 5.3225 5.9525 5.9525C5.31293 6.5895 4.55891 7.1001 3.73 7.4575C2.87833 7.81917 1.96833 8 1 8C1.96833 8 2.87833 8.18667 3.73 8.56C4.58167 8.92167 5.3225 9.4175 5.9525 10.0475C6.5825 10.6775 7.07833 11.4183 7.44 12.27Z" fill="black"/>
</svg>

After

Width:  |  Height:  |  Size: 762 B

View File

@@ -0,0 +1,797 @@
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Mutex;
use serde::{Deserialize, Serialize};
use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
use zed_extension_api::{self as zed, *};
static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
struct GoogleAiProvider {
streams: Mutex<HashMap<String, StreamState>>,
next_stream_id: Mutex<u64>,
}
struct StreamState {
response_stream: Option<HttpResponseStream>,
buffer: String,
started: bool,
stop_reason: Option<LlmStopReason>,
wants_tool_use: bool,
}
struct ModelDefinition {
real_id: &'static str,
display_name: &'static str,
max_tokens: u64,
max_output_tokens: Option<u64>,
supports_images: bool,
supports_thinking: bool,
is_default: bool,
is_default_fast: bool,
}
const MODELS: &[ModelDefinition] = &[
ModelDefinition {
real_id: "gemini-2.5-flash-lite",
display_name: "Gemini 2.5 Flash-Lite",
max_tokens: 1_048_576,
max_output_tokens: Some(65_536),
supports_images: true,
supports_thinking: true,
is_default: false,
is_default_fast: true,
},
ModelDefinition {
real_id: "gemini-2.5-flash",
display_name: "Gemini 2.5 Flash",
max_tokens: 1_048_576,
max_output_tokens: Some(65_536),
supports_images: true,
supports_thinking: true,
is_default: true,
is_default_fast: false,
},
ModelDefinition {
real_id: "gemini-2.5-pro",
display_name: "Gemini 2.5 Pro",
max_tokens: 1_048_576,
max_output_tokens: Some(65_536),
supports_images: true,
supports_thinking: true,
is_default: false,
is_default_fast: false,
},
ModelDefinition {
real_id: "gemini-3-pro-preview",
display_name: "Gemini 3 Pro",
max_tokens: 1_048_576,
max_output_tokens: Some(65_536),
supports_images: true,
supports_thinking: true,
is_default: false,
is_default_fast: false,
},
];
fn get_real_model_id(display_name: &str) -> Option<&'static str> {
MODELS
.iter()
.find(|m| m.display_name == display_name)
.map(|m| m.real_id)
}
fn get_model_supports_thinking(display_name: &str) -> bool {
MODELS
.iter()
.find(|m| m.display_name == display_name)
.map(|m| m.supports_thinking)
.unwrap_or(false)
}
/// Adapts a JSON schema to be compatible with Google's API subset.
/// Google only supports a specific subset of JSON Schema fields.
/// See: https://ai.google.dev/api/caching#Schema
fn adapt_schema_for_google(json: &mut serde_json::Value) {
adapt_schema_for_google_impl(json, true);
}
fn adapt_schema_for_google_impl(json: &mut serde_json::Value, is_schema: bool) {
if let serde_json::Value::Object(obj) = json {
// Google's Schema only supports these fields:
// type, format, title, description, nullable, enum, maxItems, minItems,
// properties, required, minProperties, maxProperties, minLength, maxLength,
// pattern, example, anyOf, propertyOrdering, default, items, minimum, maximum
const ALLOWED_KEYS: &[&str] = &[
"type",
"format",
"title",
"description",
"nullable",
"enum",
"maxItems",
"minItems",
"properties",
"required",
"minProperties",
"maxProperties",
"minLength",
"maxLength",
"pattern",
"example",
"anyOf",
"propertyOrdering",
"default",
"items",
"minimum",
"maximum",
];
// Convert oneOf to anyOf before filtering keys
if let Some(one_of) = obj.remove("oneOf") {
obj.insert("anyOf".to_string(), one_of);
}
// If type is an array (e.g., ["string", "null"]), take just the first type
if let Some(type_field) = obj.get_mut("type") {
if let serde_json::Value::Array(types) = type_field {
if let Some(first_type) = types.first().cloned() {
*type_field = first_type;
}
}
}
// Only filter keys if this is a schema object, not a properties map
if is_schema {
obj.retain(|key, _| ALLOWED_KEYS.contains(&key.as_str()));
}
// Recursively process nested values
// "properties" contains a map of property names -> schemas
// "items" and "anyOf" contain schemas directly
for (key, value) in obj.iter_mut() {
if key == "properties" {
// properties is a map of property_name -> schema
if let serde_json::Value::Object(props) = value {
for (_, prop_schema) in props.iter_mut() {
adapt_schema_for_google_impl(prop_schema, true);
}
}
} else if key == "items" {
// items is a schema
adapt_schema_for_google_impl(value, true);
} else if key == "anyOf" {
// anyOf is an array of schemas
if let serde_json::Value::Array(arr) = value {
for item in arr.iter_mut() {
adapt_schema_for_google_impl(item, true);
}
}
}
}
} else if let serde_json::Value::Array(arr) = json {
for item in arr.iter_mut() {
adapt_schema_for_google_impl(item, true);
}
}
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct GoogleRequest {
contents: Vec<GoogleContent>,
#[serde(skip_serializing_if = "Option::is_none")]
system_instruction: Option<GoogleSystemInstruction>,
#[serde(skip_serializing_if = "Option::is_none")]
generation_config: Option<GoogleGenerationConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<GoogleTool>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_config: Option<GoogleToolConfig>,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct GoogleSystemInstruction {
parts: Vec<GooglePart>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
struct GoogleContent {
parts: Vec<GooglePart>,
#[serde(skip_serializing_if = "Option::is_none")]
role: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(untagged)]
enum GooglePart {
Text(GoogleTextPart),
InlineData(GoogleInlineDataPart),
FunctionCall(GoogleFunctionCallPart),
FunctionResponse(GoogleFunctionResponsePart),
Thought(GoogleThoughtPart),
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
struct GoogleTextPart {
text: String,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
struct GoogleInlineDataPart {
inline_data: GoogleBlob,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
struct GoogleBlob {
mime_type: String,
data: String,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
struct GoogleFunctionCallPart {
function_call: GoogleFunctionCall,
#[serde(skip_serializing_if = "Option::is_none")]
thought_signature: Option<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
struct GoogleFunctionCall {
name: String,
args: serde_json::Value,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
struct GoogleFunctionResponsePart {
function_response: GoogleFunctionResponse,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
struct GoogleFunctionResponse {
name: String,
response: serde_json::Value,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
#[serde(rename_all = "camelCase")]
struct GoogleThoughtPart {
thought: bool,
thought_signature: String,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct GoogleGenerationConfig {
#[serde(skip_serializing_if = "Option::is_none")]
candidate_count: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
stop_sequences: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
max_output_tokens: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
thinking_config: Option<GoogleThinkingConfig>,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct GoogleThinkingConfig {
thinking_budget: u32,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct GoogleTool {
function_declarations: Vec<GoogleFunctionDeclaration>,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct GoogleFunctionDeclaration {
name: String,
description: String,
parameters: serde_json::Value,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct GoogleToolConfig {
function_calling_config: GoogleFunctionCallingConfig,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct GoogleFunctionCallingConfig {
mode: String,
#[serde(skip_serializing_if = "Option::is_none")]
allowed_function_names: Option<Vec<String>>,
}
#[derive(Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
struct GoogleStreamResponse {
#[serde(default)]
candidates: Vec<GoogleCandidate>,
#[serde(default)]
usage_metadata: Option<GoogleUsageMetadata>,
}
#[derive(Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
struct GoogleCandidate {
#[serde(default)]
content: Option<GoogleContent>,
#[serde(default)]
finish_reason: Option<String>,
}
#[derive(Deserialize, Debug)]
#[serde(rename_all = "camelCase")]
struct GoogleUsageMetadata {
#[serde(default)]
prompt_token_count: u64,
#[serde(default)]
candidates_token_count: u64,
}
fn convert_request(
model_id: &str,
request: &LlmCompletionRequest,
) -> Result<(GoogleRequest, String), String> {
let real_model_id =
get_real_model_id(model_id).ok_or_else(|| format!("Unknown model: {}", model_id))?;
let supports_thinking = get_model_supports_thinking(model_id);
let mut contents: Vec<GoogleContent> = Vec::new();
let mut system_parts: Vec<GooglePart> = Vec::new();
for msg in &request.messages {
match msg.role {
LlmMessageRole::System => {
for content in &msg.content {
if let LlmMessageContent::Text(text) = content {
if !text.is_empty() {
system_parts
.push(GooglePart::Text(GoogleTextPart { text: text.clone() }));
}
}
}
}
LlmMessageRole::User => {
let mut parts: Vec<GooglePart> = Vec::new();
for content in &msg.content {
match content {
LlmMessageContent::Text(text) => {
if !text.is_empty() {
parts.push(GooglePart::Text(GoogleTextPart { text: text.clone() }));
}
}
LlmMessageContent::Image(img) => {
parts.push(GooglePart::InlineData(GoogleInlineDataPart {
inline_data: GoogleBlob {
mime_type: "image/png".to_string(),
data: img.source.clone(),
},
}));
}
LlmMessageContent::ToolResult(result) => {
let response_value = match &result.content {
LlmToolResultContent::Text(t) => {
serde_json::json!({ "output": t })
}
LlmToolResultContent::Image(_) => {
serde_json::json!({ "output": "Tool responded with an image" })
}
};
parts.push(GooglePart::FunctionResponse(GoogleFunctionResponsePart {
function_response: GoogleFunctionResponse {
name: result.tool_name.clone(),
response: response_value,
},
}));
}
_ => {}
}
}
if !parts.is_empty() {
contents.push(GoogleContent {
parts,
role: Some("user".to_string()),
});
}
}
LlmMessageRole::Assistant => {
let mut parts: Vec<GooglePart> = Vec::new();
for content in &msg.content {
match content {
LlmMessageContent::Text(text) => {
if !text.is_empty() {
parts.push(GooglePart::Text(GoogleTextPart { text: text.clone() }));
}
}
LlmMessageContent::ToolUse(tool_use) => {
let thought_signature =
tool_use.thought_signature.clone().filter(|s| !s.is_empty());
let args: serde_json::Value =
serde_json::from_str(&tool_use.input).unwrap_or_default();
parts.push(GooglePart::FunctionCall(GoogleFunctionCallPart {
function_call: GoogleFunctionCall {
name: tool_use.name.clone(),
args,
},
thought_signature,
}));
}
LlmMessageContent::Thinking(thinking) => {
if let Some(ref signature) = thinking.signature {
if !signature.is_empty() {
parts.push(GooglePart::Thought(GoogleThoughtPart {
thought: true,
thought_signature: signature.clone(),
}));
}
}
}
_ => {}
}
}
if !parts.is_empty() {
contents.push(GoogleContent {
parts,
role: Some("model".to_string()),
});
}
}
}
}
let system_instruction = if system_parts.is_empty() {
None
} else {
Some(GoogleSystemInstruction {
parts: system_parts,
})
};
let tools: Option<Vec<GoogleTool>> = if request.tools.is_empty() {
None
} else {
let declarations: Vec<GoogleFunctionDeclaration> = request
.tools
.iter()
.map(|t| {
let mut parameters: serde_json::Value = serde_json::from_str(&t.input_schema)
.unwrap_or(serde_json::Value::Object(Default::default()));
adapt_schema_for_google(&mut parameters);
GoogleFunctionDeclaration {
name: t.name.clone(),
description: t.description.clone(),
parameters,
}
})
.collect();
Some(vec![GoogleTool {
function_declarations: declarations,
}])
};
let tool_config = request.tool_choice.as_ref().map(|tc| {
let mode = match tc {
LlmToolChoice::Auto => "AUTO",
LlmToolChoice::Any => "ANY",
LlmToolChoice::None => "NONE",
};
GoogleToolConfig {
function_calling_config: GoogleFunctionCallingConfig {
mode: mode.to_string(),
allowed_function_names: None,
},
}
});
let thinking_config = if supports_thinking && request.thinking_allowed {
Some(GoogleThinkingConfig {
thinking_budget: 8192,
})
} else {
None
};
let generation_config = Some(GoogleGenerationConfig {
candidate_count: Some(1),
stop_sequences: if request.stop_sequences.is_empty() {
None
} else {
Some(request.stop_sequences.clone())
},
max_output_tokens: None,
temperature: request.temperature.map(|t| t as f64).or(Some(1.0)),
thinking_config,
});
Ok((
GoogleRequest {
contents,
system_instruction,
generation_config,
tools,
tool_config,
},
real_model_id.to_string(),
))
}
fn parse_stream_line(line: &str) -> Option<GoogleStreamResponse> {
let trimmed = line.trim();
if trimmed.is_empty() || trimmed == "[" || trimmed == "]" || trimmed == "," {
return None;
}
let json_str = trimmed.strip_prefix("data: ").unwrap_or(trimmed);
let json_str = json_str.trim_start_matches(',').trim();
if json_str.is_empty() {
return None;
}
serde_json::from_str(json_str).ok()
}
impl zed::Extension for GoogleAiProvider {
fn new() -> Self {
Self {
streams: Mutex::new(HashMap::new()),
next_stream_id: Mutex::new(0),
}
}
fn llm_providers(&self) -> Vec<LlmProviderInfo> {
vec![LlmProviderInfo {
id: "google-ai".into(),
name: "Google AI".into(),
icon: Some("icons/google-ai.svg".into()),
}]
}
fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
Ok(MODELS
.iter()
.map(|m| LlmModelInfo {
id: m.display_name.to_string(),
name: m.display_name.to_string(),
max_token_count: m.max_tokens,
max_output_tokens: m.max_output_tokens,
capabilities: LlmModelCapabilities {
supports_images: m.supports_images,
supports_tools: true,
supports_tool_choice_auto: true,
supports_tool_choice_any: true,
supports_tool_choice_none: true,
supports_thinking: m.supports_thinking,
tool_input_format: LlmToolInputFormat::JsonSchema,
},
is_default: m.is_default,
is_default_fast: m.is_default_fast,
})
.collect())
}
fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
llm_get_credential("google-ai").is_some()
}
fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
Some(
"To use Google AI, you need an API key. You can create one [here](https://aistudio.google.com/apikey).".to_string(),
)
}
fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
llm_delete_credential("google-ai")
}
fn llm_stream_completion_start(
&mut self,
_provider_id: &str,
model_id: &str,
request: &LlmCompletionRequest,
) -> Result<String, String> {
let api_key = llm_get_credential("google-ai").ok_or_else(|| {
"No API key configured. Please add your Google AI API key in settings.".to_string()
})?;
let (google_request, real_model_id) = convert_request(model_id, request)?;
let body = serde_json::to_vec(&google_request)
.map_err(|e| format!("Failed to serialize request: {}", e))?;
let url = format!(
"https://generativelanguage.googleapis.com/v1beta/models/{}:streamGenerateContent?alt=sse&key={}",
real_model_id, api_key
);
let http_request = HttpRequest {
method: HttpMethod::Post,
url,
headers: vec![("Content-Type".to_string(), "application/json".to_string())],
body: Some(body),
redirect_policy: RedirectPolicy::FollowAll,
};
let response_stream = http_request
.fetch_stream()
.map_err(|e| format!("HTTP request failed: {}", e))?;
let stream_id = {
let mut id_counter = self.next_stream_id.lock().unwrap();
let id = format!("google-ai-stream-{}", *id_counter);
*id_counter += 1;
id
};
self.streams.lock().unwrap().insert(
stream_id.clone(),
StreamState {
response_stream: Some(response_stream),
buffer: String::new(),
started: false,
stop_reason: None,
wants_tool_use: false,
},
);
Ok(stream_id)
}
fn llm_stream_completion_next(
&mut self,
stream_id: &str,
) -> Result<Option<LlmCompletionEvent>, String> {
let mut streams = self.streams.lock().unwrap();
let state = streams
.get_mut(stream_id)
.ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
if !state.started {
state.started = true;
return Ok(Some(LlmCompletionEvent::Started));
}
let response_stream = state
.response_stream
.as_mut()
.ok_or_else(|| "Stream already closed".to_string())?;
loop {
if let Some(newline_pos) = state.buffer.find('\n') {
let line = state.buffer[..newline_pos].to_string();
state.buffer = state.buffer[newline_pos + 1..].to_string();
if let Some(response) = parse_stream_line(&line) {
for candidate in response.candidates {
if let Some(finish_reason) = &candidate.finish_reason {
state.stop_reason = Some(match finish_reason.as_str() {
"STOP" => {
if state.wants_tool_use {
LlmStopReason::ToolUse
} else {
LlmStopReason::EndTurn
}
}
"MAX_TOKENS" => LlmStopReason::MaxTokens,
"SAFETY" => LlmStopReason::Refusal,
_ => LlmStopReason::EndTurn,
});
}
if let Some(content) = candidate.content {
for part in content.parts {
match part {
GooglePart::Text(text_part) => {
if !text_part.text.is_empty() {
return Ok(Some(LlmCompletionEvent::Text(
text_part.text,
)));
}
}
GooglePart::FunctionCall(fc_part) => {
state.wants_tool_use = true;
let next_tool_id =
TOOL_CALL_COUNTER.fetch_add(1, Ordering::SeqCst);
let id = format!(
"{}-{}",
fc_part.function_call.name, next_tool_id
);
let thought_signature =
fc_part.thought_signature.filter(|s| !s.is_empty());
return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
id,
name: fc_part.function_call.name,
input: fc_part.function_call.args.to_string(),
thought_signature,
})));
}
GooglePart::Thought(thought_part) => {
return Ok(Some(LlmCompletionEvent::Thinking(
LlmThinkingContent {
text: "(Encrypted thought)".to_string(),
signature: Some(thought_part.thought_signature),
},
)));
}
_ => {}
}
}
}
}
if let Some(usage) = response.usage_metadata {
return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
input_tokens: usage.prompt_token_count,
output_tokens: usage.candidates_token_count,
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
})));
}
}
continue;
}
match response_stream.next_chunk() {
Ok(Some(chunk)) => {
let text = String::from_utf8_lossy(&chunk);
state.buffer.push_str(&text);
}
Ok(None) => {
// Stream ended - check if we have a stop reason
if let Some(stop_reason) = state.stop_reason.take() {
return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
}
// No stop reason - this is unexpected. Check if buffer contains error info
let mut error_msg = String::from("Stream ended unexpectedly.");
// Try to parse remaining buffer as potential error response
if !state.buffer.is_empty() {
error_msg.push_str(&format!(
"\nRemaining buffer: {}",
&state.buffer[..state.buffer.len().min(1000)]
));
}
return Err(error_msg);
}
Err(e) => {
return Err(format!("Stream error: {}", e));
}
}
}
}
fn llm_stream_completion_close(&mut self, stream_id: &str) {
self.streams.lock().unwrap().remove(stream_id);
}
}
zed::register_extension!(GoogleAiProvider);

823
extensions/openai/Cargo.lock generated Normal file
View File

@@ -0,0 +1,823 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 4
[[package]]
name = "adler2"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa"
[[package]]
name = "anyhow"
version = "1.0.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61"
[[package]]
name = "auditable-serde"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c7bf8143dfc3c0258df908843e169b5cc5fcf76c7718bd66135ef4a9cd558c5"
dependencies = [
"semver",
"serde",
"serde_json",
"topological-sort",
]
[[package]]
name = "bitflags"
version = "2.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3"
[[package]]
name = "cfg-if"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
[[package]]
name = "crc32fast"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511"
dependencies = [
"cfg-if",
]
[[package]]
name = "displaydoc"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "equivalent"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
[[package]]
name = "flate2"
version = "1.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfe33edd8e85a12a67454e37f8c75e730830d83e313556ab9ebf9ee7fbeb3bfb"
dependencies = [
"crc32fast",
"miniz_oxide",
]
[[package]]
name = "foldhash"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
[[package]]
name = "form_urlencoded"
version = "1.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf"
dependencies = [
"percent-encoding",
]
[[package]]
name = "futures"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876"
dependencies = [
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-channel"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10"
dependencies = [
"futures-core",
"futures-sink",
]
[[package]]
name = "futures-core"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e"
[[package]]
name = "futures-executor"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f"
dependencies = [
"futures-core",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-io"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6"
[[package]]
name = "futures-macro"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "futures-sink"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7"
[[package]]
name = "futures-task"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
[[package]]
name = "futures-util"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-macro",
"futures-sink",
"futures-task",
"memchr",
"pin-project-lite",
"pin-utils",
"slab",
]
[[package]]
name = "hashbrown"
version = "0.15.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1"
dependencies = [
"foldhash",
]
[[package]]
name = "hashbrown"
version = "0.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100"
[[package]]
name = "heck"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "icu_collections"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c6b649701667bbe825c3b7e6388cb521c23d88644678e83c0c4d0a621a34b43"
dependencies = [
"displaydoc",
"potential_utf",
"yoke",
"zerofrom",
"zerovec",
]
[[package]]
name = "icu_locale_core"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "edba7861004dd3714265b4db54a3c390e880ab658fec5f7db895fae2046b5bb6"
dependencies = [
"displaydoc",
"litemap",
"tinystr",
"writeable",
"zerovec",
]
[[package]]
name = "icu_normalizer"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f6c8828b67bf8908d82127b2054ea1b4427ff0230ee9141c54251934ab1b599"
dependencies = [
"icu_collections",
"icu_normalizer_data",
"icu_properties",
"icu_provider",
"smallvec",
"zerovec",
]
[[package]]
name = "icu_normalizer_data"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7aedcccd01fc5fe81e6b489c15b247b8b0690feb23304303a9e560f37efc560a"
[[package]]
name = "icu_properties"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e93fcd3157766c0c8da2f8cff6ce651a31f0810eaa1c51ec363ef790bbb5fb99"
dependencies = [
"icu_collections",
"icu_locale_core",
"icu_properties_data",
"icu_provider",
"zerotrie",
"zerovec",
]
[[package]]
name = "icu_properties_data"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02845b3647bb045f1100ecd6480ff52f34c35f82d9880e029d329c21d1054899"
[[package]]
name = "icu_provider"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85962cf0ce02e1e0a629cc34e7ca3e373ce20dda4c4d7294bbd0bf1fdb59e614"
dependencies = [
"displaydoc",
"icu_locale_core",
"writeable",
"yoke",
"zerofrom",
"zerotrie",
"zerovec",
]
[[package]]
name = "id-arena"
version = "2.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "25a2bc672d1148e28034f176e01fffebb08b35768468cc954630da77a1449005"
[[package]]
name = "idna"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de"
dependencies = [
"idna_adapter",
"smallvec",
"utf8_iter",
]
[[package]]
name = "idna_adapter"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344"
dependencies = [
"icu_normalizer",
"icu_properties",
]
[[package]]
name = "indexmap"
version = "2.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ad4bb2b565bca0645f4d68c5c9af97fba094e9791da685bf83cb5f3ce74acf2"
dependencies = [
"equivalent",
"hashbrown 0.16.1",
"serde",
"serde_core",
]
[[package]]
name = "itoa"
version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c"
[[package]]
name = "leb128fmt"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2"
[[package]]
name = "litemap"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77"
[[package]]
name = "log"
version = "0.4.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897"
[[package]]
name = "memchr"
version = "2.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273"
[[package]]
name = "miniz_oxide"
version = "0.8.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316"
dependencies = [
"adler2",
"simd-adler32",
]
[[package]]
name = "once_cell"
version = "1.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
[[package]]
name = "openai"
version = "0.1.0"
dependencies = [
"serde",
"serde_json",
"zed_extension_api",
]
[[package]]
name = "percent-encoding"
version = "2.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220"
[[package]]
name = "pin-project-lite"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b"
[[package]]
name = "pin-utils"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]]
name = "potential_utf"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b73949432f5e2a09657003c25bca5e19a0e9c84f8058ca374f49e0ebe605af77"
dependencies = [
"zerovec",
]
[[package]]
name = "prettyplease"
version = "0.2.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b"
dependencies = [
"proc-macro2",
"syn",
]
[[package]]
name = "proc-macro2"
version = "1.0.103"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ee95bc4ef87b8d5ba32e8b7714ccc834865276eab0aed5c9958d00ec45f49e8"
dependencies = [
"unicode-ident",
]
[[package]]
name = "quote"
version = "1.0.42"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a338cc41d27e6cc6dce6cefc13a0729dfbb81c262b1f519331575dd80ef3067f"
dependencies = [
"proc-macro2",
]
[[package]]
name = "ryu"
version = "1.0.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f"
[[package]]
name = "semver"
version = "1.0.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2"
dependencies = [
"serde",
"serde_core",
]
[[package]]
name = "serde"
version = "1.0.228"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e"
dependencies = [
"serde_core",
"serde_derive",
]
[[package]]
name = "serde_core"
version = "1.0.228"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.228"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "serde_json"
version = "1.0.145"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c"
dependencies = [
"itoa",
"memchr",
"ryu",
"serde",
"serde_core",
]
[[package]]
name = "simd-adler32"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe"
[[package]]
name = "slab"
version = "0.4.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589"
[[package]]
name = "smallvec"
version = "1.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03"
[[package]]
name = "spdx"
version = "0.10.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3e17e880bafaeb362a7b751ec46bdc5b61445a188f80e0606e68167cd540fa3"
dependencies = [
"smallvec",
]
[[package]]
name = "stable_deref_trait"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596"
[[package]]
name = "syn"
version = "2.0.111"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "synstructure"
version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "tinystr"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42d3e9c45c09de15d06dd8acf5f4e0e399e85927b7f00711024eb7ae10fa4869"
dependencies = [
"displaydoc",
"zerovec",
]
[[package]]
name = "topological-sort"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ea68304e134ecd095ac6c3574494fc62b909f416c4fca77e440530221e549d3d"
[[package]]
name = "unicode-ident"
version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
[[package]]
name = "unicode-xid"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
[[package]]
name = "url"
version = "2.5.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08bc136a29a3d1758e07a9cca267be308aeebf5cfd5a10f3f67ab2097683ef5b"
dependencies = [
"form_urlencoded",
"idna",
"percent-encoding",
"serde",
]
[[package]]
name = "utf8_iter"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
[[package]]
name = "wasm-encoder"
version = "0.227.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80bb72f02e7fbf07183443b27b0f3d4144abf8c114189f2e088ed95b696a7822"
dependencies = [
"leb128fmt",
"wasmparser",
]
[[package]]
name = "wasm-metadata"
version = "0.227.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ce1ef0faabbbba6674e97a56bee857ccddf942785a336c8b47b42373c922a91d"
dependencies = [
"anyhow",
"auditable-serde",
"flate2",
"indexmap",
"serde",
"serde_derive",
"serde_json",
"spdx",
"url",
"wasm-encoder",
"wasmparser",
]
[[package]]
name = "wasmparser"
version = "0.227.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f51cad774fb3c9461ab9bccc9c62dfb7388397b5deda31bf40e8108ccd678b2"
dependencies = [
"bitflags",
"hashbrown 0.15.5",
"indexmap",
"semver",
]
[[package]]
name = "wit-bindgen"
version = "0.41.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10fb6648689b3929d56bbc7eb1acf70c9a42a29eb5358c67c10f54dbd5d695de"
dependencies = [
"wit-bindgen-rt",
"wit-bindgen-rust-macro",
]
[[package]]
name = "wit-bindgen-core"
version = "0.41.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92fa781d4f2ff6d3f27f3cc9b74a73327b31ca0dc4a3ef25a0ce2983e0e5af9b"
dependencies = [
"anyhow",
"heck",
"wit-parser",
]
[[package]]
name = "wit-bindgen-rt"
version = "0.41.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4db52a11d4dfb0a59f194c064055794ee6564eb1ced88c25da2cf76e50c5621"
dependencies = [
"bitflags",
"futures",
"once_cell",
]
[[package]]
name = "wit-bindgen-rust"
version = "0.41.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d0809dc5ba19e2e98661bf32fc0addc5a3ca5bf3a6a7083aa6ba484085ff3ce"
dependencies = [
"anyhow",
"heck",
"indexmap",
"prettyplease",
"syn",
"wasm-metadata",
"wit-bindgen-core",
"wit-component",
]
[[package]]
name = "wit-bindgen-rust-macro"
version = "0.41.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad19eec017904e04c60719592a803ee5da76cb51c81e3f6fbf9457f59db49799"
dependencies = [
"anyhow",
"prettyplease",
"proc-macro2",
"quote",
"syn",
"wit-bindgen-core",
"wit-bindgen-rust",
]
[[package]]
name = "wit-component"
version = "0.227.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "635c3adc595422cbf2341a17fb73a319669cc8d33deed3a48368a841df86b676"
dependencies = [
"anyhow",
"bitflags",
"indexmap",
"log",
"serde",
"serde_derive",
"serde_json",
"wasm-encoder",
"wasm-metadata",
"wasmparser",
"wit-parser",
]
[[package]]
name = "wit-parser"
version = "0.227.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ddf445ed5157046e4baf56f9138c124a0824d4d1657e7204d71886ad8ce2fc11"
dependencies = [
"anyhow",
"id-arena",
"indexmap",
"log",
"semver",
"serde",
"serde_derive",
"serde_json",
"unicode-xid",
"wasmparser",
]
[[package]]
name = "writeable"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9"
[[package]]
name = "yoke"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72d6e5c6afb84d73944e5cedb052c4680d5657337201555f9f2a16b7406d4954"
dependencies = [
"stable_deref_trait",
"yoke-derive",
"zerofrom",
]
[[package]]
name = "yoke-derive"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d"
dependencies = [
"proc-macro2",
"quote",
"syn",
"synstructure",
]
[[package]]
name = "zed_extension_api"
version = "0.8.0"
dependencies = [
"serde",
"serde_json",
"wit-bindgen",
]
[[package]]
name = "zerofrom"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5"
dependencies = [
"zerofrom-derive",
]
[[package]]
name = "zerofrom-derive"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502"
dependencies = [
"proc-macro2",
"quote",
"syn",
"synstructure",
]
[[package]]
name = "zerotrie"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a59c17a5562d507e4b54960e8569ebee33bee890c70aa3fe7b97e85a9fd7851"
dependencies = [
"displaydoc",
"yoke",
"zerofrom",
]
[[package]]
name = "zerovec"
version = "0.11.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c28719294829477f525be0186d13efa9a3c602f7ec202ca9e353d310fb9a002"
dependencies = [
"yoke",
"zerofrom",
"zerovec-derive",
]
[[package]]
name = "zerovec-derive"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eadce39539ca5cb3985590102671f2567e659fca9666581ad3411d59207951f3"
dependencies = [
"proc-macro2",
"quote",
"syn",
]

View File

@@ -0,0 +1,17 @@
[package]
name = "openai"
version = "0.1.0"
edition = "2021"
publish = false
license = "Apache-2.0"
[workspace]
[lib]
path = "src/openai.rs"
crate-type = ["cdylib"]
[dependencies]
zed_extension_api = { path = "../../crates/extension_api" }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"

View File

@@ -0,0 +1,13 @@
id = "openai"
name = "OpenAI"
description = "OpenAI GPT LLM provider for Zed."
version = "0.1.0"
schema_version = 1
authors = ["Zed Team"]
repository = "https://github.com/zed-industries/zed"
[language_model_providers.openai]
name = "OpenAI"
[language_model_providers.openai.auth]
env_var = "OPENAI_API_KEY"

View File

@@ -0,0 +1,3 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M14.5768 6.73011C14.8987 5.77678 14.7879 4.73245 14.2731 3.86531C13.4989 2.53528 11.9427 1.85102 10.4227 2.17303C9.74656 1.42139 8.7751 0.993944 7.75664 1.00006C6.20301 0.996569 4.82452 1.98358 4.34655 3.44224C3.34849 3.64393 2.48699 4.26038 1.98286 5.13408C1.20294 6.46061 1.38074 8.13277 2.4227 9.27029C2.1008 10.2236 2.21164 11.268 2.72642 12.1351C3.50057 13.4651 5.05686 14.1494 6.57679 13.8274C7.25251 14.579 8.22441 15.0064 9.24287 14.9999C10.7974 15.0038 12.1763 14.0159 12.6543 12.556C13.6524 12.3543 14.5139 11.7379 15.018 10.8641C15.797 9.5376 15.6188 7.86676 14.5773 6.72924L14.5768 6.73011ZM9.24376 14.0851C8.62169 14.0859 8.01912 13.8711 7.5416 13.4778C7.56332 13.4664 7.60101 13.4459 7.6254 13.431L10.4507 11.821C10.5952 11.7401 10.6839 11.5882 10.683 11.4242V7.49401L11.877 8.17433C11.8899 8.18045 11.8983 8.1927 11.9001 8.2067V11.4614C11.8983 12.9086 10.7105 14.082 9.24376 14.0851ZM3.53116 11.6775C3.21946 11.1464 3.10729 10.5237 3.21414 9.91955C3.23498 9.9318 3.27178 9.95411 3.29794 9.96898L6.1232 11.5791C6.26642 11.6617 6.44377 11.6617 6.58743 11.5791L10.0365 9.61373V10.9744C10.0374 10.9884 10.0308 11.002 10.0197 11.0107L7.16383 12.6378C5.89175 13.3606 4.26674 12.9309 3.53116 11.6775ZM2.7876 5.59215C3.09797 5.06014 3.58792 4.65326 4.17141 4.44195C4.17141 4.46601 4.17008 4.50845 4.17008 4.5382V7.75869C4.1692 7.92232 4.25787 8.07414 4.40198 8.15508L7.85108 10.1199L6.65704 10.8002C6.64507 10.8081 6.62999 10.8094 6.61669 10.8037L3.76039 9.17535C2.49098 8.44995 2.05601 6.84692 2.7876 5.59215ZM12.598 7.84488L9.14887 5.8796L10.3429 5.19971C10.3549 5.19183 10.37 5.19052 10.3833 5.19621L13.2396 6.8233C14.5112 7.54826 14.947 9.15347 14.2124 10.4082C13.9015 10.9394 13.412 11.3463 12.829 11.5581V8.24127C12.8303 8.07764 12.7417 7.92626 12.598 7.84488ZM13.7863 6.07998C13.7654 6.06729 13.7286 6.04541 13.7025 6.03054L10.8772 4.42051C10.734 4.33782 10.5566 4.33782 10.413 4.42051L6.96386 6.3858V5.02514C6.96298 5.01114 6.96963 4.99758 6.98071 4.98883L9.83657 3.36305C11.1086 2.63898 12.735 3.06992 13.4683 4.32557C13.7783 4.85583 13.8914 5.47665 13.7863 6.07998ZM6.31475 8.50509L5.12026 7.82476C5.1074 7.81863 5.09898 7.80638 5.09721 7.79238V4.53776C5.09809 3.08873 6.28947 1.91446 7.75797 1.91533C8.37916 1.91533 8.98039 2.13059 9.45792 2.52259C9.43619 2.53397 9.39894 2.55453 9.37412 2.56941L6.54885 4.17944C6.40431 4.26038 6.31563 4.41176 6.31652 4.57582L6.31475 8.50509ZM6.96342 7.12518L8.49976 6.24973L10.0361 7.12475V8.87521L8.49976 9.75023L6.96342 8.87521V7.12518Z" fill="black"/>
</svg>

After

Width:  |  Height:  |  Size: 2.6 KiB

View File

@@ -0,0 +1,680 @@
use std::collections::HashMap;
use std::sync::Mutex;
use serde::{Deserialize, Serialize};
use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
use zed_extension_api::{self as zed, *};
struct OpenAiProvider {
streams: Mutex<HashMap<String, StreamState>>,
next_stream_id: Mutex<u64>,
}
struct StreamState {
response_stream: Option<HttpResponseStream>,
buffer: String,
started: bool,
tool_calls: HashMap<usize, AccumulatedToolCall>,
tool_calls_emitted: bool,
}
#[derive(Clone, Default)]
struct AccumulatedToolCall {
id: String,
name: String,
arguments: String,
}
struct ModelDefinition {
real_id: &'static str,
display_name: &'static str,
max_tokens: u64,
max_output_tokens: Option<u64>,
supports_images: bool,
is_default: bool,
is_default_fast: bool,
}
const MODELS: &[ModelDefinition] = &[
ModelDefinition {
real_id: "gpt-4o",
display_name: "GPT-4o",
max_tokens: 128_000,
max_output_tokens: Some(16_384),
supports_images: true,
is_default: true,
is_default_fast: false,
},
ModelDefinition {
real_id: "gpt-4o-mini",
display_name: "GPT-4o-mini",
max_tokens: 128_000,
max_output_tokens: Some(16_384),
supports_images: true,
is_default: false,
is_default_fast: true,
},
ModelDefinition {
real_id: "gpt-4.1",
display_name: "GPT-4.1",
max_tokens: 1_047_576,
max_output_tokens: Some(32_768),
supports_images: true,
is_default: false,
is_default_fast: false,
},
ModelDefinition {
real_id: "gpt-4.1-mini",
display_name: "GPT-4.1-mini",
max_tokens: 1_047_576,
max_output_tokens: Some(32_768),
supports_images: true,
is_default: false,
is_default_fast: false,
},
ModelDefinition {
real_id: "gpt-4.1-nano",
display_name: "GPT-4.1-nano",
max_tokens: 1_047_576,
max_output_tokens: Some(32_768),
supports_images: true,
is_default: false,
is_default_fast: false,
},
ModelDefinition {
real_id: "gpt-5",
display_name: "GPT-5",
max_tokens: 272_000,
max_output_tokens: Some(32_768),
supports_images: true,
is_default: false,
is_default_fast: false,
},
ModelDefinition {
real_id: "gpt-5-mini",
display_name: "GPT-5-mini",
max_tokens: 272_000,
max_output_tokens: Some(32_768),
supports_images: true,
is_default: false,
is_default_fast: false,
},
ModelDefinition {
real_id: "o1",
display_name: "o1",
max_tokens: 200_000,
max_output_tokens: Some(100_000),
supports_images: true,
is_default: false,
is_default_fast: false,
},
ModelDefinition {
real_id: "o3",
display_name: "o3",
max_tokens: 200_000,
max_output_tokens: Some(100_000),
supports_images: true,
is_default: false,
is_default_fast: false,
},
ModelDefinition {
real_id: "o3-mini",
display_name: "o3-mini",
max_tokens: 200_000,
max_output_tokens: Some(100_000),
supports_images: false,
is_default: false,
is_default_fast: false,
},
ModelDefinition {
real_id: "o4-mini",
display_name: "o4-mini",
max_tokens: 200_000,
max_output_tokens: Some(100_000),
supports_images: true,
is_default: false,
is_default_fast: false,
},
];
fn get_real_model_id(display_name: &str) -> Option<&'static str> {
MODELS
.iter()
.find(|m| m.display_name == display_name)
.map(|m| m.real_id)
}
#[derive(Serialize)]
struct OpenAiRequest {
model: String,
messages: Vec<OpenAiMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<OpenAiTool>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_choice: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u64>,
#[serde(skip_serializing_if = "Vec::is_empty")]
stop: Vec<String>,
stream: bool,
stream_options: Option<StreamOptions>,
}
#[derive(Serialize)]
struct StreamOptions {
include_usage: bool,
}
#[derive(Serialize)]
#[serde(tag = "role")]
enum OpenAiMessage {
#[serde(rename = "system")]
System { content: String },
#[serde(rename = "user")]
User { content: Vec<OpenAiContentPart> },
#[serde(rename = "assistant")]
Assistant {
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<OpenAiToolCall>>,
},
#[serde(rename = "tool")]
Tool {
tool_call_id: String,
content: String,
},
}
#[derive(Serialize)]
#[serde(tag = "type")]
enum OpenAiContentPart {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
ImageUrl { image_url: ImageUrl },
}
#[derive(Serialize)]
struct ImageUrl {
url: String,
}
#[derive(Serialize, Deserialize, Clone)]
struct OpenAiToolCall {
id: String,
#[serde(rename = "type")]
call_type: String,
function: OpenAiFunctionCall,
}
#[derive(Serialize, Deserialize, Clone)]
struct OpenAiFunctionCall {
name: String,
arguments: String,
}
#[derive(Serialize)]
struct OpenAiTool {
#[serde(rename = "type")]
tool_type: String,
function: OpenAiFunctionDef,
}
#[derive(Serialize)]
struct OpenAiFunctionDef {
name: String,
description: String,
parameters: serde_json::Value,
}
#[derive(Deserialize, Debug)]
struct OpenAiStreamEvent {
choices: Vec<OpenAiChoice>,
#[serde(default)]
usage: Option<OpenAiUsage>,
}
#[derive(Deserialize, Debug)]
struct OpenAiChoice {
delta: OpenAiDelta,
finish_reason: Option<String>,
}
#[derive(Deserialize, Debug, Default)]
struct OpenAiDelta {
#[serde(default)]
content: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<OpenAiToolCallDelta>>,
}
#[derive(Deserialize, Debug)]
struct OpenAiToolCallDelta {
index: usize,
#[serde(default)]
id: Option<String>,
#[serde(default)]
function: Option<OpenAiFunctionDelta>,
}
#[derive(Deserialize, Debug)]
struct OpenAiFunctionDelta {
#[serde(default)]
name: Option<String>,
#[serde(default)]
arguments: Option<String>,
}
#[derive(Deserialize, Debug)]
struct OpenAiUsage {
prompt_tokens: u64,
completion_tokens: u64,
}
#[allow(dead_code)]
#[derive(Deserialize, Debug)]
struct OpenAiError {
error: OpenAiErrorDetail,
}
#[allow(dead_code)]
#[derive(Deserialize, Debug)]
struct OpenAiErrorDetail {
message: String,
}
fn convert_request(
model_id: &str,
request: &LlmCompletionRequest,
) -> Result<OpenAiRequest, String> {
let real_model_id =
get_real_model_id(model_id).ok_or_else(|| format!("Unknown model: {}", model_id))?;
let mut messages = Vec::new();
for msg in &request.messages {
match msg.role {
LlmMessageRole::System => {
let text: String = msg
.content
.iter()
.filter_map(|c| match c {
LlmMessageContent::Text(t) => Some(t.as_str()),
_ => None,
})
.collect::<Vec<_>>()
.join("\n");
if !text.is_empty() {
messages.push(OpenAiMessage::System { content: text });
}
}
LlmMessageRole::User => {
let parts: Vec<OpenAiContentPart> = msg
.content
.iter()
.filter_map(|c| match c {
LlmMessageContent::Text(t) => {
Some(OpenAiContentPart::Text { text: t.clone() })
}
LlmMessageContent::Image(img) => Some(OpenAiContentPart::ImageUrl {
image_url: ImageUrl {
url: format!("data:image/png;base64,{}", img.source),
},
}),
LlmMessageContent::ToolResult(_) => None,
_ => None,
})
.collect();
for content in &msg.content {
if let LlmMessageContent::ToolResult(result) = content {
let content_text = match &result.content {
LlmToolResultContent::Text(t) => t.clone(),
LlmToolResultContent::Image(_) => "[Image]".to_string(),
};
messages.push(OpenAiMessage::Tool {
tool_call_id: result.tool_use_id.clone(),
content: content_text,
});
}
}
if !parts.is_empty() {
messages.push(OpenAiMessage::User { content: parts });
}
}
LlmMessageRole::Assistant => {
let mut content_text: Option<String> = None;
let mut tool_calls: Vec<OpenAiToolCall> = Vec::new();
for c in &msg.content {
match c {
LlmMessageContent::Text(t) => {
content_text = Some(t.clone());
}
LlmMessageContent::ToolUse(tool_use) => {
tool_calls.push(OpenAiToolCall {
id: tool_use.id.clone(),
call_type: "function".to_string(),
function: OpenAiFunctionCall {
name: tool_use.name.clone(),
arguments: tool_use.input.clone(),
},
});
}
_ => {}
}
}
messages.push(OpenAiMessage::Assistant {
content: content_text,
tool_calls: if tool_calls.is_empty() {
None
} else {
Some(tool_calls)
},
});
}
}
}
let tools: Option<Vec<OpenAiTool>> = if request.tools.is_empty() {
None
} else {
Some(
request
.tools
.iter()
.map(|t| OpenAiTool {
tool_type: "function".to_string(),
function: OpenAiFunctionDef {
name: t.name.clone(),
description: t.description.clone(),
parameters: serde_json::from_str(&t.input_schema)
.unwrap_or(serde_json::Value::Object(Default::default())),
},
})
.collect(),
)
};
let tool_choice = request.tool_choice.as_ref().map(|tc| match tc {
LlmToolChoice::Auto => "auto".to_string(),
LlmToolChoice::Any => "required".to_string(),
LlmToolChoice::None => "none".to_string(),
});
Ok(OpenAiRequest {
model: real_model_id.to_string(),
messages,
tools,
tool_choice,
temperature: request.temperature,
max_tokens: request.max_tokens,
stop: request.stop_sequences.clone(),
stream: true,
stream_options: Some(StreamOptions {
include_usage: true,
}),
})
}
fn parse_sse_line(line: &str) -> Option<OpenAiStreamEvent> {
if let Some(data) = line.strip_prefix("data: ") {
if data == "[DONE]" {
return None;
}
serde_json::from_str(data).ok()
} else {
None
}
}
impl zed::Extension for OpenAiProvider {
fn new() -> Self {
Self {
streams: Mutex::new(HashMap::new()),
next_stream_id: Mutex::new(0),
}
}
fn llm_providers(&self) -> Vec<LlmProviderInfo> {
vec![LlmProviderInfo {
id: "openai".into(),
name: "OpenAI".into(),
icon: Some("icons/openai.svg".into()),
}]
}
fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
Ok(MODELS
.iter()
.map(|m| LlmModelInfo {
id: m.display_name.to_string(),
name: m.display_name.to_string(),
max_token_count: m.max_tokens,
max_output_tokens: m.max_output_tokens,
capabilities: LlmModelCapabilities {
supports_images: m.supports_images,
supports_tools: true,
supports_tool_choice_auto: true,
supports_tool_choice_any: true,
supports_tool_choice_none: true,
supports_thinking: false,
tool_input_format: LlmToolInputFormat::JsonSchema,
},
is_default: m.is_default,
is_default_fast: m.is_default_fast,
})
.collect())
}
fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
llm_get_credential("openai").is_some()
}
fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
Some(
"To use OpenAI, you need an API key. You can create one [here](https://platform.openai.com/api-keys).".to_string(),
)
}
fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
llm_delete_credential("openai")
}
fn llm_stream_completion_start(
&mut self,
_provider_id: &str,
model_id: &str,
request: &LlmCompletionRequest,
) -> Result<String, String> {
let api_key = llm_get_credential("openai").ok_or_else(|| {
"No API key configured. Please add your OpenAI API key in settings.".to_string()
})?;
let openai_request = convert_request(model_id, request)?;
let body = serde_json::to_vec(&openai_request)
.map_err(|e| format!("Failed to serialize request: {}", e))?;
let http_request = HttpRequest {
method: HttpMethod::Post,
url: "https://api.openai.com/v1/chat/completions".to_string(),
headers: vec![
("Content-Type".to_string(), "application/json".to_string()),
("Authorization".to_string(), format!("Bearer {}", api_key)),
],
body: Some(body),
redirect_policy: RedirectPolicy::FollowAll,
};
let response_stream = http_request
.fetch_stream()
.map_err(|e| format!("HTTP request failed: {}", e))?;
let stream_id = {
let mut id_counter = self.next_stream_id.lock().unwrap();
let id = format!("openai-stream-{}", *id_counter);
*id_counter += 1;
id
};
self.streams.lock().unwrap().insert(
stream_id.clone(),
StreamState {
response_stream: Some(response_stream),
buffer: String::new(),
started: false,
tool_calls: HashMap::new(),
tool_calls_emitted: false,
},
);
Ok(stream_id)
}
fn llm_stream_completion_next(
&mut self,
stream_id: &str,
) -> Result<Option<LlmCompletionEvent>, String> {
let mut streams = self.streams.lock().unwrap();
let state = streams
.get_mut(stream_id)
.ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
if !state.started {
state.started = true;
return Ok(Some(LlmCompletionEvent::Started));
}
let response_stream = state
.response_stream
.as_mut()
.ok_or_else(|| "Stream already closed".to_string())?;
loop {
if let Some(newline_pos) = state.buffer.find('\n') {
let line = state.buffer[..newline_pos].trim().to_string();
state.buffer = state.buffer[newline_pos + 1..].to_string();
if line.is_empty() {
continue;
}
if let Some(event) = parse_sse_line(&line) {
if let Some(choice) = event.choices.first() {
if let Some(tool_calls) = &choice.delta.tool_calls {
for tc in tool_calls {
let entry = state.tool_calls.entry(tc.index).or_default();
if let Some(id) = &tc.id {
entry.id = id.clone();
}
if let Some(func) = &tc.function {
if let Some(name) = &func.name {
entry.name = name.clone();
}
if let Some(args) = &func.arguments {
entry.arguments.push_str(args);
}
}
}
}
if let Some(reason) = &choice.finish_reason {
if reason == "tool_calls" && !state.tool_calls_emitted {
state.tool_calls_emitted = true;
if let Some((&index, _)) = state.tool_calls.iter().next() {
if let Some(tool_call) = state.tool_calls.remove(&index) {
return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
id: tool_call.id,
name: tool_call.name,
input: tool_call.arguments,
thought_signature: None,
})));
}
}
}
let stop_reason = match reason.as_str() {
"stop" => LlmStopReason::EndTurn,
"length" => LlmStopReason::MaxTokens,
"tool_calls" => LlmStopReason::ToolUse,
"content_filter" => LlmStopReason::Refusal,
_ => LlmStopReason::EndTurn,
};
if let Some(usage) = event.usage {
return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
})));
}
return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
}
if let Some(content) = &choice.delta.content {
if !content.is_empty() {
return Ok(Some(LlmCompletionEvent::Text(content.clone())));
}
}
}
if event.choices.is_empty() {
if let Some(usage) = event.usage {
return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
})));
}
}
}
continue;
}
match response_stream.next_chunk() {
Ok(Some(chunk)) => {
let text = String::from_utf8_lossy(&chunk);
state.buffer.push_str(&text);
}
Ok(None) => {
if !state.tool_calls.is_empty() && !state.tool_calls_emitted {
state.tool_calls_emitted = true;
let keys: Vec<usize> = state.tool_calls.keys().copied().collect();
if let Some(&key) = keys.first() {
if let Some(tool_call) = state.tool_calls.remove(&key) {
return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
id: tool_call.id,
name: tool_call.name,
input: tool_call.arguments,
thought_signature: None,
})));
}
}
}
return Ok(None);
}
Err(e) => {
return Err(format!("Stream error: {}", e));
}
}
}
}
fn llm_stream_completion_close(&mut self, stream_id: &str) {
self.streams.lock().unwrap().remove(stream_id);
}
}
zed::register_extension!(OpenAiProvider);

823
extensions/openrouter/Cargo.lock generated Normal file
View File

@@ -0,0 +1,823 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 4
[[package]]
name = "adler2"
version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa"
[[package]]
name = "anyhow"
version = "1.0.100"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61"
[[package]]
name = "auditable-serde"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c7bf8143dfc3c0258df908843e169b5cc5fcf76c7718bd66135ef4a9cd558c5"
dependencies = [
"semver",
"serde",
"serde_json",
"topological-sort",
]
[[package]]
name = "bitflags"
version = "2.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "812e12b5285cc515a9c72a5c1d3b6d46a19dac5acfef5265968c166106e31dd3"
[[package]]
name = "cfg-if"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801"
[[package]]
name = "crc32fast"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9481c1c90cbf2ac953f07c8d4a58aa3945c425b7185c9154d67a65e4230da511"
dependencies = [
"cfg-if",
]
[[package]]
name = "displaydoc"
version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "equivalent"
version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f"
[[package]]
name = "flate2"
version = "1.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfe33edd8e85a12a67454e37f8c75e730830d83e313556ab9ebf9ee7fbeb3bfb"
dependencies = [
"crc32fast",
"miniz_oxide",
]
[[package]]
name = "foldhash"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2"
[[package]]
name = "form_urlencoded"
version = "1.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cb4cb245038516f5f85277875cdaa4f7d2c9a0fa0468de06ed190163b1581fcf"
dependencies = [
"percent-encoding",
]
[[package]]
name = "futures"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876"
dependencies = [
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-channel"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10"
dependencies = [
"futures-core",
"futures-sink",
]
[[package]]
name = "futures-core"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e"
[[package]]
name = "futures-executor"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f"
dependencies = [
"futures-core",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-io"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6"
[[package]]
name = "futures-macro"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "futures-sink"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7"
[[package]]
name = "futures-task"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f90f7dce0722e95104fcb095585910c0977252f286e354b5e3bd38902cd99988"
[[package]]
name = "futures-util"
version = "0.3.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81"
dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-macro",
"futures-sink",
"futures-task",
"memchr",
"pin-project-lite",
"pin-utils",
"slab",
]
[[package]]
name = "hashbrown"
version = "0.15.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9229cfe53dfd69f0609a49f65461bd93001ea1ef889cd5529dd176593f5338a1"
dependencies = [
"foldhash",
]
[[package]]
name = "hashbrown"
version = "0.16.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100"
[[package]]
name = "heck"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea"
[[package]]
name = "icu_collections"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c6b649701667bbe825c3b7e6388cb521c23d88644678e83c0c4d0a621a34b43"
dependencies = [
"displaydoc",
"potential_utf",
"yoke",
"zerofrom",
"zerovec",
]
[[package]]
name = "icu_locale_core"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "edba7861004dd3714265b4db54a3c390e880ab658fec5f7db895fae2046b5bb6"
dependencies = [
"displaydoc",
"litemap",
"tinystr",
"writeable",
"zerovec",
]
[[package]]
name = "icu_normalizer"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f6c8828b67bf8908d82127b2054ea1b4427ff0230ee9141c54251934ab1b599"
dependencies = [
"icu_collections",
"icu_normalizer_data",
"icu_properties",
"icu_provider",
"smallvec",
"zerovec",
]
[[package]]
name = "icu_normalizer_data"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7aedcccd01fc5fe81e6b489c15b247b8b0690feb23304303a9e560f37efc560a"
[[package]]
name = "icu_properties"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e93fcd3157766c0c8da2f8cff6ce651a31f0810eaa1c51ec363ef790bbb5fb99"
dependencies = [
"icu_collections",
"icu_locale_core",
"icu_properties_data",
"icu_provider",
"zerotrie",
"zerovec",
]
[[package]]
name = "icu_properties_data"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02845b3647bb045f1100ecd6480ff52f34c35f82d9880e029d329c21d1054899"
[[package]]
name = "icu_provider"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "85962cf0ce02e1e0a629cc34e7ca3e373ce20dda4c4d7294bbd0bf1fdb59e614"
dependencies = [
"displaydoc",
"icu_locale_core",
"writeable",
"yoke",
"zerofrom",
"zerotrie",
"zerovec",
]
[[package]]
name = "id-arena"
version = "2.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "25a2bc672d1148e28034f176e01fffebb08b35768468cc954630da77a1449005"
[[package]]
name = "idna"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b0875f23caa03898994f6ddc501886a45c7d3d62d04d2d90788d47be1b1e4de"
dependencies = [
"idna_adapter",
"smallvec",
"utf8_iter",
]
[[package]]
name = "idna_adapter"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3acae9609540aa318d1bc588455225fb2085b9ed0c4f6bd0d9d5bcd86f1a0344"
dependencies = [
"icu_normalizer",
"icu_properties",
]
[[package]]
name = "indexmap"
version = "2.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ad4bb2b565bca0645f4d68c5c9af97fba094e9791da685bf83cb5f3ce74acf2"
dependencies = [
"equivalent",
"hashbrown 0.16.1",
"serde",
"serde_core",
]
[[package]]
name = "itoa"
version = "1.0.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c"
[[package]]
name = "leb128fmt"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2"
[[package]]
name = "litemap"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77"
[[package]]
name = "log"
version = "0.4.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897"
[[package]]
name = "memchr"
version = "2.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273"
[[package]]
name = "miniz_oxide"
version = "0.8.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316"
dependencies = [
"adler2",
"simd-adler32",
]
[[package]]
name = "once_cell"
version = "1.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
[[package]]
name = "openrouter"
version = "0.1.0"
dependencies = [
"serde",
"serde_json",
"zed_extension_api",
]
[[package]]
name = "percent-encoding"
version = "2.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220"
[[package]]
name = "pin-project-lite"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b"
[[package]]
name = "pin-utils"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]]
name = "potential_utf"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b73949432f5e2a09657003c25bca5e19a0e9c84f8058ca374f49e0ebe605af77"
dependencies = [
"zerovec",
]
[[package]]
name = "prettyplease"
version = "0.2.37"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "479ca8adacdd7ce8f1fb39ce9ecccbfe93a3f1344b3d0d97f20bc0196208f62b"
dependencies = [
"proc-macro2",
"syn",
]
[[package]]
name = "proc-macro2"
version = "1.0.103"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ee95bc4ef87b8d5ba32e8b7714ccc834865276eab0aed5c9958d00ec45f49e8"
dependencies = [
"unicode-ident",
]
[[package]]
name = "quote"
version = "1.0.42"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a338cc41d27e6cc6dce6cefc13a0729dfbb81c262b1f519331575dd80ef3067f"
dependencies = [
"proc-macro2",
]
[[package]]
name = "ryu"
version = "1.0.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f"
[[package]]
name = "semver"
version = "1.0.27"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d767eb0aabc880b29956c35734170f26ed551a859dbd361d140cdbeca61ab1e2"
dependencies = [
"serde",
"serde_core",
]
[[package]]
name = "serde"
version = "1.0.228"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e"
dependencies = [
"serde_core",
"serde_derive",
]
[[package]]
name = "serde_core"
version = "1.0.228"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.228"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "serde_json"
version = "1.0.145"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "402a6f66d8c709116cf22f558eab210f5a50187f702eb4d7e5ef38d9a7f1c79c"
dependencies = [
"itoa",
"memchr",
"ryu",
"serde",
"serde_core",
]
[[package]]
name = "simd-adler32"
version = "0.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d66dc143e6b11c1eddc06d5c423cfc97062865baf299914ab64caa38182078fe"
[[package]]
name = "slab"
version = "0.4.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589"
[[package]]
name = "smallvec"
version = "1.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67b1b7a3b5fe4f1376887184045fcf45c69e92af734b7aaddc05fb777b6fbd03"
[[package]]
name = "spdx"
version = "0.10.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c3e17e880bafaeb362a7b751ec46bdc5b61445a188f80e0606e68167cd540fa3"
dependencies = [
"smallvec",
]
[[package]]
name = "stable_deref_trait"
version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596"
[[package]]
name = "syn"
version = "2.0.111"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "390cc9a294ab71bdb1aa2e99d13be9c753cd2d7bd6560c77118597410c4d2e87"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "synstructure"
version = "0.13.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "tinystr"
version = "0.8.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42d3e9c45c09de15d06dd8acf5f4e0e399e85927b7f00711024eb7ae10fa4869"
dependencies = [
"displaydoc",
"zerovec",
]
[[package]]
name = "topological-sort"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ea68304e134ecd095ac6c3574494fc62b909f416c4fca77e440530221e549d3d"
[[package]]
name = "unicode-ident"
version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5"
[[package]]
name = "unicode-xid"
version = "0.2.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853"
[[package]]
name = "url"
version = "2.5.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08bc136a29a3d1758e07a9cca267be308aeebf5cfd5a10f3f67ab2097683ef5b"
dependencies = [
"form_urlencoded",
"idna",
"percent-encoding",
"serde",
]
[[package]]
name = "utf8_iter"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be"
[[package]]
name = "wasm-encoder"
version = "0.227.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "80bb72f02e7fbf07183443b27b0f3d4144abf8c114189f2e088ed95b696a7822"
dependencies = [
"leb128fmt",
"wasmparser",
]
[[package]]
name = "wasm-metadata"
version = "0.227.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ce1ef0faabbbba6674e97a56bee857ccddf942785a336c8b47b42373c922a91d"
dependencies = [
"anyhow",
"auditable-serde",
"flate2",
"indexmap",
"serde",
"serde_derive",
"serde_json",
"spdx",
"url",
"wasm-encoder",
"wasmparser",
]
[[package]]
name = "wasmparser"
version = "0.227.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f51cad774fb3c9461ab9bccc9c62dfb7388397b5deda31bf40e8108ccd678b2"
dependencies = [
"bitflags",
"hashbrown 0.15.5",
"indexmap",
"semver",
]
[[package]]
name = "wit-bindgen"
version = "0.41.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "10fb6648689b3929d56bbc7eb1acf70c9a42a29eb5358c67c10f54dbd5d695de"
dependencies = [
"wit-bindgen-rt",
"wit-bindgen-rust-macro",
]
[[package]]
name = "wit-bindgen-core"
version = "0.41.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92fa781d4f2ff6d3f27f3cc9b74a73327b31ca0dc4a3ef25a0ce2983e0e5af9b"
dependencies = [
"anyhow",
"heck",
"wit-parser",
]
[[package]]
name = "wit-bindgen-rt"
version = "0.41.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c4db52a11d4dfb0a59f194c064055794ee6564eb1ced88c25da2cf76e50c5621"
dependencies = [
"bitflags",
"futures",
"once_cell",
]
[[package]]
name = "wit-bindgen-rust"
version = "0.41.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d0809dc5ba19e2e98661bf32fc0addc5a3ca5bf3a6a7083aa6ba484085ff3ce"
dependencies = [
"anyhow",
"heck",
"indexmap",
"prettyplease",
"syn",
"wasm-metadata",
"wit-bindgen-core",
"wit-component",
]
[[package]]
name = "wit-bindgen-rust-macro"
version = "0.41.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad19eec017904e04c60719592a803ee5da76cb51c81e3f6fbf9457f59db49799"
dependencies = [
"anyhow",
"prettyplease",
"proc-macro2",
"quote",
"syn",
"wit-bindgen-core",
"wit-bindgen-rust",
]
[[package]]
name = "wit-component"
version = "0.227.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "635c3adc595422cbf2341a17fb73a319669cc8d33deed3a48368a841df86b676"
dependencies = [
"anyhow",
"bitflags",
"indexmap",
"log",
"serde",
"serde_derive",
"serde_json",
"wasm-encoder",
"wasm-metadata",
"wasmparser",
"wit-parser",
]
[[package]]
name = "wit-parser"
version = "0.227.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ddf445ed5157046e4baf56f9138c124a0824d4d1657e7204d71886ad8ce2fc11"
dependencies = [
"anyhow",
"id-arena",
"indexmap",
"log",
"semver",
"serde",
"serde_derive",
"serde_json",
"unicode-xid",
"wasmparser",
]
[[package]]
name = "writeable"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9edde0db4769d2dc68579893f2306b26c6ecfbe0ef499b013d731b7b9247e0b9"
[[package]]
name = "yoke"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72d6e5c6afb84d73944e5cedb052c4680d5657337201555f9f2a16b7406d4954"
dependencies = [
"stable_deref_trait",
"yoke-derive",
"zerofrom",
]
[[package]]
name = "yoke-derive"
version = "0.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b659052874eb698efe5b9e8cf382204678a0086ebf46982b79d6ca3182927e5d"
dependencies = [
"proc-macro2",
"quote",
"syn",
"synstructure",
]
[[package]]
name = "zed_extension_api"
version = "0.8.0"
dependencies = [
"serde",
"serde_json",
"wit-bindgen",
]
[[package]]
name = "zerofrom"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "50cc42e0333e05660c3587f3bf9d0478688e15d870fab3346451ce7f8c9fbea5"
dependencies = [
"zerofrom-derive",
]
[[package]]
name = "zerofrom-derive"
version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502"
dependencies = [
"proc-macro2",
"quote",
"syn",
"synstructure",
]
[[package]]
name = "zerotrie"
version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2a59c17a5562d507e4b54960e8569ebee33bee890c70aa3fe7b97e85a9fd7851"
dependencies = [
"displaydoc",
"yoke",
"zerofrom",
]
[[package]]
name = "zerovec"
version = "0.11.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c28719294829477f525be0186d13efa9a3c602f7ec202ca9e353d310fb9a002"
dependencies = [
"yoke",
"zerofrom",
"zerovec-derive",
]
[[package]]
name = "zerovec-derive"
version = "0.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eadce39539ca5cb3985590102671f2567e659fca9666581ad3411d59207951f3"
dependencies = [
"proc-macro2",
"quote",
"syn",
]

View File

@@ -0,0 +1,17 @@
[package]
name = "openrouter"
version = "0.1.0"
edition = "2021"
publish = false
license = "Apache-2.0"
[workspace]
[lib]
path = "src/open_router.rs"
crate-type = ["cdylib"]
[dependencies]
zed_extension_api = { path = "../../crates/extension_api" }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"

View File

@@ -0,0 +1,13 @@
id = "openrouter"
name = "OpenRouter"
description = "OpenRouter LLM provider - access multiple AI models through a unified API."
version = "0.1.0"
schema_version = 1
authors = ["Zed Team"]
repository = "https://github.com/zed-industries/zed"
[language_model_providers.openrouter]
name = "OpenRouter"
[language_model_providers.openrouter.auth]
env_var = "OPENROUTER_API_KEY"

View File

@@ -0,0 +1,8 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M2.54131 7.78012C2.89456 7.78012 4.25937 7.47507 4.96588 7.07512C5.67239 6.67517 5.67239 6.67517 7.13135 5.63951C8.97897 4.32817 10.2858 4.76729 12.4272 4.76729" fill="black"/>
<path d="M2.54131 7.78012C2.89456 7.78012 4.25937 7.47507 4.96588 7.07512C5.67239 6.67517 5.67239 6.67517 7.13135 5.63951C8.97897 4.32817 10.2858 4.76729 12.4272 4.76729" stroke="black" stroke-width="2.8125"/>
<path d="M14.4985 4.7801L10.8793 6.86949V2.6907L14.4985 4.7801Z" fill="black" stroke="black"/>
<path d="M2.47052 7.78088C2.82377 7.78088 4.18859 8.08593 4.8951 8.48588C5.60161 8.88583 5.6016 8.88583 7.06057 9.92149C8.90819 11.2328 10.2142 10.7937 12.3564 10.7937" fill="black"/>
<path d="M2.47052 7.78088C2.82377 7.78088 4.18859 8.08593 4.8951 8.48588C5.60161 8.88583 5.6016 8.88583 7.06057 9.92149C8.90819 11.2328 10.2142 10.7937 12.3564 10.7937" stroke="black" stroke-width="2.8125"/>
<path d="M14.4277 10.7809L10.8085 8.6915V12.8703L14.4277 10.7809Z" fill="black" stroke="black"/>
</svg>

After

Width:  |  Height:  |  Size: 1.1 KiB

View File

@@ -0,0 +1,749 @@
use std::collections::HashMap;
use std::sync::Mutex;
use serde::{Deserialize, Serialize};
use zed_extension_api::http_client::{HttpMethod, HttpRequest, HttpResponseStream, RedirectPolicy};
use zed_extension_api::{self as zed, *};
struct OpenRouterProvider {
streams: Mutex<HashMap<String, StreamState>>,
next_stream_id: Mutex<u64>,
}
struct StreamState {
response_stream: Option<HttpResponseStream>,
buffer: String,
started: bool,
tool_calls: HashMap<usize, AccumulatedToolCall>,
tool_calls_emitted: bool,
}
#[derive(Clone, Default)]
struct AccumulatedToolCall {
id: String,
name: String,
arguments: String,
}
struct ModelDefinition {
id: &'static str,
display_name: &'static str,
max_tokens: u64,
max_output_tokens: Option<u64>,
supports_images: bool,
supports_tools: bool,
is_default: bool,
is_default_fast: bool,
}
const MODELS: &[ModelDefinition] = &[
// Anthropic Models
ModelDefinition {
id: "anthropic/claude-sonnet-4",
display_name: "Claude Sonnet 4",
max_tokens: 200_000,
max_output_tokens: Some(8_192),
supports_images: true,
supports_tools: true,
is_default: true,
is_default_fast: false,
},
ModelDefinition {
id: "anthropic/claude-opus-4",
display_name: "Claude Opus 4",
max_tokens: 200_000,
max_output_tokens: Some(8_192),
supports_images: true,
supports_tools: true,
is_default: false,
is_default_fast: false,
},
ModelDefinition {
id: "anthropic/claude-haiku-4",
display_name: "Claude Haiku 4",
max_tokens: 200_000,
max_output_tokens: Some(8_192),
supports_images: true,
supports_tools: true,
is_default: false,
is_default_fast: true,
},
ModelDefinition {
id: "anthropic/claude-3.5-sonnet",
display_name: "Claude 3.5 Sonnet",
max_tokens: 200_000,
max_output_tokens: Some(8_192),
supports_images: true,
supports_tools: true,
is_default: false,
is_default_fast: false,
},
// OpenAI Models
ModelDefinition {
id: "openai/gpt-4o",
display_name: "GPT-4o",
max_tokens: 128_000,
max_output_tokens: Some(16_384),
supports_images: true,
supports_tools: true,
is_default: false,
is_default_fast: false,
},
ModelDefinition {
id: "openai/gpt-4o-mini",
display_name: "GPT-4o Mini",
max_tokens: 128_000,
max_output_tokens: Some(16_384),
supports_images: true,
supports_tools: true,
is_default: false,
is_default_fast: false,
},
ModelDefinition {
id: "openai/o1",
display_name: "o1",
max_tokens: 200_000,
max_output_tokens: Some(100_000),
supports_images: true,
supports_tools: false,
is_default: false,
is_default_fast: false,
},
ModelDefinition {
id: "openai/o3-mini",
display_name: "o3-mini",
max_tokens: 200_000,
max_output_tokens: Some(100_000),
supports_images: false,
supports_tools: false,
is_default: false,
is_default_fast: false,
},
// Google Models
ModelDefinition {
id: "google/gemini-2.0-flash-001",
display_name: "Gemini 2.0 Flash",
max_tokens: 1_000_000,
max_output_tokens: Some(8_192),
supports_images: true,
supports_tools: true,
is_default: false,
is_default_fast: false,
},
ModelDefinition {
id: "google/gemini-2.5-pro-preview",
display_name: "Gemini 2.5 Pro",
max_tokens: 1_000_000,
max_output_tokens: Some(8_192),
supports_images: true,
supports_tools: true,
is_default: false,
is_default_fast: false,
},
// Meta Models
ModelDefinition {
id: "meta-llama/llama-3.3-70b-instruct",
display_name: "Llama 3.3 70B",
max_tokens: 128_000,
max_output_tokens: Some(4_096),
supports_images: false,
supports_tools: true,
is_default: false,
is_default_fast: false,
},
ModelDefinition {
id: "meta-llama/llama-4-maverick",
display_name: "Llama 4 Maverick",
max_tokens: 128_000,
max_output_tokens: Some(4_096),
supports_images: true,
supports_tools: true,
is_default: false,
is_default_fast: false,
},
// Mistral Models
ModelDefinition {
id: "mistralai/mistral-large-2411",
display_name: "Mistral Large",
max_tokens: 128_000,
max_output_tokens: Some(4_096),
supports_images: false,
supports_tools: true,
is_default: false,
is_default_fast: false,
},
ModelDefinition {
id: "mistralai/codestral-latest",
display_name: "Codestral",
max_tokens: 32_000,
max_output_tokens: Some(4_096),
supports_images: false,
supports_tools: true,
is_default: false,
is_default_fast: false,
},
// DeepSeek Models
ModelDefinition {
id: "deepseek/deepseek-chat-v3-0324",
display_name: "DeepSeek V3",
max_tokens: 64_000,
max_output_tokens: Some(8_192),
supports_images: false,
supports_tools: true,
is_default: false,
is_default_fast: false,
},
ModelDefinition {
id: "deepseek/deepseek-r1",
display_name: "DeepSeek R1",
max_tokens: 64_000,
max_output_tokens: Some(8_192),
supports_images: false,
supports_tools: false,
is_default: false,
is_default_fast: false,
},
// Qwen Models
ModelDefinition {
id: "qwen/qwen3-235b-a22b",
display_name: "Qwen 3 235B",
max_tokens: 40_000,
max_output_tokens: Some(8_192),
supports_images: false,
supports_tools: true,
is_default: false,
is_default_fast: false,
},
];
fn get_model_definition(model_id: &str) -> Option<&'static ModelDefinition> {
MODELS.iter().find(|m| m.id == model_id)
}
#[derive(Serialize)]
struct OpenRouterRequest {
model: String,
messages: Vec<OpenRouterMessage>,
#[serde(skip_serializing_if = "Option::is_none")]
max_tokens: Option<u64>,
#[serde(skip_serializing_if = "Vec::is_empty")]
tools: Vec<OpenRouterTool>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_choice: Option<String>,
#[serde(skip_serializing_if = "Vec::is_empty")]
stop: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
temperature: Option<f32>,
stream: bool,
}
#[derive(Serialize)]
struct OpenRouterMessage {
role: String,
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<OpenRouterContent>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<OpenRouterToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_call_id: Option<String>,
}
#[derive(Serialize, Clone)]
#[serde(untagged)]
enum OpenRouterContent {
Text(String),
Parts(Vec<OpenRouterContentPart>),
}
#[derive(Serialize, Clone)]
#[serde(tag = "type")]
enum OpenRouterContentPart {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image_url")]
ImageUrl { image_url: ImageUrl },
}
#[derive(Serialize, Clone)]
struct ImageUrl {
url: String,
}
#[derive(Serialize, Clone)]
struct OpenRouterToolCall {
id: String,
#[serde(rename = "type")]
call_type: String,
function: OpenRouterFunctionCall,
}
#[derive(Serialize, Clone)]
struct OpenRouterFunctionCall {
name: String,
arguments: String,
}
#[derive(Serialize)]
struct OpenRouterTool {
#[serde(rename = "type")]
tool_type: String,
function: OpenRouterFunctionDef,
}
#[derive(Serialize)]
struct OpenRouterFunctionDef {
name: String,
description: String,
parameters: serde_json::Value,
}
#[derive(Deserialize, Debug)]
struct OpenRouterStreamResponse {
choices: Vec<OpenRouterStreamChoice>,
#[serde(default)]
usage: Option<OpenRouterUsage>,
}
#[derive(Deserialize, Debug)]
struct OpenRouterStreamChoice {
delta: OpenRouterDelta,
finish_reason: Option<String>,
}
#[derive(Deserialize, Debug, Default)]
struct OpenRouterDelta {
#[serde(default)]
content: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<OpenRouterToolCallDelta>>,
}
#[derive(Deserialize, Debug)]
struct OpenRouterToolCallDelta {
index: usize,
#[serde(default)]
id: Option<String>,
#[serde(default)]
function: Option<OpenRouterFunctionDelta>,
}
#[derive(Deserialize, Debug, Default)]
struct OpenRouterFunctionDelta {
#[serde(default)]
name: Option<String>,
#[serde(default)]
arguments: Option<String>,
}
#[derive(Deserialize, Debug)]
struct OpenRouterUsage {
prompt_tokens: u64,
completion_tokens: u64,
}
fn convert_request(
model_id: &str,
request: &LlmCompletionRequest,
) -> Result<OpenRouterRequest, String> {
let mut messages: Vec<OpenRouterMessage> = Vec::new();
for msg in &request.messages {
match msg.role {
LlmMessageRole::System => {
let mut text_content = String::new();
for content in &msg.content {
if let LlmMessageContent::Text(text) = content {
if !text_content.is_empty() {
text_content.push('\n');
}
text_content.push_str(text);
}
}
if !text_content.is_empty() {
messages.push(OpenRouterMessage {
role: "system".to_string(),
content: Some(OpenRouterContent::Text(text_content)),
tool_calls: None,
tool_call_id: None,
});
}
}
LlmMessageRole::User => {
let mut parts: Vec<OpenRouterContentPart> = Vec::new();
let mut tool_result_messages: Vec<OpenRouterMessage> = Vec::new();
for content in &msg.content {
match content {
LlmMessageContent::Text(text) => {
if !text.is_empty() {
parts.push(OpenRouterContentPart::Text { text: text.clone() });
}
}
LlmMessageContent::Image(img) => {
let data_url = format!("data:image/png;base64,{}", img.source);
parts.push(OpenRouterContentPart::ImageUrl {
image_url: ImageUrl { url: data_url },
});
}
LlmMessageContent::ToolResult(result) => {
let content_text = match &result.content {
LlmToolResultContent::Text(t) => t.clone(),
LlmToolResultContent::Image(_) => "[Image]".to_string(),
};
tool_result_messages.push(OpenRouterMessage {
role: "tool".to_string(),
content: Some(OpenRouterContent::Text(content_text)),
tool_calls: None,
tool_call_id: Some(result.tool_use_id.clone()),
});
}
_ => {}
}
}
if !parts.is_empty() {
let content = if parts.len() == 1 {
if let OpenRouterContentPart::Text { text } = &parts[0] {
OpenRouterContent::Text(text.clone())
} else {
OpenRouterContent::Parts(parts)
}
} else {
OpenRouterContent::Parts(parts)
};
messages.push(OpenRouterMessage {
role: "user".to_string(),
content: Some(content),
tool_calls: None,
tool_call_id: None,
});
}
messages.extend(tool_result_messages);
}
LlmMessageRole::Assistant => {
let mut text_content = String::new();
let mut tool_calls: Vec<OpenRouterToolCall> = Vec::new();
for content in &msg.content {
match content {
LlmMessageContent::Text(text) => {
if !text.is_empty() {
if !text_content.is_empty() {
text_content.push('\n');
}
text_content.push_str(text);
}
}
LlmMessageContent::ToolUse(tool_use) => {
tool_calls.push(OpenRouterToolCall {
id: tool_use.id.clone(),
call_type: "function".to_string(),
function: OpenRouterFunctionCall {
name: tool_use.name.clone(),
arguments: tool_use.input.clone(),
},
});
}
_ => {}
}
}
messages.push(OpenRouterMessage {
role: "assistant".to_string(),
content: if text_content.is_empty() {
None
} else {
Some(OpenRouterContent::Text(text_content))
},
tool_calls: if tool_calls.is_empty() {
None
} else {
Some(tool_calls)
},
tool_call_id: None,
});
}
}
}
let model_def = get_model_definition(model_id);
let supports_tools = model_def.map(|m| m.supports_tools).unwrap_or(true);
let tools: Vec<OpenRouterTool> = if supports_tools {
request
.tools
.iter()
.map(|t| OpenRouterTool {
tool_type: "function".to_string(),
function: OpenRouterFunctionDef {
name: t.name.clone(),
description: t.description.clone(),
parameters: serde_json::from_str(&t.input_schema)
.unwrap_or(serde_json::Value::Object(Default::default())),
},
})
.collect()
} else {
Vec::new()
};
let tool_choice = if supports_tools {
request.tool_choice.as_ref().map(|tc| match tc {
LlmToolChoice::Auto => "auto".to_string(),
LlmToolChoice::Any => "required".to_string(),
LlmToolChoice::None => "none".to_string(),
})
} else {
None
};
let max_tokens = request
.max_tokens
.or(model_def.and_then(|m| m.max_output_tokens));
Ok(OpenRouterRequest {
model: model_id.to_string(),
messages,
max_tokens,
tools,
tool_choice,
stop: request.stop_sequences.clone(),
temperature: request.temperature,
stream: true,
})
}
fn parse_sse_line(line: &str) -> Option<OpenRouterStreamResponse> {
let data = line.strip_prefix("data: ")?;
if data.trim() == "[DONE]" {
return None;
}
serde_json::from_str(data).ok()
}
impl zed::Extension for OpenRouterProvider {
fn new() -> Self {
Self {
streams: Mutex::new(HashMap::new()),
next_stream_id: Mutex::new(0),
}
}
fn llm_providers(&self) -> Vec<LlmProviderInfo> {
vec![LlmProviderInfo {
id: "openrouter".into(),
name: "OpenRouter".into(),
icon: Some("icons/openrouter.svg".into()),
}]
}
fn llm_provider_models(&self, _provider_id: &str) -> Result<Vec<LlmModelInfo>, String> {
Ok(MODELS
.iter()
.map(|m| LlmModelInfo {
id: m.id.to_string(),
name: m.display_name.to_string(),
max_token_count: m.max_tokens,
max_output_tokens: m.max_output_tokens,
capabilities: LlmModelCapabilities {
supports_images: m.supports_images,
supports_tools: m.supports_tools,
supports_tool_choice_auto: m.supports_tools,
supports_tool_choice_any: m.supports_tools,
supports_tool_choice_none: m.supports_tools,
supports_thinking: false,
tool_input_format: LlmToolInputFormat::JsonSchema,
},
is_default: m.is_default,
is_default_fast: m.is_default_fast,
})
.collect())
}
fn llm_provider_is_authenticated(&self, _provider_id: &str) -> bool {
llm_get_credential("open_router").is_some()
}
fn llm_provider_settings_markdown(&self, _provider_id: &str) -> Option<String> {
Some(
"To use OpenRouter, you need an API key. You can create one [here](https://openrouter.ai/keys).".to_string(),
)
}
fn llm_provider_reset_credentials(&mut self, _provider_id: &str) -> Result<(), String> {
llm_delete_credential("open_router")
}
fn llm_stream_completion_start(
&mut self,
_provider_id: &str,
model_id: &str,
request: &LlmCompletionRequest,
) -> Result<String, String> {
let api_key = llm_get_credential("open_router").ok_or_else(|| {
"No API key configured. Please add your OpenRouter API key in settings.".to_string()
})?;
let openrouter_request = convert_request(model_id, request)?;
let body = serde_json::to_vec(&openrouter_request)
.map_err(|e| format!("Failed to serialize request: {}", e))?;
let http_request = HttpRequest {
method: HttpMethod::Post,
url: "https://openrouter.ai/api/v1/chat/completions".to_string(),
headers: vec![
("Content-Type".to_string(), "application/json".to_string()),
("Authorization".to_string(), format!("Bearer {}", api_key)),
("HTTP-Referer".to_string(), "https://zed.dev".to_string()),
("X-Title".to_string(), "Zed Editor".to_string()),
],
body: Some(body),
redirect_policy: RedirectPolicy::FollowAll,
};
let response_stream = http_request
.fetch_stream()
.map_err(|e| format!("HTTP request failed: {}", e))?;
let stream_id = {
let mut id_counter = self.next_stream_id.lock().unwrap();
let id = format!("openrouter-stream-{}", *id_counter);
*id_counter += 1;
id
};
self.streams.lock().unwrap().insert(
stream_id.clone(),
StreamState {
response_stream: Some(response_stream),
buffer: String::new(),
started: false,
tool_calls: HashMap::new(),
tool_calls_emitted: false,
},
);
Ok(stream_id)
}
fn llm_stream_completion_next(
&mut self,
stream_id: &str,
) -> Result<Option<LlmCompletionEvent>, String> {
let mut streams = self.streams.lock().unwrap();
let state = streams
.get_mut(stream_id)
.ok_or_else(|| format!("Unknown stream: {}", stream_id))?;
if !state.started {
state.started = true;
return Ok(Some(LlmCompletionEvent::Started));
}
let response_stream = state
.response_stream
.as_mut()
.ok_or_else(|| "Stream already closed".to_string())?;
loop {
if let Some(newline_pos) = state.buffer.find('\n') {
let line = state.buffer[..newline_pos].to_string();
state.buffer = state.buffer[newline_pos + 1..].to_string();
if line.trim().is_empty() {
continue;
}
if let Some(response) = parse_sse_line(&line) {
if let Some(choice) = response.choices.first() {
if let Some(content) = &choice.delta.content {
if !content.is_empty() {
return Ok(Some(LlmCompletionEvent::Text(content.clone())));
}
}
if let Some(tool_calls) = &choice.delta.tool_calls {
for tc in tool_calls {
let entry = state
.tool_calls
.entry(tc.index)
.or_insert_with(AccumulatedToolCall::default);
if let Some(id) = &tc.id {
entry.id = id.clone();
}
if let Some(func) = &tc.function {
if let Some(name) = &func.name {
entry.name = name.clone();
}
if let Some(args) = &func.arguments {
entry.arguments.push_str(args);
}
}
}
}
if let Some(finish_reason) = &choice.finish_reason {
if !state.tool_calls.is_empty() && !state.tool_calls_emitted {
state.tool_calls_emitted = true;
let mut tool_calls: Vec<_> = state.tool_calls.drain().collect();
tool_calls.sort_by_key(|(idx, _)| *idx);
if let Some((_, tc)) = tool_calls.into_iter().next() {
return Ok(Some(LlmCompletionEvent::ToolUse(LlmToolUse {
id: tc.id,
name: tc.name,
input: tc.arguments,
thought_signature: None,
})));
}
}
let stop_reason = match finish_reason.as_str() {
"stop" => LlmStopReason::EndTurn,
"length" => LlmStopReason::MaxTokens,
"tool_calls" => LlmStopReason::ToolUse,
"content_filter" => LlmStopReason::Refusal,
_ => LlmStopReason::EndTurn,
};
return Ok(Some(LlmCompletionEvent::Stop(stop_reason)));
}
}
if let Some(usage) = response.usage {
return Ok(Some(LlmCompletionEvent::Usage(LlmTokenUsage {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
cache_creation_input_tokens: None,
cache_read_input_tokens: None,
})));
}
}
continue;
}
match response_stream.next_chunk() {
Ok(Some(chunk)) => {
let text = String::from_utf8_lossy(&chunk);
state.buffer.push_str(&text);
}
Ok(None) => {
return Ok(None);
}
Err(e) => {
return Err(format!("Stream error: {}", e));
}
}
}
}
fn llm_stream_completion_close(&mut self, stream_id: &str) {
self.streams.lock().unwrap().remove(stream_id);
}
}
zed::register_extension!(OpenRouterProvider);