More Gemini extension fixes

This commit is contained in:
Richard Feldman
2025-12-17 18:21:02 -05:00
parent 19833f0132
commit ca8279ca79
4 changed files with 77 additions and 43 deletions

View File

@@ -66,12 +66,16 @@ use util::{ResultExt, paths::RemotePathBuf};
use wasm_host::llm_provider::ExtensionLanguageModelProvider;
use wasm_host::{
WasmExtension, WasmHost,
wit::{LlmModelInfo, LlmProviderInfo, is_supported_wasm_api_version, wasm_api_version_range},
wit::{
LlmCacheConfiguration, LlmModelInfo, LlmProviderInfo, is_supported_wasm_api_version,
wasm_api_version_range,
},
};
struct LlmProviderWithModels {
provider_info: LlmProviderInfo,
models: Vec<LlmModelInfo>,
cache_configs: collections::HashMap<String, LlmCacheConfiguration>,
is_authenticated: bool,
icon_path: Option<SharedString>,
auth_config: Option<extension::LanguageModelAuthConfig>,
@@ -1635,6 +1639,32 @@ impl ExtensionStore {
}
};
// Query cache configurations for each model
let mut cache_configs = collections::HashMap::default();
for model in &models {
let cache_config_result = wasm_extension
.call({
let provider_id = provider_info.id.clone();
let model_id = model.id.clone();
|ext, store| {
async move {
ext.call_llm_cache_configuration(
store,
&provider_id,
&model_id,
)
.await
}
.boxed()
}
})
.await;
if let Ok(Ok(Some(config))) = cache_config_result {
cache_configs.insert(model.id.clone(), config);
}
}
// Query initial authentication state
let is_authenticated = wasm_extension
.call({
@@ -1677,6 +1707,7 @@ impl ExtensionStore {
llm_providers_with_models.push(LlmProviderWithModels {
provider_info,
models,
cache_configs,
is_authenticated,
icon_path,
auth_config,
@@ -1776,6 +1807,7 @@ impl ExtensionStore {
let wasm_ext = extension.as_ref().clone();
let pinfo = llm_provider.provider_info.clone();
let mods = llm_provider.models.clone();
let cache_cfgs = llm_provider.cache_configs.clone();
let auth = llm_provider.is_authenticated;
let icon = llm_provider.icon_path.clone();
let auth_config = llm_provider.auth_config.clone();
@@ -1784,7 +1816,7 @@ impl ExtensionStore {
provider_id.clone(),
Box::new(move |cx: &mut App| {
let provider = Arc::new(ExtensionLanguageModelProvider::new(
wasm_ext, pinfo, mods, auth, icon, auth_config, cx,
wasm_ext, pinfo, mods, cache_cfgs, auth, icon, auth_config, cx,
));
language_model::LanguageModelRegistry::global(cx).update(
cx,

View File

@@ -5,11 +5,12 @@ use crate::wasm_host::wit::LlmDeviceFlowPromptInfo;
use collections::HashSet;
use crate::wasm_host::wit::{
LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole,
LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent,
LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent,
LlmToolUse,
LlmCacheConfiguration, LlmCompletionEvent, LlmCompletionRequest, LlmImageData,
LlmMessageContent, LlmMessageRole, LlmModelInfo, LlmProviderInfo, LlmRequestMessage,
LlmStopReason, LlmThinkingContent, LlmToolChoice, LlmToolDefinition, LlmToolInputFormat,
LlmToolResult, LlmToolResultContent, LlmToolUse,
};
use collections::HashMap;
use anyhow::{Result, anyhow};
use credentials_provider::CredentialsProvider;
use extension::{LanguageModelAuthConfig, OAuthConfig};
@@ -58,6 +59,8 @@ pub struct ExtensionLanguageModelProvider {
pub struct ExtensionLlmProviderState {
is_authenticated: bool,
available_models: Vec<LlmModelInfo>,
/// Cache configurations for each model, keyed by model ID.
cache_configs: HashMap<String, LlmCacheConfiguration>,
/// Set of env var names that are allowed to be read for this provider.
allowed_env_vars: HashSet<String>,
/// If authenticated via env var, which one was used.
@@ -71,6 +74,7 @@ impl ExtensionLanguageModelProvider {
extension: WasmExtension,
provider_info: LlmProviderInfo,
models: Vec<LlmModelInfo>,
cache_configs: HashMap<String, LlmCacheConfiguration>,
is_authenticated: bool,
icon_path: Option<SharedString>,
auth_config: Option<LanguageModelAuthConfig>,
@@ -118,6 +122,7 @@ impl ExtensionLanguageModelProvider {
let state = cx.new(|_| ExtensionLlmProviderState {
is_authenticated,
available_models: models,
cache_configs,
allowed_env_vars,
env_var_name_used,
});
@@ -139,6 +144,30 @@ impl ExtensionLanguageModelProvider {
fn credential_key(&self) -> String {
format!("extension-llm-{}", self.provider_id_string())
}
fn create_model(
&self,
model_info: &LlmModelInfo,
cache_configs: &HashMap<String, LlmCacheConfiguration>,
) -> Arc<dyn LanguageModel> {
let cache_config = cache_configs.get(&model_info.id).map(|config| {
LanguageModelCacheConfiguration {
max_cache_anchors: config.max_cache_anchors as usize,
should_speculate: false,
min_total_token: config.min_total_token_count,
}
});
Arc::new(ExtensionLanguageModel {
extension: self.extension.clone(),
model_info: model_info.clone(),
provider_id: self.id(),
provider_name: self.name(),
provider_info: self.provider_info.clone(),
request_limiter: RateLimiter::new(4),
cache_config,
})
}
}
impl LanguageModelProvider for ExtensionLanguageModelProvider {
@@ -165,16 +194,7 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider {
.iter()
.find(|m| m.is_default)
.or_else(|| state.available_models.first())
.map(|model_info| {
Arc::new(ExtensionLanguageModel {
extension: self.extension.clone(),
model_info: model_info.clone(),
provider_id: self.id(),
provider_name: self.name(),
provider_info: self.provider_info.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.map(|model_info| self.create_model(model_info, &state.cache_configs))
}
fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
@@ -183,16 +203,7 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider {
.available_models
.iter()
.find(|m| m.is_default_fast)
.map(|model_info| {
Arc::new(ExtensionLanguageModel {
extension: self.extension.clone(),
model_info: model_info.clone(),
provider_id: self.id(),
provider_name: self.name(),
provider_info: self.provider_info.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.map(|model_info| self.create_model(model_info, &state.cache_configs))
}
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
@@ -200,16 +211,7 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider {
state
.available_models
.iter()
.map(|model_info| {
Arc::new(ExtensionLanguageModel {
extension: self.extension.clone(),
model_info: model_info.clone(),
provider_id: self.id(),
provider_name: self.name(),
provider_info: self.provider_info.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
})
.map(|model_info| self.create_model(model_info, &state.cache_configs))
.collect()
}
@@ -1595,6 +1597,7 @@ pub struct ExtensionLanguageModel {
provider_name: LanguageModelProviderName,
provider_info: LlmProviderInfo,
request_limiter: RateLimiter,
cache_config: Option<LanguageModelCacheConfiguration>,
}
impl LanguageModel for ExtensionLanguageModel {
@@ -1615,7 +1618,7 @@ impl LanguageModel for ExtensionLanguageModel {
}
fn telemetry_id(&self) -> String {
format!("extension-{}", self.model_info.id)
format!("{}/{}", self.provider_info.id, self.model_info.id)
}
fn supports_images(&self) -> bool {
@@ -1795,8 +1798,7 @@ impl LanguageModel for ExtensionLanguageModel {
}
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
// Extensions can implement this via llm_cache_configuration
None
self.cache_config.clone()
}
}

View File

@@ -6,8 +6,8 @@ schema_version = 1
authors = ["Zed Team"]
repository = "https://github.com/zed-industries/zed"
[language_model_providers.google-ai]
[language_model_providers.google]
name = "Google AI"
[language_model_providers.google-ai.auth]
[language_model_providers.google.auth]
env_vars = ["GEMINI_API_KEY", "GOOGLE_AI_API_KEY"]

View File

@@ -128,7 +128,7 @@ fn validate_generate_content_request(request: &GenerateContentRequest) -> Result
// Extension implementation
const PROVIDER_ID: &str = "google-ai";
const PROVIDER_ID: &str = "google";
const PROVIDER_NAME: &str = "Google AI";
struct GoogleAiExtension {
@@ -343,7 +343,7 @@ fn get_default_models() -> Vec<LlmModelInfo> {
supports_tool_choice_auto: true,
supports_tool_choice_any: true,
supports_tool_choice_none: true,
supports_thinking: true,
supports_thinking: false,
tool_input_format: LlmToolInputFormat::JsonSchemaSubset,
},
is_default: false,