diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index b5dd870710..9baeebf620 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -2099,12 +2099,41 @@ impl ZedAgent { // cx.emit(ThreadEvent::ShowError( // ThreadError::ModelRequestLimitReached { plan: error.plan }, // )); + } else if let Some(completion_error) = + error.downcast_ref::() + { + match completion_error { + LanguageModelCompletionError::RateLimitExceeded { + retry_after, + } => { + if !retry(Some(*retry_after)) { + break; + } + // todo! + } + LanguageModelCompletionError::Overloaded => { + if !retry(None) { + break; + } + // todo! + } + LanguageModelCompletionError::ApiInternalServerError => { + if !retry(None) { + break; + } + // todo! + } + _ => { + // todo!(emit_generic_error(error, cx);) + break; + } + } } else if let Some(known_error) = error.downcast_ref::() { match known_error { LanguageModelKnownError::ContextWindowLimitExceeded { - tokens, + tokens: _, } => { // todo! // this.exceeded_window_error = @@ -2147,7 +2176,6 @@ impl ZedAgent { // ); if !retry(None) { - // todo! show err break; } } @@ -2189,7 +2217,7 @@ impl ZedAgent { done })?; - if done { + if done && retry_state.is_none() { break; } else { intent = CompletionIntent::ToolResults; @@ -4605,7 +4633,7 @@ mod tests { const TEST_RATE_LIMIT_RETRY_SECS: u64 = 30; use agent_settings::{AgentProfileId, AgentSettings, LanguageModelParameters}; use assistant_tool::{ToolRegistry, ToolSource}; - use futures::StreamExt; + use futures::future::BoxFuture; use futures::stream::BoxStream; use gpui::TestAppContext; @@ -5599,21 +5627,224 @@ fn main() {{ InternalServerError, } - struct ErrorInjector { + #[gpui::test] + async fn test_retry_single_attempt(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + let (_workspace, _thread_store, agent, thread, _context_store, _model) = + setup_test_environment(cx, project.clone()).await; + + // Create a model that fails once then succeeds + let attempt_count = Arc::new(Mutex::new(0)); + let attempt_count_clone = attempt_count.clone(); + + let retry_model = Arc::new(RetryTestModel { + inner: Arc::new(FakeLanguageModel::default()), + attempt_count: attempt_count_clone, + fail_attempts: 1, + error_type: TestError::Overloaded, + }); + + agent.update(cx, |agent, cx| { + agent.send_message("Hello", retry_model.clone(), None, cx); + }); + + // First attempt should fail + cx.run_until_parked(); + assert_eq!(*attempt_count.lock(), 1); + + // Advance clock for retry delay + cx.executor().advance_clock(BASE_RETRY_DELAY); + cx.run_until_parked(); + + // Second attempt should succeed + assert_eq!(*attempt_count.lock(), 2); + + // Simulate successful response + let fake_model = retry_model.as_fake(); + fake_model.stream_last_completion_response("Assistant response"); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // Verify the message was sent successfully + thread.read_with(cx, |thread, _cx| { + assert_eq!(thread.messages.len(), 2); + assert_eq!(thread.messages[0].role, Role::User); + assert_eq!(thread.messages[1].role, Role::Assistant); + assert_eq!( + &thread.messages[1].segments[0], + &MessageSegment::Text("Assistant response".to_string()) + ); + }); + } + + #[gpui::test] + async fn test_retry_max_attempts_exceeded(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + let (_workspace, _thread_store, agent, thread, _context_store, _model) = + setup_test_environment(cx, project.clone()).await; + + // Create a model that always fails + let attempt_count = Arc::new(Mutex::new(0)); + let attempt_count_clone = attempt_count.clone(); + + let retry_model = Arc::new(RetryTestModel { + inner: Arc::new(FakeLanguageModel::default()), + attempt_count: attempt_count_clone, + fail_attempts: (MAX_RETRY_ATTEMPTS + 1) as usize, + error_type: TestError::InternalServerError, + }); + + agent.update(cx, |agent, cx| { + agent.send_message("Hello", retry_model.clone(), None, cx); + }); + + // Run through all retry attempts + for attempt in 1..=MAX_RETRY_ATTEMPTS { + cx.run_until_parked(); + assert_eq!(*attempt_count.lock(), attempt as usize); + + if attempt < MAX_RETRY_ATTEMPTS { + // Advance clock for exponential backoff + let delay = BASE_RETRY_DELAY * 2_u32.pow((attempt - 1) as u32); + cx.executor().advance_clock(delay); + cx.run_until_parked(); + } + } + + cx.run_until_parked(); + + // Should not retry beyond MAX_RETRY_ATTEMPTS + assert_eq!(*attempt_count.lock(), MAX_RETRY_ATTEMPTS as usize); + + // Verify no messages were added (failure case) + thread.read_with(cx, |thread, _cx| { + assert_eq!(thread.messages.len(), 1); // Only user message + assert_eq!(thread.messages[0].role, Role::User); + }); + } + + #[gpui::test] + async fn test_retry_exponential_backoff(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + let (_workspace, _thread_store, agent, thread, _context_store, _model) = + setup_test_environment(cx, project.clone()).await; + + // Create a model that fails multiple times + let attempt_count = Arc::new(Mutex::new(0)); + let attempt_count_clone = attempt_count.clone(); + + let retry_model = Arc::new(RetryTestModel { + inner: Arc::new(FakeLanguageModel::default()), + attempt_count: attempt_count_clone, + fail_attempts: 3, + error_type: TestError::Overloaded, + }); + + agent.update(cx, |agent, cx| { + agent.send_message("Hello", retry_model.clone(), None, cx); + }); + + // First attempt + cx.run_until_parked(); + assert_eq!(*attempt_count.lock(), 1); + + // Second attempt after 5 seconds + cx.executor().advance_clock(BASE_RETRY_DELAY); + cx.run_until_parked(); + assert_eq!(*attempt_count.lock(), 2); + + // Third attempt after 10 seconds (5 * 2^1) + cx.executor().advance_clock(BASE_RETRY_DELAY * 2); + cx.run_until_parked(); + assert_eq!(*attempt_count.lock(), 3); + + // Fourth attempt after 20 seconds (5 * 2^2) + cx.executor().advance_clock(BASE_RETRY_DELAY * 4); + cx.run_until_parked(); + assert_eq!(*attempt_count.lock(), 4); + + // Simulate successful response + let fake_model = retry_model.as_fake(); + fake_model.stream_last_completion_response("Assistant response"); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // Verify the message was sent successfully + thread.read_with(cx, |thread, _cx| { + assert_eq!(thread.messages.len(), 2); + assert_eq!(thread.messages[0].role, Role::User); + assert_eq!(thread.messages[1].role, Role::Assistant); + }); + } + + #[gpui::test] + async fn test_retry_rate_limit_with_custom_delay(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + let (_workspace, _thread_store, agent, thread, _context_store, _model) = + setup_test_environment(cx, project.clone()).await; + + // Create a model that returns rate limit error with custom delay + let custom_delay = Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS); + let attempt_count = Arc::new(Mutex::new(0)); + let attempt_count_clone = attempt_count.clone(); + + let retry_model = Arc::new(RateLimitTestModel { + inner: Arc::new(FakeLanguageModel::default()), + attempt_count: attempt_count_clone, + fail_attempts: 1, + retry_after: custom_delay, + }); + + agent.update(cx, |agent, cx| { + agent.send_message("Hello", retry_model.clone(), None, cx); + }); + + // First attempt should fail with rate limit + cx.run_until_parked(); + assert_eq!(*attempt_count.lock(), 1); + + // Advance clock by less than custom delay - should not retry yet + cx.executor().advance_clock(custom_delay / 2); + cx.run_until_parked(); + assert_eq!(*attempt_count.lock(), 1); + + // Advance clock to complete custom delay + cx.executor().advance_clock(custom_delay / 2); + cx.run_until_parked(); + + // Second attempt should succeed + assert_eq!(*attempt_count.lock(), 2); + + // Simulate successful response + let fake_model = retry_model.as_fake(); + fake_model.stream_last_completion_response("Assistant response"); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // Verify success + thread.read_with(cx, |thread, _cx| { + assert_eq!(thread.messages.len(), 2); + assert_eq!(thread.messages[1].role, Role::Assistant); + }); + } + + // Test model that fails a specific number of times + struct RetryTestModel { inner: Arc, + attempt_count: Arc>, + fail_attempts: usize, error_type: TestError, } - impl ErrorInjector { - fn new(error_type: TestError) -> Self { - Self { - inner: Arc::new(FakeLanguageModel::default()), - error_type, - } - } - } - - impl LanguageModel for ErrorInjector { + impl LanguageModel for RetryTestModel { fn id(&self) -> LanguageModelId { self.inner.id() } @@ -5660,8 +5891,8 @@ fn main() {{ fn stream_completion( &self, - _request: LanguageModelRequest, - _cx: &AsyncApp, + request: LanguageModelRequest, + cx: &AsyncApp, ) -> BoxFuture< 'static, Result< @@ -5672,17 +5903,22 @@ fn main() {{ LanguageModelCompletionError, >, > { - let error = match self.error_type { - TestError::Overloaded => LanguageModelCompletionError::Overloaded, - TestError::InternalServerError => { - LanguageModelCompletionError::ApiInternalServerError - } - }; - async move { - let stream = futures::stream::once(async move { Err(error) }); - Ok(stream.boxed()) + let mut count = self.attempt_count.lock(); + *count += 1; + let current_attempt = *count; + drop(count); + + if current_attempt <= self.fail_attempts { + let error = match self.error_type { + TestError::Overloaded => LanguageModelCompletionError::Overloaded, + TestError::InternalServerError => { + LanguageModelCompletionError::ApiInternalServerError + } + }; + async move { Err(error) }.boxed() + } else { + self.inner.stream_completion(request, cx) } - .boxed() } fn as_fake(&self) -> &FakeLanguageModel { @@ -5690,1071 +5926,91 @@ fn main() {{ } } - #[gpui::test] - async fn test_retry_on_overloaded_error(cx: &mut TestAppContext) { - init_test_settings(cx); - - let project = create_test_project(cx, json!({})).await; - let (_, _, agent, thread, _, _base_model) = - setup_test_environment(cx, project.clone()).await; - - // Create model that returns overloaded error - let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); - - // Insert a user message - agent.update(cx, |agent, cx| { - agent.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); - }); - - // Start completion - agent.update(cx, |agent, cx| { - agent.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); - }); - - cx.run_until_parked(); - - agent.read_with(cx, |agent, _| { - assert!(agent.retry_state.is_some(), "Should have retry state"); - let retry_state = agent.retry_state.as_ref().unwrap(); - assert_eq!(retry_state.attempt, 1, "Should be first retry attempt"); - assert_eq!( - retry_state.max_attempts, MAX_RETRY_ATTEMPTS, - "Should have default max attempts" - ); - }); - - // Check that a retry message was added - thread.read_with(cx, |thread, _cx| { - let mut messages = thread.messages(); - assert!( - messages.any(|msg| { - msg.role == Role::System - && msg.ui_only - && msg.segments.iter().any(|seg| { - if let MessageSegment::Text(text) = seg { - text.contains("overloaded") - && text - .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS)) - } else { - false - } - }) - }), - "Should have added a system retry message" - ); - }); - - let retry_count = thread.update(cx, |thread, _cx| { - thread - .messages() - .filter(|m| { - m.ui_only - && m.segments.iter().any(|s| { - if let MessageSegment::Text(text) = s { - text.contains("Retrying") && text.contains("seconds") - } else { - false - } - }) - }) - .count() - }); - - assert_eq!(retry_count, 1, "Should have one retry message"); + // Test model for rate limit errors + struct RateLimitTestModel { + inner: Arc, + attempt_count: Arc>, + fail_attempts: usize, + retry_after: Duration, } - #[gpui::test] - async fn test_retry_on_internal_server_error(cx: &mut TestAppContext) { - init_test_settings(cx); + impl LanguageModel for RateLimitTestModel { + fn id(&self) -> LanguageModelId { + self.inner.id() + } - let project = create_test_project(cx, json!({})).await; - let (_, _, agent, thread, _, _base_model) = - setup_test_environment(cx, project.clone()).await; + fn name(&self) -> LanguageModelName { + self.inner.name() + } - // Create model that returns internal server error - let model = Arc::new(ErrorInjector::new(TestError::InternalServerError)); + fn provider_id(&self) -> LanguageModelProviderId { + self.inner.provider_id() + } - // Insert a user message - agent.update(cx, |agent, cx| { - agent.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); - }); + fn provider_name(&self) -> LanguageModelProviderName { + self.inner.provider_name() + } - // Start completion - agent.update(cx, |agent, cx| { - agent.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); - }); + fn supports_tools(&self) -> bool { + self.inner.supports_tools() + } - cx.run_until_parked(); + fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { + self.inner.supports_tool_choice(choice) + } - // Check retry state on thread - agent.read_with(cx, |agent, _| { - assert!(agent.retry_state.is_some(), "Should have retry state"); - let retry_state = agent.retry_state.as_ref().unwrap(); - assert_eq!(retry_state.attempt, 1, "Should be first retry attempt"); - assert_eq!( - retry_state.max_attempts, MAX_RETRY_ATTEMPTS, - "Should have correct max attempts" - ); - }); + fn supports_images(&self) -> bool { + self.inner.supports_images() + } - // Check that a retry message was added with provider name - thread.read_with(cx, |thread, _cx| { - let mut messages = thread.messages(); - assert!( - messages.any(|msg| { - msg.role == Role::System - && msg.ui_only - && msg.segments.iter().any(|seg| { - if let MessageSegment::Text(text) = seg { - text.contains("internal") - && text.contains("Fake") - && text - .contains(&format!("attempt 1 of {}", MAX_RETRY_ATTEMPTS)) - } else { - false - } - }) - }), - "Should have added a system retry message with provider name" - ); - }); + fn telemetry_id(&self) -> String { + self.inner.telemetry_id() + } - // Count retry messages - let retry_count = thread.update(cx, |thread, _cx| { - thread - .messages() - .filter(|m| { - m.ui_only - && m.segments.iter().any(|s| { - if let MessageSegment::Text(text) = s { - text.contains("Retrying") && text.contains("seconds") - } else { - false - } - }) - }) - .count() - }); + fn max_token_count(&self) -> u64 { + self.inner.max_token_count() + } - assert_eq!(retry_count, 1, "Should have one retry message"); - } + fn count_tokens( + &self, + request: LanguageModelRequest, + cx: &App, + ) -> BoxFuture<'static, Result> { + self.inner.count_tokens(request, cx) + } - #[gpui::test] - async fn test_exponential_backoff_on_retries(cx: &mut TestAppContext) { - init_test_settings(cx); + fn stream_completion( + &self, + request: LanguageModelRequest, + cx: &AsyncApp, + ) -> BoxFuture< + 'static, + Result< + BoxStream< + 'static, + Result, + >, + LanguageModelCompletionError, + >, + > { + let mut count = self.attempt_count.lock(); + *count += 1; + let current_attempt = *count; + drop(count); - let project = create_test_project(cx, json!({})).await; - let (_, _, agent, thread, _, _base_model) = - setup_test_environment(cx, project.clone()).await; - - // Create model that returns overloaded error - let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); - - // Insert a user message - agent.update(cx, |agent, cx| { - agent.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); - }); - - // Track retry events and completion count - // Track completion events - let completion_count = Arc::new(Mutex::new(0)); - let completion_count_clone = completion_count.clone(); - - let _subscription = agent.update(cx, |_, cx| { - cx.subscribe(&agent, move |_, _, event: &ThreadEvent, _| { - if let ThreadEvent::NewRequest = event { - *completion_count_clone.lock() += 1; - } - }) - }); - - // First attempt - agent.update(cx, |agent, cx| { - agent.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); - }); - cx.run_until_parked(); - - // Should have scheduled first retry - count retry messages - let retry_count = thread.update(cx, |thread, _cx| { - thread - .messages() - .filter(|m| { - m.ui_only - && m.segments.iter().any(|s| { - if let MessageSegment::Text(text) = s { - text.contains("Retrying") && text.contains("seconds") - } else { - false - } - }) - }) - .count() - }); - assert_eq!(retry_count, 1, "Should have scheduled first retry"); - - // Check retry state - agent.read_with(cx, |agent, _| { - assert!(agent.retry_state.is_some(), "Should have retry state"); - let retry_state = agent.retry_state.as_ref().unwrap(); - assert_eq!(retry_state.attempt, 1, "Should be first retry attempt"); - }); - - // Advance clock for first retry - cx.executor().advance_clock(BASE_RETRY_DELAY); - cx.run_until_parked(); - - // Should have scheduled second retry - count retry messages - let retry_count = thread.update(cx, |thread, _cx| { - thread - .messages() - .filter(|m| { - m.ui_only - && m.segments.iter().any(|s| { - if let MessageSegment::Text(text) = s { - text.contains("Retrying") && text.contains("seconds") - } else { - false - } - }) - }) - .count() - }); - assert_eq!(retry_count, 2, "Should have scheduled second retry"); - - // Check retry state updated - agent.read_with(cx, |agent, _| { - assert!(agent.retry_state.is_some(), "Should have retry state"); - let retry_state = agent.retry_state.as_ref().unwrap(); - assert_eq!(retry_state.attempt, 2, "Should be second retry attempt"); - assert_eq!( - retry_state.max_attempts, MAX_RETRY_ATTEMPTS, - "Should have correct max attempts" - ); - }); - - // Advance clock for second retry (exponential backoff) - cx.executor().advance_clock(BASE_RETRY_DELAY * 2); - cx.run_until_parked(); - - // Should have scheduled third retry - // Count all retry messages now - let retry_count = thread.update(cx, |thread, _cx| { - thread - .messages() - .filter(|m| { - m.ui_only - && m.segments.iter().any(|s| { - if let MessageSegment::Text(text) = s { - text.contains("Retrying") && text.contains("seconds") - } else { - false - } - }) - }) - .count() - }); - assert_eq!( - retry_count, MAX_RETRY_ATTEMPTS as usize, - "Should have scheduled third retry" - ); - - // Check retry state updated - agent.read_with(cx, |agent, _| { - assert!(agent.retry_state.is_some(), "Should have retry state"); - let retry_state = agent.retry_state.as_ref().unwrap(); - assert_eq!( - retry_state.attempt, MAX_RETRY_ATTEMPTS, - "Should be at max retry attempt" - ); - assert_eq!( - retry_state.max_attempts, MAX_RETRY_ATTEMPTS, - "Should have correct max attempts" - ); - }); - - // Advance clock for third retry (exponential backoff) - cx.executor().advance_clock(BASE_RETRY_DELAY * 4); - cx.run_until_parked(); - - // No more retries should be scheduled after clock was advanced. - let retry_count = thread.update(cx, |thread, _cx| { - thread - .messages() - .filter(|m| { - m.ui_only - && m.segments.iter().any(|s| { - if let MessageSegment::Text(text) = s { - text.contains("Retrying") && text.contains("seconds") - } else { - false - } - }) - }) - .count() - }); - assert_eq!( - retry_count, MAX_RETRY_ATTEMPTS as usize, - "Should not exceed max retries" - ); - - // Final completion count should be initial + max retries - assert_eq!( - *completion_count.lock(), - (MAX_RETRY_ATTEMPTS + 1) as usize, - "Should have made initial + max retry attempts" - ); - } - - #[gpui::test] - async fn test_max_retries_exceeded(cx: &mut TestAppContext) { - init_test_settings(cx); - - let project = create_test_project(cx, json!({})).await; - let (_, _, agent, thread, _, _base_model) = - setup_test_environment(cx, project.clone()).await; - - // Create model that returns overloaded error - let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); - - // Insert a user message - agent.update(cx, |agent, cx| { - agent.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); - }); - - // Track events - let retries_failed = Arc::new(Mutex::new(false)); - let retries_failed_clone = retries_failed.clone(); - - let _subscription = agent.update(cx, |_, cx| { - cx.subscribe(&agent, move |_, _, event: &ThreadEvent, _| { - if let ThreadEvent::RetriesFailed { .. } = event { - *retries_failed_clone.lock() = true; - } - }) - }); - - // Start initial completion - agent.update(cx, |agent, cx| { - agent.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); - }); - cx.run_until_parked(); - - // Advance through all retries - for i in 0..MAX_RETRY_ATTEMPTS { - let delay = if i == 0 { - BASE_RETRY_DELAY + if current_attempt <= self.fail_attempts { + let error = LanguageModelCompletionError::RateLimitExceeded { + retry_after: self.retry_after, + }; + async move { Err(error) }.boxed() } else { - BASE_RETRY_DELAY * 2u32.pow(i as u32 - 1) - }; - cx.executor().advance_clock(delay); - cx.run_until_parked(); - } - - // After the 3rd retry is scheduled, we need to wait for it to execute and fail - // The 3rd retry has a delay of BASE_RETRY_DELAY_SECS * 4 (20 seconds) - let final_delay = BASE_RETRY_DELAY * 2u32.pow(MAX_RETRY_ATTEMPTS as u32 - 1); - cx.executor().advance_clock(final_delay); - cx.run_until_parked(); - - let retry_count = thread.update(cx, |thread, _cx| { - thread - .messages() - .filter(|m| { - m.ui_only - && m.segments.iter().any(|s| { - if let MessageSegment::Text(text) = s { - text.contains("Retrying") && text.contains("seconds") - } else { - false - } - }) - }) - .count() - }); - - // After max retries, should emit RetriesFailed event - assert_eq!( - retry_count, MAX_RETRY_ATTEMPTS as usize, - "Should have attempted max retries" - ); - assert!( - *retries_failed.lock(), - "Should emit RetriesFailed event after max retries exceeded" - ); - - // Retry state should be cleared - agent.read_with(cx, |agent, cx| { - assert!( - agent.retry_state.is_none(), - "Retry state should be cleared after max retries" - ); - - // Verify we have the expected number of retry messages - let retry_messages = thread - .read(cx) - .messages() - .filter(|msg| msg.ui_only && msg.role == Role::System) - .count(); - assert_eq!( - retry_messages, MAX_RETRY_ATTEMPTS as usize, - "Should have one retry message per attempt" - ); - }); - } - - #[gpui::test] - async fn test_retry_message_removed_on_retry(cx: &mut TestAppContext) { - init_test_settings(cx); - - let project = create_test_project(cx, json!({})).await; - let (_, _, agent, thread, _, _base_model) = - setup_test_environment(cx, project.clone()).await; - - // We'll use a wrapper to switch behavior after first failure - struct RetryTestModel { - inner: Arc, - failed_once: Arc>, - } - - impl LanguageModel for RetryTestModel { - fn id(&self) -> LanguageModelId { - self.inner.id() - } - - fn name(&self) -> LanguageModelName { - self.inner.name() - } - - fn provider_id(&self) -> LanguageModelProviderId { - self.inner.provider_id() - } - - fn provider_name(&self) -> LanguageModelProviderName { - self.inner.provider_name() - } - - fn supports_tools(&self) -> bool { - self.inner.supports_tools() - } - - fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { - self.inner.supports_tool_choice(choice) - } - - fn supports_images(&self) -> bool { - self.inner.supports_images() - } - - fn telemetry_id(&self) -> String { - self.inner.telemetry_id() - } - - fn max_token_count(&self) -> u64 { - self.inner.max_token_count() - } - - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - self.inner.count_tokens(request, cx) - } - - fn stream_completion( - &self, - request: LanguageModelRequest, - cx: &AsyncApp, - ) -> BoxFuture< - 'static, - Result< - BoxStream< - 'static, - Result, - >, - LanguageModelCompletionError, - >, - > { - if !*self.failed_once.lock() { - *self.failed_once.lock() = true; - // Return error on first attempt - let stream = futures::stream::once(async move { - Err(LanguageModelCompletionError::Overloaded) - }); - async move { Ok(stream.boxed()) }.boxed() - } else { - // Succeed on retry - self.inner.stream_completion(request, cx) - } - } - - fn as_fake(&self) -> &FakeLanguageModel { - &self.inner + self.inner.stream_completion(request, cx) } } - let model = Arc::new(RetryTestModel { - inner: Arc::new(FakeLanguageModel::default()), - failed_once: Arc::new(Mutex::new(false)), - }); - - // Insert a user message - agent.update(cx, |agent, cx| { - agent.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); - }); - - // Track message deletions - // Track when retry completes successfully - let retry_completed = Arc::new(Mutex::new(false)); - let retry_completed_clone = retry_completed.clone(); - - let _subscription = agent.update(cx, |_, cx| { - cx.subscribe(&agent, move |_, _, event: &ThreadEvent, _| { - if let ThreadEvent::StreamedCompletion = event { - *retry_completed_clone.lock() = true; - } - }) - }); - - // Start completion - agent.update(cx, |agent, cx| { - agent.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); - }); - cx.run_until_parked(); - - // Get the retry message ID - let retry_message_id = thread.read_with(cx, |thread, _cx| { - thread - .messages() - .find(|msg| msg.role == Role::System && msg.ui_only) - .map(|msg| msg.id) - .expect("Should have a retry message") - }); - - // Wait for retry - cx.executor().advance_clock(BASE_RETRY_DELAY); - cx.run_until_parked(); - - // Stream some successful content - let fake_model = model.as_fake(); - // After the retry, there should be a new pending completion - let pending = fake_model.pending_completions(); - assert!( - !pending.is_empty(), - "Should have a pending completion after retry" - ); - fake_model.stream_completion_response(&pending[0], "Success!"); - fake_model.end_completion_stream(&pending[0]); - cx.run_until_parked(); - - // Check that the retry completed successfully - assert!( - *retry_completed.lock(), - "Retry should have completed successfully" - ); - - // Retry message should still exist but be marked as ui_only - thread.read_with(cx, |thread, _cx| { - let retry_msg = thread - .message(retry_message_id) - .expect("Retry message should still exist"); - assert!(retry_msg.ui_only, "Retry message should be ui_only"); - assert_eq!( - retry_msg.role, - Role::System, - "Retry message should have System role" - ); - }); - } - - #[gpui::test] - async fn test_successful_completion_clears_retry_state(cx: &mut TestAppContext) { - init_test_settings(cx); - - let project = create_test_project(cx, json!({})).await; - let (_, _, agent, thread, _, _base_model) = - setup_test_environment(cx, project.clone()).await; - - // Create a model that fails once then succeeds - struct FailOnceModel { - inner: Arc, - failed_once: Arc>, + fn as_fake(&self) -> &FakeLanguageModel { + &self.inner } - - impl LanguageModel for FailOnceModel { - fn id(&self) -> LanguageModelId { - self.inner.id() - } - - fn name(&self) -> LanguageModelName { - self.inner.name() - } - - fn provider_id(&self) -> LanguageModelProviderId { - self.inner.provider_id() - } - - fn provider_name(&self) -> LanguageModelProviderName { - self.inner.provider_name() - } - - fn supports_tools(&self) -> bool { - self.inner.supports_tools() - } - - fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { - self.inner.supports_tool_choice(choice) - } - - fn supports_images(&self) -> bool { - self.inner.supports_images() - } - - fn telemetry_id(&self) -> String { - self.inner.telemetry_id() - } - - fn max_token_count(&self) -> u64 { - self.inner.max_token_count() - } - - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - self.inner.count_tokens(request, cx) - } - - fn stream_completion( - &self, - request: LanguageModelRequest, - cx: &AsyncApp, - ) -> BoxFuture< - 'static, - Result< - BoxStream< - 'static, - Result, - >, - LanguageModelCompletionError, - >, - > { - if !*self.failed_once.lock() { - *self.failed_once.lock() = true; - // Return error on first attempt - let stream = futures::stream::once(async move { - Err(LanguageModelCompletionError::Overloaded) - }); - async move { Ok(stream.boxed()) }.boxed() - } else { - // Succeed on retry - self.inner.stream_completion(request, cx) - } - } - } - - let fail_once_model = Arc::new(FailOnceModel { - inner: Arc::new(FakeLanguageModel::default()), - failed_once: Arc::new(Mutex::new(false)), - }); - - // Insert a user message - agent.update(cx, |agent, cx| { - agent.insert_user_message( - "Test message", - ContextLoadResult::default(), - None, - vec![], - cx, - ); - }); - - // Start completion with fail-once model - agent.update(cx, |agent, cx| { - agent.send_to_model( - fail_once_model.clone(), - CompletionIntent::UserPrompt, - None, - cx, - ); - }); - - cx.run_until_parked(); - - // Verify retry state exists after first failure - agent.read_with(cx, |agent, _| { - assert!( - agent.retry_state.is_some(), - "Should have retry state after failure" - ); - }); - - // Wait for retry delay - cx.executor().advance_clock(BASE_RETRY_DELAY); - cx.run_until_parked(); - - // The retry should now use our FailOnceModel which should succeed - // We need to help the FakeLanguageModel complete the stream - let inner_fake = fail_once_model.inner.clone(); - - // Wait a bit for the retry to start - cx.run_until_parked(); - - // Check for pending completions and complete them - if let Some(pending) = inner_fake.pending_completions().first() { - inner_fake.stream_completion_response(pending, "Success!"); - inner_fake.end_completion_stream(pending); - } - cx.run_until_parked(); - - agent.read_with(cx, |agent, _| { - assert!( - agent.retry_state.is_none(), - "Retry state should be cleared after successful completion" - ); - }); - - thread.read_with(cx, |thread, _| { - let has_assistant_message = thread - .messages() - .any(|msg| msg.role == Role::Assistant && !msg.ui_only); - assert!( - has_assistant_message, - "Should have an assistant message after successful retry" - ); - }); - } - - #[gpui::test] - async fn test_rate_limit_retry_single_attempt(cx: &mut TestAppContext) { - init_test_settings(cx); - - let project = create_test_project(cx, json!({})).await; - let (_, _, agent, thread, _, _base_model) = - setup_test_environment(cx, project.clone()).await; - - // Create a model that returns rate limit error with retry_after - struct RateLimitModel { - inner: Arc, - } - - impl LanguageModel for RateLimitModel { - fn id(&self) -> LanguageModelId { - self.inner.id() - } - - fn name(&self) -> LanguageModelName { - self.inner.name() - } - - fn provider_id(&self) -> LanguageModelProviderId { - self.inner.provider_id() - } - - fn provider_name(&self) -> LanguageModelProviderName { - self.inner.provider_name() - } - - fn supports_tools(&self) -> bool { - self.inner.supports_tools() - } - - fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool { - self.inner.supports_tool_choice(choice) - } - - fn supports_images(&self) -> bool { - self.inner.supports_images() - } - - fn telemetry_id(&self) -> String { - self.inner.telemetry_id() - } - - fn max_token_count(&self) -> u64 { - self.inner.max_token_count() - } - - fn count_tokens( - &self, - request: LanguageModelRequest, - cx: &App, - ) -> BoxFuture<'static, Result> { - self.inner.count_tokens(request, cx) - } - - fn stream_completion( - &self, - _request: LanguageModelRequest, - _cx: &AsyncApp, - ) -> BoxFuture< - 'static, - Result< - BoxStream< - 'static, - Result, - >, - LanguageModelCompletionError, - >, - > { - async move { - let stream = futures::stream::once(async move { - Err(LanguageModelCompletionError::RateLimitExceeded { - retry_after: Duration::from_secs(TEST_RATE_LIMIT_RETRY_SECS), - }) - }); - Ok(stream.boxed()) - } - .boxed() - } - - fn as_fake(&self) -> &FakeLanguageModel { - &self.inner - } - } - - let model = Arc::new(RateLimitModel { - inner: Arc::new(FakeLanguageModel::default()), - }); - - // Insert a user message - agent.update(cx, |agent, cx| { - agent.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); - }); - - // Start completion - agent.update(cx, |agent, cx| { - agent.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); - }); - - cx.run_until_parked(); - - let retry_count = thread.update(cx, |thread, _| { - thread - .messages() - .filter(|m| { - m.ui_only - && m.segments.iter().any(|s| { - if let MessageSegment::Text(text) = s { - text.contains("rate limit exceeded") - } else { - false - } - }) - }) - .count() - }); - assert_eq!(retry_count, 1, "Should have scheduled one retry"); - - agent.read_with(cx, |agent, _| { - assert!( - agent.retry_state.is_none(), - "Rate limit errors should not set retry_state" - ); - }); - - // Verify we have one retry message - thread.read_with(cx, |thread, _| { - let retry_messages = thread - .messages() - .filter(|msg| { - msg.ui_only - && msg.segments.iter().any(|seg| { - if let MessageSegment::Text(text) = seg { - text.contains("rate limit exceeded") - } else { - false - } - }) - }) - .count(); - assert_eq!( - retry_messages, 1, - "Should have one rate limit retry message" - ); - }); - - // Check that retry message doesn't include attempt count - thread.read_with(cx, |thread, _cx| { - let retry_message = thread - .messages() - .find(|msg| msg.role == Role::System && msg.ui_only) - .expect("Should have a retry message"); - - // Check that the message doesn't contain attempt count - if let Some(MessageSegment::Text(text)) = retry_message.segments.first() { - assert!( - !text.contains("attempt"), - "Rate limit retry message should not contain attempt count" - ); - assert!( - text.contains(&format!( - "Retrying in {} seconds", - TEST_RATE_LIMIT_RETRY_SECS - )), - "Rate limit retry message should contain retry delay" - ); - } - }); - } - - #[gpui::test] - async fn test_ui_only_messages_not_sent_to_model(cx: &mut TestAppContext) { - init_test_settings(cx); - - let project = create_test_project(cx, json!({})).await; - let (_, _, agent, thread, _, model) = setup_test_environment(cx, project.clone()).await; - - // Insert a regular user message - agent.update(cx, |agent, cx| { - agent.send_message("Hello!", model.clone(), None, cx) - }); - - // Insert a UI-only message (like our retry notifications) - thread.update(cx, |thread, cx| { - let id = thread.next_message_id.post_inc(); - thread.messages.push(Message { - id, - role: Role::System, - segments: vec![MessageSegment::Text( - "This is a UI-only message that should not be sent to the model".to_string(), - )], - loaded_context: LoadedContext::default(), - creases: Vec::new(), - is_hidden: true, - ui_only: true, - }); - cx.emit(ThreadEvent::MessageAdded(id)); - }); - - // Insert another regular message - agent.update(cx, |agent, cx| { - agent.send_message("How are you?", model.clone(), None, cx) - }); - - // Generate the completion request - let request = agent.update(cx, |agent, cx| { - agent.build_request(&model, CompletionIntent::UserPrompt, cx) - }); - - // Verify that the request only contains non-UI-only messages - // Should have system prompt + 2 user messages, but not the UI-only message - let user_messages: Vec<_> = request - .messages - .iter() - .filter(|msg| msg.role == Role::User) - .collect(); - assert_eq!( - user_messages.len(), - 2, - "Should have exactly 2 user messages" - ); - - // Verify the UI-only content is not present anywhere in the request - let request_text = request - .messages - .iter() - .flat_map(|msg| &msg.content) - .filter_map(|content| match content { - MessageContent::Text(text) => Some(text.as_str()), - _ => None, - }) - .collect::(); - - assert!( - !request_text.contains("UI-only message"), - "UI-only message content should not be in the request" - ); - - // Verify the thread still has all 3 messages (including UI-only) - thread.read_with(cx, |thread, _cx| { - assert_eq!( - thread.messages().count(), - 3, - "Thread should have 3 messages" - ); - assert_eq!( - thread.messages().filter(|m| m.ui_only).count(), - 1, - "Thread should have 1 UI-only message" - ); - }); - - // Verify that UI-only messages are not serialized - let serialized = agent - .update(cx, |agent, cx| agent.serialize(cx)) - .await - .unwrap(); - assert_eq!( - serialized.messages.len(), - 2, - "Serialized thread should only have 2 messages (no UI-only)" - ); - } - - #[gpui::test] - async fn test_retry_cancelled_on_stop(cx: &mut TestAppContext) { - init_test_settings(cx); - - let project = create_test_project(cx, json!({})).await; - let (_, _, agent, thread, _, _base_model) = - setup_test_environment(cx, project.clone()).await; - - // Create model that returns overloaded error - let model = Arc::new(ErrorInjector::new(TestError::Overloaded)); - - // Insert a user message - agent.update(cx, |agent, cx| { - agent.insert_user_message("Hello!", ContextLoadResult::default(), None, vec![], cx); - }); - - // Start completion - agent.update(cx, |agent, cx| { - agent.send_to_model(model.clone(), CompletionIntent::UserPrompt, None, cx); - }); - - cx.run_until_parked(); - - // Verify retry was scheduled by checking for retry message - let has_retry_message = thread.read_with(cx, |thread, _| { - thread.messages().any(|m| { - m.ui_only - && m.segments.iter().any(|s| { - if let MessageSegment::Text(text) = s { - text.contains("Retrying") && text.contains("seconds") - } else { - false - } - }) - }) - }); - assert!(has_retry_message, "Should have scheduled a retry"); - - // Cancel the completion before the retry happens - agent.update(cx, |agent, cx| { - agent.cancel_last_completion(None, cx); - }); - - cx.run_until_parked(); - - // The retry should not have happened - no pending completions - let fake_model = model.as_fake(); - assert_eq!( - fake_model.pending_completions().len(), - 0, - "Should have no pending completions after cancellation" - ); - - // Verify the retry was cancelled by checking retry state - agent.read_with(cx, |agent, _| { - if let Some(retry_state) = &agent.retry_state { - panic!( - "retry_state should be cleared after cancellation, but found: attempt={}, max_attempts={}, intent={:?}", - retry_state.attempt, retry_state.max_attempts, retry_state.intent - ); - } - }); } fn test_summarize_error(