Compare commits
5 Commits
ex
...
simplify-e
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5b74ddf020 | ||
|
|
54f20ae5d5 | ||
|
|
811efa45d0 | ||
|
|
74501e0936 | ||
|
|
2ad8bd00ce |
@@ -21,9 +21,9 @@ use gpui::{
|
|||||||
use indoc::indoc;
|
use indoc::indoc;
|
||||||
use language_model::{
|
use language_model::{
|
||||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
|
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
|
||||||
LanguageModelProviderName, LanguageModelRegistry, LanguageModelRequest,
|
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||||
LanguageModelRequestMessage, LanguageModelToolResult, LanguageModelToolSchemaFormat,
|
LanguageModelToolResult, LanguageModelToolSchemaFormat, LanguageModelToolUse, MessageContent,
|
||||||
LanguageModelToolUse, MessageContent, Role, StopReason, fake_provider::FakeLanguageModel,
|
Role, StopReason, fake_provider::FakeLanguageModel,
|
||||||
};
|
};
|
||||||
use pretty_assertions::assert_eq;
|
use pretty_assertions::assert_eq;
|
||||||
use project::{
|
use project::{
|
||||||
@@ -664,9 +664,7 @@ async fn test_resume_after_tool_use_limit(cx: &mut TestAppContext) {
|
|||||||
);
|
);
|
||||||
|
|
||||||
// Simulate reaching tool use limit.
|
// Simulate reaching tool use limit.
|
||||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
|
||||||
cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
|
|
||||||
));
|
|
||||||
fake_model.end_last_completion_stream();
|
fake_model.end_last_completion_stream();
|
||||||
let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
|
let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
|
||||||
assert!(
|
assert!(
|
||||||
@@ -749,9 +747,7 @@ async fn test_send_after_tool_use_limit(cx: &mut TestAppContext) {
|
|||||||
};
|
};
|
||||||
fake_model
|
fake_model
|
||||||
.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
|
.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUse(tool_use.clone()));
|
||||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::StatusUpdate(
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::ToolUseLimitReached);
|
||||||
cloud_llm_client::CompletionRequestStatus::ToolUseLimitReached,
|
|
||||||
));
|
|
||||||
fake_model.end_last_completion_stream();
|
fake_model.end_last_completion_stream();
|
||||||
let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
|
let last_event = events.collect::<Vec<_>>().await.pop().unwrap();
|
||||||
assert!(
|
assert!(
|
||||||
@@ -1533,7 +1529,7 @@ async fn test_truncate_first_message(cx: &mut TestAppContext) {
|
|||||||
});
|
});
|
||||||
|
|
||||||
fake_model.send_last_completion_stream_text_chunk("Hey!");
|
fake_model.send_last_completion_stream_text_chunk("Hey!");
|
||||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::TokenUsage(
|
||||||
language_model::TokenUsage {
|
language_model::TokenUsage {
|
||||||
input_tokens: 32_000,
|
input_tokens: 32_000,
|
||||||
output_tokens: 16_000,
|
output_tokens: 16_000,
|
||||||
@@ -1591,7 +1587,7 @@ async fn test_truncate_first_message(cx: &mut TestAppContext) {
|
|||||||
});
|
});
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
fake_model.send_last_completion_stream_text_chunk("Ahoy!");
|
fake_model.send_last_completion_stream_text_chunk("Ahoy!");
|
||||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::TokenUsage(
|
||||||
language_model::TokenUsage {
|
language_model::TokenUsage {
|
||||||
input_tokens: 40_000,
|
input_tokens: 40_000,
|
||||||
output_tokens: 20_000,
|
output_tokens: 20_000,
|
||||||
@@ -1636,7 +1632,7 @@ async fn test_truncate_second_message(cx: &mut TestAppContext) {
|
|||||||
.unwrap();
|
.unwrap();
|
||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
fake_model.send_last_completion_stream_text_chunk("Message 1 response");
|
fake_model.send_last_completion_stream_text_chunk("Message 1 response");
|
||||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::TokenUsage(
|
||||||
language_model::TokenUsage {
|
language_model::TokenUsage {
|
||||||
input_tokens: 32_000,
|
input_tokens: 32_000,
|
||||||
output_tokens: 16_000,
|
output_tokens: 16_000,
|
||||||
@@ -1683,7 +1679,7 @@ async fn test_truncate_second_message(cx: &mut TestAppContext) {
|
|||||||
cx.run_until_parked();
|
cx.run_until_parked();
|
||||||
|
|
||||||
fake_model.send_last_completion_stream_text_chunk("Message 2 response");
|
fake_model.send_last_completion_stream_text_chunk("Message 2 response");
|
||||||
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::UsageUpdate(
|
fake_model.send_last_completion_stream_event(LanguageModelCompletionEvent::TokenUsage(
|
||||||
language_model::TokenUsage {
|
language_model::TokenUsage {
|
||||||
input_tokens: 40_000,
|
input_tokens: 40_000,
|
||||||
output_tokens: 20_000,
|
output_tokens: 20_000,
|
||||||
@@ -2159,7 +2155,6 @@ async fn test_send_retry_on_error(cx: &mut TestAppContext) {
|
|||||||
|
|
||||||
fake_model.send_last_completion_stream_text_chunk("Hey,");
|
fake_model.send_last_completion_stream_text_chunk("Hey,");
|
||||||
fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
|
fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
|
||||||
provider: LanguageModelProviderName::new("Anthropic"),
|
|
||||||
retry_after: Some(Duration::from_secs(3)),
|
retry_after: Some(Duration::from_secs(3)),
|
||||||
});
|
});
|
||||||
fake_model.end_last_completion_stream();
|
fake_model.end_last_completion_stream();
|
||||||
@@ -2235,7 +2230,6 @@ async fn test_send_retry_finishes_tool_calls_on_error(cx: &mut TestAppContext) {
|
|||||||
tool_use_1.clone(),
|
tool_use_1.clone(),
|
||||||
));
|
));
|
||||||
fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
|
fake_model.send_last_completion_stream_error(LanguageModelCompletionError::ServerOverloaded {
|
||||||
provider: LanguageModelProviderName::new("Anthropic"),
|
|
||||||
retry_after: Some(Duration::from_secs(3)),
|
retry_after: Some(Duration::from_secs(3)),
|
||||||
});
|
});
|
||||||
fake_model.end_last_completion_stream();
|
fake_model.end_last_completion_stream();
|
||||||
@@ -2302,7 +2296,6 @@ async fn test_send_max_retries_exceeded(cx: &mut TestAppContext) {
|
|||||||
for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
|
for _ in 0..crate::thread::MAX_RETRY_ATTEMPTS + 1 {
|
||||||
fake_model.send_last_completion_stream_error(
|
fake_model.send_last_completion_stream_error(
|
||||||
LanguageModelCompletionError::ServerOverloaded {
|
LanguageModelCompletionError::ServerOverloaded {
|
||||||
provider: LanguageModelProviderName::new("Anthropic"),
|
|
||||||
retry_after: Some(Duration::from_secs(3)),
|
retry_after: Some(Duration::from_secs(3)),
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ use agent_settings::{
|
|||||||
use anyhow::{Context as _, Result, anyhow};
|
use anyhow::{Context as _, Result, anyhow};
|
||||||
use chrono::{DateTime, Utc};
|
use chrono::{DateTime, Utc};
|
||||||
use client::{ModelRequestUsage, RequestUsage, UserStore};
|
use client::{ModelRequestUsage, RequestUsage, UserStore};
|
||||||
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, Plan, UsageLimit};
|
use cloud_llm_client::{CompletionIntent, Plan, UsageLimit};
|
||||||
use collections::{HashMap, HashSet, IndexMap};
|
use collections::{HashMap, HashSet, IndexMap};
|
||||||
use fs::Fs;
|
use fs::Fs;
|
||||||
use futures::stream;
|
use futures::stream;
|
||||||
@@ -30,11 +30,11 @@ use gpui::{
|
|||||||
};
|
};
|
||||||
use language_model::{
|
use language_model::{
|
||||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt,
|
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelExt,
|
||||||
LanguageModelId, LanguageModelImage, LanguageModelProviderId, LanguageModelRegistry,
|
LanguageModelId, LanguageModelImage, LanguageModelProviderId, LanguageModelProviderName,
|
||||||
LanguageModelRequest, LanguageModelRequestMessage, LanguageModelRequestTool,
|
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||||
LanguageModelToolResult, LanguageModelToolResultContent, LanguageModelToolSchemaFormat,
|
LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolResultContent,
|
||||||
LanguageModelToolUse, LanguageModelToolUseId, Role, SelectedModel, StopReason, TokenUsage,
|
LanguageModelToolSchemaFormat, LanguageModelToolUse, LanguageModelToolUseId, Role,
|
||||||
ZED_CLOUD_PROVIDER_ID,
|
SelectedModel, StopReason, TokenUsage, ZED_CLOUD_PROVIDER_ID,
|
||||||
};
|
};
|
||||||
use project::Project;
|
use project::Project;
|
||||||
use prompt_store::ProjectContext;
|
use prompt_store::ProjectContext;
|
||||||
@@ -1295,9 +1295,10 @@ impl Thread {
|
|||||||
|
|
||||||
if let Some(error) = error {
|
if let Some(error) = error {
|
||||||
attempt += 1;
|
attempt += 1;
|
||||||
|
let provider = model.upstream_provider_name();
|
||||||
let retry = this.update(cx, |this, cx| {
|
let retry = this.update(cx, |this, cx| {
|
||||||
let user_store = this.user_store.read(cx);
|
let user_store = this.user_store.read(cx);
|
||||||
this.handle_completion_error(error, attempt, user_store.plan())
|
this.handle_completion_error(provider, error, attempt, user_store.plan())
|
||||||
})??;
|
})??;
|
||||||
let timer = cx.background_executor().timer(retry.duration);
|
let timer = cx.background_executor().timer(retry.duration);
|
||||||
event_stream.send_retry(retry);
|
event_stream.send_retry(retry);
|
||||||
@@ -1323,6 +1324,7 @@ impl Thread {
|
|||||||
|
|
||||||
fn handle_completion_error(
|
fn handle_completion_error(
|
||||||
&mut self,
|
&mut self,
|
||||||
|
provider: LanguageModelProviderName,
|
||||||
error: LanguageModelCompletionError,
|
error: LanguageModelCompletionError,
|
||||||
attempt: u8,
|
attempt: u8,
|
||||||
plan: Option<Plan>,
|
plan: Option<Plan>,
|
||||||
@@ -1389,7 +1391,7 @@ impl Thread {
|
|||||||
use LanguageModelCompletionEvent::*;
|
use LanguageModelCompletionEvent::*;
|
||||||
|
|
||||||
match event {
|
match event {
|
||||||
StartMessage { .. } => {
|
LanguageModelCompletionEvent::StartMessage { .. } => {
|
||||||
self.flush_pending_message(cx);
|
self.flush_pending_message(cx);
|
||||||
self.pending_message = Some(AgentMessage::default());
|
self.pending_message = Some(AgentMessage::default());
|
||||||
}
|
}
|
||||||
@@ -1416,7 +1418,7 @@ impl Thread {
|
|||||||
),
|
),
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
UsageUpdate(usage) => {
|
TokenUsage(usage) => {
|
||||||
telemetry::event!(
|
telemetry::event!(
|
||||||
"Agent Thread Completion Usage Updated",
|
"Agent Thread Completion Usage Updated",
|
||||||
thread_id = self.id.to_string(),
|
thread_id = self.id.to_string(),
|
||||||
@@ -1430,20 +1432,16 @@ impl Thread {
|
|||||||
);
|
);
|
||||||
self.update_token_usage(usage, cx);
|
self.update_token_usage(usage, cx);
|
||||||
}
|
}
|
||||||
StatusUpdate(CompletionRequestStatus::UsageUpdated { amount, limit }) => {
|
RequestUsage { amount, limit } => {
|
||||||
self.update_model_request_usage(amount, limit, cx);
|
self.update_model_request_usage(amount, limit, cx);
|
||||||
}
|
}
|
||||||
StatusUpdate(
|
ToolUseLimitReached => {
|
||||||
CompletionRequestStatus::Started
|
|
||||||
| CompletionRequestStatus::Queued { .. }
|
|
||||||
| CompletionRequestStatus::Failed { .. },
|
|
||||||
) => {}
|
|
||||||
StatusUpdate(CompletionRequestStatus::ToolUseLimitReached) => {
|
|
||||||
self.tool_use_limit_reached = true;
|
self.tool_use_limit_reached = true;
|
||||||
}
|
}
|
||||||
Stop(StopReason::Refusal) => return Err(CompletionError::Refusal.into()),
|
Stop(StopReason::Refusal) => return Err(CompletionError::Refusal.into()),
|
||||||
Stop(StopReason::MaxTokens) => return Err(CompletionError::MaxTokens.into()),
|
Stop(StopReason::MaxTokens) => return Err(CompletionError::MaxTokens.into()),
|
||||||
Stop(StopReason::ToolUse | StopReason::EndTurn) => {}
|
Stop(StopReason::ToolUse | StopReason::EndTurn) => {}
|
||||||
|
Started | Queued { .. } => {}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(None)
|
Ok(None)
|
||||||
@@ -1687,9 +1685,7 @@ impl Thread {
|
|||||||
let event = event.log_err()?;
|
let event = event.log_err()?;
|
||||||
let text = match event {
|
let text = match event {
|
||||||
LanguageModelCompletionEvent::Text(text) => text,
|
LanguageModelCompletionEvent::Text(text) => text,
|
||||||
LanguageModelCompletionEvent::StatusUpdate(
|
LanguageModelCompletionEvent::RequestUsage { amount, limit } => {
|
||||||
CompletionRequestStatus::UsageUpdated { amount, limit },
|
|
||||||
) => {
|
|
||||||
this.update(cx, |thread, cx| {
|
this.update(cx, |thread, cx| {
|
||||||
thread.update_model_request_usage(amount, limit, cx);
|
thread.update_model_request_usage(amount, limit, cx);
|
||||||
})
|
})
|
||||||
@@ -1753,9 +1749,7 @@ impl Thread {
|
|||||||
let event = event?;
|
let event = event?;
|
||||||
let text = match event {
|
let text = match event {
|
||||||
LanguageModelCompletionEvent::Text(text) => text,
|
LanguageModelCompletionEvent::Text(text) => text,
|
||||||
LanguageModelCompletionEvent::StatusUpdate(
|
LanguageModelCompletionEvent::RequestUsage { amount, limit } => {
|
||||||
CompletionRequestStatus::UsageUpdated { amount, limit },
|
|
||||||
) => {
|
|
||||||
this.update(cx, |thread, cx| {
|
this.update(cx, |thread, cx| {
|
||||||
thread.update_model_request_usage(amount, limit, cx);
|
thread.update_model_request_usage(amount, limit, cx);
|
||||||
})?;
|
})?;
|
||||||
|
|||||||
@@ -7,9 +7,10 @@ use assistant_slash_command::{
|
|||||||
use assistant_slash_commands::FileCommandMetadata;
|
use assistant_slash_commands::FileCommandMetadata;
|
||||||
use client::{self, ModelRequestUsage, RequestUsage, proto, telemetry::Telemetry};
|
use client::{self, ModelRequestUsage, RequestUsage, proto, telemetry::Telemetry};
|
||||||
use clock::ReplicaId;
|
use clock::ReplicaId;
|
||||||
use cloud_llm_client::{CompletionIntent, CompletionRequestStatus, UsageLimit};
|
use cloud_llm_client::{CompletionIntent, UsageLimit};
|
||||||
use collections::{HashMap, HashSet};
|
use collections::{HashMap, HashSet};
|
||||||
use fs::{Fs, RenameOptions};
|
use fs::{Fs, RenameOptions};
|
||||||
|
|
||||||
use futures::{FutureExt, StreamExt, future::Shared};
|
use futures::{FutureExt, StreamExt, future::Shared};
|
||||||
use gpui::{
|
use gpui::{
|
||||||
App, AppContext as _, Context, Entity, EventEmitter, RenderImage, SharedString, Subscription,
|
App, AppContext as _, Context, Entity, EventEmitter, RenderImage, SharedString, Subscription,
|
||||||
@@ -2073,14 +2074,15 @@ impl TextThread {
|
|||||||
});
|
});
|
||||||
|
|
||||||
match event {
|
match event {
|
||||||
LanguageModelCompletionEvent::StatusUpdate(status_update) => {
|
LanguageModelCompletionEvent::Started |
|
||||||
if let CompletionRequestStatus::UsageUpdated { amount, limit } = status_update {
|
LanguageModelCompletionEvent::Queued {..} |
|
||||||
this.update_model_request_usage(
|
LanguageModelCompletionEvent::ToolUseLimitReached { .. } => {}
|
||||||
amount as u32,
|
LanguageModelCompletionEvent::RequestUsage { amount, limit } => {
|
||||||
limit,
|
this.update_model_request_usage(
|
||||||
cx,
|
amount as u32,
|
||||||
);
|
limit,
|
||||||
}
|
cx,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
LanguageModelCompletionEvent::StartMessage { .. } => {}
|
LanguageModelCompletionEvent::StartMessage { .. } => {}
|
||||||
LanguageModelCompletionEvent::Stop(reason) => {
|
LanguageModelCompletionEvent::Stop(reason) => {
|
||||||
@@ -2142,7 +2144,7 @@ impl TextThread {
|
|||||||
}
|
}
|
||||||
LanguageModelCompletionEvent::ToolUse(_) |
|
LanguageModelCompletionEvent::ToolUse(_) |
|
||||||
LanguageModelCompletionEvent::ToolUseJsonParseError { .. } |
|
LanguageModelCompletionEvent::ToolUseJsonParseError { .. } |
|
||||||
LanguageModelCompletionEvent::UsageUpdate(_) => {}
|
LanguageModelCompletionEvent::TokenUsage(_) => {}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
@@ -1250,9 +1250,12 @@ pub fn response_events_to_markdown(
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
Ok(
|
Ok(
|
||||||
LanguageModelCompletionEvent::UsageUpdate(_)
|
LanguageModelCompletionEvent::TokenUsage(_)
|
||||||
|
| LanguageModelCompletionEvent::ToolUseLimitReached
|
||||||
| LanguageModelCompletionEvent::StartMessage { .. }
|
| LanguageModelCompletionEvent::StartMessage { .. }
|
||||||
| LanguageModelCompletionEvent::StatusUpdate { .. },
|
| LanguageModelCompletionEvent::RequestUsage { .. }
|
||||||
|
| LanguageModelCompletionEvent::Queued { .. }
|
||||||
|
| LanguageModelCompletionEvent::Started,
|
||||||
) => {}
|
) => {}
|
||||||
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||||
json_parse_error, ..
|
json_parse_error, ..
|
||||||
@@ -1335,11 +1338,14 @@ impl ThreadDialog {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Skip these
|
// Skip these
|
||||||
Ok(LanguageModelCompletionEvent::UsageUpdate(_))
|
Ok(LanguageModelCompletionEvent::TokenUsage(_))
|
||||||
| Ok(LanguageModelCompletionEvent::RedactedThinking { .. })
|
| Ok(LanguageModelCompletionEvent::RedactedThinking { .. })
|
||||||
| Ok(LanguageModelCompletionEvent::StatusUpdate { .. })
|
|
||||||
| Ok(LanguageModelCompletionEvent::StartMessage { .. })
|
| Ok(LanguageModelCompletionEvent::StartMessage { .. })
|
||||||
| Ok(LanguageModelCompletionEvent::Stop(_)) => {}
|
| Ok(LanguageModelCompletionEvent::Stop(_))
|
||||||
|
| Ok(LanguageModelCompletionEvent::Queued { .. })
|
||||||
|
| Ok(LanguageModelCompletionEvent::Started)
|
||||||
|
| Ok(LanguageModelCompletionEvent::RequestUsage { .. })
|
||||||
|
| Ok(LanguageModelCompletionEvent::ToolUseLimitReached) => {}
|
||||||
|
|
||||||
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||||
json_parse_error,
|
json_parse_error,
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ pub mod fake_provider;
|
|||||||
use anthropic::{AnthropicError, parse_prompt_too_long};
|
use anthropic::{AnthropicError, parse_prompt_too_long};
|
||||||
use anyhow::{Result, anyhow};
|
use anyhow::{Result, anyhow};
|
||||||
use client::Client;
|
use client::Client;
|
||||||
use cloud_llm_client::{CompletionMode, CompletionRequestStatus};
|
use cloud_llm_client::{CompletionMode, CompletionRequestStatus, UsageLimit};
|
||||||
use futures::FutureExt;
|
use futures::FutureExt;
|
||||||
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
|
use futures::{StreamExt, future::BoxFuture, stream::BoxStream};
|
||||||
use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
|
use gpui::{AnyView, App, AsyncApp, SharedString, Task, Window};
|
||||||
@@ -70,7 +70,15 @@ pub fn init_settings(cx: &mut App) {
|
|||||||
/// A completion event from a language model.
|
/// A completion event from a language model.
|
||||||
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
|
#[derive(Debug, PartialEq, Clone, Serialize, Deserialize)]
|
||||||
pub enum LanguageModelCompletionEvent {
|
pub enum LanguageModelCompletionEvent {
|
||||||
StatusUpdate(CompletionRequestStatus),
|
Queued {
|
||||||
|
position: usize,
|
||||||
|
},
|
||||||
|
Started,
|
||||||
|
RequestUsage {
|
||||||
|
amount: usize,
|
||||||
|
limit: UsageLimit,
|
||||||
|
},
|
||||||
|
ToolUseLimitReached,
|
||||||
Stop(StopReason),
|
Stop(StopReason),
|
||||||
Text(String),
|
Text(String),
|
||||||
Thinking {
|
Thinking {
|
||||||
@@ -90,88 +98,93 @@ pub enum LanguageModelCompletionEvent {
|
|||||||
StartMessage {
|
StartMessage {
|
||||||
message_id: String,
|
message_id: String,
|
||||||
},
|
},
|
||||||
UsageUpdate(TokenUsage),
|
TokenUsage(TokenUsage),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LanguageModelCompletionEvent {
|
||||||
|
pub fn from_completion_request_status(
|
||||||
|
status: CompletionRequestStatus,
|
||||||
|
) -> Result<Self, LanguageModelCompletionError> {
|
||||||
|
match status {
|
||||||
|
CompletionRequestStatus::Queued { position } => {
|
||||||
|
Ok(LanguageModelCompletionEvent::Queued { position })
|
||||||
|
}
|
||||||
|
CompletionRequestStatus::Started => Ok(LanguageModelCompletionEvent::Started),
|
||||||
|
CompletionRequestStatus::UsageUpdated { amount, limit } => {
|
||||||
|
Ok(LanguageModelCompletionEvent::RequestUsage { amount, limit })
|
||||||
|
}
|
||||||
|
CompletionRequestStatus::ToolUseLimitReached => {
|
||||||
|
Ok(LanguageModelCompletionEvent::ToolUseLimitReached)
|
||||||
|
}
|
||||||
|
CompletionRequestStatus::Failed {
|
||||||
|
code,
|
||||||
|
message,
|
||||||
|
request_id: _,
|
||||||
|
retry_after,
|
||||||
|
} => Err(LanguageModelCompletionError::from_cloud_failure(
|
||||||
|
code,
|
||||||
|
message,
|
||||||
|
retry_after.map(Duration::from_secs_f64),
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
pub enum LanguageModelCompletionError {
|
pub enum LanguageModelCompletionError {
|
||||||
#[error("prompt too large for context window")]
|
#[error("prompt too large for context window")]
|
||||||
PromptTooLarge { tokens: Option<u64> },
|
PromptTooLarge { tokens: Option<u64> },
|
||||||
#[error("missing {provider} API key")]
|
#[error("missing API key")]
|
||||||
NoApiKey { provider: LanguageModelProviderName },
|
NoApiKey,
|
||||||
#[error("{provider}'s API rate limit exceeded")]
|
#[error("API rate limit exceeded")]
|
||||||
RateLimitExceeded {
|
RateLimitExceeded { retry_after: Option<Duration> },
|
||||||
provider: LanguageModelProviderName,
|
#[error("API servers are overloaded right now")]
|
||||||
retry_after: Option<Duration>,
|
ServerOverloaded { retry_after: Option<Duration> },
|
||||||
},
|
#[error("API server reported an internal server error: {message}")]
|
||||||
#[error("{provider}'s API servers are overloaded right now")]
|
ApiInternalServerError { message: String },
|
||||||
ServerOverloaded {
|
|
||||||
provider: LanguageModelProviderName,
|
|
||||||
retry_after: Option<Duration>,
|
|
||||||
},
|
|
||||||
#[error("{provider}'s API server reported an internal server error: {message}")]
|
|
||||||
ApiInternalServerError {
|
|
||||||
provider: LanguageModelProviderName,
|
|
||||||
message: String,
|
|
||||||
},
|
|
||||||
#[error("{message}")]
|
#[error("{message}")]
|
||||||
UpstreamProviderError {
|
UpstreamProviderError {
|
||||||
message: String,
|
message: String,
|
||||||
status: StatusCode,
|
status: StatusCode,
|
||||||
retry_after: Option<Duration>,
|
retry_after: Option<Duration>,
|
||||||
},
|
},
|
||||||
#[error("HTTP response error from {provider}'s API: status {status_code} - {message:?}")]
|
#[error("HTTP response error from API: status {status_code} - {message:?}")]
|
||||||
HttpResponseError {
|
HttpResponseError {
|
||||||
provider: LanguageModelProviderName,
|
|
||||||
status_code: StatusCode,
|
status_code: StatusCode,
|
||||||
message: String,
|
message: String,
|
||||||
},
|
},
|
||||||
|
|
||||||
// Client errors
|
// Client errors
|
||||||
#[error("invalid request format to {provider}'s API: {message}")]
|
#[error("invalid request format to API: {message}")]
|
||||||
BadRequestFormat {
|
BadRequestFormat { message: String },
|
||||||
provider: LanguageModelProviderName,
|
#[error("authentication error with API: {message}")]
|
||||||
message: String,
|
AuthenticationError { message: String },
|
||||||
},
|
#[error("Permission error with API: {message}")]
|
||||||
#[error("authentication error with {provider}'s API: {message}")]
|
PermissionError { message: String },
|
||||||
AuthenticationError {
|
|
||||||
provider: LanguageModelProviderName,
|
|
||||||
message: String,
|
|
||||||
},
|
|
||||||
#[error("Permission error with {provider}'s API: {message}")]
|
|
||||||
PermissionError {
|
|
||||||
provider: LanguageModelProviderName,
|
|
||||||
message: String,
|
|
||||||
},
|
|
||||||
#[error("language model provider API endpoint not found")]
|
#[error("language model provider API endpoint not found")]
|
||||||
ApiEndpointNotFound { provider: LanguageModelProviderName },
|
ApiEndpointNotFound,
|
||||||
#[error("I/O error reading response from {provider}'s API")]
|
#[error("I/O error reading response from API")]
|
||||||
ApiReadResponseError {
|
ApiReadResponseError {
|
||||||
provider: LanguageModelProviderName,
|
|
||||||
#[source]
|
#[source]
|
||||||
error: io::Error,
|
error: io::Error,
|
||||||
},
|
},
|
||||||
#[error("error serializing request to {provider} API")]
|
#[error("error serializing request to API")]
|
||||||
SerializeRequest {
|
SerializeRequest {
|
||||||
provider: LanguageModelProviderName,
|
|
||||||
#[source]
|
#[source]
|
||||||
error: serde_json::Error,
|
error: serde_json::Error,
|
||||||
},
|
},
|
||||||
#[error("error building request body to {provider} API")]
|
#[error("error building request body to API")]
|
||||||
BuildRequestBody {
|
BuildRequestBody {
|
||||||
provider: LanguageModelProviderName,
|
|
||||||
#[source]
|
#[source]
|
||||||
error: http::Error,
|
error: http::Error,
|
||||||
},
|
},
|
||||||
#[error("error sending HTTP request to {provider} API")]
|
#[error("error sending HTTP request to API")]
|
||||||
HttpSend {
|
HttpSend {
|
||||||
provider: LanguageModelProviderName,
|
|
||||||
#[source]
|
#[source]
|
||||||
error: anyhow::Error,
|
error: anyhow::Error,
|
||||||
},
|
},
|
||||||
#[error("error deserializing {provider} API response")]
|
#[error("error deserializing API response")]
|
||||||
DeserializeResponse {
|
DeserializeResponse {
|
||||||
provider: LanguageModelProviderName,
|
|
||||||
#[source]
|
#[source]
|
||||||
error: serde_json::Error,
|
error: serde_json::Error,
|
||||||
},
|
},
|
||||||
@@ -182,6 +195,72 @@ pub enum LanguageModelCompletionError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl LanguageModelCompletionError {
|
impl LanguageModelCompletionError {
|
||||||
|
fn display_format(&self, provider: LanguageModelProviderName) {
|
||||||
|
// match self {
|
||||||
|
// #[error("prompt too large for context window")]
|
||||||
|
// PromptTooLarge { tokens: Option<u64> },
|
||||||
|
// #[error("missing API key")]
|
||||||
|
// NoApiKey,
|
||||||
|
// #[error("API rate limit exceeded")]
|
||||||
|
// RateLimitExceeded { retry_after: Option<Duration> },
|
||||||
|
// #[error("API servers are overloaded right now")]
|
||||||
|
// ServerOverloaded { retry_after: Option<Duration> },
|
||||||
|
// #[error("API server reported an internal server error: {message}")]
|
||||||
|
// ApiInternalServerError { message: String },
|
||||||
|
// #[error("{message}")]
|
||||||
|
// UpstreamProviderError {
|
||||||
|
// message: String,
|
||||||
|
// status: StatusCode,
|
||||||
|
// retry_after: Option<Duration>,
|
||||||
|
// },
|
||||||
|
// #[error("HTTP response error from API: status {status_code} - {message:?}")]
|
||||||
|
// HttpResponseError {
|
||||||
|
// status_code: StatusCode,
|
||||||
|
// message: String,
|
||||||
|
// },
|
||||||
|
|
||||||
|
// // Client errors
|
||||||
|
// #[error("invalid request format to API: {message}")]
|
||||||
|
// BadRequestFormat { message: String },
|
||||||
|
// #[error("authentication error with API: {message}")]
|
||||||
|
// AuthenticationError { message: String },
|
||||||
|
// #[error("Permission error with API: {message}")]
|
||||||
|
// PermissionError { message: String },
|
||||||
|
// #[error("language model provider API endpoint not found")]
|
||||||
|
// ApiEndpointNotFound,
|
||||||
|
// #[error("I/O error reading response from API")]
|
||||||
|
// ApiReadResponseError {
|
||||||
|
// #[source]
|
||||||
|
// error: io::Error,
|
||||||
|
// },
|
||||||
|
// #[error("error serializing request to API")]
|
||||||
|
// SerializeRequest {
|
||||||
|
// #[source]
|
||||||
|
// error: serde_json::Error,
|
||||||
|
// },
|
||||||
|
// #[error("error building request body to API")]
|
||||||
|
// BuildRequestBody {
|
||||||
|
// #[source]
|
||||||
|
// error: http::Error,
|
||||||
|
// },
|
||||||
|
// #[error("error sending HTTP request to API")]
|
||||||
|
// HttpSend {
|
||||||
|
// #[source]
|
||||||
|
// error: anyhow::Error,
|
||||||
|
// },
|
||||||
|
// #[error("error deserializing API response")]
|
||||||
|
// DeserializeResponse {
|
||||||
|
// #[source]
|
||||||
|
// error: serde_json::Error,
|
||||||
|
// },
|
||||||
|
|
||||||
|
// // TODO: Ideally this would be removed in favor of having a comprehensive list of errors.
|
||||||
|
// #[error(transparent)]
|
||||||
|
// Other(#[from] anyhow::Error),
|
||||||
|
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
|
fn parse_upstream_error_json(message: &str) -> Option<(StatusCode, String)> {
|
||||||
let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
|
let error_json = serde_json::from_str::<serde_json::Value>(message).ok()?;
|
||||||
let upstream_status = error_json
|
let upstream_status = error_json
|
||||||
@@ -198,7 +277,6 @@ impl LanguageModelCompletionError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_cloud_failure(
|
pub fn from_cloud_failure(
|
||||||
upstream_provider: LanguageModelProviderName,
|
|
||||||
code: String,
|
code: String,
|
||||||
message: String,
|
message: String,
|
||||||
retry_after: Option<Duration>,
|
retry_after: Option<Duration>,
|
||||||
@@ -214,58 +292,46 @@ impl LanguageModelCompletionError {
|
|||||||
if let Some((upstream_status, inner_message)) =
|
if let Some((upstream_status, inner_message)) =
|
||||||
Self::parse_upstream_error_json(&message)
|
Self::parse_upstream_error_json(&message)
|
||||||
{
|
{
|
||||||
return Self::from_http_status(
|
return Self::from_http_status(upstream_status, inner_message, retry_after);
|
||||||
upstream_provider,
|
|
||||||
upstream_status,
|
|
||||||
inner_message,
|
|
||||||
retry_after,
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
anyhow!("completion request failed, code: {code}, message: {message}").into()
|
Self::Other(anyhow!(
|
||||||
|
"completion request failed, code: {code}, message: {message}"
|
||||||
|
))
|
||||||
} else if let Some(status_code) = code
|
} else if let Some(status_code) = code
|
||||||
.strip_prefix("upstream_http_")
|
.strip_prefix("upstream_http_")
|
||||||
.and_then(|code| StatusCode::from_str(code).ok())
|
.and_then(|code| StatusCode::from_str(code).ok())
|
||||||
{
|
{
|
||||||
Self::from_http_status(upstream_provider, status_code, message, retry_after)
|
Self::from_http_status(status_code, message, retry_after)
|
||||||
} else if let Some(status_code) = code
|
} else if let Some(status_code) = code
|
||||||
.strip_prefix("http_")
|
.strip_prefix("http_")
|
||||||
.and_then(|code| StatusCode::from_str(code).ok())
|
.and_then(|code| StatusCode::from_str(code).ok())
|
||||||
{
|
{
|
||||||
Self::from_http_status(ZED_CLOUD_PROVIDER_NAME, status_code, message, retry_after)
|
Self::from_http_status(status_code, message, retry_after)
|
||||||
} else {
|
} else {
|
||||||
anyhow!("completion request failed, code: {code}, message: {message}").into()
|
Self::Other(anyhow!(
|
||||||
|
"completion request failed, code: {code}, message: {message}"
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_http_status(
|
pub fn from_http_status(
|
||||||
provider: LanguageModelProviderName,
|
|
||||||
status_code: StatusCode,
|
status_code: StatusCode,
|
||||||
message: String,
|
message: String,
|
||||||
retry_after: Option<Duration>,
|
retry_after: Option<Duration>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
match status_code {
|
match status_code {
|
||||||
StatusCode::BAD_REQUEST => Self::BadRequestFormat { provider, message },
|
StatusCode::BAD_REQUEST => Self::BadRequestFormat { message },
|
||||||
StatusCode::UNAUTHORIZED => Self::AuthenticationError { provider, message },
|
StatusCode::UNAUTHORIZED => Self::AuthenticationError { message },
|
||||||
StatusCode::FORBIDDEN => Self::PermissionError { provider, message },
|
StatusCode::FORBIDDEN => Self::PermissionError { message },
|
||||||
StatusCode::NOT_FOUND => Self::ApiEndpointNotFound { provider },
|
StatusCode::NOT_FOUND => Self::ApiEndpointNotFound,
|
||||||
StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
|
StatusCode::PAYLOAD_TOO_LARGE => Self::PromptTooLarge {
|
||||||
tokens: parse_prompt_too_long(&message),
|
tokens: parse_prompt_too_long(&message),
|
||||||
},
|
},
|
||||||
StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded {
|
StatusCode::TOO_MANY_REQUESTS => Self::RateLimitExceeded { retry_after },
|
||||||
provider,
|
StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { message },
|
||||||
retry_after,
|
StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded { retry_after },
|
||||||
},
|
_ if status_code.as_u16() == 529 => Self::ServerOverloaded { retry_after },
|
||||||
StatusCode::INTERNAL_SERVER_ERROR => Self::ApiInternalServerError { provider, message },
|
|
||||||
StatusCode::SERVICE_UNAVAILABLE => Self::ServerOverloaded {
|
|
||||||
provider,
|
|
||||||
retry_after,
|
|
||||||
},
|
|
||||||
_ if status_code.as_u16() == 529 => Self::ServerOverloaded {
|
|
||||||
provider,
|
|
||||||
retry_after,
|
|
||||||
},
|
|
||||||
_ => Self::HttpResponseError {
|
_ => Self::HttpResponseError {
|
||||||
provider,
|
|
||||||
status_code,
|
status_code,
|
||||||
message,
|
message,
|
||||||
},
|
},
|
||||||
@@ -275,31 +341,25 @@ impl LanguageModelCompletionError {
|
|||||||
|
|
||||||
impl From<AnthropicError> for LanguageModelCompletionError {
|
impl From<AnthropicError> for LanguageModelCompletionError {
|
||||||
fn from(error: AnthropicError) -> Self {
|
fn from(error: AnthropicError) -> Self {
|
||||||
let provider = ANTHROPIC_PROVIDER_NAME;
|
|
||||||
match error {
|
match error {
|
||||||
AnthropicError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
|
AnthropicError::SerializeRequest(error) => Self::SerializeRequest { error },
|
||||||
AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
|
AnthropicError::BuildRequestBody(error) => Self::BuildRequestBody { error },
|
||||||
AnthropicError::HttpSend(error) => Self::HttpSend { provider, error },
|
AnthropicError::HttpSend(error) => Self::HttpSend { error },
|
||||||
AnthropicError::DeserializeResponse(error) => {
|
AnthropicError::DeserializeResponse(error) => Self::DeserializeResponse { error },
|
||||||
Self::DeserializeResponse { provider, error }
|
AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { error },
|
||||||
}
|
|
||||||
AnthropicError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
|
|
||||||
AnthropicError::HttpResponseError {
|
AnthropicError::HttpResponseError {
|
||||||
status_code,
|
status_code,
|
||||||
message,
|
message,
|
||||||
} => Self::HttpResponseError {
|
} => Self::HttpResponseError {
|
||||||
provider,
|
|
||||||
status_code,
|
status_code,
|
||||||
message,
|
message,
|
||||||
},
|
},
|
||||||
AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
|
AnthropicError::RateLimit { retry_after } => Self::RateLimitExceeded {
|
||||||
provider,
|
|
||||||
retry_after: Some(retry_after),
|
retry_after: Some(retry_after),
|
||||||
},
|
},
|
||||||
AnthropicError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
|
AnthropicError::ServerOverloaded { retry_after } => {
|
||||||
provider,
|
Self::ServerOverloaded { retry_after }
|
||||||
retry_after,
|
}
|
||||||
},
|
|
||||||
AnthropicError::ApiError(api_error) => api_error.into(),
|
AnthropicError::ApiError(api_error) => api_error.into(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -308,37 +368,26 @@ impl From<AnthropicError> for LanguageModelCompletionError {
|
|||||||
impl From<anthropic::ApiError> for LanguageModelCompletionError {
|
impl From<anthropic::ApiError> for LanguageModelCompletionError {
|
||||||
fn from(error: anthropic::ApiError) -> Self {
|
fn from(error: anthropic::ApiError) -> Self {
|
||||||
use anthropic::ApiErrorCode::*;
|
use anthropic::ApiErrorCode::*;
|
||||||
let provider = ANTHROPIC_PROVIDER_NAME;
|
|
||||||
match error.code() {
|
match error.code() {
|
||||||
Some(code) => match code {
|
Some(code) => match code {
|
||||||
InvalidRequestError => Self::BadRequestFormat {
|
InvalidRequestError => Self::BadRequestFormat {
|
||||||
provider,
|
|
||||||
message: error.message,
|
message: error.message,
|
||||||
},
|
},
|
||||||
AuthenticationError => Self::AuthenticationError {
|
AuthenticationError => Self::AuthenticationError {
|
||||||
provider,
|
|
||||||
message: error.message,
|
message: error.message,
|
||||||
},
|
},
|
||||||
PermissionError => Self::PermissionError {
|
PermissionError => Self::PermissionError {
|
||||||
provider,
|
|
||||||
message: error.message,
|
message: error.message,
|
||||||
},
|
},
|
||||||
NotFoundError => Self::ApiEndpointNotFound { provider },
|
NotFoundError => Self::ApiEndpointNotFound,
|
||||||
RequestTooLarge => Self::PromptTooLarge {
|
RequestTooLarge => Self::PromptTooLarge {
|
||||||
tokens: parse_prompt_too_long(&error.message),
|
tokens: parse_prompt_too_long(&error.message),
|
||||||
},
|
},
|
||||||
RateLimitError => Self::RateLimitExceeded {
|
RateLimitError => Self::RateLimitExceeded { retry_after: None },
|
||||||
provider,
|
|
||||||
retry_after: None,
|
|
||||||
},
|
|
||||||
ApiError => Self::ApiInternalServerError {
|
ApiError => Self::ApiInternalServerError {
|
||||||
provider,
|
|
||||||
message: error.message,
|
message: error.message,
|
||||||
},
|
},
|
||||||
OverloadedError => Self::ServerOverloaded {
|
OverloadedError => Self::ServerOverloaded { retry_after: None },
|
||||||
provider,
|
|
||||||
retry_after: None,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
None => Self::Other(error.into()),
|
None => Self::Other(error.into()),
|
||||||
}
|
}
|
||||||
@@ -349,7 +398,7 @@ impl From<open_ai::RequestError> for LanguageModelCompletionError {
|
|||||||
fn from(error: open_ai::RequestError) -> Self {
|
fn from(error: open_ai::RequestError) -> Self {
|
||||||
match error {
|
match error {
|
||||||
open_ai::RequestError::HttpResponseError {
|
open_ai::RequestError::HttpResponseError {
|
||||||
provider,
|
provider: _,
|
||||||
status_code,
|
status_code,
|
||||||
body,
|
body,
|
||||||
headers,
|
headers,
|
||||||
@@ -359,7 +408,7 @@ impl From<open_ai::RequestError> for LanguageModelCompletionError {
|
|||||||
.and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
|
.and_then(|val| val.to_str().ok()?.parse::<u64>().ok())
|
||||||
.map(Duration::from_secs);
|
.map(Duration::from_secs);
|
||||||
|
|
||||||
Self::from_http_status(provider.into(), status_code, body, retry_after)
|
Self::from_http_status(status_code, body, retry_after)
|
||||||
}
|
}
|
||||||
open_ai::RequestError::Other(e) => Self::Other(e),
|
open_ai::RequestError::Other(e) => Self::Other(e),
|
||||||
}
|
}
|
||||||
@@ -368,23 +417,18 @@ impl From<open_ai::RequestError> for LanguageModelCompletionError {
|
|||||||
|
|
||||||
impl From<OpenRouterError> for LanguageModelCompletionError {
|
impl From<OpenRouterError> for LanguageModelCompletionError {
|
||||||
fn from(error: OpenRouterError) -> Self {
|
fn from(error: OpenRouterError) -> Self {
|
||||||
let provider = LanguageModelProviderName::new("OpenRouter");
|
|
||||||
match error {
|
match error {
|
||||||
OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { provider, error },
|
OpenRouterError::SerializeRequest(error) => Self::SerializeRequest { error },
|
||||||
OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { provider, error },
|
OpenRouterError::BuildRequestBody(error) => Self::BuildRequestBody { error },
|
||||||
OpenRouterError::HttpSend(error) => Self::HttpSend { provider, error },
|
OpenRouterError::HttpSend(error) => Self::HttpSend { error },
|
||||||
OpenRouterError::DeserializeResponse(error) => {
|
OpenRouterError::DeserializeResponse(error) => Self::DeserializeResponse { error },
|
||||||
Self::DeserializeResponse { provider, error }
|
OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { error },
|
||||||
}
|
|
||||||
OpenRouterError::ReadResponse(error) => Self::ApiReadResponseError { provider, error },
|
|
||||||
OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
|
OpenRouterError::RateLimit { retry_after } => Self::RateLimitExceeded {
|
||||||
provider,
|
|
||||||
retry_after: Some(retry_after),
|
retry_after: Some(retry_after),
|
||||||
},
|
},
|
||||||
OpenRouterError::ServerOverloaded { retry_after } => Self::ServerOverloaded {
|
OpenRouterError::ServerOverloaded { retry_after } => {
|
||||||
provider,
|
Self::ServerOverloaded { retry_after }
|
||||||
retry_after,
|
}
|
||||||
},
|
|
||||||
OpenRouterError::ApiError(api_error) => api_error.into(),
|
OpenRouterError::ApiError(api_error) => api_error.into(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -393,41 +437,28 @@ impl From<OpenRouterError> for LanguageModelCompletionError {
|
|||||||
impl From<open_router::ApiError> for LanguageModelCompletionError {
|
impl From<open_router::ApiError> for LanguageModelCompletionError {
|
||||||
fn from(error: open_router::ApiError) -> Self {
|
fn from(error: open_router::ApiError) -> Self {
|
||||||
use open_router::ApiErrorCode::*;
|
use open_router::ApiErrorCode::*;
|
||||||
let provider = LanguageModelProviderName::new("OpenRouter");
|
|
||||||
match error.code {
|
match error.code {
|
||||||
InvalidRequestError => Self::BadRequestFormat {
|
InvalidRequestError => Self::BadRequestFormat {
|
||||||
provider,
|
|
||||||
message: error.message,
|
message: error.message,
|
||||||
},
|
},
|
||||||
AuthenticationError => Self::AuthenticationError {
|
AuthenticationError => Self::AuthenticationError {
|
||||||
provider,
|
|
||||||
message: error.message,
|
message: error.message,
|
||||||
},
|
},
|
||||||
PaymentRequiredError => Self::AuthenticationError {
|
PaymentRequiredError => Self::AuthenticationError {
|
||||||
provider,
|
|
||||||
message: format!("Payment required: {}", error.message),
|
message: format!("Payment required: {}", error.message),
|
||||||
},
|
},
|
||||||
PermissionError => Self::PermissionError {
|
PermissionError => Self::PermissionError {
|
||||||
provider,
|
|
||||||
message: error.message,
|
message: error.message,
|
||||||
},
|
},
|
||||||
RequestTimedOut => Self::HttpResponseError {
|
RequestTimedOut => Self::HttpResponseError {
|
||||||
provider,
|
|
||||||
status_code: StatusCode::REQUEST_TIMEOUT,
|
status_code: StatusCode::REQUEST_TIMEOUT,
|
||||||
message: error.message,
|
message: error.message,
|
||||||
},
|
},
|
||||||
RateLimitError => Self::RateLimitExceeded {
|
RateLimitError => Self::RateLimitExceeded { retry_after: None },
|
||||||
provider,
|
|
||||||
retry_after: None,
|
|
||||||
},
|
|
||||||
ApiError => Self::ApiInternalServerError {
|
ApiError => Self::ApiInternalServerError {
|
||||||
provider,
|
|
||||||
message: error.message,
|
message: error.message,
|
||||||
},
|
},
|
||||||
OverloadedError => Self::ServerOverloaded {
|
OverloadedError => Self::ServerOverloaded { retry_after: None },
|
||||||
provider,
|
|
||||||
retry_after: None,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -633,7 +664,10 @@ pub trait LanguageModel: Send + Sync {
|
|||||||
let last_token_usage = last_token_usage.clone();
|
let last_token_usage = last_token_usage.clone();
|
||||||
async move {
|
async move {
|
||||||
match result {
|
match result {
|
||||||
Ok(LanguageModelCompletionEvent::StatusUpdate { .. }) => None,
|
Ok(LanguageModelCompletionEvent::Queued { .. }) => None,
|
||||||
|
Ok(LanguageModelCompletionEvent::Started) => None,
|
||||||
|
Ok(LanguageModelCompletionEvent::RequestUsage { .. }) => None,
|
||||||
|
Ok(LanguageModelCompletionEvent::ToolUseLimitReached) => None,
|
||||||
Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
|
Ok(LanguageModelCompletionEvent::StartMessage { .. }) => None,
|
||||||
Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
|
Ok(LanguageModelCompletionEvent::Text(text)) => Some(Ok(text)),
|
||||||
Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
|
Ok(LanguageModelCompletionEvent::Thinking { .. }) => None,
|
||||||
@@ -643,7 +677,7 @@ pub trait LanguageModel: Send + Sync {
|
|||||||
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
|
||||||
..
|
..
|
||||||
}) => None,
|
}) => None,
|
||||||
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
|
Ok(LanguageModelCompletionEvent::TokenUsage(token_usage)) => {
|
||||||
*last_token_usage.lock() = token_usage;
|
*last_token_usage.lock() = token_usage;
|
||||||
None
|
None
|
||||||
}
|
}
|
||||||
@@ -832,16 +866,13 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_from_cloud_failure_with_upstream_http_error() {
|
fn test_from_cloud_failure_with_upstream_http_error() {
|
||||||
let error = LanguageModelCompletionError::from_cloud_failure(
|
let error = LanguageModelCompletionError::from_cloud_failure(
|
||||||
String::from("anthropic").into(),
|
|
||||||
"upstream_http_error".to_string(),
|
"upstream_http_error".to_string(),
|
||||||
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
|
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
|
||||||
match error {
|
match error {
|
||||||
LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
|
LanguageModelCompletionError::ServerOverloaded { .. } => {}
|
||||||
assert_eq!(provider.0, "anthropic");
|
|
||||||
}
|
|
||||||
_ => panic!(
|
_ => panic!(
|
||||||
"Expected ServerOverloaded error for 503 status, got: {:?}",
|
"Expected ServerOverloaded error for 503 status, got: {:?}",
|
||||||
error
|
error
|
||||||
@@ -849,15 +880,13 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let error = LanguageModelCompletionError::from_cloud_failure(
|
let error = LanguageModelCompletionError::from_cloud_failure(
|
||||||
String::from("anthropic").into(),
|
|
||||||
"upstream_http_error".to_string(),
|
"upstream_http_error".to_string(),
|
||||||
r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
|
r#"{"code":"upstream_http_error","message":"Internal server error","upstream_status":500}"#.to_string(),
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
|
||||||
match error {
|
match error {
|
||||||
LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
|
LanguageModelCompletionError::ApiInternalServerError { message } => {
|
||||||
assert_eq!(provider.0, "anthropic");
|
|
||||||
assert_eq!(message, "Internal server error");
|
assert_eq!(message, "Internal server error");
|
||||||
}
|
}
|
||||||
_ => panic!(
|
_ => panic!(
|
||||||
@@ -870,16 +899,13 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_from_cloud_failure_with_standard_format() {
|
fn test_from_cloud_failure_with_standard_format() {
|
||||||
let error = LanguageModelCompletionError::from_cloud_failure(
|
let error = LanguageModelCompletionError::from_cloud_failure(
|
||||||
String::from("anthropic").into(),
|
|
||||||
"upstream_http_503".to_string(),
|
"upstream_http_503".to_string(),
|
||||||
"Service unavailable".to_string(),
|
"Service unavailable".to_string(),
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
|
||||||
match error {
|
match error {
|
||||||
LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
|
LanguageModelCompletionError::ServerOverloaded { .. } => {}
|
||||||
assert_eq!(provider.0, "anthropic");
|
|
||||||
}
|
|
||||||
_ => panic!("Expected ServerOverloaded error for upstream_http_503"),
|
_ => panic!("Expected ServerOverloaded error for upstream_http_503"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -887,16 +913,13 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_upstream_http_error_connection_timeout() {
|
fn test_upstream_http_error_connection_timeout() {
|
||||||
let error = LanguageModelCompletionError::from_cloud_failure(
|
let error = LanguageModelCompletionError::from_cloud_failure(
|
||||||
String::from("anthropic").into(),
|
|
||||||
"upstream_http_error".to_string(),
|
"upstream_http_error".to_string(),
|
||||||
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
|
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":503}"#.to_string(),
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
|
||||||
match error {
|
match error {
|
||||||
LanguageModelCompletionError::ServerOverloaded { provider, .. } => {
|
LanguageModelCompletionError::ServerOverloaded { .. } => {}
|
||||||
assert_eq!(provider.0, "anthropic");
|
|
||||||
}
|
|
||||||
_ => panic!(
|
_ => panic!(
|
||||||
"Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
|
"Expected ServerOverloaded error for connection timeout with 503 status, got: {:?}",
|
||||||
error
|
error
|
||||||
@@ -904,15 +927,13 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let error = LanguageModelCompletionError::from_cloud_failure(
|
let error = LanguageModelCompletionError::from_cloud_failure(
|
||||||
String::from("anthropic").into(),
|
|
||||||
"upstream_http_error".to_string(),
|
"upstream_http_error".to_string(),
|
||||||
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
|
r#"{"code":"upstream_http_error","message":"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout","upstream_status":500}"#.to_string(),
|
||||||
None,
|
None,
|
||||||
);
|
);
|
||||||
|
|
||||||
match error {
|
match error {
|
||||||
LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
|
LanguageModelCompletionError::ApiInternalServerError { message } => {
|
||||||
assert_eq!(provider.0, "anthropic");
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
message,
|
message,
|
||||||
"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
|
"Received an error from the Anthropic API: upstream connect error or disconnect/reset before headers. reset reason: connection timeout"
|
||||||
|
|||||||
@@ -320,9 +320,7 @@ impl AnthropicModel {
|
|||||||
|
|
||||||
async move {
|
async move {
|
||||||
let Some(api_key) = api_key else {
|
let Some(api_key) = api_key else {
|
||||||
return Err(LanguageModelCompletionError::NoApiKey {
|
return Err(LanguageModelCompletionError::NoApiKey);
|
||||||
provider: PROVIDER_NAME,
|
|
||||||
});
|
|
||||||
};
|
};
|
||||||
let request = anthropic::stream_completion(
|
let request = anthropic::stream_completion(
|
||||||
http_client.as_ref(),
|
http_client.as_ref(),
|
||||||
@@ -756,7 +754,7 @@ impl AnthropicEventMapper {
|
|||||||
Event::MessageStart { message } => {
|
Event::MessageStart { message } => {
|
||||||
update_usage(&mut self.usage, &message.usage);
|
update_usage(&mut self.usage, &message.usage);
|
||||||
vec![
|
vec![
|
||||||
Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage(
|
Ok(LanguageModelCompletionEvent::TokenUsage(convert_usage(
|
||||||
&self.usage,
|
&self.usage,
|
||||||
))),
|
))),
|
||||||
Ok(LanguageModelCompletionEvent::StartMessage {
|
Ok(LanguageModelCompletionEvent::StartMessage {
|
||||||
@@ -778,9 +776,9 @@ impl AnthropicEventMapper {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
vec![Ok(LanguageModelCompletionEvent::UsageUpdate(
|
vec![Ok(LanguageModelCompletionEvent::TokenUsage(convert_usage(
|
||||||
convert_usage(&self.usage),
|
&self.usage,
|
||||||
))]
|
)))]
|
||||||
}
|
}
|
||||||
Event::MessageStop => {
|
Event::MessageStop => {
|
||||||
vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))]
|
vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))]
|
||||||
|
|||||||
@@ -975,7 +975,7 @@ pub fn map_to_language_model_completion_events(
|
|||||||
))
|
))
|
||||||
}),
|
}),
|
||||||
ConverseStreamOutput::Metadata(cb_meta) => cb_meta.usage.map(|metadata| {
|
ConverseStreamOutput::Metadata(cb_meta) => cb_meta.usage.map(|metadata| {
|
||||||
Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||||
input_tokens: metadata.input_tokens as u64,
|
input_tokens: metadata.input_tokens as u64,
|
||||||
output_tokens: metadata.output_tokens as u64,
|
output_tokens: metadata.output_tokens as u64,
|
||||||
cache_creation_input_tokens: metadata
|
cache_creation_input_tokens: metadata
|
||||||
|
|||||||
@@ -541,7 +541,6 @@ impl From<ApiError> for LanguageModelCompletionError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
return LanguageModelCompletionError::from_http_status(
|
return LanguageModelCompletionError::from_http_status(
|
||||||
PROVIDER_NAME,
|
|
||||||
error.status,
|
error.status,
|
||||||
cloud_error.message,
|
cloud_error.message,
|
||||||
None,
|
None,
|
||||||
@@ -549,12 +548,7 @@ impl From<ApiError> for LanguageModelCompletionError {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let retry_after = None;
|
let retry_after = None;
|
||||||
LanguageModelCompletionError::from_http_status(
|
LanguageModelCompletionError::from_http_status(error.status, error.body, retry_after)
|
||||||
PROVIDER_NAME,
|
|
||||||
error.status,
|
|
||||||
error.body,
|
|
||||||
retry_after,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -961,7 +955,7 @@ where
|
|||||||
vec![Err(LanguageModelCompletionError::from(error))]
|
vec![Err(LanguageModelCompletionError::from(error))]
|
||||||
}
|
}
|
||||||
Ok(CompletionEvent::Status(event)) => {
|
Ok(CompletionEvent::Status(event)) => {
|
||||||
vec![Ok(LanguageModelCompletionEvent::StatusUpdate(event))]
|
vec![LanguageModelCompletionEvent::from_completion_request_status(event)]
|
||||||
}
|
}
|
||||||
Ok(CompletionEvent::Event(event)) => map_callback(event),
|
Ok(CompletionEvent::Event(event)) => map_callback(event),
|
||||||
})
|
})
|
||||||
@@ -1313,8 +1307,7 @@ mod tests {
|
|||||||
let completion_error: LanguageModelCompletionError = api_error.into();
|
let completion_error: LanguageModelCompletionError = api_error.into();
|
||||||
|
|
||||||
match completion_error {
|
match completion_error {
|
||||||
LanguageModelCompletionError::ApiInternalServerError { provider, message } => {
|
LanguageModelCompletionError::ApiInternalServerError { message } => {
|
||||||
assert_eq!(provider, PROVIDER_NAME);
|
|
||||||
assert_eq!(message, "Regular internal server error");
|
assert_eq!(message, "Regular internal server error");
|
||||||
}
|
}
|
||||||
_ => panic!(
|
_ => panic!(
|
||||||
@@ -1362,9 +1355,7 @@ mod tests {
|
|||||||
let completion_error: LanguageModelCompletionError = api_error.into();
|
let completion_error: LanguageModelCompletionError = api_error.into();
|
||||||
|
|
||||||
match completion_error {
|
match completion_error {
|
||||||
LanguageModelCompletionError::ApiInternalServerError { provider, .. } => {
|
LanguageModelCompletionError::ApiInternalServerError { .. } => {}
|
||||||
assert_eq!(provider, PROVIDER_NAME);
|
|
||||||
}
|
|
||||||
_ => panic!(
|
_ => panic!(
|
||||||
"Expected ApiInternalServerError for invalid JSON, got: {:?}",
|
"Expected ApiInternalServerError for invalid JSON, got: {:?}",
|
||||||
completion_error
|
completion_error
|
||||||
|
|||||||
@@ -422,14 +422,12 @@ pub fn map_to_language_model_completion_events(
|
|||||||
}
|
}
|
||||||
|
|
||||||
if let Some(usage) = event.usage {
|
if let Some(usage) = event.usage {
|
||||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
|
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||||
TokenUsage {
|
input_tokens: usage.prompt_tokens,
|
||||||
input_tokens: usage.prompt_tokens,
|
output_tokens: usage.completion_tokens,
|
||||||
output_tokens: usage.completion_tokens,
|
cache_creation_input_tokens: 0,
|
||||||
cache_creation_input_tokens: 0,
|
cache_read_input_tokens: 0,
|
||||||
cache_read_input_tokens: 0,
|
})));
|
||||||
},
|
|
||||||
)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
match choice.finish_reason.as_deref() {
|
match choice.finish_reason.as_deref() {
|
||||||
@@ -610,7 +608,7 @@ impl CopilotResponsesEventMapper {
|
|||||||
copilot::copilot_responses::StreamEvent::Completed { response } => {
|
copilot::copilot_responses::StreamEvent::Completed { response } => {
|
||||||
let mut events = Vec::new();
|
let mut events = Vec::new();
|
||||||
if let Some(usage) = response.usage {
|
if let Some(usage) = response.usage {
|
||||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||||
input_tokens: usage.input_tokens.unwrap_or(0),
|
input_tokens: usage.input_tokens.unwrap_or(0),
|
||||||
output_tokens: usage.output_tokens.unwrap_or(0),
|
output_tokens: usage.output_tokens.unwrap_or(0),
|
||||||
cache_creation_input_tokens: 0,
|
cache_creation_input_tokens: 0,
|
||||||
@@ -643,7 +641,7 @@ impl CopilotResponsesEventMapper {
|
|||||||
|
|
||||||
let mut events = Vec::new();
|
let mut events = Vec::new();
|
||||||
if let Some(usage) = response.usage {
|
if let Some(usage) = response.usage {
|
||||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||||
input_tokens: usage.input_tokens.unwrap_or(0),
|
input_tokens: usage.input_tokens.unwrap_or(0),
|
||||||
output_tokens: usage.output_tokens.unwrap_or(0),
|
output_tokens: usage.output_tokens.unwrap_or(0),
|
||||||
cache_creation_input_tokens: 0,
|
cache_creation_input_tokens: 0,
|
||||||
@@ -655,7 +653,6 @@ impl CopilotResponsesEventMapper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
copilot::copilot_responses::StreamEvent::Failed { response } => {
|
copilot::copilot_responses::StreamEvent::Failed { response } => {
|
||||||
let provider = PROVIDER_NAME;
|
|
||||||
let (status_code, message) = match response.error {
|
let (status_code, message) = match response.error {
|
||||||
Some(error) => {
|
Some(error) => {
|
||||||
let status_code = StatusCode::from_str(&error.code)
|
let status_code = StatusCode::from_str(&error.code)
|
||||||
@@ -668,7 +665,6 @@ impl CopilotResponsesEventMapper {
|
|||||||
),
|
),
|
||||||
};
|
};
|
||||||
vec![Err(LanguageModelCompletionError::HttpResponseError {
|
vec![Err(LanguageModelCompletionError::HttpResponseError {
|
||||||
provider,
|
|
||||||
status_code,
|
status_code,
|
||||||
message,
|
message,
|
||||||
})]
|
})]
|
||||||
@@ -1099,7 +1095,7 @@ mod tests {
|
|||||||
));
|
));
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
mapped[2],
|
mapped[2],
|
||||||
LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||||
input_tokens: 5,
|
input_tokens: 5,
|
||||||
output_tokens: 3,
|
output_tokens: 3,
|
||||||
..
|
..
|
||||||
@@ -1207,7 +1203,7 @@ mod tests {
|
|||||||
let mapped = map_events(events);
|
let mapped = map_events(events);
|
||||||
assert!(matches!(
|
assert!(matches!(
|
||||||
mapped[0],
|
mapped[0],
|
||||||
LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||||
input_tokens: 10,
|
input_tokens: 10,
|
||||||
output_tokens: 0,
|
output_tokens: 0,
|
||||||
..
|
..
|
||||||
|
|||||||
@@ -224,9 +224,7 @@ impl DeepSeekLanguageModel {
|
|||||||
|
|
||||||
let future = self.request_limiter.stream(async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let Some(api_key) = api_key else {
|
let Some(api_key) = api_key else {
|
||||||
return Err(LanguageModelCompletionError::NoApiKey {
|
return Err(LanguageModelCompletionError::NoApiKey);
|
||||||
provider: PROVIDER_NAME,
|
|
||||||
});
|
|
||||||
};
|
};
|
||||||
let request =
|
let request =
|
||||||
deepseek::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
deepseek::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||||
@@ -479,7 +477,7 @@ impl DeepSeekEventMapper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if let Some(usage) = event.usage {
|
if let Some(usage) = event.usage {
|
||||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||||
input_tokens: usage.prompt_tokens,
|
input_tokens: usage.prompt_tokens,
|
||||||
output_tokens: usage.completion_tokens,
|
output_tokens: usage.completion_tokens,
|
||||||
cache_creation_input_tokens: 0,
|
cache_creation_input_tokens: 0,
|
||||||
|
|||||||
@@ -350,10 +350,7 @@ impl LanguageModel for GoogleLanguageModel {
|
|||||||
|
|
||||||
async move {
|
async move {
|
||||||
let Some(api_key) = api_key else {
|
let Some(api_key) = api_key else {
|
||||||
return Err(LanguageModelCompletionError::NoApiKey {
|
return Err(LanguageModelCompletionError::NoApiKey.into());
|
||||||
provider: PROVIDER_NAME,
|
|
||||||
}
|
|
||||||
.into());
|
|
||||||
};
|
};
|
||||||
let response = google_ai::count_tokens(
|
let response = google_ai::count_tokens(
|
||||||
http_client.as_ref(),
|
http_client.as_ref(),
|
||||||
@@ -608,9 +605,9 @@ impl GoogleEventMapper {
|
|||||||
let mut wants_to_use_tool = false;
|
let mut wants_to_use_tool = false;
|
||||||
if let Some(usage_metadata) = event.usage_metadata {
|
if let Some(usage_metadata) = event.usage_metadata {
|
||||||
update_usage(&mut self.usage, &usage_metadata);
|
update_usage(&mut self.usage, &usage_metadata);
|
||||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(
|
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(convert_usage(
|
||||||
convert_usage(&self.usage),
|
&self.usage,
|
||||||
)))
|
))))
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(prompt_feedback) = event.prompt_feedback
|
if let Some(prompt_feedback) = event.prompt_feedback
|
||||||
|
|||||||
@@ -547,7 +547,7 @@ impl LmStudioEventMapper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if let Some(usage) = event.usage {
|
if let Some(usage) = event.usage {
|
||||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||||
input_tokens: usage.prompt_tokens,
|
input_tokens: usage.prompt_tokens,
|
||||||
output_tokens: usage.completion_tokens,
|
output_tokens: usage.completion_tokens,
|
||||||
cache_creation_input_tokens: 0,
|
cache_creation_input_tokens: 0,
|
||||||
|
|||||||
@@ -291,9 +291,7 @@ impl MistralLanguageModel {
|
|||||||
|
|
||||||
let future = self.request_limiter.stream(async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let Some(api_key) = api_key else {
|
let Some(api_key) = api_key else {
|
||||||
return Err(LanguageModelCompletionError::NoApiKey {
|
return Err(LanguageModelCompletionError::NoApiKey);
|
||||||
provider: PROVIDER_NAME,
|
|
||||||
});
|
|
||||||
};
|
};
|
||||||
let request =
|
let request =
|
||||||
mistral::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
mistral::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||||
@@ -672,7 +670,7 @@ impl MistralEventMapper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if let Some(usage) = event.usage {
|
if let Some(usage) = event.usage {
|
||||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||||
input_tokens: usage.prompt_tokens,
|
input_tokens: usage.prompt_tokens,
|
||||||
output_tokens: usage.completion_tokens,
|
output_tokens: usage.completion_tokens,
|
||||||
cache_creation_input_tokens: 0,
|
cache_creation_input_tokens: 0,
|
||||||
|
|||||||
@@ -603,7 +603,7 @@ fn map_to_language_model_completion_events(
|
|||||||
};
|
};
|
||||||
|
|
||||||
if delta.done {
|
if delta.done {
|
||||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||||
input_tokens: delta.prompt_eval_count.unwrap_or(0),
|
input_tokens: delta.prompt_eval_count.unwrap_or(0),
|
||||||
output_tokens: delta.eval_count.unwrap_or(0),
|
output_tokens: delta.eval_count.unwrap_or(0),
|
||||||
cache_creation_input_tokens: 0,
|
cache_creation_input_tokens: 0,
|
||||||
|
|||||||
@@ -228,7 +228,7 @@ impl OpenAiLanguageModel {
|
|||||||
let future = self.request_limiter.stream(async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let provider = PROVIDER_NAME;
|
let provider = PROVIDER_NAME;
|
||||||
let Some(api_key) = api_key else {
|
let Some(api_key) = api_key else {
|
||||||
return Err(LanguageModelCompletionError::NoApiKey { provider });
|
return Err(LanguageModelCompletionError::NoApiKey);
|
||||||
};
|
};
|
||||||
let request = stream_completion(
|
let request = stream_completion(
|
||||||
http_client.as_ref(),
|
http_client.as_ref(),
|
||||||
@@ -534,7 +534,7 @@ impl OpenAiEventMapper {
|
|||||||
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
|
||||||
let mut events = Vec::new();
|
let mut events = Vec::new();
|
||||||
if let Some(usage) = event.usage {
|
if let Some(usage) = event.usage {
|
||||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||||
input_tokens: usage.prompt_tokens,
|
input_tokens: usage.prompt_tokens,
|
||||||
output_tokens: usage.completion_tokens,
|
output_tokens: usage.completion_tokens,
|
||||||
cache_creation_input_tokens: 0,
|
cache_creation_input_tokens: 0,
|
||||||
|
|||||||
@@ -227,7 +227,7 @@ impl OpenAiCompatibleLanguageModel {
|
|||||||
let provider = self.provider_name.clone();
|
let provider = self.provider_name.clone();
|
||||||
let future = self.request_limiter.stream(async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let Some(api_key) = api_key else {
|
let Some(api_key) = api_key else {
|
||||||
return Err(LanguageModelCompletionError::NoApiKey { provider });
|
return Err(LanguageModelCompletionError::NoApiKey);
|
||||||
};
|
};
|
||||||
let request = stream_completion(
|
let request = stream_completion(
|
||||||
http_client.as_ref(),
|
http_client.as_ref(),
|
||||||
|
|||||||
@@ -84,9 +84,7 @@ impl State {
|
|||||||
let http_client = self.http_client.clone();
|
let http_client = self.http_client.clone();
|
||||||
let api_url = OpenRouterLanguageModelProvider::api_url(cx);
|
let api_url = OpenRouterLanguageModelProvider::api_url(cx);
|
||||||
let Some(api_key) = self.api_key_state.key(&api_url) else {
|
let Some(api_key) = self.api_key_state.key(&api_url) else {
|
||||||
return Task::ready(Err(LanguageModelCompletionError::NoApiKey {
|
return Task::ready(Err(LanguageModelCompletionError::NoApiKey));
|
||||||
provider: PROVIDER_NAME,
|
|
||||||
}));
|
|
||||||
};
|
};
|
||||||
cx.spawn(async move |this, cx| {
|
cx.spawn(async move |this, cx| {
|
||||||
let models = list_models(http_client.as_ref(), &api_url, &api_key)
|
let models = list_models(http_client.as_ref(), &api_url, &api_key)
|
||||||
@@ -288,9 +286,7 @@ impl OpenRouterLanguageModel {
|
|||||||
|
|
||||||
async move {
|
async move {
|
||||||
let Some(api_key) = api_key else {
|
let Some(api_key) = api_key else {
|
||||||
return Err(LanguageModelCompletionError::NoApiKey {
|
return Err(LanguageModelCompletionError::NoApiKey);
|
||||||
provider: PROVIDER_NAME,
|
|
||||||
});
|
|
||||||
};
|
};
|
||||||
let request =
|
let request =
|
||||||
open_router::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
open_router::stream_completion(http_client.as_ref(), &api_url, &api_key, request);
|
||||||
@@ -613,7 +609,7 @@ impl OpenRouterEventMapper {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if let Some(usage) = event.usage {
|
if let Some(usage) = event.usage {
|
||||||
events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
|
events.push(Ok(LanguageModelCompletionEvent::TokenUsage(TokenUsage {
|
||||||
input_tokens: usage.prompt_tokens,
|
input_tokens: usage.prompt_tokens,
|
||||||
output_tokens: usage.completion_tokens,
|
output_tokens: usage.completion_tokens,
|
||||||
cache_creation_input_tokens: 0,
|
cache_creation_input_tokens: 0,
|
||||||
|
|||||||
@@ -222,7 +222,7 @@ impl VercelLanguageModel {
|
|||||||
let future = self.request_limiter.stream(async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let provider = PROVIDER_NAME;
|
let provider = PROVIDER_NAME;
|
||||||
let Some(api_key) = api_key else {
|
let Some(api_key) = api_key else {
|
||||||
return Err(LanguageModelCompletionError::NoApiKey { provider });
|
return Err(LanguageModelCompletionError::NoApiKey);
|
||||||
};
|
};
|
||||||
let request = open_ai::stream_completion(
|
let request = open_ai::stream_completion(
|
||||||
http_client.as_ref(),
|
http_client.as_ref(),
|
||||||
|
|||||||
@@ -230,7 +230,7 @@ impl XAiLanguageModel {
|
|||||||
let future = self.request_limiter.stream(async move {
|
let future = self.request_limiter.stream(async move {
|
||||||
let provider = PROVIDER_NAME;
|
let provider = PROVIDER_NAME;
|
||||||
let Some(api_key) = api_key else {
|
let Some(api_key) = api_key else {
|
||||||
return Err(LanguageModelCompletionError::NoApiKey { provider });
|
return Err(LanguageModelCompletionError::NoApiKey);
|
||||||
};
|
};
|
||||||
let request = open_ai::stream_completion(
|
let request = open_ai::stream_completion(
|
||||||
http_client.as_ref(),
|
http_client.as_ref(),
|
||||||
|
|||||||
Reference in New Issue
Block a user