Compare commits

...

5 Commits

Author SHA1 Message Date
Mikayla Maki
5b74ddf020 WIP: Remove provider name 2025-11-18 23:56:05 -08:00
Mikayla Maki
54f20ae5d5 Remove provider name 2025-11-18 22:13:04 -08:00
Mikayla Maki
811efa45d0 Disambiguate similar completion events 2025-11-18 21:04:35 -08:00
Mikayla Maki
74501e0936 Simplify LanguageModelCompletionEvent enum, remove
`StatusUpdate::Failed` state by turning it into an error early, and then
faltten `StatusUpdate` into LanguageModelCompletionEvent
2025-11-18 20:56:11 -08:00
Michael Benfield
2ad8bd00ce Simplifying errors
Co-authored-by: Mikayla <mikayla@zed.dev>
2025-11-18 20:45:22 -08:00
19 changed files with 263 additions and 273 deletions

View File

@@ -21,9 +21,9 @@ use gpui::{
use indoc::indoc; use indoc::indoc;
use language_model::{ use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat, LanguageModelToolResult, LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent,
LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel, Role, StopReason, fake_provider::FakeLanguageModel,
}; };
use pretty_assertions::assert_eq; use pretty_assertions::assert_eq;
use project::{ use project::{
@@ -664,9 +664,7 @@ async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
); );
// Simulate reaching tool use limit. // Simulate reaching tool use limit.
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate( fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
));
fake_model.end_last_completion_stream(); fake_model.end_last_completion_stream();
let last_event = events.collect::<Vec<_>>().await.pop().unwrap(); let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
assert!( assert!(
@@ -749,9 +747,7 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
}; };
fake_model fake_model
.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone())); .send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate( fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
));
fake_model.end_last_completion_stream(); fake_model.end_last_completion_stream();
let last_event = events.collect::<Vec<_>>().await.pop().unwrap(); let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
assert!( assert!(
@@ -1533,7 +1529,7 @@ async fn test_truncate_first_message(cx: &mut TestAppContext) {
}); });
fake_model.send_last_completion_stream_text_chunk("Hey!"); fake_model.send_last_completion_stream_text_chunk("Hey!");
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate( fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::TokenUsage(
language_model::TokenUsage { language_model::TokenUsage {
input_tokens: 32_000, input_tokens: 32_000,
output_tokens: 16_000, output_tokens: 16_000,
@@ -1591,7 +1587,7 @@ async fn test_truncate_first_message(cx: &mut TestAppContext) {
}); });
cx.run_until_parked(); cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Ahoy!"); fake_model.send_last_completion_stream_text_chunk("Ahoy!");
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate( fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::TokenUsage(
language_model::TokenUsage { language_model::TokenUsage {
input_tokens: 40_000, input_tokens: 40_000,
output_tokens: 20_000, output_tokens: 20_000,
@@ -1636,7 +1632,7 @@ async fn test_truncate_second_message(cx: &mut TestAppContext) {
.unwrap(); .unwrap();
cx.run_until_parked(); cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Message 1 response"); fake_model.send_last_completion_stream_text_chunk("Message 1 response");
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate( fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::TokenUsage(
language_model::TokenUsage { language_model::TokenUsage {
input_tokens: 32_000, input_tokens: 32_000,
output_tokens: 16_000, output_tokens: 16_000,
@@ -1683,7 +1679,7 @@ async fn test_truncate_second_message(cx: &mut TestAppContext) {
cx.run_until_parked(); cx.run_until_parked();
fake_model.send_last_completion_stream_text_chunk("Message 2 response"); fake_model.send_last_completion_stream_text_chunk("Message 2 response");
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate( fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::TokenUsage(
language_model::TokenUsage { language_model::TokenUsage {
input_tokens: 40_000, input_tokens: 40_000,
output_tokens: 20_000, output_tokens: 20_000,
@@ -2159,7 +2155,6 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
fake_model.send_last_completion_stream_text_chunk("Hey,"); fake_model.send_last_completion_stream_text_chunk("Hey,");
fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded { fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
provider: LanguageModelProviderName::new("Anthropic"),
retry_after: Some(Duration::from_secs(3)), retry_after: Some(Duration::from_secs(3)),
}); });
fake_model.end_last_completion_stream(); fake_model.end_last_completion_stream();
@@ -2235,7 +2230,6 @@ async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
tool_use_1.clone(), tool_use_1.clone(),
)); ));
fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded { fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
provider: LanguageModelProviderName::new("Anthropic"),
retry_after: Some(Duration::from_secs(3)), retry_after: Some(Duration::from_secs(3)),
}); });
fake_model.end_last_completion_stream(); fake_model.end_last_completion_stream();
@@ -2302,7 +2296,6 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 { for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
fake_model.send_last_completion_stream_error( fake_model.send_last_completion_stream_error(
LanguageModelCompletionError::ServerOverloaded { LanguageModelCompletionError::ServerOverloaded {
provider: LanguageModelProviderName::new("Anthropic"),
retry_after: Some(Duration::from_secs(3)), retry_after: Some(Duration::from_secs(3)),
}, },
); );

View File

@@ -15,7 +15,7 @@ use agent_settings::{
use anyhow::{Context as _, Result, anyhow}; use anyhow::{Context as _, Result, anyhow};
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use client::{ModelRequestUsage, RequestUsage, UserStore}; use client::{ModelRequestUsage, RequestUsage, UserStore};
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit}; use cloud_llm_client::{CompletionIntent, Plan, UsageLimit};
use collections::{HashMap, HashSet, IndexMap}; use collections::{HashMap, HashSet, IndexMap};
use fs::Fs; use fs::Fs;
use futures::stream; use futures::stream;
@@ -30,11 +30,11 @@ use gpui::{
}; };
use language_model::{ use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt,
LanguageModelId, LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry, LanguageModelId, LanguageModelImage, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat, LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
LanguageModelToolUse, LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage, LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role,
ZED_CLOUD_PROVIDER_ID, SelectedModel, StopReason, TokenUsage, ZED_CLOUD_PROVIDER_ID,
}; };
use project::Project; use project::Project;
use prompt_store::ProjectContext; use prompt_store::ProjectContext;
@@ -1295,9 +1295,10 @@ impl Thread {
if let Some(error) = error { if let Some(error) = error {
attempt += 1; attempt += 1;
let provider = model.upstream_provider_name();
let retry = this.update(cx, |this, cx| { let retry = this.update(cx, |this, cx| {
let user_store = this.user_store.read(cx); let user_store = this.user_store.read(cx);
this.handle_completion_error(error, attempt, user_store.plan()) this.handle_completion_error(provider, error, attempt, user_store.plan())
})??; })??;
let timer = cx.background_executor().timer(retry.duration); let timer = cx.background_executor().timer(retry.duration);
event_stream.send_retry(retry); event_stream.send_retry(retry);
@@ -1323,6 +1324,7 @@ impl Thread {
fn handle_completion_error( fn handle_completion_error(
&mut self, &mut self,
provider: LanguageModelProviderName,
error: LanguageModelCompletionError, error: LanguageModelCompletionError,
attempt: u8, attempt: u8,
plan: Option<Plan>, plan: Option<Plan>,
@@ -1389,7 +1391,7 @@ impl Thread {
use LanguageModelCompletionEvent::*; use LanguageModelCompletionEvent::*;
match event { match event {
StartMessage { .. } => { LanguageModelCompletionEvent::StartMessage { .. } => {
self.flush_pending_message(cx); self.flush_pending_message(cx);
self.pending_message = Some(AgentMessage::default()); self.pending_message = Some(AgentMessage::default());
} }
@@ -1416,7 +1418,7 @@ impl Thread {
), ),
))); )));
} }
UsageUpdate(usage) => { TokenUsage(usage) => {
telemetry::event!( telemetry::event!(
"Agent Thread Completion Usage Updated", "Agent Thread Completion Usage Updated",
thread_id = self.id.to_string(), thread_id = self.id.to_string(),
@@ -1430,20 +1432,16 @@ impl Thread {
); );
self.update_token_usage(usage, cx); self.update_token_usage(usage, cx);
} }
StatusUpdate(CompletionRequestStatus::UsageUpdated { amount, limit }) => { RequestUsage { amount, limit } => {
self.update_model_request_usage(amount, limit, cx); self.update_model_request_usage(amount, limit, cx);
} }
StatusUpdate( ToolUseLimitReached => {
CompletionRequestStatus::Started
| CompletionRequestStatus::Queued { .. }
| CompletionRequestStatus::Failed { .. },
) => {}
StatusUpdate(CompletionRequestStatus::ToolUseLimitReached) => {
self.tool_use_limit_reached = true; self.tool_use_limit_reached = true;
} }
Stop(StopReason::Refusal) => return Err(CompletionError::Refusal.into()), Stop(StopReason::Refusal) => return Err(CompletionError::Refusal.into()),
Stop(StopReason::MaxTokens) => return Err(CompletionError::MaxTokens.into()), Stop(StopReason::MaxTokens) => return Err(CompletionError::MaxTokens.into()),
Stop(StopReason::ToolUse | StopReason::EndTurn) => {} Stop(StopReason::ToolUse | StopReason::EndTurn) => {}
Started | Queued { .. } => {}
} }
Ok(None) Ok(None)
@@ -1687,9 +1685,7 @@ impl Thread {
let event = event.log_err()?; let event = event.log_err()?;
let text = match event { let text = match event {
LanguageModelCompletionEvent::Text(text) => text, LanguageModelCompletionEvent::Text(text) => text,
LanguageModelCompletionEvent::StatusUpdate( LanguageModelCompletionEvent::RequestUsage { amount, limit } => {
CompletionRequestStatus::UsageUpdated { amount, limit },
) => {
this.update(cx, |thread, cx| { this.update(cx, |thread, cx| {
thread.update_model_request_usage(amount, limit, cx); thread.update_model_request_usage(amount, limit, cx);
}) })
@@ -1753,9 +1749,7 @@ impl Thread {
let event = event?; let event = event?;
let text = match event { let text = match event {
LanguageModelCompletionEvent::Text(text) => text, LanguageModelCompletionEvent::Text(text) => text,
LanguageModelCompletionEvent::StatusUpdate( LanguageModelCompletionEvent::RequestUsage { amount, limit } => {
CompletionRequestStatus::UsageUpdated { amount, limit },
) => {
this.update(cx, |thread, cx| { this.update(cx, |thread, cx| {
thread.update_model_request_usage(amount, limit, cx); thread.update_model_request_usage(amount, limit, cx);
})?; })?;

View File

@@ -7,9 +7,10 @@ use assistant_slash_command::{
use assistant_slash_commands::FileCommandMetadata; use assistant_slash_commands::FileCommandMetadata;
use client::{self, ModelRequestUsage, RequestUsage, proto, telemetry::Telemetry}; use client::{self, ModelRequestUsage, RequestUsage, proto, telemetry::Telemetry};
use clock::ReplicaId; use clock::ReplicaId;
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit}; use cloud_llm_client::{CompletionIntent, UsageLimit};
use collections::{HashMap, HashSet}; use collections::{HashMap, HashSet};
use fs::{Fs, RenameOptions}; use fs::{Fs, RenameOptions};
use futures::{FutureExt, StreamExt, future::Shared}; use futures::{FutureExt, StreamExt, future::Shared};
use gpui::{ use gpui::{
App, AppContext as _, Context, Entity, EventEmitter, RenderImage, SharedString, Subscription, App, AppContext as _, Context, Entity, EventEmitter, RenderImage, SharedString, Subscription,
@@ -2073,14 +2074,15 @@ impl TextThread {
}); });
match event { match event {
LanguageModelCompletionEvent::StatusUpdate(status_update) => { LanguageModelCompletionEvent::Started |
if let CompletionRequestStatus::UsageUpdated { amount, limit } = status_update { LanguageModelCompletionEvent::Queued {..} |
this.update_model_request_usage( LanguageModelCompletionEvent::ToolUseLimitReached { .. } => {}
amount as u32, LanguageModelCompletionEvent::RequestUsage { amount, limit } => {
limit, this.update_model_request_usage(
cx, amount as u32,
); limit,
} cx,
);
} }
LanguageModelCompletionEvent::StartMessage { .. } => {} LanguageModelCompletionEvent::StartMessage { .. } => {}
LanguageModelCompletionEvent::Stop(reason) => { LanguageModelCompletionEvent::Stop(reason) => {
@@ -2142,7 +2144,7 @@ impl TextThread {
} }
LanguageModelCompletionEvent::ToolUse(_) | LanguageModelCompletionEvent::ToolUse(_) |
LanguageModelCompletionEvent::ToolUseJsonParseError { .. } | LanguageModelCompletionEvent::ToolUseJsonParseError { .. } |
LanguageModelCompletionEvent::UsageUpdate(_) => {} LanguageModelCompletionEvent::TokenUsage(_) => {}
} }
}); });

View File

@@ -1250,9 +1250,12 @@ pub fn response_events_to_markdown(
)); ));
} }
Ok( Ok(
LanguageModelCompletionEvent::UsageUpdate(_) LanguageModelCompletionEvent::TokenUsage(_)
| LanguageModelCompletionEvent::ToolUseLimitReached
| LanguageModelCompletionEvent::StartMessage { .. } | LanguageModelCompletionEvent::StartMessage { .. }
| LanguageModelCompletionEvent::StatusUpdate { .. }, | LanguageModelCompletionEvent::RequestUsage { .. }
| LanguageModelCompletionEvent::Queued { .. }
| LanguageModelCompletionEvent::Started,
) => {} ) => {}
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
json_parse_error, .. json_parse_error, ..
@@ -1335,11 +1338,14 @@ impl ThreadDialog {
} }
// Skip these // Skip these
Ok(LanguageModelCompletionEvent::UsageUpdate(_)) Ok(LanguageModelCompletionEvent::TokenUsage(_))
| Ok(LanguageModelCompletionEvent::RedactedThinking { .. }) | Ok(LanguageModelCompletionEvent::RedactedThinking { .. })
| Ok(LanguageModelCompletionEvent::StatusUpdate { .. })
| Ok(LanguageModelCompletionEvent::StartMessage { .. }) | Ok(LanguageModelCompletionEvent::StartMessage { .. })
| Ok(LanguageModelCompletionEvent::Stop(_)) => {} | Ok(LanguageModelCompletionEvent::Stop(_))
| Ok(LanguageModelCompletionEvent::Queued { .. })
| Ok(LanguageModelCompletionEvent::Started)
| Ok(LanguageModelCompletionEvent::RequestUsage { .. })
| Ok(LanguageModelCompletionEvent::ToolUseLimitReached) => {}
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
json_parse_error, json_parse_error,

View File

@@ -12,7 +12,7 @@ pub mod fake_provider;
use anthropic::{AnthropicError, parse_prompt_too_long}; use anthropic::{AnthropicError, parse_prompt_too_long};
use anyhow::{Result, anyhow}; use anyhow::{Result, anyhow};
use client::Client; use client::Client;
use cloud_llm_client::{CompletionMode, CompletionRequestStatus}; use cloud_llm_client::{CompletionMode, CompletionRequestStatus, UsageLimit};
use futures::FutureExt; use futures::FutureExt;
use futures::{StreamExt, future::BoxFuture, stream::BoxStream}; use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window}; use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
@@ -70,7 +70,15 @@ pub fn init_settings(cx: &mut App) {
/// A completion event from a language model. /// A completion event from a language model.
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)] #[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
pub enum LanguageModelCompletionEvent { pub enum LanguageModelCompletionEvent {
StatusUpdate(CompletionRequestStatus), Queued {
position: usize,
},
Started,
RequestUsage {
amount: usize,
limit: UsageLimit,
},
ToolUseLimitReached,
Stop(StopReason), Stop(StopReason),
Text(String), Text(String),
Thinking { Thinking {
@@ -90,88 +98,93 @@ pub enum LanguageModelCompletionEvent {
StartMessage { StartMessage {
message_id: String, message_id: String,
}, },
UsageUpdate(TokenUsage), TokenUsage(TokenUsage),
}
impl LanguageModelCompletionEvent {
pub fn from_completion_request_status(
status: CompletionRequestStatus,
) -> Result<Self, LanguageModelCompletionError> {
match status {
CompletionRequestStatus::Queued { position } => {
Ok(LanguageModelCompletionEvent::Queued { position })
}
CompletionRequestStatus::Started => Ok(LanguageModelCompletionEvent::Started),
CompletionRequestStatus::UsageUpdated { amount, limit } => {
Ok(LanguageModelCompletionEvent::RequestUsage { amount, limit })
}
CompletionRequestStatus::ToolUseLimitReached => {
Ok(LanguageModelCompletionEvent::ToolUseLimitReached)
}
CompletionRequestStatus::Failed {
code,
message,
request_id: _,
retry_after,
} => Err(LanguageModelCompletionError::from_cloud_failure(
code,
message,
retry_after.map(Duration::from_secs_f64),
)),
}
}
} }
#[derive(Error, Debug)] #[derive(Error, Debug)]
pub enum LanguageModelCompletionError { pub enum LanguageModelCompletionError {
#[error("prompt too large for context window")] #[error("prompt too large for context window")]
PromptTooLarge { tokens: Option<u64> }, PromptTooLarge { tokens: Option<u64> },
#[error("missing {provider} API key")] #[error("missing API key")]
NoApiKey { provider: LanguageModelProviderName }, NoApiKey,
#[error("{provider}'s API rate limit exceeded")] #[error("API rate limit exceeded")]
RateLimitExceeded { RateLimitExceeded { retry_after: Option<Duration> },
provider: LanguageModelProviderName, #[error("API servers are overloaded right now")]
retry_after: Option<Duration>, ServerOverloaded { retry_after: Option<Duration> },
}, #[error("API server reported an internal server error: {message}")]
#[error("{provider}'s API servers are overloaded right now")] ApiInternalServerError { message: String },
ServerOverloaded {
provider: LanguageModelProviderName,
retry_after: Option<Duration>,
},
#[error("{provider}'s API server reported an internal server error: {message}")]
ApiInternalServerError {
provider: LanguageModelProviderName,
message: String,
},
#[error("{message}")] #[error("{message}")]
UpstreamProviderError { UpstreamProviderError {
message: String, message: String,
status: StatusCode, status: StatusCode,
retry_after: Option<Duration>, retry_after: Option<Duration>,
}, },
#[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")] #[error("HTTP response error from API: status {status_code} - {message:?}")]
HttpResponseError { HttpResponseError {
provider: LanguageModelProviderName,
status_code: StatusCode, status_code: StatusCode,
message: String, message: String,
}, },
// Client errors // Client errors
#[error("invalid request format to {provider}'s API: {message}")] #[error("invalid request format to API: {message}")]
BadRequestFormat { BadRequestFormat { message: String },
provider: LanguageModelProviderName, #[error("authentication error with API: {message}")]
message: String, AuthenticationError { message: String },
}, #[error("Permission error with API: {message}")]
#[error("authentication error with {provider}'s API: {message}")] PermissionError { message: String },
AuthenticationError {
provider: LanguageModelProviderName,
message: String,
},
#[error("Permission error with {provider}'s API: {message}")]
PermissionError {
provider: LanguageModelProviderName,
message: String,
},
#[error("language model provider API endpoint not found")] #[error("language model provider API endpoint not found")]
ApiEndpointNotFound { provider: LanguageModelProviderName }, ApiEndpointNotFound,
#[error("I/O error reading response from {provider}'s API")] #[error("I/O error reading response from API")]
ApiReadResponseError { ApiReadResponseError {
provider: LanguageModelProviderName,
#[source] #[source]
error: io::Error, error: io::Error,
}, },
#[error("error serializing request to {provider} API")] #[error("error serializing request to API")]
SerializeRequest { SerializeRequest {
provider: LanguageModelProviderName,
#[source] #[source]
error: serde_json::Error, error: serde_json::Error,
}, },
#[error("error building request body to {provider} API")] #[error("error building request body to API")]
BuildRequestBody { BuildRequestBody {
provider: LanguageModelProviderName,
#[source] #[source]
error: http::Error, error: http::Error,
}, },
#[error("error sending HTTP request to {provider} API")] #[error("error sending HTTP request to API")]
HttpSend { HttpSend {
provider: LanguageModelProviderName,
#[source] #[source]
error: anyhow::Error, error: anyhow::Error,
}, },
#[error("error deserializing {provider} API response")] #[error("error deserializing API response")]
DeserializeResponse { DeserializeResponse {
provider: LanguageModelProviderName,
#[source] #[source]
error: serde_json::Error, error: serde_json::Error,
}, },
@@ -182,6 +195,72 @@ pub enum LanguageModelCompletionError {
} }
impl LanguageModelCompletionError { impl LanguageModelCompletionError {
fn display_format(&self, provider: LanguageModelProviderName) {
// match self {
// #[error("prompt too large for context window")]
// PromptTooLarge { tokens: Option<u64> },
// #[error("missing API key")]
// NoApiKey,
// #[error("API rate limit exceeded")]
// RateLimitExceeded { retry_after: Option<Duration> },
// #[error("API servers are overloaded right now")]
// ServerOverloaded { retry_after: Option<Duration> },
// #[error("API server reported an internal server error: {message}")]
// ApiInternalServerError { message: String },
// #[error("{message}")]
// UpstreamProviderError {
// message: String,
// status: StatusCode,
// retry_after: Option<Duration>,
// },
// #[error("HTTP response error from API: status {status_code} - {message:?}")]
// HttpResponseError {
// status_code: StatusCode,
// message: String,
// },
// // Client errors
// #[error("invalid request format to API: {message}")]
// BadRequestFormat { message: String },
// #[error("authentication error with API: {message}")]
// AuthenticationError { message: String },
// #[error("Permission error with API: {message}")]
// PermissionError { message: String },
// #[error("language model provider API endpoint not found")]
// ApiEndpointNotFound,
// #[error("I/O error reading response from API")]
// ApiReadResponseError {
// #[source]
// error: io::Error,
// },
// #[error("error serializing request to API")]
// SerializeRequest {
// #[source]
// error: serde_json::Error,
// },
// #[error("error building request body to API")]
// BuildRequestBody {
// #[source]
// error: http::Error,
// },
// #[error("error sending HTTP request to API")]
// HttpSend {
// #[source]
// error: anyhow::Error,
// },
// #[error("error deserializing API response")]
// DeserializeResponse {
// #[source]
// error: serde_json::Error,
// },
// // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
// #[error(transparent)]
// Other(#[from] anyhow::Error),
// }
}
fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> { fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?; let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
let upstream_status = error_json let upstream_status = error_json
@@ -198,7 +277,6 @@ impl LanguageModelCompletionError {
} }
pub fn from_cloud_failure( pub fn from_cloud_failure(
upstream_provider: LanguageModelProviderName,
code: String, code: String,
message: String, message: String,
retry_after: Option<Duration>, retry_after: Option<Duration>,
@@ -214,58 +292,46 @@ impl LanguageModelCompletionError {
if let Some((upstream_status, inner_message)) = if let Some((upstream_status, inner_message)) =
Self::parse_upstream_error_json(&message) Self::parse_upstream_error_json(&message)
{ {
return Self::from_http_status( return Self::from_http_status(upstream_status, inner_message, retry_after);
upstream_provider,
upstream_status,
inner_message,
retry_after,
);
} }
anyhow!("completion request failed, code: {code}, message: {message}").into() Self::Other(anyhow!(
"completion request failed, code: {code}, message: {message}"
))
} else if let Some(status_code) = code } else if let Some(status_code) = code
.strip_prefix("upstream_http_") .strip_prefix("upstream_http_")
.and_then(|code| StatusCode::from_str(code).ok()) .and_then(|code| StatusCode::from_str(code).ok())
{ {
Self::from_http_status(upstream_provider, status_code, message, retry_after) Self::from_http_status(status_code, message, retry_after)
} else if let Some(status_code) = code } else if let Some(status_code) = code
.strip_prefix("http_") .strip_prefix("http_")
.and_then(|code| StatusCode::from_str(code).ok()) .and_then(|code| StatusCode::from_str(code).ok())
{ {
Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after) Self::from_http_status(status_code, message, retry_after)
} else { } else {
anyhow!("completion request failed, code: {code}, message: {message}").into() Self::Other(anyhow!(
"completion request failed, code: {code}, message: {message}"
))
} }
} }
pub fn from_http_status( pub fn from_http_status(
provider: LanguageModelProviderName,
status_code: StatusCode, status_code: StatusCode,
message: String, message: String,
retry_after: Option<Duration>, retry_after: Option<Duration>,
) -> Self { ) -> Self {
match status_code { match status_code {
StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message }, StatusCode::BAD_REQUEST => Self::BadRequestFormat { message },
StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message }, StatusCode::UNAUTHORIZED => Self::AuthenticationError { message },
StatusCode::FORBIDDEN => Self::PermissionError { provider, message }, StatusCode::FORBIDDEN => Self::PermissionError { message },
StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider }, StatusCode::NOT_FOUND => Self::ApiEndpointNotFound,
StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge { StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
tokens: parse_prompt_too_long(&message), tokens: parse_prompt_too_long(&message),
}, },
StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded { StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded { retry_after },
provider, StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { message },
retry_after, StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded { retry_after },
}, _ if status_code.as_u16() == 529 => Self::ServerOverloaded { retry_after },
StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
provider,
retry_after,
},
_ if status_code.as_u16() == 529 => Self::ServerOverloaded {
provider,
retry_after,
},
_ => Self::HttpResponseError { _ => Self::HttpResponseError {
provider,
status_code, status_code,
message, message,
}, },
@@ -275,31 +341,25 @@ impl LanguageModelCompletionError {
impl From<AnthropicError> for LanguageModelCompletionError { impl From<AnthropicError> for LanguageModelCompletionError {
fn from(error: AnthropicError) -> Self { fn from(error: AnthropicError) -> Self {
let provider = ANTHROPIC_PROVIDER_NAME;
match error { match error {
AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error }, AnthropicError::SerializeRequest(error) => Self::SerializeRequest { error },
AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error }, AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { error },
AnthropicError::HttpSend(error) => Self::HttpSend { provider, error }, AnthropicError::HttpSend(error) => Self::HttpSend { error },
AnthropicError::DeserializeResponse(error) => { AnthropicError::DeserializeResponse(error) => Self::DeserializeResponse { error },
Self::DeserializeResponse { provider, error } AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { error },
}
AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
AnthropicError::HttpResponseError { AnthropicError::HttpResponseError {
status_code, status_code,
message, message,
} => Self::HttpResponseError { } => Self::HttpResponseError {
provider,
status_code, status_code,
message, message,
}, },
AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded { AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
provider,
retry_after: Some(retry_after), retry_after: Some(retry_after),
}, },
AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded { AnthropicError::ServerOverloaded { retry_after } => {
provider, Self::ServerOverloaded { retry_after }
retry_after, }
},
AnthropicError::ApiError(api_error) => api_error.into(), AnthropicError::ApiError(api_error) => api_error.into(),
} }
} }
@@ -308,37 +368,26 @@ impl From<AnthropicError> for LanguageModelCompletionError {
impl From<anthropic::ApiError> for LanguageModelCompletionError { impl From<anthropic::ApiError> for LanguageModelCompletionError {
fn from(error: anthropic::ApiError) -> Self { fn from(error: anthropic::ApiError) -> Self {
use anthropic::ApiErrorCode::*; use anthropic::ApiErrorCode::*;
let provider = ANTHROPIC_PROVIDER_NAME;
match error.code() { match error.code() {
Some(code) => match code { Some(code) => match code {
InvalidRequestError => Self::BadRequestFormat { InvalidRequestError => Self::BadRequestFormat {
provider,
message: error.message, message: error.message,
}, },
AuthenticationError => Self::AuthenticationError { AuthenticationError => Self::AuthenticationError {
provider,
message: error.message, message: error.message,
}, },
PermissionError => Self::PermissionError { PermissionError => Self::PermissionError {
provider,
message: error.message, message: error.message,
}, },
NotFoundError => Self::ApiEndpointNotFound { provider }, NotFoundError => Self::ApiEndpointNotFound,
RequestTooLarge => Self::PromptTooLarge { RequestTooLarge => Self::PromptTooLarge {
tokens: parse_prompt_too_long(&error.message), tokens: parse_prompt_too_long(&error.message),
}, },
RateLimitError => Self::RateLimitExceeded { RateLimitError => Self::RateLimitExceeded { retry_after: None },
provider,
retry_after: None,
},
ApiError => Self::ApiInternalServerError { ApiError => Self::ApiInternalServerError {
provider,
message: error.message, message: error.message,
}, },
OverloadedError => Self::ServerOverloaded { OverloadedError => Self::ServerOverloaded { retry_after: None },
provider,
retry_after: None,
},
}, },
None => Self::Other(error.into()), None => Self::Other(error.into()),
} }
@@ -349,7 +398,7 @@ impl From<open_ai::RequestError> for LanguageModelCompletionError {
fn from(error: open_ai::RequestError) -> Self { fn from(error: open_ai::RequestError) -> Self {
match error { match error {
open_ai::RequestError::HttpResponseError { open_ai::RequestError::HttpResponseError {
provider, provider: _,
status_code, status_code,
body, body,
headers, headers,
@@ -359,7 +408,7 @@ impl From<open_ai::RequestError> for LanguageModelCompletionError {
.and_then(|val| val.to_str().ok()?.parse::<u64>().ok()) .and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
.map(Duration::from_secs); .map(Duration::from_secs);
Self::from_http_status(provider.into(), status_code, body, retry_after) Self::from_http_status(status_code, body, retry_after)
} }
open_ai::RequestError::Other(e) => Self::Other(e), open_ai::RequestError::Other(e) => Self::Other(e),
} }
@@ -368,23 +417,18 @@ impl From<open_ai::RequestError> for LanguageModelCompletionError {
impl From<OpenRouterError> for LanguageModelCompletionError { impl From<OpenRouterError> for LanguageModelCompletionError {
fn from(error: OpenRouterError) -> Self { fn from(error: OpenRouterError) -> Self {
let provider = LanguageModelProviderName::new("OpenRouter");
match error { match error {
OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error }, OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { error },
OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error }, OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { error },
OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error }, OpenRouterError::HttpSend(error) => Self::HttpSend { error },
OpenRouterError::DeserializeResponse(error) => { OpenRouterError::DeserializeResponse(error) => Self::DeserializeResponse { error },
Self::DeserializeResponse { provider, error } OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { error },
}
OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded { OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
provider,
retry_after: Some(retry_after), retry_after: Some(retry_after),
}, },
OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded { OpenRouterError::ServerOverloaded { retry_after } => {
provider, Self::ServerOverloaded { retry_after }
retry_after, }
},
OpenRouterError::ApiError(api_error) => api_error.into(), OpenRouterError::ApiError(api_error) => api_error.into(),
} }
} }
@@ -393,41 +437,28 @@ impl From<OpenRouterError> for LanguageModelCompletionError {
impl From<open_router::ApiError> for LanguageModelCompletionError { impl From<open_router::ApiError> for LanguageModelCompletionError {
fn from(error: open_router::ApiError) -> Self { fn from(error: open_router::ApiError) -> Self {
use open_router::ApiErrorCode::*; use open_router::ApiErrorCode::*;
let provider = LanguageModelProviderName::new("OpenRouter");
match error.code { match error.code {
InvalidRequestError => Self::BadRequestFormat { InvalidRequestError => Self::BadRequestFormat {
provider,
message: error.message, message: error.message,
}, },
AuthenticationError => Self::AuthenticationError { AuthenticationError => Self::AuthenticationError {
provider,
message: error.message, message: error.message,
}, },
PaymentRequiredError => Self::AuthenticationError { PaymentRequiredError => Self::AuthenticationError {
provider,
message: format!("Payment required: {}", error.message), message: format!("Payment required: {}", error.message),
}, },
PermissionError => Self::PermissionError { PermissionError => Self::PermissionError {
provider,
message: error.message, message: error.message,
}, },
RequestTimedOut => Self::HttpResponseError { RequestTimedOut => Self::HttpResponseError {
provider,
status_code: StatusCode::REQUEST_TIMEOUT, status_code: StatusCode::REQUEST_TIMEOUT,
message: error.message, message: error.message,
}, },
RateLimitError => Self::RateLimitExceeded { RateLimitError => Self::RateLimitExceeded { retry_after: None },
provider,
retry_after: None,
},
ApiError => Self::ApiInternalServerError { ApiError => Self::ApiInternalServerError {
provider,
message: error.message, message: error.message,
}, },
OverloadedError => Self::ServerOverloaded { OverloadedError => Self::ServerOverloaded { retry_after: None },
provider,
retry_after: None,
},
} }
} }
} }
@@ -633,7 +664,10 @@ pub trait LanguageModel: Send + Sync {
let last_token_usage = last_token_usage.clone(); let last_token_usage = last_token_usage.clone();
async move { async move {
match result { match result {
Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None, Ok(LanguageModelCompletionEvent::Queued { .. }) => None,
Ok(LanguageModelCompletionEvent::Started) => None,
Ok(LanguageModelCompletionEvent::RequestUsage { .. }) => None,
Ok(LanguageModelCompletionEvent::ToolUseLimitReached) => None,
Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None, Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)), Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
Ok(LanguageModelCompletionEvent::Thinking { .. }) => None, Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
@@ -643,7 +677,7 @@ pub trait LanguageModel: Send + Sync {
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError { Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
.. ..
}) => None, }) => None,
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => { Ok(LanguageModelCompletionEvent::TokenUsage(token_usage)) => {
*last_token_usage.lock() = token_usage; *last_token_usage.lock() = token_usage;
None None
} }
@@ -832,16 +866,13 @@ mod tests {
#[test] #[test]
fn test_from_cloud_failure_with_upstream_http_error() { fn test_from_cloud_failure_with_upstream_http_error() {
let error = LanguageModelCompletionError::from_cloud_failure( let error = LanguageModelCompletionError::from_cloud_failure(
String::from("anthropic").into(),
"upstream_http_error".to_string(), "upstream_http_error".to_string(),
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(), r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
None, None,
); );
match error { match error {
LanguageModelCompletionError::ServerOverloaded { provider, .. } => { LanguageModelCompletionError::ServerOverloaded { .. } => {}
assert_eq!(provider.0, "anthropic");
}
_ => panic!( _ => panic!(
"Expected ServerOverloaded error for 503 status, got: {:?}", "Expected ServerOverloaded error for 503 status, got: {:?}",
error error
@@ -849,15 +880,13 @@ mod tests {
} }
let error = LanguageModelCompletionError::from_cloud_failure( let error = LanguageModelCompletionError::from_cloud_failure(
String::from("anthropic").into(),
"upstream_http_error".to_string(), "upstream_http_error".to_string(),
r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(), r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
None, None,
); );
match error { match error {
LanguageModelCompletionError::ApiInternalServerError { provider, message } => { LanguageModelCompletionError::ApiInternalServerError { message } => {
assert_eq!(provider.0, "anthropic");
assert_eq!(message, "Internal server error"); assert_eq!(message, "Internal server error");
} }
_ => panic!( _ => panic!(
@@ -870,16 +899,13 @@ mod tests {
#[test] #[test]
fn test_from_cloud_failure_with_standard_format() { fn test_from_cloud_failure_with_standard_format() {
let error = LanguageModelCompletionError::from_cloud_failure( let error = LanguageModelCompletionError::from_cloud_failure(
String::from("anthropic").into(),
"upstream_http_503".to_string(), "upstream_http_503".to_string(),
"Service unavailable".to_string(), "Service unavailable".to_string(),
None, None,
); );
match error { match error {
LanguageModelCompletionError::ServerOverloaded { provider, .. } => { LanguageModelCompletionError::ServerOverloaded { .. } => {}
assert_eq!(provider.0, "anthropic");
}
_ => panic!("Expected ServerOverloaded error for upstream_http_503"), _ => panic!("Expected ServerOverloaded error for upstream_http_503"),
} }
} }
@@ -887,16 +913,13 @@ mod tests {
#[test] #[test]
fn test_upstream_http_error_connection_timeout() { fn test_upstream_http_error_connection_timeout() {
let error = LanguageModelCompletionError::from_cloud_failure( let error = LanguageModelCompletionError::from_cloud_failure(
String::from("anthropic").into(),
"upstream_http_error".to_string(), "upstream_http_error".to_string(),
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(), r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
None, None,
); );
match error { match error {
LanguageModelCompletionError::ServerOverloaded { provider, .. } => { LanguageModelCompletionError::ServerOverloaded { .. } => {}
assert_eq!(provider.0, "anthropic");
}
_ => panic!( _ => panic!(
"Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}", "Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
error error
@@ -904,15 +927,13 @@ mod tests {
} }
let error = LanguageModelCompletionError::from_cloud_failure( let error = LanguageModelCompletionError::from_cloud_failure(
String::from("anthropic").into(),
"upstream_http_error".to_string(), "upstream_http_error".to_string(),
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(), r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
None, None,
); );
match error { match error {
LanguageModelCompletionError::ApiInternalServerError { provider, message } => { LanguageModelCompletionError::ApiInternalServerError { message } => {
assert_eq!(provider.0, "anthropic");
assert_eq!( assert_eq!(
message, message,
"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout" "Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"

View File

@@ -320,9 +320,7 @@ impl AnthropicModel {
async move { async move {
let Some(api_key) = api_key else { let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey { return Err(LanguageModelCompletionError::NoApiKey);
provider: PROVIDER_NAME,
});
}; };
let request = anthropic::stream_completion( let request = anthropic::stream_completion(
http_client.as_ref(), http_client.as_ref(),
@@ -756,7 +754,7 @@ impl AnthropicEventMapper {
Event::MessageStart { message } => { Event::MessageStart { message } => {
update_usage(&mut self.usage, &message.usage); update_usage(&mut self.usage, &message.usage);
vec![ vec![
Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage( Ok(LanguageModelCompletionEvent::TokenUsage(convert_usage(
&self.usage, &self.usage,
))), ))),
Ok(LanguageModelCompletionEvent::StartMessage { Ok(LanguageModelCompletionEvent::StartMessage {
@@ -778,9 +776,9 @@ impl AnthropicEventMapper {
} }
}; };
} }
vec![Ok(LanguageModelCompletionEvent::UsageUpdate( vec![Ok(LanguageModelCompletionEvent::TokenUsage(convert_usage(
convert_usage(&self.usage), &self.usage,
))] )))]
} }
Event::MessageStop => { Event::MessageStop => {
vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))] vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))]

View File

@@ -975,7 +975,7 @@ pub fn map_to_language_model_completion_events(
)) ))
}), }),
ConverseStreamOutput::Metadata(cb_meta) => cb_meta.usage.map(|metadata| { ConverseStreamOutput::Metadata(cb_meta) => cb_meta.usage.map(|metadata| {
Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: metadata.input_tokens as u64, input_tokens: metadata.input_tokens as u64,
output_tokens: metadata.output_tokens as u64, output_tokens: metadata.output_tokens as u64,
cache_creation_input_tokens: metadata cache_creation_input_tokens: metadata

View File

@@ -541,7 +541,6 @@ impl From<ApiError> for LanguageModelCompletionError {
} }
return LanguageModelCompletionError::from_http_status( return LanguageModelCompletionError::from_http_status(
PROVIDER_NAME,
error.status, error.status,
cloud_error.message, cloud_error.message,
None, None,
@@ -549,12 +548,7 @@ impl From<ApiError> for LanguageModelCompletionError {
} }
let retry_after = None; let retry_after = None;
LanguageModelCompletionError::from_http_status( LanguageModelCompletionError::from_http_status(error.status, error.body, retry_after)
PROVIDER_NAME,
error.status,
error.body,
retry_after,
)
} }
} }
@@ -961,7 +955,7 @@ where
vec![Err(LanguageModelCompletionError::from(error))] vec![Err(LanguageModelCompletionError::from(error))]
} }
Ok(CompletionEvent::Status(event)) => { Ok(CompletionEvent::Status(event)) => {
vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))] vec![LanguageModelCompletionEvent::from_completion_request_status(event)]
} }
Ok(CompletionEvent::Event(event)) => map_callback(event), Ok(CompletionEvent::Event(event)) => map_callback(event),
}) })
@@ -1313,8 +1307,7 @@ mod tests {
let completion_error: LanguageModelCompletionError = api_error.into(); let completion_error: LanguageModelCompletionError = api_error.into();
match completion_error { match completion_error {
LanguageModelCompletionError::ApiInternalServerError { provider, message } => { LanguageModelCompletionError::ApiInternalServerError { message } => {
assert_eq!(provider, PROVIDER_NAME);
assert_eq!(message, "Regular internal server error"); assert_eq!(message, "Regular internal server error");
} }
_ => panic!( _ => panic!(
@@ -1362,9 +1355,7 @@ mod tests {
let completion_error: LanguageModelCompletionError = api_error.into(); let completion_error: LanguageModelCompletionError = api_error.into();
match completion_error { match completion_error {
LanguageModelCompletionError::ApiInternalServerError { provider, .. } => { LanguageModelCompletionError::ApiInternalServerError { .. } => {}
assert_eq!(provider, PROVIDER_NAME);
}
_ => panic!( _ => panic!(
"Expected ApiInternalServerError for invalid JSON, got: {:?}", "Expected ApiInternalServerError for invalid JSON, got: {:?}",
completion_error completion_error

View File

@@ -422,14 +422,12 @@ pub fn map_to_language_model_completion_events(
} }
if let Some(usage) = event.usage { if let Some(usage) = event.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate( events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
TokenUsage { input_tokens: usage.prompt_tokens,
input_tokens: usage.prompt_tokens, output_tokens: usage.completion_tokens,
output_tokens: usage.completion_tokens, cache_creation_input_tokens: 0,
cache_creation_input_tokens: 0, cache_read_input_tokens: 0,
cache_read_input_tokens: 0, })));
},
)));
} }
match choice.finish_reason.as_deref() { match choice.finish_reason.as_deref() {
@@ -610,7 +608,7 @@ impl CopilotResponsesEventMapper {
copilot::copilot_responses::StreamEvent::Completed { response } => { copilot::copilot_responses::StreamEvent::Completed { response } => {
let mut events = Vec::new(); let mut events = Vec::new();
if let Some(usage) = response.usage { if let Some(usage) = response.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: usage.input_tokens.unwrap_or(0), input_tokens: usage.input_tokens.unwrap_or(0),
output_tokens: usage.output_tokens.unwrap_or(0), output_tokens: usage.output_tokens.unwrap_or(0),
cache_creation_input_tokens: 0, cache_creation_input_tokens: 0,
@@ -643,7 +641,7 @@ impl CopilotResponsesEventMapper {
let mut events = Vec::new(); let mut events = Vec::new();
if let Some(usage) = response.usage { if let Some(usage) = response.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: usage.input_tokens.unwrap_or(0), input_tokens: usage.input_tokens.unwrap_or(0),
output_tokens: usage.output_tokens.unwrap_or(0), output_tokens: usage.output_tokens.unwrap_or(0),
cache_creation_input_tokens: 0, cache_creation_input_tokens: 0,
@@ -655,7 +653,6 @@ impl CopilotResponsesEventMapper {
} }
copilot::copilot_responses::StreamEvent::Failed { response } => { copilot::copilot_responses::StreamEvent::Failed { response } => {
let provider = PROVIDER_NAME;
let (status_code, message) = match response.error { let (status_code, message) = match response.error {
Some(error) => { Some(error) => {
let status_code = StatusCode::from_str(&error.code) let status_code = StatusCode::from_str(&error.code)
@@ -668,7 +665,6 @@ impl CopilotResponsesEventMapper {
), ),
}; };
vec![Err(LanguageModelCompletionError::HttpResponseError { vec![Err(LanguageModelCompletionError::HttpResponseError {
provider,
status_code, status_code,
message, message,
})] })]
@@ -1099,7 +1095,7 @@ mod tests {
)); ));
assert!(matches!( assert!(matches!(
mapped[2], mapped[2],
LanguageModelCompletionEvent::UsageUpdate(TokenUsage { LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: 5, input_tokens: 5,
output_tokens: 3, output_tokens: 3,
.. ..
@@ -1207,7 +1203,7 @@ mod tests {
let mapped = map_events(events); let mapped = map_events(events);
assert!(matches!( assert!(matches!(
mapped[0], mapped[0],
LanguageModelCompletionEvent::UsageUpdate(TokenUsage { LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: 10, input_tokens: 10,
output_tokens: 0, output_tokens: 0,
.. ..

View File

@@ -224,9 +224,7 @@ impl DeepSeekLanguageModel {
let future = self.request_limiter.stream(async move { let future = self.request_limiter.stream(async move {
let Some(api_key) = api_key else { let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey { return Err(LanguageModelCompletionError::NoApiKey);
provider: PROVIDER_NAME,
});
}; };
let request = let request =
deepseek::stream_completion(http_client.as_ref(), &api_url, &api_key, request); deepseek::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
@@ -479,7 +477,7 @@ impl DeepSeekEventMapper {
} }
if let Some(usage) = event.usage { if let Some(usage) = event.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: usage.prompt_tokens, input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens, output_tokens: usage.completion_tokens,
cache_creation_input_tokens: 0, cache_creation_input_tokens: 0,

View File

@@ -350,10 +350,7 @@ impl LanguageModel for GoogleLanguageModel {
async move { async move {
let Some(api_key) = api_key else { let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey { return Err(LanguageModelCompletionError::NoApiKey.into());
provider: PROVIDER_NAME,
}
.into());
}; };
let response = google_ai::count_tokens( let response = google_ai::count_tokens(
http_client.as_ref(), http_client.as_ref(),
@@ -608,9 +605,9 @@ impl GoogleEventMapper {
let mut wants_to_use_tool = false; let mut wants_to_use_tool = false;
if let Some(usage_metadata) = event.usage_metadata { if let Some(usage_metadata) = event.usage_metadata {
update_usage(&mut self.usage, &usage_metadata); update_usage(&mut self.usage, &usage_metadata);
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate( events.push(Ok(LanguageModelCompletionEvent::TokenUsage(convert_usage(
convert_usage(&self.usage), &self.usage,
))) ))))
} }
if let Some(prompt_feedback) = event.prompt_feedback if let Some(prompt_feedback) = event.prompt_feedback

View File

@@ -547,7 +547,7 @@ impl LmStudioEventMapper {
} }
if let Some(usage) = event.usage { if let Some(usage) = event.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: usage.prompt_tokens, input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens, output_tokens: usage.completion_tokens,
cache_creation_input_tokens: 0, cache_creation_input_tokens: 0,

View File

@@ -291,9 +291,7 @@ impl MistralLanguageModel {
let future = self.request_limiter.stream(async move { let future = self.request_limiter.stream(async move {
let Some(api_key) = api_key else { let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey { return Err(LanguageModelCompletionError::NoApiKey);
provider: PROVIDER_NAME,
});
}; };
let request = let request =
mistral::stream_completion(http_client.as_ref(), &api_url, &api_key, request); mistral::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
@@ -672,7 +670,7 @@ impl MistralEventMapper {
} }
if let Some(usage) = event.usage { if let Some(usage) = event.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: usage.prompt_tokens, input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens, output_tokens: usage.completion_tokens,
cache_creation_input_tokens: 0, cache_creation_input_tokens: 0,

View File

@@ -603,7 +603,7 @@ fn map_to_language_model_completion_events(
}; };
if delta.done { if delta.done {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: delta.prompt_eval_count.unwrap_or(0), input_tokens: delta.prompt_eval_count.unwrap_or(0),
output_tokens: delta.eval_count.unwrap_or(0), output_tokens: delta.eval_count.unwrap_or(0),
cache_creation_input_tokens: 0, cache_creation_input_tokens: 0,

View File

@@ -228,7 +228,7 @@ impl OpenAiLanguageModel {
let future = self.request_limiter.stream(async move { let future = self.request_limiter.stream(async move {
let provider = PROVIDER_NAME; let provider = PROVIDER_NAME;
let Some(api_key) = api_key else { let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey { provider }); return Err(LanguageModelCompletionError::NoApiKey);
}; };
let request = stream_completion( let request = stream_completion(
http_client.as_ref(), http_client.as_ref(),
@@ -534,7 +534,7 @@ impl OpenAiEventMapper {
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> { ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
let mut events = Vec::new(); let mut events = Vec::new();
if let Some(usage) = event.usage { if let Some(usage) = event.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: usage.prompt_tokens, input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens, output_tokens: usage.completion_tokens,
cache_creation_input_tokens: 0, cache_creation_input_tokens: 0,

View File

@@ -227,7 +227,7 @@ impl OpenAiCompatibleLanguageModel {
let provider = self.provider_name.clone(); let provider = self.provider_name.clone();
let future = self.request_limiter.stream(async move { let future = self.request_limiter.stream(async move {
let Some(api_key) = api_key else { let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey { provider }); return Err(LanguageModelCompletionError::NoApiKey);
}; };
let request = stream_completion( let request = stream_completion(
http_client.as_ref(), http_client.as_ref(),

View File

@@ -84,9 +84,7 @@ impl State {
let http_client = self.http_client.clone(); let http_client = self.http_client.clone();
let api_url = OpenRouterLanguageModelProvider::api_url(cx); let api_url = OpenRouterLanguageModelProvider::api_url(cx);
let Some(api_key) = self.api_key_state.key(&api_url) else { let Some(api_key) = self.api_key_state.key(&api_url) else {
return Task::ready(Err(LanguageModelCompletionError::NoApiKey { return Task::ready(Err(LanguageModelCompletionError::NoApiKey));
provider: PROVIDER_NAME,
}));
}; };
cx.spawn(async move |this, cx| { cx.spawn(async move |this, cx| {
let models = list_models(http_client.as_ref(), &api_url, &api_key) let models = list_models(http_client.as_ref(), &api_url, &api_key)
@@ -288,9 +286,7 @@ impl OpenRouterLanguageModel {
async move { async move {
let Some(api_key) = api_key else { let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey { return Err(LanguageModelCompletionError::NoApiKey);
provider: PROVIDER_NAME,
});
}; };
let request = let request =
open_router::stream_completion(http_client.as_ref(), &api_url, &api_key, request); open_router::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
@@ -613,7 +609,7 @@ impl OpenRouterEventMapper {
} }
if let Some(usage) = event.usage { if let Some(usage) = event.usage {
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage { events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
input_tokens: usage.prompt_tokens, input_tokens: usage.prompt_tokens,
output_tokens: usage.completion_tokens, output_tokens: usage.completion_tokens,
cache_creation_input_tokens: 0, cache_creation_input_tokens: 0,

View File

@@ -222,7 +222,7 @@ impl VercelLanguageModel {
let future = self.request_limiter.stream(async move { let future = self.request_limiter.stream(async move {
let provider = PROVIDER_NAME; let provider = PROVIDER_NAME;
let Some(api_key) = api_key else { let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey { provider }); return Err(LanguageModelCompletionError::NoApiKey);
}; };
let request = open_ai::stream_completion( let request = open_ai::stream_completion(
http_client.as_ref(), http_client.as_ref(),

View File

@@ -230,7 +230,7 @@ impl XAiLanguageModel {
let future = self.request_limiter.stream(async move { let future = self.request_limiter.stream(async move {
let provider = PROVIDER_NAME; let provider = PROVIDER_NAME;
let Some(api_key) = api_key else { let Some(api_key) = api_key else {
return Err(LanguageModelCompletionError::NoApiKey { provider }); return Err(LanguageModelCompletionError::NoApiKey);
}; };
let request = open_ai::stream_completion( let request = open_ai::stream_completion(
http_client.as_ref(), http_client.as_ref(),