Compare commits
6 Commits
lua-run-cl
...
gemini
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9445013dd6 | ||
|
|
74cf3d2d92 | ||
|
|
427491a24f | ||
|
|
2781b1cce1 | ||
|
|
ab69c05d99 | ||
|
|
f06c3b5670 |
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -4819,6 +4819,7 @@ name = "google_ai"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"chrono",
|
||||
"futures 0.3.28",
|
||||
"http 0.1.0",
|
||||
"serde",
|
||||
|
||||
@@ -189,8 +189,12 @@ pub struct LanguageModelRequest {
|
||||
pub messages: Vec<LanguageModelRequestMessage>,
|
||||
pub stop: Vec<String>,
|
||||
pub temperature: f32,
|
||||
pub cached_contents: Vec<CachedContentId>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
|
||||
pub struct CachedContentId(String);
|
||||
|
||||
impl LanguageModelRequest {
|
||||
pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
|
||||
proto::CompleteWithLanguageModel {
|
||||
@@ -200,6 +204,7 @@ impl LanguageModelRequest {
|
||||
temperature: self.temperature,
|
||||
tool_choice: None,
|
||||
tools: Vec::new(),
|
||||
cached_contents: self.cached_contents.iter().map(|id| id.0.clone()).collect(),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -27,6 +27,8 @@ pub enum CloudModel {
|
||||
Claude3Opus,
|
||||
Claude3Sonnet,
|
||||
Claude3Haiku,
|
||||
Gemini15Pro,
|
||||
Gemini15Flash,
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
@@ -109,6 +111,8 @@ impl CloudModel {
|
||||
Self::Claude3Opus => "claude-3-opus",
|
||||
Self::Claude3Sonnet => "claude-3-sonnet",
|
||||
Self::Claude3Haiku => "claude-3-haiku",
|
||||
Self::Gemini15Pro => "gemini-1.5-pro",
|
||||
Self::Gemini15Flash => "gemini-1.5-flash",
|
||||
Self::Custom(id) => id,
|
||||
}
|
||||
}
|
||||
@@ -123,6 +127,8 @@ impl CloudModel {
|
||||
Self::Claude3Opus => "Claude 3 Opus",
|
||||
Self::Claude3Sonnet => "Claude 3 Sonnet",
|
||||
Self::Claude3Haiku => "Claude 3 Haiku",
|
||||
Self::Gemini15Pro => "Gemini 1.5 Pro",
|
||||
Self::Gemini15Flash => "Gemini 1.5 Flash",
|
||||
Self::Custom(id) => id.as_str(),
|
||||
}
|
||||
}
|
||||
@@ -136,6 +142,8 @@ impl CloudModel {
|
||||
| Self::Claude3Opus
|
||||
| Self::Claude3Sonnet
|
||||
| Self::Claude3Haiku => 200000,
|
||||
Self::Gemini15Pro => 128000,
|
||||
Self::Gemini15Flash => 32000,
|
||||
Self::Custom(_) => 4096, // TODO: Make this configurable
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,6 +152,11 @@ impl LanguageModelCompletionProvider for CloudCompletionProvider {
|
||||
temperature: request.temperature,
|
||||
tools: Vec::new(),
|
||||
tool_choice: None,
|
||||
cached_contents: request
|
||||
.cached_contents
|
||||
.iter()
|
||||
.map(|id| id.0.clone())
|
||||
.collect(),
|
||||
};
|
||||
|
||||
self.client
|
||||
|
||||
@@ -1258,6 +1258,7 @@ impl Context {
|
||||
messages: messages.collect(),
|
||||
stop: vec![],
|
||||
temperature: 1.0,
|
||||
cached_contents: Vec::new(), // todo!("support context caching")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1513,6 +1514,7 @@ impl Context {
|
||||
messages: messages.collect(),
|
||||
stop: vec![],
|
||||
temperature: 1.0,
|
||||
cached_contents: Vec::new(), // todo!
|
||||
};
|
||||
|
||||
let stream = CompletionProvider::global(cx).complete(request, cx);
|
||||
|
||||
@@ -851,6 +851,7 @@ impl InlineAssistant {
|
||||
messages,
|
||||
stop: vec!["|END|>".to_string()],
|
||||
temperature,
|
||||
cached_contents: Vec::new(),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -744,6 +744,7 @@ impl PromptLibrary {
|
||||
}],
|
||||
stop: Vec::new(),
|
||||
temperature: 1.,
|
||||
cached_contents: Vec::new(),
|
||||
},
|
||||
cx,
|
||||
)
|
||||
|
||||
@@ -269,6 +269,7 @@ impl TerminalInlineAssistant {
|
||||
messages,
|
||||
stop: Vec::new(),
|
||||
temperature: 1.0,
|
||||
cached_contents: Vec::new(), // todo!
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -101,6 +101,7 @@ pub fn language_model_request_to_google_ai(
|
||||
.collect::<Result<Vec<_>>>()?,
|
||||
generation_config: None,
|
||||
safety_settings: None,
|
||||
cached_content: request.cached_contents.into_iter().next(),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -125,6 +126,60 @@ pub fn language_model_request_message_to_google_ai(
|
||||
})
|
||||
}
|
||||
|
||||
pub fn cache_language_model_content_request_to_google_ai(
|
||||
request: proto::CacheLanguageModelContent,
|
||||
) -> Result<google_ai::CreateCachedContentRequest> {
|
||||
Ok(google_ai::CreateCachedContentRequest {
|
||||
contents: request
|
||||
.messages
|
||||
.into_iter()
|
||||
.map(language_model_request_message_to_google_ai)
|
||||
.collect::<Result<Vec<_>>>()?,
|
||||
tools: if request.tools.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
request
|
||||
.tools
|
||||
.into_iter()
|
||||
.try_fold(Vec::new(), |mut acc, tool| {
|
||||
if let Some(variant) = tool.variant {
|
||||
match variant {
|
||||
proto::chat_completion_tool::Variant::Function(f) => {
|
||||
let description = f.description.ok_or_else(|| {
|
||||
anyhow!("Function tool is missing a description")
|
||||
})?;
|
||||
let parameters = f.parameters.ok_or_else(|| {
|
||||
anyhow!("Function tool is missing parameters")
|
||||
})?;
|
||||
let parsed_parameters = serde_json::from_str(¶meters)
|
||||
.map_err(|e| {
|
||||
anyhow!("Failed to parse parameters: {}", e)
|
||||
})?;
|
||||
acc.push(google_ai::Tool {
|
||||
function_declarations: vec![
|
||||
google_ai::FunctionDeclaration {
|
||||
name: f.name,
|
||||
description,
|
||||
parameters: parsed_parameters,
|
||||
},
|
||||
],
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
anyhow::Ok(acc)
|
||||
})?,
|
||||
)
|
||||
},
|
||||
ttl: request.ttl_seconds.map(|s| format!("{}s", s)),
|
||||
display_name: None,
|
||||
model: request.model,
|
||||
system_instruction: None,
|
||||
tool_config: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn count_tokens_request_to_google_ai(
|
||||
request: proto::CountTokensWithLanguageModel,
|
||||
) -> Result<google_ai::CountTokensRequest> {
|
||||
|
||||
@@ -616,6 +616,17 @@ impl Server {
|
||||
)
|
||||
}
|
||||
})
|
||||
.add_request_handler({
|
||||
let app_state = app_state.clone();
|
||||
user_handler(move |request, response, session| {
|
||||
cache_language_model_content(
|
||||
request,
|
||||
response,
|
||||
session,
|
||||
app_state.config.google_ai_api_key.clone(),
|
||||
)
|
||||
})
|
||||
})
|
||||
.add_request_handler({
|
||||
let app_state = app_state.clone();
|
||||
user_handler(move |request, response, session| {
|
||||
@@ -4723,6 +4734,43 @@ async fn complete_with_anthropic(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cache_language_model_content(
|
||||
request: proto::CacheLanguageModelContent,
|
||||
response: Response<proto::CacheLanguageModelContent>,
|
||||
session: UserSession,
|
||||
google_ai_api_key: Option<Arc<str>>,
|
||||
) -> Result<()> {
|
||||
authorize_access_to_language_models(&session).await?;
|
||||
|
||||
if !request.model.starts_with("gemini") {
|
||||
return Err(anyhow!(
|
||||
"caching content for model: {:?} is not supported",
|
||||
request.model
|
||||
))?;
|
||||
}
|
||||
|
||||
let api_key = google_ai_api_key
|
||||
.ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
|
||||
|
||||
let cached_content = google_ai::create_cached_content(
|
||||
session.http_client.as_ref(),
|
||||
google_ai::API_URL,
|
||||
&api_key,
|
||||
crate::ai::cache_language_model_content_request_to_google_ai(request)?,
|
||||
)
|
||||
.await?;
|
||||
|
||||
response.send(proto::CacheLanguageModelContentResponse {
|
||||
name: cached_content.name,
|
||||
create_time: cached_content.create_time.timestamp() as u64,
|
||||
update_time: cached_content.update_time.timestamp() as u64,
|
||||
expire_time: cached_content.expire_time.map(|t| t.timestamp() as u64),
|
||||
total_token_count: cached_content.usage_metadata.total_token_count,
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct CountTokensWithLanguageModelRateLimit;
|
||||
|
||||
impl RateLimit for CountTokensWithLanguageModelRateLimit {
|
||||
|
||||
@@ -14,3 +14,4 @@ futures.workspace = true
|
||||
http.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
chrono.workspace = true
|
||||
|
||||
@@ -97,6 +97,7 @@ pub struct GenerateContentRequest {
|
||||
pub contents: Vec<Content>,
|
||||
pub generation_config: Option<GenerationConfig>,
|
||||
pub safety_settings: Option<Vec<SafetySetting>>,
|
||||
pub cached_content: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
@@ -267,3 +268,115 @@ pub struct CountTokensRequest {
|
||||
pub struct CountTokensResponse {
|
||||
pub total_tokens: usize,
|
||||
}
|
||||
|
||||
pub async fn create_cached_content<T: HttpClient>(
|
||||
client: &T,
|
||||
api_url: &str,
|
||||
api_key: &str,
|
||||
request: CreateCachedContentRequest,
|
||||
) -> Result<CreateCachedContentResponse> {
|
||||
let uri = format!("{}/v1beta/cachedContents?key={}", api_url, api_key);
|
||||
let request = serde_json::to_string(&request)?;
|
||||
let mut response = client.post_json(&uri, request.into()).await?;
|
||||
let mut text = String::new();
|
||||
response.body_mut().read_to_string(&mut text).await?;
|
||||
if response.status().is_success() {
|
||||
Ok(serde_json::from_str::<CreateCachedContentResponse>(&text)?)
|
||||
} else {
|
||||
Err(anyhow!(
|
||||
"error during createCachedContent, status code: {:?}, body: {}",
|
||||
response.status(),
|
||||
text
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CreateCachedContentRequest {
|
||||
pub contents: Vec<Content>,
|
||||
pub tools: Option<Vec<Tool>>,
|
||||
pub ttl: Option<String>,
|
||||
pub display_name: Option<String>,
|
||||
pub model: String,
|
||||
pub system_instruction: Option<Content>,
|
||||
pub tool_config: Option<ToolConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CreateCachedContentResponse {
|
||||
pub name: String,
|
||||
pub create_time: chrono::DateTime<chrono::Utc>,
|
||||
pub update_time: chrono::DateTime<chrono::Utc>,
|
||||
pub expire_time: Option<chrono::DateTime<chrono::Utc>>,
|
||||
pub usage_metadata: UsageMetadata,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct UsageMetadata {
|
||||
pub total_token_count: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Tool {
|
||||
pub function_declarations: Vec<FunctionDeclaration>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FunctionDeclaration {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: Schema,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Schema {
|
||||
#[serde(rename = "type")]
|
||||
pub schema_type: SchemaType,
|
||||
pub format: Option<String>,
|
||||
pub description: Option<String>,
|
||||
pub nullable: Option<bool>,
|
||||
pub enum_values: Option<Vec<String>>,
|
||||
pub properties: Option<std::collections::HashMap<String, Box<Schema>>>,
|
||||
pub required: Option<Vec<String>>,
|
||||
pub items: Option<Box<Schema>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
|
||||
pub enum SchemaType {
|
||||
TypeUnspecified,
|
||||
String,
|
||||
Number,
|
||||
Integer,
|
||||
Boolean,
|
||||
Array,
|
||||
Object,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct ToolConfig {
|
||||
pub function_calling_config: Option<FunctionCallingConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct FunctionCallingConfig {
|
||||
pub mode: Option<FunctionCallingMode>,
|
||||
pub allowed_function_names: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
|
||||
pub enum FunctionCallingMode {
|
||||
ModeUnspecified,
|
||||
Auto,
|
||||
Any,
|
||||
None,
|
||||
}
|
||||
|
||||
@@ -203,6 +203,10 @@ message Envelope {
|
||||
|
||||
CompleteWithLanguageModel complete_with_language_model = 166;
|
||||
LanguageModelResponse language_model_response = 167;
|
||||
|
||||
CacheLanguageModelContent cache_language_model_content = 219;
|
||||
CacheLanguageModelContentResponse cache_language_model_content_response = 220; // current max
|
||||
|
||||
CountTokensWithLanguageModel count_tokens_with_language_model = 168;
|
||||
CountTokensResponse count_tokens_response = 169;
|
||||
GetCachedEmbeddings get_cached_embeddings = 189;
|
||||
@@ -265,7 +269,7 @@ message Envelope {
|
||||
SynchronizeContextsResponse synchronize_contexts_response = 216;
|
||||
|
||||
GetSignatureHelp get_signature_help = 217;
|
||||
GetSignatureHelpResponse get_signature_help_response = 218; // current max
|
||||
GetSignatureHelpResponse get_signature_help_response = 218;
|
||||
}
|
||||
|
||||
reserved 158 to 161;
|
||||
@@ -2025,6 +2029,25 @@ message CompleteWithLanguageModel {
|
||||
float temperature = 4;
|
||||
repeated ChatCompletionTool tools = 5;
|
||||
optional string tool_choice = 6;
|
||||
repeated string cached_contents = 7;
|
||||
}
|
||||
|
||||
message CacheLanguageModelContent {
|
||||
string model = 1;
|
||||
repeated LanguageModelRequestMessage messages = 2;
|
||||
repeated string stop = 3;
|
||||
float temperature = 4;
|
||||
repeated ChatCompletionTool tools = 5;
|
||||
optional string tool_choice = 6;
|
||||
optional uint64 ttl_seconds = 7;
|
||||
}
|
||||
|
||||
message CacheLanguageModelContentResponse {
|
||||
string name = 1;
|
||||
uint64 create_time = 2;
|
||||
uint64 update_time = 3;
|
||||
optional uint64 expire_time = 4;
|
||||
uint32 total_token_count = 5;
|
||||
}
|
||||
|
||||
// A tool presented to the language model for its use
|
||||
|
||||
@@ -194,6 +194,8 @@ messages!(
|
||||
(ApplyCompletionAdditionalEditsResponse, Background),
|
||||
(BufferReloaded, Foreground),
|
||||
(BufferSaved, Foreground),
|
||||
(CacheLanguageModelContent, Background),
|
||||
(CacheLanguageModelContentResponse, Background),
|
||||
(Call, Foreground),
|
||||
(CallCanceled, Foreground),
|
||||
(CancelCall, Foreground),
|
||||
@@ -400,6 +402,7 @@ request_messages!(
|
||||
ApplyCompletionAdditionalEdits,
|
||||
ApplyCompletionAdditionalEditsResponse
|
||||
),
|
||||
(CacheLanguageModelContent, CacheLanguageModelContentResponse),
|
||||
(Call, Ack),
|
||||
(CancelCall, Ack),
|
||||
(CopyProjectEntry, ProjectEntryResponse),
|
||||
|
||||
Reference in New Issue
Block a user