diff --git a/crates/edit_prediction_cli/src/main.rs b/crates/edit_prediction_cli/src/main.rs index 8faa8e3d04..50d3087743 100644 --- a/crates/edit_prediction_cli/src/main.rs +++ b/crates/edit_prediction_cli/src/main.rs @@ -151,7 +151,7 @@ struct PredictArgs { repetitions: usize, } -#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)] +#[derive(Clone, Copy, Debug, PartialEq, ValueEnum, Serialize, Deserialize)] enum PredictionProvider { Sweep, Mercury, diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index 6579e61cf9..51f4523605 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -28,12 +28,16 @@ pub async fn run_prediction( app_state: Arc, mut cx: AsyncApp, ) -> anyhow::Result<()> { - if !example.predictions.is_empty() { - return Ok(()); - } - let provider = provider.context("provider is required")?; + if let Some(existing_prediction) = example.predictions.first() { + if existing_prediction.provider == provider { + return Ok(()); + } else { + example.predictions.clear(); + } + } + run_context_retrieval(example, app_state.clone(), cx.clone()).await?; if matches!(