Compare commits

...

5 Commits

Author SHA1 Message Date
Mikayla Maki
5b74ddf020 WIP: Remove provider name 2025-11-18 23:56:05 -08:00
Mikayla Maki
54f20ae5d5 Remove provider name 2025-11-18 22:13:04 -08:00
Mikayla Maki
811efa45d0 Disambiguate similar completion events 2025-11-18 21:04:35 -08:00
Mikayla Maki
74501e0936 Simplify LanguageModelCompletionEvent enum, remove
`StatusUpdate::Failed` state by turning it into an error early, and then
faltten `StatusUpdate` into LanguageModelCompletionEvent
2025-11-18 20:56:11 -08:00
Michael Benfield
2ad8bd00ce Simplifying errors
Co-authored-by: Mikayla <mikayla@zed.dev>
2025-11-18 20:45:22 -08:00
19 changed files with 263 additions and 273 deletions

View File

@@ -21,9 +21,9 @@ use gpui::{
use indoc::indoc;
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelToolResult, LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent,
Role, StopReason, fake_provider::FakeLanguageModel,
};
use pretty_assertions::assert_eq;
use project::{
@@ -664,9 +664,7 @@ async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
);
// Simulate reaching tool use limit.
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
));
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
fake_model.end_last_completion_stream();
let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
assert!(
@@ -749,9 +747,7 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
};
fake_model
.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
));
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
fake_model.end_last_completion_stream();
let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
assert!(
@@ -1533,7 +1529,7 @@ async fn test_truncate_first_message(cx: &mut TestAppContext) {
});
fake_model.send_last_completion_stream_text_chunk("Hey!");
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::TokenUsage(
language_model::TokenUsage {
input_tokens: 32_000,
output_tokens: 16_000,
@@ -1591,7 +1587,7 @@ async fn test_truncate_first_message(cx: &mut TestAppContext) {
});
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Ahoy!");
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::TokenUsage(
language_model::TokenUsage {
input_tokens: 40_000,
output_tokens: 20_000,
@@ -1636,7 +1632,7 @@ async fn test_truncate_second_message(cx: &mut TestAppContext) {
.unwrap();
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Message 1 response");
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::TokenUsage(
language_model::TokenUsage {
input_tokens: 32_000,
output_tokens: 16_000,
@@ -1683,7 +1679,7 @@ async fn test_truncate_second_message(cx: &mut TestAppContext) {
cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Message 2 response");
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::TokenUsage(
language_model::TokenUsage {
input_tokens: 40_000,
output_tokens: 20_000,
@@ -2159,7 +2155,6 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
fake_model.send_last_completion_stream_text_chunk("Hey,");
fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
provider: LanguageModelProviderName::new("Anthropic"),
retry_after: Some(Duration::from_secs(3)),
});
fake_model.end_last_completion_stream();
@@ -2235,7 +2230,6 @@ async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
tool_use_1.clone(),
));
fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
provider: LanguageModelProviderName::new("Anthropic"),
retry_after: Some(Duration::from_secs(3)),
});
fake_model.end_last_completion_stream();
@@ -2302,7 +2296,6 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
fake_model.send_last_completion_stream_error(
LanguageModelCompletionError::ServerOverloaded {
provider: LanguageModelProviderName::new("Anthropic"),
retry_after: Some(Duration::from_secs(3)),
},
);

View File

@@ -15,7 +15,7 @@ use agent_settings::{
use anyhow::{Context as _, Result, anyhow};
use chrono::{DateTime, Utc};
use client::{ModelRequestUsage, RequestUsage, UserStore};
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit};
use cloud_llm_client::{CompletionIntent, Plan, UsageLimit};
use collections::{HashMap, HashSet, IndexMap};
use fs::Fs;
use futures::stream;
@@ -30,11 +30,11 @@ use gpui::{
};
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt,
LanguageModelId, LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry,
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
LanguageModelToolUse, LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage,
ZED_CLOUD_PROVIDER_ID,
LanguageModelId, LanguageModelImage, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role,
SelectedModel, StopReason, TokenUsage, ZED_CLOUD_PROVIDER_ID,
};
use project::Project;
use prompt_store::ProjectContext;
@@ -1295,9 +1295,10 @@ impl Thread {
if let Some(error) = error {
attempt += 1;
let provider = model.upstream_provider_name();
let retry = this.update(cx, |this, cx| {
let user_store = this.user_store.read(cx);
this.handle_completion_error(error, attempt, user_store.plan())
this.handle_completion_error(provider, error, attempt, user_store.plan())
})??;
let timer = cx.background_executor().timer(retry.duration);
event_stream.send_retry(retry);
@@ -1323,6 +1324,7 @@ impl Thread {
fn handle_completion_error(
&mut self,
provider: LanguageModelProviderName,
error: LanguageModelCompletionError,
attempt: u8,
plan: Option<Plan>,
@@ -1389,7 +1391,7 @@ impl Thread {
use LanguageModelCompletionEvent::*;
match event {
StartMessage { .. } => {
LanguageModelCompletionEvent::StartMessage { .. } => {
self.flush_pending_message(cx);
self.pending_message = Some(AgentMessage::default());
}
@@ -1416,7 +1418,7 @@ impl Thread {
),
)));
}
UsageUpdate(usage) => {
TokenUsage(usage) => {
telemetry::event!(
"Agent Thread Completion Usage Updated",
thread_id = self.id.to_string(),
@@ -1430,20 +1432,16 @@ impl Thread {
);
self.update_token_usage(usage, cx);
}
StatusUpdate(CompletionRequestStatus::UsageUpdated { amount, limit }) => {
RequestUsage { amount, limit } => {
self.update_model_request_usage(amount, limit, cx);
}
StatusUpdate(
CompletionRequestStatus::Started
| CompletionRequestStatus::Queued { .. }
| CompletionRequestStatus::Failed { .. },
) => {}
StatusUpdate(CompletionRequestStatus::ToolUseLimitReached) => {
ToolUseLimitReached => {
self.tool_use_limit_reached = true;
}
Stop(StopReason::Refusal) => return Err(CompletionError::Refusal.into()),
Stop(StopReason::MaxTokens) => return Err(CompletionError::MaxTokens.into()),
Stop(StopReason::ToolUse | StopReason::EndTurn) => {}
Started | Queued { .. } => {}
}
Ok(None)
@@ -1687,9 +1685,7 @@ impl Thread {
let event = event.log_err()?;
let text = match event {
LanguageModelCompletionEvent::Text(text) => text,
LanguageModelCompletionEvent::StatusUpdate(
CompletionRequestStatus::UsageUpdated { amount, limit },
) => {
LanguageModelCompletionEvent::RequestUsage { amount, limit } => {
this.update(cx, |thread, cx| {
thread.update_model_request_usage(amount, limit, cx);
})
@@ -1753,9 +1749,7 @@ impl Thread {
let event = event?;
let text = match event {
LanguageModelCompletionEvent::Text(text) => text,
LanguageModelCompletionEvent::StatusUpdate(
CompletionRequestStatus::UsageUpdated { amount, limit },
) => {
LanguageModelCompletionEvent::RequestUsage { amount, limit } => {
this.update(cx, |thread, cx| {
thread.update_model_request_usage(amount, limit, cx);
})?;

View File

@@ -7,9 +7,10 @@ use assistant_slash_command::{
use assistant_slash_commands::FileCommandMetadata;
use client::{self, ModelRequestUsage, RequestUsage, proto, telemetry::Telemetry};
use clock::ReplicaId;
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
use cloud_llm_client::{CompletionIntent, UsageLimit};
use collections::{HashMap, HashSet};
use fs::{Fs, RenameOptions};
use futures::{FutureExt, StreamExt, future::Shared};
use gpui::{
App, AppContext as _, Context, Entity, EventEmitter, RenderImage, SharedString, Subscription,
@@ -2073,14 +2074,15 @@ impl TextThread {
});
match event {
LanguageModelCompletionEvent::StatusUpdate(status_update) => {
if let CompletionRequestStatus::UsageUpdated { amount, limit } = status_update {
this.update_model_request_usage(
amount as u32,
limit,
cx,
);
}
LanguageModelCompletionEvent::Started |
LanguageModelCompletionEvent::Queued {..} |
LanguageModelCompletionEvent::ToolUseLimitReached { .. } => {}
LanguageModelCompletionEvent::RequestUsage { amount, limit } => {
this.update_model_request_usage(
amount as u32,
limit,
cx,
);
}
LanguageModelCompletionEvent::StartMessage { .. } => {}
LanguageModelCompletionEvent::Stop(reason) => {
@@ -2142,7 +2144,7 @@ impl TextThread {
}
LanguageModelCompletionEvent::ToolUse(_) |
LanguageModelCompletionEvent::ToolUseJsonParseError { .. } |
LanguageModelCompletionEvent::UsageUpdate(_) => {}
LanguageModelCompletionEvent::TokenUsage(_) => {}
}
});

View File

@@ -1250,9 +1250,12 @@ pub fn response_events_to_markdown(
));
}
Ok(
LanguageModelCompletionEvent::UsageUpdate(_)
LanguageModelCompletionEvent::TokenUsage(_)
| LanguageModelCompletionEvent::ToolUseLimitReached
| LanguageModelCompletionEvent::StartMessage { .. }
| LanguageModelCompletionEvent::StatusUpdate { .. },
| LanguageModelCompletionEvent::RequestUsage { .. }
| LanguageModelCompletionEvent::Queued { .. }
| LanguageModelCompletionEvent::Started,
) => {}
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
json_parse_error, ..
@@ -1335,11 +1338,14 @@ impl ThreadDialog {
}
// Skip these
Ok(LanguageModelCompletionEvent::UsageUpdate(_))
Ok(LanguageModelCompletionEvent::TokenUsage(_))
| Ok(LanguageModelCompletionEvent::RedactedThinking { .. })
| Ok(LanguageModelCompletionEvent::StatusUpdate { .. })
| Ok(LanguageModelCompletionEvent::StartMessage { .. })
| Ok(LanguageModelCompletionEvent::Stop(_)) => {}
| Ok(LanguageModelCompletionEvent::Stop(_))
| Ok(LanguageModelCompletionEvent::Queued { .. })
| Ok(LanguageModelCompletionEvent::Started)
| Ok(LanguageModelCompletionEvent::RequestUsage { .. })
| Ok(LanguageModelCompletionEvent::ToolUseLimitReached) => {}
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
json_parse_error,

View File

@@ -12,7 +12,7 @@ pub mod fake_provider;
use anthropic::{AnthropicError, parse_prompt_too_long};
use anyhow::{Result, anyhow};
use client::Client;
use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
use cloud_llm_client::{CompletionMode, CompletionRequestStatus, UsageLimit};
use futures::FutureExt;
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
@@ -70,7 +70,15 @@ pub fn init_settings(cx: &mut App) {
/// A completion event from a language model.
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub enum LanguageModelCompletionEvent {
StatusUpdate(CompletionRequestStatus),
Queued {
position: usize,
},
Started,
RequestUsage {
amount: usize,
limit: UsageLimit,
},
ToolUseLimitReached,
Stop(StopReason),
Text(String),
Thinking {
@@ -90,88 +98,93 @@ pub enum LanguageModelCompletionEvent {
StartMessage {
message_id: String,
},
UsageUpdate(TokenUsage),
TokenUsage(TokenUsage),
}
impl LanguageModelCompletionEvent {
pub fn from_completion_request_status(
status: CompletionRequestStatus,
) -> Result<Self, LanguageModelCompletionError> {
match status {
CompletionRequestStatus::Queued { position } => {
Ok(LanguageModelCompletionEvent::Queued { position })
}
CompletionRequestStatus::Started => Ok(LanguageModelCompletionEvent::Started),
CompletionRequestStatus::UsageUpdated { amount, limit } => {
Ok(LanguageModelCompletionEvent::RequestUsage { amount, limit })
}
CompletionRequestStatus::ToolUseLimitReached => {
Ok(LanguageModelCompletionEvent::ToolUseLimitReached)
}
CompletionRequestStatus::Failed {
code,
message,
request_id: _,
retry_after,
} => Err(LanguageModelCompletionError::from_cloud_failure(
code,
message,
retry_after.map(Duration::from_secs_f64),
)),
}
}
}
#[derive(Error, Debug)]
pub enum LanguageModelCompletionError {
#[error("prompt too large for context window")]
PromptTooLarge { tokens: Option<u64> },
#[error("missing {provider} API key")]
NoApiKey { provider: LanguageModelProviderName },
#[error("{provider}'s API rate limit exceeded")]
RateLimitExceeded {
provider: LanguageModelProviderName,
retry_after: Option<Duration>,
},
#[error("{provider}'s API servers are overloaded right now")]
ServerOverloaded {
provider: LanguageModelProviderName,
retry_after: Option<Duration>,
},
#[error("{provider}'s API server reported an internal server error: {message}")]
ApiInternalServerError {
provider: LanguageModelProviderName,
message: String,
},
#[error("missing API key")]
NoApiKey,
#[error("API rate limit exceeded")]
RateLimitExceeded { retry_after: Option<Duration> },
#[error("API servers are overloaded right now")]
ServerOverloaded { retry_after: Option<Duration> },
#[error("API server reported an internal server error: {message}")]
ApiInternalServerError { message: String },
#[error("{message}")]
UpstreamProviderError {
message: String,
status: StatusCode,
retry_after: Option<Duration>,
},
#[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
#[error("HTTP response error from API: status {status_code} - {message:?}")]
HttpResponseError {
provider: LanguageModelProviderName,
status_code: StatusCode,
message: String,
},
// Client errors
#[error("invalid request format to {provider}'s API: {message}")]
BadRequestFormat {
provider: LanguageModelProviderName,
message: String,
},
#[error("authentication error with {provider}'s API: {message}")]
AuthenticationError {
provider: LanguageModelProviderName,
message: String,
},
#[error("Permission error with {provider}'s API: {message}")]
PermissionError {
provider: LanguageModelProviderName,
message: String,
},
#[error("invalid request format to API: {message}")]
BadRequestFormat { message: String },
#[error("authentication error with API: {message}")]
AuthenticationError { message: String },
#[error("Permission error with API: {message}")]
PermissionError { message: String },
#[error("language model provider API endpoint not found")]
ApiEndpointNotFound { provider: LanguageModelProviderName },
#[error("I/O error reading response from {provider}'s API")]
ApiEndpointNotFound,
#[error("I/O error reading response from API")]
ApiReadResponseError {
provider: LanguageModelProviderName,
#[source]
error: io::Error,
},
#[error("error serializing request to {provider} API")]
#[error("error serializing request to API")]
SerializeRequest {
provider: LanguageModelProviderName,
#[source]
error: serde_json::Error,
},
#[error("error building request body to {provider} API")]
#[error("error building request body to API")]
BuildRequestBody {
provider: LanguageModelProviderName,
#[source]
error: http::Error,
},
#[error("error sending HTTP request to {provider} API")]
#[error("error sending HTTP request to API")]
HttpSend {
provider: LanguageModelProviderName,
#[source]
error: anyhow::Error,
},
#[error("error deserializing {provider} API response")]
#[error("error deserializing API response")]
DeserializeResponse {
provider: LanguageModelProviderName,
#[source]
error: serde_json::Error,
},
@@ -182,6 +195,72 @@ pub enum LanguageModelCompletionError {
}
impl LanguageModelCompletionError {
fn display_format(&self, provider: LanguageModelProviderName) {
// match self {
// #[error("prompt too large for context window")]
// PromptTooLarge { tokens: Option<u64> },
// #[error("missing API key")]
// NoApiKey,
// #[error("API rate limit exceeded")]
// RateLimitExceeded { retry_after: Option<Duration> },
// #[error("API servers are overloaded right now")]
// ServerOverloaded { retry_after: Option<Duration> },
// #[error("API server reported an internal server error: {message}")]
// ApiInternalServerError { message: String },
// #[error("{message}")]
// UpstreamProviderError {
// message: String,
// status: StatusCode,
// retry_after: Option<Duration>,
// },
// #[error("HTTP response error from API: status {status_code} - {message:?}")]
// HttpResponseError {
// status_code: StatusCode,
// message: String,
// },
// // Client errors
// #[error("invalid request format to API: {message}")]
// BadRequestFormat { message: String },
// #[error("authentication error with API: {message}")]
// AuthenticationError { message: String },
// #[error("Permission error with API: {message}")]
// PermissionError { message: String },
// #[error("language model provider API endpoint not found")]
// ApiEndpointNotFound,
// #[error("I/O error reading response from API")]
// ApiReadResponseError {
// #[source]
// error: io::Error,
// },
// #[error("error serializing request to API")]
// SerializeRequest {
// #[source]
// error: serde_json::Error,
// },
// #[error("error building request body to API")]
// BuildRequestBody {
// #[source]
// error: http::Error,
// },
// #[error("error sending HTTP request to API")]
// HttpSend {
// #[source]
// error: anyhow::Error,
// },
// #[error("error deserializing API response")]
// DeserializeResponse {
// #[source]
// error: serde_json::Error,
// },
// // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
// #[error(transparent)]
// Other(#[from] anyhow::Error),
// }
}
fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
let upstream_status = error_json
@@ -198,7 +277,6 @@ impl LanguageModelCompletionError {
}
pub fn from_cloud_failure(
upstream_provider: LanguageModelProviderName,
code: String,
message: String,
retry_after: Option<Duration>,
@@ -214,58 +292,46 @@ impl LanguageModelCompletionError {
if let Some((upstream_status, inner_message)) =
Self::parse_upstream_error_json(&message)
{
return Self::from_http_status(
upstream_provider,
upstream_status,
inner_message,
retry_after,
);
return Self::from_http_status(upstream_status, inner_message, retry_after);
}
anyhow!("completion request failed, code: {code}, message: {message}").into()
Self::Other(anyhow!(
"completion request failed, code: {code}, message: {message}"
))
} else if let Some(status_code) = code
.strip_prefix("upstream_http_")
.and_then(|code| StatusCode::from_str(code).ok())
{
Self::from_http_status(upstream_provider, status_code, message, retry_after)
Self::from_http_status(status_code, message, retry_after)
} else if let Some(status_code) = code
.strip_prefix("http_")
.and_then(|code| StatusCode::from_str(code).ok())
{
Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
Self::from_http_status(status_code, message, retry_after)
} else {
anyhow!("completion request failed, code: {code}, message: {message}").into()
Self::Other(anyhow!(
"completion request failed, code: {code}, message: {message}"
))
}
}
pub fn from_http_status(
provider: LanguageModelProviderName,
status_code: StatusCode,
message: String,
retry_after: Option<Duration>,
) -> Self {
match status_code {
StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
StatusCode::BAD_REQUEST => Self::BadRequestFormat { message },
StatusCode::UNAUTHORIZED => Self::AuthenticationError { message },
StatusCode::FORBIDDEN => Self::PermissionError { message },
StatusCode::NOT_FOUND => Self::ApiEndpointNotFound,
StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
tokens: parse_prompt_too_long(&message),
},
StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
provider,
retry_after,
},
StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
provider,
retry_after,
},
_ if status_code.as_u16() == 529 => Self::ServerOverloaded {
provider,
retry_after,
},
StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded { retry_after },
StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { message },
StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded { retry_after },
_ if status_code.as_u16() == 529 => Self::ServerOverloaded { retry_after },
_ => Self::HttpResponseError {
provider,
status_code,
message,
},
@@ -275,31 +341,25 @@ impl LanguageModelCompletionError {
impl From<AnthropicError> for LanguageModelCompletionError {
fn from(error: AnthropicError) -> Self {
let provider = ANTHROPIC_PROVIDER_NAME;
match error {
AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
AnthropicError::DeserializeResponse(error) => {
Self::DeserializeResponse { provider, error }
}
AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
AnthropicError::SerializeRequest(error) => Self::SerializeRequest { error },
AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { error },
AnthropicError::HttpSend(error) => Self::HttpSend { error },
AnthropicError::DeserializeResponse(error) => Self::DeserializeResponse { error },
AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { error },
AnthropicError::HttpResponseError {
status_code,
message,
} => Self::HttpResponseError {
provider,
status_code,
message,
},
AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
provider,
retry_after: Some(retry_after),
},
AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
provider,
retry_after,
},
AnthropicError::ServerOverloaded { retry_after } => {
Self::ServerOverloaded { retry_after }
}
AnthropicError::ApiError(api_error) => api_error.into(),
}
}
@@ -308,37 +368,26 @@ impl From<AnthropicError> for LanguageModelCompletionError {
impl From<anthropic::ApiError> for LanguageModelCompletionError {
fn from(error: anthropic::ApiError) -> Self {
use anthropic::ApiErrorCode::*;
let provider = ANTHROPIC_PROVIDER_NAME;
match error.code() {
Some(code) => match code {
InvalidRequestError => Self::BadRequestFormat {
provider,
message: error.message,
},
AuthenticationError => Self::AuthenticationError {
provider,
message: error.message,
},
PermissionError => Self::PermissionError {
provider,
message: error.message,
},
NotFoundError => Self::ApiEndpointNotFound { provider },
NotFoundError => Self::ApiEndpointNotFound,
RequestTooLarge => Self::PromptTooLarge {
tokens: parse_prompt_too_long(&error.message),
},
RateLimitError => Self::RateLimitExceeded {
provider,
retry_after: None,
},
RateLimitError => Self::RateLimitExceeded { retry_after: None },
ApiError => Self::ApiInternalServerError {
provider,
message: error.message,
},
OverloadedError => Self::ServerOverloaded {
provider,
retry_after: None,
},
OverloadedError => Self::ServerOverloaded { retry_after: None },
},
None => Self::Other(error.into()),
}
@@ -349,7 +398,7 @@ impl From<open_ai::RequestError> for LanguageModelCompletionError {
fn from(error: open_ai::RequestError) -> Self {
match error {
open_ai::RequestError::HttpResponseError {
provider,
provider: _,
status_code,
body,
headers,
@@ -359,7 +408,7 @@ impl From<open_ai::RequestError> for LanguageModelCompletionError {
.and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
.map(Duration::from_secs);
Self::from_http_status(provider.into(), status_code, body, retry_after)
Self::from_http_status(status_code, body, retry_after)
}
open_ai::RequestError::Other(e) => Self::Other(e),
}
@@ -368,23 +417,18 @@ impl From<open_ai::RequestError> for LanguageModelCompletionError {
impl From<OpenRouterError> for LanguageModelCompletionError {
fn from(error: OpenRouterError) -> Self {
let provider = LanguageModelProviderName::new("OpenRouter");
match error {
OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
OpenRouterError::DeserializeResponse(error) => {
Self::DeserializeResponse { provider, error }
}
OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { error },
OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { error },
OpenRouterError::HttpSend(error) => Self::HttpSend { error },
OpenRouterError::DeserializeResponse(error) => Self::DeserializeResponse { error },
OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { error },
OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
provider,
retry_after: Some(retry_after),
},
OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
provider,
retry_after,
},
OpenRouterError::ServerOverloaded { retry_after } => {
Self::ServerOverloaded { retry_after }
}
OpenRouterError::ApiError(api_error) => api_error.into(),
}
}
@@ -393,41 +437,28 @@ impl From<OpenRouterError> for LanguageModelCompletionError {
impl From<open_router::ApiError> for LanguageModelCompletionError {
fn from(error: open_router::ApiError) -> Self {
use open_router::ApiErrorCode::*;
let provider = LanguageModelProviderName::new("OpenRouter");
match error.code {
InvalidRequestError => Self::BadRequestFormat {
provider,
message: error.message,
},
AuthenticationError => Self::AuthenticationError {
provider,
message: error.message,
},
PaymentRequiredError => Self::AuthenticationError {
provider,
message: format!("Payment required: {}", error.message),
},
PermissionError => Self::PermissionError {
provider,
message: error.message,
},
RequestTimedOut => Self::HttpResponseError {
provider,
status_code: StatusCode::REQUEST_TIMEOUT,
message: error.message,
},
RateLimitError => Self::RateLimitExceeded {
provider,
retry_after: None,
},
RateLimitError => Self::RateLimitExceeded { retry_after: None },
ApiError => Self::ApiInternalServerError {
provider,
message: error.message,
},
OverloadedError => Self::ServerOverloaded {
provider,
retry_after: None,
},
OverloadedError => Self::ServerOverloaded { retry_after: None },
}
}
}
@@ -633,7 +664,10 @@ pub trait LanguageModel: Send + Sync {
let last_token_usage = last_token_usage.clone();
async move {
match result {
Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
Ok(LanguageModelCompletionEvent::Queued { .. }) => None,
Ok(LanguageModelCompletionEvent::Started) => None,
Ok(LanguageModelCompletionEvent::RequestUsage { .. }) => None,
Ok(LanguageModelCompletionEvent::ToolUseLimitReached) => None,
Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
@@ -643,7 +677,7 @@ pub trait LanguageModel: Send + Sync {
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
..
}) => None,
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
Ok(LanguageModelCompletionEvent::TokenUsage(token_usage)) => {
*last_token_usage.lock() = token_usage;
None
}
@@ -832,16 +866,13 @@ mod tests {
#[test]
fn test_from_cloud_failure_with_upstream_http_error() {
let error = LanguageModelCompletionError::from_cloud_failure(
String::from("anthropic").into(),
"upstream_http_error".to_string(),
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
None,
);
match error {
LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
assert_eq!(provider.0, "anthropic");
}
LanguageModelCompletionError::ServerOverloaded { .. } => {}
_ => panic!(
"Expected ServerOverloaded error for 503 status, got: {:?}",
error
@@ -849,15 +880,13 @@ mod tests {
}
let error = LanguageModelCompletionError::from_cloud_failure(
String::from("anthropic").into(),
"upstream_http_error".to_string(),
r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
None,
);
match error {
LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
assert_eq!(provider.0, "anthropic");
LanguageModelCompletionError::ApiInternalServerError { message } => {
assert_eq!(message, "Internal server error");
}
_ => panic!(
@@ -870,16 +899,13 @@ mod tests {
#[test]
fn test_from_cloud_failure_with_standard_format() {
let error = LanguageModelCompletionError::from_cloud_failure(
String::from("anthropic").into(),
"upstream_http_503".to_string(),
"Service unavailable".to_string(),
None,
);
match error {
LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
assert_eq!(provider.0, "anthropic");
}
LanguageModelCompletionError::ServerOverloaded { .. } => {}
_ => panic!("Expected ServerOverloaded error for upstream_http_503"),
}
}
@@ -887,16 +913,13 @@ mod tests {
#[test]
fn test_upstream_http_error_connection_timeout() {
let error = LanguageModelCompletionError::from_cloud_failure(
String::from("anthropic").into(),
"upstream_http_error".to_string(),
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
None,
);
match error {
LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
assert_eq!(provider.0, "anthropic");
}
LanguageModelCompletionError::ServerOverloaded { .. } => {}
_ => panic!(
"Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
error
@@ -904,15 +927,13 @@ mod tests {
}
let error = LanguageModelCompletionError::from_cloud_failure(
String::from("anthropic").into(),
"upstream_http_error".to_string(),
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
None,
);
match error {
LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
assert_eq!(provider.0, "anthropic");
LanguageModelCompletionError::ApiInternalServerError { message } => {
assert_eq!(
message,
"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"

View File

@@ -320,9 +320,7 @@ impl AnthropicModel {
async move {
let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey {
provider: PROVIDER_NAME,
});
return Err(LanguageModelCompletionError::NoApiKey);
};
let request = anthropic::stream_completion(
http_client.as_ref(),
@@ -756,7 +754,7 @@ impl AnthropicEventMapper {
Event::MessageStart { message } => {
update_usage(&mut self.usage, &message.usage);
vec![
Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage(
Ok(LanguageModelCompletionEvent::TokenUsage(convert_usage(
&self.usage,
))),
Ok(LanguageModelCompletionEvent::StartMessage {
@@ -778,9 +776,9 @@ impl AnthropicEventMapper {
}
};
}
vec![Ok(LanguageModelCompletionEvent::UsageUpdate(
convert_usage(&self.usage),
))]
vec![Ok(LanguageModelCompletionEvent::TokenUsage(convert_usage(
&self.usage,
)))]
}
Event::MessageStop => {
vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))]

View File

@@ -975,7 +975,7 @@ pub fn map_to_language_model_completion_events(
))
}),
ConverseStreamOutput::Metadata(cb_meta) => cb_meta.usage.map(|metadata| {
Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: metadata.input_tokens as u64,
output_tokens: metadata.output_tokens as u64,
cache_creation_input_tokens: metadata

View File

@@ -541,7 +541,6 @@ impl From<ApiError> for LanguageModelCompletionError {
}
return LanguageModelCompletionError::from_http_status(
PROVIDER_NAME,
error.status,
cloud_error.message,
None,
@@ -549,12 +548,7 @@ impl From<ApiError> for LanguageModelCompletionError {
}
let retry_after = None;
LanguageModelCompletionError::from_http_status(
PROVIDER_NAME,
error.status,
error.body,
retry_after,
)
LanguageModelCompletionError::from_http_status(error.status, error.body, retry_after)
}
}
@@ -961,7 +955,7 @@ where
vec![Err(LanguageModelCompletionError::from(error))]
}
Ok(CompletionEvent::Status(event)) => {
vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
vec![LanguageModelCompletionEvent::from_completion_request_status(event)]
}
Ok(CompletionEvent::Event(event)) => map_callback(event),
})
@@ -1313,8 +1307,7 @@ mod tests {
let completion_error: LanguageModelCompletionError = api_error.into();
match completion_error {
LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
assert_eq!(provider, PROVIDER_NAME);
LanguageModelCompletionError::ApiInternalServerError { message } => {
assert_eq!(message, "Regular internal server error");
}
_ => panic!(
@@ -1362,9 +1355,7 @@ mod tests {
let completion_error: LanguageModelCompletionError = api_error.into();
match completion_error {
LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
assert_eq!(provider, PROVIDER_NAME);
}
LanguageModelCompletionError::ApiInternalServerError { .. } => {}
_ => panic!(
"Expected ApiInternalServerError for invalid JSON, got: {:?}",
completion_error

View File

@@ -422,14 +422,12 @@ pub fn map_to_language_model_completion_events(
}
if let Some(usage) = event.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
TokenUsage {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
},
)));
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
})));
}
match choice.finish_reason.as_deref() {
@@ -610,7 +608,7 @@ impl CopilotResponsesEventMapper {
copilot::copilot_responses::StreamEvent::Completed { response } => {
let mut events = Vec::new();
if let Some(usage) = response.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: usage.input_tokens.unwrap_or(0),
output_tokens: usage.output_tokens.unwrap_or(0),
cache_creation_input_tokens: 0,
@@ -643,7 +641,7 @@ impl CopilotResponsesEventMapper {
let mut events = Vec::new();
if let Some(usage) = response.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: usage.input_tokens.unwrap_or(0),
output_tokens: usage.output_tokens.unwrap_or(0),
cache_creation_input_tokens: 0,
@@ -655,7 +653,6 @@ impl CopilotResponsesEventMapper {
}
copilot::copilot_responses::StreamEvent::Failed { response } => {
let provider = PROVIDER_NAME;
let (status_code, message) = match response.error {
Some(error) => {
let status_code = StatusCode::from_str(&error.code)
@@ -668,7 +665,6 @@ impl CopilotResponsesEventMapper {
),
};
vec![Err(LanguageModelCompletionError::HttpResponseError {
provider,
status_code,
message,
})]
@@ -1099,7 +1095,7 @@ mod tests {
));
assert!(matches!(
mapped[2],
LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: 5,
output_tokens: 3,
..
@@ -1207,7 +1203,7 @@ mod tests {
let mapped = map_events(events);
assert!(matches!(
mapped[0],
LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: 10,
output_tokens: 0,
..

View File

@@ -224,9 +224,7 @@ impl DeepSeekLanguageModel {
let future = self.request_limiter.stream(async move {
let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey {
provider: PROVIDER_NAME,
});
return Err(LanguageModelCompletionError::NoApiKey);
};
let request =
deepseek::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
@@ -479,7 +477,7 @@ impl DeepSeekEventMapper {
}
if let Some(usage) = event.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
cache_creation_input_tokens: 0,

View File

@@ -350,10 +350,7 @@ impl LanguageModel for GoogleLanguageModel {
async move {
let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey {
provider: PROVIDER_NAME,
}
.into());
return Err(LanguageModelCompletionError::NoApiKey.into());
};
let response = google_ai::count_tokens(
http_client.as_ref(),
@@ -608,9 +605,9 @@ impl GoogleEventMapper {
let mut wants_to_use_tool = false;
if let Some(usage_metadata) = event.usage_metadata {
update_usage(&mut self.usage, &usage_metadata);
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
convert_usage(&self.usage),
)))
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(convert_usage(
&self.usage,
))))
}
if let Some(prompt_feedback) = event.prompt_feedback

View File

@@ -547,7 +547,7 @@ impl LmStudioEventMapper {
}
if let Some(usage) = event.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
cache_creation_input_tokens: 0,

View File

@@ -291,9 +291,7 @@ impl MistralLanguageModel {
let future = self.request_limiter.stream(async move {
let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey {
provider: PROVIDER_NAME,
});
return Err(LanguageModelCompletionError::NoApiKey);
};
let request =
mistral::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
@@ -672,7 +670,7 @@ impl MistralEventMapper {
}
if let Some(usage) = event.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
cache_creation_input_tokens: 0,

View File

@@ -603,7 +603,7 @@ fn map_to_language_model_completion_events(
};
if delta.done {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: delta.prompt_eval_count.unwrap_or(0),
output_tokens: delta.eval_count.unwrap_or(0),
cache_creation_input_tokens: 0,

View File

@@ -228,7 +228,7 @@ impl OpenAiLanguageModel {
let future = self.request_limiter.stream(async move {
let provider = PROVIDER_NAME;
let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey { provider });
return Err(LanguageModelCompletionError::NoApiKey);
};
let request = stream_completion(
http_client.as_ref(),
@@ -534,7 +534,7 @@ impl OpenAiEventMapper {
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
let mut events = Vec::new();
if let Some(usage) = event.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
cache_creation_input_tokens: 0,

View File

@@ -227,7 +227,7 @@ impl OpenAiCompatibleLanguageModel {
let provider = self.provider_name.clone();
let future = self.request_limiter.stream(async move {
let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey { provider });
return Err(LanguageModelCompletionError::NoApiKey);
};
let request = stream_completion(
http_client.as_ref(),

View File

@@ -84,9 +84,7 @@ impl State {
let http_client = self.http_client.clone();
let api_url = OpenRouterLanguageModelProvider::api_url(cx);
let Some(api_key) = self.api_key_state.key(&api_url) else {
return Task::ready(Err(LanguageModelCompletionError::NoApiKey {
provider: PROVIDER_NAME,
}));
return Task::ready(Err(LanguageModelCompletionError::NoApiKey));
};
cx.spawn(async move |this, cx| {
let models = list_models(http_client.as_ref(), &api_url, &api_key)
@@ -288,9 +286,7 @@ impl OpenRouterLanguageModel {
async move {
let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey {
provider: PROVIDER_NAME,
});
return Err(LanguageModelCompletionError::NoApiKey);
};
let request =
open_router::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
@@ -613,7 +609,7 @@ impl OpenRouterEventMapper {
}
if let Some(usage) = event.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens,
cache_creation_input_tokens: 0,

View File

@@ -222,7 +222,7 @@ impl VercelLanguageModel {
let future = self.request_limiter.stream(async move {
let provider = PROVIDER_NAME;
let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey { provider });
return Err(LanguageModelCompletionError::NoApiKey);
};
let request = open_ai::stream_completion(
http_client.as_ref(),

View File

@@ -230,7 +230,7 @@ impl XAiLanguageModel {
let future = self.request_limiter.stream(async move {
let provider = PROVIDER_NAME;
let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey { provider });
return Err(LanguageModelCompletionError::NoApiKey);
};
let request = open_ai::stream_completion(
http_client.as_ref(),