Compare commits
34 Commits
commit-vie
...
gemini-cac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fc3a67f046 | ||
|
|
670813b9de | ||
|
|
61d615549f | ||
|
|
3931d74275 | ||
|
|
2cea66c5cb | ||
|
|
09e26204ad | ||
|
|
bfea3e5285 | ||
|
|
f704e0578a | ||
|
|
04b16fedde | ||
|
|
bde4bd5b3d | ||
|
|
85e26b7f02 | ||
|
|
0cc51f72e5 | ||
|
|
abbb3dc7a6 | ||
|
|
141e0a702a | ||
|
|
bb47b766a1 | ||
|
|
515cdb9ae6 | ||
|
|
2666ff7873 | ||
|
|
1d0cc37205 | ||
|
|
1257d44998 | ||
|
|
5fdcdc1926 | ||
|
|
8c91fc3153 | ||
|
|
b1595dba71 | ||
|
|
d4702209ea | ||
|
|
8440ec03ad | ||
|
|
8c8eabe96d | ||
|
|
10dfa36c91 | ||
|
|
88c7893913 | ||
|
|
c2151f0082 | ||
|
|
d677117a48 | ||
|
|
f86b552a20 | ||
|
|
1c6040e54f | ||
|
|
b861e1ca8c | ||
|
|
961e7dd52a | ||
|
|
9c548fecbc |
7
Cargo.lock
generated
7
Cargo.lock
generated
@@ -6198,12 +6198,17 @@ name = "google_ai"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"clap",
|
||||
"futures 0.3.31",
|
||||
"http_client",
|
||||
"log",
|
||||
"reqwest_client",
|
||||
"schemars",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"strum 0.27.1",
|
||||
"time",
|
||||
"tokio",
|
||||
"workspace-hack",
|
||||
]
|
||||
|
||||
@@ -7909,6 +7914,7 @@ dependencies = [
|
||||
"mistral",
|
||||
"ollama",
|
||||
"open_ai",
|
||||
"parking_lot",
|
||||
"partial-json-fixer",
|
||||
"project",
|
||||
"proto",
|
||||
@@ -7921,6 +7927,7 @@ dependencies = [
|
||||
"theme",
|
||||
"thiserror 2.0.12",
|
||||
"tiktoken-rs",
|
||||
"time",
|
||||
"tokio",
|
||||
"ui",
|
||||
"util",
|
||||
|
||||
@@ -107,6 +107,10 @@ impl Model {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn matches_id(&self, other_id: &str) -> bool {
|
||||
self.id() == other_id
|
||||
}
|
||||
|
||||
/// The id of the model that should be used for making API requests
|
||||
pub fn request_id(&self) -> &str {
|
||||
match self {
|
||||
|
||||
@@ -18,8 +18,15 @@ schemars = ["dep:schemars"]
|
||||
anyhow.workspace = true
|
||||
futures.workspace = true
|
||||
http_client.workspace = true
|
||||
log.workspace = true
|
||||
schemars = { workspace = true, optional = true }
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
strum.workspace = true
|
||||
time.workspace = true
|
||||
workspace-hack.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
clap = { version = "4.4", features = ["derive", "env"] }
|
||||
tokio = { version = "1.28", features = ["full"] }
|
||||
reqwest_client.workspace = true
|
||||
|
||||
482
crates/google_ai/examples/google_ai_cli.rs
Normal file
482
crates/google_ai/examples/google_ai_cli.rs
Normal file
@@ -0,0 +1,482 @@
|
||||
//! CLI for interacting with Gemini, 99% generated by Zed + Claude 3.7 Sonnet. Feel free to delete
|
||||
//! it rather than maintain, it took 10 minutes.
|
||||
|
||||
use anyhow::{Result, anyhow};
|
||||
use clap::{Parser, Subcommand};
|
||||
use futures::StreamExt;
|
||||
use google_ai::{
|
||||
CacheName, Content, CountTokensRequest, CreateCacheRequest, CreateCacheResponse,
|
||||
GenerateContentRequest, GenerationConfig, ModelName, Part, Role, SystemInstruction, TextPart,
|
||||
UpdateCacheRequest, count_tokens, create_cache, stream_generate_content, update_cache,
|
||||
};
|
||||
use reqwest_client::ReqwestClient;
|
||||
use std::io::Write;
|
||||
use std::{fs, io, path::PathBuf, sync::Arc, time::Duration};
|
||||
|
||||
#[derive(Parser)]
|
||||
#[command(name = "google_ai_cli")]
|
||||
#[command(author = "Zed Team")]
|
||||
#[command(version = "0.1.0")]
|
||||
#[command(about = "Interface with the Google Generative AI API", long_about = None)]
|
||||
struct Cli {
|
||||
/// Gemini API key
|
||||
#[arg(long, env = "GEMINI_API_KEY")]
|
||||
api_key: String,
|
||||
|
||||
/// API URL (defaults to https://generativelanguage.googleapis.com)
|
||||
#[arg(long, global = true, default_value = google_ai::API_URL)]
|
||||
api_url: String,
|
||||
|
||||
/// The model to use
|
||||
#[arg(long, global = true, default_value = "gemini-1.5-flash-002")]
|
||||
model: String,
|
||||
|
||||
#[command(subcommand)]
|
||||
command: Commands,
|
||||
}
|
||||
|
||||
#[derive(Subcommand)]
|
||||
enum Commands {
|
||||
/// Generate content from a prompt
|
||||
Generate {
|
||||
/// The prompt text
|
||||
#[arg(long, conflicts_with = "prompt_file")]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// File containing the prompt text
|
||||
#[arg(long, conflicts_with = "prompt")]
|
||||
prompt_file: Option<PathBuf>,
|
||||
|
||||
/// System instruction for the model
|
||||
#[arg(long)]
|
||||
system_instruction: Option<String>,
|
||||
|
||||
/// Maximum number of tokens to generate
|
||||
#[arg(long)]
|
||||
max_tokens: Option<usize>,
|
||||
|
||||
/// Temperature for generation (0.0 to 1.0)
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Top-p sampling parameter (0.0 to 1.0)
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Top-k sampling parameter
|
||||
#[arg(long)]
|
||||
top_k: Option<usize>,
|
||||
|
||||
/// Use cached content by specifying the cache name
|
||||
#[arg(long)]
|
||||
cached_content: Option<String>,
|
||||
},
|
||||
|
||||
/// Count tokens in a prompt
|
||||
CountTokens {
|
||||
/// The prompt text
|
||||
#[arg(long, conflicts_with = "prompt_file")]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// File containing the prompt text
|
||||
#[arg(long, conflicts_with = "prompt")]
|
||||
prompt_file: Option<PathBuf>,
|
||||
},
|
||||
|
||||
/// Cache content for faster repeated access
|
||||
CreateCache {
|
||||
/// The prompt text
|
||||
#[arg(long, conflicts_with = "prompt_file")]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// File containing the prompt text
|
||||
#[arg(long, conflicts_with = "prompt")]
|
||||
prompt_file: Option<PathBuf>,
|
||||
|
||||
/// System instruction for the model
|
||||
#[arg(long)]
|
||||
system_instruction: Option<String>,
|
||||
|
||||
/// Time-to-live for the cache in seconds
|
||||
#[arg(long, default_value = "3600")]
|
||||
ttl: u64,
|
||||
},
|
||||
|
||||
/// Update cache TTL
|
||||
UpdateCache {
|
||||
/// The cache name to update
|
||||
#[arg(long)]
|
||||
cache_id: String,
|
||||
|
||||
/// New time-to-live for the cache in seconds
|
||||
#[arg(long)]
|
||||
ttl: u64,
|
||||
},
|
||||
|
||||
/// Interactive conversation with the model
|
||||
Chat {
|
||||
/// System instruction for the model
|
||||
#[arg(long)]
|
||||
system_instruction: Option<String>,
|
||||
|
||||
/// Maximum number of tokens to generate
|
||||
#[arg(long)]
|
||||
max_tokens: Option<usize>,
|
||||
|
||||
/// Temperature for generation (0.0 to 1.0)
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Load a JSON file containing chat history
|
||||
#[arg(long)]
|
||||
history_file: Option<PathBuf>,
|
||||
},
|
||||
}
|
||||
|
||||
// Helper function to get prompt text from either prompt or prompt_file
|
||||
fn get_prompt_text(prompt: &Option<String>, prompt_file: &Option<PathBuf>) -> Result<String> {
|
||||
match (prompt, prompt_file) {
|
||||
(Some(text), None) => Ok(text.clone()),
|
||||
(None, Some(file_path)) => {
|
||||
fs::read_to_string(file_path).map_err(|e| anyhow!("Failed to read prompt file: {}", e))
|
||||
}
|
||||
(None, None) => Err(anyhow!(
|
||||
"Either --prompt or --prompt-file must be specified"
|
||||
)),
|
||||
(Some(_), Some(_)) => Err(anyhow!("Cannot specify both --prompt and --prompt-file")),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
let cli = Cli::parse();
|
||||
let http_client = Arc::new(ReqwestClient::new());
|
||||
|
||||
match &cli.command {
|
||||
Commands::Generate {
|
||||
prompt,
|
||||
prompt_file,
|
||||
system_instruction,
|
||||
max_tokens,
|
||||
temperature,
|
||||
top_p,
|
||||
top_k,
|
||||
cached_content,
|
||||
} => {
|
||||
let prompt_text = get_prompt_text(prompt, prompt_file)?;
|
||||
|
||||
let user_content = Content {
|
||||
role: Role::User,
|
||||
parts: vec![Part::TextPart(TextPart { text: prompt_text })],
|
||||
};
|
||||
|
||||
let request = GenerateContentRequest {
|
||||
model: ModelName {
|
||||
model_id: cli.model.clone(),
|
||||
},
|
||||
contents: vec![user_content],
|
||||
system_instruction: system_instruction.as_ref().map(|instruction| {
|
||||
SystemInstruction {
|
||||
parts: vec![Part::TextPart(TextPart {
|
||||
text: instruction.clone(),
|
||||
})],
|
||||
}
|
||||
}),
|
||||
generation_config: Some(GenerationConfig {
|
||||
max_output_tokens: *max_tokens,
|
||||
temperature: *temperature,
|
||||
top_p: *top_p,
|
||||
top_k: *top_k,
|
||||
candidate_count: None,
|
||||
stop_sequences: None,
|
||||
}),
|
||||
safety_settings: None,
|
||||
tools: None,
|
||||
tool_config: None,
|
||||
cached_content: cached_content.as_ref().map(|cache_id| {
|
||||
if let Some(cache_id) = cache_id.strip_prefix("cachedContents/") {
|
||||
CacheName {
|
||||
cache_id: cache_id.to_string(),
|
||||
}
|
||||
} else {
|
||||
CacheName {
|
||||
cache_id: cache_id.clone(),
|
||||
}
|
||||
}
|
||||
}),
|
||||
};
|
||||
|
||||
println!("Generating content with model: {}", cli.model);
|
||||
let mut stream =
|
||||
stream_generate_content(http_client.as_ref(), &cli.api_url, &cli.api_key, request)
|
||||
.await?;
|
||||
|
||||
println!("Response:");
|
||||
while let Some(response) = stream.next().await {
|
||||
match response {
|
||||
Ok(resp) => {
|
||||
if let Some(candidates) = &resp.candidates {
|
||||
for candidate in candidates {
|
||||
for part in &candidate.content.parts {
|
||||
match part {
|
||||
Part::TextPart(text_part) => {
|
||||
print!("{}", text_part.text);
|
||||
}
|
||||
_ => {
|
||||
println!("[Received non-text response part]");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("Error: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
println!();
|
||||
}
|
||||
|
||||
Commands::CountTokens {
|
||||
prompt,
|
||||
prompt_file,
|
||||
} => {
|
||||
let prompt_text = get_prompt_text(prompt, prompt_file)?;
|
||||
|
||||
let user_content = Content {
|
||||
role: Role::User,
|
||||
parts: vec![Part::TextPart(TextPart { text: prompt_text })],
|
||||
};
|
||||
|
||||
let generate_content_request = GenerateContentRequest {
|
||||
model: ModelName {
|
||||
model_id: cli.model.clone(),
|
||||
},
|
||||
contents: vec![user_content],
|
||||
system_instruction: None,
|
||||
generation_config: None,
|
||||
safety_settings: None,
|
||||
tools: None,
|
||||
tool_config: None,
|
||||
cached_content: None,
|
||||
};
|
||||
|
||||
let request = CountTokensRequest {
|
||||
generate_content_request,
|
||||
};
|
||||
|
||||
let response =
|
||||
count_tokens(http_client.as_ref(), &cli.api_url, &cli.api_key, request).await?;
|
||||
println!("Total tokens: {}", response.total_tokens);
|
||||
}
|
||||
|
||||
Commands::CreateCache {
|
||||
prompt,
|
||||
prompt_file,
|
||||
system_instruction,
|
||||
ttl,
|
||||
} => {
|
||||
let prompt_text = get_prompt_text(prompt, prompt_file)?;
|
||||
|
||||
let request = CreateCacheRequest {
|
||||
ttl: Duration::from_secs(*ttl),
|
||||
model: ModelName {
|
||||
model_id: cli.model.clone(),
|
||||
},
|
||||
contents: vec![Content {
|
||||
role: Role::User,
|
||||
parts: vec![Part::TextPart(TextPart { text: prompt_text })],
|
||||
}],
|
||||
system_instruction: system_instruction.as_ref().map(|instruction| {
|
||||
SystemInstruction {
|
||||
parts: vec![Part::TextPart(TextPart {
|
||||
text: instruction.clone(),
|
||||
})],
|
||||
}
|
||||
}),
|
||||
tools: None,
|
||||
tool_config: None,
|
||||
};
|
||||
|
||||
let response =
|
||||
create_cache(http_client.as_ref(), &cli.api_url, &cli.api_key, request).await?;
|
||||
match response {
|
||||
CreateCacheResponse::Created(created_cache) => {
|
||||
println!("Cache created:");
|
||||
println!(" ID: {:?}", created_cache.name.cache_id);
|
||||
println!(" Expires: {}", created_cache.expire_time);
|
||||
if let Some(token_count) = created_cache.usage_metadata.total_token_count {
|
||||
println!(" Total tokens: {}", token_count);
|
||||
}
|
||||
}
|
||||
CreateCacheResponse::CachingNotSupportedByModel => {
|
||||
println!("Cache not created, assuming this is due to the specified model ID.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Commands::UpdateCache { cache_id, ttl } => {
|
||||
let cache_name = if let Some(cache_id) = cache_id.strip_prefix("cachedContents/") {
|
||||
CacheName {
|
||||
cache_id: cache_id.to_string(),
|
||||
}
|
||||
} else {
|
||||
CacheName {
|
||||
cache_id: cache_id.clone(),
|
||||
}
|
||||
};
|
||||
|
||||
let request = UpdateCacheRequest {
|
||||
name: cache_name.clone(),
|
||||
ttl: Duration::from_secs(*ttl),
|
||||
};
|
||||
|
||||
let response =
|
||||
update_cache(http_client.as_ref(), &cli.api_url, &cli.api_key, request).await?;
|
||||
println!("Cache updated:");
|
||||
println!(" ID: {}", cache_name.cache_id);
|
||||
println!(" New expiration: {}", response.expire_time);
|
||||
}
|
||||
|
||||
Commands::Chat {
|
||||
system_instruction,
|
||||
max_tokens,
|
||||
temperature,
|
||||
history_file,
|
||||
} => {
|
||||
// Initialize a vector to store conversation history
|
||||
let mut history = Vec::new();
|
||||
|
||||
// Load history if provided
|
||||
if let Some(file_path) = history_file {
|
||||
if file_path.exists() {
|
||||
let file_content = fs::read_to_string(file_path)?;
|
||||
history = serde_json::from_str(&file_content)?;
|
||||
}
|
||||
}
|
||||
|
||||
// Add system instruction if present
|
||||
if let Some(instruction) = system_instruction {
|
||||
println!("System: {}", instruction);
|
||||
println!();
|
||||
}
|
||||
|
||||
loop {
|
||||
// Get user input
|
||||
print!("You: ");
|
||||
io::stdout().flush()?;
|
||||
let mut input = String::new();
|
||||
io::stdin().read_line(&mut input)?;
|
||||
|
||||
if input.trim().eq_ignore_ascii_case("exit")
|
||||
|| input.trim().eq_ignore_ascii_case("quit")
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
// Add user message to history
|
||||
let user_content = Content {
|
||||
role: Role::User,
|
||||
parts: vec![Part::TextPart(TextPart {
|
||||
text: input.trim().to_string(),
|
||||
})],
|
||||
};
|
||||
history.push(user_content);
|
||||
|
||||
// Create request with history
|
||||
let request = GenerateContentRequest {
|
||||
model: ModelName {
|
||||
model_id: cli.model.clone(),
|
||||
},
|
||||
contents: history
|
||||
.iter()
|
||||
.map(|content| Content {
|
||||
role: content.role,
|
||||
parts: content
|
||||
.parts
|
||||
.iter()
|
||||
.map(|part| match part {
|
||||
Part::TextPart(text_part) => Part::TextPart(TextPart {
|
||||
text: text_part.text.clone(),
|
||||
}),
|
||||
_ => panic!("Unsupported part type in history"),
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
.collect(),
|
||||
system_instruction: system_instruction.as_ref().map(|instruction| {
|
||||
SystemInstruction {
|
||||
parts: vec![Part::TextPart(TextPart {
|
||||
text: instruction.clone(),
|
||||
})],
|
||||
}
|
||||
}),
|
||||
generation_config: Some(GenerationConfig {
|
||||
max_output_tokens: *max_tokens,
|
||||
temperature: *temperature,
|
||||
top_p: None,
|
||||
top_k: None,
|
||||
candidate_count: None,
|
||||
stop_sequences: None,
|
||||
}),
|
||||
safety_settings: None,
|
||||
tools: None,
|
||||
tool_config: None,
|
||||
cached_content: None,
|
||||
};
|
||||
|
||||
// Get response
|
||||
print!("Assistant: ");
|
||||
let mut stream = stream_generate_content(
|
||||
http_client.as_ref(),
|
||||
&cli.api_url,
|
||||
&cli.api_key,
|
||||
request,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let mut model_response = String::new();
|
||||
while let Some(response) = stream.next().await {
|
||||
match response {
|
||||
Ok(resp) => {
|
||||
if let Some(candidates) = &resp.candidates {
|
||||
for candidate in candidates {
|
||||
for part in &candidate.content.parts {
|
||||
match part {
|
||||
Part::TextPart(text_part) => {
|
||||
print!("{}", text_part.text);
|
||||
model_response.push_str(&text_part.text);
|
||||
}
|
||||
_ => {
|
||||
println!("[Received non-text response part]");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("\nError: {}", e);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
println!();
|
||||
println!();
|
||||
|
||||
// Add model response to history
|
||||
let model_content = Content {
|
||||
role: Role::Model,
|
||||
parts: vec![Part::TextPart(TextPart {
|
||||
text: model_response,
|
||||
})],
|
||||
};
|
||||
history.push(model_content);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -1,7 +1,11 @@
|
||||
use std::time::Duration;
|
||||
use std::{fmt::Display, mem};
|
||||
|
||||
use anyhow::{Result, anyhow, bail};
|
||||
use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
|
||||
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
use time::OffsetDateTime;
|
||||
|
||||
pub const API_URL: &str = "https://generativelanguage.googleapis.com";
|
||||
|
||||
@@ -11,25 +15,13 @@ pub async fn stream_generate_content(
|
||||
api_key: &str,
|
||||
mut request: GenerateContentRequest,
|
||||
) -> Result<BoxStream<'static, Result<GenerateContentResponse>>> {
|
||||
if request.contents.is_empty() {
|
||||
bail!("Request must contain at least one content item");
|
||||
}
|
||||
validate_generate_content_request(&request)?;
|
||||
|
||||
if let Some(user_content) = request
|
||||
.contents
|
||||
.iter()
|
||||
.find(|content| content.role == Role::User)
|
||||
{
|
||||
if user_content.parts.is_empty() {
|
||||
bail!("User content must contain at least one part");
|
||||
}
|
||||
}
|
||||
// The `model` field is emptied as it is provided as a path parameter.
|
||||
let model_id = mem::take(&mut request.model.model_id);
|
||||
|
||||
let uri = format!(
|
||||
"{api_url}/v1beta/models/{model}:streamGenerateContent?alt=sse&key={api_key}",
|
||||
model = request.model
|
||||
);
|
||||
request.model.clear();
|
||||
let uri =
|
||||
format!("{api_url}/v1beta/models/{model_id}:streamGenerateContent?alt=sse&key={api_key}",);
|
||||
|
||||
let request_builder = HttpRequest::builder()
|
||||
.method(Method::POST)
|
||||
@@ -65,7 +57,7 @@ pub async fn stream_generate_content(
|
||||
let mut text = String::new();
|
||||
response.body_mut().read_to_string(&mut text).await?;
|
||||
Err(anyhow!(
|
||||
"error during streamGenerateContent, status code: {:?}, body: {}",
|
||||
"error during Gemini content generation, status code: {:?}, body: {}",
|
||||
response.status(),
|
||||
text
|
||||
))
|
||||
@@ -76,18 +68,22 @@ pub async fn count_tokens(
|
||||
client: &dyn HttpClient,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
model_id: &str,
|
||||
request: CountTokensRequest,
|
||||
) -> Result<CountTokensResponse> {
|
||||
let uri = format!("{api_url}/v1beta/models/{model_id}:countTokens?key={api_key}",);
|
||||
let request = serde_json::to_string(&request)?;
|
||||
validate_generate_content_request(&request.generate_content_request)?;
|
||||
|
||||
let uri = format!(
|
||||
"{api_url}/v1beta/models/{model_id}:countTokens?key={api_key}",
|
||||
model_id = &request.generate_content_request.model.model_id,
|
||||
);
|
||||
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let request_builder = HttpRequest::builder()
|
||||
.method(Method::POST)
|
||||
.uri(&uri)
|
||||
.header("Content-Type", "application/json");
|
||||
|
||||
let http_request = request_builder.body(AsyncBody::from(request))?;
|
||||
|
||||
let mut response = client.send(http_request).await?;
|
||||
let mut text = String::new();
|
||||
response.body_mut().read_to_string(&mut text).await?;
|
||||
@@ -95,13 +91,114 @@ pub async fn count_tokens(
|
||||
Ok(serde_json::from_str::<CountTokensResponse>(&text)?)
|
||||
} else {
|
||||
Err(anyhow!(
|
||||
"error during countTokens, status code: {:?}, body: {}",
|
||||
"error during Gemini token counting, status code: {:?}, body: {}",
|
||||
response.status(),
|
||||
text
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn create_cache(
|
||||
client: &dyn HttpClient,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
request: CreateCacheRequest,
|
||||
) -> Result<CreateCacheResponse> {
|
||||
if let Some(user_content) = request
|
||||
.contents
|
||||
.iter()
|
||||
.find(|content| content.role == Role::User)
|
||||
{
|
||||
if user_content.parts.is_empty() {
|
||||
bail!("User content must contain at least one part");
|
||||
}
|
||||
}
|
||||
let uri = format!("{api_url}/v1beta/cachedContents?key={api_key}");
|
||||
let request_builder = HttpRequest::builder()
|
||||
.method(Method::POST)
|
||||
.uri(uri)
|
||||
.header("Content-Type", "application/json");
|
||||
|
||||
let http_request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
|
||||
let mut response = client.send(http_request).await?;
|
||||
let mut text = String::new();
|
||||
response.body_mut().read_to_string(&mut text).await?;
|
||||
if response.status().is_success() {
|
||||
let created_cache = serde_json::from_str::<CreatedCache>(&text)?;
|
||||
Ok(CreateCacheResponse::Created(created_cache))
|
||||
} else if response.status().as_u16() == 404
|
||||
&& text.contains("or is not supported for createCachedContent")
|
||||
{
|
||||
log::info!(
|
||||
"will no longer attempt Gemini cache creation for model `{}` due to {:?} response: {}",
|
||||
request.model.model_id,
|
||||
response.status(),
|
||||
text
|
||||
);
|
||||
Ok(CreateCacheResponse::CachingNotSupportedByModel)
|
||||
} else {
|
||||
Err(anyhow!(
|
||||
"unexpected {:?} response during Gemini cache creation: {}",
|
||||
response.status(),
|
||||
text
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn update_cache(
|
||||
client: &dyn HttpClient,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
mut request: UpdateCacheRequest,
|
||||
) -> Result<UpdateCacheResponse> {
|
||||
// The `name` field is emptied as it is provided as a path parameter.
|
||||
let name = mem::take(&mut request.name);
|
||||
let uri = format!(
|
||||
"{api_url}/v1beta/cachedContents/{cache_id}?key={api_key}",
|
||||
cache_id = &name.cache_id
|
||||
);
|
||||
let request_builder = HttpRequest::builder()
|
||||
.method(Method::PATCH)
|
||||
.uri(uri)
|
||||
.header("Content-Type", "application/json");
|
||||
|
||||
let http_request = request_builder.body(AsyncBody::from(serde_json::to_string(&request)?))?;
|
||||
let mut response = client.send(http_request).await?;
|
||||
let mut text = String::new();
|
||||
response.body_mut().read_to_string(&mut text).await?;
|
||||
if response.status().is_success() {
|
||||
Ok(serde_json::from_str::<UpdateCacheResponse>(&text)?)
|
||||
} else {
|
||||
Err(anyhow!(
|
||||
"error during Gemini cache update, status code: {:?}, body: {}",
|
||||
response.status(),
|
||||
text
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn validate_generate_content_request(request: &GenerateContentRequest) -> Result<()> {
|
||||
if request.model.is_empty() {
|
||||
bail!("Model must be specified");
|
||||
}
|
||||
|
||||
if request.contents.is_empty() {
|
||||
bail!("Request must contain at least one content item");
|
||||
}
|
||||
|
||||
if let Some(user_content) = request
|
||||
.contents
|
||||
.iter()
|
||||
.find(|content| content.role == Role::User)
|
||||
{
|
||||
if user_content.parts.is_empty() {
|
||||
bail!("User content must contain at least one part");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub enum Task {
|
||||
#[serde(rename = "generateContent")]
|
||||
@@ -119,8 +216,8 @@ pub enum Task {
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GenerateContentRequest {
|
||||
#[serde(default, skip_serializing_if = "String::is_empty")]
|
||||
pub model: String,
|
||||
#[serde(default, skip_serializing_if = "ModelName::is_empty")]
|
||||
pub model: ModelName,
|
||||
pub contents: Vec<Content>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub system_instruction: Option<SystemInstruction>,
|
||||
@@ -132,6 +229,8 @@ pub struct GenerateContentRequest {
|
||||
pub tools: Option<Vec<Tool>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_config: Option<ToolConfig>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cached_content: Option<CacheName>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
@@ -161,7 +260,7 @@ pub struct GenerateContentCandidate {
|
||||
pub citation_metadata: Option<CitationMetadata>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Content {
|
||||
#[serde(default)]
|
||||
@@ -169,20 +268,20 @@ pub struct Content {
|
||||
pub role: Role,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct SystemInstruction {
|
||||
pub parts: Vec<Part>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Deserialize, Serialize)]
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub enum Role {
|
||||
User,
|
||||
Model,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum Part {
|
||||
TextPart(TextPart),
|
||||
@@ -191,32 +290,32 @@ pub enum Part {
|
||||
FunctionResponsePart(FunctionResponsePart),
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct TextPart {
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct InlineDataPart {
|
||||
pub inline_data: GenerativeContentBlob,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GenerativeContentBlob {
|
||||
pub mime_type: String,
|
||||
pub data: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FunctionCallPart {
|
||||
pub function_call: FunctionCall,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FunctionResponsePart {
|
||||
pub function_response: FunctionResponse,
|
||||
@@ -251,7 +350,7 @@ pub struct PromptFeedback {
|
||||
pub block_reason_message: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Default)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct UsageMetadata {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
@@ -350,7 +449,7 @@ pub struct SafetyRating {
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CountTokensRequest {
|
||||
pub contents: Vec<Content>,
|
||||
pub generate_content_request: GenerateContentRequest,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
@@ -359,31 +458,31 @@ pub struct CountTokensResponse {
|
||||
pub total_tokens: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct FunctionCall {
|
||||
pub name: String,
|
||||
pub args: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct FunctionResponse {
|
||||
pub name: String,
|
||||
pub response: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Tool {
|
||||
pub function_declarations: Vec<FunctionDeclaration>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ToolConfig {
|
||||
pub function_calling_config: FunctionCallingConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FunctionCallingConfig {
|
||||
pub mode: FunctionCallingMode,
|
||||
@@ -391,7 +490,7 @@ pub struct FunctionCallingConfig {
|
||||
pub allowed_function_names: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum FunctionCallingMode {
|
||||
Auto,
|
||||
@@ -399,27 +498,260 @@ pub enum FunctionCallingMode {
|
||||
None,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct FunctionDeclaration {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, Hash, PartialEq, Eq)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CreateCacheRequest {
|
||||
#[serde(
|
||||
serialize_with = "serialize_duration",
|
||||
deserialize_with = "deserialize_duration"
|
||||
)]
|
||||
pub ttl: Duration,
|
||||
pub model: ModelName,
|
||||
pub contents: Vec<Content>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub system_instruction: Option<SystemInstruction>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<Tool>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_config: Option<ToolConfig>,
|
||||
// Other fields that could be provided:
|
||||
//
|
||||
// name: The resource name referring to the cached content. Format: cachedContents/{id}
|
||||
// display_name: user-generated meaningful display name of the cached content. Maximum 128 Unicode characters.
|
||||
}
|
||||
|
||||
pub enum CreateCacheResponse {
|
||||
Created(CreatedCache),
|
||||
CachingNotSupportedByModel,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CreatedCache {
|
||||
pub name: CacheName,
|
||||
#[serde(
|
||||
serialize_with = "time::serde::rfc3339::serialize",
|
||||
deserialize_with = "time::serde::rfc3339::deserialize"
|
||||
)]
|
||||
pub expire_time: OffsetDateTime,
|
||||
pub usage_metadata: UsageMetadata,
|
||||
// Other fields that could be provided:
|
||||
//
|
||||
// create_time: Creation time of the cache entry.
|
||||
// update_time: When the cache entry was last updated in UTC time.
|
||||
// usage_metadata: Metadata on the usage of the cached content.
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct UpdateCacheRequest {
|
||||
pub name: CacheName,
|
||||
#[serde(
|
||||
serialize_with = "serialize_duration",
|
||||
deserialize_with = "deserialize_duration"
|
||||
)]
|
||||
pub ttl: Duration,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct UpdateCacheResponse {
|
||||
#[serde(
|
||||
serialize_with = "time::serde::rfc3339::serialize",
|
||||
deserialize_with = "time::serde::rfc3339::deserialize"
|
||||
)]
|
||||
pub expire_time: OffsetDateTime,
|
||||
}
|
||||
|
||||
const MODEL_NAME_PREFIX: &str = "models/";
|
||||
const CACHE_NAME_PREFIX: &str = "cachedContents/";
|
||||
|
||||
#[derive(Clone, Debug, Default, Hash, PartialEq, Eq)]
|
||||
pub struct ModelName {
|
||||
pub model_id: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default)]
|
||||
pub struct CacheName {
|
||||
pub cache_id: String,
|
||||
}
|
||||
|
||||
impl ModelName {
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.model_id.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl CacheName {
|
||||
pub fn is_empty(&self) -> bool {
|
||||
self.cache_id.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for ModelName {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
serializer.serialize_str(&format!("{MODEL_NAME_PREFIX}{}", &self.model_id))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for ModelName {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let string = String::deserialize(deserializer)?;
|
||||
if let Some(id) = string.strip_prefix(MODEL_NAME_PREFIX) {
|
||||
Ok(Self {
|
||||
model_id: id.to_string(),
|
||||
})
|
||||
} else {
|
||||
return Err(serde::de::Error::custom(format!(
|
||||
"Expected model name to begin with {}, got: {}",
|
||||
MODEL_NAME_PREFIX, string
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for ModelName {
|
||||
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(formatter, "{}", self.model_id)
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for CacheName {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
serializer.serialize_str(&format!("{CACHE_NAME_PREFIX}{}", &self.cache_id))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for CacheName {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let string = String::deserialize(deserializer)?;
|
||||
if let Some(id) = string.strip_prefix(CACHE_NAME_PREFIX) {
|
||||
Ok(CacheName {
|
||||
cache_id: id.to_string(),
|
||||
})
|
||||
} else {
|
||||
return Err(serde::de::Error::custom(format!(
|
||||
"Expected cache name to begin with {}, got: {}",
|
||||
CACHE_NAME_PREFIX, string
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for CacheName {
|
||||
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(formatter, "{}", self.cache_id)
|
||||
}
|
||||
}
|
||||
|
||||
/// Serializes a Duration as a string in the format "X.Ys" where X is the whole seconds
|
||||
/// and Y is up to 9 decimal places of fractional seconds.
|
||||
pub fn serialize_duration<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
let secs = duration.as_secs();
|
||||
let nanos = duration.subsec_nanos();
|
||||
|
||||
let formatted = if nanos == 0 {
|
||||
format!("{}s", secs)
|
||||
} else {
|
||||
// Remove trailing zeros from nanos
|
||||
let mut nanos_str = format!("{:09}", nanos);
|
||||
while nanos_str.ends_with('0') && nanos_str.len() > 1 {
|
||||
nanos_str.pop();
|
||||
}
|
||||
format!("{}.{}s", secs, nanos_str)
|
||||
};
|
||||
|
||||
serializer.serialize_str(&formatted)
|
||||
}
|
||||
|
||||
pub fn deserialize_duration<'de, D>(deserializer: D) -> Result<Duration, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let duration_str = String::deserialize(deserializer)?;
|
||||
|
||||
let Some(num_part) = duration_str.strip_suffix('s') else {
|
||||
return Err(serde::de::Error::custom(format!(
|
||||
"Duration must end with 's', got: {}",
|
||||
duration_str
|
||||
)));
|
||||
};
|
||||
|
||||
if let Some(decimal_ix) = num_part.find('.') {
|
||||
let secs_part = &num_part[0..decimal_ix];
|
||||
let frac_len = (num_part.len() - (decimal_ix + 1)).min(9);
|
||||
let frac_start_ix = decimal_ix + 1;
|
||||
let frac_end_ix = frac_start_ix + frac_len;
|
||||
let frac_part = &num_part[frac_start_ix..frac_end_ix];
|
||||
|
||||
let secs = u64::from_str_radix(secs_part, 10).map_err(|e| {
|
||||
serde::de::Error::custom(format!(
|
||||
"Invalid seconds in duration: {}. Error: {}",
|
||||
duration_str, e
|
||||
))
|
||||
})?;
|
||||
|
||||
let frac_number = frac_part.parse::<u32>().map_err(|e| {
|
||||
serde::de::Error::custom(format!(
|
||||
"Invalid fractional seconds in duration: {}. Error: {}",
|
||||
duration_str, e
|
||||
))
|
||||
})?;
|
||||
|
||||
let nanos = frac_number * 10u32.pow(9 - frac_len as u32);
|
||||
|
||||
Ok(Duration::new(secs, nanos))
|
||||
} else {
|
||||
let secs = u64::from_str_radix(num_part, 10).map_err(|e| {
|
||||
serde::de::Error::custom(format!(
|
||||
"Invalid duration format: {}. Error: {}",
|
||||
duration_str, e
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(Duration::new(secs, 0))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
|
||||
#[derive(Clone, Default, Debug, Deserialize, Serialize, PartialEq, Eq, strum::EnumIter)]
|
||||
pub enum Model {
|
||||
#[serde(rename = "gemini-1.5-pro")]
|
||||
#[serde(rename = "gemini-1.5-pro-002")]
|
||||
Gemini15Pro,
|
||||
#[serde(rename = "gemini-1.5-flash")]
|
||||
#[serde(rename = "gemini-1.5-flash-002")]
|
||||
Gemini15Flash,
|
||||
/// Note: replaced by `gemini-2.5-pro-exp-03-25` (continues to work in API).
|
||||
#[serde(rename = "gemini-2.0-pro-exp")]
|
||||
Gemini20Pro,
|
||||
#[serde(rename = "gemini-2.0-flash")]
|
||||
#[default]
|
||||
Gemini20Flash,
|
||||
/// Note: replaced by `gemini-2.5-flash-preview-04-17` (continues to work in API).
|
||||
#[serde(rename = "gemini-2.0-flash-thinking-exp")]
|
||||
Gemini20FlashThinking,
|
||||
/// Note: replaced by `gemini-2.0-flash-lite` (continues to work in API).
|
||||
#[serde(rename = "gemini-2.0-flash-lite-preview")]
|
||||
Gemini20FlashLite,
|
||||
#[serde(rename = "gemini-2.5-pro-exp-03-25")]
|
||||
@@ -444,8 +776,8 @@ impl Model {
|
||||
|
||||
pub fn id(&self) -> &str {
|
||||
match self {
|
||||
Model::Gemini15Pro => "gemini-1.5-pro",
|
||||
Model::Gemini15Flash => "gemini-1.5-flash",
|
||||
Model::Gemini15Pro => "gemini-1.5-pro-002",
|
||||
Model::Gemini15Flash => "gemini-1.5-flash-002",
|
||||
Model::Gemini20Pro => "gemini-2.0-pro-exp",
|
||||
Model::Gemini20Flash => "gemini-2.0-flash",
|
||||
Model::Gemini20FlashThinking => "gemini-2.0-flash-thinking-exp",
|
||||
@@ -457,6 +789,19 @@ impl Model {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn matches_id(&self, other_id: &str) -> bool {
|
||||
if self.id() == other_id {
|
||||
return true;
|
||||
}
|
||||
// These IDs are present in user settings. The `-002` stable model version is added in the
|
||||
// ID used for the model so that caching works.
|
||||
match self {
|
||||
Model::Gemini15Pro => other_id == "gemini-1.5-pro",
|
||||
Model::Gemini15Flash => other_id == "gemini-1.5-flash",
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn display_name(&self) -> &str {
|
||||
match self {
|
||||
Model::Gemini15Pro => "Gemini 1.5 Pro",
|
||||
@@ -497,3 +842,54 @@ impl std::fmt::Display for Model {
|
||||
write!(f, "{}", self.id())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_duration_serialization() {
|
||||
#[derive(Debug, Deserialize, Serialize, PartialEq, Eq)]
|
||||
struct Example {
|
||||
#[serde(
|
||||
serialize_with = "serialize_duration",
|
||||
deserialize_with = "deserialize_duration"
|
||||
)]
|
||||
duration: Duration,
|
||||
}
|
||||
|
||||
let example = Example {
|
||||
duration: Duration::from_secs(5),
|
||||
};
|
||||
let serialized = serde_json::to_string(&example).unwrap();
|
||||
let deserialized: Example = serde_json::from_str(&serialized).unwrap();
|
||||
assert_eq!(serialized, r#"{"duration":"5s"}"#);
|
||||
assert_eq!(deserialized, example);
|
||||
|
||||
let example = Example {
|
||||
duration: Duration::from_millis(5534),
|
||||
};
|
||||
let serialized = serde_json::to_string(&example).unwrap();
|
||||
let deserialized: Example = serde_json::from_str(&serialized).unwrap();
|
||||
assert_eq!(serialized, r#"{"duration":"5.534s"}"#);
|
||||
assert_eq!(deserialized, example);
|
||||
|
||||
let example = Example {
|
||||
duration: Duration::from_nanos(12345678900),
|
||||
};
|
||||
let serialized = serde_json::to_string(&example).unwrap();
|
||||
let deserialized: Example = serde_json::from_str(&serialized).unwrap();
|
||||
assert_eq!(serialized, r#"{"duration":"12.3456789s"}"#);
|
||||
assert_eq!(deserialized, example);
|
||||
|
||||
// Deserializer doesn't panic for too many fractional digits
|
||||
let deserialized: Example =
|
||||
serde_json::from_str(r#"{"duration":"5.12345678905s"}"#).unwrap();
|
||||
assert_eq!(
|
||||
deserialized,
|
||||
Example {
|
||||
duration: Duration::from_nanos(5123456789)
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -228,12 +228,27 @@ impl Default for LanguageModelTextStream {
|
||||
}
|
||||
|
||||
pub trait LanguageModel: Send + Sync {
|
||||
/// The ID used for the model when making requests. If checking for match of a user-provided
|
||||
/// string, use `matches_id` instead.
|
||||
fn id(&self) -> LanguageModelId;
|
||||
fn name(&self) -> LanguageModelName;
|
||||
fn provider_id(&self) -> LanguageModelProviderId;
|
||||
fn provider_name(&self) -> LanguageModelProviderName;
|
||||
fn telemetry_id(&self) -> String;
|
||||
|
||||
/// Use this to check whether this model should be used for the given user-provided model ID, to support multiple IDs.
|
||||
/// Multiple IDs are used for:
|
||||
///
|
||||
/// * Preferring a particular ID in requests (Gemini only supports caching for stable IDs).
|
||||
///
|
||||
/// * When one model has been replaced by another. This allows the display name to update to
|
||||
/// reflect which model is actually being used.
|
||||
///
|
||||
/// TODO: In the future consider replacing this mechanism with a settings migration.
|
||||
fn matches_id(&self, other_id: &LanguageModelId) -> bool {
|
||||
&self.id() == other_id
|
||||
}
|
||||
|
||||
fn api_key(&self, _cx: &App) -> Option<String> {
|
||||
None
|
||||
}
|
||||
|
||||
@@ -197,7 +197,7 @@ impl LanguageModelRegistry {
|
||||
let model = provider
|
||||
.provided_models(cx)
|
||||
.iter()
|
||||
.find(|model| model.id() == selected_model.model)?
|
||||
.find(|model| model.matches_id(&selected_model.model))?
|
||||
.clone();
|
||||
Some(ConfiguredModel { provider, model })
|
||||
}
|
||||
|
||||
@@ -20,8 +20,8 @@ aws_http_client.workspace = true
|
||||
bedrock.workspace = true
|
||||
client.workspace = true
|
||||
collections.workspace = true
|
||||
credentials_provider.workspace = true
|
||||
copilot = { workspace = true, features = ["schemars"] }
|
||||
credentials_provider.workspace = true
|
||||
deepseek = { workspace = true, features = ["schemars"] }
|
||||
editor.workspace = true
|
||||
feature_flags.workspace = true
|
||||
@@ -38,6 +38,7 @@ menu.workspace = true
|
||||
mistral = { workspace = true, features = ["schemars"] }
|
||||
ollama = { workspace = true, features = ["schemars"] }
|
||||
open_ai = { workspace = true, features = ["schemars"] }
|
||||
parking_lot.workspace = true
|
||||
partial-json-fixer.workspace = true
|
||||
project.workspace = true
|
||||
proto.workspace = true
|
||||
@@ -50,6 +51,7 @@ strum.workspace = true
|
||||
theme.workspace = true
|
||||
thiserror.workspace = true
|
||||
tiktoken-rs.workspace = true
|
||||
time.workspace = true
|
||||
tokio = { workspace = true, features = ["rt", "rt-multi-thread"] }
|
||||
ui.workspace = true
|
||||
util.workspace = true
|
||||
|
||||
@@ -461,6 +461,7 @@ impl LanguageModel for AnthropicModel {
|
||||
self.model.max_output_tokens(),
|
||||
self.model.mode(),
|
||||
);
|
||||
|
||||
let request = self.stream_completion(request, cx);
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let response = request
|
||||
|
||||
@@ -656,6 +656,14 @@ impl LanguageModel for CloudLanguageModel {
|
||||
self.id.clone()
|
||||
}
|
||||
|
||||
fn matches_id(&self, other_id: &LanguageModelId) -> bool {
|
||||
match &self.model {
|
||||
CloudModel::Anthropic(model) => model.matches_id(&other_id.0),
|
||||
CloudModel::Google(model) => model.matches_id(&other_id.0),
|
||||
CloudModel::OpenAi(model) => model.matches_id(&other_id.0),
|
||||
}
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelName {
|
||||
LanguageModelName::from(self.model.display_name().to_string())
|
||||
}
|
||||
@@ -718,7 +726,8 @@ impl LanguageModel for CloudLanguageModel {
|
||||
CloudModel::Google(model) => {
|
||||
let client = self.client.clone();
|
||||
let llm_api_token = self.llm_api_token.clone();
|
||||
let request = into_google(request, model.id().into());
|
||||
let model_id = model.id().to_string();
|
||||
let generate_content_request = into_google(request, model_id.clone());
|
||||
async move {
|
||||
let http_client = &client.http_client();
|
||||
let token = llm_api_token.acquire(&client).await?;
|
||||
@@ -736,9 +745,9 @@ impl LanguageModel for CloudLanguageModel {
|
||||
};
|
||||
let request_body = CountTokensBody {
|
||||
provider: zed_llm_client::LanguageModelProvider::Google,
|
||||
model: model.id().into(),
|
||||
model: model_id,
|
||||
provider_request: serde_json::to_value(&google_ai::CountTokensRequest {
|
||||
contents: request.contents,
|
||||
generate_content_request,
|
||||
})?,
|
||||
};
|
||||
let request = request_builder
|
||||
@@ -895,7 +904,7 @@ impl LanguageModel for CloudLanguageModel {
|
||||
prompt_id,
|
||||
mode,
|
||||
provider: zed_llm_client::LanguageModelProvider::Google,
|
||||
model: request.model.clone(),
|
||||
model: request.model.model_id.clone(),
|
||||
provider_request: serde_json::to_value(&request)?,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::BTreeMap;
|
||||
use collections::{BTreeMap, FxHasher, HashMap, HashSet};
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use editor::{Editor, EditorElement, EditorStyle};
|
||||
use futures::future::{self, Shared};
|
||||
use futures::{FutureExt, Stream, StreamExt, future::BoxFuture};
|
||||
use google_ai::{
|
||||
FunctionDeclaration, GenerateContentResponse, Part, SystemInstruction, UsageMetadata,
|
||||
CacheName, Content, CreateCacheRequest, CreateCacheResponse, FunctionDeclaration,
|
||||
GenerateContentRequest, GenerateContentResponse, ModelName, Part, SystemInstruction,
|
||||
UpdateCacheRequest, UpdateCacheResponse, UsageMetadata,
|
||||
};
|
||||
use gpui::{
|
||||
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
|
||||
AnyView, App, AppContext, AsyncApp, BackgroundExecutor, Context, Entity, FontStyle,
|
||||
Subscription, Task, TextStyle, WhiteSpace,
|
||||
};
|
||||
use http_client::HttpClient;
|
||||
use language_model::{
|
||||
@@ -20,16 +24,20 @@ use language_model::{
|
||||
LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
|
||||
LanguageModelRequest, RateLimiter, Role,
|
||||
};
|
||||
use parking_lot::Mutex;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use settings::{Settings, SettingsStore};
|
||||
use std::hash::{Hash as _, Hasher as _};
|
||||
use std::pin::Pin;
|
||||
use std::sync::{
|
||||
Arc,
|
||||
atomic::{self, AtomicU64},
|
||||
};
|
||||
use std::time::Duration;
|
||||
use strum::IntoEnumIterator;
|
||||
use theme::ThemeSettings;
|
||||
use time::UtcDateTime;
|
||||
use ui::{Icon, IconName, List, Tooltip, prelude::*};
|
||||
use util::ResultExt;
|
||||
|
||||
@@ -38,6 +46,10 @@ use crate::ui::InstructionListItem;
|
||||
|
||||
const PROVIDER_ID: &str = "google";
|
||||
const PROVIDER_NAME: &str = "Google AI";
|
||||
const CACHE_TTL: Duration = Duration::from_secs(60 * 5);
|
||||
|
||||
/// Minimum amount of time left before a cache expires for it to be used.
|
||||
const CACHE_TTL_MINIMUM_FOR_USAGE: Duration = Duration::from_secs(10);
|
||||
|
||||
#[derive(Default, Clone, Debug, PartialEq)]
|
||||
pub struct GoogleSettings {
|
||||
@@ -50,6 +62,7 @@ pub struct AvailableModel {
|
||||
name: String,
|
||||
display_name: Option<String>,
|
||||
max_tokens: usize,
|
||||
caching: bool,
|
||||
}
|
||||
|
||||
pub struct GoogleLanguageModelProvider {
|
||||
@@ -160,6 +173,7 @@ impl GoogleLanguageModelProvider {
|
||||
id: LanguageModelId::from(model.id().to_string()),
|
||||
model,
|
||||
state: self.state.clone(),
|
||||
cache: Mutex::new(Cache::default()).into(),
|
||||
http_client: self.http_client.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
})
|
||||
@@ -227,6 +241,7 @@ impl LanguageModelProvider for GoogleLanguageModelProvider {
|
||||
id: LanguageModelId::from(model.id().to_string()),
|
||||
model,
|
||||
state: self.state.clone(),
|
||||
cache: Mutex::new(Cache::default()).into(),
|
||||
http_client: self.http_client.clone(),
|
||||
request_limiter: RateLimiter::new(4),
|
||||
}) as Arc<dyn LanguageModel>
|
||||
@@ -256,6 +271,7 @@ pub struct GoogleLanguageModel {
|
||||
id: LanguageModelId,
|
||||
model: google_ai::Model,
|
||||
state: gpui::Entity<State>,
|
||||
cache: Arc<Mutex<Cache>>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
request_limiter: RateLimiter,
|
||||
}
|
||||
@@ -263,32 +279,77 @@ pub struct GoogleLanguageModel {
|
||||
impl GoogleLanguageModel {
|
||||
fn stream_completion(
|
||||
&self,
|
||||
request: google_ai::GenerateContentRequest,
|
||||
request: impl 'static + Send + Future<Output = google_ai::GenerateContentRequest>,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<
|
||||
'static,
|
||||
Result<futures::stream::BoxStream<'static, Result<GenerateContentResponse>>>,
|
||||
> {
|
||||
let http_client = self.http_client.clone();
|
||||
match self.api_url_and_key(cx) {
|
||||
Ok((api_url, api_key)) => {
|
||||
let http_client = self.http_client.clone();
|
||||
async move {
|
||||
let request = google_ai::stream_generate_content(
|
||||
http_client.as_ref(),
|
||||
&api_url,
|
||||
&api_key,
|
||||
request.await,
|
||||
);
|
||||
request.await.context("failed to stream completion")
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
Err(err) => future::ready(Err(err)).boxed(),
|
||||
}
|
||||
}
|
||||
|
||||
let Ok((api_key, api_url)) = cx.read_entity(&self.state, |state, cx| {
|
||||
fn create_cache(
|
||||
&self,
|
||||
request: google_ai::CreateCacheRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<CreateCacheResponse>> {
|
||||
match self.api_url_and_key(cx) {
|
||||
Ok((api_url, api_key)) => {
|
||||
let http_client = self.http_client.clone();
|
||||
async move {
|
||||
google_ai::create_cache(http_client.as_ref(), &api_url, &api_key, request).await
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
Err(err) => future::ready(Err(err)).boxed(),
|
||||
}
|
||||
}
|
||||
|
||||
fn update_cache(
|
||||
&self,
|
||||
request: UpdateCacheRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> BoxFuture<'static, Result<UpdateCacheResponse>> {
|
||||
match self.api_url_and_key(cx) {
|
||||
Ok((api_url, api_key)) => {
|
||||
let http_client = self.http_client.clone();
|
||||
async move {
|
||||
google_ai::update_cache(http_client.as_ref(), &api_url, &api_key, request).await
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
Err(err) => future::ready(Err(err)).boxed(),
|
||||
}
|
||||
}
|
||||
|
||||
fn api_url_and_key(&self, cx: &AsyncApp) -> Result<(String, String)> {
|
||||
let Ok((api_url, api_key)) = self.state.read_with(cx, |state, cx| {
|
||||
let settings = &AllLanguageModelSettings::get_global(cx).google;
|
||||
(state.api_key.clone(), settings.api_url.clone())
|
||||
(settings.api_url.clone(), state.api_key.clone())
|
||||
}) else {
|
||||
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
|
||||
return Err(anyhow!("App state dropped"));
|
||||
};
|
||||
|
||||
async move {
|
||||
let api_key = api_key.ok_or_else(|| anyhow!("Missing Google API key"))?;
|
||||
let request = google_ai::stream_generate_content(
|
||||
http_client.as_ref(),
|
||||
&api_url,
|
||||
&api_key,
|
||||
request,
|
||||
);
|
||||
request.await.context("failed to stream completion")
|
||||
}
|
||||
.boxed()
|
||||
let Some(api_key) = api_key else {
|
||||
return Err(anyhow!("Missing Google API key"));
|
||||
};
|
||||
|
||||
Ok((api_url, api_key))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -297,6 +358,10 @@ impl LanguageModel for GoogleLanguageModel {
|
||||
self.id.clone()
|
||||
}
|
||||
|
||||
fn matches_id(&self, other_id: &LanguageModelId) -> bool {
|
||||
self.model.matches_id(&other_id.0)
|
||||
}
|
||||
|
||||
fn name(&self) -> LanguageModelName {
|
||||
LanguageModelName::from(self.model.display_name().to_string())
|
||||
}
|
||||
@@ -344,9 +409,8 @@ impl LanguageModel for GoogleLanguageModel {
|
||||
http_client.as_ref(),
|
||||
&api_url,
|
||||
&api_key,
|
||||
&model_id,
|
||||
google_ai::CountTokensRequest {
|
||||
contents: request.contents,
|
||||
generate_content_request: request,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
@@ -368,10 +432,19 @@ impl LanguageModel for GoogleLanguageModel {
|
||||
>,
|
||||
>,
|
||||
> {
|
||||
let is_last_message_cached = request
|
||||
.messages
|
||||
.last()
|
||||
.map_or(false, |content| content.cache);
|
||||
|
||||
let request = into_google(request, self.model.id().to_string());
|
||||
let request = self.stream_completion(request, cx);
|
||||
|
||||
let request_task = self.use_cache_and_create_cache(is_last_message_cached, request, cx);
|
||||
|
||||
let stream_request = self.stream_completion(request_task, cx);
|
||||
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let response = request
|
||||
let response = stream_request
|
||||
.await
|
||||
.map_err(|err| LanguageModelCompletionError::Other(anyhow!(err)))?;
|
||||
Ok(GoogleEventMapper::new().map_stream(response))
|
||||
@@ -382,7 +455,7 @@ impl LanguageModel for GoogleLanguageModel {
|
||||
|
||||
pub fn into_google(
|
||||
mut request: LanguageModelRequest,
|
||||
model: String,
|
||||
model_id: String,
|
||||
) -> google_ai::GenerateContentRequest {
|
||||
fn map_content(content: Vec<MessageContent>) -> Vec<Part> {
|
||||
content
|
||||
@@ -442,7 +515,7 @@ pub fn into_google(
|
||||
};
|
||||
|
||||
google_ai::GenerateContentRequest {
|
||||
model,
|
||||
model: google_ai::ModelName { model_id },
|
||||
system_instruction: system_instructions,
|
||||
contents: request
|
||||
.messages
|
||||
@@ -486,6 +559,7 @@ pub fn into_google(
|
||||
}]
|
||||
}),
|
||||
tool_config: None,
|
||||
cached_content: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -642,6 +716,287 @@ fn convert_usage(usage: &UsageMetadata) -> language_model::TokenUsage {
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct Cache {
|
||||
task_map: HashMap<CacheKey, Shared<Task<Option<CacheEntry>>>>,
|
||||
models_not_supported: HashSet<ModelName>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct CacheEntry {
|
||||
name: CacheName,
|
||||
expire_time: UtcDateTime,
|
||||
_drop_on_expire: Shared<Task<()>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
|
||||
struct CacheKey(u64);
|
||||
|
||||
#[derive(Debug)]
|
||||
enum CacheStatus {
|
||||
Absent,
|
||||
Creating,
|
||||
Failed,
|
||||
Valid(CacheEntry),
|
||||
Expired,
|
||||
}
|
||||
|
||||
impl CacheEntry {
|
||||
fn new(
|
||||
name: CacheName,
|
||||
expire_time: UtcDateTime,
|
||||
cache_key: CacheKey,
|
||||
cache: Arc<Mutex<Cache>>,
|
||||
executor: BackgroundExecutor,
|
||||
) -> Option<CacheEntry> {
|
||||
let now = UtcDateTime::now();
|
||||
let Some(remaining_time): Option<std::time::Duration> = (expire_time - now).try_into().ok()
|
||||
else {
|
||||
cache.lock().task_map.remove(&cache_key);
|
||||
return None;
|
||||
};
|
||||
let cache_expired = executor.timer(remaining_time);
|
||||
let drop_on_expire = executor
|
||||
.spawn({
|
||||
let name = name.clone();
|
||||
async move {
|
||||
cache_expired.await;
|
||||
cache.lock().task_map.remove(&cache_key);
|
||||
log::info!("removed cache `{name}` because it expired");
|
||||
}
|
||||
})
|
||||
.shared();
|
||||
|
||||
Some(CacheEntry {
|
||||
name,
|
||||
expire_time,
|
||||
_drop_on_expire: drop_on_expire,
|
||||
})
|
||||
}
|
||||
|
||||
fn is_expired(&self, now: UtcDateTime) -> bool {
|
||||
self.expire_time - CACHE_TTL_MINIMUM_FOR_USAGE > now
|
||||
}
|
||||
}
|
||||
|
||||
impl CacheKey {
|
||||
fn base(request: &GenerateContentRequest) -> Self {
|
||||
#[derive(Hash, PartialEq, Eq)]
|
||||
pub struct CacheBaseRef<'a> {
|
||||
pub model: &'a ModelName,
|
||||
pub system_instruction: &'a Option<SystemInstruction>,
|
||||
pub tools: &'a Option<Vec<google_ai::Tool>>,
|
||||
pub tool_config: &'a Option<google_ai::ToolConfig>,
|
||||
}
|
||||
|
||||
let cache_base = CacheBaseRef {
|
||||
model: &request.model,
|
||||
system_instruction: &request.system_instruction,
|
||||
tools: &request.tools,
|
||||
tool_config: &request.tool_config,
|
||||
};
|
||||
|
||||
let mut hasher = FxHasher::default();
|
||||
cache_base.hash(&mut hasher);
|
||||
Self(hasher.finish())
|
||||
}
|
||||
|
||||
fn for_message(predecessor: Self, message: &Content) -> Self {
|
||||
let mut hasher = FxHasher::default();
|
||||
predecessor.0.hash(&mut hasher);
|
||||
message.hash(&mut hasher);
|
||||
Self(hasher.finish())
|
||||
}
|
||||
}
|
||||
|
||||
impl Cache {
|
||||
fn get_unexpired(
|
||||
&self,
|
||||
key: &CacheKey,
|
||||
now: UtcDateTime,
|
||||
) -> Option<Shared<Task<Option<CacheEntry>>>> {
|
||||
let cache_task = self.task_map.get(key)?;
|
||||
match cache_task.clone().now_or_never() {
|
||||
Some(Some(existing_cache)) => {
|
||||
if existing_cache.is_expired(now) {
|
||||
Some(cache_task.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
Some(None) => {
|
||||
// Cache creation failed
|
||||
None
|
||||
}
|
||||
None => {
|
||||
// Caching task pending, so use it.
|
||||
Some(cache_task.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_status(&self, key: &CacheKey) -> CacheStatus {
|
||||
if let Some(existing_task) = self.task_map.get(&key) {
|
||||
match existing_task.clone().now_or_never() {
|
||||
Some(Some(existing_cache)) => {
|
||||
let now = UtcDateTime::now();
|
||||
if existing_cache.is_expired(now) {
|
||||
CacheStatus::Expired
|
||||
} else {
|
||||
CacheStatus::Valid(existing_cache)
|
||||
}
|
||||
}
|
||||
Some(None) => CacheStatus::Failed,
|
||||
None => CacheStatus::Creating,
|
||||
}
|
||||
} else {
|
||||
CacheStatus::Absent
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl GoogleLanguageModel {
|
||||
fn use_cache_and_create_cache(
|
||||
&self,
|
||||
should_create_cache: bool,
|
||||
mut request: google_ai::GenerateContentRequest,
|
||||
cx: &AsyncApp,
|
||||
) -> Task<google_ai::GenerateContentRequest> {
|
||||
if self
|
||||
.cache
|
||||
.lock()
|
||||
.models_not_supported
|
||||
.contains(&request.model)
|
||||
{
|
||||
return Task::ready(request);
|
||||
}
|
||||
|
||||
let base_cache_key = CacheKey::base(&request);
|
||||
let mut prev_cache_key = base_cache_key;
|
||||
let content_cache_keys = request
|
||||
.contents
|
||||
.iter()
|
||||
.map(|content| {
|
||||
let key = CacheKey::for_message(prev_cache_key, content);
|
||||
prev_cache_key = key;
|
||||
key
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if let (true, Some(cache_key)) = (should_create_cache, content_cache_keys.last().copied()) {
|
||||
let cache_status = self.cache.lock().get_status(&cache_key);
|
||||
match cache_status {
|
||||
CacheStatus::Creating => {}
|
||||
CacheStatus::Valid(existing_cache) => {
|
||||
let name = existing_cache.name;
|
||||
let update_cache_request = UpdateCacheRequest {
|
||||
name: name.clone(),
|
||||
ttl: CACHE_TTL,
|
||||
};
|
||||
// TODO: This is imperfect, but not in ways that are likely to matter in
|
||||
// practice. Ideally the cache update task would get stored such that:
|
||||
// * It's not posible for the cache to expire and get recreated while waiting
|
||||
// for the update response.
|
||||
// * It would get cancelled on drop.
|
||||
let update_cache_future = self.update_cache(update_cache_request, cx);
|
||||
let cache = self.cache.clone();
|
||||
let executor = cx.background_executor().clone();
|
||||
cx.background_spawn(async move {
|
||||
if let Some(response) = update_cache_future.await.log_err() {
|
||||
let cache_entry = CacheEntry::new(
|
||||
name,
|
||||
response.expire_time.to_utc(),
|
||||
cache_key,
|
||||
cache.clone(),
|
||||
executor,
|
||||
);
|
||||
cache
|
||||
.lock()
|
||||
.task_map
|
||||
.insert(cache_key, Task::ready(cache_entry).shared());
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
}
|
||||
CacheStatus::Absent | CacheStatus::Expired | CacheStatus::Failed => {
|
||||
let create_request = CreateCacheRequest {
|
||||
ttl: CACHE_TTL,
|
||||
model: request.model.clone(),
|
||||
contents: request.contents.clone(),
|
||||
system_instruction: request.system_instruction.clone(),
|
||||
tools: request.tools.clone(),
|
||||
tool_config: request.tool_config.clone(),
|
||||
};
|
||||
let model = request.model.clone();
|
||||
let create_cache_future = self
|
||||
.request_limiter
|
||||
.run(self.create_cache(create_request, cx));
|
||||
let cache = self.cache.clone();
|
||||
let executor = cx.background_executor().clone();
|
||||
let task = cx.background_spawn(async move {
|
||||
let result = match create_cache_future.await.log_err()? {
|
||||
CreateCacheResponse::Created(created_cache) => {
|
||||
log::info!("created cache `{}`", created_cache.name);
|
||||
CacheEntry::new(
|
||||
created_cache.name,
|
||||
created_cache.expire_time.to_utc(),
|
||||
cache_key,
|
||||
cache.clone(),
|
||||
executor,
|
||||
)
|
||||
}
|
||||
CreateCacheResponse::CachingNotSupportedByModel => {
|
||||
cache.lock().models_not_supported.insert(model);
|
||||
None
|
||||
}
|
||||
};
|
||||
if result.is_none() {
|
||||
cache.lock().task_map.remove(&cache_key);
|
||||
}
|
||||
result
|
||||
});
|
||||
self.cache.lock().task_map.insert(cache_key, task.shared());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
cx.background_spawn({
|
||||
let cache = self.cache.clone();
|
||||
async move {
|
||||
let mut prefix_len = 0;
|
||||
let mut found_cache_entry = None;
|
||||
let mut now = UtcDateTime::now();
|
||||
// The last key is skipped because `contents` must be non-empty.
|
||||
for (ix, key) in content_cache_keys.iter().enumerate().rev().skip(1) {
|
||||
let task = cache.lock().get_unexpired(&key, now);
|
||||
// TODO: Measure if it's worth it to await on caches vs using ones that are already ready.
|
||||
if let Some(task) = task {
|
||||
if let Some(cache_entry) = task.await {
|
||||
prefix_len = ix + 1;
|
||||
found_cache_entry = Some(cache_entry);
|
||||
break;
|
||||
} else {
|
||||
now = UtcDateTime::now();
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(found_cache_entry) = found_cache_entry {
|
||||
log::info!(
|
||||
"using cache `{}` which has {prefix_len} messages",
|
||||
found_cache_entry.name.cache_id
|
||||
);
|
||||
request.cached_content = Some(found_cache_entry.name);
|
||||
request.contents.drain(..prefix_len);
|
||||
request.system_instruction = None;
|
||||
request.tools = None;
|
||||
request.tool_config = None;
|
||||
}
|
||||
request
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
struct ConfigurationView {
|
||||
api_key_editor: Entity<Editor>,
|
||||
state: gpui::Entity<State>,
|
||||
|
||||
@@ -142,6 +142,10 @@ impl Model {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn matches_id(&self, other_id: &str) -> bool {
|
||||
self.id() == other_id
|
||||
}
|
||||
|
||||
pub fn display_name(&self) -> &str {
|
||||
match self {
|
||||
Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
|
||||
|
||||
@@ -540,7 +540,7 @@ impl SummaryIndex {
|
||||
);
|
||||
let Some(model) = LanguageModelRegistry::read_global(cx)
|
||||
.available_models(cx)
|
||||
.find(|model| &model.id() == &summary_model_id)
|
||||
.find(|model| model.matches_id(&summary_model_id))
|
||||
else {
|
||||
return cx.background_spawn(async move {
|
||||
Err(anyhow!("Couldn't find the preferred summarization model ({:?}) in the language registry's available models", summary_model_id))
|
||||
|
||||
Reference in New Issue
Block a user