Compare commits
5 Commits
remote-pro
...
simplify-e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5b74ddf020 | ||
|
|
54f20ae5d5 | ||
|
|
811efa45d0 | ||
|
|
74501e0936 | ||
|
|
2ad8bd00ce |
@@ -21,9 +21,9 @@ use gpui::{
|
||||
use indoc::indoc;
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
|
||||
LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
|
||||
LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
|
||||
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
LanguageModelToolResult, LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent,
|
||||
Role, StopReason, fake_provider::FakeLanguageModel,
|
||||
};
|
||||
use pretty_assertions::assert_eq;
|
||||
use project::{
|
||||
@@ -664,9 +664,7 @@ async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
|
||||
);
|
||||
|
||||
// Simulate reaching tool use limit.
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
|
||||
cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
|
||||
));
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
|
||||
fake_model.end_last_completion_stream();
|
||||
let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
|
||||
assert!(
|
||||
@@ -749,9 +747,7 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
|
||||
};
|
||||
fake_model
|
||||
.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
|
||||
cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
|
||||
));
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
|
||||
fake_model.end_last_completion_stream();
|
||||
let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
|
||||
assert!(
|
||||
@@ -1533,7 +1529,7 @@ async fn test_truncate_first_message(cx: &mut TestAppContext) {
|
||||
});
|
||||
|
||||
fake_model.send_last_completion_stream_text_chunk("Hey!");
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::TokenUsage(
|
||||
language_model::TokenUsage {
|
||||
input_tokens: 32_000,
|
||||
output_tokens: 16_000,
|
||||
@@ -1591,7 +1587,7 @@ async fn test_truncate_first_message(cx: &mut TestAppContext) {
|
||||
});
|
||||
cx.run_until_parked();
|
||||
fake_model.send_last_completion_stream_text_chunk("Ahoy!");
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::TokenUsage(
|
||||
language_model::TokenUsage {
|
||||
input_tokens: 40_000,
|
||||
output_tokens: 20_000,
|
||||
@@ -1636,7 +1632,7 @@ async fn test_truncate_second_message(cx: &mut TestAppContext) {
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
fake_model.send_last_completion_stream_text_chunk("Message 1 response");
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::TokenUsage(
|
||||
language_model::TokenUsage {
|
||||
input_tokens: 32_000,
|
||||
output_tokens: 16_000,
|
||||
@@ -1683,7 +1679,7 @@ async fn test_truncate_second_message(cx: &mut TestAppContext) {
|
||||
cx.run_until_parked();
|
||||
|
||||
fake_model.send_last_completion_stream_text_chunk("Message 2 response");
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
|
||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::TokenUsage(
|
||||
language_model::TokenUsage {
|
||||
input_tokens: 40_000,
|
||||
output_tokens: 20_000,
|
||||
@@ -2159,7 +2155,6 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
|
||||
|
||||
fake_model.send_last_completion_stream_text_chunk("Hey,");
|
||||
fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
|
||||
provider: LanguageModelProviderName::new("Anthropic"),
|
||||
retry_after: Some(Duration::from_secs(3)),
|
||||
});
|
||||
fake_model.end_last_completion_stream();
|
||||
@@ -2235,7 +2230,6 @@ async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
|
||||
tool_use_1.clone(),
|
||||
));
|
||||
fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
|
||||
provider: LanguageModelProviderName::new("Anthropic"),
|
||||
retry_after: Some(Duration::from_secs(3)),
|
||||
});
|
||||
fake_model.end_last_completion_stream();
|
||||
@@ -2302,7 +2296,6 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
|
||||
for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
|
||||
fake_model.send_last_completion_stream_error(
|
||||
LanguageModelCompletionError::ServerOverloaded {
|
||||
provider: LanguageModelProviderName::new("Anthropic"),
|
||||
retry_after: Some(Duration::from_secs(3)),
|
||||
},
|
||||
);
|
||||
|
||||
@@ -15,7 +15,7 @@ use agent_settings::{
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use chrono::{DateTime, Utc};
|
||||
use client::{ModelRequestUsage, RequestUsage, UserStore};
|
||||
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit};
|
||||
use cloud_llm_client::{CompletionIntent, Plan, UsageLimit};
|
||||
use collections::{HashMap, HashSet, IndexMap};
|
||||
use fs::Fs;
|
||||
use futures::stream;
|
||||
@@ -30,11 +30,11 @@ use gpui::{
|
||||
};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt,
|
||||
LanguageModelId, LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry,
|
||||
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
|
||||
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
|
||||
LanguageModelToolUse, LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage,
|
||||
ZED_CLOUD_PROVIDER_ID,
|
||||
LanguageModelId, LanguageModelImage, LanguageModelProviderId, LanguageModelProviderName,
|
||||
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
|
||||
LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role,
|
||||
SelectedModel, StopReason, TokenUsage, ZED_CLOUD_PROVIDER_ID,
|
||||
};
|
||||
use project::Project;
|
||||
use prompt_store::ProjectContext;
|
||||
@@ -1295,9 +1295,10 @@ impl Thread {
|
||||
|
||||
if let Some(error) = error {
|
||||
attempt += 1;
|
||||
let provider = model.upstream_provider_name();
|
||||
let retry = this.update(cx, |this, cx| {
|
||||
let user_store = this.user_store.read(cx);
|
||||
this.handle_completion_error(error, attempt, user_store.plan())
|
||||
this.handle_completion_error(provider, error, attempt, user_store.plan())
|
||||
})??;
|
||||
let timer = cx.background_executor().timer(retry.duration);
|
||||
event_stream.send_retry(retry);
|
||||
@@ -1323,6 +1324,7 @@ impl Thread {
|
||||
|
||||
fn handle_completion_error(
|
||||
&mut self,
|
||||
provider: LanguageModelProviderName,
|
||||
error: LanguageModelCompletionError,
|
||||
attempt: u8,
|
||||
plan: Option<Plan>,
|
||||
@@ -1389,7 +1391,7 @@ impl Thread {
|
||||
use LanguageModelCompletionEvent::*;
|
||||
|
||||
match event {
|
||||
StartMessage { .. } => {
|
||||
LanguageModelCompletionEvent::StartMessage { .. } => {
|
||||
self.flush_pending_message(cx);
|
||||
self.pending_message = Some(AgentMessage::default());
|
||||
}
|
||||
@@ -1416,7 +1418,7 @@ impl Thread {
|
||||
),
|
||||
)));
|
||||
}
|
||||
UsageUpdate(usage) => {
|
||||
TokenUsage(usage) => {
|
||||
telemetry::event!(
|
||||
"Agent Thread Completion Usage Updated",
|
||||
thread_id = self.id.to_string(),
|
||||
@@ -1430,20 +1432,16 @@ impl Thread {
|
||||
);
|
||||
self.update_token_usage(usage, cx);
|
||||
}
|
||||
StatusUpdate(CompletionRequestStatus::UsageUpdated { amount, limit }) => {
|
||||
RequestUsage { amount, limit } => {
|
||||
self.update_model_request_usage(amount, limit, cx);
|
||||
}
|
||||
StatusUpdate(
|
||||
CompletionRequestStatus::Started
|
||||
| CompletionRequestStatus::Queued { .. }
|
||||
| CompletionRequestStatus::Failed { .. },
|
||||
) => {}
|
||||
StatusUpdate(CompletionRequestStatus::ToolUseLimitReached) => {
|
||||
ToolUseLimitReached => {
|
||||
self.tool_use_limit_reached = true;
|
||||
}
|
||||
Stop(StopReason::Refusal) => return Err(CompletionError::Refusal.into()),
|
||||
Stop(StopReason::MaxTokens) => return Err(CompletionError::MaxTokens.into()),
|
||||
Stop(StopReason::ToolUse | StopReason::EndTurn) => {}
|
||||
Started | Queued { .. } => {}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
@@ -1687,9 +1685,7 @@ impl Thread {
|
||||
let event = event.log_err()?;
|
||||
let text = match event {
|
||||
LanguageModelCompletionEvent::Text(text) => text,
|
||||
LanguageModelCompletionEvent::StatusUpdate(
|
||||
CompletionRequestStatus::UsageUpdated { amount, limit },
|
||||
) => {
|
||||
LanguageModelCompletionEvent::RequestUsage { amount, limit } => {
|
||||
this.update(cx, |thread, cx| {
|
||||
thread.update_model_request_usage(amount, limit, cx);
|
||||
})
|
||||
@@ -1753,9 +1749,7 @@ impl Thread {
|
||||
let event = event?;
|
||||
let text = match event {
|
||||
LanguageModelCompletionEvent::Text(text) => text,
|
||||
LanguageModelCompletionEvent::StatusUpdate(
|
||||
CompletionRequestStatus::UsageUpdated { amount, limit },
|
||||
) => {
|
||||
LanguageModelCompletionEvent::RequestUsage { amount, limit } => {
|
||||
this.update(cx, |thread, cx| {
|
||||
thread.update_model_request_usage(amount, limit, cx);
|
||||
})?;
|
||||
|
||||
@@ -7,9 +7,10 @@ use assistant_slash_command::{
|
||||
use assistant_slash_commands::FileCommandMetadata;
|
||||
use client::{self, ModelRequestUsage, RequestUsage, proto, telemetry::Telemetry};
|
||||
use clock::ReplicaId;
|
||||
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
|
||||
use cloud_llm_client::{CompletionIntent, UsageLimit};
|
||||
use collections::{HashMap, HashSet};
|
||||
use fs::{Fs, RenameOptions};
|
||||
|
||||
use futures::{FutureExt, StreamExt, future::Shared};
|
||||
use gpui::{
|
||||
App, AppContext as _, Context, Entity, EventEmitter, RenderImage, SharedString, Subscription,
|
||||
@@ -2073,14 +2074,15 @@ impl TextThread {
|
||||
});
|
||||
|
||||
match event {
|
||||
LanguageModelCompletionEvent::StatusUpdate(status_update) => {
|
||||
if let CompletionRequestStatus::UsageUpdated { amount, limit } = status_update {
|
||||
this.update_model_request_usage(
|
||||
amount as u32,
|
||||
limit,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
LanguageModelCompletionEvent::Started |
|
||||
LanguageModelCompletionEvent::Queued {..} |
|
||||
LanguageModelCompletionEvent::ToolUseLimitReached { .. } => {}
|
||||
LanguageModelCompletionEvent::RequestUsage { amount, limit } => {
|
||||
this.update_model_request_usage(
|
||||
amount as u32,
|
||||
limit,
|
||||
cx,
|
||||
);
|
||||
}
|
||||
LanguageModelCompletionEvent::StartMessage { .. } => {}
|
||||
LanguageModelCompletionEvent::Stop(reason) => {
|
||||
@@ -2142,7 +2144,7 @@ impl TextThread {
|
||||
}
|
||||
LanguageModelCompletionEvent::ToolUse(_) |
|
||||
LanguageModelCompletionEvent::ToolUseJsonParseError { .. } |
|
||||
LanguageModelCompletionEvent::UsageUpdate(_) => {}
|
||||
LanguageModelCompletionEvent::TokenUsage(_) => {}
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
@@ -1250,9 +1250,12 @@ pub fn response_events_to_markdown(
|
||||
));
|
||||
}
|
||||
Ok(
|
||||
LanguageModelCompletionEvent::UsageUpdate(_)
|
||||
LanguageModelCompletionEvent::TokenUsage(_)
|
||||
| LanguageModelCompletionEvent::ToolUseLimitReached
|
||||
| LanguageModelCompletionEvent::StartMessage { .. }
|
||||
| LanguageModelCompletionEvent::StatusUpdate { .. },
|
||||
| LanguageModelCompletionEvent::RequestUsage { .. }
|
||||
| LanguageModelCompletionEvent::Queued { .. }
|
||||
| LanguageModelCompletionEvent::Started,
|
||||
) => {}
|
||||
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||
json_parse_error, ..
|
||||
@@ -1335,11 +1338,14 @@ impl ThreadDialog {
|
||||
}
|
||||
|
||||
// Skip these
|
||||
Ok(LanguageModelCompletionEvent::UsageUpdate(_))
|
||||
Ok(LanguageModelCompletionEvent::TokenUsage(_))
|
||||
| Ok(LanguageModelCompletionEvent::RedactedThinking { .. })
|
||||
| Ok(LanguageModelCompletionEvent::StatusUpdate { .. })
|
||||
| Ok(LanguageModelCompletionEvent::StartMessage { .. })
|
||||
| Ok(LanguageModelCompletionEvent::Stop(_)) => {}
|
||||
| Ok(LanguageModelCompletionEvent::Stop(_))
|
||||
| Ok(LanguageModelCompletionEvent::Queued { .. })
|
||||
| Ok(LanguageModelCompletionEvent::Started)
|
||||
| Ok(LanguageModelCompletionEvent::RequestUsage { .. })
|
||||
| Ok(LanguageModelCompletionEvent::ToolUseLimitReached) => {}
|
||||
|
||||
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||
json_parse_error,
|
||||
|
||||
@@ -12,7 +12,7 @@ pub mod fake_provider;
|
||||
use anthropic::{AnthropicError, parse_prompt_too_long};
|
||||
use anyhow::{Result, anyhow};
|
||||
use client::Client;
|
||||
use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
|
||||
use cloud_llm_client::{CompletionMode, CompletionRequestStatus, UsageLimit};
|
||||
use futures::FutureExt;
|
||||
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
|
||||
use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
|
||||
@@ -70,7 +70,15 @@ pub fn init_settings(cx: &mut App) {
|
||||
/// A completion event from a language model.
|
||||
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
|
||||
pub enum LanguageModelCompletionEvent {
|
||||
StatusUpdate(CompletionRequestStatus),
|
||||
Queued {
|
||||
position: usize,
|
||||
},
|
||||
Started,
|
||||
RequestUsage {
|
||||
amount: usize,
|
||||
limit: UsageLimit,
|
||||
},
|
||||
ToolUseLimitReached,
|
||||
Stop(StopReason),
|
||||
Text(String),
|
||||
Thinking {
|
||||
@@ -90,88 +98,93 @@ pub enum LanguageModelCompletionEvent {
|
||||
StartMessage {
|
||||
message_id: String,
|
||||
},
|
||||
UsageUpdate(TokenUsage),
|
||||
TokenUsage(TokenUsage),
|
||||
}
|
||||
|
||||
impl LanguageModelCompletionEvent {
|
||||
pub fn from_completion_request_status(
|
||||
status: CompletionRequestStatus,
|
||||
) -> Result<Self, LanguageModelCompletionError> {
|
||||
match status {
|
||||
CompletionRequestStatus::Queued { position } => {
|
||||
Ok(LanguageModelCompletionEvent::Queued { position })
|
||||
}
|
||||
CompletionRequestStatus::Started => Ok(LanguageModelCompletionEvent::Started),
|
||||
CompletionRequestStatus::UsageUpdated { amount, limit } => {
|
||||
Ok(LanguageModelCompletionEvent::RequestUsage { amount, limit })
|
||||
}
|
||||
CompletionRequestStatus::ToolUseLimitReached => {
|
||||
Ok(LanguageModelCompletionEvent::ToolUseLimitReached)
|
||||
}
|
||||
CompletionRequestStatus::Failed {
|
||||
code,
|
||||
message,
|
||||
request_id: _,
|
||||
retry_after,
|
||||
} => Err(LanguageModelCompletionError::from_cloud_failure(
|
||||
code,
|
||||
message,
|
||||
retry_after.map(Duration::from_secs_f64),
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum LanguageModelCompletionError {
|
||||
#[error("prompt too large for context window")]
|
||||
PromptTooLarge { tokens: Option<u64> },
|
||||
#[error("missing {provider} API key")]
|
||||
NoApiKey { provider: LanguageModelProviderName },
|
||||
#[error("{provider}'s API rate limit exceeded")]
|
||||
RateLimitExceeded {
|
||||
provider: LanguageModelProviderName,
|
||||
retry_after: Option<Duration>,
|
||||
},
|
||||
#[error("{provider}'s API servers are overloaded right now")]
|
||||
ServerOverloaded {
|
||||
provider: LanguageModelProviderName,
|
||||
retry_after: Option<Duration>,
|
||||
},
|
||||
#[error("{provider}'s API server reported an internal server error: {message}")]
|
||||
ApiInternalServerError {
|
||||
provider: LanguageModelProviderName,
|
||||
message: String,
|
||||
},
|
||||
#[error("missing API key")]
|
||||
NoApiKey,
|
||||
#[error("API rate limit exceeded")]
|
||||
RateLimitExceeded { retry_after: Option<Duration> },
|
||||
#[error("API servers are overloaded right now")]
|
||||
ServerOverloaded { retry_after: Option<Duration> },
|
||||
#[error("API server reported an internal server error: {message}")]
|
||||
ApiInternalServerError { message: String },
|
||||
#[error("{message}")]
|
||||
UpstreamProviderError {
|
||||
message: String,
|
||||
status: StatusCode,
|
||||
retry_after: Option<Duration>,
|
||||
},
|
||||
#[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
|
||||
#[error("HTTP response error from API: status {status_code} - {message:?}")]
|
||||
HttpResponseError {
|
||||
provider: LanguageModelProviderName,
|
||||
status_code: StatusCode,
|
||||
message: String,
|
||||
},
|
||||
|
||||
// Client errors
|
||||
#[error("invalid request format to {provider}'s API: {message}")]
|
||||
BadRequestFormat {
|
||||
provider: LanguageModelProviderName,
|
||||
message: String,
|
||||
},
|
||||
#[error("authentication error with {provider}'s API: {message}")]
|
||||
AuthenticationError {
|
||||
provider: LanguageModelProviderName,
|
||||
message: String,
|
||||
},
|
||||
#[error("Permission error with {provider}'s API: {message}")]
|
||||
PermissionError {
|
||||
provider: LanguageModelProviderName,
|
||||
message: String,
|
||||
},
|
||||
#[error("invalid request format to API: {message}")]
|
||||
BadRequestFormat { message: String },
|
||||
#[error("authentication error with API: {message}")]
|
||||
AuthenticationError { message: String },
|
||||
#[error("Permission error with API: {message}")]
|
||||
PermissionError { message: String },
|
||||
#[error("language model provider API endpoint not found")]
|
||||
ApiEndpointNotFound { provider: LanguageModelProviderName },
|
||||
#[error("I/O error reading response from {provider}'s API")]
|
||||
ApiEndpointNotFound,
|
||||
#[error("I/O error reading response from API")]
|
||||
ApiReadResponseError {
|
||||
provider: LanguageModelProviderName,
|
||||
#[source]
|
||||
error: io::Error,
|
||||
},
|
||||
#[error("error serializing request to {provider} API")]
|
||||
#[error("error serializing request to API")]
|
||||
SerializeRequest {
|
||||
provider: LanguageModelProviderName,
|
||||
#[source]
|
||||
error: serde_json::Error,
|
||||
},
|
||||
#[error("error building request body to {provider} API")]
|
||||
#[error("error building request body to API")]
|
||||
BuildRequestBody {
|
||||
provider: LanguageModelProviderName,
|
||||
#[source]
|
||||
error: http::Error,
|
||||
},
|
||||
#[error("error sending HTTP request to {provider} API")]
|
||||
#[error("error sending HTTP request to API")]
|
||||
HttpSend {
|
||||
provider: LanguageModelProviderName,
|
||||
#[source]
|
||||
error: anyhow::Error,
|
||||
},
|
||||
#[error("error deserializing {provider} API response")]
|
||||
#[error("error deserializing API response")]
|
||||
DeserializeResponse {
|
||||
provider: LanguageModelProviderName,
|
||||
#[source]
|
||||
error: serde_json::Error,
|
||||
},
|
||||
@@ -182,6 +195,72 @@ pub enum LanguageModelCompletionError {
|
||||
}
|
||||
|
||||
impl LanguageModelCompletionError {
|
||||
fn display_format(&self, provider: LanguageModelProviderName) {
|
||||
// match self {
|
||||
// #[error("prompt too large for context window")]
|
||||
// PromptTooLarge { tokens: Option<u64> },
|
||||
// #[error("missing API key")]
|
||||
// NoApiKey,
|
||||
// #[error("API rate limit exceeded")]
|
||||
// RateLimitExceeded { retry_after: Option<Duration> },
|
||||
// #[error("API servers are overloaded right now")]
|
||||
// ServerOverloaded { retry_after: Option<Duration> },
|
||||
// #[error("API server reported an internal server error: {message}")]
|
||||
// ApiInternalServerError { message: String },
|
||||
// #[error("{message}")]
|
||||
// UpstreamProviderError {
|
||||
// message: String,
|
||||
// status: StatusCode,
|
||||
// retry_after: Option<Duration>,
|
||||
// },
|
||||
// #[error("HTTP response error from API: status {status_code} - {message:?}")]
|
||||
// HttpResponseError {
|
||||
// status_code: StatusCode,
|
||||
// message: String,
|
||||
// },
|
||||
|
||||
// // Client errors
|
||||
// #[error("invalid request format to API: {message}")]
|
||||
// BadRequestFormat { message: String },
|
||||
// #[error("authentication error with API: {message}")]
|
||||
// AuthenticationError { message: String },
|
||||
// #[error("Permission error with API: {message}")]
|
||||
// PermissionError { message: String },
|
||||
// #[error("language model provider API endpoint not found")]
|
||||
// ApiEndpointNotFound,
|
||||
// #[error("I/O error reading response from API")]
|
||||
// ApiReadResponseError {
|
||||
// #[source]
|
||||
// error: io::Error,
|
||||
// },
|
||||
// #[error("error serializing request to API")]
|
||||
// SerializeRequest {
|
||||
// #[source]
|
||||
// error: serde_json::Error,
|
||||
// },
|
||||
// #[error("error building request body to API")]
|
||||
// BuildRequestBody {
|
||||
// #[source]
|
||||
// error: http::Error,
|
||||
// },
|
||||
// #[error("error sending HTTP request to API")]
|
||||
// HttpSend {
|
||||
// #[source]
|
||||
// error: anyhow::Error,
|
||||
// },
|
||||
// #[error("error deserializing API response")]
|
||||
// DeserializeResponse {
|
||||
// #[source]
|
||||
// error: serde_json::Error,
|
||||
// },
|
||||
|
||||
// // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
|
||||
// #[error(transparent)]
|
||||
// Other(#[from] anyhow::Error),
|
||||
|
||||
// }
|
||||
}
|
||||
|
||||
fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
|
||||
let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
|
||||
let upstream_status = error_json
|
||||
@@ -198,7 +277,6 @@ impl LanguageModelCompletionError {
|
||||
}
|
||||
|
||||
pub fn from_cloud_failure(
|
||||
upstream_provider: LanguageModelProviderName,
|
||||
code: String,
|
||||
message: String,
|
||||
retry_after: Option<Duration>,
|
||||
@@ -214,58 +292,46 @@ impl LanguageModelCompletionError {
|
||||
if let Some((upstream_status, inner_message)) =
|
||||
Self::parse_upstream_error_json(&message)
|
||||
{
|
||||
return Self::from_http_status(
|
||||
upstream_provider,
|
||||
upstream_status,
|
||||
inner_message,
|
||||
retry_after,
|
||||
);
|
||||
return Self::from_http_status(upstream_status, inner_message, retry_after);
|
||||
}
|
||||
anyhow!("completion request failed, code: {code}, message: {message}").into()
|
||||
Self::Other(anyhow!(
|
||||
"completion request failed, code: {code}, message: {message}"
|
||||
))
|
||||
} else if let Some(status_code) = code
|
||||
.strip_prefix("upstream_http_")
|
||||
.and_then(|code| StatusCode::from_str(code).ok())
|
||||
{
|
||||
Self::from_http_status(upstream_provider, status_code, message, retry_after)
|
||||
Self::from_http_status(status_code, message, retry_after)
|
||||
} else if let Some(status_code) = code
|
||||
.strip_prefix("http_")
|
||||
.and_then(|code| StatusCode::from_str(code).ok())
|
||||
{
|
||||
Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
|
||||
Self::from_http_status(status_code, message, retry_after)
|
||||
} else {
|
||||
anyhow!("completion request failed, code: {code}, message: {message}").into()
|
||||
Self::Other(anyhow!(
|
||||
"completion request failed, code: {code}, message: {message}"
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_http_status(
|
||||
provider: LanguageModelProviderName,
|
||||
status_code: StatusCode,
|
||||
message: String,
|
||||
retry_after: Option<Duration>,
|
||||
) -> Self {
|
||||
match status_code {
|
||||
StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
|
||||
StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
|
||||
StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
|
||||
StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
|
||||
StatusCode::BAD_REQUEST => Self::BadRequestFormat { message },
|
||||
StatusCode::UNAUTHORIZED => Self::AuthenticationError { message },
|
||||
StatusCode::FORBIDDEN => Self::PermissionError { message },
|
||||
StatusCode::NOT_FOUND => Self::ApiEndpointNotFound,
|
||||
StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
|
||||
tokens: parse_prompt_too_long(&message),
|
||||
},
|
||||
StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
|
||||
provider,
|
||||
retry_after,
|
||||
},
|
||||
StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
|
||||
StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after,
|
||||
},
|
||||
_ if status_code.as_u16() == 529 => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after,
|
||||
},
|
||||
StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded { retry_after },
|
||||
StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { message },
|
||||
StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded { retry_after },
|
||||
_ if status_code.as_u16() == 529 => Self::ServerOverloaded { retry_after },
|
||||
_ => Self::HttpResponseError {
|
||||
provider,
|
||||
status_code,
|
||||
message,
|
||||
},
|
||||
@@ -275,31 +341,25 @@ impl LanguageModelCompletionError {
|
||||
|
||||
impl From<AnthropicError> for LanguageModelCompletionError {
|
||||
fn from(error: AnthropicError) -> Self {
|
||||
let provider = ANTHROPIC_PROVIDER_NAME;
|
||||
match error {
|
||||
AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
|
||||
AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
|
||||
AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
|
||||
AnthropicError::DeserializeResponse(error) => {
|
||||
Self::DeserializeResponse { provider, error }
|
||||
}
|
||||
AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
|
||||
AnthropicError::SerializeRequest(error) => Self::SerializeRequest { error },
|
||||
AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { error },
|
||||
AnthropicError::HttpSend(error) => Self::HttpSend { error },
|
||||
AnthropicError::DeserializeResponse(error) => Self::DeserializeResponse { error },
|
||||
AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { error },
|
||||
AnthropicError::HttpResponseError {
|
||||
status_code,
|
||||
message,
|
||||
} => Self::HttpResponseError {
|
||||
provider,
|
||||
status_code,
|
||||
message,
|
||||
},
|
||||
AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
|
||||
provider,
|
||||
retry_after: Some(retry_after),
|
||||
},
|
||||
AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after,
|
||||
},
|
||||
AnthropicError::ServerOverloaded { retry_after } => {
|
||||
Self::ServerOverloaded { retry_after }
|
||||
}
|
||||
AnthropicError::ApiError(api_error) => api_error.into(),
|
||||
}
|
||||
}
|
||||
@@ -308,37 +368,26 @@ impl From<AnthropicError> for LanguageModelCompletionError {
|
||||
impl From<anthropic::ApiError> for LanguageModelCompletionError {
|
||||
fn from(error: anthropic::ApiError) -> Self {
|
||||
use anthropic::ApiErrorCode::*;
|
||||
let provider = ANTHROPIC_PROVIDER_NAME;
|
||||
match error.code() {
|
||||
Some(code) => match code {
|
||||
InvalidRequestError => Self::BadRequestFormat {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
AuthenticationError => Self::AuthenticationError {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
PermissionError => Self::PermissionError {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
NotFoundError => Self::ApiEndpointNotFound { provider },
|
||||
NotFoundError => Self::ApiEndpointNotFound,
|
||||
RequestTooLarge => Self::PromptTooLarge {
|
||||
tokens: parse_prompt_too_long(&error.message),
|
||||
},
|
||||
RateLimitError => Self::RateLimitExceeded {
|
||||
provider,
|
||||
retry_after: None,
|
||||
},
|
||||
RateLimitError => Self::RateLimitExceeded { retry_after: None },
|
||||
ApiError => Self::ApiInternalServerError {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
OverloadedError => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after: None,
|
||||
},
|
||||
OverloadedError => Self::ServerOverloaded { retry_after: None },
|
||||
},
|
||||
None => Self::Other(error.into()),
|
||||
}
|
||||
@@ -349,7 +398,7 @@ impl From<open_ai::RequestError> for LanguageModelCompletionError {
|
||||
fn from(error: open_ai::RequestError) -> Self {
|
||||
match error {
|
||||
open_ai::RequestError::HttpResponseError {
|
||||
provider,
|
||||
provider: _,
|
||||
status_code,
|
||||
body,
|
||||
headers,
|
||||
@@ -359,7 +408,7 @@ impl From<open_ai::RequestError> for LanguageModelCompletionError {
|
||||
.and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
|
||||
.map(Duration::from_secs);
|
||||
|
||||
Self::from_http_status(provider.into(), status_code, body, retry_after)
|
||||
Self::from_http_status(status_code, body, retry_after)
|
||||
}
|
||||
open_ai::RequestError::Other(e) => Self::Other(e),
|
||||
}
|
||||
@@ -368,23 +417,18 @@ impl From<open_ai::RequestError> for LanguageModelCompletionError {
|
||||
|
||||
impl From<OpenRouterError> for LanguageModelCompletionError {
|
||||
fn from(error: OpenRouterError) -> Self {
|
||||
let provider = LanguageModelProviderName::new("OpenRouter");
|
||||
match error {
|
||||
OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
|
||||
OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
|
||||
OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
|
||||
OpenRouterError::DeserializeResponse(error) => {
|
||||
Self::DeserializeResponse { provider, error }
|
||||
}
|
||||
OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
|
||||
OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { error },
|
||||
OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { error },
|
||||
OpenRouterError::HttpSend(error) => Self::HttpSend { error },
|
||||
OpenRouterError::DeserializeResponse(error) => Self::DeserializeResponse { error },
|
||||
OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { error },
|
||||
OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
|
||||
provider,
|
||||
retry_after: Some(retry_after),
|
||||
},
|
||||
OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after,
|
||||
},
|
||||
OpenRouterError::ServerOverloaded { retry_after } => {
|
||||
Self::ServerOverloaded { retry_after }
|
||||
}
|
||||
OpenRouterError::ApiError(api_error) => api_error.into(),
|
||||
}
|
||||
}
|
||||
@@ -393,41 +437,28 @@ impl From<OpenRouterError> for LanguageModelCompletionError {
|
||||
impl From<open_router::ApiError> for LanguageModelCompletionError {
|
||||
fn from(error: open_router::ApiError) -> Self {
|
||||
use open_router::ApiErrorCode::*;
|
||||
let provider = LanguageModelProviderName::new("OpenRouter");
|
||||
match error.code {
|
||||
InvalidRequestError => Self::BadRequestFormat {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
AuthenticationError => Self::AuthenticationError {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
PaymentRequiredError => Self::AuthenticationError {
|
||||
provider,
|
||||
message: format!("Payment required: {}", error.message),
|
||||
},
|
||||
PermissionError => Self::PermissionError {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
RequestTimedOut => Self::HttpResponseError {
|
||||
provider,
|
||||
status_code: StatusCode::REQUEST_TIMEOUT,
|
||||
message: error.message,
|
||||
},
|
||||
RateLimitError => Self::RateLimitExceeded {
|
||||
provider,
|
||||
retry_after: None,
|
||||
},
|
||||
RateLimitError => Self::RateLimitExceeded { retry_after: None },
|
||||
ApiError => Self::ApiInternalServerError {
|
||||
provider,
|
||||
message: error.message,
|
||||
},
|
||||
OverloadedError => Self::ServerOverloaded {
|
||||
provider,
|
||||
retry_after: None,
|
||||
},
|
||||
OverloadedError => Self::ServerOverloaded { retry_after: None },
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -633,7 +664,10 @@ pub trait LanguageModel: Send + Sync {
|
||||
let last_token_usage = last_token_usage.clone();
|
||||
async move {
|
||||
match result {
|
||||
Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
|
||||
Ok(LanguageModelCompletionEvent::Queued { .. }) => None,
|
||||
Ok(LanguageModelCompletionEvent::Started) => None,
|
||||
Ok(LanguageModelCompletionEvent::RequestUsage { .. }) => None,
|
||||
Ok(LanguageModelCompletionEvent::ToolUseLimitReached) => None,
|
||||
Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
|
||||
Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
|
||||
Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
|
||||
@@ -643,7 +677,7 @@ pub trait LanguageModel: Send + Sync {
|
||||
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||
..
|
||||
}) => None,
|
||||
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
|
||||
Ok(LanguageModelCompletionEvent::TokenUsage(token_usage)) => {
|
||||
*last_token_usage.lock() = token_usage;
|
||||
None
|
||||
}
|
||||
@@ -832,16 +866,13 @@ mod tests {
|
||||
#[test]
|
||||
fn test_from_cloud_failure_with_upstream_http_error() {
|
||||
let error = LanguageModelCompletionError::from_cloud_failure(
|
||||
String::from("anthropic").into(),
|
||||
"upstream_http_error".to_string(),
|
||||
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
|
||||
None,
|
||||
);
|
||||
|
||||
match error {
|
||||
LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
|
||||
assert_eq!(provider.0, "anthropic");
|
||||
}
|
||||
LanguageModelCompletionError::ServerOverloaded { .. } => {}
|
||||
_ => panic!(
|
||||
"Expected ServerOverloaded error for 503 status, got: {:?}",
|
||||
error
|
||||
@@ -849,15 +880,13 @@ mod tests {
|
||||
}
|
||||
|
||||
let error = LanguageModelCompletionError::from_cloud_failure(
|
||||
String::from("anthropic").into(),
|
||||
"upstream_http_error".to_string(),
|
||||
r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
|
||||
None,
|
||||
);
|
||||
|
||||
match error {
|
||||
LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
|
||||
assert_eq!(provider.0, "anthropic");
|
||||
LanguageModelCompletionError::ApiInternalServerError { message } => {
|
||||
assert_eq!(message, "Internal server error");
|
||||
}
|
||||
_ => panic!(
|
||||
@@ -870,16 +899,13 @@ mod tests {
|
||||
#[test]
|
||||
fn test_from_cloud_failure_with_standard_format() {
|
||||
let error = LanguageModelCompletionError::from_cloud_failure(
|
||||
String::from("anthropic").into(),
|
||||
"upstream_http_503".to_string(),
|
||||
"Service unavailable".to_string(),
|
||||
None,
|
||||
);
|
||||
|
||||
match error {
|
||||
LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
|
||||
assert_eq!(provider.0, "anthropic");
|
||||
}
|
||||
LanguageModelCompletionError::ServerOverloaded { .. } => {}
|
||||
_ => panic!("Expected ServerOverloaded error for upstream_http_503"),
|
||||
}
|
||||
}
|
||||
@@ -887,16 +913,13 @@ mod tests {
|
||||
#[test]
|
||||
fn test_upstream_http_error_connection_timeout() {
|
||||
let error = LanguageModelCompletionError::from_cloud_failure(
|
||||
String::from("anthropic").into(),
|
||||
"upstream_http_error".to_string(),
|
||||
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
|
||||
None,
|
||||
);
|
||||
|
||||
match error {
|
||||
LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
|
||||
assert_eq!(provider.0, "anthropic");
|
||||
}
|
||||
LanguageModelCompletionError::ServerOverloaded { .. } => {}
|
||||
_ => panic!(
|
||||
"Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
|
||||
error
|
||||
@@ -904,15 +927,13 @@ mod tests {
|
||||
}
|
||||
|
||||
let error = LanguageModelCompletionError::from_cloud_failure(
|
||||
String::from("anthropic").into(),
|
||||
"upstream_http_error".to_string(),
|
||||
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
|
||||
None,
|
||||
);
|
||||
|
||||
match error {
|
||||
LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
|
||||
assert_eq!(provider.0, "anthropic");
|
||||
LanguageModelCompletionError::ApiInternalServerError { message } => {
|
||||
assert_eq!(
|
||||
message,
|
||||
"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
|
||||
|
||||
@@ -320,9 +320,7 @@ impl AnthropicModel {
|
||||
|
||||
async move {
|
||||
let Some(api_key) = api_key else {
|
||||
return Err(LanguageModelCompletionError::NoApiKey {
|
||||
provider: PROVIDER_NAME,
|
||||
});
|
||||
return Err(LanguageModelCompletionError::NoApiKey);
|
||||
};
|
||||
let request = anthropic::stream_completion(
|
||||
http_client.as_ref(),
|
||||
@@ -756,7 +754,7 @@ impl AnthropicEventMapper {
|
||||
Event::MessageStart { message } => {
|
||||
update_usage(&mut self.usage, &message.usage);
|
||||
vec![
|
||||
Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage(
|
||||
Ok(LanguageModelCompletionEvent::TokenUsage(convert_usage(
|
||||
&self.usage,
|
||||
))),
|
||||
Ok(LanguageModelCompletionEvent::StartMessage {
|
||||
@@ -778,9 +776,9 @@ impl AnthropicEventMapper {
|
||||
}
|
||||
};
|
||||
}
|
||||
vec![Ok(LanguageModelCompletionEvent::UsageUpdate(
|
||||
convert_usage(&self.usage),
|
||||
))]
|
||||
vec![Ok(LanguageModelCompletionEvent::TokenUsage(convert_usage(
|
||||
&self.usage,
|
||||
)))]
|
||||
}
|
||||
Event::MessageStop => {
|
||||
vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))]
|
||||
|
||||
@@ -975,7 +975,7 @@ pub fn map_to_language_model_completion_events(
|
||||
))
|
||||
}),
|
||||
ConverseStreamOutput::Metadata(cb_meta) => cb_meta.usage.map(|metadata| {
|
||||
Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||
input_tokens: metadata.input_tokens as u64,
|
||||
output_tokens: metadata.output_tokens as u64,
|
||||
cache_creation_input_tokens: metadata
|
||||
|
||||
@@ -541,7 +541,6 @@ impl From<ApiError> for LanguageModelCompletionError {
|
||||
}
|
||||
|
||||
return LanguageModelCompletionError::from_http_status(
|
||||
PROVIDER_NAME,
|
||||
error.status,
|
||||
cloud_error.message,
|
||||
None,
|
||||
@@ -549,12 +548,7 @@ impl From<ApiError> for LanguageModelCompletionError {
|
||||
}
|
||||
|
||||
let retry_after = None;
|
||||
LanguageModelCompletionError::from_http_status(
|
||||
PROVIDER_NAME,
|
||||
error.status,
|
||||
error.body,
|
||||
retry_after,
|
||||
)
|
||||
LanguageModelCompletionError::from_http_status(error.status, error.body, retry_after)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -961,7 +955,7 @@ where
|
||||
vec![Err(LanguageModelCompletionError::from(error))]
|
||||
}
|
||||
Ok(CompletionEvent::Status(event)) => {
|
||||
vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
|
||||
vec![LanguageModelCompletionEvent::from_completion_request_status(event)]
|
||||
}
|
||||
Ok(CompletionEvent::Event(event)) => map_callback(event),
|
||||
})
|
||||
@@ -1313,8 +1307,7 @@ mod tests {
|
||||
let completion_error: LanguageModelCompletionError = api_error.into();
|
||||
|
||||
match completion_error {
|
||||
LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
|
||||
assert_eq!(provider, PROVIDER_NAME);
|
||||
LanguageModelCompletionError::ApiInternalServerError { message } => {
|
||||
assert_eq!(message, "Regular internal server error");
|
||||
}
|
||||
_ => panic!(
|
||||
@@ -1362,9 +1355,7 @@ mod tests {
|
||||
let completion_error: LanguageModelCompletionError = api_error.into();
|
||||
|
||||
match completion_error {
|
||||
LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
|
||||
assert_eq!(provider, PROVIDER_NAME);
|
||||
}
|
||||
LanguageModelCompletionError::ApiInternalServerError { .. } => {}
|
||||
_ => panic!(
|
||||
"Expected ApiInternalServerError for invalid JSON, got: {:?}",
|
||||
completion_error
|
||||
|
||||
@@ -422,14 +422,12 @@ pub fn map_to_language_model_completion_events(
|
||||
}
|
||||
|
||||
if let Some(usage) = event.usage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
|
||||
TokenUsage {
|
||||
input_tokens: usage.prompt_tokens,
|
||||
output_tokens: usage.completion_tokens,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
},
|
||||
)));
|
||||
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||
input_tokens: usage.prompt_tokens,
|
||||
output_tokens: usage.completion_tokens,
|
||||
cache_creation_input_tokens: 0,
|
||||
cache_read_input_tokens: 0,
|
||||
})));
|
||||
}
|
||||
|
||||
match choice.finish_reason.as_deref() {
|
||||
@@ -610,7 +608,7 @@ impl CopilotResponsesEventMapper {
|
||||
copilot::copilot_responses::StreamEvent::Completed { response } => {
|
||||
let mut events = Vec::new();
|
||||
if let Some(usage) = response.usage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||
input_tokens: usage.input_tokens.unwrap_or(0),
|
||||
output_tokens: usage.output_tokens.unwrap_or(0),
|
||||
cache_creation_input_tokens: 0,
|
||||
@@ -643,7 +641,7 @@ impl CopilotResponsesEventMapper {
|
||||
|
||||
let mut events = Vec::new();
|
||||
if let Some(usage) = response.usage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||
input_tokens: usage.input_tokens.unwrap_or(0),
|
||||
output_tokens: usage.output_tokens.unwrap_or(0),
|
||||
cache_creation_input_tokens: 0,
|
||||
@@ -655,7 +653,6 @@ impl CopilotResponsesEventMapper {
|
||||
}
|
||||
|
||||
copilot::copilot_responses::StreamEvent::Failed { response } => {
|
||||
let provider = PROVIDER_NAME;
|
||||
let (status_code, message) = match response.error {
|
||||
Some(error) => {
|
||||
let status_code = StatusCode::from_str(&error.code)
|
||||
@@ -668,7 +665,6 @@ impl CopilotResponsesEventMapper {
|
||||
),
|
||||
};
|
||||
vec![Err(LanguageModelCompletionError::HttpResponseError {
|
||||
provider,
|
||||
status_code,
|
||||
message,
|
||||
})]
|
||||
@@ -1099,7 +1095,7 @@ mod tests {
|
||||
));
|
||||
assert!(matches!(
|
||||
mapped[2],
|
||||
LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||
input_tokens: 5,
|
||||
output_tokens: 3,
|
||||
..
|
||||
@@ -1207,7 +1203,7 @@ mod tests {
|
||||
let mapped = map_events(events);
|
||||
assert!(matches!(
|
||||
mapped[0],
|
||||
LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||
input_tokens: 10,
|
||||
output_tokens: 0,
|
||||
..
|
||||
|
||||
@@ -224,9 +224,7 @@ impl DeepSeekLanguageModel {
|
||||
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let Some(api_key) = api_key else {
|
||||
return Err(LanguageModelCompletionError::NoApiKey {
|
||||
provider: PROVIDER_NAME,
|
||||
});
|
||||
return Err(LanguageModelCompletionError::NoApiKey);
|
||||
};
|
||||
let request =
|
||||
deepseek::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||
@@ -479,7 +477,7 @@ impl DeepSeekEventMapper {
|
||||
}
|
||||
|
||||
if let Some(usage) = event.usage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||
input_tokens: usage.prompt_tokens,
|
||||
output_tokens: usage.completion_tokens,
|
||||
cache_creation_input_tokens: 0,
|
||||
|
||||
@@ -350,10 +350,7 @@ impl LanguageModel for GoogleLanguageModel {
|
||||
|
||||
async move {
|
||||
let Some(api_key) = api_key else {
|
||||
return Err(LanguageModelCompletionError::NoApiKey {
|
||||
provider: PROVIDER_NAME,
|
||||
}
|
||||
.into());
|
||||
return Err(LanguageModelCompletionError::NoApiKey.into());
|
||||
};
|
||||
let response = google_ai::count_tokens(
|
||||
http_client.as_ref(),
|
||||
@@ -608,9 +605,9 @@ impl GoogleEventMapper {
|
||||
let mut wants_to_use_tool = false;
|
||||
if let Some(usage_metadata) = event.usage_metadata {
|
||||
update_usage(&mut self.usage, &usage_metadata);
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
|
||||
convert_usage(&self.usage),
|
||||
)))
|
||||
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(convert_usage(
|
||||
&self.usage,
|
||||
))))
|
||||
}
|
||||
|
||||
if let Some(prompt_feedback) = event.prompt_feedback
|
||||
|
||||
@@ -547,7 +547,7 @@ impl LmStudioEventMapper {
|
||||
}
|
||||
|
||||
if let Some(usage) = event.usage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||
input_tokens: usage.prompt_tokens,
|
||||
output_tokens: usage.completion_tokens,
|
||||
cache_creation_input_tokens: 0,
|
||||
|
||||
@@ -291,9 +291,7 @@ impl MistralLanguageModel {
|
||||
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let Some(api_key) = api_key else {
|
||||
return Err(LanguageModelCompletionError::NoApiKey {
|
||||
provider: PROVIDER_NAME,
|
||||
});
|
||||
return Err(LanguageModelCompletionError::NoApiKey);
|
||||
};
|
||||
let request =
|
||||
mistral::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||
@@ -672,7 +670,7 @@ impl MistralEventMapper {
|
||||
}
|
||||
|
||||
if let Some(usage) = event.usage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||
input_tokens: usage.prompt_tokens,
|
||||
output_tokens: usage.completion_tokens,
|
||||
cache_creation_input_tokens: 0,
|
||||
|
||||
@@ -603,7 +603,7 @@ fn map_to_language_model_completion_events(
|
||||
};
|
||||
|
||||
if delta.done {
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||
input_tokens: delta.prompt_eval_count.unwrap_or(0),
|
||||
output_tokens: delta.eval_count.unwrap_or(0),
|
||||
cache_creation_input_tokens: 0,
|
||||
|
||||
@@ -228,7 +228,7 @@ impl OpenAiLanguageModel {
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let provider = PROVIDER_NAME;
|
||||
let Some(api_key) = api_key else {
|
||||
return Err(LanguageModelCompletionError::NoApiKey { provider });
|
||||
return Err(LanguageModelCompletionError::NoApiKey);
|
||||
};
|
||||
let request = stream_completion(
|
||||
http_client.as_ref(),
|
||||
@@ -534,7 +534,7 @@ impl OpenAiEventMapper {
|
||||
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||
let mut events = Vec::new();
|
||||
if let Some(usage) = event.usage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||
input_tokens: usage.prompt_tokens,
|
||||
output_tokens: usage.completion_tokens,
|
||||
cache_creation_input_tokens: 0,
|
||||
|
||||
@@ -227,7 +227,7 @@ impl OpenAiCompatibleLanguageModel {
|
||||
let provider = self.provider_name.clone();
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let Some(api_key) = api_key else {
|
||||
return Err(LanguageModelCompletionError::NoApiKey { provider });
|
||||
return Err(LanguageModelCompletionError::NoApiKey);
|
||||
};
|
||||
let request = stream_completion(
|
||||
http_client.as_ref(),
|
||||
|
||||
@@ -84,9 +84,7 @@ impl State {
|
||||
let http_client = self.http_client.clone();
|
||||
let api_url = OpenRouterLanguageModelProvider::api_url(cx);
|
||||
let Some(api_key) = self.api_key_state.key(&api_url) else {
|
||||
return Task::ready(Err(LanguageModelCompletionError::NoApiKey {
|
||||
provider: PROVIDER_NAME,
|
||||
}));
|
||||
return Task::ready(Err(LanguageModelCompletionError::NoApiKey));
|
||||
};
|
||||
cx.spawn(async move |this, cx| {
|
||||
let models = list_models(http_client.as_ref(), &api_url, &api_key)
|
||||
@@ -288,9 +286,7 @@ impl OpenRouterLanguageModel {
|
||||
|
||||
async move {
|
||||
let Some(api_key) = api_key else {
|
||||
return Err(LanguageModelCompletionError::NoApiKey {
|
||||
provider: PROVIDER_NAME,
|
||||
});
|
||||
return Err(LanguageModelCompletionError::NoApiKey);
|
||||
};
|
||||
let request =
|
||||
open_router::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||
@@ -613,7 +609,7 @@ impl OpenRouterEventMapper {
|
||||
}
|
||||
|
||||
if let Some(usage) = event.usage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
||||
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||
input_tokens: usage.prompt_tokens,
|
||||
output_tokens: usage.completion_tokens,
|
||||
cache_creation_input_tokens: 0,
|
||||
|
||||
@@ -222,7 +222,7 @@ impl VercelLanguageModel {
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let provider = PROVIDER_NAME;
|
||||
let Some(api_key) = api_key else {
|
||||
return Err(LanguageModelCompletionError::NoApiKey { provider });
|
||||
return Err(LanguageModelCompletionError::NoApiKey);
|
||||
};
|
||||
let request = open_ai::stream_completion(
|
||||
http_client.as_ref(),
|
||||
|
||||
@@ -230,7 +230,7 @@ impl XAiLanguageModel {
|
||||
let future = self.request_limiter.stream(async move {
|
||||
let provider = PROVIDER_NAME;
|
||||
let Some(api_key) = api_key else {
|
||||
return Err(LanguageModelCompletionError::NoApiKey { provider });
|
||||
return Err(LanguageModelCompletionError::NoApiKey);
|
||||
};
|
||||
let request = open_ai::stream_completion(
|
||||
http_client.as_ref(),
|
||||
|
||||
Reference in New Issue
Block a user