Compare commits
8 Commits
main
...
dominic.bu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
615fcb50cf | ||
|
|
690f0db743 | ||
|
|
2b8280083f | ||
|
|
3f0cb7ee3d | ||
|
|
f0d00cff1a | ||
|
|
587e00746a | ||
|
|
0e41202703 | ||
|
|
c1cd943b30 |
@@ -19,6 +19,7 @@ use futures::{
|
||||
select_biased,
|
||||
};
|
||||
use gpui::BackgroundExecutor;
|
||||
use gpui::http_client::Url;
|
||||
use gpui::{
|
||||
App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
|
||||
http_client::{self, AsyncBody, Method},
|
||||
@@ -127,15 +128,6 @@ static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
|
||||
}
|
||||
.to_string()
|
||||
});
|
||||
static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
|
||||
env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
|
||||
if *USE_OLLAMA {
|
||||
Some("http://localhost:11434/v1/chat/completions".into())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
});
|
||||
|
||||
pub struct Zeta2FeatureFlag;
|
||||
|
||||
@@ -170,6 +162,7 @@ pub struct EditPredictionStore {
|
||||
reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
|
||||
shown_predictions: VecDeque<EditPrediction>,
|
||||
rated_predictions: HashSet<EditPredictionId>,
|
||||
custom_predict_edits_url: Option<Arc<Url>>,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Default, PartialEq, Eq)]
|
||||
@@ -568,6 +561,20 @@ impl EditPredictionStore {
|
||||
reject_predictions_tx: reject_tx,
|
||||
rated_predictions: Default::default(),
|
||||
shown_predictions: Default::default(),
|
||||
custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") {
|
||||
Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into),
|
||||
Err(_) => {
|
||||
if *USE_OLLAMA {
|
||||
Some(
|
||||
Url::parse("http://localhost:11434/v1/chat/completions")
|
||||
.unwrap()
|
||||
.into(),
|
||||
)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
this.configure_context_retrieval(cx);
|
||||
@@ -586,6 +593,11 @@ impl EditPredictionStore {
|
||||
this
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn set_custom_predict_edits_url(&mut self, url: Url) {
|
||||
self.custom_predict_edits_url = Some(url.into());
|
||||
}
|
||||
|
||||
pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
|
||||
self.edit_prediction_model = model;
|
||||
}
|
||||
@@ -1015,8 +1027,13 @@ impl EditPredictionStore {
|
||||
}
|
||||
|
||||
fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
|
||||
let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
|
||||
match self.edit_prediction_model {
|
||||
EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
|
||||
EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
|
||||
if self.custom_predict_edits_url.is_some() && custom_accept_url.is_none() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
|
||||
}
|
||||
|
||||
@@ -1036,12 +1053,15 @@ impl EditPredictionStore {
|
||||
let llm_token = self.llm_token.clone();
|
||||
let app_version = AppVersion::global(cx);
|
||||
cx.spawn(async move |this, cx| {
|
||||
let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
|
||||
http_client::Url::parse(&predict_edits_url)?
|
||||
let (url, require_auth) = if let Some(accept_edits_url) = custom_accept_url {
|
||||
(http_client::Url::parse(&accept_edits_url)?, false)
|
||||
} else {
|
||||
client
|
||||
.http_client()
|
||||
.build_zed_llm_url("/predict_edits/accept", &[])?
|
||||
(
|
||||
client
|
||||
.http_client()
|
||||
.build_zed_llm_url("/predict_edits/accept", &[])?,
|
||||
true,
|
||||
)
|
||||
};
|
||||
|
||||
let response = cx
|
||||
@@ -1058,6 +1078,7 @@ impl EditPredictionStore {
|
||||
client,
|
||||
llm_token,
|
||||
app_version,
|
||||
require_auth,
|
||||
))
|
||||
.await;
|
||||
|
||||
@@ -1116,6 +1137,7 @@ impl EditPredictionStore {
|
||||
client.clone(),
|
||||
llm_token.clone(),
|
||||
app_version.clone(),
|
||||
true,
|
||||
)
|
||||
.await;
|
||||
|
||||
@@ -1161,7 +1183,11 @@ impl EditPredictionStore {
|
||||
was_shown: bool,
|
||||
) {
|
||||
match self.edit_prediction_model {
|
||||
EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
|
||||
EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
|
||||
if self.custom_predict_edits_url.is_some() {
|
||||
return;
|
||||
}
|
||||
}
|
||||
EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
|
||||
}
|
||||
|
||||
@@ -1671,13 +1697,9 @@ impl EditPredictionStore {
|
||||
#[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
|
||||
#[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
|
||||
) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
|
||||
let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
|
||||
http_client::Url::parse(&predict_edits_url)?
|
||||
} else {
|
||||
client
|
||||
.http_client()
|
||||
.build_zed_llm_url("/predict_edits/raw", &[])?
|
||||
};
|
||||
let url = client
|
||||
.http_client()
|
||||
.build_zed_llm_url("/predict_edits/raw", &[])?;
|
||||
|
||||
#[cfg(feature = "cli-support")]
|
||||
let cache_key = if let Some(cache) = eval_cache {
|
||||
@@ -1710,6 +1732,7 @@ impl EditPredictionStore {
|
||||
client,
|
||||
llm_token,
|
||||
app_version,
|
||||
true,
|
||||
)
|
||||
.await?;
|
||||
|
||||
@@ -1770,23 +1793,34 @@ impl EditPredictionStore {
|
||||
client: Arc<Client>,
|
||||
llm_token: LlmApiToken,
|
||||
app_version: Version,
|
||||
require_auth: bool,
|
||||
) -> Result<(Res, Option<EditPredictionUsage>)>
|
||||
where
|
||||
Res: DeserializeOwned,
|
||||
{
|
||||
let http_client = client.http_client();
|
||||
let mut token = llm_token.acquire(&client).await?;
|
||||
|
||||
let mut token = if require_auth {
|
||||
Some(llm_token.acquire(&client).await?)
|
||||
} else {
|
||||
llm_token.acquire(&client).await.ok()
|
||||
};
|
||||
let mut did_retry = false;
|
||||
|
||||
loop {
|
||||
let request_builder = http_client::Request::builder().method(Method::POST);
|
||||
|
||||
let request = build(
|
||||
request_builder
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Authorization", format!("Bearer {}", token))
|
||||
.header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
|
||||
)?;
|
||||
let mut request_builder = request_builder
|
||||
.header("Content-Type", "application/json")
|
||||
.header(ZED_VERSION_HEADER_NAME, app_version.to_string());
|
||||
|
||||
// Only add Authorization header if we have a token
|
||||
if let Some(ref token_value) = token {
|
||||
request_builder =
|
||||
request_builder.header("Authorization", format!("Bearer {}", token_value));
|
||||
}
|
||||
|
||||
let request = build(request_builder)?;
|
||||
|
||||
let mut response = http_client.send(request).await?;
|
||||
|
||||
@@ -1810,13 +1844,14 @@ impl EditPredictionStore {
|
||||
response.body_mut().read_to_end(&mut body).await?;
|
||||
return Ok((serde_json::from_slice(&body)?, usage));
|
||||
} else if !did_retry
|
||||
&& token.is_some()
|
||||
&& response
|
||||
.headers()
|
||||
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
|
||||
.is_some()
|
||||
{
|
||||
did_retry = true;
|
||||
token = llm_token.refresh(&client).await?;
|
||||
token = Some(llm_token.refresh(&client).await?);
|
||||
} else {
|
||||
let mut body = String::new();
|
||||
response.body_mut().read_to_string(&mut body).await?;
|
||||
|
||||
@@ -1914,6 +1914,174 @@ fn from_completion_edits(
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
"/project",
|
||||
serde_json::json!({
|
||||
"main.rs": "fn main() {\n \n}\n"
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
|
||||
let http_client = FakeHttpClient::create(|_req| async move {
|
||||
Ok(gpui::http_client::Response::builder()
|
||||
.status(401)
|
||||
.body("Unauthorized".into())
|
||||
.unwrap())
|
||||
});
|
||||
|
||||
let client =
|
||||
cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
|
||||
cx.update(|cx| {
|
||||
language_model::RefreshLlmTokenListener::register(client.clone(), cx);
|
||||
});
|
||||
|
||||
let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
|
||||
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
let path = project
|
||||
.find_project_path(path!("/project/main.rs"), cx)
|
||||
.unwrap();
|
||||
project.open_buffer(path, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.register_buffer(&buffer, &project, cx)
|
||||
});
|
||||
cx.background_executor.run_until_parked();
|
||||
|
||||
let completion_task = ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
|
||||
ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
|
||||
});
|
||||
|
||||
let result = completion_task.await;
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"Without authentication and without custom URL, prediction should fail"
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_unauthenticated_with_custom_url_allows_prediction_impl(cx: &mut TestAppContext) {
|
||||
init_test(cx);
|
||||
|
||||
let fs = FakeFs::new(cx.executor());
|
||||
fs.insert_tree(
|
||||
"/project",
|
||||
serde_json::json!({
|
||||
"main.rs": "fn main() {\n \n}\n"
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
|
||||
|
||||
let predict_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
|
||||
let predict_called_clone = predict_called.clone();
|
||||
|
||||
let http_client = FakeHttpClient::create({
|
||||
move |req| {
|
||||
let uri = req.uri().path().to_string();
|
||||
let predict_called = predict_called_clone.clone();
|
||||
async move {
|
||||
if uri.contains("predict") {
|
||||
predict_called.store(true, std::sync::atomic::Ordering::SeqCst);
|
||||
Ok(gpui::http_client::Response::builder()
|
||||
.body(
|
||||
serde_json::to_string(&open_ai::Response {
|
||||
id: "test-123".to_string(),
|
||||
object: "chat.completion".to_string(),
|
||||
created: 0,
|
||||
model: "test".to_string(),
|
||||
usage: open_ai::Usage {
|
||||
prompt_tokens: 0,
|
||||
completion_tokens: 0,
|
||||
total_tokens: 0,
|
||||
},
|
||||
choices: vec![open_ai::Choice {
|
||||
index: 0,
|
||||
message: open_ai::RequestMessage::Assistant {
|
||||
content: Some(open_ai::MessageContent::Plain(
|
||||
indoc! {"
|
||||
```main.rs
|
||||
<|start_of_file|>
|
||||
<|editable_region_start|>
|
||||
fn main() {
|
||||
println!(\"Hello, world!\");
|
||||
}
|
||||
<|editable_region_end|>
|
||||
```
|
||||
"}
|
||||
.to_string(),
|
||||
)),
|
||||
tool_calls: vec![],
|
||||
},
|
||||
finish_reason: Some("stop".to_string()),
|
||||
}],
|
||||
})
|
||||
.unwrap()
|
||||
.into(),
|
||||
)
|
||||
.unwrap())
|
||||
} else {
|
||||
Ok(gpui::http_client::Response::builder()
|
||||
.status(401)
|
||||
.body("Unauthorized".into())
|
||||
.unwrap())
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let client =
|
||||
cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
|
||||
cx.update(|cx| {
|
||||
language_model::RefreshLlmTokenListener::register(client.clone(), cx);
|
||||
});
|
||||
|
||||
let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
|
||||
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
let path = project
|
||||
.find_project_path(path!("/project/main.rs"), cx)
|
||||
.unwrap();
|
||||
project.open_buffer(path, cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.register_buffer(&buffer, &project, cx)
|
||||
});
|
||||
cx.background_executor.run_until_parked();
|
||||
|
||||
let completion_task = ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.set_custom_predict_edits_url(Url::parse("http://test/predict").unwrap());
|
||||
ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
|
||||
ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
|
||||
});
|
||||
|
||||
let _ = completion_task.await;
|
||||
|
||||
assert!(
|
||||
predict_called.load(std::sync::atomic::Ordering::SeqCst),
|
||||
"With custom URL, predict endpoint should be called even without authentication"
|
||||
);
|
||||
}
|
||||
|
||||
#[ctor::ctor]
|
||||
fn init_logger() {
|
||||
zlog::init_test();
|
||||
|
||||
@@ -78,6 +78,19 @@ pub(crate) fn request_prediction_with_zeta1(
|
||||
cx,
|
||||
);
|
||||
|
||||
let (uri, require_auth) = match &store.custom_predict_edits_url {
|
||||
Some(custom_url) => (custom_url.clone(), false),
|
||||
None => {
|
||||
match client
|
||||
.http_client()
|
||||
.build_zed_llm_url("/predict_edits/v2", &[])
|
||||
{
|
||||
Ok(url) => (url.into(), true),
|
||||
Err(err) => return Task::ready(Err(err)),
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
let GatherContextOutput {
|
||||
mut body,
|
||||
@@ -102,25 +115,16 @@ pub(crate) fn request_prediction_with_zeta1(
|
||||
body.input_excerpt
|
||||
);
|
||||
|
||||
let http_client = client.http_client();
|
||||
|
||||
let response = EditPredictionStore::send_api_request::<PredictEditsResponse>(
|
||||
|request| {
|
||||
let uri = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
|
||||
predict_edits_url
|
||||
} else {
|
||||
http_client
|
||||
.build_zed_llm_url("/predict_edits/v2", &[])?
|
||||
.as_str()
|
||||
.into()
|
||||
};
|
||||
Ok(request
|
||||
.uri(uri)
|
||||
.uri(uri.as_str())
|
||||
.body(serde_json::to_string(&body)?.into())?)
|
||||
},
|
||||
client,
|
||||
llm_token,
|
||||
app_version,
|
||||
require_auth,
|
||||
)
|
||||
.await;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user