Compare commits

...

34 Commits

Author SHA1 Message Date
Michael Sloan
fc3a67f046 Fix mild memory leak for failed cache tasks 2025-05-04 22:32:22 +02:00
Michael Sloan
670813b9de Remove old ideation comment 2025-05-04 22:32:09 +02:00
Michael Sloan
61d615549f Clippy + other polish 2025-05-04 22:19:06 +02:00
Michael Sloan
3931d74275 Merge branch 'main' into gemini-caching 2025-05-04 22:12:18 +02:00
Michael Sloan
2cea66c5cb Change some todo! to TODO 2025-05-04 18:11:09 +02:00
Michael Sloan
09e26204ad Update cache TTL if it already exists 2025-05-04 18:09:49 +02:00
Michael Sloan
bfea3e5285 Remove todo about special handling of response for cache missing 2025-05-04 14:19:36 +02:00
Michael Sloan
f704e0578a Check cache creation 404 text 2025-05-04 14:14:52 +02:00
Michael Sloan
04b16fedde Reorganize 2025-05-04 14:05:47 +02:00
Michael Sloan
bde4bd5b3d Keep track of models lacking caching, based on cache creation status code 2025-05-03 08:11:49 -06:00
Michael Sloan
85e26b7f02 Resolve some todo! + other polish 2025-05-03 07:48:13 -06:00
Michael Sloan
0cc51f72e5 Fix compilation 2025-05-03 05:00:13 -06:00
Michael Sloan
abbb3dc7a6 Attempt to drop expired cache entries. Untested - writteon on a plane 2025-05-02 14:07:43 -06:00
Michael Sloan
141e0a702a Remove logic for checking if a model supports caching 2025-05-01 21:37:54 -06:00
Michael Sloan
bb47b766a1 Stable IDs for gemini-1.5 variants 2025-05-01 21:37:10 -06:00
Michael Sloan
515cdb9ae6 Progress towards better cache awaiting, written without internet on an airplane 2025-05-01 21:25:15 -06:00
Michael Sloan
2666ff7873 Remove some agent generated comments 2025-05-01 09:11:52 -06:00
Michael Sloan
1d0cc37205 Progress towards blocking on a specific cache
Co-authored-by: Max <max@zed.dev>
2025-04-30 23:26:47 -06:00
Michael Sloan
1257d44998 Initial implementation of also caching every agent request
Co-authored-by: Max <max@zed.dev>
2025-04-30 16:43:42 -06:00
Michael Sloan
5fdcdc1926 Add missing fields in provider code 2025-04-30 12:13:54 -06:00
Michael Sloan
8c91fc3153 Use model IDs which support caching for Gemini 2025-04-30 12:13:50 -06:00
Michael Sloan
b1595dba71 Comment out some code from interface sketching 2025-04-30 11:52:08 -06:00
Michael Sloan
d4702209ea Support for using cache in generation requests etc 2025-04-30 00:40:49 -06:00
Michael Sloan
8440ec03ad Fixes after merge 2025-04-29 23:32:51 -06:00
Michael Sloan
8c8eabe96d Merge branch 'fix-gemini-token-counting' into gemini-caching 2025-04-29 23:23:49 -06:00
Michael Sloan
10dfa36c91 Fix Gemini token counting + add support for counting whole requests
* Now provides the model id in the path instead of always `gemini-pro`, which appears to have stopped working.

* `CountTokensRequest` now takes a full `GenerateContentRequest` instead of just content.

* Fixes handling of `models/` prefix in `model` field of `GenerateContentRequest`, since that's required for use in `CountTokensRequest`. This didn't cause issues before because it was always cleared and used in the path.
2025-04-29 23:19:57 -06:00
Michael Sloan
88c7893913 Merge branch 'main' into gemini-caching 2025-04-29 23:03:14 -06:00
Michael Sloan
c2151f0082 Progress 2025-04-29 23:01:16 -06:00
Michael Sloan
d677117a48 Undo a change from Option<Vec<Tool>> -> Vec<Tool>
While this is safe, it is not safe on the llm worker side as it needs to deserialize `null` from older Zeds. Better to keep the definitions consistent.
2025-04-29 22:45:58 -06:00
Michael Sloan
f86b552a20 Merge branch 'main' into gemini-caching 2025-04-29 19:06:57 -06:00
Michael Sloan
1c6040e54f Add --prompt-file + improve request types 2025-04-29 14:17:37 -06:00
Michael Sloan
b861e1ca8c Add a cli example for google ai API 2025-04-29 13:58:49 -06:00
Michael Sloan
961e7dd52a Wrap gemini cache creation and update APIs 2025-04-29 13:03:06 -06:00
Michael Sloan
9c548fecbc WIP 2025-04-29 11:05:59 -06:00
13 changed files with 1362 additions and 80 deletions

7
Cargo.lock generated
View File

@@ -6198,12 +6198,17 @@ name = "google_ai"
version = "0.1.0"
dependencies = [
"anyhow",
"clap",
"futures 0.3.31",
"http_client",
"log",
"reqwest_client",
"schemars",
"serde",
"serde_json",
"strum 0.27.1",
"time",
"tokio",
"workspace-hack",
]
@@ -7909,6 +7914,7 @@ dependencies = [
"mistral",
"ollama",
"open_ai",
"parking_lot",
"partial-json-fixer",
"project",
"proto",
@@ -7921,6 +7927,7 @@ dependencies = [
"theme",
"thiserror 2.0.12",
"tiktoken-rs",
"time",
"tokio",
"ui",
"util",

View File

@@ -107,6 +107,10 @@ impl Model {
}
}
pub fn matches_id(&self, other_id: &str) -> bool {
self.id() == other_id
}
/// The id of the model that should be used for making API requests
pub fn request_id(&self) -> &str {
match self {

View File

@@ -18,8 +18,15 @@ schemars = ["dep:schemars"]
anyhow.workspace = true
futures.workspace = true
http_client.workspace = true
log.workspace = true
schemars = { workspace = true, optional = true }
serde.workspace = true
serde_json.workspace = true
strum.workspace = true
time.workspace = true
workspace-hack.workspace = true
[dev-dependencies]
clap = { version = "4.4", features = ["derive", "env"] }
tokio = { version = "1.28", features = ["full"] }
reqwest_client.workspace = true

View File

@@ -0,0 +1,482 @@
//! CLI for interacting with Gemini, 99% generated by Zed + Claude 3.7 Sonnet. Feel free to delete
//! it rather than maintain, it took 10 minutes.
use anyhow::{Result, anyhow};
use clap::{Parser, Subcommand};
use futures::StreamExt;
use google_ai::{
CacheName, Content, CountTokensRequest, CreateCacheRequest, CreateCacheResponse,
GenerateContentRequest, GenerationConfig, ModelName, Part, Role, SystemInstruction, TextPart,
UpdateCacheRequest, count_tokens, create_cache, stream_generate_content, update_cache,
};
use reqwest_client::ReqwestClient;
use std::io::Write;
use std::{fs, io, path::PathBuf, sync::Arc, time::Duration};
#[derive(Parser)]
#[command(name = "google_ai_cli")]
#[command(author = "Zed Team")]
#[command(version = "0.1.0")]
#[command(about = "Interface with the Google Generative AI API", long_about = None)]
struct Cli {
/// Gemini API key
#[arg(long, env = "GEMINI_API_KEY")]
api_key: String,
/// API URL (defaults to https://generativelanguage.googleapis.com)
#[arg(long, global = true, default_value = google_ai::API_URL)]
api_url: String,
/// The model to use
#[arg(long, global = true, default_value = "gemini-1.5-flash-002")]
model: String,
#[command(subcommand)]
command: Commands,
}
#[derive(Subcommand)]
enum Commands {
/// Generate content from a prompt
Generate {
/// The prompt text
#[arg(long, conflicts_with = "prompt_file")]
prompt: Option<String>,
/// File containing the prompt text
#[arg(long, conflicts_with = "prompt")]
prompt_file: Option<PathBuf>,
/// System instruction for the model
#[arg(long)]
system_instruction: Option<String>,
/// Maximum number of tokens to generate
#[arg(long)]
max_tokens: Option<usize>,
/// Temperature for generation (0.0 to 1.0)
#[arg(long)]
temperature: Option<f64>,
/// Top-p sampling parameter (0.0 to 1.0)
#[arg(long)]
top_p: Option<f64>,
/// Top-k sampling parameter
#[arg(long)]
top_k: Option<usize>,
/// Use cached content by specifying the cache name
#[arg(long)]
cached_content: Option<String>,
},
/// Count tokens in a prompt
CountTokens {
/// The prompt text
#[arg(long, conflicts_with = "prompt_file")]
prompt: Option<String>,
/// File containing the prompt text
#[arg(long, conflicts_with = "prompt")]
prompt_file: Option<PathBuf>,
},
/// Cache content for faster repeated access
CreateCache {
/// The prompt text
#[arg(long, conflicts_with = "prompt_file")]
prompt: Option<String>,
/// File containing the prompt text
#[arg(long, conflicts_with = "prompt")]
prompt_file: Option<PathBuf>,
/// System instruction for the model
#[arg(long)]
system_instruction: Option<String>,
/// Time-to-live for the cache in seconds
#[arg(long, default_value = "3600")]
ttl: u64,
},
/// Update cache TTL
UpdateCache {
/// The cache name to update
#[arg(long)]
cache_id: String,
/// New time-to-live for the cache in seconds
#[arg(long)]
ttl: u64,
},
/// Interactive conversation with the model
Chat {
/// System instruction for the model
#[arg(long)]
system_instruction: Option<String>,
/// Maximum number of tokens to generate
#[arg(long)]
max_tokens: Option<usize>,
/// Temperature for generation (0.0 to 1.0)
#[arg(long)]
temperature: Option<f64>,
/// Load a JSON file containing chat history
#[arg(long)]
history_file: Option<PathBuf>,
},
}
// Helper function to get prompt text from either prompt or prompt_file
fn get_prompt_text(prompt: &Option<String>, prompt_file: &Option<PathBuf>) -> Result<String> {
match (prompt, prompt_file) {
(Some(text), None) => Ok(text.clone()),
(None, Some(file_path)) => {
fs::read_to_string(file_path).map_err(|e| anyhow!("Failed to read prompt file: {}", e))
}
(None, None) => Err(anyhow!(
"Either --prompt or --prompt-file must be specified"
)),
(Some(_), Some(_)) => Err(anyhow!("Cannot specify both --prompt and --prompt-file")),
}
}
#[tokio::main]
async fn main() -> Result<()> {
let cli = Cli::parse();
let http_client = Arc::new(ReqwestClient::new());
match &cli.command {
Commands::Generate {
prompt,
prompt_file,
system_instruction,
max_tokens,
temperature,
top_p,
top_k,
cached_content,
} => {
let prompt_text = get_prompt_text(prompt, prompt_file)?;
let user_content = Content {
role: Role::User,
parts: vec![Part::TextPart(TextPart { text: prompt_text })],
};
let request = GenerateContentRequest {
model: ModelName {
model_id: cli.model.clone(),
},
contents: vec![user_content],
system_instruction: system_instruction.as_ref().map(|instruction| {
SystemInstruction {
parts: vec![Part::TextPart(TextPart {
text: instruction.clone(),
})],
}
}),
generation_config: Some(GenerationConfig {
max_output_tokens: *max_tokens,
temperature: *temperature,
top_p: *top_p,
top_k: *top_k,
candidate_count: None,
stop_sequences: None,
}),
safety_settings: None,
tools: None,
tool_config: None,
cached_content: cached_content.as_ref().map(|cache_id| {
if let Some(cache_id) = cache_id.strip_prefix("cachedContents/") {
CacheName {
cache_id: cache_id.to_string(),
}
} else {
CacheName {
cache_id: cache_id.clone(),
}
}
}),
};
println!("Generating content with model: {}", cli.model);
let mut stream =
stream_generate_content(http_client.as_ref(), &cli.api_url, &cli.api_key, request)
.await?;
println!("Response:");
while let Some(response) = stream.next().await {
match response {
Ok(resp) => {
if let Some(candidates) = &resp.candidates {
for candidate in candidates {
for part in &candidate.content.parts {
match part {
Part::TextPart(text_part) => {
print!("{}", text_part.text);
}
_ => {
println!("[Received non-text response part]");
}
}
}
}
}
}
Err(e) => {
eprintln!("Error: {}", e);
break;
}
}
}
println!();
}
Commands::CountTokens {
prompt,
prompt_file,
} => {
let prompt_text = get_prompt_text(prompt, prompt_file)?;
let user_content = Content {
role: Role::User,
parts: vec![Part::TextPart(TextPart { text: prompt_text })],
};
let generate_content_request = GenerateContentRequest {
model: ModelName {
model_id: cli.model.clone(),
},
contents: vec![user_content],
system_instruction: None,
generation_config: None,
safety_settings: None,
tools: None,
tool_config: None,
cached_content: None,
};
let request = CountTokensRequest {
generate_content_request,
};
let response =
count_tokens(http_client.as_ref(), &cli.api_url, &cli.api_key, request).await?;
println!("Total tokens: {}", response.total_tokens);
}
Commands::CreateCache {
prompt,
prompt_file,
system_instruction,
ttl,
} => {
let prompt_text = get_prompt_text(prompt, prompt_file)?;
let request = CreateCacheRequest {
ttl: Duration::from_secs(*ttl),
model: ModelName {
model_id: cli.model.clone(),
},
contents: vec![Content {
role: Role::User,
parts: vec![Part::TextPart(TextPart { text: prompt_text })],
}],
system_instruction: system_instruction.as_ref().map(|instruction| {
SystemInstruction {
parts: vec![Part::TextPart(TextPart {
text: instruction.clone(),
})],
}
}),
tools: None,
tool_config: None,
};
let response =
create_cache(http_client.as_ref(), &cli.api_url, &cli.api_key, request).await?;
match response {
CreateCacheResponse::Created(created_cache) => {
println!("Cache created:");
println!(" ID: {:?}", created_cache.name.cache_id);
println!(" Expires: {}", created_cache.expire_time);
if let Some(token_count) = created_cache.usage_metadata.total_token_count {
println!(" Total tokens: {}", token_count);
}
}
CreateCacheResponse::CachingNotSupportedByModel => {
println!("Cache not created, assuming this is due to the specified model ID.");
}
}
}
Commands::UpdateCache { cache_id, ttl } => {
let cache_name = if let Some(cache_id) = cache_id.strip_prefix("cachedContents/") {
CacheName {
cache_id: cache_id.to_string(),
}
} else {
CacheName {
cache_id: cache_id.clone(),
}
};
let request = UpdateCacheRequest {
name: cache_name.clone(),
ttl: Duration::from_secs(*ttl),
};
let response =
update_cache(http_client.as_ref(), &cli.api_url, &cli.api_key, request).await?;
println!("Cache updated:");
println!(" ID: {}", cache_name.cache_id);
println!(" New expiration: {}", response.expire_time);
}
Commands::Chat {
system_instruction,
max_tokens,
temperature,
history_file,
} => {
// Initialize a vector to store conversation history
let mut history = Vec::new();
// Load history if provided
if let Some(file_path) = history_file {
if file_path.exists() {
let file_content = fs::read_to_string(file_path)?;
history = serde_json::from_str(&file_content)?;
}
}
// Add system instruction if present
if let Some(instruction) = system_instruction {
println!("System: {}", instruction);
println!();
}
loop {
// Get user input
print!("You: ");
io::stdout().flush()?;
let mut input = String::new();
io::stdin().read_line(&mut input)?;
if input.trim().eq_ignore_ascii_case("exit")
|| input.trim().eq_ignore_ascii_case("quit")
{
break;
}
// Add user message to history
let user_content = Content {
role: Role::User,
parts: vec![Part::TextPart(TextPart {
text: input.trim().to_string(),
})],
};
history.push(user_content);
// Create request with history
let request = GenerateContentRequest {
model: ModelName {
model_id: cli.model.clone(),
},
contents: history
.iter()
.map(|content| Content {
role: content.role,
parts: content
.parts
.iter()
.map(|part| match part {
Part::TextPart(text_part) => Part::TextPart(TextPart {
text: text_part.text.clone(),
}),
_ => panic!("Unsupported part type in history"),
})
.collect(),
})
.collect(),
system_instruction: system_instruction.as_ref().map(|instruction| {
SystemInstruction {
parts: vec![Part::TextPart(TextPart {
text: instruction.clone(),
})],
}
}),
generation_config: Some(GenerationConfig {
max_output_tokens: *max_tokens,
temperature: *temperature,
top_p: None,
top_k: None,
candidate_count: None,
stop_sequences: None,
}),
safety_settings: None,
tools: None,
tool_config: None,
cached_content: None,
};
// Get response
print!("Assistant: ");
let mut stream = stream_generate_content(
http_client.as_ref(),
&cli.api_url,
&cli.api_key,
request,
)
.await?;
let mut model_response = String::new();
while let Some(response) = stream.next().await {
match response {
Ok(resp) => {
if let Some(candidates) = &resp.candidates {
for candidate in candidates {
for part in &candidate.content.parts {
match part {
Part::TextPart(text_part) => {
print!("{}", text_part.text);
model_response.push_str(&text_part.text);
}
_ => {
println!("[Received non-text response part]");
}
}
}
}
}
}
Err(e) => {
eprintln!("\nError: {}", e);
break;
}
}
}
println!();
println!();
// Add model response to history
let model_content = Content {
role: Role::Model,
parts: vec![Part::TextPart(TextPart {
text: model_response,
})],
};
history.push(model_content);
}
}
}
Ok(())
}

View File

@@ -1,7 +1,11 @@
use std::time::Duration;
use std::{fmt::Display, mem};
use anyhow::{Result, anyhow, bail};
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use serde::{Deserialize, Serialize};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use time::OffsetDateTime;
pub const API_URL: &str = "https://generativelanguage.googleapis.com";
@@ -11,25 +15,13 @@ pub async fn stream_generate_content(
api_key: &str,
mut request: GenerateContentRequest,
) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
if request.contents.is_empty() {
bail!("Request must contain at least one content item");
}
validate_generate_content_request(&request)?;
if let Some(user_content) = request
.contents
.iter()
.find(|content| content.role == Role::User)
{
if user_content.parts.is_empty() {
bail!("User content must contain at least one part");
}
}
// The `model` field is emptied as it is provided as a path parameter.
let model_id = mem::take(&mut request.model.model_id);
let uri = format!(
"{api_url}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={api_key}",
model = request.model
);
request.model.clear();
let uri =
format!("{api_url}/v1beta/models/{model_id}:streamGenerateContent?alt=sse&key={api_key}",);
let request_builder = HttpRequest::builder()
.method(Method::POST)
@@ -65,7 +57,7 @@ pub async fn stream_generate_content(
let mut text = String::new();
response.body_mut().read_to_string(&mut text).await?;
Err(anyhow!(
"error during streamGenerateContent, status code: {:?}, body: {}",
"error during Gemini content generation, status code: {:?}, body: {}",
response.status(),
text
))
@@ -76,18 +68,22 @@ pub async fn count_tokens(
client: &dyn HttpClient,
api_url: &str,
api_key: &str,
model_id: &str,
request: CountTokensRequest,
) -> Result<CountTokensResponse> {
let uri = format!("{api_url}/v1beta/models/{model_id}:countTokens?key={api_key}",);
let request = serde_json::to_string(&request)?;
validate_generate_content_request(&request.generate_content_request)?;
let uri = format!(
"{api_url}/v1beta/models/{model_id}:countTokens?key={api_key}",
model_id = &request.generate_content_request.model.model_id,
);
let request = serde_json::to_string(&request)?;
let request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(&uri)
.header("Content-Type", "application/json");
let http_request = request_builder.body(AsyncBody::from(request))?;
let mut response = client.send(http_request).await?;
let mut text = String::new();
response.body_mut().read_to_string(&mut text).await?;
@@ -95,13 +91,114 @@ pub async fn count_tokens(
Ok(serde_json::from_str::<CountTokensResponse>(&text)?)
} else {
Err(anyhow!(
"error during countTokens, status code: {:?}, body: {}",
"error during Gemini token counting, status code: {:?}, body: {}",
response.status(),
text
))
}
}
pub async fn create_cache(
client: &dyn HttpClient,
api_url: &str,
api_key: &str,
request: CreateCacheRequest,
) -> Result<CreateCacheResponse> {
if let Some(user_content) = request
.contents
.iter()
.find(|content| content.role == Role::User)
{
if user_content.parts.is_empty() {
bail!("User content must contain at least one part");
}
}
let uri = format!("{api_url}/v1beta/cachedContents?key={api_key}");
let request_builder = HttpRequest::builder()
.method(Method::POST)
.uri(uri)
.header("Content-Type", "application/json");
let http_request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
let mut response = client.send(http_request).await?;
let mut text = String::new();
response.body_mut().read_to_string(&mut text).await?;
if response.status().is_success() {
let created_cache = serde_json::from_str::<CreatedCache>(&text)?;
Ok(CreateCacheResponse::Created(created_cache))
} else if response.status().as_u16() == 404
&& text.contains("or is not supported for createCachedContent")
{
log::info!(
"will no longer attempt Gemini cache creation for model `{}` due to {:?} response: {}",
request.model.model_id,
response.status(),
text
);
Ok(CreateCacheResponse::CachingNotSupportedByModel)
} else {
Err(anyhow!(
"unexpected {:?} response during Gemini cache creation: {}",
response.status(),
text
))
}
}
pub async fn update_cache(
client: &dyn HttpClient,
api_url: &str,
api_key: &str,
mut request: UpdateCacheRequest,
) -> Result<UpdateCacheResponse> {
// The `name` field is emptied as it is provided as a path parameter.
let name = mem::take(&mut request.name);
let uri = format!(
"{api_url}/v1beta/cachedContents/{cache_id}?key={api_key}",
cache_id = &name.cache_id
);
let request_builder = HttpRequest::builder()
.method(Method::PATCH)
.uri(uri)
.header("Content-Type", "application/json");
let http_request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
let mut response = client.send(http_request).await?;
let mut text = String::new();
response.body_mut().read_to_string(&mut text).await?;
if response.status().is_success() {
Ok(serde_json::from_str::<UpdateCacheResponse>(&text)?)
} else {
Err(anyhow!(
"error during Gemini cache update, status code: {:?}, body: {}",
response.status(),
text
))
}
}
pub fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<()> {
if request.model.is_empty() {
bail!("Model must be specified");
}
if request.contents.is_empty() {
bail!("Request must contain at least one content item");
}
if let Some(user_content) = request
.contents
.iter()
.find(|content| content.role == Role::User)
{
if user_content.parts.is_empty() {
bail!("User content must contain at least one part");
}
}
Ok(())
}
#[derive(Debug, Serialize, Deserialize)]
pub enum Task {
#[serde(rename = "generateContent")]
@@ -119,8 +216,8 @@ pub enum Task {
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerateContentRequest {
#[serde(default, skip_serializing_if = "String::is_empty")]
pub model: String,
#[serde(default, skip_serializing_if = "ModelName::is_empty")]
pub model: ModelName,
pub contents: Vec<Content>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_instruction: Option<SystemInstruction>,
@@ -132,6 +229,8 @@ pub struct GenerateContentRequest {
pub tools: Option<Vec<Tool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_config: Option<ToolConfig>,
#[serde(skip_serializing_if = "Option::is_none")]
pub cached_content: Option<CacheName>,
}
#[derive(Debug, Serialize, Deserialize)]
@@ -161,7 +260,7 @@ pub struct GenerateContentCandidate {
pub citation_metadata: Option<CitationMetadata>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Content {
#[serde(default)]
@@ -169,20 +268,20 @@ pub struct Content {
pub role: Role,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SystemInstruction {
pub parts: Vec<Part>,
}
#[derive(Debug, PartialEq, Deserialize, Serialize)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Deserialize, Serialize)]
#[serde(rename_all = "camelCase")]
pub enum Role {
User,
Model,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum Part {
TextPart(TextPart),
@@ -191,32 +290,32 @@ pub enum Part {
FunctionResponsePart(FunctionResponsePart),
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct TextPart {
pub text: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct InlineDataPart {
pub inline_data: GenerativeContentBlob,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GenerativeContentBlob {
pub mime_type: String,
pub data: String,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionCallPart {
pub function_call: FunctionCall,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionResponsePart {
pub function_response: FunctionResponse,
@@ -251,7 +350,7 @@ pub struct PromptFeedback {
pub block_reason_message: Option<String>,
}
#[derive(Debug, Serialize, Deserialize, Default)]
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
#[serde(rename_all = "camelCase")]
pub struct UsageMetadata {
#[serde(skip_serializing_if = "Option::is_none")]
@@ -350,7 +449,7 @@ pub struct SafetyRating {
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CountTokensRequest {
pub contents: Vec<Content>,
pub generate_content_request: GenerateContentRequest,
}
#[derive(Debug, Serialize, Deserialize)]
@@ -359,31 +458,31 @@ pub struct CountTokensResponse {
pub total_tokens: usize,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub struct FunctionCall {
pub name: String,
pub args: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub struct FunctionResponse {
pub name: String,
pub response: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Tool {
pub function_declarations: Vec<FunctionDeclaration>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolConfig {
pub function_calling_config: FunctionCallingConfig,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionCallingConfig {
pub mode: FunctionCallingMode,
@@ -391,7 +490,7 @@ pub struct FunctionCallingConfig {
pub allowed_function_names: Option<Vec<String>>,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum FunctionCallingMode {
Auto,
@@ -399,27 +498,260 @@ pub enum FunctionCallingMode {
None,
}
#[derive(Debug, Serialize, Deserialize)]
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub struct FunctionDeclaration {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize, Hash, PartialEq, Eq)]
#[serde(rename_all = "camelCase")]
pub struct CreateCacheRequest {
#[serde(
serialize_with = "serialize_duration",
deserialize_with = "deserialize_duration"
)]
pub ttl: Duration,
pub model: ModelName,
pub contents: Vec<Content>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_instruction: Option<SystemInstruction>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<Tool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_config: Option<ToolConfig>,
// Other fields that could be provided:
//
// name: The resource name referring to the cached content. Format: cachedContents/{id}
// display_name: user-generated meaningful display name of the cached content. Maximum 128 Unicode characters.
}
pub enum CreateCacheResponse {
Created(CreatedCache),
CachingNotSupportedByModel,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CreatedCache {
pub name: CacheName,
#[serde(
serialize_with = "time::serde::rfc3339::serialize",
deserialize_with = "time::serde::rfc3339::deserialize"
)]
pub expire_time: OffsetDateTime,
pub usage_metadata: UsageMetadata,
// Other fields that could be provided:
//
// create_time: Creation time of the cache entry.
// update_time: When the cache entry was last updated in UTC time.
// usage_metadata: Metadata on the usage of the cached content.
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct UpdateCacheRequest {
pub name: CacheName,
#[serde(
serialize_with = "serialize_duration",
deserialize_with = "deserialize_duration"
)]
pub ttl: Duration,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct UpdateCacheResponse {
#[serde(
serialize_with = "time::serde::rfc3339::serialize",
deserialize_with = "time::serde::rfc3339::deserialize"
)]
pub expire_time: OffsetDateTime,
}
const MODEL_NAME_PREFIX: &str = "models/";
const CACHE_NAME_PREFIX: &str = "cachedContents/";
#[derive(Clone, Debug, Default, Hash, PartialEq, Eq)]
pub struct ModelName {
pub model_id: String,
}
#[derive(Clone, Debug, Default)]
pub struct CacheName {
pub cache_id: String,
}
impl ModelName {
pub fn is_empty(&self) -> bool {
self.model_id.is_empty()
}
}
impl CacheName {
pub fn is_empty(&self) -> bool {
self.cache_id.is_empty()
}
}
impl Serialize for ModelName {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&format!("{MODEL_NAME_PREFIX}{}", &self.model_id))
}
}
impl<'de> Deserialize<'de> for ModelName {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let string = String::deserialize(deserializer)?;
if let Some(id) = string.strip_prefix(MODEL_NAME_PREFIX) {
Ok(Self {
model_id: id.to_string(),
})
} else {
return Err(serde::de::Error::custom(format!(
"Expected model name to begin with {}, got: {}",
MODEL_NAME_PREFIX, string
)));
}
}
}
impl Display for ModelName {
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(formatter, "{}", self.model_id)
}
}
impl Serialize for CacheName {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&format!("{CACHE_NAME_PREFIX}{}", &self.cache_id))
}
}
impl<'de> Deserialize<'de> for CacheName {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let string = String::deserialize(deserializer)?;
if let Some(id) = string.strip_prefix(CACHE_NAME_PREFIX) {
Ok(CacheName {
cache_id: id.to_string(),
})
} else {
return Err(serde::de::Error::custom(format!(
"Expected cache name to begin with {}, got: {}",
CACHE_NAME_PREFIX, string
)));
}
}
}
impl Display for CacheName {
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(formatter, "{}", self.cache_id)
}
}
/// Serializes a Duration as a string in the format "X.Ys" where X is the whole seconds
/// and Y is up to 9 decimal places of fractional seconds.
pub fn serialize_duration<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let secs = duration.as_secs();
let nanos = duration.subsec_nanos();
let formatted = if nanos == 0 {
format!("{}s", secs)
} else {
// Remove trailing zeros from nanos
let mut nanos_str = format!("{:09}", nanos);
while nanos_str.ends_with('0') && nanos_str.len() > 1 {
nanos_str.pop();
}
format!("{}.{}s", secs, nanos_str)
};
serializer.serialize_str(&formatted)
}
pub fn deserialize_duration<'de, D>(deserializer: D) -> Result<Duration, D::Error>
where
D: Deserializer<'de>,
{
let duration_str = String::deserialize(deserializer)?;
let Some(num_part) = duration_str.strip_suffix('s') else {
return Err(serde::de::Error::custom(format!(
"Duration must end with 's', got: {}",
duration_str
)));
};
if let Some(decimal_ix) = num_part.find('.') {
let secs_part = &num_part[0..decimal_ix];
let frac_len = (num_part.len() - (decimal_ix + 1)).min(9);
let frac_start_ix = decimal_ix + 1;
let frac_end_ix = frac_start_ix + frac_len;
let frac_part = &num_part[frac_start_ix..frac_end_ix];
let secs = u64::from_str_radix(secs_part, 10).map_err(|e| {
serde::de::Error::custom(format!(
"Invalid seconds in duration: {}. Error: {}",
duration_str, e
))
})?;
let frac_number = frac_part.parse::<u32>().map_err(|e| {
serde::de::Error::custom(format!(
"Invalid fractional seconds in duration: {}. Error: {}",
duration_str, e
))
})?;
let nanos = frac_number * 10u32.pow(9 - frac_len as u32);
Ok(Duration::new(secs, nanos))
} else {
let secs = u64::from_str_radix(num_part, 10).map_err(|e| {
serde::de::Error::custom(format!(
"Invalid duration format: {}. Error: {}",
duration_str, e
))
})?;
Ok(Duration::new(secs, 0))
}
}
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
pub enum Model {
#[serde(rename = "gemini-1.5-pro")]
#[serde(rename = "gemini-1.5-pro-002")]
Gemini15Pro,
#[serde(rename = "gemini-1.5-flash")]
#[serde(rename = "gemini-1.5-flash-002")]
Gemini15Flash,
/// Note: replaced by `gemini-2.5-pro-exp-03-25` (continues to work in API).
#[serde(rename = "gemini-2.0-pro-exp")]
Gemini20Pro,
#[serde(rename = "gemini-2.0-flash")]
#[default]
Gemini20Flash,
/// Note: replaced by `gemini-2.5-flash-preview-04-17` (continues to work in API).
#[serde(rename = "gemini-2.0-flash-thinking-exp")]
Gemini20FlashThinking,
/// Note: replaced by `gemini-2.0-flash-lite` (continues to work in API).
#[serde(rename = "gemini-2.0-flash-lite-preview")]
Gemini20FlashLite,
#[serde(rename = "gemini-2.5-pro-exp-03-25")]
@@ -444,8 +776,8 @@ impl Model {
pub fn id(&self) -> &str {
match self {
Model::Gemini15Pro => "gemini-1.5-pro",
Model::Gemini15Flash => "gemini-1.5-flash",
Model::Gemini15Pro => "gemini-1.5-pro-002",
Model::Gemini15Flash => "gemini-1.5-flash-002",
Model::Gemini20Pro => "gemini-2.0-pro-exp",
Model::Gemini20Flash => "gemini-2.0-flash",
Model::Gemini20FlashThinking => "gemini-2.0-flash-thinking-exp",
@@ -457,6 +789,19 @@ impl Model {
}
}
pub fn matches_id(&self, other_id: &str) -> bool {
if self.id() == other_id {
return true;
}
// These IDs are present in user settings. The `-002` stable model version is added in the
// ID used for the model so that caching works.
match self {
Model::Gemini15Pro => other_id == "gemini-1.5-pro",
Model::Gemini15Flash => other_id == "gemini-1.5-flash",
_ => false,
}
}
pub fn display_name(&self) -> &str {
match self {
Model::Gemini15Pro => "Gemini 1.5 Pro",
@@ -497,3 +842,54 @@ impl std::fmt::Display for Model {
write!(f, "{}", self.id())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_duration_serialization() {
#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)]
struct Example {
#[serde(
serialize_with = "serialize_duration",
deserialize_with = "deserialize_duration"
)]
duration: Duration,
}
let example = Example {
duration: Duration::from_secs(5),
};
let serialized = serde_json::to_string(&example).unwrap();
let deserialized: Example = serde_json::from_str(&serialized).unwrap();
assert_eq!(serialized, r#"{"duration":"5s"}"#);
assert_eq!(deserialized, example);
let example = Example {
duration: Duration::from_millis(5534),
};
let serialized = serde_json::to_string(&example).unwrap();
let deserialized: Example = serde_json::from_str(&serialized).unwrap();
assert_eq!(serialized, r#"{"duration":"5.534s"}"#);
assert_eq!(deserialized, example);
let example = Example {
duration: Duration::from_nanos(12345678900),
};
let serialized = serde_json::to_string(&example).unwrap();
let deserialized: Example = serde_json::from_str(&serialized).unwrap();
assert_eq!(serialized, r#"{"duration":"12.3456789s"}"#);
assert_eq!(deserialized, example);
// Deserializer doesn't panic for too many fractional digits
let deserialized: Example =
serde_json::from_str(r#"{"duration":"5.12345678905s"}"#).unwrap();
assert_eq!(
deserialized,
Example {
duration: Duration::from_nanos(5123456789)
}
);
}
}

View File

@@ -228,12 +228,27 @@ impl Default for LanguageModelTextStream {
}
pub trait LanguageModel: Send + Sync {
/// The ID used for the model when making requests. If checking for match of a user-provided
/// string, use `matches_id` instead.
fn id(&self) -> LanguageModelId;
fn name(&self) -> LanguageModelName;
fn provider_id(&self) -> LanguageModelProviderId;
fn provider_name(&self) -> LanguageModelProviderName;
fn telemetry_id(&self) -> String;
/// Use this to check whether this model should be used for the given user-provided model ID, to support multiple IDs.
/// Multiple IDs are used for:
///
/// * Preferring a particular ID in requests (Gemini only supports caching for stable IDs).
///
/// * When one model has been replaced by another. This allows the display name to update to
/// reflect which model is actually being used.
///
/// TODO: In the future consider replacing this mechanism with a settings migration.
fn matches_id(&self, other_id: &LanguageModelId) -> bool {
&self.id() == other_id
}
fn api_key(&self, _cx: &App) -> Option<String> {
None
}

View File

@@ -197,7 +197,7 @@ impl LanguageModelRegistry {
let model = provider
.provided_models(cx)
.iter()
.find(|model| model.id() == selected_model.model)?
.find(|model| model.matches_id(&selected_model.model))?
.clone();
Some(ConfiguredModel { provider, model })
}

View File

@@ -20,8 +20,8 @@ aws_http_client.workspace = true
bedrock.workspace = true
client.workspace = true
collections.workspace = true
credentials_provider.workspace = true
copilot = { workspace = true, features = ["schemars"] }
credentials_provider.workspace = true
deepseek = { workspace = true, features = ["schemars"] }
editor.workspace = true
feature_flags.workspace = true
@@ -38,6 +38,7 @@ menu.workspace = true
mistral = { workspace = true, features = ["schemars"] }
ollama = { workspace = true, features = ["schemars"] }
open_ai = { workspace = true, features = ["schemars"] }
parking_lot.workspace = true
partial-json-fixer.workspace = true
project.workspace = true
proto.workspace = true
@@ -50,6 +51,7 @@ strum.workspace = true
theme.workspace = true
thiserror.workspace = true
tiktoken-rs.workspace = true
time.workspace = true
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
ui.workspace = true
util.workspace = true

View File

@@ -461,6 +461,7 @@ impl LanguageModel for AnthropicModel {
self.model.max_output_tokens(),
self.model.mode(),
);
let request = self.stream_completion(request, cx);
let future = self.request_limiter.stream(async move {
let response = request

View File

@@ -656,6 +656,14 @@ impl LanguageModel for CloudLanguageModel {
self.id.clone()
}
fn matches_id(&self, other_id: &LanguageModelId) -> bool {
match &self.model {
CloudModel::Anthropic(model) => model.matches_id(&other_id.0),
CloudModel::Google(model) => model.matches_id(&other_id.0),
CloudModel::OpenAi(model) => model.matches_id(&other_id.0),
}
}
fn name(&self) -> LanguageModelName {
LanguageModelName::from(self.model.display_name().to_string())
}
@@ -718,7 +726,8 @@ impl LanguageModel for CloudLanguageModel {
CloudModel::Google(model) => {
let client = self.client.clone();
let llm_api_token = self.llm_api_token.clone();
let request = into_google(request, model.id().into());
let model_id = model.id().to_string();
let generate_content_request = into_google(request, model_id.clone());
async move {
let http_client = &client.http_client();
let token = llm_api_token.acquire(&client).await?;
@@ -736,9 +745,9 @@ impl LanguageModel for CloudLanguageModel {
};
let request_body = CountTokensBody {
provider: zed_llm_client::LanguageModelProvider::Google,
model: model.id().into(),
model: model_id,
provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
contents: request.contents,
generate_content_request,
})?,
};
let request = request_builder
@@ -895,7 +904,7 @@ impl LanguageModel for CloudLanguageModel {
prompt_id,
mode,
provider: zed_llm_client::LanguageModelProvider::Google,
model: request.model.clone(),
model: request.model.model_id.clone(),
provider_request: serde_json::to_value(&request)?,
},
)

View File

@@ -1,13 +1,17 @@
use anyhow::{Context as _, Result, anyhow};
use collections::BTreeMap;
use collections::{BTreeMap, FxHasher, HashMap, HashSet};
use credentials_provider::CredentialsProvider;
use editor::{Editor, EditorElement, EditorStyle};
use futures::future::{self, Shared};
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
use google_ai::{
FunctionDeclaration, GenerateContentResponse, Part, SystemInstruction, UsageMetadata,
CacheName, Content, CreateCacheRequest, CreateCacheResponse, FunctionDeclaration,
GenerateContentRequest, GenerateContentResponse, ModelName, Part, SystemInstruction,
UpdateCacheRequest, UpdateCacheResponse, UsageMetadata,
};
use gpui::{
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
AnyView, App, AppContext, AsyncApp, BackgroundExecutor, Context, Entity, FontStyle,
Subscription, Task, TextStyle, WhiteSpace,
};
use http_client::HttpClient;
use language_model::{
@@ -20,16 +24,20 @@ use language_model::{
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
LanguageModelRequest, RateLimiter, Role,
};
use parking_lot::Mutex;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
use std::hash::{Hash as _, Hasher as _};
use std::pin::Pin;
use std::sync::{
Arc,
atomic::{self, AtomicU64},
};
use std::time::Duration;
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use time::UtcDateTime;
use ui::{Icon, IconName, List, Tooltip, prelude::*};
use util::ResultExt;
@@ -38,6 +46,10 @@ use crate::ui::InstructionListItem;
const PROVIDER_ID: &str = "google";
const PROVIDER_NAME: &str = "Google AI";
const CACHE_TTL: Duration = Duration::from_secs(60 * 5);
/// Minimum amount of time left before a cache expires for it to be used.
const CACHE_TTL_MINIMUM_FOR_USAGE: Duration = Duration::from_secs(10);
#[derive(Default, Clone, Debug, PartialEq)]
pub struct GoogleSettings {
@@ -50,6 +62,7 @@ pub struct AvailableModel {
name: String,
display_name: Option<String>,
max_tokens: usize,
caching: bool,
}
pub struct GoogleLanguageModelProvider {
@@ -160,6 +173,7 @@ impl GoogleLanguageModelProvider {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
cache: Mutex::new(Cache::default()).into(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
})
@@ -227,6 +241,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
id: LanguageModelId::from(model.id().to_string()),
model,
state: self.state.clone(),
cache: Mutex::new(Cache::default()).into(),
http_client: self.http_client.clone(),
request_limiter: RateLimiter::new(4),
}) as Arc<dyn LanguageModel>
@@ -256,6 +271,7 @@ pub struct GoogleLanguageModel {
id: LanguageModelId,
model: google_ai::Model,
state: gpui::Entity<State>,
cache: Arc<Mutex<Cache>>,
http_client: Arc<dyn HttpClient>,
request_limiter: RateLimiter,
}
@@ -263,32 +279,77 @@ pub struct GoogleLanguageModel {
impl GoogleLanguageModel {
fn stream_completion(
&self,
request: google_ai::GenerateContentRequest,
request: impl 'static + Send + Future<Output = google_ai::GenerateContentRequest>,
cx: &AsyncApp,
) -> BoxFuture<
'static,
Result<futures::stream::BoxStream<'static, Result<GenerateContentResponse>>>,
> {
let http_client = self.http_client.clone();
match self.api_url_and_key(cx) {
Ok((api_url, api_key)) => {
let http_client = self.http_client.clone();
async move {
let request = google_ai::stream_generate_content(
http_client.as_ref(),
&api_url,
&api_key,
request.await,
);
request.await.context("failed to stream completion")
}
.boxed()
}
Err(err) => future::ready(Err(err)).boxed(),
}
}
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
fn create_cache(
&self,
request: google_ai::CreateCacheRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<CreateCacheResponse>> {
match self.api_url_and_key(cx) {
Ok((api_url, api_key)) => {
let http_client = self.http_client.clone();
async move {
google_ai::create_cache(http_client.as_ref(), &api_url, &api_key, request).await
}
.boxed()
}
Err(err) => future::ready(Err(err)).boxed(),
}
}
fn update_cache(
&self,
request: UpdateCacheRequest,
cx: &AsyncApp,
) -> BoxFuture<'static, Result<UpdateCacheResponse>> {
match self.api_url_and_key(cx) {
Ok((api_url, api_key)) => {
let http_client = self.http_client.clone();
async move {
google_ai::update_cache(http_client.as_ref(), &api_url, &api_key, request).await
}
.boxed()
}
Err(err) => future::ready(Err(err)).boxed(),
}
}
fn api_url_and_key(&self, cx: &AsyncApp) -> Result<(String, String)> {
let Ok((api_url, api_key)) = self.state.read_with(cx, |state, cx| {
let settings = &AllLanguageModelSettings::get_global(cx).google;
(state.api_key.clone(), settings.api_url.clone())
(settings.api_url.clone(), state.api_key.clone())
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
return Err(anyhow!("App state dropped"));
};
async move {
let api_key = api_key.ok_or_else(|| anyhow!("Missing Google API key"))?;
let request = google_ai::stream_generate_content(
http_client.as_ref(),
&api_url,
&api_key,
request,
);
request.await.context("failed to stream completion")
}
.boxed()
let Some(api_key) = api_key else {
return Err(anyhow!("Missing Google API key"));
};
Ok((api_url, api_key))
}
}
@@ -297,6 +358,10 @@ impl LanguageModel for GoogleLanguageModel {
self.id.clone()
}
fn matches_id(&self, other_id: &LanguageModelId) -> bool {
self.model.matches_id(&other_id.0)
}
fn name(&self) -> LanguageModelName {
LanguageModelName::from(self.model.display_name().to_string())
}
@@ -344,9 +409,8 @@ impl LanguageModel for GoogleLanguageModel {
http_client.as_ref(),
&api_url,
&api_key,
&model_id,
google_ai::CountTokensRequest {
contents: request.contents,
generate_content_request: request,
},
)
.await?;
@@ -368,10 +432,19 @@ impl LanguageModel for GoogleLanguageModel {
>,
>,
> {
let is_last_message_cached = request
.messages
.last()
.map_or(false, |content| content.cache);
let request = into_google(request, self.model.id().to_string());
let request = self.stream_completion(request, cx);
let request_task = self.use_cache_and_create_cache(is_last_message_cached, request, cx);
let stream_request = self.stream_completion(request_task, cx);
let future = self.request_limiter.stream(async move {
let response = request
let response = stream_request
.await
.map_err(|err| LanguageModelCompletionError::Other(anyhow!(err)))?;
Ok(GoogleEventMapper::new().map_stream(response))
@@ -382,7 +455,7 @@ impl LanguageModel for GoogleLanguageModel {
pub fn into_google(
mut request: LanguageModelRequest,
model: String,
model_id: String,
) -> google_ai::GenerateContentRequest {
fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
content
@@ -442,7 +515,7 @@ pub fn into_google(
};
google_ai::GenerateContentRequest {
model,
model: google_ai::ModelName { model_id },
system_instruction: system_instructions,
contents: request
.messages
@@ -486,6 +559,7 @@ pub fn into_google(
}]
}),
tool_config: None,
cached_content: None,
}
}
@@ -642,6 +716,287 @@ fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage {
}
}
#[derive(Debug, Default)]
pub struct Cache {
task_map: HashMap<CacheKey, Shared<Task<Option<CacheEntry>>>>,
models_not_supported: HashSet<ModelName>,
}
#[derive(Debug, Clone)]
struct CacheEntry {
name: CacheName,
expire_time: UtcDateTime,
_drop_on_expire: Shared<Task<()>>,
}
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
struct CacheKey(u64);
#[derive(Debug)]
enum CacheStatus {
Absent,
Creating,
Failed,
Valid(CacheEntry),
Expired,
}
impl CacheEntry {
fn new(
name: CacheName,
expire_time: UtcDateTime,
cache_key: CacheKey,
cache: Arc<Mutex<Cache>>,
executor: BackgroundExecutor,
) -> Option<CacheEntry> {
let now = UtcDateTime::now();
let Some(remaining_time): Option<std::time::Duration> = (expire_time - now).try_into().ok()
else {
cache.lock().task_map.remove(&cache_key);
return None;
};
let cache_expired = executor.timer(remaining_time);
let drop_on_expire = executor
.spawn({
let name = name.clone();
async move {
cache_expired.await;
cache.lock().task_map.remove(&cache_key);
log::info!("removed cache `{name}` because it expired");
}
})
.shared();
Some(CacheEntry {
name,
expire_time,
_drop_on_expire: drop_on_expire,
})
}
fn is_expired(&self, now: UtcDateTime) -> bool {
self.expire_time - CACHE_TTL_MINIMUM_FOR_USAGE > now
}
}
impl CacheKey {
fn base(request: &GenerateContentRequest) -> Self {
#[derive(Hash, PartialEq, Eq)]
pub struct CacheBaseRef<'a> {
pub model: &'a ModelName,
pub system_instruction: &'a Option<SystemInstruction>,
pub tools: &'a Option<Vec<google_ai::Tool>>,
pub tool_config: &'a Option<google_ai::ToolConfig>,
}
let cache_base = CacheBaseRef {
model: &request.model,
system_instruction: &request.system_instruction,
tools: &request.tools,
tool_config: &request.tool_config,
};
let mut hasher = FxHasher::default();
cache_base.hash(&mut hasher);
Self(hasher.finish())
}
fn for_message(predecessor: Self, message: &Content) -> Self {
let mut hasher = FxHasher::default();
predecessor.0.hash(&mut hasher);
message.hash(&mut hasher);
Self(hasher.finish())
}
}
impl Cache {
fn get_unexpired(
&self,
key: &CacheKey,
now: UtcDateTime,
) -> Option<Shared<Task<Option<CacheEntry>>>> {
let cache_task = self.task_map.get(key)?;
match cache_task.clone().now_or_never() {
Some(Some(existing_cache)) => {
if existing_cache.is_expired(now) {
Some(cache_task.clone())
} else {
None
}
}
Some(None) => {
// Cache creation failed
None
}
None => {
// Caching task pending, so use it.
Some(cache_task.clone())
}
}
}
fn get_status(&self, key: &CacheKey) -> CacheStatus {
if let Some(existing_task) = self.task_map.get(&key) {
match existing_task.clone().now_or_never() {
Some(Some(existing_cache)) => {
let now = UtcDateTime::now();
if existing_cache.is_expired(now) {
CacheStatus::Expired
} else {
CacheStatus::Valid(existing_cache)
}
}
Some(None) => CacheStatus::Failed,
None => CacheStatus::Creating,
}
} else {
CacheStatus::Absent
}
}
}
impl GoogleLanguageModel {
fn use_cache_and_create_cache(
&self,
should_create_cache: bool,
mut request: google_ai::GenerateContentRequest,
cx: &AsyncApp,
) -> Task<google_ai::GenerateContentRequest> {
if self
.cache
.lock()
.models_not_supported
.contains(&request.model)
{
return Task::ready(request);
}
let base_cache_key = CacheKey::base(&request);
let mut prev_cache_key = base_cache_key;
let content_cache_keys = request
.contents
.iter()
.map(|content| {
let key = CacheKey::for_message(prev_cache_key, content);
prev_cache_key = key;
key
})
.collect::<Vec<_>>();
if let (true, Some(cache_key)) = (should_create_cache, content_cache_keys.last().copied()) {
let cache_status = self.cache.lock().get_status(&cache_key);
match cache_status {
CacheStatus::Creating => {}
CacheStatus::Valid(existing_cache) => {
let name = existing_cache.name;
let update_cache_request = UpdateCacheRequest {
name: name.clone(),
ttl: CACHE_TTL,
};
// TODO: This is imperfect, but not in ways that are likely to matter in
// practice. Ideally the cache update task would get stored such that:
// * It's not posible for the cache to expire and get recreated while waiting
// for the update response.
// * It would get cancelled on drop.
let update_cache_future = self.update_cache(update_cache_request, cx);
let cache = self.cache.clone();
let executor = cx.background_executor().clone();
cx.background_spawn(async move {
if let Some(response) = update_cache_future.await.log_err() {
let cache_entry = CacheEntry::new(
name,
response.expire_time.to_utc(),
cache_key,
cache.clone(),
executor,
);
cache
.lock()
.task_map
.insert(cache_key, Task::ready(cache_entry).shared());
}
})
.detach();
}
CacheStatus::Absent | CacheStatus::Expired | CacheStatus::Failed => {
let create_request = CreateCacheRequest {
ttl: CACHE_TTL,
model: request.model.clone(),
contents: request.contents.clone(),
system_instruction: request.system_instruction.clone(),
tools: request.tools.clone(),
tool_config: request.tool_config.clone(),
};
let model = request.model.clone();
let create_cache_future = self
.request_limiter
.run(self.create_cache(create_request, cx));
let cache = self.cache.clone();
let executor = cx.background_executor().clone();
let task = cx.background_spawn(async move {
let result = match create_cache_future.await.log_err()? {
CreateCacheResponse::Created(created_cache) => {
log::info!("created cache `{}`", created_cache.name);
CacheEntry::new(
created_cache.name,
created_cache.expire_time.to_utc(),
cache_key,
cache.clone(),
executor,
)
}
CreateCacheResponse::CachingNotSupportedByModel => {
cache.lock().models_not_supported.insert(model);
None
}
};
if result.is_none() {
cache.lock().task_map.remove(&cache_key);
}
result
});
self.cache.lock().task_map.insert(cache_key, task.shared());
}
}
}
cx.background_spawn({
let cache = self.cache.clone();
async move {
let mut prefix_len = 0;
let mut found_cache_entry = None;
let mut now = UtcDateTime::now();
// The last key is skipped because `contents` must be non-empty.
for (ix, key) in content_cache_keys.iter().enumerate().rev().skip(1) {
let task = cache.lock().get_unexpired(&key, now);
// TODO: Measure if it's worth it to await on caches vs using ones that are already ready.
if let Some(task) = task {
if let Some(cache_entry) = task.await {
prefix_len = ix + 1;
found_cache_entry = Some(cache_entry);
break;
} else {
now = UtcDateTime::now();
}
}
}
if let Some(found_cache_entry) = found_cache_entry {
log::info!(
"using cache `{}` which has {prefix_len} messages",
found_cache_entry.name.cache_id
);
request.cached_content = Some(found_cache_entry.name);
request.contents.drain(..prefix_len);
request.system_instruction = None;
request.tools = None;
request.tool_config = None;
}
request
}
})
}
}
struct ConfigurationView {
api_key_editor: Entity<Editor>,
state: gpui::Entity<State>,

View File

@@ -142,6 +142,10 @@ impl Model {
}
}
pub fn matches_id(&self, other_id: &str) -> bool {
self.id() == other_id
}
pub fn display_name(&self) -> &str {
match self {
Self::ThreePointFiveTurbo => "gpt-3.5-turbo",

View File

@@ -540,7 +540,7 @@ impl SummaryIndex {
);
let Some(model) = LanguageModelRegistry::read_global(cx)
.available_models(cx)
.find(|model| &model.id() == &summary_model_id)
.find(|model| model.matches_id(&summary_model_id))
else {
return cx.background_spawn(async move {
Err(anyhow!("Couldn't find the preferred summarization model ({:?}) in the language registry's available models", summary_model_id))