Compare commits

...

8 Commits

Author SHA1 Message Date
Agus Zubiaga
615fcb50cf Fix test 2025-12-15 14:33:20 -03:00
Agus Zubiaga
690f0db743 Fix var name 2025-12-15 14:17:11 -03:00
Agus Zubiaga
2b8280083f Refactor custom url/auth handling 2025-12-15 14:09:53 -03:00
Agus Zubiaga
3f0cb7ee3d Merge branch 'main' into dominic.burkart/zeta-works-while-signed-out-if-ZED_PREDICT_EDITS_URL-set 2025-12-15 13:24:20 -03:00
Dominic Burkart
f0d00cff1a style: Apply cargo fmt to edit_prediction_tests 2025-12-12 14:54:40 +01:00
Dominic Burkart
587e00746a refactor(zeta): Make auth tests CI-compatible
Add test-only infrastructure to allow testing custom URL auth behavior
without requiring environment variable manipulation:

- Add TEST_FORCE_CUSTOM_URL_MODE AtomicBool with #[cfg(test)]
- Add has_custom_url() helper that checks test override first
- Combine auth tests into single test_authentication_behavior function
  for sequential execution (avoids concurrency issues)
- Remove #[ignore] from custom URL test - now runs in CI

The test-only code has zero impact on production builds.
2025-12-12 14:46:13 +01:00
Dominic Burkart
0e41202703 feat(zeta): Allow custom URLs without authentication
When ZED_PREDICT_EDITS_URL is set (custom URL), authentication is now
optional - the prediction will use authentication if available, and
proceed without it if not available.

When no custom URL is set (standard Zed infrastructure), authentication
remains required to maintain the existing behavior.

Changes:
- Token acquisition uses .ok() for custom URLs (converts error to None)
- Authorization header only added when token is present
- Token refresh only attempted when we had a token to begin with
2025-12-11 13:53:33 +01:00
Dominic Burkart
c1cd943b30 test(zeta): Add authentication bypass tests for custom URLs
Add tests to verify authentication behavior with custom URLs:
- test_unauthenticated_without_custom_url_blocks_prediction: verifies
  that without authentication and without a custom URL, prediction fails
- test_unauthenticated_with_custom_url_allows_prediction: verifies that
  with a custom URL set via ZED_PREDICT_EDITS_URL, prediction should
  proceed even without authentication

These tests establish the expected behavior: when using custom URLs,
authentication should be optional (use if available, proceed without).
2025-12-11 13:53:24 +01:00
3 changed files with 249 additions and 42 deletions

View File

@@ -19,6 +19,7 @@ use futures::{
select_biased,
};
use gpui::BackgroundExecutor;
use gpui::http_client::Url;
use gpui::{
App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
http_client::{self, AsyncBody, Method},
@@ -127,15 +128,6 @@ static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
}
.to_string()
});
static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
if *USE_OLLAMA {
Some("http://localhost:11434/v1/chat/completions".into())
} else {
None
}
})
});
pub struct Zeta2FeatureFlag;
@@ -170,6 +162,7 @@ pub struct EditPredictionStore {
reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
shown_predictions: VecDeque<EditPrediction>,
rated_predictions: HashSet<EditPredictionId>,
custom_predict_edits_url: Option<Arc<Url>>,
}
#[derive(Copy, Clone, Default, PartialEq, Eq)]
@@ -568,6 +561,20 @@ impl EditPredictionStore {
reject_predictions_tx: reject_tx,
rated_predictions: Default::default(),
shown_predictions: Default::default(),
custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") {
Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into),
Err(_) => {
if *USE_OLLAMA {
Some(
Url::parse("http://localhost:11434/v1/chat/completions")
.unwrap()
.into(),
)
} else {
None
}
}
},
};
this.configure_context_retrieval(cx);
@@ -586,6 +593,11 @@ impl EditPredictionStore {
this
}
#[cfg(test)]
pub fn set_custom_predict_edits_url(&mut self, url: Url) {
self.custom_predict_edits_url = Some(url.into());
}
pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
self.edit_prediction_model = model;
}
@@ -1015,8 +1027,13 @@ impl EditPredictionStore {
}
fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
match self.edit_prediction_model {
EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
if self.custom_predict_edits_url.is_some() && custom_accept_url.is_none() {
return;
}
}
EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
}
@@ -1036,12 +1053,15 @@ impl EditPredictionStore {
let llm_token = self.llm_token.clone();
let app_version = AppVersion::global(cx);
cx.spawn(async move |this, cx| {
let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
http_client::Url::parse(&predict_edits_url)?
let (url, require_auth) = if let Some(accept_edits_url) = custom_accept_url {
(http_client::Url::parse(&accept_edits_url)?, false)
} else {
client
.http_client()
.build_zed_llm_url("/predict_edits/accept", &[])?
(
client
.http_client()
.build_zed_llm_url("/predict_edits/accept", &[])?,
true,
)
};
let response = cx
@@ -1058,6 +1078,7 @@ impl EditPredictionStore {
client,
llm_token,
app_version,
require_auth,
))
.await;
@@ -1116,6 +1137,7 @@ impl EditPredictionStore {
client.clone(),
llm_token.clone(),
app_version.clone(),
true,
)
.await;
@@ -1161,7 +1183,11 @@ impl EditPredictionStore {
was_shown: bool,
) {
match self.edit_prediction_model {
EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
if self.custom_predict_edits_url.is_some() {
return;
}
}
EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
}
@@ -1671,13 +1697,9 @@ impl EditPredictionStore {
#[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
#[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
http_client::Url::parse(&predict_edits_url)?
} else {
client
.http_client()
.build_zed_llm_url("/predict_edits/raw", &[])?
};
let url = client
.http_client()
.build_zed_llm_url("/predict_edits/raw", &[])?;
#[cfg(feature = "cli-support")]
let cache_key = if let Some(cache) = eval_cache {
@@ -1710,6 +1732,7 @@ impl EditPredictionStore {
client,
llm_token,
app_version,
true,
)
.await?;
@@ -1770,23 +1793,34 @@ impl EditPredictionStore {
client: Arc<Client>,
llm_token: LlmApiToken,
app_version: Version,
require_auth: bool,
) -> Result<(Res, Option<EditPredictionUsage>)>
where
Res: DeserializeOwned,
{
let http_client = client.http_client();
let mut token = llm_token.acquire(&client).await?;
let mut token = if require_auth {
Some(llm_token.acquire(&client).await?)
} else {
llm_token.acquire(&client).await.ok()
};
let mut did_retry = false;
loop {
let request_builder = http_client::Request::builder().method(Method::POST);
let request = build(
request_builder
.header("Content-Type", "application/json")
.header("Authorization", format!("Bearer {}", token))
.header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
)?;
let mut request_builder = request_builder
.header("Content-Type", "application/json")
.header(ZED_VERSION_HEADER_NAME, app_version.to_string());
// Only add Authorization header if we have a token
if let Some(ref token_value) = token {
request_builder =
request_builder.header("Authorization", format!("Bearer {}", token_value));
}
let request = build(request_builder)?;
let mut response = http_client.send(request).await?;
@@ -1810,13 +1844,14 @@ impl EditPredictionStore {
response.body_mut().read_to_end(&mut body).await?;
return Ok((serde_json::from_slice(&body)?, usage));
} else if !did_retry
&& token.is_some()
&& response
.headers()
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
.is_some()
{
did_retry = true;
token = llm_token.refresh(&client).await?;
token = Some(llm_token.refresh(&client).await?);
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;

View File

@@ -1914,6 +1914,174 @@ fn from_completion_edits(
.collect()
}
#[gpui::test]
async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/project",
serde_json::json!({
"main.rs": "fn main() {\n \n}\n"
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let http_client = FakeHttpClient::create(|_req| async move {
Ok(gpui::http_client::Response::builder()
.status(401)
.body("Unauthorized".into())
.unwrap())
});
let client =
cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
cx.update(|cx| {
language_model::RefreshLlmTokenListener::register(client.clone(), cx);
});
let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
let buffer = project
.update(cx, |project, cx| {
let path = project
.find_project_path(path!("/project/main.rs"), cx)
.unwrap();
project.open_buffer(path, cx)
})
.await
.unwrap();
let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
ep_store.update(cx, |ep_store, cx| {
ep_store.register_buffer(&buffer, &project, cx)
});
cx.background_executor.run_until_parked();
let completion_task = ep_store.update(cx, |ep_store, cx| {
ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
});
let result = completion_task.await;
assert!(
result.is_err(),
"Without authentication and without custom URL, prediction should fail"
);
}
#[gpui::test]
async fn test_unauthenticated_with_custom_url_allows_prediction_impl(cx: &mut TestAppContext) {
init_test(cx);
let fs = FakeFs::new(cx.executor());
fs.insert_tree(
"/project",
serde_json::json!({
"main.rs": "fn main() {\n \n}\n"
}),
)
.await;
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
let predict_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
let predict_called_clone = predict_called.clone();
let http_client = FakeHttpClient::create({
move |req| {
let uri = req.uri().path().to_string();
let predict_called = predict_called_clone.clone();
async move {
if uri.contains("predict") {
predict_called.store(true, std::sync::atomic::Ordering::SeqCst);
Ok(gpui::http_client::Response::builder()
.body(
serde_json::to_string(&open_ai::Response {
id: "test-123".to_string(),
object: "chat.completion".to_string(),
created: 0,
model: "test".to_string(),
usage: open_ai::Usage {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0,
},
choices: vec![open_ai::Choice {
index: 0,
message: open_ai::RequestMessage::Assistant {
content: Some(open_ai::MessageContent::Plain(
indoc! {"
```main.rs
<|start_of_file|>
<|editable_region_start|>
fn main() {
println!(\"Hello, world!\");
}
<|editable_region_end|>
```
"}
.to_string(),
)),
tool_calls: vec![],
},
finish_reason: Some("stop".to_string()),
}],
})
.unwrap()
.into(),
)
.unwrap())
} else {
Ok(gpui::http_client::Response::builder()
.status(401)
.body("Unauthorized".into())
.unwrap())
}
}
}
});
let client =
cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
cx.update(|cx| {
language_model::RefreshLlmTokenListener::register(client.clone(), cx);
});
let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
let buffer = project
.update(cx, |project, cx| {
let path = project
.find_project_path(path!("/project/main.rs"), cx)
.unwrap();
project.open_buffer(path, cx)
})
.await
.unwrap();
let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
ep_store.update(cx, |ep_store, cx| {
ep_store.register_buffer(&buffer, &project, cx)
});
cx.background_executor.run_until_parked();
let completion_task = ep_store.update(cx, |ep_store, cx| {
ep_store.set_custom_predict_edits_url(Url::parse("http://test/predict").unwrap());
ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
});
let _ = completion_task.await;
assert!(
predict_called.load(std::sync::atomic::Ordering::SeqCst),
"With custom URL, predict endpoint should be called even without authentication"
);
}
#[ctor::ctor]
fn init_logger() {
zlog::init_test();

View File

@@ -78,6 +78,19 @@ pub(crate) fn request_prediction_with_zeta1(
cx,
);
let (uri, require_auth) = match &store.custom_predict_edits_url {
Some(custom_url) => (custom_url.clone(), false),
None => {
match client
.http_client()
.build_zed_llm_url("/predict_edits/v2", &[])
{
Ok(url) => (url.into(), true),
Err(err) => return Task::ready(Err(err)),
}
}
};
cx.spawn(async move |this, cx| {
let GatherContextOutput {
mut body,
@@ -102,25 +115,16 @@ pub(crate) fn request_prediction_with_zeta1(
body.input_excerpt
);
let http_client = client.http_client();
let response = EditPredictionStore::send_api_request::<PredictEditsResponse>(
|request| {
let uri = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
predict_edits_url
} else {
http_client
.build_zed_llm_url("/predict_edits/v2", &[])?
.as_str()
.into()
};
Ok(request
.uri(uri)
.uri(uri.as_str())
.body(serde_json::to_string(&body)?.into())?)
},
client,
llm_token,
app_version,
require_auth,
)
.await;