More Gemini extension fixes
This commit is contained in:
@@ -66,12 +66,16 @@ use util::{ResultExt, paths::RemotePathBuf};
|
||||
use wasm_host::llm_provider::ExtensionLanguageModelProvider;
|
||||
use wasm_host::{
|
||||
WasmExtension, WasmHost,
|
||||
wit::{LlmModelInfo, LlmProviderInfo, is_supported_wasm_api_version, wasm_api_version_range},
|
||||
wit::{
|
||||
LlmCacheConfiguration, LlmModelInfo, LlmProviderInfo, is_supported_wasm_api_version,
|
||||
wasm_api_version_range,
|
||||
},
|
||||
};
|
||||
|
||||
struct LlmProviderWithModels {
|
||||
provider_info: LlmProviderInfo,
|
||||
models: Vec<LlmModelInfo>,
|
||||
cache_configs: collections::HashMap<String, LlmCacheConfiguration>,
|
||||
is_authenticated: bool,
|
||||
icon_path: Option<SharedString>,
|
||||
auth_config: Option<extension::LanguageModelAuthConfig>,
|
||||
@@ -1635,6 +1639,32 @@ impl ExtensionStore {
|
||||
}
|
||||
};
|
||||
|
||||
// Query cache configurations for each model
|
||||
let mut cache_configs = collections::HashMap::default();
|
||||
for model in &models {
|
||||
let cache_config_result = wasm_extension
|
||||
.call({
|
||||
let provider_id = provider_info.id.clone();
|
||||
let model_id = model.id.clone();
|
||||
|ext, store| {
|
||||
async move {
|
||||
ext.call_llm_cache_configuration(
|
||||
store,
|
||||
&provider_id,
|
||||
&model_id,
|
||||
)
|
||||
.await
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
if let Ok(Ok(Some(config))) = cache_config_result {
|
||||
cache_configs.insert(model.id.clone(), config);
|
||||
}
|
||||
}
|
||||
|
||||
// Query initial authentication state
|
||||
let is_authenticated = wasm_extension
|
||||
.call({
|
||||
@@ -1677,6 +1707,7 @@ impl ExtensionStore {
|
||||
llm_providers_with_models.push(LlmProviderWithModels {
|
||||
provider_info,
|
||||
models,
|
||||
cache_configs,
|
||||
is_authenticated,
|
||||
icon_path,
|
||||
auth_config,
|
||||
@@ -1776,6 +1807,7 @@ impl ExtensionStore {
|
||||
let wasm_ext = extension.as_ref().clone();
|
||||
let pinfo = llm_provider.provider_info.clone();
|
||||
let mods = llm_provider.models.clone();
|
||||
let cache_cfgs = llm_provider.cache_configs.clone();
|
||||
let auth = llm_provider.is_authenticated;
|
||||
let icon = llm_provider.icon_path.clone();
|
||||
let auth_config = llm_provider.auth_config.clone();
|
||||
@@ -1784,7 +1816,7 @@ impl ExtensionStore {
|
||||
provider_id.clone(),
|
||||
Box::new(move |cx: &mut App| {
|
||||
let provider = Arc::new(ExtensionLanguageModelProvider::new(
|
||||
wasm_ext, pinfo, mods, auth, icon, auth_config, cx,
|
||||
wasm_ext, pinfo, mods, cache_cfgs, auth, icon, auth_config, cx,
|
||||
));
|
||||
language_model::LanguageModelRegistry::global(cx).update(
|
||||
cx,
|
||||
|
||||
@@ -5,11 +5,12 @@ use crate::wasm_host::wit::LlmDeviceFlowPromptInfo;
|
||||
use collections::HashSet;
|
||||
|
||||
use crate::wasm_host::wit::{
|
||||
LlmCompletionEvent, LlmCompletionRequest, LlmImageData, LlmMessageContent, LlmMessageRole,
|
||||
LlmModelInfo, LlmProviderInfo, LlmRequestMessage, LlmStopReason, LlmThinkingContent,
|
||||
LlmToolChoice, LlmToolDefinition, LlmToolInputFormat, LlmToolResult, LlmToolResultContent,
|
||||
LlmToolUse,
|
||||
LlmCacheConfiguration, LlmCompletionEvent, LlmCompletionRequest, LlmImageData,
|
||||
LlmMessageContent, LlmMessageRole, LlmModelInfo, LlmProviderInfo, LlmRequestMessage,
|
||||
LlmStopReason, LlmThinkingContent, LlmToolChoice, LlmToolDefinition, LlmToolInputFormat,
|
||||
LlmToolResult, LlmToolResultContent, LlmToolUse,
|
||||
};
|
||||
use collections::HashMap;
|
||||
use anyhow::{Result, anyhow};
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use extension::{LanguageModelAuthConfig, OAuthConfig};
|
||||
@@ -58,6 +59,8 @@ pub struct ExtensionLanguageModelProvider {
|
||||
pub struct ExtensionLlmProviderState {
|
||||
is_authenticated: bool,
|
||||
available_models: Vec<LlmModelInfo>,
|
||||
/// Cache configurations for each model, keyed by model ID.
|
||||
cache_configs: HashMap<String, LlmCacheConfiguration>,
|
||||
/// Set of env var names that are allowed to be read for this provider.
|
||||
allowed_env_vars: HashSet<String>,
|
||||
/// If authenticated via env var, which one was used.
|
||||
@@ -71,6 +74,7 @@ impl ExtensionLanguageModelProvider {
|
||||
extension: WasmExtension,
|
||||
provider_info: LlmProviderInfo,
|
||||
models: Vec<LlmModelInfo>,
|
||||
cache_configs: HashMap<String, LlmCacheConfiguration>,
|
||||
is_authenticated: bool,
|
||||
icon_path: Option<SharedString>,
|
||||
auth_config: Option<LanguageModelAuthConfig>,
|
||||
@@ -118,6 +122,7 @@ impl ExtensionLanguageModelProvider {
|
||||
let state = cx.new(|_| ExtensionLlmProviderState {
|
||||
is_authenticated,
|
||||
available_models: models,
|
||||
cache_configs,
|
||||
allowed_env_vars,
|
||||
env_var_name_used,
|
||||
});
|
||||
@@ -139,6 +144,30 @@ impl ExtensionLanguageModelProvider {
|
||||
fn credential_key(&self) -> String {
|
||||
format!("extension-llm-{}", self.provider_id_string())
|
||||
}
|
||||
|
||||
fn create_model(
|
||||
&self,
|
||||
model_info: &LlmModelInfo,
|
||||
cache_configs: &HashMap<String, LlmCacheConfiguration>,
|
||||
) -> Arc<dyn LanguageModel> {
|
||||
let cache_config = cache_configs.get(&model_info.id).map(|config| {
|
||||
LanguageModelCacheConfiguration {
|
||||
max_cache_anchors: config.max_cache_anchors as usize,
|
||||
should_speculate: false,
|
||||
min_total_token: config.min_total_token_count,
|
||||
}
|
||||
});
|
||||
|
||||
Arc::new(ExtensionLanguageModel {
|
||||
extension: self.extension.clone(),
|
||||
model_info: model_info.clone(),
|
||||
provider_id: self.id(),
|
||||
provider_name: self.name(),
|
||||
provider_info: self.provider_info.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
cache_config,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl LanguageModelProvider for ExtensionLanguageModelProvider {
|
||||
@@ -165,16 +194,7 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider {
|
||||
.iter()
|
||||
.find(|m| m.is_default)
|
||||
.or_else(|| state.available_models.first())
|
||||
.map(|model_info| {
|
||||
Arc::new(ExtensionLanguageModel {
|
||||
extension: self.extension.clone(),
|
||||
model_info: model_info.clone(),
|
||||
provider_id: self.id(),
|
||||
provider_name: self.name(),
|
||||
provider_info: self.provider_info.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
}) as Arc<dyn LanguageModel>
|
||||
})
|
||||
.map(|model_info| self.create_model(model_info, &state.cache_configs))
|
||||
}
|
||||
|
||||
fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
|
||||
@@ -183,16 +203,7 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider {
|
||||
.available_models
|
||||
.iter()
|
||||
.find(|m| m.is_default_fast)
|
||||
.map(|model_info| {
|
||||
Arc::new(ExtensionLanguageModel {
|
||||
extension: self.extension.clone(),
|
||||
model_info: model_info.clone(),
|
||||
provider_id: self.id(),
|
||||
provider_name: self.name(),
|
||||
provider_info: self.provider_info.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
}) as Arc<dyn LanguageModel>
|
||||
})
|
||||
.map(|model_info| self.create_model(model_info, &state.cache_configs))
|
||||
}
|
||||
|
||||
fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
|
||||
@@ -200,16 +211,7 @@ impl LanguageModelProvider for ExtensionLanguageModelProvider {
|
||||
state
|
||||
.available_models
|
||||
.iter()
|
||||
.map(|model_info| {
|
||||
Arc::new(ExtensionLanguageModel {
|
||||
extension: self.extension.clone(),
|
||||
model_info: model_info.clone(),
|
||||
provider_id: self.id(),
|
||||
provider_name: self.name(),
|
||||
provider_info: self.provider_info.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
}) as Arc<dyn LanguageModel>
|
||||
})
|
||||
.map(|model_info| self.create_model(model_info, &state.cache_configs))
|
||||
.collect()
|
||||
}
|
||||
|
||||
@@ -1595,6 +1597,7 @@ pub struct ExtensionLanguageModel {
|
||||
provider_name: LanguageModelProviderName,
|
||||
provider_info: LlmProviderInfo,
|
||||
request_limiter: RateLimiter,
|
||||
cache_config: Option<LanguageModelCacheConfiguration>,
|
||||
}
|
||||
|
||||
impl LanguageModel for ExtensionLanguageModel {
|
||||
@@ -1615,7 +1618,7 @@ impl LanguageModel for ExtensionLanguageModel {
|
||||
}
|
||||
|
||||
fn telemetry_id(&self) -> String {
|
||||
format!("extension-{}", self.model_info.id)
|
||||
format!("{}/{}", self.provider_info.id, self.model_info.id)
|
||||
}
|
||||
|
||||
fn supports_images(&self) -> bool {
|
||||
@@ -1795,8 +1798,7 @@ impl LanguageModel for ExtensionLanguageModel {
|
||||
}
|
||||
|
||||
fn cache_configuration(&self) -> Option<LanguageModelCacheConfiguration> {
|
||||
// Extensions can implement this via llm_cache_configuration
|
||||
None
|
||||
self.cache_config.clone()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@ schema_version = 1
|
||||
authors = ["Zed Team"]
|
||||
repository = "https://github.com/zed-industries/zed"
|
||||
|
||||
[language_model_providers.google-ai]
|
||||
[language_model_providers.google]
|
||||
name = "Google AI"
|
||||
|
||||
[language_model_providers.google-ai.auth]
|
||||
[language_model_providers.google.auth]
|
||||
env_vars = ["GEMINI_API_KEY", "GOOGLE_AI_API_KEY"]
|
||||
|
||||
@@ -128,7 +128,7 @@ fn validate_generate_content_request(request: &GenerateContentRequest) -> Result
|
||||
|
||||
// Extension implementation
|
||||
|
||||
const PROVIDER_ID: &str = "google-ai";
|
||||
const PROVIDER_ID: &str = "google";
|
||||
const PROVIDER_NAME: &str = "Google AI";
|
||||
|
||||
struct GoogleAiExtension {
|
||||
@@ -343,7 +343,7 @@ fn get_default_models() -> Vec<LlmModelInfo> {
|
||||
supports_tool_choice_auto: true,
|
||||
supports_tool_choice_any: true,
|
||||
supports_tool_choice_none: true,
|
||||
supports_thinking: true,
|
||||
supports_thinking: false,
|
||||
tool_input_format: LlmToolInputFormat::JsonSchemaSubset,
|
||||
},
|
||||
is_default: false,
|
||||
|
||||
Reference in New Issue
Block a user