Compare commits

...

6 Commits

Author SHA1 Message Date
Nathan Sobo
9445013dd6 Merge branch 'main' into gemini
Co-Authored-By: Antonio <antonio@zed.dev>
2024-07-15 12:02:43 +02:00
Nathan Sobo
74cf3d2d92 Include cached content in stream generate content request
This commit updates the GenerateContentRequest struct in the Google AI module
to include an optional cached_content field. This allows for the use of
previously cached content in subsequent requests, potentially improving
response times and maintaining context across multiple interactions.

The changes are propagated through the assistant and collab crates to ensure
that the cached content can be passed from the user interface down to the
API request level.

Changes:
- Added cached_content field to GenerateContentRequest in google_ai.rs
- Updated LanguageModelRequest struct in assistant.rs to include cached_contents
- Modified relevant functions in assistant_panel.rs, completion_provider/cloud.rs,
  inline_assistant.rs, and prompt_library.rs to accommodate the new field
- Updated language_model_request_to_google_ai function in collab/ai.rs to
  pass cached content to the Google AI request
- Added cached_contents field to CompleteWithLanguageModel message in zed.proto
2024-07-01 21:18:37 -06:00
Nathan Sobo
427491a24f Add caching support for language model content
This commit adds support for caching language model content using the Google AI API. The changes include:

1. Adding a new CacheLanguageModelContent request and response to the protocol.
2. Implementing the cache_language_model_content function in the RPC server.
3. Updating the Google AI client to support creating cached content.
4. Modifying the proto definitions to include the new messages.
2024-07-01 20:18:14 -06:00
Nathan Sobo
2781b1cce1 Add an API provider for creating cached content with Gemini.
Co-authored-by: mikayla <mikayla@zed.dev>
2024-07-01 15:31:20 -06:00
Nathan Sobo
ab69c05d99 Refactor token counting and add support for Gemini models
This commit refactors the token counting logic in the `CloudCompletionProvider`
struct to improve code organization and add support for Gemini models. Key changes include:

- Introduce a new `count_tokens_with_model` method to handle token counting
  for models that use the client's request mechanism.
- Update the `count_tokens` method to use the new helper method for Gemini
  and custom models.
- Replace the placeholder implementation for Gemini models with proper
  token counting support.

These changes provide a more consistent approach to token counting across
different model types and lay the groundwork for easier integration of
future models.

Co-authored-by: mikayla <mikayla@zed.dev>
2024-07-01 14:24:55 -06:00
Nathan Sobo
f06c3b5670 Add Gemini 1.5 Pro and Gemini 1.5 Flash as cloud model options
This commit adds support for two new Gemini models:
- Gemini 1.5 Pro
- Gemini 1.5 Flash

Changes include:
1. Adding new variants to the CloudModel enum
2. Updating the id() and display_name() methods for the new models
3. Setting max_token_count for the new models (128000 for Pro, 32000 for Flash)
4. Adding token counting logic for Gemini models (currently using OpenAI's tokenizer as an approximation)

Note: A proper tokenizer for Gemini models should be implemented in the future.

Co-authored-by: mikayla <mikayla@zed.dev>
2024-07-01 14:11:40 -06:00
14 changed files with 268 additions and 1 deletions

1
Cargo.lock generated
View File

@@ -4819,6 +4819,7 @@ name = "google_ai"
version = "0.1.0"
dependencies = [
"anyhow",
"chrono",
"futures 0.3.28",
"http 0.1.0",
"serde",

View File

@@ -189,8 +189,12 @@ pub struct LanguageModelRequest {
pub messages: Vec<LanguageModelRequestMessage>,
pub stop: Vec<String>,
pub temperature: f32,
pub cached_contents: Vec<CachedContentId>,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct CachedContentId(String);
impl LanguageModelRequest {
pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
proto::CompleteWithLanguageModel {
@@ -200,6 +204,7 @@ impl LanguageModelRequest {
temperature: self.temperature,
tool_choice: None,
tools: Vec::new(),
cached_contents: self.cached_contents.iter().map(|id| id.0.clone()).collect(),
}
}

View File

@@ -27,6 +27,8 @@ pub enum CloudModel {
Claude3Opus,
Claude3Sonnet,
Claude3Haiku,
Gemini15Pro,
Gemini15Flash,
Custom(String),
}
@@ -109,6 +111,8 @@ impl CloudModel {
Self::Claude3Opus => "claude-3-opus",
Self::Claude3Sonnet => "claude-3-sonnet",
Self::Claude3Haiku => "claude-3-haiku",
Self::Gemini15Pro => "gemini-1.5-pro",
Self::Gemini15Flash => "gemini-1.5-flash",
Self::Custom(id) => id,
}
}
@@ -123,6 +127,8 @@ impl CloudModel {
Self::Claude3Opus => "Claude 3 Opus",
Self::Claude3Sonnet => "Claude 3 Sonnet",
Self::Claude3Haiku => "Claude 3 Haiku",
Self::Gemini15Pro => "Gemini 1.5 Pro",
Self::Gemini15Flash => "Gemini 1.5 Flash",
Self::Custom(id) => id.as_str(),
}
}
@@ -136,6 +142,8 @@ impl CloudModel {
| Self::Claude3Opus
| Self::Claude3Sonnet
| Self::Claude3Haiku => 200000,
Self::Gemini15Pro => 128000,
Self::Gemini15Flash => 32000,
Self::Custom(_) => 4096, // TODO: Make this configurable
}
}

View File

@@ -152,6 +152,11 @@ impl LanguageModelCompletionProvider for CloudCompletionProvider {
temperature: request.temperature,
tools: Vec::new(),
tool_choice: None,
cached_contents: request
.cached_contents
.iter()
.map(|id| id.0.clone())
.collect(),
};
self.client

View File

@@ -1258,6 +1258,7 @@ impl Context {
messages: messages.collect(),
stop: vec![],
temperature: 1.0,
cached_contents: Vec::new(), // todo!("support context caching")
}
}
@@ -1513,6 +1514,7 @@ impl Context {
messages: messages.collect(),
stop: vec![],
temperature: 1.0,
cached_contents: Vec::new(), // todo!
};
let stream = CompletionProvider::global(cx).complete(request, cx);

View File

@@ -851,6 +851,7 @@ impl InlineAssistant {
messages,
stop: vec!["|END|>".to_string()],
temperature,
cached_contents: Vec::new(),
})
})
}

View File

@@ -744,6 +744,7 @@ impl PromptLibrary {
}],
stop: Vec::new(),
temperature: 1.,
cached_contents: Vec::new(),
},
cx,
)

View File

@@ -269,6 +269,7 @@ impl TerminalInlineAssistant {
messages,
stop: Vec::new(),
temperature: 1.0,
cached_contents: Vec::new(), // todo!
})
}

View File

@@ -101,6 +101,7 @@ pub fn language_model_request_to_google_ai(
.collect::<Result<Vec<_>>>()?,
generation_config: None,
safety_settings: None,
cached_content: request.cached_contents.into_iter().next(),
})
}
@@ -125,6 +126,60 @@ pub fn language_model_request_message_to_google_ai(
})
}
pub fn cache_language_model_content_request_to_google_ai(
request: proto::CacheLanguageModelContent,
) -> Result<google_ai::CreateCachedContentRequest> {
Ok(google_ai::CreateCachedContentRequest {
contents: request
.messages
.into_iter()
.map(language_model_request_message_to_google_ai)
.collect::<Result<Vec<_>>>()?,
tools: if request.tools.is_empty() {
None
} else {
Some(
request
.tools
.into_iter()
.try_fold(Vec::new(), |mut acc, tool| {
if let Some(variant) = tool.variant {
match variant {
proto::chat_completion_tool::Variant::Function(f) => {
let description = f.description.ok_or_else(|| {
anyhow!("Function tool is missing a description")
})?;
let parameters = f.parameters.ok_or_else(|| {
anyhow!("Function tool is missing parameters")
})?;
let parsed_parameters = serde_json::from_str(&parameters)
.map_err(|e| {
anyhow!("Failed to parse parameters: {}", e)
})?;
acc.push(google_ai::Tool {
function_declarations: vec![
google_ai::FunctionDeclaration {
name: f.name,
description,
parameters: parsed_parameters,
},
],
});
}
}
}
anyhow::Ok(acc)
})?,
)
},
ttl: request.ttl_seconds.map(|s| format!("{}s", s)),
display_name: None,
model: request.model,
system_instruction: None,
tool_config: None,
})
}
pub fn count_tokens_request_to_google_ai(
request: proto::CountTokensWithLanguageModel,
) -> Result<google_ai::CountTokensRequest> {

View File

@@ -616,6 +616,17 @@ impl Server {
)
}
})
.add_request_handler({
let app_state = app_state.clone();
user_handler(move |request, response, session| {
cache_language_model_content(
request,
response,
session,
app_state.config.google_ai_api_key.clone(),
)
})
})
.add_request_handler({
let app_state = app_state.clone();
user_handler(move |request, response, session| {
@@ -4723,6 +4734,43 @@ async fn complete_with_anthropic(
Ok(())
}
async fn cache_language_model_content(
request: proto::CacheLanguageModelContent,
response: Response<proto::CacheLanguageModelContent>,
session: UserSession,
google_ai_api_key: Option<Arc<str>>,
) -> Result<()> {
authorize_access_to_language_models(&session).await?;
if !request.model.starts_with("gemini") {
return Err(anyhow!(
"caching content for model: {:?} is not supported",
request.model
))?;
}
let api_key = google_ai_api_key
.ok_or_else(|| anyhow!("no Google AI API key configured on the server"))?;
let cached_content = google_ai::create_cached_content(
session.http_client.as_ref(),
google_ai::API_URL,
&api_key,
crate::ai::cache_language_model_content_request_to_google_ai(request)?,
)
.await?;
response.send(proto::CacheLanguageModelContentResponse {
name: cached_content.name,
create_time: cached_content.create_time.timestamp() as u64,
update_time: cached_content.update_time.timestamp() as u64,
expire_time: cached_content.expire_time.map(|t| t.timestamp() as u64),
total_token_count: cached_content.usage_metadata.total_token_count,
})?;
Ok(())
}
struct CountTokensWithLanguageModelRateLimit;
impl RateLimit for CountTokensWithLanguageModelRateLimit {

View File

@@ -14,3 +14,4 @@ futures.workspace = true
http.workspace = true
serde.workspace = true
serde_json.workspace = true
chrono.workspace = true

View File

@@ -97,6 +97,7 @@ pub struct GenerateContentRequest {
pub contents: Vec<Content>,
pub generation_config: Option<GenerationConfig>,
pub safety_settings: Option<Vec<SafetySetting>>,
pub cached_content: Option<String>,
}
#[derive(Debug, Deserialize)]
@@ -267,3 +268,115 @@ pub struct CountTokensRequest {
pub struct CountTokensResponse {
pub total_tokens: usize,
}
pub async fn create_cached_content<T: HttpClient>(
client: &T,
api_url: &str,
api_key: &str,
request: CreateCachedContentRequest,
) -> Result<CreateCachedContentResponse> {
let uri = format!("{}/v1beta/cachedContents?key={}", api_url, api_key);
let request = serde_json::to_string(&request)?;
let mut response = client.post_json(&uri, request.into()).await?;
let mut text = String::new();
response.body_mut().read_to_string(&mut text).await?;
if response.status().is_success() {
Ok(serde_json::from_str::<CreateCachedContentResponse>(&text)?)
} else {
Err(anyhow!(
"error during createCachedContent, status code: {:?}, body: {}",
response.status(),
text
))
}
}
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct CreateCachedContentRequest {
pub contents: Vec<Content>,
pub tools: Option<Vec<Tool>>,
pub ttl: Option<String>,
pub display_name: Option<String>,
pub model: String,
pub system_instruction: Option<Content>,
pub tool_config: Option<ToolConfig>,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CreateCachedContentResponse {
pub name: String,
pub create_time: chrono::DateTime<chrono::Utc>,
pub update_time: chrono::DateTime<chrono::Utc>,
pub expire_time: Option<chrono::DateTime<chrono::Utc>>,
pub usage_metadata: UsageMetadata,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct UsageMetadata {
pub total_token_count: u32,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Tool {
pub function_declarations: Vec<FunctionDeclaration>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionDeclaration {
pub name: String,
pub description: String,
pub parameters: Schema,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Schema {
#[serde(rename = "type")]
pub schema_type: SchemaType,
pub format: Option<String>,
pub description: Option<String>,
pub nullable: Option<bool>,
pub enum_values: Option<Vec<String>>,
pub properties: Option<std::collections::HashMap<String, Box<Schema>>>,
pub required: Option<Vec<String>>,
pub items: Option<Box<Schema>>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum SchemaType {
TypeUnspecified,
String,
Number,
Integer,
Boolean,
Array,
Object,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ToolConfig {
pub function_calling_config: Option<FunctionCallingConfig>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FunctionCallingConfig {
pub mode: Option<FunctionCallingMode>,
pub allowed_function_names: Option<Vec<String>>,
}
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum FunctionCallingMode {
ModeUnspecified,
Auto,
Any,
None,
}

View File

@@ -203,6 +203,10 @@ message Envelope {
CompleteWithLanguageModel complete_with_language_model = 166;
LanguageModelResponse language_model_response = 167;
CacheLanguageModelContent cache_language_model_content = 219;
CacheLanguageModelContentResponse cache_language_model_content_response = 220; // current max
CountTokensWithLanguageModel count_tokens_with_language_model = 168;
CountTokensResponse count_tokens_response = 169;
GetCachedEmbeddings get_cached_embeddings = 189;
@@ -265,7 +269,7 @@ message Envelope {
SynchronizeContextsResponse synchronize_contexts_response = 216;
GetSignatureHelp get_signature_help = 217;
GetSignatureHelpResponse get_signature_help_response = 218; // current max
GetSignatureHelpResponse get_signature_help_response = 218;
}
reserved 158 to 161;
@@ -2025,6 +2029,25 @@ message CompleteWithLanguageModel {
float temperature = 4;
repeated ChatCompletionTool tools = 5;
optional string tool_choice = 6;
repeated string cached_contents = 7;
}
message CacheLanguageModelContent {
string model = 1;
repeated LanguageModelRequestMessage messages = 2;
repeated string stop = 3;
float temperature = 4;
repeated ChatCompletionTool tools = 5;
optional string tool_choice = 6;
optional uint64 ttl_seconds = 7;
}
message CacheLanguageModelContentResponse {
string name = 1;
uint64 create_time = 2;
uint64 update_time = 3;
optional uint64 expire_time = 4;
uint32 total_token_count = 5;
}
// A tool presented to the language model for its use

View File

@@ -194,6 +194,8 @@ messages!(
(ApplyCompletionAdditionalEditsResponse, Background),
(BufferReloaded, Foreground),
(BufferSaved, Foreground),
(CacheLanguageModelContent, Background),
(CacheLanguageModelContentResponse, Background),
(Call, Foreground),
(CallCanceled, Foreground),
(CancelCall, Foreground),
@@ -400,6 +402,7 @@ request_messages!(
ApplyCompletionAdditionalEdits,
ApplyCompletionAdditionalEditsResponse
),
(CacheLanguageModelContent, CacheLanguageModelContentResponse),
(Call, Ack),
(CancelCall, Ack),
(CopyProjectEntry, ProjectEntryResponse),