Compare commits

...

1 Commits

Author SHA1 Message Date
Nathan Sobo
139c5c59b7 WIP 2024-10-21 17:52:45 -06:00

View File

@@ -21,6 +21,24 @@ use axum::{
use chrono::{DateTime, Duration, Utc};
use collections::HashMap;
use db::TokenUsage;
#[derive(Debug, serde::Serialize, serde::Deserialize)]
pub struct ComputeEmbeddingsRequest {
pub model: String,
pub texts: Vec<String>,
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
pub struct ComputeEmbeddingsResponse {
pub embeddings: Vec<Embedding>,
}
#[derive(Debug, serde::Serialize, serde::Deserialize)]
pub struct Embedding {
pub digest: Vec<u8>,
pub dimensions: Vec<f32>,
}
use db::{usage_measure::UsageMeasure, ActiveUserCount, LlmDatabase};
use futures::{Stream, StreamExt as _};
use reqwest_client::ReqwestClient;
@@ -28,6 +46,7 @@ use rpc::{
proto::Plan, LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME,
};
use rpc::{ListModelsResponse, MAX_LLM_MONTHLY_SPEND_REACHED_HEADER_NAME};
use sha2::Sha256;
use std::{
pin::Pin,
sync::Arc,
@@ -113,10 +132,79 @@ impl LlmState {
}
}
async fn compute_embeddings_http(
Extension(state): Extension<Arc<LlmState>>,
Extension(claims): Extension<LlmTokenClaims>,
Json(request): Json<proto::ComputeEmbeddings>,
) -> Result<impl IntoResponse> {
let api_key = state
.config
.openai_api_key
.as_ref()
.context("no OpenAI API key configured on the server")?;
let rate_limit: Box<dyn RateLimit> = match claims.plan {
proto::Plan::ZedPro => Box::new(ZedProComputeEmbeddingsRateLimit),
proto::Plan::Free => Box::new(FreeComputeEmbeddingsRateLimit),
};
state
.app_state
.rate_limiter
.check(&*rate_limit, UserId::from_proto(claims.user_id))
.await?;
let embeddings = match request.model.as_str() {
"openai/text-embedding-3-small" => {
open_ai::embed(
&state.http_client,
OPEN_AI_API_URL,
api_key,
OpenAiEmbeddingModel::TextEmbedding3Small,
request.texts.iter().map(|text| text.as_str()),
)
.await?
}
provider => return Err(anyhow!("unsupported embedding provider {:?}", provider))?,
};
let embeddings = request
.texts
.iter()
.map(|text| {
let mut hasher = Sha256::new();
hasher.update(text.as_bytes());
let result = hasher.finalize();
result.to_vec()
})
.zip(
embeddings
.data
.into_iter()
.map(|embedding| embedding.embedding),
)
.collect::<HashMap<_, _>>();
state
.db
.save_embeddings(&request.model, &embeddings)
.await
.context("failed to save embeddings")
.trace_err();
Ok(Json(proto::ComputeEmbeddingsResponse {
embeddings: embeddings
.into_iter()
.map(|(digest, dimensions)| proto::Embedding { digest, dimensions })
.collect(),
}))
}
pub fn routes() -> Router<(), Body> {
Router::new()
.route("/models", get(list_models))
.route("/completion", post(perform_completion))
.route("/compute_embeddings", post(compute_embeddings_http))
.layer(middleware::from_fn(validate_api_token))
}