diff --git a/Cargo.lock b/Cargo.lock index 85a62c9519..d949da7e8a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1546,6 +1546,18 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" +[[package]] +name = "bge" +version = "0.1.0" +dependencies = [ + "anyhow", + "futures 0.3.30", + "http_client", + "isahc", + "serde", + "serde_json", +] + [[package]] name = "bigdecimal" version = "0.4.5" @@ -2543,6 +2555,7 @@ dependencies = [ "axum", "axum-extra", "base64 0.22.1", + "bge", "call", "channel", "chrono", diff --git a/Cargo.toml b/Cargo.toml index c72fec020f..cd20002fab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ members = [ "crates/assistant_tool", "crates/audio", "crates/auto_update", + "crates/bge", "crates/breadcrumbs", "crates/call", "crates/channel", @@ -187,6 +188,7 @@ assistant_slash_command = { path = "crates/assistant_slash_command" } assistant_tool = { path = "crates/assistant_tool" } audio = { path = "crates/audio" } auto_update = { path = "crates/auto_update" } +bge = { path = "crates/bge" } breadcrumbs = { path = "crates/breadcrumbs" } call = { path = "crates/call" } channel = { path = "crates/channel" } diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index ad43d2d1f0..03041d306a 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -27,6 +27,7 @@ aws-sdk-s3 = { version = "1.15.0" } axum = { version = "0.6", features = ["json", "headers", "ws"] } axum-extra = { version = "0.4", features = ["erased-json"] } base64.workspace = true +bge.workspace = true chrono.workspace = true clock.workspace = true clickhouse.workspace = true diff --git a/crates/collab/k8s/collab.template.yml b/crates/collab/k8s/collab.template.yml index 7d4ea6eb9a..17df872372 100644 --- a/crates/collab/k8s/collab.template.yml +++ b/crates/collab/k8s/collab.template.yml @@ -149,6 +149,18 @@ spec: secretKeyRef: name: google-ai key: api_key + - name: EMBEDDINGS_API_KEY + valueFrom: + secretKeyRef: + name: embeddings + key: api_key + optional: true + - name: EMBEDDINGS_API_URL + valueFrom: + secretKeyRef: + name: embeddings + key: url + optional: true - name: BLOB_STORE_ACCESS_KEY valueFrom: secretKeyRef: diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 6c32023a97..bd2f2a7982 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -170,6 +170,8 @@ pub struct Config { pub anthropic_api_key: Option>, pub anthropic_staff_api_key: Option>, pub llm_closed_beta_model_name: Option>, + pub embeddings_api_key: Option>, + pub embeddings_api_url: Option>, pub zed_client_checksum_seed: Option, pub slack_panics_webhook: Option, pub auto_join_channel_id: Option, @@ -233,6 +235,8 @@ impl Config { stripe_api_key: None, stripe_price_id: None, supermaven_admin_api_key: None, + embeddings_api_key: None, + embeddings_api_url: None, user_backfiller_github_access_token: None, } } diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index d9683fb8b3..6969ffe395 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -37,7 +37,6 @@ pub use connection_pool::{ConnectionPool, ZedVersion}; use core::fmt::{self, Debug, Formatter}; use http_client::HttpClient; use isahc_http_client::IsahcHttpClient; -use open_ai::{OpenAiEmbeddingModel, OPEN_AI_API_URL}; use sha2::Digest; use supermaven_api::{CreateExternalUserRequest, SupermavenAdminApi}; @@ -641,7 +640,8 @@ impl Server { request, response, session, - app_state.config.openai_api_key.clone(), + app_state.config.embeddings_api_url.clone(), + app_state.config.embeddings_api_key.clone(), ) }) }); @@ -4585,9 +4585,11 @@ async fn compute_embeddings( request: proto::ComputeEmbeddings, response: Response, session: UserSession, + api_url: Option>, api_key: Option>, ) -> Result<()> { - let api_key = api_key.context("no OpenAI API key configured on the server")?; + let embeddings_url = api_url.context("no embeddings API URL configured on the server")?; + let embeddings_api_key = api_key.context("no embeddings API key configured on the server")?; authorize_access_to_legacy_llm_endpoints(&session).await?; let rate_limit: Box = match session.current_plan(session.db().await).await? { @@ -4601,19 +4603,13 @@ async fn compute_embeddings( .check(&*rate_limit, session.user_id()) .await?; - let embeddings = match request.model.as_str() { - "openai/text-embedding-3-small" => { - open_ai::embed( - session.http_client.as_ref(), - OPEN_AI_API_URL, - &api_key, - OpenAiEmbeddingModel::TextEmbedding3Small, - request.texts.iter().map(|text| text.as_str()), - ) - .await? - } - provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?, - }; + let embeddings = bge::embed( + session.http_client.as_ref(), + &embeddings_url, + &embeddings_api_key, + request.texts.iter().map(|text| text.as_str()), + ) + .await?; let embeddings = request .texts diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index 5ff4a72074..bf5f2c7365 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -679,6 +679,8 @@ impl TestServer { stripe_api_key: None, stripe_price_id: None, supermaven_admin_api_key: None, + embeddings_api_key: None, + embeddings_api_url: None, user_backfiller_github_access_token: None, }, })