Compare commits
21 Commits
deleter
...
inline-ass
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aef1106e7a | ||
|
|
605c594181 | ||
|
|
f332b08a34 | ||
|
|
cfb61624c3 | ||
|
|
f79dd8784b | ||
|
|
0c5bebae93 | ||
|
|
2098b67304 | ||
|
|
5a6198cc39 | ||
|
|
cda78c12ab | ||
|
|
f4378672b8 | ||
|
|
ecb8d3d4dd | ||
|
|
95dbc0efc2 | ||
|
|
8572c19a02 | ||
|
|
045c14593f | ||
|
|
0ff3b68a5e | ||
|
|
a6b9524d78 | ||
|
|
7ed5d42696 | ||
|
|
25d74480aa | ||
|
|
37077a8ebb | ||
|
|
7c4a85f5f1 | ||
|
|
d21628c349 |
2
.github/workflows/run_tests.yml
vendored
2
.github/workflows/run_tests.yml
vendored
@@ -497,6 +497,8 @@ jobs:
|
||||
env:
|
||||
GIT_AUTHOR_NAME: Protobuf Action
|
||||
GIT_AUTHOR_EMAIL: ci@zed.dev
|
||||
GIT_COMMITTER_NAME: Protobuf Action
|
||||
GIT_COMMITTER_EMAIL: ci@zed.dev
|
||||
steps:
|
||||
- name: steps::checkout_repo
|
||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683
|
||||
|
||||
30
Cargo.lock
generated
30
Cargo.lock
generated
@@ -3111,16 +3111,6 @@ dependencies = [
|
||||
"uuid",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cloud_zeta2_prompt"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"cloud_llm_client",
|
||||
"indoc",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "cmake"
|
||||
version = "0.1.54"
|
||||
@@ -5119,7 +5109,6 @@ dependencies = [
|
||||
"clock",
|
||||
"cloud_api_types",
|
||||
"cloud_llm_client",
|
||||
"cloud_zeta2_prompt",
|
||||
"collections",
|
||||
"copilot",
|
||||
"credentials_provider",
|
||||
@@ -5150,8 +5139,6 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"settings",
|
||||
"smol",
|
||||
"strsim",
|
||||
"strum 0.27.2",
|
||||
"telemetry",
|
||||
"telemetry_events",
|
||||
@@ -5162,6 +5149,7 @@ dependencies = [
|
||||
"workspace",
|
||||
"worktree",
|
||||
"zed_actions",
|
||||
"zeta_prompt",
|
||||
"zlog",
|
||||
]
|
||||
|
||||
@@ -5175,11 +5163,10 @@ dependencies = [
|
||||
"clap",
|
||||
"client",
|
||||
"cloud_llm_client",
|
||||
"cloud_zeta2_prompt",
|
||||
"collections",
|
||||
"debug_adapter_extension",
|
||||
"dirs 4.0.0",
|
||||
"edit_prediction",
|
||||
"edit_prediction_context",
|
||||
"extension",
|
||||
"fs",
|
||||
"futures 0.3.31",
|
||||
@@ -5209,9 +5196,10 @@ dependencies = [
|
||||
"sqlez",
|
||||
"sqlez_macros",
|
||||
"terminal_view",
|
||||
"toml 0.8.23",
|
||||
"util",
|
||||
"wasmtime",
|
||||
"watch",
|
||||
"zeta_prompt",
|
||||
"zlog",
|
||||
]
|
||||
|
||||
@@ -5239,6 +5227,7 @@ dependencies = [
|
||||
"text",
|
||||
"tree-sitter",
|
||||
"util",
|
||||
"zeta_prompt",
|
||||
"zlog",
|
||||
]
|
||||
|
||||
@@ -5260,7 +5249,6 @@ dependencies = [
|
||||
"buffer_diff",
|
||||
"client",
|
||||
"cloud_llm_client",
|
||||
"cloud_zeta2_prompt",
|
||||
"codestral",
|
||||
"command_palette_hooks",
|
||||
"copilot",
|
||||
@@ -5291,6 +5279,7 @@ dependencies = [
|
||||
"util",
|
||||
"workspace",
|
||||
"zed_actions",
|
||||
"zeta_prompt",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -20933,6 +20922,13 @@ dependencies = [
|
||||
"syn 2.0.106",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zeta_prompt"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "zip"
|
||||
version = "0.6.6"
|
||||
|
||||
@@ -32,7 +32,6 @@ members = [
|
||||
"crates/cloud_api_client",
|
||||
"crates/cloud_api_types",
|
||||
"crates/cloud_llm_client",
|
||||
"crates/cloud_zeta2_prompt",
|
||||
"crates/collab",
|
||||
"crates/collab_ui",
|
||||
"crates/collections",
|
||||
@@ -202,6 +201,7 @@ members = [
|
||||
"crates/zed_actions",
|
||||
"crates/zed_env_vars",
|
||||
"crates/edit_prediction_cli",
|
||||
"crates/zeta_prompt",
|
||||
"crates/zlog",
|
||||
"crates/zlog_settings",
|
||||
"crates/ztracing",
|
||||
@@ -266,7 +266,6 @@ clock = { path = "crates/clock" }
|
||||
cloud_api_client = { path = "crates/cloud_api_client" }
|
||||
cloud_api_types = { path = "crates/cloud_api_types" }
|
||||
cloud_llm_client = { path = "crates/cloud_llm_client" }
|
||||
cloud_zeta2_prompt = { path = "crates/cloud_zeta2_prompt" }
|
||||
collab_ui = { path = "crates/collab_ui" }
|
||||
collections = { path = "crates/collections", version = "0.1.0" }
|
||||
command_palette = { path = "crates/command_palette" }
|
||||
@@ -425,6 +424,7 @@ zed = { path = "crates/zed" }
|
||||
zed_actions = { path = "crates/zed_actions" }
|
||||
zed_env_vars = { path = "crates/zed_env_vars" }
|
||||
edit_prediction = { path = "crates/edit_prediction" }
|
||||
zeta_prompt = { path = "crates/zeta_prompt" }
|
||||
zlog = { path = "crates/zlog" }
|
||||
zlog_settings = { path = "crates/zlog_settings" }
|
||||
ztracing = { path = "crates/ztracing" }
|
||||
@@ -631,7 +631,7 @@ shellexpand = "2.1.0"
|
||||
shlex = "1.3.0"
|
||||
simplelog = "0.12.2"
|
||||
slotmap = "1.0.6"
|
||||
smallvec = { version = "1.6", features = ["union"] }
|
||||
smallvec = { version = "1.6", features = ["union", "const_new"] }
|
||||
smol = "2.0"
|
||||
sqlformat = "0.2"
|
||||
stacksafe = "0.1"
|
||||
@@ -657,6 +657,7 @@ time = { version = "0.3", features = [
|
||||
tiny_http = "0.8"
|
||||
tokio = { version = "1" }
|
||||
tokio-tungstenite = { version = "0.26", features = ["__rustls-tls"] }
|
||||
tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io", "tokio"] }
|
||||
toml = "0.8"
|
||||
toml_edit = { version = "0.22", default-features = false, features = ["display", "parse", "serde"] }
|
||||
tower-http = "0.4.4"
|
||||
|
||||
@@ -489,8 +489,8 @@
|
||||
"bindings": {
|
||||
"ctrl-[": "editor::Outdent",
|
||||
"ctrl-]": "editor::Indent",
|
||||
"ctrl-shift-alt-up": ["editor::AddSelectionAbove", { "skip_soft_wrap": true }], // Insert Cursor Above
|
||||
"ctrl-shift-alt-down": ["editor::AddSelectionBelow", { "skip_soft_wrap": true }], // Insert Cursor Below
|
||||
"ctrl-alt-up": ["editor::AddSelectionAbove", { "skip_soft_wrap": true }], // Insert Cursor Above
|
||||
"ctrl-alt-down": ["editor::AddSelectionBelow", { "skip_soft_wrap": true }], // Insert Cursor Below
|
||||
"ctrl-shift-k": "editor::DeleteLine",
|
||||
"alt-up": "editor::MoveLineUp",
|
||||
"alt-down": "editor::MoveLineDown",
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::{context::LoadedContext, inline_prompt_editor::CodegenStatus};
|
||||
use agent_settings::AgentSettings;
|
||||
use anyhow::{Context as _, Result};
|
||||
|
||||
use client::telemetry::Telemetry;
|
||||
use cloud_llm_client::CompletionIntent;
|
||||
use collections::HashSet;
|
||||
@@ -11,12 +12,14 @@ use futures::{
|
||||
channel::mpsc,
|
||||
future::{LocalBoxFuture, Shared},
|
||||
join,
|
||||
stream::BoxStream,
|
||||
};
|
||||
use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task};
|
||||
use language::{Buffer, IndentKind, Point, TransactionId, line_diff};
|
||||
use language_model::{
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelRegistry, LanguageModelRequest,
|
||||
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelTextStream, Role,
|
||||
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
|
||||
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
|
||||
LanguageModelRequestTool, LanguageModelTextStream, LanguageModelToolUse, Role, TokenUsage,
|
||||
report_assistant_event,
|
||||
};
|
||||
use multi_buffer::MultiBufferRow;
|
||||
@@ -46,6 +49,7 @@ pub struct FailureMessageInput {
|
||||
/// A brief message to the user explaining why you're unable to fulfill the request or to ask a question about the request.
|
||||
///
|
||||
/// The message may use markdown formatting if you wish.
|
||||
#[serde(default)]
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
@@ -56,9 +60,11 @@ pub struct RewriteSectionInput {
|
||||
///
|
||||
/// The description may use markdown formatting if you wish.
|
||||
/// This is optional - if the edit is simple or obvious, you should leave it empty.
|
||||
#[serde(default)]
|
||||
pub description: String,
|
||||
|
||||
/// The text to replace the section with.
|
||||
#[serde(default)]
|
||||
pub replacement_text: String,
|
||||
}
|
||||
|
||||
@@ -400,9 +406,15 @@ impl CodegenAlternative {
|
||||
|
||||
if cx.has_flag::<InlineAssistantV2FeatureFlag>() {
|
||||
let request = self.build_request(&model, user_prompt, context_task, cx)?;
|
||||
let tool_use =
|
||||
cx.spawn(async move |_, cx| model.stream_completion_tool(request.await, cx).await);
|
||||
self.handle_tool_use(telemetry_id, provider_id.to_string(), api_key, tool_use, cx);
|
||||
let completion_events =
|
||||
cx.spawn(async move |_, cx| model.stream_completion(request.await, cx).await);
|
||||
self.generation = self.handle_completion(
|
||||
telemetry_id,
|
||||
provider_id.to_string(),
|
||||
api_key,
|
||||
completion_events,
|
||||
cx,
|
||||
);
|
||||
} else {
|
||||
let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
|
||||
if user_prompt.trim().to_lowercase() == "delete" {
|
||||
@@ -414,7 +426,8 @@ impl CodegenAlternative {
|
||||
})
|
||||
.boxed_local()
|
||||
};
|
||||
self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
|
||||
self.generation =
|
||||
self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -603,7 +616,7 @@ impl CodegenAlternative {
|
||||
model_api_key: Option<String>,
|
||||
stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
) -> Task<()> {
|
||||
let start_time = Instant::now();
|
||||
|
||||
// Make a new snapshot and re-resolve anchor in case the document was modified.
|
||||
@@ -659,7 +672,8 @@ impl CodegenAlternative {
|
||||
let completion = Arc::new(Mutex::new(String::new()));
|
||||
let completion_clone = completion.clone();
|
||||
|
||||
self.generation = cx.spawn(async move |codegen, cx| {
|
||||
cx.notify();
|
||||
cx.spawn(async move |codegen, cx| {
|
||||
let stream = stream.await;
|
||||
|
||||
let token_usage = stream
|
||||
@@ -685,6 +699,7 @@ impl CodegenAlternative {
|
||||
stream?.stream.map_err(|error| error.into()),
|
||||
);
|
||||
futures::pin_mut!(chunks);
|
||||
|
||||
let mut diff = StreamingDiff::new(selected_text.to_string());
|
||||
let mut line_diff = LineDiff::default();
|
||||
|
||||
@@ -694,6 +709,7 @@ impl CodegenAlternative {
|
||||
let mut first_line = true;
|
||||
|
||||
while let Some(chunk) = chunks.next().await {
|
||||
dbg!(&chunk);
|
||||
if response_latency.is_none() {
|
||||
response_latency = Some(request_start.elapsed());
|
||||
}
|
||||
@@ -876,8 +892,7 @@ impl CodegenAlternative {
|
||||
cx.notify();
|
||||
})
|
||||
.ok();
|
||||
});
|
||||
cx.notify();
|
||||
})
|
||||
}
|
||||
|
||||
pub fn current_completion(&self) -> Option<String> {
|
||||
@@ -1060,21 +1075,29 @@ impl CodegenAlternative {
|
||||
})
|
||||
}
|
||||
|
||||
fn handle_tool_use(
|
||||
fn handle_completion(
|
||||
&mut self,
|
||||
_telemetry_id: String,
|
||||
_provider_id: String,
|
||||
_api_key: Option<String>,
|
||||
tool_use: impl 'static
|
||||
+ Future<
|
||||
Output = Result<language_model::LanguageModelToolUse, LanguageModelCompletionError>,
|
||||
telemetry_id: String,
|
||||
provider_id: String,
|
||||
api_key: Option<String>,
|
||||
completion_stream: Task<
|
||||
Result<
|
||||
BoxStream<
|
||||
'static,
|
||||
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
|
||||
>,
|
||||
LanguageModelCompletionError,
|
||||
>,
|
||||
>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
) -> Task<()> {
|
||||
self.diff = Diff::default();
|
||||
self.status = CodegenStatus::Pending;
|
||||
|
||||
self.generation = cx.spawn(async move |codegen, cx| {
|
||||
cx.notify();
|
||||
// Leaving this in generation so that STOP equivalent events are respected even
|
||||
// while we're still pre-processing the completion event
|
||||
cx.spawn(async move |codegen, cx| {
|
||||
let finish_with_status = |status: CodegenStatus, cx: &mut AsyncApp| {
|
||||
let _ = codegen.update(cx, |this, cx| {
|
||||
this.status = status;
|
||||
@@ -1083,76 +1106,177 @@ impl CodegenAlternative {
|
||||
});
|
||||
};
|
||||
|
||||
let tool_use = tool_use.await;
|
||||
let mut completion_events = match completion_stream.await {
|
||||
Ok(events) => events,
|
||||
Err(err) => {
|
||||
finish_with_status(CodegenStatus::Error(err.into()), cx);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
match tool_use {
|
||||
Ok(tool_use) if tool_use.name.as_ref() == "rewrite_section" => {
|
||||
// Parse the input JSON into RewriteSectionInput
|
||||
match serde_json::from_value::<RewriteSectionInput>(tool_use.input) {
|
||||
Ok(input) => {
|
||||
// Store the description if non-empty
|
||||
let description = if !input.description.trim().is_empty() {
|
||||
Some(input.description.clone())
|
||||
} else {
|
||||
None
|
||||
let chars_read_so_far = Arc::new(Mutex::new(0usize));
|
||||
let tool_to_text_and_message =
|
||||
move |tool_use: LanguageModelToolUse| -> (Option<String>, Option<String>) {
|
||||
let mut chars_read_so_far = chars_read_so_far.lock();
|
||||
match tool_use.name.as_ref() {
|
||||
"rewrite_section" => {
|
||||
let Ok(mut input) =
|
||||
serde_json::from_value::<RewriteSectionInput>(tool_use.input)
|
||||
else {
|
||||
return (None, None);
|
||||
};
|
||||
let value = input.replacement_text[*chars_read_so_far..].to_string();
|
||||
*chars_read_so_far = input.replacement_text.len();
|
||||
(Some(value), Some(std::mem::take(&mut input.description)))
|
||||
}
|
||||
"failure_message" => {
|
||||
let Ok(mut input) =
|
||||
serde_json::from_value::<FailureMessageInput>(tool_use.input)
|
||||
else {
|
||||
return (None, None);
|
||||
};
|
||||
(None, Some(std::mem::take(&mut input.message)))
|
||||
}
|
||||
_ => (None, None),
|
||||
}
|
||||
};
|
||||
|
||||
// Apply the replacement text to the buffer and compute diff
|
||||
let batch_diff_task = codegen
|
||||
.update(cx, |this, cx| {
|
||||
this.model_explanation = description.map(Into::into);
|
||||
let range = this.range.clone();
|
||||
this.apply_edits(
|
||||
std::iter::once((range, input.replacement_text)),
|
||||
cx,
|
||||
);
|
||||
this.reapply_batch_diff(cx)
|
||||
})
|
||||
.ok();
|
||||
let mut message_id = None;
|
||||
let mut first_text = None;
|
||||
let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
|
||||
let total_text = Arc::new(Mutex::new(String::new()));
|
||||
|
||||
// Wait for the diff computation to complete
|
||||
if let Some(diff_task) = batch_diff_task {
|
||||
diff_task.await;
|
||||
loop {
|
||||
if let Some(first_event) = completion_events.next().await {
|
||||
dbg!(&first_event);
|
||||
match first_event {
|
||||
Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
|
||||
message_id = Some(id);
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
|
||||
if matches!(
|
||||
tool_use.name.as_ref(),
|
||||
"rewrite_section" | "failure_message"
|
||||
) =>
|
||||
{
|
||||
let is_complete = tool_use.is_input_complete;
|
||||
let (text, message) = tool_to_text_and_message(tool_use);
|
||||
// Only update the model explanation if the tool use is complete.
|
||||
// Otherwise the UI element bounces around as it's updated.
|
||||
if is_complete {
|
||||
let _ = codegen.update(cx, |this, _cx| {
|
||||
this.model_explanation = message.map(Into::into);
|
||||
});
|
||||
}
|
||||
|
||||
finish_with_status(CodegenStatus::Done, cx);
|
||||
return;
|
||||
first_text = text;
|
||||
if first_text.is_some() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
|
||||
*last_token_usage.lock() = token_usage;
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::Text(text)) => {
|
||||
let mut lock = total_text.lock();
|
||||
lock.push_str(&text);
|
||||
}
|
||||
Ok(e) => {
|
||||
log::warn!("Unexpected event: {:?}", e);
|
||||
break;
|
||||
}
|
||||
Err(e) => {
|
||||
finish_with_status(CodegenStatus::Error(e.into()), cx);
|
||||
return;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(tool_use) if tool_use.name.as_ref() == "failure_message" => {
|
||||
// Handle failure message tool use
|
||||
match serde_json::from_value::<FailureMessageInput>(tool_use.input) {
|
||||
Ok(input) => {
|
||||
let _ = codegen.update(cx, |this, _cx| {
|
||||
// Store the failure message as the tool description
|
||||
this.model_explanation = Some(input.message.into());
|
||||
});
|
||||
finish_with_status(CodegenStatus::Done, cx);
|
||||
return;
|
||||
}
|
||||
Err(e) => {
|
||||
finish_with_status(CodegenStatus::Error(e.into()), cx);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(_tool_use) => {
|
||||
// Unexpected tool.
|
||||
finish_with_status(CodegenStatus::Done, cx);
|
||||
return;
|
||||
}
|
||||
Err(e) => {
|
||||
finish_with_status(CodegenStatus::Error(e.into()), cx);
|
||||
return;
|
||||
}
|
||||
}
|
||||
});
|
||||
cx.notify();
|
||||
|
||||
let Some(first_text) = first_text else {
|
||||
finish_with_status(CodegenStatus::Done, cx);
|
||||
return;
|
||||
};
|
||||
|
||||
let (message_tx, mut message_rx) = futures::channel::mpsc::unbounded();
|
||||
|
||||
cx.spawn({
|
||||
let codegen = codegen.clone();
|
||||
async move |cx| {
|
||||
while let Some(message) = message_rx.next().await {
|
||||
let _ = codegen.update(cx, |this, _cx| {
|
||||
this.model_explanation = message;
|
||||
});
|
||||
}
|
||||
}
|
||||
})
|
||||
.detach();
|
||||
|
||||
let move_last_token_usage = last_token_usage.clone();
|
||||
|
||||
let text_stream = Box::pin(futures::stream::once(async { Ok(first_text) }).chain(
|
||||
completion_events.filter_map(move |e| {
|
||||
let tool_to_text_and_message = tool_to_text_and_message.clone();
|
||||
let last_token_usage = move_last_token_usage.clone();
|
||||
let total_text = total_text.clone();
|
||||
let mut message_tx = message_tx.clone();
|
||||
async move {
|
||||
match e {
|
||||
Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
|
||||
if matches!(
|
||||
tool_use.name.as_ref(),
|
||||
"rewrite_section" | "failure_message"
|
||||
) =>
|
||||
{
|
||||
let is_complete = tool_use.is_input_complete;
|
||||
let (text, message) = tool_to_text_and_message(tool_use);
|
||||
if is_complete {
|
||||
// Again only send the message when complete to not get a bouncing UI element.
|
||||
let _ = message_tx.send(message.map(Into::into)).await;
|
||||
}
|
||||
text.map(Ok)
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
|
||||
*last_token_usage.lock() = token_usage;
|
||||
None
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::Text(text)) => {
|
||||
let mut lock = total_text.lock();
|
||||
lock.push_str(&text);
|
||||
None
|
||||
}
|
||||
Ok(LanguageModelCompletionEvent::Stop(_reason)) => None,
|
||||
e => {
|
||||
log::error!("UNEXPECTED EVENT {:?}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}),
|
||||
));
|
||||
|
||||
let language_model_text_stream = LanguageModelTextStream {
|
||||
message_id: message_id,
|
||||
stream: text_stream,
|
||||
last_token_usage,
|
||||
};
|
||||
|
||||
let Some(task) = codegen
|
||||
.update(cx, move |codegen, cx| {
|
||||
codegen.handle_stream(
|
||||
telemetry_id,
|
||||
provider_id,
|
||||
api_key,
|
||||
async { Ok(language_model_text_stream) },
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.ok()
|
||||
else {
|
||||
return;
|
||||
};
|
||||
|
||||
task.await;
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1679,7 +1803,7 @@ mod tests {
|
||||
) -> mpsc::UnboundedSender<String> {
|
||||
let (chunks_tx, chunks_rx) = mpsc::unbounded();
|
||||
codegen.update(cx, |codegen, cx| {
|
||||
codegen.handle_stream(
|
||||
codegen.generation = codegen.handle_stream(
|
||||
String::new(),
|
||||
String::new(),
|
||||
None,
|
||||
|
||||
@@ -1455,60 +1455,8 @@ impl InlineAssistant {
|
||||
let old_snapshot = codegen.snapshot(cx);
|
||||
let old_buffer = codegen.old_buffer(cx);
|
||||
let deleted_row_ranges = codegen.diff(cx).deleted_row_ranges.clone();
|
||||
// let model_explanation = codegen.model_explanation(cx);
|
||||
|
||||
editor.update(cx, |editor, cx| {
|
||||
// Update tool description block
|
||||
// if let Some(description) = model_explanation {
|
||||
// if let Some(block_id) = decorations.model_explanation {
|
||||
// editor.remove_blocks(HashSet::from_iter([block_id]), None, cx);
|
||||
// let new_block_id = editor.insert_blocks(
|
||||
// [BlockProperties {
|
||||
// style: BlockStyle::Flex,
|
||||
// placement: BlockPlacement::Below(assist.range.end),
|
||||
// height: Some(1),
|
||||
// render: Arc::new({
|
||||
// let description = description.clone();
|
||||
// move |cx| {
|
||||
// div()
|
||||
// .w_full()
|
||||
// .py_1()
|
||||
// .px_2()
|
||||
// .bg(cx.theme().colors().editor_background)
|
||||
// .border_y_1()
|
||||
// .border_color(cx.theme().status().info_border)
|
||||
// .child(
|
||||
// Label::new(description.clone())
|
||||
// .color(Color::Muted)
|
||||
// .size(LabelSize::Small),
|
||||
// )
|
||||
// .into_any_element()
|
||||
// }
|
||||
// }),
|
||||
// priority: 0,
|
||||
// }],
|
||||
// None,
|
||||
// cx,
|
||||
// );
|
||||
// decorations.model_explanation = new_block_id.into_iter().next();
|
||||
// }
|
||||
// } else if let Some(block_id) = decorations.model_explanation {
|
||||
// // Hide the block if there's no description
|
||||
// editor.remove_blocks(HashSet::from_iter([block_id]), None, cx);
|
||||
// let new_block_id = editor.insert_blocks(
|
||||
// [BlockProperties {
|
||||
// style: BlockStyle::Flex,
|
||||
// placement: BlockPlacement::Below(assist.range.end),
|
||||
// height: Some(0),
|
||||
// render: Arc::new(|_cx| div().into_any_element()),
|
||||
// priority: 0,
|
||||
// }],
|
||||
// None,
|
||||
// cx,
|
||||
// );
|
||||
// decorations.model_explanation = new_block_id.into_iter().next();
|
||||
// }
|
||||
|
||||
let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
|
||||
editor.remove_blocks(old_blocks, None, cx);
|
||||
|
||||
|
||||
@@ -429,6 +429,19 @@ impl Model {
|
||||
let mut headers = vec![];
|
||||
|
||||
match self {
|
||||
Self::ClaudeOpus4
|
||||
| Self::ClaudeOpus4_1
|
||||
| Self::ClaudeOpus4_5
|
||||
| Self::ClaudeSonnet4
|
||||
| Self::ClaudeSonnet4_5
|
||||
| Self::ClaudeOpus4Thinking
|
||||
| Self::ClaudeOpus4_1Thinking
|
||||
| Self::ClaudeOpus4_5Thinking
|
||||
| Self::ClaudeSonnet4Thinking
|
||||
| Self::ClaudeSonnet4_5Thinking => {
|
||||
// Fine-grained tool streaming for newer models
|
||||
headers.push("fine-grained-tool-streaming-2025-05-14".to_string());
|
||||
}
|
||||
Self::Claude3_7Sonnet | Self::Claude3_7SonnetThinking => {
|
||||
// Try beta token-efficient tool use (supported in Claude 3.7 Sonnet only)
|
||||
// https://docs.anthropic.com/en/docs/build-with-claude/tool-use/token-efficient-tool-use
|
||||
|
||||
@@ -53,7 +53,7 @@ text.workspace = true
|
||||
thiserror.workspace = true
|
||||
time.workspace = true
|
||||
tiny_http.workspace = true
|
||||
tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io"] }
|
||||
tokio-socks.workspace = true
|
||||
tokio.workspace = true
|
||||
url.workspace = true
|
||||
util.workspace = true
|
||||
|
||||
@@ -1,18 +0,0 @@
|
||||
[package]
|
||||
name = "cloud_zeta2_prompt"
|
||||
version = "0.1.0"
|
||||
publish.workspace = true
|
||||
edition.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/cloud_zeta2_prompt.rs"
|
||||
|
||||
[dependencies]
|
||||
anyhow.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
indoc.workspace = true
|
||||
serde.workspace = true
|
||||
@@ -1,485 +0,0 @@
|
||||
use anyhow::Result;
|
||||
use cloud_llm_client::predict_edits_v3::{
|
||||
self, DiffPathFmt, Event, Excerpt, Line, Point, PromptFormat, RelatedFile,
|
||||
};
|
||||
use indoc::indoc;
|
||||
use std::cmp;
|
||||
use std::fmt::Write;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub const DEFAULT_MAX_PROMPT_BYTES: usize = 10 * 1024;
|
||||
|
||||
pub const CURSOR_MARKER: &str = "<|user_cursor|>";
|
||||
/// NOTE: Differs from zed version of constant - includes a newline
|
||||
pub const EDITABLE_REGION_START_MARKER_WITH_NEWLINE: &str = "<|editable_region_start|>\n";
|
||||
/// NOTE: Differs from zed version of constant - includes a newline
|
||||
pub const EDITABLE_REGION_END_MARKER_WITH_NEWLINE: &str = "<|editable_region_end|>\n";
|
||||
|
||||
const STUDENT_MODEL_INSTRUCTIONS: &str = indoc! {r#"
|
||||
You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase.
|
||||
|
||||
## Edit History
|
||||
|
||||
"#};
|
||||
|
||||
const MINIMAL_PROMPT_REMINDER: &str = indoc! {"
|
||||
---
|
||||
|
||||
Please analyze the edit history and the files, then provide the unified diff for your predicted edits.
|
||||
Do not include the cursor marker in your output.
|
||||
If you're editing multiple files, be sure to reflect filename in the hunk's header.
|
||||
"};
|
||||
|
||||
const XML_TAGS_INSTRUCTIONS: &str = indoc! {r#"
|
||||
# Instructions
|
||||
|
||||
You are an edit prediction agent in a code editor.
|
||||
|
||||
Analyze the history of edits made by the user in order to infer what they are currently trying to accomplish.
|
||||
Then complete the remainder of the current change if it is incomplete, or predict the next edit the user intends to make.
|
||||
Always continue along the user's current trajectory, rather than changing course.
|
||||
|
||||
## Output Format
|
||||
|
||||
You should briefly explain your understanding of the user's overall goal in one sentence, then explain what the next change
|
||||
along the users current trajectory will be in another, and finally specify the next edit using the following XML-like format:
|
||||
|
||||
<edits path="my-project/src/myapp/cli.py">
|
||||
<old_text>
|
||||
OLD TEXT 1 HERE
|
||||
</old_text>
|
||||
<new_text>
|
||||
NEW TEXT 1 HERE
|
||||
</new_text>
|
||||
|
||||
<old_text>
|
||||
OLD TEXT 1 HERE
|
||||
</old_text>
|
||||
<new_text>
|
||||
NEW TEXT 1 HERE
|
||||
</new_text>
|
||||
</edits>
|
||||
|
||||
- Specify the file to edit using the `path` attribute.
|
||||
- Use `<old_text>` and `<new_text>` tags to replace content
|
||||
- `<old_text>` must exactly match existing file content, including indentation
|
||||
- `<old_text>` cannot be empty
|
||||
- Do not escape quotes, newlines, or other characters within tags
|
||||
- Always close all tags properly
|
||||
- Don't include the <|user_cursor|> marker in your output.
|
||||
|
||||
## Edit History
|
||||
|
||||
"#};
|
||||
|
||||
const OLD_TEXT_NEW_TEXT_REMINDER: &str = indoc! {r#"
|
||||
---
|
||||
|
||||
Remember that the edits in the edit history have already been applied.
|
||||
"#};
|
||||
|
||||
pub fn build_prompt(request: &predict_edits_v3::PredictEditsRequest) -> Result<String> {
|
||||
let prompt_data = PromptData {
|
||||
events: request.events.clone(),
|
||||
cursor_point: request.cursor_point,
|
||||
cursor_path: request.excerpt_path.clone(),
|
||||
included_files: request.related_files.clone(),
|
||||
};
|
||||
match request.prompt_format {
|
||||
PromptFormat::MinimalQwen => {
|
||||
return Ok(MinimalQwenPrompt.render(&prompt_data));
|
||||
}
|
||||
PromptFormat::SeedCoder1120 => {
|
||||
return Ok(SeedCoder1120Prompt.render(&prompt_data));
|
||||
}
|
||||
_ => (),
|
||||
};
|
||||
|
||||
let insertions = match request.prompt_format {
|
||||
PromptFormat::Minimal | PromptFormat::OldTextNewText => {
|
||||
vec![(request.cursor_point, CURSOR_MARKER)]
|
||||
}
|
||||
PromptFormat::OnlySnippets => vec![],
|
||||
PromptFormat::MinimalQwen => unreachable!(),
|
||||
PromptFormat::SeedCoder1120 => unreachable!(),
|
||||
};
|
||||
|
||||
let mut prompt = match request.prompt_format {
|
||||
PromptFormat::OldTextNewText => XML_TAGS_INSTRUCTIONS.to_string(),
|
||||
PromptFormat::OnlySnippets => String::new(),
|
||||
PromptFormat::Minimal => STUDENT_MODEL_INSTRUCTIONS.to_string(),
|
||||
PromptFormat::MinimalQwen => unreachable!(),
|
||||
PromptFormat::SeedCoder1120 => unreachable!(),
|
||||
};
|
||||
|
||||
if request.events.is_empty() {
|
||||
prompt.push_str("(No edit history)\n\n");
|
||||
} else {
|
||||
let edit_preamble = if request.prompt_format == PromptFormat::Minimal {
|
||||
"The following are the latest edits made by the user, from earlier to later.\n\n"
|
||||
} else {
|
||||
"Here are the latest edits made by the user, from earlier to later.\n\n"
|
||||
};
|
||||
prompt.push_str(edit_preamble);
|
||||
push_events(&mut prompt, &request.events);
|
||||
}
|
||||
|
||||
let excerpts_preamble = match request.prompt_format {
|
||||
PromptFormat::Minimal => indoc! {"
|
||||
## Part of the file under the cursor
|
||||
|
||||
(The cursor marker <|user_cursor|> indicates the current user cursor position.
|
||||
The file is in current state, edits from edit history has been applied.
|
||||
We only show part of the file around the cursor.
|
||||
You can only edit exactly this part of the file.
|
||||
We prepend line numbers (e.g., `123|<actual line>`); they are not part of the file.)
|
||||
"},
|
||||
PromptFormat::OldTextNewText => indoc! {"
|
||||
## Code Excerpts
|
||||
|
||||
Here is some excerpts of code that you should take into account to predict the next edit.
|
||||
|
||||
The cursor position is marked by `<|user_cursor|>` as it stands after the last edit in the history.
|
||||
|
||||
In addition other excerpts are included to better understand what the edit will be, including the declaration
|
||||
or references of symbols around the cursor, or other similar code snippets that may need to be updated
|
||||
following patterns that appear in the edit history.
|
||||
|
||||
Consider each of them carefully in relation to the edit history, and that the user may not have navigated
|
||||
to the next place they want to edit yet.
|
||||
|
||||
Lines starting with `…` indicate omitted line ranges. These may appear inside multi-line code constructs.
|
||||
"},
|
||||
PromptFormat::OnlySnippets | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => {
|
||||
indoc! {"
|
||||
## Code Excerpts
|
||||
|
||||
The cursor marker <|user_cursor|> indicates the current user cursor position.
|
||||
The file is in current state, edits from edit history have been applied.
|
||||
"}
|
||||
}
|
||||
};
|
||||
|
||||
prompt.push_str(excerpts_preamble);
|
||||
prompt.push('\n');
|
||||
|
||||
let include_line_numbers = matches!(request.prompt_format, PromptFormat::Minimal);
|
||||
for related_file in &request.related_files {
|
||||
if request.prompt_format == PromptFormat::Minimal {
|
||||
write_codeblock_with_filename(
|
||||
&related_file.path,
|
||||
&related_file.excerpts,
|
||||
if related_file.path == request.excerpt_path {
|
||||
&insertions
|
||||
} else {
|
||||
&[]
|
||||
},
|
||||
related_file.max_row,
|
||||
include_line_numbers,
|
||||
&mut prompt,
|
||||
);
|
||||
} else {
|
||||
write_codeblock(
|
||||
&related_file.path,
|
||||
&related_file.excerpts,
|
||||
if related_file.path == request.excerpt_path {
|
||||
&insertions
|
||||
} else {
|
||||
&[]
|
||||
},
|
||||
related_file.max_row,
|
||||
include_line_numbers,
|
||||
&mut prompt,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
match request.prompt_format {
|
||||
PromptFormat::OldTextNewText => {
|
||||
prompt.push_str(OLD_TEXT_NEW_TEXT_REMINDER);
|
||||
}
|
||||
PromptFormat::Minimal => {
|
||||
prompt.push_str(MINIMAL_PROMPT_REMINDER);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok(prompt)
|
||||
}
|
||||
|
||||
pub fn generation_params(prompt_format: PromptFormat) -> GenerationParams {
|
||||
match prompt_format {
|
||||
PromptFormat::SeedCoder1120 => SeedCoder1120Prompt::generation_params(),
|
||||
_ => GenerationParams::default(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write_codeblock<'a>(
|
||||
path: &Path,
|
||||
excerpts: impl IntoIterator<Item = &'a Excerpt>,
|
||||
sorted_insertions: &[(Point, &str)],
|
||||
file_line_count: Line,
|
||||
include_line_numbers: bool,
|
||||
output: &'a mut String,
|
||||
) {
|
||||
writeln!(output, "`````{}", DiffPathFmt(path)).unwrap();
|
||||
|
||||
write_excerpts(
|
||||
excerpts,
|
||||
sorted_insertions,
|
||||
file_line_count,
|
||||
include_line_numbers,
|
||||
output,
|
||||
);
|
||||
write!(output, "`````\n\n").unwrap();
|
||||
}
|
||||
|
||||
fn write_codeblock_with_filename<'a>(
|
||||
path: &Path,
|
||||
excerpts: impl IntoIterator<Item = &'a Excerpt>,
|
||||
sorted_insertions: &[(Point, &str)],
|
||||
file_line_count: Line,
|
||||
include_line_numbers: bool,
|
||||
output: &'a mut String,
|
||||
) {
|
||||
writeln!(output, "`````filename={}", DiffPathFmt(path)).unwrap();
|
||||
|
||||
write_excerpts(
|
||||
excerpts,
|
||||
sorted_insertions,
|
||||
file_line_count,
|
||||
include_line_numbers,
|
||||
output,
|
||||
);
|
||||
write!(output, "`````\n\n").unwrap();
|
||||
}
|
||||
|
||||
pub fn write_excerpts<'a>(
|
||||
excerpts: impl IntoIterator<Item = &'a Excerpt>,
|
||||
sorted_insertions: &[(Point, &str)],
|
||||
file_line_count: Line,
|
||||
include_line_numbers: bool,
|
||||
output: &mut String,
|
||||
) {
|
||||
let mut current_row = Line(0);
|
||||
let mut sorted_insertions = sorted_insertions.iter().peekable();
|
||||
|
||||
for excerpt in excerpts {
|
||||
if excerpt.start_line > current_row {
|
||||
writeln!(output, "…").unwrap();
|
||||
}
|
||||
if excerpt.text.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
current_row = excerpt.start_line;
|
||||
|
||||
for mut line in excerpt.text.lines() {
|
||||
if include_line_numbers {
|
||||
write!(output, "{}|", current_row.0 + 1).unwrap();
|
||||
}
|
||||
|
||||
while let Some((insertion_location, insertion_marker)) = sorted_insertions.peek() {
|
||||
match current_row.cmp(&insertion_location.line) {
|
||||
cmp::Ordering::Equal => {
|
||||
let (prefix, suffix) = line.split_at(insertion_location.column as usize);
|
||||
output.push_str(prefix);
|
||||
output.push_str(insertion_marker);
|
||||
line = suffix;
|
||||
sorted_insertions.next();
|
||||
}
|
||||
cmp::Ordering::Less => break,
|
||||
cmp::Ordering::Greater => {
|
||||
sorted_insertions.next();
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
output.push_str(line);
|
||||
output.push('\n');
|
||||
current_row.0 += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if current_row < file_line_count {
|
||||
writeln!(output, "…").unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push_events(output: &mut String, events: &[Arc<predict_edits_v3::Event>]) {
|
||||
if events.is_empty() {
|
||||
return;
|
||||
};
|
||||
|
||||
writeln!(output, "`````diff").unwrap();
|
||||
for event in events {
|
||||
writeln!(output, "{}", event).unwrap();
|
||||
}
|
||||
writeln!(output, "`````\n").unwrap();
|
||||
}
|
||||
|
||||
struct PromptData {
|
||||
events: Vec<Arc<Event>>,
|
||||
cursor_point: Point,
|
||||
cursor_path: Arc<Path>, // TODO: make a common struct with cursor_point
|
||||
included_files: Vec<RelatedFile>,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct GenerationParams {
|
||||
pub temperature: Option<f32>,
|
||||
pub top_p: Option<f32>,
|
||||
pub stop: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
trait PromptFormatter {
|
||||
fn render(&self, data: &PromptData) -> String;
|
||||
|
||||
fn generation_params() -> GenerationParams {
|
||||
return GenerationParams::default();
|
||||
}
|
||||
}
|
||||
|
||||
struct MinimalQwenPrompt;
|
||||
|
||||
impl PromptFormatter for MinimalQwenPrompt {
|
||||
fn render(&self, data: &PromptData) -> String {
|
||||
let edit_history = self.fmt_edit_history(data);
|
||||
let context = self.fmt_context(data);
|
||||
|
||||
format!(
|
||||
"{instructions}\n\n{edit_history}\n\n{context}",
|
||||
instructions = MinimalQwenPrompt::INSTRUCTIONS,
|
||||
edit_history = edit_history,
|
||||
context = context
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl MinimalQwenPrompt {
|
||||
const INSTRUCTIONS: &str = "You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase.\n";
|
||||
|
||||
fn fmt_edit_history(&self, data: &PromptData) -> String {
|
||||
if data.events.is_empty() {
|
||||
"(No edit history)\n\n".to_string()
|
||||
} else {
|
||||
let mut events_str = String::new();
|
||||
push_events(&mut events_str, &data.events);
|
||||
format!(
|
||||
"The following are the latest edits made by the user, from earlier to later.\n\n{}",
|
||||
events_str
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
fn fmt_context(&self, data: &PromptData) -> String {
|
||||
let mut context = String::new();
|
||||
let include_line_numbers = true;
|
||||
|
||||
for related_file in &data.included_files {
|
||||
writeln!(context, "<|file_sep|>{}", DiffPathFmt(&related_file.path)).unwrap();
|
||||
|
||||
if related_file.path == data.cursor_path {
|
||||
write!(context, "<|fim_prefix|>").unwrap();
|
||||
write_excerpts(
|
||||
&related_file.excerpts,
|
||||
&[(data.cursor_point, "<|fim_suffix|>")],
|
||||
related_file.max_row,
|
||||
include_line_numbers,
|
||||
&mut context,
|
||||
);
|
||||
writeln!(context, "<|fim_middle|>").unwrap();
|
||||
} else {
|
||||
write_excerpts(
|
||||
&related_file.excerpts,
|
||||
&[],
|
||||
related_file.max_row,
|
||||
include_line_numbers,
|
||||
&mut context,
|
||||
);
|
||||
}
|
||||
}
|
||||
context
|
||||
}
|
||||
}
|
||||
|
||||
struct SeedCoder1120Prompt;
|
||||
|
||||
impl PromptFormatter for SeedCoder1120Prompt {
|
||||
fn render(&self, data: &PromptData) -> String {
|
||||
let edit_history = self.fmt_edit_history(data);
|
||||
let context = self.fmt_context(data);
|
||||
|
||||
format!(
|
||||
"# Edit History:\n{edit_history}\n\n{context}",
|
||||
edit_history = edit_history,
|
||||
context = context
|
||||
)
|
||||
}
|
||||
|
||||
fn generation_params() -> GenerationParams {
|
||||
GenerationParams {
|
||||
temperature: Some(0.2),
|
||||
top_p: Some(0.9),
|
||||
stop: Some(vec!["<[end_of_sentence]>".into()]),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl SeedCoder1120Prompt {
|
||||
fn fmt_edit_history(&self, data: &PromptData) -> String {
|
||||
if data.events.is_empty() {
|
||||
"(No edit history)\n\n".to_string()
|
||||
} else {
|
||||
let mut events_str = String::new();
|
||||
push_events(&mut events_str, &data.events);
|
||||
events_str
|
||||
}
|
||||
}
|
||||
|
||||
fn fmt_context(&self, data: &PromptData) -> String {
|
||||
let mut context = String::new();
|
||||
let include_line_numbers = true;
|
||||
|
||||
for related_file in &data.included_files {
|
||||
writeln!(context, "# Path: {}\n", DiffPathFmt(&related_file.path)).unwrap();
|
||||
|
||||
if related_file.path == data.cursor_path {
|
||||
let fim_prompt = self.fmt_fim(&related_file, data.cursor_point);
|
||||
context.push_str(&fim_prompt);
|
||||
} else {
|
||||
write_excerpts(
|
||||
&related_file.excerpts,
|
||||
&[],
|
||||
related_file.max_row,
|
||||
include_line_numbers,
|
||||
&mut context,
|
||||
);
|
||||
}
|
||||
}
|
||||
context
|
||||
}
|
||||
|
||||
fn fmt_fim(&self, file: &RelatedFile, cursor_point: Point) -> String {
|
||||
let mut buf = String::new();
|
||||
const FIM_SUFFIX: &str = "<[fim-suffix]>";
|
||||
const FIM_PREFIX: &str = "<[fim-prefix]>";
|
||||
const FIM_MIDDLE: &str = "<[fim-middle]>";
|
||||
write!(buf, "{}", FIM_PREFIX).unwrap();
|
||||
write_excerpts(
|
||||
&file.excerpts,
|
||||
&[(cursor_point, FIM_SUFFIX)],
|
||||
file.max_row,
|
||||
true,
|
||||
&mut buf,
|
||||
);
|
||||
|
||||
// Swap prefix and suffix parts
|
||||
let index = buf.find(FIM_SUFFIX).unwrap();
|
||||
let prefix = &buf[..index];
|
||||
let suffix = &buf[index..];
|
||||
|
||||
format!("{}{}{}", suffix, prefix, FIM_MIDDLE)
|
||||
}
|
||||
}
|
||||
@@ -516,7 +516,8 @@ impl Copilot {
|
||||
None,
|
||||
Default::default(),
|
||||
cx,
|
||||
)?;
|
||||
)
|
||||
.await?;
|
||||
|
||||
server
|
||||
.on_notification::<StatusNotification, _>(|_, _| { /* Silence the notification */ })
|
||||
|
||||
@@ -1045,54 +1045,47 @@ async fn heuristic_syntactic_expand(
|
||||
let node_range = node_start..node_end;
|
||||
let row_count = node_end.row - node_start.row + 1;
|
||||
let mut ancestor_range = None;
|
||||
let reached_outline_node = cx.background_executor().scoped({
|
||||
let node_range = node_range.clone();
|
||||
let outline_range = outline_range.clone();
|
||||
let ancestor_range = &mut ancestor_range;
|
||||
|scope| {
|
||||
scope.spawn(async move {
|
||||
// Stop if we've exceeded the row count or reached an outline node. Then, find the interval
|
||||
// of node children which contains the query range. For example, this allows just returning
|
||||
// the header of a declaration rather than the entire declaration.
|
||||
if row_count > max_row_count || outline_range == Some(node_range.clone()) {
|
||||
let mut cursor = node.walk();
|
||||
let mut included_child_start = None;
|
||||
let mut included_child_end = None;
|
||||
let mut previous_end = node_start;
|
||||
if cursor.goto_first_child() {
|
||||
loop {
|
||||
let child_node = cursor.node();
|
||||
let child_range =
|
||||
previous_end..Point::from_ts_point(child_node.end_position());
|
||||
if included_child_start.is_none()
|
||||
&& child_range.contains(&input_range.start)
|
||||
{
|
||||
included_child_start = Some(child_range.start);
|
||||
}
|
||||
if child_range.contains(&input_range.end) {
|
||||
included_child_end = Some(child_range.end);
|
||||
}
|
||||
previous_end = child_range.end;
|
||||
if !cursor.goto_next_sibling() {
|
||||
break;
|
||||
}
|
||||
cx.background_executor()
|
||||
.await_on_background(async {
|
||||
// Stop if we've exceeded the row count or reached an outline node. Then, find the interval
|
||||
// of node children which contains the query range. For example, this allows just returning
|
||||
// the header of a declaration rather than the entire declaration.
|
||||
if row_count > max_row_count || outline_range == Some(node_range.clone()) {
|
||||
let mut cursor = node.walk();
|
||||
let mut included_child_start = None;
|
||||
let mut included_child_end = None;
|
||||
let mut previous_end = node_start;
|
||||
if cursor.goto_first_child() {
|
||||
loop {
|
||||
let child_node = cursor.node();
|
||||
let child_range =
|
||||
previous_end..Point::from_ts_point(child_node.end_position());
|
||||
if included_child_start.is_none()
|
||||
&& child_range.contains(&input_range.start)
|
||||
{
|
||||
included_child_start = Some(child_range.start);
|
||||
}
|
||||
if child_range.contains(&input_range.end) {
|
||||
included_child_end = Some(child_range.end);
|
||||
}
|
||||
previous_end = child_range.end;
|
||||
if !cursor.goto_next_sibling() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let end = included_child_end.unwrap_or(node_range.end);
|
||||
if let Some(start) = included_child_start {
|
||||
let row_count = end.row - start.row;
|
||||
if row_count < max_row_count {
|
||||
*ancestor_range =
|
||||
Some(Some(RangeInclusive::new(start.row, end.row)));
|
||||
return;
|
||||
}
|
||||
}
|
||||
*ancestor_range = Some(None);
|
||||
}
|
||||
})
|
||||
}
|
||||
});
|
||||
reached_outline_node.await;
|
||||
let end = included_child_end.unwrap_or(node_range.end);
|
||||
if let Some(start) = included_child_start {
|
||||
let row_count = end.row - start.row;
|
||||
if row_count < max_row_count {
|
||||
ancestor_range = Some(Some(RangeInclusive::new(start.row, end.row)));
|
||||
return;
|
||||
}
|
||||
}
|
||||
ancestor_range = Some(None);
|
||||
}
|
||||
})
|
||||
.await;
|
||||
if let Some(node) = ancestor_range {
|
||||
return node;
|
||||
}
|
||||
|
||||
@@ -21,7 +21,6 @@ arrayvec.workspace = true
|
||||
brotli.workspace = true
|
||||
client.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
cloud_zeta2_prompt.workspace = true
|
||||
collections.workspace = true
|
||||
copilot.workspace = true
|
||||
credentials_provider.workspace = true
|
||||
@@ -50,8 +49,6 @@ semver.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
settings.workspace = true
|
||||
smol.workspace = true
|
||||
strsim.workspace = true
|
||||
strum.workspace = true
|
||||
telemetry.workspace = true
|
||||
telemetry_events.workspace = true
|
||||
@@ -62,6 +59,7 @@ uuid.workspace = true
|
||||
workspace.workspace = true
|
||||
worktree.workspace = true
|
||||
zed_actions.workspace = true
|
||||
zeta_prompt.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
clock = { workspace = true, features = ["test-support"] }
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
use anyhow::Result;
|
||||
use arrayvec::ArrayVec;
|
||||
use client::{Client, EditPredictionUsage, UserStore};
|
||||
use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
|
||||
use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
|
||||
use cloud_llm_client::{
|
||||
AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
|
||||
EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
|
||||
MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
|
||||
ZED_VERSION_HEADER_NAME,
|
||||
};
|
||||
use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
|
||||
use collections::{HashMap, HashSet};
|
||||
use db::kvp::{Dismissable, KEY_VALUE_STORE};
|
||||
use edit_prediction_context::EditPredictionExcerptOptions;
|
||||
@@ -16,10 +15,7 @@ use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, Rel
|
||||
use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
|
||||
use futures::{
|
||||
AsyncReadExt as _, FutureExt as _, StreamExt as _,
|
||||
channel::{
|
||||
mpsc::{self, UnboundedReceiver},
|
||||
oneshot,
|
||||
},
|
||||
channel::mpsc::{self, UnboundedReceiver},
|
||||
select_biased,
|
||||
};
|
||||
use gpui::BackgroundExecutor;
|
||||
@@ -58,8 +54,10 @@ mod onboarding_modal;
|
||||
pub mod open_ai_response;
|
||||
mod prediction;
|
||||
pub mod sweep_ai;
|
||||
|
||||
#[cfg(any(test, feature = "test-support", feature = "eval-support"))]
|
||||
pub mod udiff;
|
||||
mod xml_edits;
|
||||
|
||||
mod zed_edit_prediction_delegate;
|
||||
pub mod zeta1;
|
||||
pub mod zeta2;
|
||||
@@ -72,7 +70,6 @@ use crate::mercury::Mercury;
|
||||
use crate::onboarding_modal::ZedPredictModal;
|
||||
pub use crate::prediction::EditPrediction;
|
||||
pub use crate::prediction::EditPredictionId;
|
||||
pub use crate::prediction::EditPredictionInputs;
|
||||
use crate::prediction::EditPredictionResult;
|
||||
pub use crate::sweep_ai::SweepAi;
|
||||
pub use telemetry_events::EditPredictionRating;
|
||||
@@ -112,7 +109,6 @@ pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
|
||||
min_bytes: 128,
|
||||
target_before_cursor_over_total_bytes: 0.5,
|
||||
},
|
||||
max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
|
||||
prompt_format: PromptFormat::DEFAULT,
|
||||
};
|
||||
|
||||
@@ -162,7 +158,6 @@ pub struct EditPredictionStore {
|
||||
use_context: bool,
|
||||
options: ZetaOptions,
|
||||
update_required: bool,
|
||||
debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
|
||||
#[cfg(feature = "eval-support")]
|
||||
eval_cache: Option<Arc<dyn EvalCache>>,
|
||||
edit_prediction_model: EditPredictionModel,
|
||||
@@ -183,10 +178,22 @@ pub enum EditPredictionModel {
|
||||
Mercury,
|
||||
}
|
||||
|
||||
pub struct EditPredictionModelInput {
|
||||
project: Entity<Project>,
|
||||
buffer: Entity<Buffer>,
|
||||
snapshot: BufferSnapshot,
|
||||
position: Anchor,
|
||||
events: Vec<Arc<zeta_prompt::Event>>,
|
||||
related_files: Arc<[RelatedFile]>,
|
||||
recent_paths: VecDeque<ProjectPath>,
|
||||
trigger: PredictEditsRequestTrigger,
|
||||
diagnostic_search_range: Range<Point>,
|
||||
debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct ZetaOptions {
|
||||
pub context: EditPredictionExcerptOptions,
|
||||
pub max_prompt_bytes: usize,
|
||||
pub prompt_format: predict_edits_v3::PromptFormat,
|
||||
}
|
||||
|
||||
@@ -194,7 +201,8 @@ pub struct ZetaOptions {
|
||||
pub enum DebugEvent {
|
||||
ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
|
||||
ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
|
||||
EditPredictionRequested(EditPredictionRequestedDebugEvent),
|
||||
EditPredictionStarted(EditPredictionStartedDebugEvent),
|
||||
EditPredictionFinished(EditPredictionFinishedDebugEvent),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@@ -212,27 +220,30 @@ pub struct ContextRetrievalFinishedDebugEvent {
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EditPredictionRequestedDebugEvent {
|
||||
pub inputs: EditPredictionInputs,
|
||||
pub retrieval_time: Duration,
|
||||
pub struct EditPredictionStartedDebugEvent {
|
||||
pub buffer: WeakEntity<Buffer>,
|
||||
pub position: Anchor,
|
||||
pub local_prompt: Result<String, String>,
|
||||
pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
|
||||
pub prompt: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EditPredictionFinishedDebugEvent {
|
||||
pub buffer: WeakEntity<Buffer>,
|
||||
pub position: Anchor,
|
||||
pub model_output: Option<String>,
|
||||
}
|
||||
|
||||
pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
|
||||
|
||||
struct ProjectState {
|
||||
events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
|
||||
events: VecDeque<Arc<zeta_prompt::Event>>,
|
||||
last_event: Option<LastEvent>,
|
||||
recent_paths: VecDeque<ProjectPath>,
|
||||
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
|
||||
current_prediction: Option<CurrentEditPrediction>,
|
||||
next_pending_prediction_id: usize,
|
||||
pending_predictions: ArrayVec<PendingPrediction, 2>,
|
||||
context_updates_tx: smol::channel::Sender<()>,
|
||||
context_updates_rx: smol::channel::Receiver<()>,
|
||||
debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
|
||||
last_prediction_refresh: Option<(EntityId, Instant)>,
|
||||
cancelled_predictions: HashSet<usize>,
|
||||
context: Entity<RelatedExcerptStore>,
|
||||
@@ -241,7 +252,7 @@ struct ProjectState {
|
||||
}
|
||||
|
||||
impl ProjectState {
|
||||
pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
|
||||
pub fn events(&self, cx: &App) -> Vec<Arc<zeta_prompt::Event>> {
|
||||
self.events
|
||||
.iter()
|
||||
.cloned()
|
||||
@@ -272,6 +283,18 @@ impl ProjectState {
|
||||
})
|
||||
.detach()
|
||||
}
|
||||
|
||||
fn active_buffer(
|
||||
&self,
|
||||
project: &Entity<Project>,
|
||||
cx: &App,
|
||||
) -> Option<(Entity<Buffer>, Option<Anchor>)> {
|
||||
let project = project.read(cx);
|
||||
let active_path = project.path_for_entry(project.active_entry()?, cx)?;
|
||||
let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
|
||||
let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
|
||||
Some((active_buffer, registered_buffer.last_position))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
@@ -362,6 +385,7 @@ impl std::ops::Deref for BufferEditPrediction<'_> {
|
||||
|
||||
struct RegisteredBuffer {
|
||||
snapshot: BufferSnapshot,
|
||||
last_position: Option<Anchor>,
|
||||
_subscriptions: [gpui::Subscription; 2],
|
||||
}
|
||||
|
||||
@@ -376,7 +400,7 @@ impl LastEvent {
|
||||
&self,
|
||||
license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
|
||||
cx: &App,
|
||||
) -> Option<Arc<predict_edits_v3::Event>> {
|
||||
) -> Option<Arc<zeta_prompt::Event>> {
|
||||
let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
|
||||
let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
|
||||
|
||||
@@ -396,7 +420,7 @@ impl LastEvent {
|
||||
if path == old_path && diff.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(Arc::new(predict_edits_v3::Event::BufferChange {
|
||||
Some(Arc::new(zeta_prompt::Event::BufferChange {
|
||||
old_path,
|
||||
path,
|
||||
diff,
|
||||
@@ -481,7 +505,6 @@ impl EditPredictionStore {
|
||||
},
|
||||
),
|
||||
update_required: false,
|
||||
debug_tx: None,
|
||||
#[cfg(feature = "eval-support")]
|
||||
eval_cache: None,
|
||||
edit_prediction_model: EditPredictionModel::Zeta2,
|
||||
@@ -536,12 +559,6 @@ impl EditPredictionStore {
|
||||
self.eval_cache = Some(cache);
|
||||
}
|
||||
|
||||
pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<DebugEvent> {
|
||||
let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
|
||||
self.debug_tx = Some(debug_watch_tx);
|
||||
debug_watch_rx
|
||||
}
|
||||
|
||||
pub fn options(&self) -> &ZetaOptions {
|
||||
&self.options
|
||||
}
|
||||
@@ -560,15 +577,35 @@ impl EditPredictionStore {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn edit_history_for_project(
|
||||
&self,
|
||||
project: &Entity<Project>,
|
||||
) -> Vec<Arc<zeta_prompt::Event>> {
|
||||
self.projects
|
||||
.get(&project.entity_id())
|
||||
.map(|project_state| project_state.events.iter().cloned().collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn context_for_project<'a>(
|
||||
&'a self,
|
||||
project: &Entity<Project>,
|
||||
cx: &'a App,
|
||||
) -> &'a [RelatedFile] {
|
||||
) -> Arc<[RelatedFile]> {
|
||||
self.projects
|
||||
.get(&project.entity_id())
|
||||
.map(|project| project.context.read(cx).related_files())
|
||||
.unwrap_or(&[])
|
||||
.unwrap_or_else(|| vec![].into())
|
||||
}
|
||||
|
||||
pub fn context_for_project_with_buffers<'a>(
|
||||
&'a self,
|
||||
project: &Entity<Project>,
|
||||
cx: &'a App,
|
||||
) -> Option<impl 'a + Iterator<Item = (RelatedFile, Entity<Buffer>)>> {
|
||||
self.projects
|
||||
.get(&project.entity_id())
|
||||
.map(|project| project.context.read(cx).related_files_with_buffers())
|
||||
}
|
||||
|
||||
pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
|
||||
@@ -599,85 +636,21 @@ impl EditPredictionStore {
|
||||
cx: &mut Context<Self>,
|
||||
) -> &mut ProjectState {
|
||||
let entity_id = project.entity_id();
|
||||
let (context_updates_tx, context_updates_rx) = smol::channel::unbounded();
|
||||
self.projects
|
||||
.entry(entity_id)
|
||||
.or_insert_with(|| ProjectState {
|
||||
context: {
|
||||
let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
|
||||
cx.subscribe(
|
||||
&related_excerpt_store,
|
||||
move |this, _, event, _| match event {
|
||||
RelatedExcerptStoreEvent::StartedRefresh => {
|
||||
if let Some(debug_tx) = this.debug_tx.clone() {
|
||||
debug_tx
|
||||
.unbounded_send(DebugEvent::ContextRetrievalStarted(
|
||||
ContextRetrievalStartedDebugEvent {
|
||||
project_entity_id: entity_id,
|
||||
timestamp: Instant::now(),
|
||||
search_prompt: String::new(),
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
RelatedExcerptStoreEvent::FinishedRefresh {
|
||||
cache_hit_count,
|
||||
cache_miss_count,
|
||||
mean_definition_latency,
|
||||
max_definition_latency,
|
||||
} => {
|
||||
if let Some(debug_tx) = this.debug_tx.clone() {
|
||||
debug_tx
|
||||
.unbounded_send(DebugEvent::ContextRetrievalFinished(
|
||||
ContextRetrievalFinishedDebugEvent {
|
||||
project_entity_id: entity_id,
|
||||
timestamp: Instant::now(),
|
||||
metadata: vec![
|
||||
(
|
||||
"Cache Hits",
|
||||
format!(
|
||||
"{}/{}",
|
||||
cache_hit_count,
|
||||
cache_hit_count + cache_miss_count
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
(
|
||||
"Max LSP Time",
|
||||
format!(
|
||||
"{} ms",
|
||||
max_definition_latency.as_millis()
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
(
|
||||
"Mean LSP Time",
|
||||
format!(
|
||||
"{} ms",
|
||||
mean_definition_latency.as_millis()
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
],
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
if let Some(project_state) = this.projects.get(&entity_id) {
|
||||
project_state.context_updates_tx.send_blocking(()).ok();
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
|
||||
this.handle_excerpt_store_event(entity_id, event);
|
||||
})
|
||||
.detach();
|
||||
related_excerpt_store
|
||||
},
|
||||
events: VecDeque::new(),
|
||||
last_event: None,
|
||||
recent_paths: VecDeque::new(),
|
||||
context_updates_rx,
|
||||
context_updates_tx,
|
||||
debug_tx: None,
|
||||
registered_buffers: HashMap::default(),
|
||||
current_prediction: None,
|
||||
cancelled_predictions: HashSet::default(),
|
||||
@@ -689,12 +662,79 @@ impl EditPredictionStore {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn project_context_updates(
|
||||
&self,
|
||||
pub fn remove_project(&mut self, project: &Entity<Project>) {
|
||||
self.projects.remove(&project.entity_id());
|
||||
}
|
||||
|
||||
fn handle_excerpt_store_event(
|
||||
&mut self,
|
||||
project_entity_id: EntityId,
|
||||
event: &RelatedExcerptStoreEvent,
|
||||
) {
|
||||
if let Some(project_state) = self.projects.get(&project_entity_id) {
|
||||
if let Some(debug_tx) = project_state.debug_tx.clone() {
|
||||
match event {
|
||||
RelatedExcerptStoreEvent::StartedRefresh => {
|
||||
debug_tx
|
||||
.unbounded_send(DebugEvent::ContextRetrievalStarted(
|
||||
ContextRetrievalStartedDebugEvent {
|
||||
project_entity_id: project_entity_id,
|
||||
timestamp: Instant::now(),
|
||||
search_prompt: String::new(),
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
RelatedExcerptStoreEvent::FinishedRefresh {
|
||||
cache_hit_count,
|
||||
cache_miss_count,
|
||||
mean_definition_latency,
|
||||
max_definition_latency,
|
||||
} => {
|
||||
debug_tx
|
||||
.unbounded_send(DebugEvent::ContextRetrievalFinished(
|
||||
ContextRetrievalFinishedDebugEvent {
|
||||
project_entity_id: project_entity_id,
|
||||
timestamp: Instant::now(),
|
||||
metadata: vec![
|
||||
(
|
||||
"Cache Hits",
|
||||
format!(
|
||||
"{}/{}",
|
||||
cache_hit_count,
|
||||
cache_hit_count + cache_miss_count
|
||||
)
|
||||
.into(),
|
||||
),
|
||||
(
|
||||
"Max LSP Time",
|
||||
format!("{} ms", max_definition_latency.as_millis())
|
||||
.into(),
|
||||
),
|
||||
(
|
||||
"Mean LSP Time",
|
||||
format!("{} ms", mean_definition_latency.as_millis())
|
||||
.into(),
|
||||
),
|
||||
],
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn debug_info(
|
||||
&mut self,
|
||||
project: &Entity<Project>,
|
||||
) -> Option<smol::channel::Receiver<()>> {
|
||||
let project_state = self.projects.get(&project.entity_id())?;
|
||||
Some(project_state.context_updates_rx.clone())
|
||||
cx: &mut Context<Self>,
|
||||
) -> mpsc::UnboundedReceiver<DebugEvent> {
|
||||
let project_state = self.get_or_init_project(project, cx);
|
||||
let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
|
||||
project_state.debug_tx = Some(debug_watch_tx);
|
||||
debug_watch_rx
|
||||
}
|
||||
|
||||
fn handle_project_event(
|
||||
@@ -768,6 +808,7 @@ impl EditPredictionStore {
|
||||
let project_entity_id = project.entity_id();
|
||||
entry.insert(RegisteredBuffer {
|
||||
snapshot,
|
||||
last_position: None,
|
||||
_subscriptions: [
|
||||
cx.subscribe(buffer, {
|
||||
let project = project.downgrade();
|
||||
@@ -855,13 +896,21 @@ impl EditPredictionStore {
|
||||
});
|
||||
}
|
||||
|
||||
fn current_prediction_for_buffer(
|
||||
&self,
|
||||
fn prediction_at(
|
||||
&mut self,
|
||||
buffer: &Entity<Buffer>,
|
||||
position: Option<language::Anchor>,
|
||||
project: &Entity<Project>,
|
||||
cx: &App,
|
||||
) -> Option<BufferEditPrediction<'_>> {
|
||||
let project_state = self.projects.get(&project.entity_id())?;
|
||||
let project_state = self.projects.get_mut(&project.entity_id())?;
|
||||
if let Some(position) = position
|
||||
&& let Some(buffer) = project_state
|
||||
.registered_buffers
|
||||
.get_mut(&buffer.entity_id())
|
||||
{
|
||||
buffer.last_position = Some(position);
|
||||
}
|
||||
|
||||
let CurrentEditPrediction {
|
||||
requested_by,
|
||||
@@ -1104,12 +1153,21 @@ impl EditPredictionStore {
|
||||
};
|
||||
|
||||
self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
|
||||
let Some(open_buffer_task) = project
|
||||
.update(cx, |project, cx| {
|
||||
project
|
||||
.active_entry()
|
||||
.and_then(|entry| project.path_for_entry(entry, cx))
|
||||
.map(|path| project.open_buffer(path, cx))
|
||||
let Some((active_buffer, snapshot, cursor_point)) = this
|
||||
.read_with(cx, |this, cx| {
|
||||
let project_state = this.projects.get(&project.entity_id())?;
|
||||
let (buffer, position) = project_state.active_buffer(&project, cx)?;
|
||||
let snapshot = buffer.read(cx).snapshot();
|
||||
|
||||
if !Self::predictions_enabled_at(&snapshot, position, cx) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let cursor_point = position
|
||||
.map(|pos| pos.to_point(&snapshot))
|
||||
.unwrap_or_default();
|
||||
|
||||
Some((buffer, snapshot, cursor_point))
|
||||
})
|
||||
.log_err()
|
||||
.flatten()
|
||||
@@ -1118,14 +1176,11 @@ impl EditPredictionStore {
|
||||
};
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let active_buffer = open_buffer_task.await?;
|
||||
let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
|
||||
|
||||
let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
|
||||
active_buffer,
|
||||
&snapshot,
|
||||
Default::default(),
|
||||
Default::default(),
|
||||
cursor_point,
|
||||
&project,
|
||||
cx,
|
||||
)
|
||||
@@ -1170,6 +1225,37 @@ impl EditPredictionStore {
|
||||
});
|
||||
}
|
||||
|
||||
fn predictions_enabled_at(
|
||||
snapshot: &BufferSnapshot,
|
||||
position: Option<language::Anchor>,
|
||||
cx: &App,
|
||||
) -> bool {
|
||||
let file = snapshot.file();
|
||||
let all_settings = all_language_settings(file, cx);
|
||||
if !all_settings.show_edit_predictions(snapshot.language(), cx)
|
||||
|| file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if let Some(last_position) = position {
|
||||
let settings = snapshot.settings_at(last_position, cx);
|
||||
|
||||
if !settings.edit_predictions_disabled_in.is_empty()
|
||||
&& let Some(scope) = snapshot.language_scope_at(last_position)
|
||||
&& let Some(scope_name) = scope.override_name()
|
||||
&& settings
|
||||
.edit_predictions_disabled_in
|
||||
.iter()
|
||||
.any(|s| s == scope_name)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
#[cfg(not(test))]
|
||||
pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
|
||||
#[cfg(test)]
|
||||
@@ -1348,6 +1434,7 @@ impl EditPredictionStore {
|
||||
let project_state = self.projects.get(&project.entity_id()).unwrap();
|
||||
let events = project_state.events(cx);
|
||||
let has_events = !events.is_empty();
|
||||
let debug_tx = project_state.debug_tx.clone();
|
||||
|
||||
let snapshot = active_buffer.read(cx).snapshot();
|
||||
let cursor_point = position.to_point(&snapshot);
|
||||
@@ -1357,55 +1444,29 @@ impl EditPredictionStore {
|
||||
Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
|
||||
|
||||
let related_files = if self.use_context {
|
||||
self.context_for_project(&project, cx).to_vec()
|
||||
self.context_for_project(&project, cx)
|
||||
} else {
|
||||
Vec::new()
|
||||
Vec::new().into()
|
||||
};
|
||||
|
||||
let inputs = EditPredictionModelInput {
|
||||
project: project.clone(),
|
||||
buffer: active_buffer.clone(),
|
||||
snapshot: snapshot.clone(),
|
||||
position,
|
||||
events,
|
||||
related_files,
|
||||
recent_paths: project_state.recent_paths.clone(),
|
||||
trigger,
|
||||
diagnostic_search_range: diagnostic_search_range.clone(),
|
||||
debug_tx,
|
||||
};
|
||||
|
||||
let task = match self.edit_prediction_model {
|
||||
EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(
|
||||
self,
|
||||
&project,
|
||||
&active_buffer,
|
||||
snapshot.clone(),
|
||||
position,
|
||||
events,
|
||||
trigger,
|
||||
cx,
|
||||
),
|
||||
EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(
|
||||
self,
|
||||
&project,
|
||||
&active_buffer,
|
||||
snapshot.clone(),
|
||||
position,
|
||||
events,
|
||||
related_files,
|
||||
trigger,
|
||||
cx,
|
||||
),
|
||||
EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
|
||||
&project,
|
||||
&active_buffer,
|
||||
snapshot.clone(),
|
||||
position,
|
||||
events,
|
||||
&project_state.recent_paths,
|
||||
related_files,
|
||||
diagnostic_search_range.clone(),
|
||||
cx,
|
||||
),
|
||||
EditPredictionModel::Mercury => self.mercury.request_prediction(
|
||||
&project,
|
||||
&active_buffer,
|
||||
snapshot.clone(),
|
||||
position,
|
||||
events,
|
||||
&project_state.recent_paths,
|
||||
related_files,
|
||||
diagnostic_search_range.clone(),
|
||||
cx,
|
||||
),
|
||||
EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
|
||||
EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
|
||||
EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
|
||||
EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
|
||||
};
|
||||
|
||||
cx.spawn(async move |this, cx| {
|
||||
@@ -1706,6 +1767,20 @@ impl EditPredictionStore {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "eval-support")]
|
||||
pub fn set_context_for_buffer(
|
||||
&mut self,
|
||||
project: &Entity<Project>,
|
||||
related_files: Vec<RelatedFile>,
|
||||
cx: &mut Context<Self>,
|
||||
) {
|
||||
self.get_or_init_project(project, cx)
|
||||
.context
|
||||
.update(cx, |store, _| {
|
||||
store.set_related_files(related_files);
|
||||
});
|
||||
}
|
||||
|
||||
fn is_file_open_source(
|
||||
&self,
|
||||
project: &Entity<Project>,
|
||||
@@ -1729,14 +1804,14 @@ impl EditPredictionStore {
|
||||
self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
|
||||
}
|
||||
|
||||
fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
|
||||
fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>]) -> bool {
|
||||
if !self.data_collection_choice.is_enabled() {
|
||||
return false;
|
||||
}
|
||||
events.iter().all(|event| {
|
||||
matches!(
|
||||
event.as_ref(),
|
||||
Event::BufferChange {
|
||||
zeta_prompt::Event::BufferChange {
|
||||
in_open_source_repo: true,
|
||||
..
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use super::*;
|
||||
use crate::zeta1::MAX_EVENT_TOKENS;
|
||||
use crate::{udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
|
||||
use client::{UserStore, test::FakeServer};
|
||||
use clock::{FakeSystemClock, ReplicaId};
|
||||
use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
|
||||
@@ -7,7 +7,6 @@ use cloud_llm_client::{
|
||||
EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
|
||||
RejectEditPredictionsBody,
|
||||
};
|
||||
use edit_prediction_context::Line;
|
||||
use futures::{
|
||||
AsyncReadExt, StreamExt,
|
||||
channel::{mpsc, oneshot},
|
||||
@@ -28,6 +27,7 @@ use settings::SettingsStore;
|
||||
use std::{path::Path, sync::Arc, time::Duration};
|
||||
use util::{path, rel_path::rel_path};
|
||||
use uuid::Uuid;
|
||||
use zeta_prompt::ZetaPromptInput;
|
||||
|
||||
use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
|
||||
|
||||
@@ -45,10 +45,6 @@ async fn test_current_state(cx: &mut TestAppContext) {
|
||||
.await;
|
||||
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
|
||||
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.register_project(&project, cx);
|
||||
});
|
||||
|
||||
let buffer1 = project
|
||||
.update(cx, |project, cx| {
|
||||
let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
|
||||
@@ -60,30 +56,38 @@ async fn test_current_state(cx: &mut TestAppContext) {
|
||||
let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
|
||||
let position = snapshot1.anchor_before(language::Point::new(1, 3));
|
||||
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.register_project(&project, cx);
|
||||
ep_store.register_buffer(&buffer1, &project, cx);
|
||||
});
|
||||
|
||||
// Prediction for current file
|
||||
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
|
||||
});
|
||||
let (_request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let (request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
|
||||
respond_tx
|
||||
.send(model_response(indoc! {r"
|
||||
--- a/root/1.txt
|
||||
+++ b/root/1.txt
|
||||
@@ ... @@
|
||||
Hello!
|
||||
-How
|
||||
+How are you?
|
||||
Bye
|
||||
"}))
|
||||
.send(model_response(
|
||||
request,
|
||||
indoc! {r"
|
||||
--- a/root/1.txt
|
||||
+++ b/root/1.txt
|
||||
@@ ... @@
|
||||
Hello!
|
||||
-How
|
||||
+How are you?
|
||||
Bye
|
||||
"},
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
let prediction = ep_store
|
||||
.current_prediction_for_buffer(&buffer1, &project, cx)
|
||||
.prediction_at(&buffer1, None, &project, cx)
|
||||
.unwrap();
|
||||
assert_matches!(prediction, BufferEditPrediction::Local { .. });
|
||||
});
|
||||
@@ -120,22 +124,26 @@ async fn test_current_state(cx: &mut TestAppContext) {
|
||||
});
|
||||
});
|
||||
|
||||
let (_request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let (request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
respond_tx
|
||||
.send(model_response(indoc! {r#"
|
||||
--- a/root/2.txt
|
||||
+++ b/root/2.txt
|
||||
Hola!
|
||||
-Como
|
||||
+Como estas?
|
||||
Adios
|
||||
"#}))
|
||||
.send(model_response(
|
||||
request,
|
||||
indoc! {r#"
|
||||
--- a/root/2.txt
|
||||
+++ b/root/2.txt
|
||||
@@ ... @@
|
||||
Hola!
|
||||
-Como
|
||||
+Como estas?
|
||||
Adios
|
||||
"#},
|
||||
))
|
||||
.unwrap();
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
let prediction = ep_store
|
||||
.current_prediction_for_buffer(&buffer1, &project, cx)
|
||||
.prediction_at(&buffer1, None, &project, cx)
|
||||
.unwrap();
|
||||
assert_matches!(
|
||||
prediction,
|
||||
@@ -151,9 +159,9 @@ async fn test_current_state(cx: &mut TestAppContext) {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
let prediction = ep_store
|
||||
.current_prediction_for_buffer(&buffer2, &project, cx)
|
||||
.prediction_at(&buffer2, None, &project, cx)
|
||||
.unwrap();
|
||||
assert_matches!(prediction, BufferEditPrediction::Local { .. });
|
||||
});
|
||||
@@ -186,7 +194,7 @@ async fn test_simple_request(cx: &mut TestAppContext) {
|
||||
ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
|
||||
});
|
||||
|
||||
let (_, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let (request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
|
||||
// TODO Put back when we have a structured request again
|
||||
// assert_eq!(
|
||||
@@ -202,15 +210,18 @@ async fn test_simple_request(cx: &mut TestAppContext) {
|
||||
// );
|
||||
|
||||
respond_tx
|
||||
.send(model_response(indoc! { r"
|
||||
--- a/root/foo.md
|
||||
+++ b/root/foo.md
|
||||
@@ ... @@
|
||||
Hello!
|
||||
-How
|
||||
+How are you?
|
||||
Bye
|
||||
"}))
|
||||
.send(model_response(
|
||||
request,
|
||||
indoc! { r"
|
||||
--- a/root/foo.md
|
||||
+++ b/root/foo.md
|
||||
@@ ... @@
|
||||
Hello!
|
||||
-How
|
||||
+How are you?
|
||||
Bye
|
||||
"},
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
|
||||
@@ -276,15 +287,18 @@ async fn test_request_events(cx: &mut TestAppContext) {
|
||||
);
|
||||
|
||||
respond_tx
|
||||
.send(model_response(indoc! {r#"
|
||||
--- a/root/foo.md
|
||||
+++ b/root/foo.md
|
||||
@@ ... @@
|
||||
Hello!
|
||||
-How
|
||||
+How are you?
|
||||
Bye
|
||||
"#}))
|
||||
.send(model_response(
|
||||
request,
|
||||
indoc! {r#"
|
||||
--- a/root/foo.md
|
||||
+++ b/root/foo.md
|
||||
@@ ... @@
|
||||
Hello!
|
||||
-How
|
||||
+How are you?
|
||||
Bye
|
||||
"#},
|
||||
))
|
||||
.unwrap();
|
||||
|
||||
let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
|
||||
@@ -324,27 +338,17 @@ async fn test_empty_prediction(cx: &mut TestAppContext) {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
|
||||
});
|
||||
|
||||
const NO_OP_DIFF: &str = indoc! { r"
|
||||
--- a/root/foo.md
|
||||
+++ b/root/foo.md
|
||||
@@ ... @@
|
||||
Hello!
|
||||
-How
|
||||
+How
|
||||
Bye
|
||||
"};
|
||||
|
||||
let (_, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let response = model_response(NO_OP_DIFF);
|
||||
let (request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let response = model_response(request, "");
|
||||
let id = response.id.clone();
|
||||
respond_tx.send(response).unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
assert!(
|
||||
ep_store
|
||||
.current_prediction_for_buffer(&buffer, &project, cx)
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.is_none()
|
||||
);
|
||||
});
|
||||
@@ -389,22 +393,22 @@ async fn test_interpolated_empty(cx: &mut TestAppContext) {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
|
||||
});
|
||||
|
||||
let (_, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let (request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer.set_text("Hello!\nHow are you?\nBye", cx);
|
||||
});
|
||||
|
||||
let response = model_response(SIMPLE_DIFF);
|
||||
let response = model_response(request, SIMPLE_DIFF);
|
||||
let id = response.id.clone();
|
||||
respond_tx.send(response).unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
assert!(
|
||||
ep_store
|
||||
.current_prediction_for_buffer(&buffer, &project, cx)
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.is_none()
|
||||
);
|
||||
});
|
||||
@@ -459,17 +463,17 @@ async fn test_replace_current(cx: &mut TestAppContext) {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
|
||||
});
|
||||
|
||||
let (_, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let first_response = model_response(SIMPLE_DIFF);
|
||||
let (request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let first_response = model_response(request, SIMPLE_DIFF);
|
||||
let first_id = first_response.id.clone();
|
||||
respond_tx.send(first_response).unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
assert_eq!(
|
||||
ep_store
|
||||
.current_prediction_for_buffer(&buffer, &project, cx)
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.unwrap()
|
||||
.id
|
||||
.0,
|
||||
@@ -482,18 +486,18 @@ async fn test_replace_current(cx: &mut TestAppContext) {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
|
||||
});
|
||||
|
||||
let (_, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let second_response = model_response(SIMPLE_DIFF);
|
||||
let (request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let second_response = model_response(request, SIMPLE_DIFF);
|
||||
let second_id = second_response.id.clone();
|
||||
respond_tx.send(second_response).unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
// second replaces first
|
||||
assert_eq!(
|
||||
ep_store
|
||||
.current_prediction_for_buffer(&buffer, &project, cx)
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.unwrap()
|
||||
.id
|
||||
.0,
|
||||
@@ -541,17 +545,17 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
|
||||
});
|
||||
|
||||
let (_, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let first_response = model_response(SIMPLE_DIFF);
|
||||
let (request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let first_response = model_response(request, SIMPLE_DIFF);
|
||||
let first_id = first_response.id.clone();
|
||||
respond_tx.send(first_response).unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
assert_eq!(
|
||||
ep_store
|
||||
.current_prediction_for_buffer(&buffer, &project, cx)
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.unwrap()
|
||||
.id
|
||||
.0,
|
||||
@@ -564,27 +568,30 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
|
||||
});
|
||||
|
||||
let (_, respond_tx) = requests.predict.next().await.unwrap();
|
||||
let (request, respond_tx) = requests.predict.next().await.unwrap();
|
||||
// worse than current prediction
|
||||
let second_response = model_response(indoc! { r"
|
||||
--- a/root/foo.md
|
||||
+++ b/root/foo.md
|
||||
@@ ... @@
|
||||
Hello!
|
||||
-How
|
||||
+How are
|
||||
Bye
|
||||
"});
|
||||
let second_response = model_response(
|
||||
request,
|
||||
indoc! { r"
|
||||
--- a/root/foo.md
|
||||
+++ b/root/foo.md
|
||||
@@ ... @@
|
||||
Hello!
|
||||
-How
|
||||
+How are
|
||||
Bye
|
||||
"},
|
||||
);
|
||||
let second_id = second_response.id.clone();
|
||||
respond_tx.send(second_response).unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
// first is preferred over second
|
||||
assert_eq!(
|
||||
ep_store
|
||||
.current_prediction_for_buffer(&buffer, &project, cx)
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.unwrap()
|
||||
.id
|
||||
.0,
|
||||
@@ -633,29 +640,29 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
|
||||
});
|
||||
|
||||
let (_, respond_first) = requests.predict.next().await.unwrap();
|
||||
let (request1, respond_first) = requests.predict.next().await.unwrap();
|
||||
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
|
||||
});
|
||||
|
||||
let (_, respond_second) = requests.predict.next().await.unwrap();
|
||||
let (request, respond_second) = requests.predict.next().await.unwrap();
|
||||
|
||||
// wait for throttle
|
||||
cx.run_until_parked();
|
||||
|
||||
// second responds first
|
||||
let second_response = model_response(SIMPLE_DIFF);
|
||||
let second_response = model_response(request, SIMPLE_DIFF);
|
||||
let second_id = second_response.id.clone();
|
||||
respond_second.send(second_response).unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
// current prediction is second
|
||||
assert_eq!(
|
||||
ep_store
|
||||
.current_prediction_for_buffer(&buffer, &project, cx)
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.unwrap()
|
||||
.id
|
||||
.0,
|
||||
@@ -663,17 +670,17 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
|
||||
);
|
||||
});
|
||||
|
||||
let first_response = model_response(SIMPLE_DIFF);
|
||||
let first_response = model_response(request1, SIMPLE_DIFF);
|
||||
let first_id = first_response.id.clone();
|
||||
respond_first.send(first_response).unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
// current prediction is still second, since first was cancelled
|
||||
assert_eq!(
|
||||
ep_store
|
||||
.current_prediction_for_buffer(&buffer, &project, cx)
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.unwrap()
|
||||
.id
|
||||
.0,
|
||||
@@ -724,13 +731,13 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
|
||||
});
|
||||
|
||||
let (_, respond_first) = requests.predict.next().await.unwrap();
|
||||
let (request1, respond_first) = requests.predict.next().await.unwrap();
|
||||
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
|
||||
});
|
||||
|
||||
let (_, respond_second) = requests.predict.next().await.unwrap();
|
||||
let (request2, respond_second) = requests.predict.next().await.unwrap();
|
||||
|
||||
// wait for throttle, so requests are sent
|
||||
cx.run_until_parked();
|
||||
@@ -754,19 +761,19 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
|
||||
// wait for throttle
|
||||
cx.run_until_parked();
|
||||
|
||||
let (_, respond_third) = requests.predict.next().await.unwrap();
|
||||
let (request3, respond_third) = requests.predict.next().await.unwrap();
|
||||
|
||||
let first_response = model_response(SIMPLE_DIFF);
|
||||
let first_response = model_response(request1, SIMPLE_DIFF);
|
||||
let first_id = first_response.id.clone();
|
||||
respond_first.send(first_response).unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
// current prediction is first
|
||||
assert_eq!(
|
||||
ep_store
|
||||
.current_prediction_for_buffer(&buffer, &project, cx)
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.unwrap()
|
||||
.id
|
||||
.0,
|
||||
@@ -774,17 +781,17 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
|
||||
);
|
||||
});
|
||||
|
||||
let cancelled_response = model_response(SIMPLE_DIFF);
|
||||
let cancelled_response = model_response(request2, SIMPLE_DIFF);
|
||||
let cancelled_id = cancelled_response.id.clone();
|
||||
respond_second.send(cancelled_response).unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
// current prediction is still first, since second was cancelled
|
||||
assert_eq!(
|
||||
ep_store
|
||||
.current_prediction_for_buffer(&buffer, &project, cx)
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.unwrap()
|
||||
.id
|
||||
.0,
|
||||
@@ -792,17 +799,17 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
|
||||
);
|
||||
});
|
||||
|
||||
let third_response = model_response(SIMPLE_DIFF);
|
||||
let third_response = model_response(request3, SIMPLE_DIFF);
|
||||
let third_response_id = third_response.id.clone();
|
||||
respond_third.send(third_response).unwrap();
|
||||
|
||||
cx.run_until_parked();
|
||||
|
||||
ep_store.read_with(cx, |ep_store, cx| {
|
||||
ep_store.update(cx, |ep_store, cx| {
|
||||
// third completes and replaces first
|
||||
assert_eq!(
|
||||
ep_store
|
||||
.current_prediction_for_buffer(&buffer, &project, cx)
|
||||
.prediction_at(&buffer, None, &project, cx)
|
||||
.unwrap()
|
||||
.id
|
||||
.0,
|
||||
@@ -1036,7 +1043,24 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) {
|
||||
// );
|
||||
// }
|
||||
|
||||
fn model_response(text: &str) -> open_ai::Response {
|
||||
// Generate a model response that would apply the given diff to the active file.
|
||||
fn model_response(request: open_ai::Request, diff_to_apply: &str) -> open_ai::Response {
|
||||
let prompt = match &request.messages[0] {
|
||||
open_ai::RequestMessage::User {
|
||||
content: open_ai::MessageContent::Plain(content),
|
||||
} => content,
|
||||
_ => panic!("unexpected request {request:?}"),
|
||||
};
|
||||
|
||||
let open = "<editable_region>\n";
|
||||
let close = "</editable_region>";
|
||||
let cursor = "<|user_cursor|>";
|
||||
|
||||
let start_ix = open.len() + prompt.find(open).unwrap();
|
||||
let end_ix = start_ix + &prompt[start_ix..].find(close).unwrap();
|
||||
let excerpt = prompt[start_ix..end_ix].replace(cursor, "");
|
||||
let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
|
||||
|
||||
open_ai::Response {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
object: "response".into(),
|
||||
@@ -1045,7 +1069,7 @@ fn model_response(text: &str) -> open_ai::Response {
|
||||
choices: vec![open_ai::Choice {
|
||||
index: 0,
|
||||
message: open_ai::RequestMessage::Assistant {
|
||||
content: Some(open_ai::MessageContent::Plain(text.to_string())),
|
||||
content: Some(open_ai::MessageContent::Plain(new_excerpt)),
|
||||
tool_calls: vec![],
|
||||
},
|
||||
finish_reason: None,
|
||||
@@ -1160,20 +1184,19 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
.read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
|
||||
.await;
|
||||
|
||||
let completion = EditPrediction {
|
||||
let prediction = EditPrediction {
|
||||
edits,
|
||||
edit_preview,
|
||||
buffer: buffer.clone(),
|
||||
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
|
||||
id: EditPredictionId("the-id".into()),
|
||||
inputs: EditPredictionInputs {
|
||||
inputs: ZetaPromptInput {
|
||||
events: Default::default(),
|
||||
included_files: Default::default(),
|
||||
cursor_point: cloud_llm_client::predict_edits_v3::Point {
|
||||
line: Line(0),
|
||||
column: 0,
|
||||
},
|
||||
related_files: Default::default(),
|
||||
cursor_path: Path::new("").into(),
|
||||
cursor_excerpt: "".into(),
|
||||
editable_range_in_excerpt: 0..0,
|
||||
cursor_offset_in_excerpt: 0,
|
||||
},
|
||||
buffer_snapshotted_at: Instant::now(),
|
||||
response_received_at: Instant::now(),
|
||||
@@ -1182,7 +1205,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
cx.update(|cx| {
|
||||
assert_eq!(
|
||||
from_completion_edits(
|
||||
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&buffer,
|
||||
cx
|
||||
),
|
||||
@@ -1192,7 +1215,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
|
||||
assert_eq!(
|
||||
from_completion_edits(
|
||||
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&buffer,
|
||||
cx
|
||||
),
|
||||
@@ -1202,7 +1225,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
buffer.update(cx, |buffer, cx| buffer.undo(cx));
|
||||
assert_eq!(
|
||||
from_completion_edits(
|
||||
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&buffer,
|
||||
cx
|
||||
),
|
||||
@@ -1212,7 +1235,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
|
||||
assert_eq!(
|
||||
from_completion_edits(
|
||||
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&buffer,
|
||||
cx
|
||||
),
|
||||
@@ -1222,7 +1245,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
|
||||
assert_eq!(
|
||||
from_completion_edits(
|
||||
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&buffer,
|
||||
cx
|
||||
),
|
||||
@@ -1232,7 +1255,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
|
||||
assert_eq!(
|
||||
from_completion_edits(
|
||||
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&buffer,
|
||||
cx
|
||||
),
|
||||
@@ -1242,7 +1265,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
|
||||
assert_eq!(
|
||||
from_completion_edits(
|
||||
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&buffer,
|
||||
cx
|
||||
),
|
||||
@@ -1252,7 +1275,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
|
||||
assert_eq!(
|
||||
from_completion_edits(
|
||||
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
|
||||
&buffer,
|
||||
cx
|
||||
),
|
||||
@@ -1260,7 +1283,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
);
|
||||
|
||||
buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
|
||||
assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
|
||||
assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -1,20 +1,17 @@
|
||||
use anyhow::{Context as _, Result};
|
||||
use cloud_llm_client::predict_edits_v3::Event;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use edit_prediction_context::RelatedFile;
|
||||
use futures::{AsyncReadExt as _, FutureExt, future::Shared};
|
||||
use gpui::{
|
||||
App, AppContext as _, Entity, Task,
|
||||
App, AppContext as _, Task,
|
||||
http_client::{self, AsyncBody, Method},
|
||||
};
|
||||
use language::{Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _};
|
||||
use project::{Project, ProjectPath};
|
||||
use std::{
|
||||
collections::VecDeque, fmt::Write as _, mem, ops::Range, path::Path, sync::Arc, time::Instant,
|
||||
};
|
||||
use language::{OffsetRangeExt as _, ToOffset, ToPoint as _};
|
||||
use std::{mem, ops::Range, path::Path, sync::Arc, time::Instant};
|
||||
use zeta_prompt::ZetaPromptInput;
|
||||
|
||||
use crate::{
|
||||
EditPredictionId, EditPredictionInputs, open_ai_response::text_from_response,
|
||||
DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
|
||||
EditPredictionStartedDebugEvent, open_ai_response::text_from_response,
|
||||
prediction::EditPredictionResult,
|
||||
};
|
||||
|
||||
@@ -38,16 +35,17 @@ impl Mercury {
|
||||
store_api_token_in_keychain(api_token, cx)
|
||||
}
|
||||
|
||||
pub fn request_prediction(
|
||||
pub(crate) fn request_prediction(
|
||||
&self,
|
||||
_project: &Entity<Project>,
|
||||
active_buffer: &Entity<Buffer>,
|
||||
snapshot: BufferSnapshot,
|
||||
position: language::Anchor,
|
||||
events: Vec<Arc<Event>>,
|
||||
_recent_paths: &VecDeque<ProjectPath>,
|
||||
related_files: Vec<RelatedFile>,
|
||||
_diagnostic_search_range: Range<Point>,
|
||||
EditPredictionModelInput {
|
||||
buffer,
|
||||
snapshot,
|
||||
position,
|
||||
events,
|
||||
related_files,
|
||||
debug_tx,
|
||||
..
|
||||
}: EditPredictionModelInput,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Option<EditPredictionResult>>> {
|
||||
let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
|
||||
@@ -62,6 +60,7 @@ impl Mercury {
|
||||
let http_client = cx.http_client();
|
||||
let cursor_point = position.to_point(&snapshot);
|
||||
let buffer_snapshotted_at = Instant::now();
|
||||
let active_buffer = buffer.clone();
|
||||
|
||||
let result = cx.background_spawn(async move {
|
||||
let (editable_range, context_range) =
|
||||
@@ -72,39 +71,39 @@ impl Mercury {
|
||||
MAX_REWRITE_TOKENS,
|
||||
);
|
||||
|
||||
let offset_range = editable_range.to_offset(&snapshot);
|
||||
let prompt = build_prompt(
|
||||
&events,
|
||||
&related_files,
|
||||
&snapshot,
|
||||
full_path.as_ref(),
|
||||
cursor_point,
|
||||
editable_range,
|
||||
context_range.clone(),
|
||||
);
|
||||
let context_offset_range = context_range.to_offset(&snapshot);
|
||||
|
||||
let inputs = EditPredictionInputs {
|
||||
events: events,
|
||||
included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
|
||||
path: full_path.clone(),
|
||||
max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
|
||||
excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
|
||||
start_line: cloud_llm_client::predict_edits_v3::Line(
|
||||
context_range.start.row,
|
||||
),
|
||||
text: snapshot
|
||||
.text_for_range(context_range.clone())
|
||||
.collect::<String>()
|
||||
.into(),
|
||||
}],
|
||||
}],
|
||||
cursor_point: cloud_llm_client::predict_edits_v3::Point {
|
||||
column: cursor_point.column,
|
||||
line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
|
||||
},
|
||||
let editable_offset_range = editable_range.to_offset(&snapshot);
|
||||
|
||||
let inputs = zeta_prompt::ZetaPromptInput {
|
||||
events,
|
||||
related_files,
|
||||
cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot)
|
||||
- context_range.start.to_offset(&snapshot),
|
||||
cursor_path: full_path.clone(),
|
||||
cursor_excerpt: snapshot
|
||||
.text_for_range(context_range)
|
||||
.collect::<String>()
|
||||
.into(),
|
||||
editable_range_in_excerpt: (editable_offset_range.start
|
||||
- context_offset_range.start)
|
||||
..(editable_offset_range.end - context_offset_range.start),
|
||||
};
|
||||
|
||||
let prompt = build_prompt(&inputs);
|
||||
|
||||
if let Some(debug_tx) = &debug_tx {
|
||||
debug_tx
|
||||
.unbounded_send(DebugEvent::EditPredictionStarted(
|
||||
EditPredictionStartedDebugEvent {
|
||||
buffer: active_buffer.downgrade(),
|
||||
prompt: Some(prompt.clone()),
|
||||
position,
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
|
||||
let request_body = open_ai::Request {
|
||||
model: "mercury-coder".into(),
|
||||
messages: vec![open_ai::RequestMessage::User {
|
||||
@@ -160,6 +159,18 @@ impl Mercury {
|
||||
let id = mem::take(&mut response.id);
|
||||
let response_str = text_from_response(response).unwrap_or_default();
|
||||
|
||||
if let Some(debug_tx) = &debug_tx {
|
||||
debug_tx
|
||||
.unbounded_send(DebugEvent::EditPredictionFinished(
|
||||
EditPredictionFinishedDebugEvent {
|
||||
buffer: active_buffer.downgrade(),
|
||||
model_output: Some(response_str.clone()),
|
||||
position,
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
|
||||
let response_str = response_str.strip_prefix("```\n").unwrap_or(&response_str);
|
||||
let response_str = response_str.strip_suffix("\n```").unwrap_or(&response_str);
|
||||
|
||||
@@ -168,15 +179,16 @@ impl Mercury {
|
||||
|
||||
if response_str != NO_PREDICTION_OUTPUT {
|
||||
let old_text = snapshot
|
||||
.text_for_range(offset_range.clone())
|
||||
.text_for_range(editable_offset_range.clone())
|
||||
.collect::<String>();
|
||||
edits.extend(
|
||||
language::text_diff(&old_text, &response_str)
|
||||
.into_iter()
|
||||
.map(|(range, text)| {
|
||||
(
|
||||
snapshot.anchor_after(offset_range.start + range.start)
|
||||
..snapshot.anchor_before(offset_range.start + range.end),
|
||||
snapshot.anchor_after(editable_offset_range.start + range.start)
|
||||
..snapshot
|
||||
.anchor_before(editable_offset_range.start + range.end),
|
||||
text,
|
||||
)
|
||||
}),
|
||||
@@ -186,8 +198,6 @@ impl Mercury {
|
||||
anyhow::Ok((id, edits, snapshot, response_received_at, inputs))
|
||||
});
|
||||
|
||||
let buffer = active_buffer.clone();
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let (id, edits, old_snapshot, response_received_at, inputs) =
|
||||
result.await.context("Mercury edit prediction failed")?;
|
||||
@@ -208,15 +218,7 @@ impl Mercury {
|
||||
}
|
||||
}
|
||||
|
||||
fn build_prompt(
|
||||
events: &[Arc<Event>],
|
||||
related_files: &[RelatedFile],
|
||||
cursor_buffer: &BufferSnapshot,
|
||||
cursor_buffer_path: &Path,
|
||||
cursor_point: Point,
|
||||
editable_range: Range<Point>,
|
||||
context_range: Range<Point>,
|
||||
) -> String {
|
||||
fn build_prompt(inputs: &ZetaPromptInput) -> String {
|
||||
const RECENTLY_VIEWED_SNIPPETS_START: &str = "<|recently_viewed_code_snippets|>\n";
|
||||
const RECENTLY_VIEWED_SNIPPETS_END: &str = "<|/recently_viewed_code_snippets|>\n";
|
||||
const RECENTLY_VIEWED_SNIPPET_START: &str = "<|recently_viewed_code_snippet|>\n";
|
||||
@@ -237,14 +239,14 @@ fn build_prompt(
|
||||
&mut prompt,
|
||||
RECENTLY_VIEWED_SNIPPETS_START..RECENTLY_VIEWED_SNIPPETS_END,
|
||||
|prompt| {
|
||||
for related_file in related_files {
|
||||
for related_file in inputs.related_files.iter() {
|
||||
for related_excerpt in &related_file.excerpts {
|
||||
push_delimited(
|
||||
prompt,
|
||||
RECENTLY_VIEWED_SNIPPET_START..RECENTLY_VIEWED_SNIPPET_END,
|
||||
|prompt| {
|
||||
prompt.push_str(CODE_SNIPPET_FILE_PATH_PREFIX);
|
||||
prompt.push_str(related_file.path.path.as_unix_str());
|
||||
prompt.push_str(related_file.path.to_string_lossy().as_ref());
|
||||
prompt.push('\n');
|
||||
prompt.push_str(&related_excerpt.text.to_string());
|
||||
},
|
||||
@@ -259,21 +261,22 @@ fn build_prompt(
|
||||
CURRENT_FILE_CONTENT_START..CURRENT_FILE_CONTENT_END,
|
||||
|prompt| {
|
||||
prompt.push_str(CURRENT_FILE_PATH_PREFIX);
|
||||
prompt.push_str(cursor_buffer_path.as_os_str().to_string_lossy().as_ref());
|
||||
prompt.push_str(inputs.cursor_path.as_os_str().to_string_lossy().as_ref());
|
||||
prompt.push('\n');
|
||||
|
||||
let prefix_range = context_range.start..editable_range.start;
|
||||
let suffix_range = editable_range.end..context_range.end;
|
||||
|
||||
prompt.extend(cursor_buffer.text_for_range(prefix_range));
|
||||
prompt.push_str(&inputs.cursor_excerpt[0..inputs.editable_range_in_excerpt.start]);
|
||||
push_delimited(prompt, CODE_TO_EDIT_START..CODE_TO_EDIT_END, |prompt| {
|
||||
let range_before_cursor = editable_range.start..cursor_point;
|
||||
let range_after_cursor = cursor_point..editable_range.end;
|
||||
prompt.extend(cursor_buffer.text_for_range(range_before_cursor));
|
||||
prompt.push_str(
|
||||
&inputs.cursor_excerpt
|
||||
[inputs.editable_range_in_excerpt.start..inputs.cursor_offset_in_excerpt],
|
||||
);
|
||||
prompt.push_str(CURSOR_TAG);
|
||||
prompt.extend(cursor_buffer.text_for_range(range_after_cursor));
|
||||
prompt.push_str(
|
||||
&inputs.cursor_excerpt
|
||||
[inputs.cursor_offset_in_excerpt..inputs.editable_range_in_excerpt.end],
|
||||
);
|
||||
});
|
||||
prompt.extend(cursor_buffer.text_for_range(suffix_range));
|
||||
prompt.push_str(&inputs.cursor_excerpt[inputs.editable_range_in_excerpt.end..]);
|
||||
},
|
||||
);
|
||||
|
||||
@@ -281,8 +284,8 @@ fn build_prompt(
|
||||
&mut prompt,
|
||||
EDIT_DIFF_HISTORY_START..EDIT_DIFF_HISTORY_END,
|
||||
|prompt| {
|
||||
for event in events {
|
||||
writeln!(prompt, "{event}").unwrap();
|
||||
for event in inputs.events.iter() {
|
||||
zeta_prompt::write_event(prompt, &event);
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
use std::{
|
||||
ops::Range,
|
||||
path::Path,
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
@@ -9,7 +8,7 @@ use cloud_llm_client::EditPredictionRejectReason;
|
||||
use edit_prediction_types::interpolate_edits;
|
||||
use gpui::{AsyncApp, Entity, SharedString};
|
||||
use language::{Anchor, Buffer, BufferSnapshot, EditPreview, TextBufferSnapshot};
|
||||
use serde::Serialize;
|
||||
use zeta_prompt::ZetaPromptInput;
|
||||
|
||||
#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
|
||||
pub struct EditPredictionId(pub SharedString);
|
||||
@@ -40,7 +39,7 @@ impl EditPredictionResult {
|
||||
edits: Arc<[(Range<Anchor>, Arc<str>)]>,
|
||||
buffer_snapshotted_at: Instant,
|
||||
response_received_at: Instant,
|
||||
inputs: EditPredictionInputs,
|
||||
inputs: ZetaPromptInput,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Self {
|
||||
if edits.is_empty() {
|
||||
@@ -94,15 +93,7 @@ pub struct EditPrediction {
|
||||
pub buffer: Entity<Buffer>,
|
||||
pub buffer_snapshotted_at: Instant,
|
||||
pub response_received_at: Instant,
|
||||
pub inputs: EditPredictionInputs,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct EditPredictionInputs {
|
||||
pub events: Vec<Arc<cloud_llm_client::predict_edits_v3::Event>>,
|
||||
pub included_files: Vec<cloud_llm_client::predict_edits_v3::RelatedFile>,
|
||||
pub cursor_point: cloud_llm_client::predict_edits_v3::Point,
|
||||
pub cursor_path: Arc<Path>,
|
||||
pub inputs: zeta_prompt::ZetaPromptInput,
|
||||
}
|
||||
|
||||
impl EditPrediction {
|
||||
@@ -133,9 +124,12 @@ impl std::fmt::Debug for EditPrediction {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::path::Path;
|
||||
|
||||
use super::*;
|
||||
use gpui::{App, Entity, TestAppContext, prelude::*};
|
||||
use language::{Buffer, ToOffset as _};
|
||||
use zeta_prompt::ZetaPromptInput;
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
|
||||
@@ -154,14 +148,13 @@ mod tests {
|
||||
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
|
||||
buffer: buffer.clone(),
|
||||
edit_preview,
|
||||
inputs: EditPredictionInputs {
|
||||
inputs: ZetaPromptInput {
|
||||
events: vec![],
|
||||
included_files: vec![],
|
||||
cursor_point: cloud_llm_client::predict_edits_v3::Point {
|
||||
line: cloud_llm_client::predict_edits_v3::Line(0),
|
||||
column: 0,
|
||||
},
|
||||
related_files: vec![].into(),
|
||||
cursor_path: Path::new("path.txt").into(),
|
||||
cursor_offset_in_excerpt: 0,
|
||||
cursor_excerpt: "".into(),
|
||||
editable_range_in_excerpt: 0..0,
|
||||
},
|
||||
buffer_snapshotted_at: Instant::now(),
|
||||
response_received_at: Instant::now(),
|
||||
|
||||
@@ -1,26 +1,21 @@
|
||||
use anyhow::{Context as _, Result};
|
||||
use cloud_llm_client::predict_edits_v3::Event;
|
||||
use credentials_provider::CredentialsProvider;
|
||||
use edit_prediction_context::RelatedFile;
|
||||
use futures::{AsyncReadExt as _, FutureExt, future::Shared};
|
||||
use gpui::{
|
||||
App, AppContext as _, Entity, Task,
|
||||
App, AppContext as _, Task,
|
||||
http_client::{self, AsyncBody, Method},
|
||||
};
|
||||
use language::{Buffer, BufferSnapshot, Point, ToOffset as _, ToPoint as _};
|
||||
use language::{Point, ToOffset as _};
|
||||
use lsp::DiagnosticSeverity;
|
||||
use project::{Project, ProjectPath};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
collections::VecDeque,
|
||||
fmt::{self, Write as _},
|
||||
ops::Range,
|
||||
path::Path,
|
||||
sync::Arc,
|
||||
time::Instant,
|
||||
};
|
||||
|
||||
use crate::{EditPredictionId, EditPredictionInputs, prediction::EditPredictionResult};
|
||||
use crate::{EditPredictionId, EditPredictionModelInput, prediction::EditPredictionResult};
|
||||
|
||||
const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
|
||||
|
||||
@@ -44,40 +39,34 @@ impl SweepAi {
|
||||
|
||||
pub fn request_prediction_with_sweep(
|
||||
&self,
|
||||
project: &Entity<Project>,
|
||||
active_buffer: &Entity<Buffer>,
|
||||
snapshot: BufferSnapshot,
|
||||
position: language::Anchor,
|
||||
events: Vec<Arc<Event>>,
|
||||
recent_paths: &VecDeque<ProjectPath>,
|
||||
related_files: Vec<RelatedFile>,
|
||||
diagnostic_search_range: Range<Point>,
|
||||
inputs: EditPredictionModelInput,
|
||||
cx: &mut App,
|
||||
) -> Task<Result<Option<EditPredictionResult>>> {
|
||||
let debug_info = self.debug_info.clone();
|
||||
let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
|
||||
return Task::ready(Ok(None));
|
||||
};
|
||||
let full_path: Arc<Path> = snapshot
|
||||
let full_path: Arc<Path> = inputs
|
||||
.snapshot
|
||||
.file()
|
||||
.map(|file| file.full_path(cx))
|
||||
.unwrap_or_else(|| "untitled".into())
|
||||
.into();
|
||||
|
||||
let project_file = project::File::from_dyn(snapshot.file());
|
||||
let project_file = project::File::from_dyn(inputs.snapshot.file());
|
||||
let repo_name = project_file
|
||||
.map(|file| file.worktree.read(cx).root_name_str())
|
||||
.unwrap_or("untitled")
|
||||
.into();
|
||||
let offset = position.to_offset(&snapshot);
|
||||
let offset = inputs.position.to_offset(&inputs.snapshot);
|
||||
|
||||
let recent_buffers = recent_paths.iter().cloned();
|
||||
let recent_buffers = inputs.recent_paths.iter().cloned();
|
||||
let http_client = cx.http_client();
|
||||
|
||||
let recent_buffer_snapshots = recent_buffers
|
||||
.filter_map(|project_path| {
|
||||
let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
|
||||
if active_buffer == &buffer {
|
||||
let buffer = inputs.project.read(cx).get_open_buffer(&project_path, cx)?;
|
||||
if inputs.buffer == buffer {
|
||||
None
|
||||
} else {
|
||||
Some(buffer.read(cx).snapshot())
|
||||
@@ -86,14 +75,13 @@ impl SweepAi {
|
||||
.take(3)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let cursor_point = position.to_point(&snapshot);
|
||||
let buffer_snapshotted_at = Instant::now();
|
||||
|
||||
let result = cx.background_spawn(async move {
|
||||
let text = snapshot.text();
|
||||
let text = inputs.snapshot.text();
|
||||
|
||||
let mut recent_changes = String::new();
|
||||
for event in &events {
|
||||
for event in &inputs.events {
|
||||
write_event(event.as_ref(), &mut recent_changes).unwrap();
|
||||
}
|
||||
|
||||
@@ -122,20 +110,23 @@ impl SweepAi {
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let retrieval_chunks = related_files
|
||||
let retrieval_chunks = inputs
|
||||
.related_files
|
||||
.iter()
|
||||
.flat_map(|related_file| {
|
||||
related_file.excerpts.iter().map(|excerpt| FileChunk {
|
||||
file_path: related_file.path.path.as_unix_str().to_string(),
|
||||
start_line: excerpt.point_range.start.row as usize,
|
||||
end_line: excerpt.point_range.end.row as usize,
|
||||
file_path: related_file.path.to_string_lossy().to_string(),
|
||||
start_line: excerpt.row_range.start as usize,
|
||||
end_line: excerpt.row_range.end as usize,
|
||||
content: excerpt.text.to_string(),
|
||||
timestamp: None,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let diagnostic_entries = snapshot.diagnostics_in_range(diagnostic_search_range, false);
|
||||
let diagnostic_entries = inputs
|
||||
.snapshot
|
||||
.diagnostics_in_range(inputs.diagnostic_search_range, false);
|
||||
let mut diagnostic_content = String::new();
|
||||
let mut diagnostic_count = 0;
|
||||
|
||||
@@ -195,21 +186,14 @@ impl SweepAi {
|
||||
serde_json::to_writer(writer, &request_body)?;
|
||||
let body: AsyncBody = buf.into();
|
||||
|
||||
let inputs = EditPredictionInputs {
|
||||
events,
|
||||
included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
|
||||
path: full_path.clone(),
|
||||
max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
|
||||
excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
|
||||
start_line: cloud_llm_client::predict_edits_v3::Line(0),
|
||||
text: request_body.file_contents.into(),
|
||||
}],
|
||||
}],
|
||||
cursor_point: cloud_llm_client::predict_edits_v3::Point {
|
||||
column: cursor_point.column,
|
||||
line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
|
||||
},
|
||||
let ep_inputs = zeta_prompt::ZetaPromptInput {
|
||||
events: inputs.events,
|
||||
related_files: inputs.related_files.clone(),
|
||||
cursor_path: full_path.clone(),
|
||||
cursor_excerpt: request_body.file_contents.into(),
|
||||
// we actually don't know
|
||||
editable_range_in_excerpt: 0..inputs.snapshot.len(),
|
||||
cursor_offset_in_excerpt: request_body.cursor_position,
|
||||
};
|
||||
|
||||
let request = http_client::Request::builder()
|
||||
@@ -237,15 +221,20 @@ impl SweepAi {
|
||||
|
||||
let response: AutocompleteResponse = serde_json::from_slice(&body)?;
|
||||
|
||||
let old_text = snapshot
|
||||
let old_text = inputs
|
||||
.snapshot
|
||||
.text_for_range(response.start_index..response.end_index)
|
||||
.collect::<String>();
|
||||
let edits = language::text_diff(&old_text, &response.completion)
|
||||
.into_iter()
|
||||
.map(|(range, text)| {
|
||||
(
|
||||
snapshot.anchor_after(response.start_index + range.start)
|
||||
..snapshot.anchor_before(response.start_index + range.end),
|
||||
inputs
|
||||
.snapshot
|
||||
.anchor_after(response.start_index + range.start)
|
||||
..inputs
|
||||
.snapshot
|
||||
.anchor_before(response.start_index + range.end),
|
||||
text,
|
||||
)
|
||||
})
|
||||
@@ -254,13 +243,13 @@ impl SweepAi {
|
||||
anyhow::Ok((
|
||||
response.autocomplete_id,
|
||||
edits,
|
||||
snapshot,
|
||||
inputs.snapshot,
|
||||
response_received_at,
|
||||
inputs,
|
||||
ep_inputs,
|
||||
))
|
||||
});
|
||||
|
||||
let buffer = active_buffer.clone();
|
||||
let buffer = inputs.buffer.clone();
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
|
||||
@@ -403,12 +392,9 @@ struct AdditionalCompletion {
|
||||
pub finish_reason: Option<String>,
|
||||
}
|
||||
|
||||
fn write_event(
|
||||
event: &cloud_llm_client::predict_edits_v3::Event,
|
||||
f: &mut impl fmt::Write,
|
||||
) -> fmt::Result {
|
||||
fn write_event(event: &zeta_prompt::Event, f: &mut impl fmt::Write) -> fmt::Result {
|
||||
match event {
|
||||
cloud_llm_client::predict_edits_v3::Event::BufferChange {
|
||||
zeta_prompt::Event::BufferChange {
|
||||
old_path,
|
||||
path,
|
||||
diff,
|
||||
|
||||
@@ -14,68 +14,18 @@ use anyhow::anyhow;
|
||||
use collections::HashMap;
|
||||
use gpui::AsyncApp;
|
||||
use gpui::Entity;
|
||||
use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, TextBufferSnapshot};
|
||||
use language::{Anchor, Buffer, OffsetRangeExt as _, TextBufferSnapshot};
|
||||
use project::Project;
|
||||
|
||||
pub async fn parse_diff<'a>(
|
||||
diff_str: &'a str,
|
||||
get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
|
||||
) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
|
||||
let mut diff = DiffParser::new(diff_str);
|
||||
let mut edited_buffer = None;
|
||||
let mut edits = Vec::new();
|
||||
|
||||
while let Some(event) = diff.next()? {
|
||||
match event {
|
||||
DiffEvent::Hunk {
|
||||
path: file_path,
|
||||
hunk,
|
||||
} => {
|
||||
let (buffer, ranges) = match edited_buffer {
|
||||
None => {
|
||||
edited_buffer = get_buffer(&Path::new(file_path.as_ref()));
|
||||
edited_buffer
|
||||
.as_ref()
|
||||
.context("Model tried to edit a file that wasn't included")?
|
||||
}
|
||||
Some(ref current) => current,
|
||||
};
|
||||
|
||||
edits.extend(
|
||||
resolve_hunk_edits_in_buffer(hunk, &buffer.text, ranges)
|
||||
.with_context(|| format!("Diff:\n{diff_str}"))?,
|
||||
);
|
||||
}
|
||||
DiffEvent::FileEnd { renamed_to } => {
|
||||
let (buffer, _) = edited_buffer
|
||||
.take()
|
||||
.context("Got a FileEnd event before an Hunk event")?;
|
||||
|
||||
if renamed_to.is_some() {
|
||||
anyhow::bail!("edit predictions cannot rename files");
|
||||
}
|
||||
|
||||
if diff.next()?.is_some() {
|
||||
anyhow::bail!("Edited more than one file");
|
||||
}
|
||||
|
||||
return Ok((buffer, edits));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Err(anyhow::anyhow!("No EOF"))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct OpenedBuffers<'a>(#[allow(unused)] HashMap<Cow<'a, str>, Entity<Buffer>>);
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct OpenedBuffers(#[allow(unused)] HashMap<String, Entity<Buffer>>);
|
||||
|
||||
#[must_use]
|
||||
pub async fn apply_diff<'a>(
|
||||
diff_str: &'a str,
|
||||
pub async fn apply_diff(
|
||||
diff_str: &str,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<OpenedBuffers<'a>> {
|
||||
) -> Result<OpenedBuffers> {
|
||||
let mut included_files = HashMap::default();
|
||||
|
||||
for line in diff_str.lines() {
|
||||
@@ -94,7 +44,7 @@ pub async fn apply_diff<'a>(
|
||||
})??
|
||||
.await?;
|
||||
|
||||
included_files.insert(path, buffer);
|
||||
included_files.insert(path.to_string(), buffer);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -113,7 +63,7 @@ pub async fn apply_diff<'a>(
|
||||
let (buffer, ranges) = match current_file {
|
||||
None => {
|
||||
let buffer = included_files
|
||||
.get_mut(&file_path)
|
||||
.get_mut(file_path.as_ref())
|
||||
.expect("Opened all files in diff");
|
||||
|
||||
current_file = Some((buffer, ranges.as_slice()));
|
||||
@@ -167,6 +117,29 @@ pub async fn apply_diff<'a>(
|
||||
Ok(OpenedBuffers(included_files))
|
||||
}
|
||||
|
||||
pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result<String> {
|
||||
let mut diff = DiffParser::new(diff_str);
|
||||
|
||||
let mut text = text.to_string();
|
||||
|
||||
while let Some(event) = diff.next()? {
|
||||
match event {
|
||||
DiffEvent::Hunk { hunk, .. } => {
|
||||
let hunk_offset = text
|
||||
.find(&hunk.context)
|
||||
.ok_or_else(|| anyhow!("couldn't result hunk {:?}", hunk.context))?;
|
||||
for edit in hunk.edits.iter().rev() {
|
||||
let range = (hunk_offset + edit.range.start)..(hunk_offset + edit.range.end);
|
||||
text.replace_range(range, &edit.text);
|
||||
}
|
||||
}
|
||||
DiffEvent::FileEnd { .. } => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(text)
|
||||
}
|
||||
|
||||
struct PatchFile<'a> {
|
||||
old_path: Cow<'a, str>,
|
||||
new_path: Cow<'a, str>,
|
||||
@@ -492,7 +465,6 @@ mod tests {
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
use indoc::indoc;
|
||||
use language::Point;
|
||||
use pretty_assertions::assert_eq;
|
||||
use project::{FakeFs, Project};
|
||||
use serde_json::json;
|
||||
@@ -817,137 +789,6 @@ mod tests {
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_apply_diff_non_unique(cx: &mut TestAppContext) {
|
||||
let fs = init_test(cx);
|
||||
|
||||
let buffer_1_text = indoc! {r#"
|
||||
one
|
||||
two
|
||||
three
|
||||
four
|
||||
five
|
||||
one
|
||||
two
|
||||
three
|
||||
four
|
||||
five
|
||||
"# };
|
||||
|
||||
fs.insert_tree(
|
||||
path!("/root"),
|
||||
json!({
|
||||
"file1": buffer_1_text,
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_local_buffer(path!("/root/file1"), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
|
||||
|
||||
let diff = indoc! {r#"
|
||||
--- a/root/file1
|
||||
+++ b/root/file1
|
||||
one
|
||||
two
|
||||
-three
|
||||
+3
|
||||
four
|
||||
five
|
||||
"#};
|
||||
|
||||
let final_text = indoc! {r#"
|
||||
one
|
||||
two
|
||||
three
|
||||
four
|
||||
five
|
||||
one
|
||||
two
|
||||
3
|
||||
four
|
||||
five
|
||||
"#};
|
||||
|
||||
apply_diff(diff, &project, &mut cx.to_async())
|
||||
.await
|
||||
.expect_err("Non-unique edits should fail");
|
||||
|
||||
let ranges = [buffer_snapshot.anchor_before(Point::new(1, 0))
|
||||
..buffer_snapshot.anchor_after(buffer_snapshot.max_point())];
|
||||
|
||||
let (edited_snapshot, edits) = parse_diff(diff, |_path| Some((&buffer_snapshot, &ranges)))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(edited_snapshot.remote_id(), buffer_snapshot.remote_id());
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
buffer.edit(edits, None, cx);
|
||||
assert_eq!(buffer.text(), final_text);
|
||||
});
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_parse_diff_with_edits_within_line(cx: &mut TestAppContext) {
|
||||
let fs = init_test(cx);
|
||||
|
||||
let buffer_1_text = indoc! {r#"
|
||||
one two three four
|
||||
five six seven eight
|
||||
nine ten eleven twelve
|
||||
"# };
|
||||
|
||||
fs.insert_tree(
|
||||
path!("/root"),
|
||||
json!({
|
||||
"file1": buffer_1_text,
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_local_buffer(path!("/root/file1"), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
|
||||
|
||||
let diff = indoc! {r#"
|
||||
--- a/root/file1
|
||||
+++ b/root/file1
|
||||
one two three four
|
||||
-five six seven eight
|
||||
+five SIX seven eight!
|
||||
nine ten eleven twelve
|
||||
"#};
|
||||
|
||||
let (buffer, edits) = parse_diff(diff, |_path| {
|
||||
Some((&buffer_snapshot, &[(Anchor::MIN..Anchor::MAX)] as &[_]))
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let edits = edits
|
||||
.into_iter()
|
||||
.map(|(range, text)| (range.to_point(&buffer), text))
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(
|
||||
edits,
|
||||
&[
|
||||
(Point::new(1, 5)..Point::new(1, 8), "SIX".into()),
|
||||
(Point::new(1, 20)..Point::new(1, 20), "!".into())
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_apply_diff_unique_via_previous_context(cx: &mut TestAppContext) {
|
||||
let fs = init_test(cx);
|
||||
|
||||
@@ -1,637 +0,0 @@
|
||||
use anyhow::{Context as _, Result};
|
||||
use language::{Anchor, BufferSnapshot, OffsetRangeExt as _, Point};
|
||||
use std::{cmp, ops::Range, path::Path, sync::Arc};
|
||||
|
||||
const EDITS_TAG_NAME: &'static str = "edits";
|
||||
const OLD_TEXT_TAG_NAME: &'static str = "old_text";
|
||||
const NEW_TEXT_TAG_NAME: &'static str = "new_text";
|
||||
const XML_TAGS: &[&str] = &[EDITS_TAG_NAME, OLD_TEXT_TAG_NAME, NEW_TEXT_TAG_NAME];
|
||||
|
||||
pub async fn parse_xml_edits<'a>(
|
||||
input: &'a str,
|
||||
get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
|
||||
) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
|
||||
parse_xml_edits_inner(input, get_buffer)
|
||||
.await
|
||||
.with_context(|| format!("Failed to parse XML edits:\n{input}"))
|
||||
}
|
||||
|
||||
async fn parse_xml_edits_inner<'a>(
|
||||
input: &'a str,
|
||||
get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
|
||||
) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
|
||||
let xml_edits = extract_xml_replacements(input)?;
|
||||
|
||||
let (buffer, context_ranges) = get_buffer(xml_edits.file_path.as_ref())
|
||||
.with_context(|| format!("no buffer for file {}", xml_edits.file_path))?;
|
||||
|
||||
let mut all_edits = vec![];
|
||||
for (old_text, new_text) in xml_edits.replacements {
|
||||
let match_range = fuzzy_match_in_ranges(old_text, buffer, context_ranges)?;
|
||||
let matched_old_text = buffer
|
||||
.text_for_range(match_range.clone())
|
||||
.collect::<String>();
|
||||
let edits_within_hunk = language::text_diff(&matched_old_text, new_text);
|
||||
all_edits.extend(
|
||||
edits_within_hunk
|
||||
.into_iter()
|
||||
.map(move |(inner_range, inner_text)| {
|
||||
(
|
||||
buffer.anchor_after(match_range.start + inner_range.start)
|
||||
..buffer.anchor_before(match_range.start + inner_range.end),
|
||||
inner_text,
|
||||
)
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
Ok((buffer, all_edits))
|
||||
}
|
||||
|
||||
fn fuzzy_match_in_ranges(
|
||||
old_text: &str,
|
||||
buffer: &BufferSnapshot,
|
||||
context_ranges: &[Range<Anchor>],
|
||||
) -> Result<Range<usize>> {
|
||||
let mut state = FuzzyMatcher::new(buffer, old_text);
|
||||
let mut best_match = None;
|
||||
let mut tie_match_range = None;
|
||||
|
||||
for range in context_ranges {
|
||||
let best_match_cost = best_match.as_ref().map(|(score, _)| *score);
|
||||
match (best_match_cost, state.match_range(range.to_offset(buffer))) {
|
||||
(Some(lowest_cost), Some((new_cost, new_range))) => {
|
||||
if new_cost == lowest_cost {
|
||||
tie_match_range = Some(new_range);
|
||||
} else if new_cost < lowest_cost {
|
||||
tie_match_range.take();
|
||||
best_match = Some((new_cost, new_range));
|
||||
}
|
||||
}
|
||||
(None, Some(new_match)) => {
|
||||
best_match = Some(new_match);
|
||||
}
|
||||
(None, None) | (Some(_), None) => {}
|
||||
};
|
||||
}
|
||||
|
||||
if let Some((_, best_match_range)) = best_match {
|
||||
if let Some(tie_match_range) = tie_match_range {
|
||||
anyhow::bail!(
|
||||
"Multiple ambiguous matches:\n{:?}:\n{}\n\n{:?}:\n{}",
|
||||
best_match_range.clone(),
|
||||
buffer.text_for_range(best_match_range).collect::<String>(),
|
||||
tie_match_range.clone(),
|
||||
buffer.text_for_range(tie_match_range).collect::<String>()
|
||||
);
|
||||
}
|
||||
return Ok(best_match_range);
|
||||
}
|
||||
|
||||
anyhow::bail!(
|
||||
"Failed to fuzzy match `old_text`:\n{}\nin:\n```\n{}\n```",
|
||||
old_text,
|
||||
context_ranges
|
||||
.iter()
|
||||
.map(|range| buffer.text_for_range(range.clone()).collect::<String>())
|
||||
.collect::<Vec<String>>()
|
||||
.join("```\n```")
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct XmlEdits<'a> {
|
||||
file_path: &'a str,
|
||||
/// Vec of (old_text, new_text) pairs
|
||||
replacements: Vec<(&'a str, &'a str)>,
|
||||
}
|
||||
|
||||
fn extract_xml_replacements(input: &str) -> Result<XmlEdits<'_>> {
|
||||
let mut cursor = 0;
|
||||
|
||||
let (edits_body_start, edits_attrs) =
|
||||
find_tag_open(input, &mut cursor, EDITS_TAG_NAME)?.context("No edits tag found")?;
|
||||
|
||||
let file_path = edits_attrs
|
||||
.trim_start()
|
||||
.strip_prefix("path")
|
||||
.context("no path attribute on edits tag")?
|
||||
.trim_end()
|
||||
.strip_prefix('=')
|
||||
.context("no value for path attribute")?
|
||||
.trim()
|
||||
.trim_start_matches('"')
|
||||
.trim_end_matches('"');
|
||||
|
||||
cursor = edits_body_start;
|
||||
let mut edits_list = Vec::new();
|
||||
|
||||
while let Some((old_body_start, _)) = find_tag_open(input, &mut cursor, OLD_TEXT_TAG_NAME)? {
|
||||
let old_body_end = find_tag_close(input, &mut cursor)?;
|
||||
let old_text = trim_surrounding_newlines(&input[old_body_start..old_body_end]);
|
||||
|
||||
let (new_body_start, _) = find_tag_open(input, &mut cursor, NEW_TEXT_TAG_NAME)?
|
||||
.context("no new_text tag following old_text")?;
|
||||
let new_body_end = find_tag_close(input, &mut cursor)?;
|
||||
let new_text = trim_surrounding_newlines(&input[new_body_start..new_body_end]);
|
||||
|
||||
edits_list.push((old_text, new_text));
|
||||
}
|
||||
|
||||
Ok(XmlEdits {
|
||||
file_path,
|
||||
replacements: edits_list,
|
||||
})
|
||||
}
|
||||
|
||||
/// Trims a single leading and trailing newline
|
||||
fn trim_surrounding_newlines(input: &str) -> &str {
|
||||
let start = input.strip_prefix('\n').unwrap_or(input);
|
||||
let end = start.strip_suffix('\n').unwrap_or(start);
|
||||
end
|
||||
}
|
||||
|
||||
fn find_tag_open<'a>(
|
||||
input: &'a str,
|
||||
cursor: &mut usize,
|
||||
expected_tag: &str,
|
||||
) -> Result<Option<(usize, &'a str)>> {
|
||||
let mut search_pos = *cursor;
|
||||
|
||||
while search_pos < input.len() {
|
||||
let Some(tag_start) = input[search_pos..].find("<") else {
|
||||
break;
|
||||
};
|
||||
let tag_start = search_pos + tag_start;
|
||||
if !input[tag_start + 1..].starts_with(expected_tag) {
|
||||
search_pos = search_pos + tag_start + 1;
|
||||
continue;
|
||||
};
|
||||
|
||||
let after_tag_name = tag_start + expected_tag.len() + 1;
|
||||
let close_bracket = input[after_tag_name..]
|
||||
.find('>')
|
||||
.with_context(|| format!("missing > after <{}", expected_tag))?;
|
||||
let attrs_end = after_tag_name + close_bracket;
|
||||
let body_start = attrs_end + 1;
|
||||
|
||||
let attributes = input[after_tag_name..attrs_end].trim();
|
||||
*cursor = body_start;
|
||||
|
||||
return Ok(Some((body_start, attributes)));
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
fn find_tag_close(input: &str, cursor: &mut usize) -> Result<usize> {
|
||||
let mut depth = 1;
|
||||
let mut search_pos = *cursor;
|
||||
|
||||
while search_pos < input.len() && depth > 0 {
|
||||
let Some(bracket_offset) = input[search_pos..].find('<') else {
|
||||
break;
|
||||
};
|
||||
let bracket_pos = search_pos + bracket_offset;
|
||||
|
||||
if input[bracket_pos..].starts_with("</")
|
||||
&& let Some(close_end) = input[bracket_pos + 2..].find('>')
|
||||
{
|
||||
let close_start = bracket_pos + 2;
|
||||
let tag_name = input[close_start..close_start + close_end].trim();
|
||||
|
||||
if XML_TAGS.contains(&tag_name) {
|
||||
depth -= 1;
|
||||
if depth == 0 {
|
||||
*cursor = close_start + close_end + 1;
|
||||
return Ok(bracket_pos);
|
||||
}
|
||||
}
|
||||
search_pos = close_start + close_end + 1;
|
||||
continue;
|
||||
} else if let Some(close_bracket_offset) = input[bracket_pos..].find('>') {
|
||||
let close_bracket_pos = bracket_pos + close_bracket_offset;
|
||||
let tag_name = &input[bracket_pos + 1..close_bracket_pos].trim();
|
||||
if XML_TAGS.contains(&tag_name) {
|
||||
depth += 1;
|
||||
}
|
||||
}
|
||||
|
||||
search_pos = bracket_pos + 1;
|
||||
}
|
||||
|
||||
anyhow::bail!("no closing tag found")
|
||||
}
|
||||
|
||||
const REPLACEMENT_COST: u32 = 1;
|
||||
const INSERTION_COST: u32 = 3;
|
||||
const DELETION_COST: u32 = 10;
|
||||
|
||||
/// A fuzzy matcher that can process text chunks incrementally
|
||||
/// and return the best match found so far at each step.
|
||||
struct FuzzyMatcher<'a> {
|
||||
snapshot: &'a BufferSnapshot,
|
||||
query_lines: Vec<&'a str>,
|
||||
matrix: SearchMatrix,
|
||||
}
|
||||
|
||||
impl<'a> FuzzyMatcher<'a> {
|
||||
fn new(snapshot: &'a BufferSnapshot, old_text: &'a str) -> Self {
|
||||
let query_lines = old_text.lines().collect();
|
||||
Self {
|
||||
snapshot,
|
||||
query_lines,
|
||||
matrix: SearchMatrix::new(0),
|
||||
}
|
||||
}
|
||||
|
||||
fn match_range(&mut self, range: Range<usize>) -> Option<(u32, Range<usize>)> {
|
||||
let point_range = range.to_point(&self.snapshot);
|
||||
let buffer_line_count = (point_range.end.row - point_range.start.row + 1) as usize;
|
||||
|
||||
self.matrix
|
||||
.reset(self.query_lines.len() + 1, buffer_line_count + 1);
|
||||
let query_line_count = self.query_lines.len();
|
||||
|
||||
for row in 0..query_line_count {
|
||||
let query_line = self.query_lines[row].trim();
|
||||
let leading_deletion_cost = (row + 1) as u32 * DELETION_COST;
|
||||
|
||||
self.matrix.set(
|
||||
row + 1,
|
||||
0,
|
||||
SearchState::new(leading_deletion_cost, SearchDirection::Up),
|
||||
);
|
||||
|
||||
let mut buffer_lines = self.snapshot.text_for_range(range.clone()).lines();
|
||||
|
||||
let mut col = 0;
|
||||
while let Some(buffer_line) = buffer_lines.next() {
|
||||
let buffer_line = buffer_line.trim();
|
||||
let up = SearchState::new(
|
||||
self.matrix
|
||||
.get(row, col + 1)
|
||||
.cost
|
||||
.saturating_add(DELETION_COST),
|
||||
SearchDirection::Up,
|
||||
);
|
||||
let left = SearchState::new(
|
||||
self.matrix
|
||||
.get(row + 1, col)
|
||||
.cost
|
||||
.saturating_add(INSERTION_COST),
|
||||
SearchDirection::Left,
|
||||
);
|
||||
let diagonal = SearchState::new(
|
||||
if query_line == buffer_line {
|
||||
self.matrix.get(row, col).cost
|
||||
} else if fuzzy_eq(query_line, buffer_line) {
|
||||
self.matrix.get(row, col).cost + REPLACEMENT_COST
|
||||
} else {
|
||||
self.matrix
|
||||
.get(row, col)
|
||||
.cost
|
||||
.saturating_add(DELETION_COST + INSERTION_COST)
|
||||
},
|
||||
SearchDirection::Diagonal,
|
||||
);
|
||||
self.matrix
|
||||
.set(row + 1, col + 1, up.min(left).min(diagonal));
|
||||
col += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Find all matches with the best cost
|
||||
let mut best_cost = u32::MAX;
|
||||
let mut matches_with_best_cost = Vec::new();
|
||||
|
||||
for col in 1..=buffer_line_count {
|
||||
let cost = self.matrix.get(query_line_count, col).cost;
|
||||
if cost < best_cost {
|
||||
best_cost = cost;
|
||||
matches_with_best_cost.clear();
|
||||
matches_with_best_cost.push(col as u32);
|
||||
} else if cost == best_cost {
|
||||
matches_with_best_cost.push(col as u32);
|
||||
}
|
||||
}
|
||||
|
||||
// Find ranges for the matches
|
||||
for &match_end_col in &matches_with_best_cost {
|
||||
let mut matched_lines = 0;
|
||||
let mut query_row = query_line_count;
|
||||
let mut match_start_col = match_end_col;
|
||||
while query_row > 0 && match_start_col > 0 {
|
||||
let current = self.matrix.get(query_row, match_start_col as usize);
|
||||
match current.direction {
|
||||
SearchDirection::Diagonal => {
|
||||
query_row -= 1;
|
||||
match_start_col -= 1;
|
||||
matched_lines += 1;
|
||||
}
|
||||
SearchDirection::Up => {
|
||||
query_row -= 1;
|
||||
}
|
||||
SearchDirection::Left => {
|
||||
match_start_col -= 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let buffer_row_start = match_start_col + point_range.start.row;
|
||||
let buffer_row_end = match_end_col + point_range.start.row;
|
||||
|
||||
let matched_buffer_row_count = buffer_row_end - buffer_row_start;
|
||||
let matched_ratio = matched_lines as f32
|
||||
/ (matched_buffer_row_count as f32).max(query_line_count as f32);
|
||||
if matched_ratio >= 0.8 {
|
||||
let buffer_start_ix = self
|
||||
.snapshot
|
||||
.point_to_offset(Point::new(buffer_row_start, 0));
|
||||
let buffer_end_ix = self.snapshot.point_to_offset(Point::new(
|
||||
buffer_row_end - 1,
|
||||
self.snapshot.line_len(buffer_row_end - 1),
|
||||
));
|
||||
return Some((best_cost, buffer_start_ix..buffer_end_ix));
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn fuzzy_eq(left: &str, right: &str) -> bool {
|
||||
const THRESHOLD: f64 = 0.8;
|
||||
|
||||
let min_levenshtein = left.len().abs_diff(right.len());
|
||||
let min_normalized_levenshtein =
|
||||
1. - (min_levenshtein as f64 / cmp::max(left.len(), right.len()) as f64);
|
||||
if min_normalized_levenshtein < THRESHOLD {
|
||||
return false;
|
||||
}
|
||||
|
||||
strsim::normalized_levenshtein(left, right) >= THRESHOLD
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
|
||||
enum SearchDirection {
|
||||
Up,
|
||||
Left,
|
||||
Diagonal,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
|
||||
struct SearchState {
|
||||
cost: u32,
|
||||
direction: SearchDirection,
|
||||
}
|
||||
|
||||
impl SearchState {
|
||||
fn new(cost: u32, direction: SearchDirection) -> Self {
|
||||
Self { cost, direction }
|
||||
}
|
||||
}
|
||||
|
||||
struct SearchMatrix {
|
||||
cols: usize,
|
||||
rows: usize,
|
||||
data: Vec<SearchState>,
|
||||
}
|
||||
|
||||
impl SearchMatrix {
|
||||
fn new(cols: usize) -> Self {
|
||||
SearchMatrix {
|
||||
cols,
|
||||
rows: 0,
|
||||
data: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn reset(&mut self, rows: usize, cols: usize) {
|
||||
self.rows = rows;
|
||||
self.cols = cols;
|
||||
self.data
|
||||
.fill(SearchState::new(0, SearchDirection::Diagonal));
|
||||
self.data.resize(
|
||||
self.rows * self.cols,
|
||||
SearchState::new(0, SearchDirection::Diagonal),
|
||||
);
|
||||
}
|
||||
|
||||
fn get(&self, row: usize, col: usize) -> SearchState {
|
||||
debug_assert!(row < self.rows);
|
||||
debug_assert!(col < self.cols);
|
||||
self.data[row * self.cols + col]
|
||||
}
|
||||
|
||||
fn set(&mut self, row: usize, col: usize, state: SearchState) {
|
||||
debug_assert!(row < self.rows && col < self.cols);
|
||||
self.data[row * self.cols + col] = state;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use gpui::TestAppContext;
|
||||
use indoc::indoc;
|
||||
use language::Point;
|
||||
use project::{FakeFs, Project};
|
||||
use serde_json::json;
|
||||
use settings::SettingsStore;
|
||||
use util::path;
|
||||
|
||||
#[test]
|
||||
fn test_extract_xml_edits() {
|
||||
let input = indoc! {r#"
|
||||
<edits path="test.rs">
|
||||
<old_text>
|
||||
old content
|
||||
</old_text>
|
||||
<new_text>
|
||||
new content
|
||||
</new_text>
|
||||
</edits>
|
||||
"#};
|
||||
|
||||
let result = extract_xml_replacements(input).unwrap();
|
||||
assert_eq!(result.file_path, "test.rs");
|
||||
assert_eq!(result.replacements.len(), 1);
|
||||
assert_eq!(result.replacements[0].0, "old content");
|
||||
assert_eq!(result.replacements[0].1, "new content");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_xml_edits_with_wrong_closing_tags() {
|
||||
let input = indoc! {r#"
|
||||
<edits path="test.rs">
|
||||
<old_text>
|
||||
old content
|
||||
</new_text>
|
||||
<new_text>
|
||||
new content
|
||||
</old_text>
|
||||
</ edits >
|
||||
"#};
|
||||
|
||||
let result = extract_xml_replacements(input).unwrap();
|
||||
assert_eq!(result.file_path, "test.rs");
|
||||
assert_eq!(result.replacements.len(), 1);
|
||||
assert_eq!(result.replacements[0].0, "old content");
|
||||
assert_eq!(result.replacements[0].1, "new content");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_xml_edits_with_xml_like_content() {
|
||||
let input = indoc! {r#"
|
||||
<edits path="component.tsx">
|
||||
<old_text>
|
||||
<foo><bar></bar></foo>
|
||||
</old_text>
|
||||
<new_text>
|
||||
<foo><bar><baz></baz></bar></foo>
|
||||
</new_text>
|
||||
</edits>
|
||||
"#};
|
||||
|
||||
let result = extract_xml_replacements(input).unwrap();
|
||||
assert_eq!(result.file_path, "component.tsx");
|
||||
assert_eq!(result.replacements.len(), 1);
|
||||
assert_eq!(result.replacements[0].0, "<foo><bar></bar></foo>");
|
||||
assert_eq!(
|
||||
result.replacements[0].1,
|
||||
"<foo><bar><baz></baz></bar></foo>"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_xml_edits_with_conflicting_content() {
|
||||
let input = indoc! {r#"
|
||||
<edits path="component.tsx">
|
||||
<old_text>
|
||||
<new_text></new_text>
|
||||
</old_text>
|
||||
<new_text>
|
||||
<old_text></old_text>
|
||||
</new_text>
|
||||
</edits>
|
||||
"#};
|
||||
|
||||
let result = extract_xml_replacements(input).unwrap();
|
||||
assert_eq!(result.file_path, "component.tsx");
|
||||
assert_eq!(result.replacements.len(), 1);
|
||||
assert_eq!(result.replacements[0].0, "<new_text></new_text>");
|
||||
assert_eq!(result.replacements[0].1, "<old_text></old_text>");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_xml_edits_multiple_pairs() {
|
||||
let input = indoc! {r#"
|
||||
Some reasoning before edits. Lots of thinking going on here
|
||||
|
||||
<edits path="test.rs">
|
||||
<old_text>
|
||||
first old
|
||||
</old_text>
|
||||
<new_text>
|
||||
first new
|
||||
</new_text>
|
||||
<old_text>
|
||||
second old
|
||||
</edits>
|
||||
<new_text>
|
||||
second new
|
||||
</old_text>
|
||||
</edits>
|
||||
"#};
|
||||
|
||||
let result = extract_xml_replacements(input).unwrap();
|
||||
assert_eq!(result.file_path, "test.rs");
|
||||
assert_eq!(result.replacements.len(), 2);
|
||||
assert_eq!(result.replacements[0].0, "first old");
|
||||
assert_eq!(result.replacements[0].1, "first new");
|
||||
assert_eq!(result.replacements[1].0, "second old");
|
||||
assert_eq!(result.replacements[1].1, "second new");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_xml_edits_unexpected_eof() {
|
||||
let input = indoc! {r#"
|
||||
<edits path="test.rs">
|
||||
<old_text>
|
||||
first old
|
||||
</
|
||||
"#};
|
||||
|
||||
extract_xml_replacements(input).expect_err("Unexpected end of file");
|
||||
}
|
||||
|
||||
#[gpui::test]
|
||||
async fn test_parse_xml_edits(cx: &mut TestAppContext) {
|
||||
let fs = init_test(cx);
|
||||
|
||||
let buffer_1_text = indoc! {r#"
|
||||
one two three four
|
||||
five six seven eight
|
||||
nine ten eleven twelve
|
||||
thirteen fourteen fifteen
|
||||
sixteen seventeen eighteen
|
||||
"#};
|
||||
|
||||
fs.insert_tree(
|
||||
path!("/root"),
|
||||
json!({
|
||||
"file1": buffer_1_text,
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_local_buffer(path!("/root/file1"), cx)
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
|
||||
|
||||
let edits = indoc! {r#"
|
||||
<edits path="root/file1">
|
||||
<old_text>
|
||||
nine ten eleven twelve
|
||||
</old_text>
|
||||
<new_text>
|
||||
nine TEN eleven twelve!
|
||||
</new_text>
|
||||
</edits>
|
||||
"#};
|
||||
|
||||
let included_ranges = [(buffer_snapshot.anchor_before(Point::new(1, 0))..Anchor::MAX)];
|
||||
let (buffer, edits) = parse_xml_edits(edits, |_path| {
|
||||
Some((&buffer_snapshot, included_ranges.as_slice()))
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let edits = edits
|
||||
.into_iter()
|
||||
.map(|(range, text)| (range.to_point(&buffer), text))
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(
|
||||
edits,
|
||||
&[
|
||||
(Point::new(2, 5)..Point::new(2, 8), "TEN".into()),
|
||||
(Point::new(2, 22)..Point::new(2, 22), "!".into())
|
||||
]
|
||||
);
|
||||
}
|
||||
|
||||
fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
|
||||
cx.update(|cx| {
|
||||
let settings_store = SettingsStore::test(cx);
|
||||
cx.set_global(settings_store);
|
||||
});
|
||||
|
||||
FakeFs::new(cx.background_executor.clone())
|
||||
}
|
||||
}
|
||||
@@ -125,14 +125,15 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate {
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(current) = store.current_prediction_for_buffer(&buffer, &self.project, cx)
|
||||
&& let BufferEditPrediction::Local { prediction } = current
|
||||
&& prediction.interpolate(buffer.read(cx)).is_some()
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
self.store.update(cx, |store, cx| {
|
||||
if let Some(current) =
|
||||
store.prediction_at(&buffer, Some(cursor_position), &self.project, cx)
|
||||
&& let BufferEditPrediction::Local { prediction } = current
|
||||
&& prediction.interpolate(buffer.read(cx)).is_some()
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
store.refresh_context(&self.project, &buffer, cursor_position, cx);
|
||||
store.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
|
||||
});
|
||||
@@ -171,69 +172,68 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate {
|
||||
cursor_position: language::Anchor,
|
||||
cx: &mut Context<Self>,
|
||||
) -> Option<edit_prediction_types::EditPrediction> {
|
||||
let prediction =
|
||||
self.store
|
||||
.read(cx)
|
||||
.current_prediction_for_buffer(buffer, &self.project, cx)?;
|
||||
self.store.update(cx, |store, cx| {
|
||||
let prediction =
|
||||
store.prediction_at(buffer, Some(cursor_position), &self.project, cx)?;
|
||||
|
||||
let prediction = match prediction {
|
||||
BufferEditPrediction::Local { prediction } => prediction,
|
||||
BufferEditPrediction::Jump { prediction } => {
|
||||
return Some(edit_prediction_types::EditPrediction::Jump {
|
||||
id: Some(prediction.id.to_string().into()),
|
||||
snapshot: prediction.snapshot.clone(),
|
||||
target: prediction.edits.first().unwrap().0.start,
|
||||
});
|
||||
}
|
||||
};
|
||||
let prediction = match prediction {
|
||||
BufferEditPrediction::Local { prediction } => prediction,
|
||||
BufferEditPrediction::Jump { prediction } => {
|
||||
return Some(edit_prediction_types::EditPrediction::Jump {
|
||||
id: Some(prediction.id.to_string().into()),
|
||||
snapshot: prediction.snapshot.clone(),
|
||||
target: prediction.edits.first().unwrap().0.start,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let buffer = buffer.read(cx);
|
||||
let snapshot = buffer.snapshot();
|
||||
let buffer = buffer.read(cx);
|
||||
let snapshot = buffer.snapshot();
|
||||
|
||||
let Some(edits) = prediction.interpolate(&snapshot) else {
|
||||
self.store.update(cx, |store, _cx| {
|
||||
let Some(edits) = prediction.interpolate(&snapshot) else {
|
||||
store.reject_current_prediction(
|
||||
EditPredictionRejectReason::InterpolatedEmpty,
|
||||
&self.project,
|
||||
);
|
||||
});
|
||||
return None;
|
||||
};
|
||||
return None;
|
||||
};
|
||||
|
||||
let cursor_row = cursor_position.to_point(&snapshot).row;
|
||||
let (closest_edit_ix, (closest_edit_range, _)) =
|
||||
edits.iter().enumerate().min_by_key(|(_, (range, _))| {
|
||||
let distance_from_start = cursor_row.abs_diff(range.start.to_point(&snapshot).row);
|
||||
let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
|
||||
cmp::min(distance_from_start, distance_from_end)
|
||||
})?;
|
||||
let cursor_row = cursor_position.to_point(&snapshot).row;
|
||||
let (closest_edit_ix, (closest_edit_range, _)) =
|
||||
edits.iter().enumerate().min_by_key(|(_, (range, _))| {
|
||||
let distance_from_start =
|
||||
cursor_row.abs_diff(range.start.to_point(&snapshot).row);
|
||||
let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
|
||||
cmp::min(distance_from_start, distance_from_end)
|
||||
})?;
|
||||
|
||||
let mut edit_start_ix = closest_edit_ix;
|
||||
for (range, _) in edits[..edit_start_ix].iter().rev() {
|
||||
let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
|
||||
- range.end.to_point(&snapshot).row;
|
||||
if distance_from_closest_edit <= 1 {
|
||||
edit_start_ix -= 1;
|
||||
} else {
|
||||
break;
|
||||
let mut edit_start_ix = closest_edit_ix;
|
||||
for (range, _) in edits[..edit_start_ix].iter().rev() {
|
||||
let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
|
||||
- range.end.to_point(&snapshot).row;
|
||||
if distance_from_closest_edit <= 1 {
|
||||
edit_start_ix -= 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut edit_end_ix = closest_edit_ix + 1;
|
||||
for (range, _) in &edits[edit_end_ix..] {
|
||||
let distance_from_closest_edit =
|
||||
range.start.to_point(buffer).row - closest_edit_range.end.to_point(&snapshot).row;
|
||||
if distance_from_closest_edit <= 1 {
|
||||
edit_end_ix += 1;
|
||||
} else {
|
||||
break;
|
||||
let mut edit_end_ix = closest_edit_ix + 1;
|
||||
for (range, _) in &edits[edit_end_ix..] {
|
||||
let distance_from_closest_edit = range.start.to_point(buffer).row
|
||||
- closest_edit_range.end.to_point(&snapshot).row;
|
||||
if distance_from_closest_edit <= 1 {
|
||||
edit_end_ix += 1;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some(edit_prediction_types::EditPrediction::Local {
|
||||
id: Some(prediction.id.to_string().into()),
|
||||
edits: edits[edit_start_ix..edit_end_ix].to_vec(),
|
||||
edit_preview: Some(prediction.edit_preview.clone()),
|
||||
Some(edit_prediction_types::EditPrediction::Local {
|
||||
id: Some(prediction.id.to_string().into()),
|
||||
edits: edits[edit_start_ix..edit_end_ix].to_vec(),
|
||||
edit_preview: Some(prediction.edit_preview.clone()),
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,22 +1,23 @@
|
||||
use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
|
||||
|
||||
use crate::{
|
||||
EditPredictionId, EditPredictionStore, ZedUpdateRequiredError,
|
||||
DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
|
||||
EditPredictionStartedDebugEvent, EditPredictionStore, ZedUpdateRequiredError,
|
||||
cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
|
||||
prediction::{EditPredictionInputs, EditPredictionResult},
|
||||
prediction::EditPredictionResult,
|
||||
};
|
||||
use anyhow::{Context as _, Result};
|
||||
use cloud_llm_client::{
|
||||
PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
|
||||
predict_edits_v3::Event,
|
||||
};
|
||||
use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
|
||||
use language::{
|
||||
Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _, text_diff,
|
||||
Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset, ToPoint as _, text_diff,
|
||||
};
|
||||
use project::{Project, ProjectPath};
|
||||
use release_channel::AppVersion;
|
||||
use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
|
||||
use zeta_prompt::{Event, ZetaPromptInput};
|
||||
|
||||
const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
|
||||
const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
|
||||
@@ -29,24 +30,27 @@ pub(crate) const MAX_EVENT_TOKENS: usize = 500;
|
||||
|
||||
pub(crate) fn request_prediction_with_zeta1(
|
||||
store: &mut EditPredictionStore,
|
||||
project: &Entity<Project>,
|
||||
buffer: &Entity<Buffer>,
|
||||
snapshot: BufferSnapshot,
|
||||
position: language::Anchor,
|
||||
events: Vec<Arc<Event>>,
|
||||
trigger: PredictEditsRequestTrigger,
|
||||
EditPredictionModelInput {
|
||||
project,
|
||||
buffer,
|
||||
snapshot,
|
||||
position,
|
||||
events,
|
||||
trigger,
|
||||
debug_tx,
|
||||
..
|
||||
}: EditPredictionModelInput,
|
||||
cx: &mut Context<EditPredictionStore>,
|
||||
) -> Task<Result<Option<EditPredictionResult>>> {
|
||||
let buffer = buffer.clone();
|
||||
let buffer_snapshotted_at = Instant::now();
|
||||
let client = store.client.clone();
|
||||
let llm_token = store.llm_token.clone();
|
||||
let app_version = AppVersion::global(cx);
|
||||
|
||||
let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
|
||||
let can_collect_file = store.can_collect_file(project, file, cx);
|
||||
let can_collect_file = store.can_collect_file(&project, file, cx);
|
||||
let git_info = if can_collect_file {
|
||||
git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
|
||||
git_info_for_file(&project, &ProjectPath::from_file(file.as_ref(), cx), cx)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
@@ -120,33 +124,33 @@ pub(crate) fn request_prediction_with_zeta1(
|
||||
)
|
||||
.await;
|
||||
|
||||
let inputs = EditPredictionInputs {
|
||||
let context_start_offset = context_range.start.to_offset(&snapshot);
|
||||
let editable_offset_range = editable_range.to_offset(&snapshot);
|
||||
|
||||
let inputs = ZetaPromptInput {
|
||||
events: included_events.into(),
|
||||
included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
|
||||
path: full_path.clone(),
|
||||
max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
|
||||
excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
|
||||
start_line: cloud_llm_client::predict_edits_v3::Line(context_range.start.row),
|
||||
text: snapshot
|
||||
.text_for_range(context_range)
|
||||
.collect::<String>()
|
||||
.into(),
|
||||
}],
|
||||
}],
|
||||
cursor_point: cloud_llm_client::predict_edits_v3::Point {
|
||||
column: cursor_point.column,
|
||||
line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
|
||||
},
|
||||
related_files: vec![].into(),
|
||||
cursor_path: full_path,
|
||||
cursor_excerpt: snapshot
|
||||
.text_for_range(context_range)
|
||||
.collect::<String>()
|
||||
.into(),
|
||||
editable_range_in_excerpt: (editable_range.start - context_start_offset)
|
||||
..(editable_offset_range.end - context_start_offset),
|
||||
cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_start_offset,
|
||||
};
|
||||
|
||||
// let response = perform_predict_edits(PerformPredictEditsParams {
|
||||
// client,
|
||||
// llm_token,
|
||||
// app_version,
|
||||
// body,
|
||||
// })
|
||||
// .await;
|
||||
if let Some(debug_tx) = &debug_tx {
|
||||
debug_tx
|
||||
.unbounded_send(DebugEvent::EditPredictionStarted(
|
||||
EditPredictionStartedDebugEvent {
|
||||
buffer: buffer.downgrade(),
|
||||
prompt: Some(serde_json::to_string(&inputs).unwrap()),
|
||||
position,
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
|
||||
let (response, usage) = match response {
|
||||
Ok(response) => response,
|
||||
@@ -189,6 +193,18 @@ pub(crate) fn request_prediction_with_zeta1(
|
||||
.ok();
|
||||
}
|
||||
|
||||
if let Some(debug_tx) = &debug_tx {
|
||||
debug_tx
|
||||
.unbounded_send(DebugEvent::EditPredictionFinished(
|
||||
EditPredictionFinishedDebugEvent {
|
||||
buffer: buffer.downgrade(),
|
||||
model_output: Some(response.output_excerpt.clone()),
|
||||
position,
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
|
||||
let edit_prediction = process_completion_response(
|
||||
response,
|
||||
buffer,
|
||||
@@ -226,7 +242,7 @@ fn process_completion_response(
|
||||
buffer: Entity<Buffer>,
|
||||
snapshot: &BufferSnapshot,
|
||||
editable_range: Range<usize>,
|
||||
inputs: EditPredictionInputs,
|
||||
inputs: ZetaPromptInput,
|
||||
buffer_snapshotted_at: Instant,
|
||||
received_response_at: Instant,
|
||||
cx: &AsyncApp,
|
||||
|
||||
@@ -3,46 +3,39 @@ use crate::EvalCacheEntryKind;
|
||||
use crate::open_ai_response::text_from_response;
|
||||
use crate::prediction::EditPredictionResult;
|
||||
use crate::{
|
||||
DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionId, EditPredictionInputs,
|
||||
EditPredictionRequestedDebugEvent, EditPredictionStore,
|
||||
DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionFinishedDebugEvent, EditPredictionId,
|
||||
EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore,
|
||||
};
|
||||
use anyhow::{Result, anyhow, bail};
|
||||
use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
|
||||
use cloud_llm_client::{EditPredictionRejectReason, PredictEditsRequestTrigger};
|
||||
use cloud_zeta2_prompt::CURSOR_MARKER;
|
||||
use edit_prediction_context::{EditPredictionExcerpt, Line};
|
||||
use edit_prediction_context::{RelatedExcerpt, RelatedFile};
|
||||
use futures::channel::oneshot;
|
||||
use gpui::{Entity, Task, prelude::*};
|
||||
use language::{Anchor, BufferSnapshot};
|
||||
use language::{Buffer, Point, ToOffset as _, ToPoint};
|
||||
use project::{Project, ProjectItem as _};
|
||||
use anyhow::{Result, anyhow};
|
||||
use cloud_llm_client::EditPredictionRejectReason;
|
||||
use gpui::{Task, prelude::*};
|
||||
use language::{OffsetRangeExt as _, ToOffset as _, ToPoint};
|
||||
use release_channel::AppVersion;
|
||||
use std::{
|
||||
env,
|
||||
path::Path,
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use std::{path::Path, sync::Arc, time::Instant};
|
||||
use zeta_prompt::CURSOR_MARKER;
|
||||
use zeta_prompt::format_zeta_prompt;
|
||||
|
||||
const MAX_CONTEXT_TOKENS: usize = 150;
|
||||
const MAX_REWRITE_TOKENS: usize = 350;
|
||||
|
||||
pub fn request_prediction_with_zeta2(
|
||||
store: &mut EditPredictionStore,
|
||||
project: &Entity<Project>,
|
||||
active_buffer: &Entity<Buffer>,
|
||||
active_snapshot: BufferSnapshot,
|
||||
position: Anchor,
|
||||
events: Vec<Arc<Event>>,
|
||||
mut included_files: Vec<RelatedFile>,
|
||||
trigger: PredictEditsRequestTrigger,
|
||||
EditPredictionModelInput {
|
||||
buffer,
|
||||
snapshot,
|
||||
position,
|
||||
related_files,
|
||||
events,
|
||||
debug_tx,
|
||||
..
|
||||
}: EditPredictionModelInput,
|
||||
cx: &mut Context<EditPredictionStore>,
|
||||
) -> Task<Result<Option<EditPredictionResult>>> {
|
||||
let options = store.options.clone();
|
||||
let buffer_snapshotted_at = Instant::now();
|
||||
|
||||
let Some((excerpt_path, active_project_path)) = active_snapshot
|
||||
let Some(excerpt_path) = snapshot
|
||||
.file()
|
||||
.map(|file| -> Arc<Path> { file.full_path(cx).into() })
|
||||
.zip(active_buffer.read(cx).project_path(cx))
|
||||
else {
|
||||
return Task::ready(Err(anyhow!("No file path for excerpt")));
|
||||
};
|
||||
@@ -50,148 +43,35 @@ pub fn request_prediction_with_zeta2(
|
||||
let client = store.client.clone();
|
||||
let llm_token = store.llm_token.clone();
|
||||
let app_version = AppVersion::global(cx);
|
||||
let debug_tx = store.debug_tx.clone();
|
||||
|
||||
let file = active_buffer.read(cx).file();
|
||||
|
||||
let active_file_full_path = file.as_ref().map(|f| f.full_path(cx));
|
||||
|
||||
// TODO data collection
|
||||
let can_collect_data = file
|
||||
.as_ref()
|
||||
.map_or(false, |file| store.can_collect_file(project, file, cx));
|
||||
|
||||
#[cfg(feature = "eval-support")]
|
||||
let eval_cache = store.eval_cache.clone();
|
||||
|
||||
let request_task = cx.background_spawn({
|
||||
let active_buffer = active_buffer.clone();
|
||||
async move {
|
||||
let cursor_offset = position.to_offset(&active_snapshot);
|
||||
let cursor_point = cursor_offset.to_point(&active_snapshot);
|
||||
|
||||
let before_retrieval = Instant::now();
|
||||
|
||||
let excerpt_options = options.context;
|
||||
|
||||
let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
|
||||
cursor_point,
|
||||
&active_snapshot,
|
||||
&excerpt_options,
|
||||
) else {
|
||||
return Ok((None, None));
|
||||
};
|
||||
|
||||
let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
|
||||
..active_snapshot.anchor_before(excerpt.range.end);
|
||||
let related_excerpt = RelatedExcerpt {
|
||||
anchor_range: excerpt_anchor_range.clone(),
|
||||
point_range: Point::new(excerpt.line_range.start.0, 0)
|
||||
..Point::new(excerpt.line_range.end.0, 0),
|
||||
text: active_snapshot.as_rope().slice(excerpt.range),
|
||||
};
|
||||
|
||||
if let Some(buffer_ix) = included_files
|
||||
.iter()
|
||||
.position(|file| file.buffer.entity_id() == active_buffer.entity_id())
|
||||
{
|
||||
let file = &mut included_files[buffer_ix];
|
||||
file.excerpts.push(related_excerpt);
|
||||
file.merge_excerpts();
|
||||
let last_ix = included_files.len() - 1;
|
||||
included_files.swap(buffer_ix, last_ix);
|
||||
} else {
|
||||
let active_file = RelatedFile {
|
||||
path: active_project_path,
|
||||
buffer: active_buffer.downgrade(),
|
||||
excerpts: vec![related_excerpt],
|
||||
max_row: active_snapshot.max_point().row,
|
||||
};
|
||||
included_files.push(active_file);
|
||||
}
|
||||
|
||||
let included_files = included_files
|
||||
.iter()
|
||||
.map(|related_file| predict_edits_v3::RelatedFile {
|
||||
path: Arc::from(related_file.path.path.as_std_path()),
|
||||
max_row: Line(related_file.max_row),
|
||||
excerpts: related_file
|
||||
.excerpts
|
||||
.iter()
|
||||
.map(|excerpt| predict_edits_v3::Excerpt {
|
||||
start_line: Line(excerpt.point_range.start.row),
|
||||
text: excerpt.text.to_string().into(),
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let cloud_request = predict_edits_v3::PredictEditsRequest {
|
||||
excerpt_path,
|
||||
excerpt: String::new(),
|
||||
excerpt_line_range: Line(0)..Line(0),
|
||||
excerpt_range: 0..0,
|
||||
cursor_point: predict_edits_v3::Point {
|
||||
line: predict_edits_v3::Line(cursor_point.row),
|
||||
column: cursor_point.column,
|
||||
},
|
||||
related_files: included_files,
|
||||
let cursor_offset = position.to_offset(&snapshot);
|
||||
let (editable_offset_range, prompt_input) = zeta2_prompt_input(
|
||||
&snapshot,
|
||||
related_files,
|
||||
events,
|
||||
can_collect_data,
|
||||
debug_info: debug_tx.is_some(),
|
||||
prompt_max_bytes: Some(options.max_prompt_bytes),
|
||||
prompt_format: options.prompt_format,
|
||||
excerpt_parent: None,
|
||||
git_info: None,
|
||||
trigger,
|
||||
};
|
||||
excerpt_path,
|
||||
cursor_offset,
|
||||
);
|
||||
|
||||
let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
|
||||
|
||||
let inputs = EditPredictionInputs {
|
||||
included_files: cloud_request.related_files,
|
||||
events: cloud_request.events,
|
||||
cursor_point: cloud_request.cursor_point,
|
||||
cursor_path: cloud_request.excerpt_path,
|
||||
};
|
||||
|
||||
let retrieval_time = Instant::now() - before_retrieval;
|
||||
|
||||
let debug_response_tx = if let Some(debug_tx) = &debug_tx {
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
let prompt = format_zeta_prompt(&prompt_input);
|
||||
|
||||
if let Some(debug_tx) = &debug_tx {
|
||||
debug_tx
|
||||
.unbounded_send(DebugEvent::EditPredictionRequested(
|
||||
EditPredictionRequestedDebugEvent {
|
||||
inputs: inputs.clone(),
|
||||
retrieval_time,
|
||||
buffer: active_buffer.downgrade(),
|
||||
local_prompt: match prompt_result.as_ref() {
|
||||
Ok(prompt) => Ok(prompt.clone()),
|
||||
Err(err) => Err(err.to_string()),
|
||||
},
|
||||
.unbounded_send(DebugEvent::EditPredictionStarted(
|
||||
EditPredictionStartedDebugEvent {
|
||||
buffer: buffer.downgrade(),
|
||||
prompt: Some(prompt.clone()),
|
||||
position,
|
||||
response_rx,
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
Some(response_tx)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
|
||||
if let Some(debug_response_tx) = debug_response_tx {
|
||||
debug_response_tx
|
||||
.send((Err("Request skipped".to_string()), Duration::ZERO))
|
||||
.ok();
|
||||
}
|
||||
anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
|
||||
}
|
||||
|
||||
let prompt = prompt_result?;
|
||||
let generation_params =
|
||||
cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
|
||||
let request = open_ai::Request {
|
||||
model: EDIT_PREDICTIONS_MODEL_ID.clone(),
|
||||
messages: vec![open_ai::RequestMessage::User {
|
||||
@@ -199,8 +79,8 @@ pub fn request_prediction_with_zeta2(
|
||||
}],
|
||||
stream: false,
|
||||
max_completion_tokens: None,
|
||||
stop: generation_params.stop.unwrap_or_default(),
|
||||
temperature: generation_params.temperature.or(Some(0.7)),
|
||||
stop: Default::default(),
|
||||
temperature: Default::default(),
|
||||
tool_choice: None,
|
||||
parallel_tool_calls: None,
|
||||
tools: vec![],
|
||||
@@ -210,7 +90,6 @@ pub fn request_prediction_with_zeta2(
|
||||
|
||||
log::trace!("Sending edit prediction request");
|
||||
|
||||
let before_request = Instant::now();
|
||||
let response = EditPredictionStore::send_raw_llm_request(
|
||||
request,
|
||||
client,
|
||||
@@ -223,68 +102,53 @@ pub fn request_prediction_with_zeta2(
|
||||
)
|
||||
.await;
|
||||
let received_response_at = Instant::now();
|
||||
let request_time = received_response_at - before_request;
|
||||
|
||||
log::trace!("Got edit prediction response");
|
||||
|
||||
if let Some(debug_response_tx) = debug_response_tx {
|
||||
debug_response_tx
|
||||
.send((
|
||||
response
|
||||
.as_ref()
|
||||
.map_err(|err| err.to_string())
|
||||
.map(|response| response.0.clone()),
|
||||
request_time,
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
|
||||
let (res, usage) = response?;
|
||||
let request_id = EditPredictionId(res.id.clone().into());
|
||||
let Some(mut output_text) = text_from_response(res) else {
|
||||
return Ok((Some((request_id, None)), usage));
|
||||
};
|
||||
|
||||
if let Some(debug_tx) = &debug_tx {
|
||||
debug_tx
|
||||
.unbounded_send(DebugEvent::EditPredictionFinished(
|
||||
EditPredictionFinishedDebugEvent {
|
||||
buffer: buffer.downgrade(),
|
||||
position,
|
||||
model_output: Some(output_text.clone()),
|
||||
},
|
||||
))
|
||||
.ok();
|
||||
}
|
||||
|
||||
if output_text.contains(CURSOR_MARKER) {
|
||||
log::trace!("Stripping out {CURSOR_MARKER} from response");
|
||||
output_text = output_text.replace(CURSOR_MARKER, "");
|
||||
}
|
||||
|
||||
let get_buffer_from_context = |path: &Path| {
|
||||
if Some(path) == active_file_full_path.as_deref() {
|
||||
Some((
|
||||
&active_snapshot,
|
||||
std::slice::from_ref(&excerpt_anchor_range),
|
||||
))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
let (_, edits) = match options.prompt_format {
|
||||
PromptFormat::Minimal | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => {
|
||||
if output_text.contains("--- a/\n+++ b/\nNo edits") {
|
||||
let edits = vec![];
|
||||
(&active_snapshot, edits)
|
||||
} else {
|
||||
crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
|
||||
}
|
||||
}
|
||||
PromptFormat::OldTextNewText => {
|
||||
crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context).await?
|
||||
}
|
||||
_ => {
|
||||
bail!("unsupported prompt format {}", options.prompt_format)
|
||||
}
|
||||
};
|
||||
let old_text = snapshot
|
||||
.text_for_range(editable_offset_range.clone())
|
||||
.collect::<String>();
|
||||
let edits: Vec<_> = language::text_diff(&old_text, &output_text)
|
||||
.into_iter()
|
||||
.map(|(range, text)| {
|
||||
(
|
||||
snapshot.anchor_after(editable_offset_range.start + range.start)
|
||||
..snapshot.anchor_before(editable_offset_range.start + range.end),
|
||||
text,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
anyhow::Ok((
|
||||
Some((
|
||||
request_id,
|
||||
Some((
|
||||
inputs,
|
||||
active_buffer,
|
||||
active_snapshot.clone(),
|
||||
prompt_input,
|
||||
buffer,
|
||||
snapshot.clone(),
|
||||
edits,
|
||||
received_response_at,
|
||||
)),
|
||||
@@ -325,3 +189,40 @@ pub fn request_prediction_with_zeta2(
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn zeta2_prompt_input(
|
||||
snapshot: &language::BufferSnapshot,
|
||||
related_files: Arc<[zeta_prompt::RelatedFile]>,
|
||||
events: Vec<Arc<zeta_prompt::Event>>,
|
||||
excerpt_path: Arc<Path>,
|
||||
cursor_offset: usize,
|
||||
) -> (std::ops::Range<usize>, zeta_prompt::ZetaPromptInput) {
|
||||
let cursor_point = cursor_offset.to_point(snapshot);
|
||||
|
||||
let (editable_range, context_range) =
|
||||
crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
|
||||
cursor_point,
|
||||
snapshot,
|
||||
MAX_CONTEXT_TOKENS,
|
||||
MAX_REWRITE_TOKENS,
|
||||
);
|
||||
|
||||
let context_start_offset = context_range.start.to_offset(snapshot);
|
||||
let editable_offset_range = editable_range.to_offset(snapshot);
|
||||
let cursor_offset_in_excerpt = cursor_offset - context_start_offset;
|
||||
let editable_range_in_excerpt = (editable_offset_range.start - context_start_offset)
|
||||
..(editable_offset_range.end - context_start_offset);
|
||||
|
||||
let prompt_input = zeta_prompt::ZetaPromptInput {
|
||||
cursor_path: excerpt_path,
|
||||
cursor_excerpt: snapshot
|
||||
.text_for_range(context_range)
|
||||
.collect::<String>()
|
||||
.into(),
|
||||
editable_range_in_excerpt,
|
||||
cursor_offset_in_excerpt,
|
||||
events,
|
||||
related_files,
|
||||
};
|
||||
(editable_offset_range, prompt_input)
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ license = "GPL-3.0-or-later"
|
||||
workspace = true
|
||||
|
||||
[[bin]]
|
||||
name = "ep_cli"
|
||||
name = "ep"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
@@ -20,10 +20,9 @@ chrono.workspace = true
|
||||
clap.workspace = true
|
||||
client.workspace = true
|
||||
cloud_llm_client.workspace= true
|
||||
cloud_zeta2_prompt.workspace = true
|
||||
collections.workspace = true
|
||||
debug_adapter_extension.workspace = true
|
||||
edit_prediction_context.workspace = true
|
||||
dirs.workspace = true
|
||||
extension.workspace = true
|
||||
fs.workspace = true
|
||||
futures.workspace = true
|
||||
@@ -51,12 +50,21 @@ smol.workspace = true
|
||||
sqlez.workspace = true
|
||||
sqlez_macros.workspace = true
|
||||
terminal_view.workspace = true
|
||||
toml.workspace = true
|
||||
util.workspace = true
|
||||
watch.workspace = true
|
||||
edit_prediction = { workspace = true, features = ["eval-support"] }
|
||||
wasmtime.workspace = true
|
||||
zeta_prompt.workspace = true
|
||||
zlog.workspace = true
|
||||
|
||||
# Wasmtime is included as a dependency in order to enable the same
|
||||
# features that are enabled in Zed.
|
||||
#
|
||||
# If we don't enable these features we get crashes when creating
|
||||
# a Tree-sitter WasmStore.
|
||||
[package.metadata.cargo-machete]
|
||||
ignored = ["wasmtime"]
|
||||
|
||||
[dev-dependencies]
|
||||
indoc.workspace = true
|
||||
gpui = { workspace = true, features = ["test-support"] }
|
||||
|
||||
@@ -5,11 +5,13 @@ use anthropic::{
|
||||
use anyhow::Result;
|
||||
use http_client::HttpClient;
|
||||
use indoc::indoc;
|
||||
use reqwest_client::ReqwestClient;
|
||||
use sqlez::bindable::Bind;
|
||||
use sqlez::bindable::StaticColumnCount;
|
||||
use sqlez_macros::sql;
|
||||
use std::hash::Hash;
|
||||
use std::hash::Hasher;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct PlainLlmClient {
|
||||
@@ -18,7 +20,8 @@ pub struct PlainLlmClient {
|
||||
}
|
||||
|
||||
impl PlainLlmClient {
|
||||
fn new(http_client: Arc<dyn HttpClient>) -> Result<Self> {
|
||||
fn new() -> Result<Self> {
|
||||
let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
|
||||
let api_key = std::env::var("ANTHROPIC_API_KEY")
|
||||
.map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
|
||||
Ok(Self {
|
||||
@@ -29,12 +32,12 @@ impl PlainLlmClient {
|
||||
|
||||
async fn generate(
|
||||
&self,
|
||||
model: String,
|
||||
model: &str,
|
||||
max_tokens: u64,
|
||||
messages: Vec<Message>,
|
||||
) -> Result<AnthropicResponse> {
|
||||
let request = AnthropicRequest {
|
||||
model,
|
||||
model: model.to_string(),
|
||||
max_tokens,
|
||||
messages,
|
||||
tools: Vec::new(),
|
||||
@@ -105,11 +108,12 @@ struct SerializableMessage {
|
||||
}
|
||||
|
||||
impl BatchingLlmClient {
|
||||
fn new(cache_path: &str, http_client: Arc<dyn HttpClient>) -> Result<Self> {
|
||||
fn new(cache_path: &Path) -> Result<Self> {
|
||||
let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
|
||||
let api_key = std::env::var("ANTHROPIC_API_KEY")
|
||||
.map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
|
||||
|
||||
let connection = sqlez::connection::Connection::open_file(&cache_path);
|
||||
let connection = sqlez::connection::Connection::open_file(&cache_path.to_str().unwrap());
|
||||
let mut statement = sqlez::statement::Statement::prepare(
|
||||
&connection,
|
||||
indoc! {"
|
||||
@@ -182,16 +186,16 @@ impl BatchingLlmClient {
|
||||
|
||||
async fn generate(
|
||||
&self,
|
||||
model: String,
|
||||
model: &str,
|
||||
max_tokens: u64,
|
||||
messages: Vec<Message>,
|
||||
) -> Result<Option<AnthropicResponse>> {
|
||||
let response = self.lookup(&model, max_tokens, &messages)?;
|
||||
let response = self.lookup(model, max_tokens, &messages)?;
|
||||
if let Some(response) = response {
|
||||
return Ok(Some(response));
|
||||
}
|
||||
|
||||
self.mark_for_batch(&model, max_tokens, &messages)?;
|
||||
self.mark_for_batch(model, max_tokens, &messages)?;
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
@@ -258,7 +262,7 @@ impl BatchingLlmClient {
|
||||
}
|
||||
}
|
||||
}
|
||||
log::info!("Uploaded {} successful requests", success_count);
|
||||
log::info!("Downloaded {} successful requests", success_count);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -363,23 +367,20 @@ fn message_content_to_string(content: &[RequestContent]) -> String {
|
||||
.join("\n")
|
||||
}
|
||||
|
||||
pub enum LlmClient {
|
||||
pub enum AnthropicClient {
|
||||
// No batching
|
||||
Plain(PlainLlmClient),
|
||||
Batch(BatchingLlmClient),
|
||||
Dummy,
|
||||
}
|
||||
|
||||
impl LlmClient {
|
||||
pub fn plain(http_client: Arc<dyn HttpClient>) -> Result<Self> {
|
||||
Ok(Self::Plain(PlainLlmClient::new(http_client)?))
|
||||
impl AnthropicClient {
|
||||
pub fn plain() -> Result<Self> {
|
||||
Ok(Self::Plain(PlainLlmClient::new()?))
|
||||
}
|
||||
|
||||
pub fn batch(cache_path: &str, http_client: Arc<dyn HttpClient>) -> Result<Self> {
|
||||
Ok(Self::Batch(BatchingLlmClient::new(
|
||||
cache_path,
|
||||
http_client,
|
||||
)?))
|
||||
pub fn batch(cache_path: &Path) -> Result<Self> {
|
||||
Ok(Self::Batch(BatchingLlmClient::new(cache_path)?))
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
@@ -389,29 +390,29 @@ impl LlmClient {
|
||||
|
||||
pub async fn generate(
|
||||
&self,
|
||||
model: String,
|
||||
model: &str,
|
||||
max_tokens: u64,
|
||||
messages: Vec<Message>,
|
||||
) -> Result<Option<AnthropicResponse>> {
|
||||
match self {
|
||||
LlmClient::Plain(plain_llm_client) => plain_llm_client
|
||||
AnthropicClient::Plain(plain_llm_client) => plain_llm_client
|
||||
.generate(model, max_tokens, messages)
|
||||
.await
|
||||
.map(Some),
|
||||
LlmClient::Batch(batching_llm_client) => {
|
||||
AnthropicClient::Batch(batching_llm_client) => {
|
||||
batching_llm_client
|
||||
.generate(model, max_tokens, messages)
|
||||
.await
|
||||
}
|
||||
LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
|
||||
AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn sync_batches(&self) -> Result<()> {
|
||||
match self {
|
||||
LlmClient::Plain(_) => Ok(()),
|
||||
LlmClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
|
||||
LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
|
||||
AnthropicClient::Plain(_) => Ok(()),
|
||||
AnthropicClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
|
||||
AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,641 +0,0 @@
|
||||
use crate::metrics::{self, Scores};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
io::{IsTerminal, Write},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use anyhow::Result;
|
||||
use edit_prediction::{EditPredictionStore, udiff::DiffLine};
|
||||
use gpui::{AsyncApp, Entity};
|
||||
use project::Project;
|
||||
use util::ResultExt as _;
|
||||
|
||||
use crate::{
|
||||
EvaluateArguments, PredictionOptions,
|
||||
example::{Example, NamedExample},
|
||||
headless::ZetaCliAppState,
|
||||
paths::print_run_data_dir,
|
||||
predict::{PredictionDetails, perform_predict, setup_store},
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct ExecutionData {
|
||||
execution_id: String,
|
||||
diff: String,
|
||||
reasoning: String,
|
||||
}
|
||||
|
||||
pub async fn run_evaluate(
|
||||
args: EvaluateArguments,
|
||||
app_state: &Arc<ZetaCliAppState>,
|
||||
cx: &mut AsyncApp,
|
||||
) {
|
||||
if args.example_paths.is_empty() {
|
||||
eprintln!("No examples provided");
|
||||
return;
|
||||
}
|
||||
|
||||
let all_tasks = args.example_paths.into_iter().map(|path| {
|
||||
let options = args.options.clone();
|
||||
let app_state = app_state.clone();
|
||||
let example = NamedExample::load(&path).expect("Failed to load example");
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let project = example.setup_project(&app_state, cx).await.unwrap();
|
||||
|
||||
let providers = (0..args.repetitions)
|
||||
.map(|_| setup_store(args.options.provider, &project, &app_state, cx).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
|
||||
|
||||
let tasks = providers
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(move |(repetition_ix, store)| {
|
||||
let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
|
||||
let example = example.clone();
|
||||
let project = project.clone();
|
||||
let options = options.clone();
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
let name = example.name.clone();
|
||||
run_evaluate_one(
|
||||
example,
|
||||
repetition_ix,
|
||||
project,
|
||||
store,
|
||||
options,
|
||||
!args.skip_prediction,
|
||||
cx,
|
||||
)
|
||||
.await
|
||||
.map_err(|err| (err, name, repetition_ix))
|
||||
})
|
||||
});
|
||||
futures::future::join_all(tasks).await
|
||||
})
|
||||
});
|
||||
let all_results = futures::future::join_all(all_tasks).await;
|
||||
|
||||
write_aggregated_scores(&mut std::io::stdout(), &all_results).unwrap();
|
||||
if let Some(mut output_file) =
|
||||
std::fs::File::create(crate::paths::RUN_DIR.join("aggregated_results.md")).log_err()
|
||||
{
|
||||
write_aggregated_scores(&mut output_file, &all_results).log_err();
|
||||
};
|
||||
|
||||
if args.repetitions > 1 {
|
||||
if let Err(e) = write_bucketed_analysis(&all_results) {
|
||||
eprintln!("Failed to write bucketed analysis: {:?}", e);
|
||||
}
|
||||
}
|
||||
|
||||
print_run_data_dir(args.repetitions == 1, std::io::stdout().is_terminal());
|
||||
}
|
||||
|
||||
fn write_aggregated_scores(
|
||||
w: &mut impl std::io::Write,
|
||||
all_results: &Vec<
|
||||
Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
|
||||
>,
|
||||
) -> Result<()> {
|
||||
let mut successful = Vec::new();
|
||||
let mut failed_count = 0;
|
||||
|
||||
for result in all_results.iter().flatten() {
|
||||
match result {
|
||||
Ok((eval_result, _execution_data)) => successful.push(eval_result),
|
||||
Err((err, name, repetition_ix)) => {
|
||||
if failed_count == 0 {
|
||||
writeln!(w, "## Errors\n")?;
|
||||
}
|
||||
|
||||
failed_count += 1;
|
||||
writeln!(w, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if successful.len() > 1 {
|
||||
let edit_scores = successful
|
||||
.iter()
|
||||
.filter_map(|r| r.edit_scores.clone())
|
||||
.collect::<Vec<_>>();
|
||||
let has_edit_predictions = edit_scores.len() > 0;
|
||||
let aggregated_result = EvaluationResult {
|
||||
context_scores: Scores::aggregate(successful.iter().map(|r| &r.context_scores)),
|
||||
edit_scores: has_edit_predictions.then(|| EditScores::aggregate(&edit_scores)),
|
||||
prompt_len: successful.iter().map(|r| r.prompt_len).sum::<usize>() / successful.len(),
|
||||
generated_len: successful.iter().map(|r| r.generated_len).sum::<usize>()
|
||||
/ successful.len(),
|
||||
};
|
||||
|
||||
writeln!(w, "\n{}", "-".repeat(80))?;
|
||||
writeln!(w, "\n## TOTAL SCORES")?;
|
||||
writeln!(w, "{:#}", aggregated_result)?;
|
||||
}
|
||||
|
||||
if successful.len() + failed_count > 1 {
|
||||
writeln!(
|
||||
w,
|
||||
"\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉",
|
||||
successful.len(),
|
||||
successful.len() + failed_count,
|
||||
(successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn run_evaluate_one(
|
||||
example: NamedExample,
|
||||
repetition_ix: Option<u16>,
|
||||
project: Entity<Project>,
|
||||
store: Entity<EditPredictionStore>,
|
||||
prediction_options: PredictionOptions,
|
||||
predict: bool,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<(EvaluationResult, ExecutionData)> {
|
||||
let predict_result = perform_predict(
|
||||
example.clone(),
|
||||
project,
|
||||
store,
|
||||
repetition_ix,
|
||||
prediction_options,
|
||||
cx,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let evaluation_result = evaluate(&example.example, &predict_result, predict);
|
||||
|
||||
if repetition_ix.is_none() {
|
||||
write_eval_result(
|
||||
&example,
|
||||
&predict_result,
|
||||
&evaluation_result,
|
||||
&mut std::io::stdout(),
|
||||
std::io::stdout().is_terminal(),
|
||||
predict,
|
||||
)?;
|
||||
}
|
||||
|
||||
if let Some(mut results_file) =
|
||||
std::fs::File::create(predict_result.run_example_dir.join("results.md")).log_err()
|
||||
{
|
||||
write_eval_result(
|
||||
&example,
|
||||
&predict_result,
|
||||
&evaluation_result,
|
||||
&mut results_file,
|
||||
false,
|
||||
predict,
|
||||
)
|
||||
.log_err();
|
||||
}
|
||||
|
||||
let execution_data = ExecutionData {
|
||||
execution_id: if let Some(rep_ix) = repetition_ix {
|
||||
format!("{:03}", rep_ix)
|
||||
} else {
|
||||
example.name.clone()
|
||||
},
|
||||
diff: predict_result.diff.clone(),
|
||||
reasoning: std::fs::read_to_string(
|
||||
predict_result
|
||||
.run_example_dir
|
||||
.join("prediction_response.md"),
|
||||
)
|
||||
.unwrap_or_default(),
|
||||
};
|
||||
|
||||
anyhow::Ok((evaluation_result, execution_data))
|
||||
}
|
||||
|
||||
fn write_eval_result(
|
||||
example: &NamedExample,
|
||||
predictions: &PredictionDetails,
|
||||
evaluation_result: &EvaluationResult,
|
||||
out: &mut impl Write,
|
||||
use_color: bool,
|
||||
predict: bool,
|
||||
) -> Result<()> {
|
||||
if predict {
|
||||
writeln!(
|
||||
out,
|
||||
"## Expected edit prediction:\n\n```diff\n{}\n```\n",
|
||||
compare_diffs(
|
||||
&example.example.expected_patch,
|
||||
&predictions.diff,
|
||||
use_color
|
||||
)
|
||||
)?;
|
||||
writeln!(
|
||||
out,
|
||||
"## Actual edit prediction:\n\n```diff\n{}\n```\n",
|
||||
compare_diffs(
|
||||
&predictions.diff,
|
||||
&example.example.expected_patch,
|
||||
use_color
|
||||
)
|
||||
)?;
|
||||
}
|
||||
|
||||
writeln!(out, "{:#}", evaluation_result)?;
|
||||
|
||||
anyhow::Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Clone)]
|
||||
pub struct EditScores {
|
||||
pub line_match: Scores,
|
||||
pub chr_f: f64,
|
||||
}
|
||||
|
||||
impl EditScores {
|
||||
pub fn aggregate(scores: &[EditScores]) -> EditScores {
|
||||
let line_match = Scores::aggregate(scores.iter().map(|s| &s.line_match));
|
||||
let chr_f = scores.iter().map(|s| s.chr_f).sum::<f64>() / scores.len() as f64;
|
||||
|
||||
EditScores { line_match, chr_f }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct EvaluationResult {
|
||||
pub edit_scores: Option<EditScores>,
|
||||
pub context_scores: Scores,
|
||||
pub prompt_len: usize,
|
||||
pub generated_len: usize,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for EvaluationResult {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
if f.alternate() {
|
||||
self.fmt_table(f)
|
||||
} else {
|
||||
self.fmt_markdown(f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl EvaluationResult {
|
||||
fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
r#"
|
||||
### Context Scores
|
||||
{}
|
||||
"#,
|
||||
self.context_scores.to_markdown(),
|
||||
)?;
|
||||
if let Some(scores) = &self.edit_scores {
|
||||
write!(
|
||||
f,
|
||||
r#"
|
||||
### Edit Prediction Scores
|
||||
{}"#,
|
||||
scores.line_match.to_markdown()
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn fmt_table(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
writeln!(f, "#### Prompt Statistics")?;
|
||||
writeln!(f, "─────────────────────────")?;
|
||||
writeln!(f, "Prompt_len Generated_len")?;
|
||||
writeln!(f, "─────────────────────────")?;
|
||||
writeln!(f, "{:<11} {:<14}", self.prompt_len, self.generated_len,)?;
|
||||
writeln!(f)?;
|
||||
writeln!(f)?;
|
||||
writeln!(f, "#### Performance Scores")?;
|
||||
writeln!(
|
||||
f,
|
||||
"──────────────────────────────────────────────────────────────────"
|
||||
)?;
|
||||
writeln!(
|
||||
f,
|
||||
" TP FP FN Precision Recall F1"
|
||||
)?;
|
||||
writeln!(
|
||||
f,
|
||||
"──────────────────────────────────────────────────────────────────"
|
||||
)?;
|
||||
writeln!(
|
||||
f,
|
||||
"Context Retrieval {:<6} {:<6} {:<6} {:>8.2} {:>7.2} {:>6.2}",
|
||||
self.context_scores.true_positives,
|
||||
self.context_scores.false_positives,
|
||||
self.context_scores.false_negatives,
|
||||
self.context_scores.precision() * 100.0,
|
||||
self.context_scores.recall() * 100.0,
|
||||
self.context_scores.f1_score() * 100.0
|
||||
)?;
|
||||
if let Some(edit_scores) = &self.edit_scores {
|
||||
let line_match = &edit_scores.line_match;
|
||||
writeln!(f, "Edit Prediction")?;
|
||||
writeln!(
|
||||
f,
|
||||
" ├─ exact lines {:<6} {:<6} {:<6} {:>8.2} {:>7.2} {:>6.2}",
|
||||
line_match.true_positives,
|
||||
line_match.false_positives,
|
||||
line_match.false_negatives,
|
||||
line_match.precision() * 100.0,
|
||||
line_match.recall() * 100.0,
|
||||
line_match.f1_score() * 100.0
|
||||
)?;
|
||||
writeln!(
|
||||
f,
|
||||
" └─ diff chrF {:<6} {:<6} {:<6} {:>8} {:>8} {:>6.2}",
|
||||
"-", "-", "-", "-", "-", edit_scores.chr_f
|
||||
)?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> EvaluationResult {
|
||||
let mut eval_result = EvaluationResult {
|
||||
prompt_len: preds.prompt_len,
|
||||
generated_len: preds.generated_len,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
if predict {
|
||||
// todo: alternatives for patches
|
||||
let expected_patch = example
|
||||
.expected_patch
|
||||
.lines()
|
||||
.map(DiffLine::parse)
|
||||
.collect::<Vec<_>>();
|
||||
let actual_patch = preds.diff.lines().map(DiffLine::parse).collect::<Vec<_>>();
|
||||
|
||||
let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
|
||||
let chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch);
|
||||
|
||||
eval_result.edit_scores = Some(EditScores { line_match, chr_f });
|
||||
}
|
||||
|
||||
eval_result
|
||||
}
|
||||
|
||||
/// Return annotated `patch_a` so that:
|
||||
/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
|
||||
/// Additions and deletions that are present in `patch_b` will be highlighted in green.
|
||||
pub fn compare_diffs(patch_a: &str, patch_b: &str, use_color: bool) -> String {
|
||||
let green = if use_color { "\x1b[32m✓ " } else { "" };
|
||||
let red = if use_color { "\x1b[31m✗ " } else { "" };
|
||||
let neutral = if use_color { " " } else { "" };
|
||||
let reset = if use_color { "\x1b[0m" } else { "" };
|
||||
let lines_a = patch_a.lines().map(DiffLine::parse);
|
||||
let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
|
||||
|
||||
let annotated = lines_a
|
||||
.map(|line| match line {
|
||||
DiffLine::Addition(_) | DiffLine::Deletion(_) => {
|
||||
if lines_b.contains(&line) {
|
||||
format!("{green}{line}{reset}")
|
||||
} else {
|
||||
format!("{red}{line}{reset}")
|
||||
}
|
||||
}
|
||||
_ => format!("{neutral}{line}{reset}"),
|
||||
})
|
||||
.collect::<Vec<String>>();
|
||||
|
||||
annotated.join("\n")
|
||||
}
|
||||
|
||||
fn write_bucketed_analysis(
|
||||
all_results: &Vec<
|
||||
Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
|
||||
>,
|
||||
) -> Result<()> {
|
||||
#[derive(Debug)]
|
||||
struct EditBucket {
|
||||
diff: String,
|
||||
is_correct: bool,
|
||||
execution_indices: Vec<String>,
|
||||
reasoning_samples: Vec<String>,
|
||||
}
|
||||
|
||||
let mut total_executions = 0;
|
||||
let mut empty_predictions = Vec::new();
|
||||
let mut errors = Vec::new();
|
||||
|
||||
let mut buckets: HashMap<String, EditBucket> = HashMap::new();
|
||||
|
||||
for result in all_results.iter().flatten() {
|
||||
total_executions += 1;
|
||||
|
||||
let (evaluation_result, execution_data) = match result {
|
||||
Ok((eval_result, execution_data)) => {
|
||||
if execution_data.diff.is_empty() {
|
||||
empty_predictions.push(execution_data);
|
||||
continue;
|
||||
}
|
||||
(eval_result, execution_data)
|
||||
}
|
||||
Err(err) => {
|
||||
errors.push(err);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
buckets
|
||||
.entry(execution_data.diff.clone())
|
||||
.and_modify(|bucket| {
|
||||
bucket
|
||||
.execution_indices
|
||||
.push(execution_data.execution_id.clone());
|
||||
bucket
|
||||
.reasoning_samples
|
||||
.push(execution_data.reasoning.clone());
|
||||
})
|
||||
.or_insert_with(|| EditBucket {
|
||||
diff: execution_data.diff.clone(),
|
||||
is_correct: {
|
||||
evaluation_result
|
||||
.edit_scores
|
||||
.as_ref()
|
||||
.map_or(false, |edit_scores| {
|
||||
edit_scores.line_match.false_positives == 0
|
||||
&& edit_scores.line_match.false_negatives == 0
|
||||
&& edit_scores.line_match.true_positives > 0
|
||||
})
|
||||
},
|
||||
execution_indices: vec![execution_data.execution_id.clone()],
|
||||
reasoning_samples: vec![execution_data.reasoning.clone()],
|
||||
});
|
||||
}
|
||||
|
||||
let mut sorted_buckets = buckets.into_values().collect::<Vec<_>>();
|
||||
sorted_buckets.sort_by(|a, b| match (a.is_correct, b.is_correct) {
|
||||
(true, false) => std::cmp::Ordering::Less,
|
||||
(false, true) => std::cmp::Ordering::Greater,
|
||||
_ => b.execution_indices.len().cmp(&a.execution_indices.len()),
|
||||
});
|
||||
|
||||
let output_path = crate::paths::RUN_DIR.join("bucketed_analysis.md");
|
||||
let mut output = std::fs::File::create(&output_path)?;
|
||||
|
||||
writeln!(output, "# Bucketed Edit Analysis\n")?;
|
||||
|
||||
writeln!(output, "## Summary\n")?;
|
||||
writeln!(output, "- **Total executions**: {}", total_executions)?;
|
||||
|
||||
let correct_count: usize = sorted_buckets
|
||||
.iter()
|
||||
.filter(|b| b.is_correct)
|
||||
.map(|b| b.execution_indices.len())
|
||||
.sum();
|
||||
|
||||
let incorrect_count: usize = sorted_buckets
|
||||
.iter()
|
||||
.filter(|b| !b.is_correct)
|
||||
.map(|b| b.execution_indices.len())
|
||||
.sum();
|
||||
|
||||
writeln!(
|
||||
output,
|
||||
"- **Correct predictions**: {} ({:.1}%)",
|
||||
correct_count,
|
||||
(correct_count as f64 / total_executions as f64) * 100.0
|
||||
)?;
|
||||
|
||||
writeln!(
|
||||
output,
|
||||
"- **Incorrect predictions**: {} ({:.1}%)",
|
||||
incorrect_count,
|
||||
(incorrect_count as f64 / total_executions as f64) * 100.0
|
||||
)?;
|
||||
|
||||
writeln!(
|
||||
output,
|
||||
"- **No Predictions**: {} ({:.1}%)",
|
||||
empty_predictions.len(),
|
||||
(empty_predictions.len() as f64 / total_executions as f64) * 100.0
|
||||
)?;
|
||||
|
||||
let unique_incorrect = sorted_buckets.iter().filter(|b| !b.is_correct).count();
|
||||
writeln!(
|
||||
output,
|
||||
"- **Unique incorrect edit patterns**: {}\n",
|
||||
unique_incorrect
|
||||
)?;
|
||||
|
||||
writeln!(output, "---\n")?;
|
||||
|
||||
for (idx, bucket) in sorted_buckets.iter().filter(|b| b.is_correct).enumerate() {
|
||||
if idx == 0 {
|
||||
writeln!(
|
||||
output,
|
||||
"## Correct Predictions ({} occurrences)\n",
|
||||
bucket.execution_indices.len()
|
||||
)?;
|
||||
}
|
||||
|
||||
writeln!(output, "**Predicted Edit:**\n")?;
|
||||
writeln!(output, "```diff")?;
|
||||
writeln!(output, "{}", bucket.diff)?;
|
||||
writeln!(output, "```\n")?;
|
||||
|
||||
writeln!(
|
||||
output,
|
||||
"**Executions:** {}\n",
|
||||
bucket.execution_indices.join(", ")
|
||||
)?;
|
||||
writeln!(output, "---\n")?;
|
||||
}
|
||||
|
||||
for (idx, bucket) in sorted_buckets.iter().filter(|b| !b.is_correct).enumerate() {
|
||||
writeln!(
|
||||
output,
|
||||
"## Incorrect Prediction #{} ({} occurrences)\n",
|
||||
idx + 1,
|
||||
bucket.execution_indices.len()
|
||||
)?;
|
||||
|
||||
writeln!(output, "**Predicted Edit:**\n")?;
|
||||
writeln!(output, "```diff")?;
|
||||
writeln!(output, "{}", bucket.diff)?;
|
||||
writeln!(output, "```\n")?;
|
||||
|
||||
writeln!(
|
||||
output,
|
||||
"**Executions:** {}\n",
|
||||
bucket.execution_indices.join(", ")
|
||||
)?;
|
||||
|
||||
for (exec_id, reasoning) in bucket
|
||||
.execution_indices
|
||||
.iter()
|
||||
.zip(bucket.reasoning_samples.iter())
|
||||
{
|
||||
writeln!(output, "{}", fmt_execution(exec_id, reasoning))?;
|
||||
}
|
||||
|
||||
writeln!(output, "\n---\n")?;
|
||||
}
|
||||
|
||||
if !empty_predictions.is_empty() {
|
||||
writeln!(
|
||||
output,
|
||||
"## No Predictions ({} occurrences)\n",
|
||||
empty_predictions.len()
|
||||
)?;
|
||||
|
||||
for execution_data in &empty_predictions {
|
||||
writeln!(
|
||||
output,
|
||||
"{}",
|
||||
fmt_execution(&execution_data.execution_id, &execution_data.reasoning)
|
||||
)?;
|
||||
}
|
||||
writeln!(output, "\n---\n")?;
|
||||
}
|
||||
|
||||
if !errors.is_empty() {
|
||||
writeln!(output, "## Errors ({} occurrences)\n", errors.len())?;
|
||||
|
||||
for (err, name, repetition_ix) in &errors {
|
||||
writeln!(output, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
|
||||
}
|
||||
writeln!(output, "\n---\n")?;
|
||||
}
|
||||
|
||||
fn fmt_execution(exec_id: &str, reasoning: &str) -> String {
|
||||
let exec_content = format!(
|
||||
"\n### Execution {} `{}/{}/prediction_response.md`{}",
|
||||
exec_id,
|
||||
crate::paths::RUN_DIR.display(),
|
||||
exec_id,
|
||||
indent_text(&format!("\n\n```\n{}\n```\n", reasoning,), 2)
|
||||
);
|
||||
indent_text(&exec_content, 2)
|
||||
}
|
||||
|
||||
fn indent_text(text: &str, spaces: usize) -> String {
|
||||
let indent = " ".repeat(spaces);
|
||||
text.lines()
|
||||
.collect::<Vec<_>>()
|
||||
.join(&format!("\n{}", indent))
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn fmt_evaluation_error(err: &anyhow::Error, name: &str, repetition_ix: &Option<u16>) -> String {
|
||||
let err = format!("{err:?}")
|
||||
.replace("<edits", "```xml\n<edits")
|
||||
.replace("</edits>", "</edits>\n```");
|
||||
format!(
|
||||
"### ERROR {name}{}\n\n{err}\n",
|
||||
repetition_ix
|
||||
.map(|ix| format!(" [RUN {ix:03}]"))
|
||||
.unwrap_or_default()
|
||||
)
|
||||
}
|
||||
@@ -1,59 +1,103 @@
|
||||
use crate::{
|
||||
PredictionProvider, PromptFormat,
|
||||
metrics::ClassificationMetrics,
|
||||
paths::{REPOS_DIR, WORKTREES_DIR},
|
||||
};
|
||||
use anyhow::{Context as _, Result};
|
||||
use edit_prediction::udiff::OpenedBuffers;
|
||||
use gpui::Entity;
|
||||
use http_client::Url;
|
||||
use language::{Anchor, Buffer};
|
||||
use project::Project;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::Arc;
|
||||
use std::{
|
||||
borrow::Cow,
|
||||
cell::RefCell,
|
||||
fmt::{self, Display},
|
||||
fs,
|
||||
hash::Hash,
|
||||
hash::Hasher,
|
||||
io::Write,
|
||||
io::{Read, Write},
|
||||
mem,
|
||||
path::{Path, PathBuf},
|
||||
sync::{Arc, OnceLock},
|
||||
};
|
||||
use zeta_prompt::RelatedFile;
|
||||
|
||||
use crate::headless::ZetaCliAppState;
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use clap::ValueEnum;
|
||||
use cloud_zeta2_prompt::CURSOR_MARKER;
|
||||
use collections::HashMap;
|
||||
use edit_prediction::udiff::OpenedBuffers;
|
||||
use futures::{
|
||||
AsyncWriteExt as _,
|
||||
lock::{Mutex, OwnedMutexGuard},
|
||||
};
|
||||
use futures::{FutureExt as _, future::Shared};
|
||||
use gpui::{AsyncApp, Entity, Task, http_client::Url};
|
||||
use language::{Anchor, Buffer};
|
||||
use project::{Project, ProjectPath};
|
||||
use pulldown_cmark::CowStr;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use util::{paths::PathStyle, rel_path::RelPath};
|
||||
|
||||
use crate::paths::{REPOS_DIR, WORKTREES_DIR};
|
||||
|
||||
const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
|
||||
const EDIT_HISTORY_HEADING: &str = "Edit History";
|
||||
const CURSOR_POSITION_HEADING: &str = "Cursor Position";
|
||||
const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
|
||||
const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
|
||||
const REPOSITORY_URL_FIELD: &str = "repository_url";
|
||||
const REVISION_FIELD: &str = "revision";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NamedExample {
|
||||
pub name: String,
|
||||
pub example: Example,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Hash, Serialize, Deserialize)]
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Example {
|
||||
#[serde(default)]
|
||||
pub name: String,
|
||||
pub repository_url: String,
|
||||
pub revision: String,
|
||||
pub uncommitted_diff: String,
|
||||
pub cursor_path: PathBuf,
|
||||
pub cursor_path: Arc<Path>,
|
||||
pub cursor_position: String,
|
||||
pub edit_history: String,
|
||||
pub expected_patch: String,
|
||||
|
||||
/// The full content of the file where an edit is being predicted, and the
|
||||
/// actual cursor offset.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub buffer: Option<ExampleBuffer>,
|
||||
|
||||
/// The context retrieved for the prediction. This requires the worktree to
|
||||
/// be loaded and the language server to be started.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub context: Option<ExampleContext>,
|
||||
|
||||
/// The input and expected output from the edit prediction model.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub prompt: Option<ExamplePrompt>,
|
||||
|
||||
/// The actual predictions from the model.
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub predictions: Vec<ExamplePrediction>,
|
||||
|
||||
/// The scores, for how well the actual predictions match the expected
|
||||
/// predictions.
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub score: Vec<ExampleScore>,
|
||||
|
||||
/// The application state used to process this example.
|
||||
#[serde(skip)]
|
||||
pub state: Option<ExampleState>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ExampleState {
|
||||
pub project: Entity<Project>,
|
||||
pub buffer: Entity<Buffer>,
|
||||
pub cursor_position: Anchor,
|
||||
pub _open_buffers: OpenedBuffers,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ExampleContext {
|
||||
pub files: Arc<[RelatedFile]>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ExampleBuffer {
|
||||
pub content: String,
|
||||
pub cursor_row: u32,
|
||||
pub cursor_column: u32,
|
||||
pub cursor_offset: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ExamplePrompt {
|
||||
pub input: String,
|
||||
pub expected_output: String,
|
||||
pub format: PromptFormat,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ExamplePrediction {
|
||||
pub actual_patch: String,
|
||||
pub actual_output: String,
|
||||
pub provider: PredictionProvider,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ExampleScore {
|
||||
pub delta_chr_f: f32,
|
||||
pub line_match: ClassificationMetrics,
|
||||
}
|
||||
|
||||
impl Example {
|
||||
@@ -90,485 +134,244 @@ impl Example {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn setup_worktree(&self, file_name: String) -> Result<PathBuf> {
|
||||
let (repo_owner, repo_name) = self.repo_name()?;
|
||||
pub fn worktree_path(&self) -> PathBuf {
|
||||
WORKTREES_DIR
|
||||
.join(&self.name)
|
||||
.join(self.repo_name().unwrap().1.as_ref())
|
||||
}
|
||||
|
||||
let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
|
||||
let repo_lock = lock_repo(&repo_dir).await;
|
||||
pub fn repo_path(&self) -> PathBuf {
|
||||
let (repo_owner, repo_name) = self.repo_name().expect("failed to get repo name");
|
||||
REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref())
|
||||
}
|
||||
}
|
||||
|
||||
if !repo_dir.is_dir() {
|
||||
fs::create_dir_all(&repo_dir)?;
|
||||
run_git(&repo_dir, &["init"]).await?;
|
||||
run_git(
|
||||
&repo_dir,
|
||||
&["remote", "add", "origin", &self.repository_url],
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
|
||||
let mut examples = Vec::new();
|
||||
|
||||
// Resolve the example to a revision, fetching it if needed.
|
||||
let revision = run_git(
|
||||
&repo_dir,
|
||||
&["rev-parse", &format!("{}^{{commit}}", self.revision)],
|
||||
)
|
||||
.await;
|
||||
let revision = if let Ok(revision) = revision {
|
||||
revision
|
||||
let stdin_path: PathBuf = PathBuf::from("-");
|
||||
|
||||
let inputs = if inputs.is_empty() {
|
||||
&[stdin_path]
|
||||
} else {
|
||||
inputs
|
||||
};
|
||||
|
||||
for path in inputs {
|
||||
let is_stdin = path.as_path() == Path::new("-");
|
||||
let content = if is_stdin {
|
||||
let mut buffer = String::new();
|
||||
std::io::stdin()
|
||||
.read_to_string(&mut buffer)
|
||||
.expect("Failed to read from stdin");
|
||||
buffer
|
||||
} else {
|
||||
if run_git(
|
||||
&repo_dir,
|
||||
&["fetch", "--depth", "1", "origin", &self.revision],
|
||||
)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
run_git(&repo_dir, &["fetch", "origin"]).await?;
|
||||
}
|
||||
let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
|
||||
if revision != self.revision {
|
||||
run_git(&repo_dir, &["tag", &self.revision, &revision]).await?;
|
||||
}
|
||||
revision
|
||||
std::fs::read_to_string(path)
|
||||
.unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
|
||||
};
|
||||
let filename = path.file_stem().unwrap().to_string_lossy().to_string();
|
||||
let ext = if !is_stdin {
|
||||
path.extension()
|
||||
.map(|ext| ext.to_string_lossy().to_string())
|
||||
.unwrap_or_else(|| panic!("{} should have an extension", path.display()))
|
||||
} else {
|
||||
"jsonl".to_string()
|
||||
};
|
||||
|
||||
// Create the worktree for this example if needed.
|
||||
let worktree_path = WORKTREES_DIR.join(&file_name).join(repo_name.as_ref());
|
||||
if worktree_path.is_dir() {
|
||||
run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
|
||||
run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
|
||||
run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
|
||||
} else {
|
||||
let worktree_path_string = worktree_path.to_string_lossy();
|
||||
run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
|
||||
run_git(
|
||||
&repo_dir,
|
||||
&["worktree", "add", "-f", &worktree_path_string, &file_name],
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
drop(repo_lock);
|
||||
|
||||
// Apply the uncommitted diff for this example.
|
||||
if !self.uncommitted_diff.is_empty() {
|
||||
let mut apply_process = smol::process::Command::new("git")
|
||||
.current_dir(&worktree_path)
|
||||
.args(&["apply", "-"])
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.spawn()?;
|
||||
|
||||
let mut stdin = apply_process.stdin.take().unwrap();
|
||||
stdin.write_all(self.uncommitted_diff.as_bytes()).await?;
|
||||
stdin.close().await?;
|
||||
drop(stdin);
|
||||
|
||||
let apply_result = apply_process.output().await?;
|
||||
if !apply_result.status.success() {
|
||||
anyhow::bail!(
|
||||
"Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
|
||||
apply_result.status,
|
||||
String::from_utf8_lossy(&apply_result.stderr),
|
||||
String::from_utf8_lossy(&apply_result.stdout),
|
||||
);
|
||||
match ext.as_ref() {
|
||||
"json" => {
|
||||
let mut example =
|
||||
serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
|
||||
panic!("Failed to parse example file: {}\n{error}", path.display())
|
||||
});
|
||||
if example.name.is_empty() {
|
||||
example.name = filename;
|
||||
}
|
||||
examples.push(example);
|
||||
}
|
||||
"jsonl" => examples.extend(
|
||||
content
|
||||
.lines()
|
||||
.enumerate()
|
||||
.map(|(line_ix, line)| {
|
||||
let mut example =
|
||||
serde_json::from_str::<Example>(line).unwrap_or_else(|_| {
|
||||
panic!(
|
||||
"Failed to parse example on {}:{}",
|
||||
path.display(),
|
||||
line_ix + 1
|
||||
)
|
||||
});
|
||||
if example.name.is_empty() {
|
||||
example.name = format!("{filename}-{line_ix}")
|
||||
}
|
||||
example
|
||||
})
|
||||
.collect::<Vec<Example>>(),
|
||||
),
|
||||
"md" => {
|
||||
examples.push(parse_markdown_example(filename, &content).unwrap());
|
||||
}
|
||||
ext => {
|
||||
panic!("{} has invalid example extension `{ext}`", path.display())
|
||||
}
|
||||
}
|
||||
|
||||
Ok(worktree_path)
|
||||
}
|
||||
examples
|
||||
}
|
||||
|
||||
pub fn unique_name(&self) -> String {
|
||||
let mut hasher = std::hash::DefaultHasher::new();
|
||||
self.hash(&mut hasher);
|
||||
let disambiguator = hasher.finish();
|
||||
let hash = format!("{:04x}", disambiguator);
|
||||
format!("{}_{}", &self.revision[..8], &hash[..4])
|
||||
pub fn write_examples(examples: &[Example], output_path: Option<&PathBuf>) {
|
||||
let mut content = String::new();
|
||||
for example in examples {
|
||||
let line = serde_json::to_string(example).unwrap();
|
||||
content.push_str(&line);
|
||||
content.push('\n');
|
||||
}
|
||||
if let Some(output_path) = output_path {
|
||||
std::fs::write(output_path, content).expect("Failed to write examples");
|
||||
} else {
|
||||
std::io::stdout().write_all(&content.as_bytes()).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
pub type ActualExcerpt = Excerpt;
|
||||
fn parse_markdown_example(id: String, input: &str) -> Result<Example> {
|
||||
use pulldown_cmark::{CodeBlockKind, CowStr, Event, HeadingLevel, Parser, Tag, TagEnd};
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Excerpt {
|
||||
pub path: PathBuf,
|
||||
pub text: String,
|
||||
}
|
||||
const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
|
||||
const EDIT_HISTORY_HEADING: &str = "Edit History";
|
||||
const CURSOR_POSITION_HEADING: &str = "Cursor Position";
|
||||
const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
|
||||
const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
|
||||
const REPOSITORY_URL_FIELD: &str = "repository_url";
|
||||
const REVISION_FIELD: &str = "revision";
|
||||
|
||||
#[derive(ValueEnum, Debug, Clone)]
|
||||
pub enum ExampleFormat {
|
||||
Json,
|
||||
Toml,
|
||||
Md,
|
||||
}
|
||||
let parser = Parser::new(input);
|
||||
|
||||
impl NamedExample {
|
||||
pub fn load(path: impl AsRef<Path>) -> Result<Self> {
|
||||
let path = path.as_ref();
|
||||
let content = std::fs::read_to_string(path)?;
|
||||
let ext = path.extension();
|
||||
let mut example = Example {
|
||||
name: id,
|
||||
repository_url: String::new(),
|
||||
revision: String::new(),
|
||||
uncommitted_diff: String::new(),
|
||||
cursor_path: PathBuf::new().into(),
|
||||
cursor_position: String::new(),
|
||||
edit_history: String::new(),
|
||||
expected_patch: String::new(),
|
||||
buffer: None,
|
||||
context: None,
|
||||
prompt: None,
|
||||
predictions: Vec::new(),
|
||||
score: Vec::new(),
|
||||
state: None,
|
||||
};
|
||||
|
||||
match ext.and_then(|s| s.to_str()) {
|
||||
Some("json") => Ok(Self {
|
||||
name: path.file_stem().unwrap_or_default().display().to_string(),
|
||||
example: serde_json::from_str(&content)?,
|
||||
}),
|
||||
Some("toml") => Ok(Self {
|
||||
name: path.file_stem().unwrap_or_default().display().to_string(),
|
||||
example: toml::from_str(&content)?,
|
||||
}),
|
||||
Some("md") => Self::parse_md(&content),
|
||||
Some(_) => {
|
||||
anyhow::bail!("Unrecognized example extension: {}", ext.unwrap().display());
|
||||
}
|
||||
None => {
|
||||
anyhow::bail!(
|
||||
"Failed to determine example type since the file does not have an extension."
|
||||
);
|
||||
}
|
||||
}
|
||||
let mut name = String::new();
|
||||
let mut text = String::new();
|
||||
let mut block_info: CowStr = "".into();
|
||||
|
||||
#[derive(PartialEq)]
|
||||
enum Section {
|
||||
UncommittedDiff,
|
||||
EditHistory,
|
||||
CursorPosition,
|
||||
ExpectedExcerpts,
|
||||
ExpectedPatch,
|
||||
Other,
|
||||
}
|
||||
|
||||
pub fn parse_md(input: &str) -> Result<Self> {
|
||||
use pulldown_cmark::{CodeBlockKind, Event, HeadingLevel, Parser, Tag, TagEnd};
|
||||
let mut current_section = Section::Other;
|
||||
|
||||
let parser = Parser::new(input);
|
||||
for event in parser {
|
||||
match event {
|
||||
Event::Text(line) => {
|
||||
text.push_str(&line);
|
||||
|
||||
let mut named = NamedExample {
|
||||
name: String::new(),
|
||||
example: Example {
|
||||
repository_url: String::new(),
|
||||
revision: String::new(),
|
||||
uncommitted_diff: String::new(),
|
||||
cursor_path: PathBuf::new(),
|
||||
cursor_position: String::new(),
|
||||
edit_history: String::new(),
|
||||
expected_patch: String::new(),
|
||||
},
|
||||
};
|
||||
|
||||
let mut text = String::new();
|
||||
let mut block_info: CowStr = "".into();
|
||||
|
||||
#[derive(PartialEq)]
|
||||
enum Section {
|
||||
UncommittedDiff,
|
||||
EditHistory,
|
||||
CursorPosition,
|
||||
ExpectedExcerpts,
|
||||
ExpectedPatch,
|
||||
Other,
|
||||
}
|
||||
|
||||
let mut current_section = Section::Other;
|
||||
|
||||
for event in parser {
|
||||
match event {
|
||||
Event::Text(line) => {
|
||||
text.push_str(&line);
|
||||
|
||||
if !named.name.is_empty()
|
||||
&& current_section == Section::Other
|
||||
// in h1 section
|
||||
&& let Some((field, value)) = line.split_once('=')
|
||||
{
|
||||
match field.trim() {
|
||||
REPOSITORY_URL_FIELD => {
|
||||
named.example.repository_url = value.trim().to_string();
|
||||
}
|
||||
REVISION_FIELD => {
|
||||
named.example.revision = value.trim().to_string();
|
||||
}
|
||||
_ => {}
|
||||
if let Some((field, value)) = line.split_once('=') {
|
||||
match field.trim() {
|
||||
REPOSITORY_URL_FIELD => {
|
||||
example.repository_url = value.trim().to_string();
|
||||
}
|
||||
REVISION_FIELD => {
|
||||
example.revision = value.trim().to_string();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
|
||||
if !named.name.is_empty() {
|
||||
anyhow::bail!(
|
||||
"Found multiple H1 headings. There should only be one with the name of the example."
|
||||
);
|
||||
}
|
||||
named.name = mem::take(&mut text);
|
||||
}
|
||||
Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
|
||||
let title = mem::take(&mut text);
|
||||
current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
|
||||
Section::UncommittedDiff
|
||||
} else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
|
||||
Section::EditHistory
|
||||
} else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
|
||||
Section::CursorPosition
|
||||
} else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
|
||||
Section::ExpectedPatch
|
||||
} else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
|
||||
Section::ExpectedExcerpts
|
||||
} else {
|
||||
Section::Other
|
||||
};
|
||||
}
|
||||
Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
|
||||
mem::take(&mut text);
|
||||
}
|
||||
Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
|
||||
mem::take(&mut text);
|
||||
}
|
||||
Event::End(TagEnd::Heading(level)) => {
|
||||
anyhow::bail!("Unexpected heading level: {level}");
|
||||
}
|
||||
Event::Start(Tag::CodeBlock(kind)) => {
|
||||
match kind {
|
||||
CodeBlockKind::Fenced(info) => {
|
||||
block_info = info;
|
||||
}
|
||||
CodeBlockKind::Indented => {
|
||||
anyhow::bail!("Unexpected indented codeblock");
|
||||
}
|
||||
};
|
||||
}
|
||||
Event::Start(_) => {
|
||||
text.clear();
|
||||
block_info = "".into();
|
||||
}
|
||||
Event::End(TagEnd::CodeBlock) => {
|
||||
let block_info = block_info.trim();
|
||||
match current_section {
|
||||
Section::UncommittedDiff => {
|
||||
named.example.uncommitted_diff = mem::take(&mut text);
|
||||
}
|
||||
Section::EditHistory => {
|
||||
named.example.edit_history.push_str(&mem::take(&mut text));
|
||||
}
|
||||
Section::CursorPosition => {
|
||||
named.example.cursor_path = block_info.into();
|
||||
named.example.cursor_position = mem::take(&mut text);
|
||||
}
|
||||
Section::ExpectedExcerpts => {
|
||||
mem::take(&mut text);
|
||||
}
|
||||
Section::ExpectedPatch => {
|
||||
named.example.expected_patch = mem::take(&mut text);
|
||||
}
|
||||
Section::Other => {}
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
if named.example.cursor_path.as_path() == Path::new("")
|
||||
|| named.example.cursor_position.is_empty()
|
||||
{
|
||||
anyhow::bail!("Missing cursor position codeblock");
|
||||
}
|
||||
|
||||
Ok(named)
|
||||
}
|
||||
|
||||
pub fn write(&self, format: ExampleFormat, mut out: impl Write) -> Result<()> {
|
||||
match format {
|
||||
ExampleFormat::Json => Ok(serde_json::to_writer(out, &self.example)?),
|
||||
ExampleFormat::Toml => {
|
||||
Ok(out.write_all(toml::to_string_pretty(&self.example)?.as_bytes())?)
|
||||
Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
|
||||
if !name.is_empty() {
|
||||
anyhow::bail!(
|
||||
"Found multiple H1 headings. There should only be one with the name of the example."
|
||||
);
|
||||
}
|
||||
name = mem::take(&mut text);
|
||||
}
|
||||
ExampleFormat::Md => Ok(write!(out, "{}", self)?),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn setup_project(
|
||||
&self,
|
||||
app_state: &Arc<ZetaCliAppState>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<Entity<Project>> {
|
||||
let worktree_path = self.setup_worktree().await?;
|
||||
|
||||
static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
|
||||
|
||||
AUTHENTICATED
|
||||
.get_or_init(|| {
|
||||
let client = app_state.client.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
client
|
||||
.sign_in_with_optional_connect(true, cx)
|
||||
.await
|
||||
.unwrap();
|
||||
})
|
||||
.shared()
|
||||
})
|
||||
.clone()
|
||||
.await;
|
||||
|
||||
let project = cx.update(|cx| {
|
||||
Project::local(
|
||||
app_state.client.clone(),
|
||||
app_state.node_runtime.clone(),
|
||||
app_state.user_store.clone(),
|
||||
app_state.languages.clone(),
|
||||
app_state.fs.clone(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
|
||||
let worktree = project
|
||||
.update(cx, |project, cx| {
|
||||
project.create_worktree(&worktree_path, true, cx)
|
||||
})?
|
||||
.await?;
|
||||
worktree
|
||||
.read_with(cx, |worktree, _cx| {
|
||||
worktree.as_local().unwrap().scan_complete()
|
||||
})?
|
||||
.await;
|
||||
|
||||
anyhow::Ok(project)
|
||||
}
|
||||
|
||||
pub async fn setup_worktree(&self) -> Result<PathBuf> {
|
||||
self.example.setup_worktree(self.file_name()).await
|
||||
}
|
||||
|
||||
pub fn file_name(&self) -> String {
|
||||
self.name
|
||||
.chars()
|
||||
.map(|c| {
|
||||
if c.is_whitespace() {
|
||||
'-'
|
||||
Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
|
||||
let title = mem::take(&mut text);
|
||||
current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
|
||||
Section::UncommittedDiff
|
||||
} else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
|
||||
Section::EditHistory
|
||||
} else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
|
||||
Section::CursorPosition
|
||||
} else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
|
||||
Section::ExpectedPatch
|
||||
} else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
|
||||
Section::ExpectedExcerpts
|
||||
} else {
|
||||
c.to_ascii_lowercase()
|
||||
Section::Other
|
||||
};
|
||||
}
|
||||
Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
|
||||
mem::take(&mut text);
|
||||
}
|
||||
Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
|
||||
mem::take(&mut text);
|
||||
}
|
||||
Event::End(TagEnd::Heading(level)) => {
|
||||
anyhow::bail!("Unexpected heading level: {level}");
|
||||
}
|
||||
Event::Start(Tag::CodeBlock(kind)) => {
|
||||
match kind {
|
||||
CodeBlockKind::Fenced(info) => {
|
||||
block_info = info;
|
||||
}
|
||||
CodeBlockKind::Indented => {
|
||||
anyhow::bail!("Unexpected indented codeblock");
|
||||
}
|
||||
};
|
||||
}
|
||||
Event::Start(_) => {
|
||||
text.clear();
|
||||
block_info = "".into();
|
||||
}
|
||||
Event::End(TagEnd::CodeBlock) => {
|
||||
let block_info = block_info.trim();
|
||||
match current_section {
|
||||
Section::UncommittedDiff => {
|
||||
example.uncommitted_diff = mem::take(&mut text);
|
||||
}
|
||||
Section::EditHistory => {
|
||||
example.edit_history.push_str(&mem::take(&mut text));
|
||||
}
|
||||
Section::CursorPosition => {
|
||||
example.cursor_path = Path::new(block_info).into();
|
||||
example.cursor_position = mem::take(&mut text);
|
||||
}
|
||||
Section::ExpectedExcerpts => {
|
||||
mem::take(&mut text);
|
||||
}
|
||||
Section::ExpectedPatch => {
|
||||
example.expected_patch = mem::take(&mut text);
|
||||
}
|
||||
Section::Other => {}
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub async fn cursor_position(
|
||||
&self,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<(Entity<Buffer>, Anchor)> {
|
||||
let worktree = project.read_with(cx, |project, cx| {
|
||||
project.visible_worktrees(cx).next().unwrap()
|
||||
})?;
|
||||
let cursor_path = RelPath::new(&self.example.cursor_path, PathStyle::Posix)?.into_arc();
|
||||
let cursor_buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_buffer(
|
||||
ProjectPath {
|
||||
worktree_id: worktree.read(cx).id(),
|
||||
path: cursor_path,
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await?;
|
||||
let cursor_offset_within_excerpt = self
|
||||
.example
|
||||
.cursor_position
|
||||
.find(CURSOR_MARKER)
|
||||
.ok_or_else(|| anyhow!("missing cursor marker"))?;
|
||||
let mut cursor_excerpt = self.example.cursor_position.clone();
|
||||
cursor_excerpt.replace_range(
|
||||
cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
|
||||
"",
|
||||
);
|
||||
let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
|
||||
let text = buffer.text();
|
||||
|
||||
let mut matches = text.match_indices(&cursor_excerpt);
|
||||
let Some((excerpt_offset, _)) = matches.next() else {
|
||||
anyhow::bail!(
|
||||
"\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Cursor excerpt did not exist in buffer."
|
||||
);
|
||||
};
|
||||
assert!(matches.next().is_none());
|
||||
|
||||
Ok(excerpt_offset)
|
||||
})??;
|
||||
|
||||
let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
|
||||
let cursor_anchor =
|
||||
cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
|
||||
Ok((cursor_buffer, cursor_anchor))
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub async fn apply_edit_history(
|
||||
&self,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<OpenedBuffers<'_>> {
|
||||
edit_prediction::udiff::apply_diff(&self.example.edit_history, project, cx).await
|
||||
}
|
||||
}
|
||||
|
||||
async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
|
||||
let output = smol::process::Command::new("git")
|
||||
.current_dir(repo_path)
|
||||
.args(args)
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
anyhow::ensure!(
|
||||
output.status.success(),
|
||||
"`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
|
||||
args.join(" "),
|
||||
repo_path.display(),
|
||||
output.status,
|
||||
String::from_utf8_lossy(&output.stderr),
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
);
|
||||
Ok(String::from_utf8(output.stdout)?.trim().to_string())
|
||||
}
|
||||
|
||||
impl Display for NamedExample {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "# {}\n\n", self.name)?;
|
||||
write!(
|
||||
f,
|
||||
"{REPOSITORY_URL_FIELD} = {}\n",
|
||||
self.example.repository_url
|
||||
)?;
|
||||
write!(f, "{REVISION_FIELD} = {}\n\n", self.example.revision)?;
|
||||
|
||||
write!(f, "## {UNCOMMITTED_DIFF_HEADING}\n\n")?;
|
||||
write!(f, "`````diff\n")?;
|
||||
write!(f, "{}", self.example.uncommitted_diff)?;
|
||||
write!(f, "`````\n")?;
|
||||
|
||||
if !self.example.edit_history.is_empty() {
|
||||
write!(f, "`````diff\n{}`````\n", self.example.edit_history)?;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
write!(
|
||||
f,
|
||||
"## {CURSOR_POSITION_HEADING}\n\n`````{}\n{}`````\n",
|
||||
self.example.cursor_path.display(),
|
||||
self.example.cursor_position
|
||||
)?;
|
||||
write!(f, "## {EDIT_HISTORY_HEADING}\n\n")?;
|
||||
|
||||
if !self.example.expected_patch.is_empty() {
|
||||
write!(
|
||||
f,
|
||||
"\n## {EXPECTED_PATCH_HEADING}\n\n`````diff\n{}`````\n",
|
||||
self.example.expected_patch
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
if example.cursor_path.as_ref() == Path::new("") || example.cursor_position.is_empty() {
|
||||
anyhow::bail!("Missing cursor position codeblock");
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
|
||||
REPO_LOCKS
|
||||
.with(|cell| {
|
||||
cell.borrow_mut()
|
||||
.entry(path.as_ref().to_path_buf())
|
||||
.or_default()
|
||||
.clone()
|
||||
})
|
||||
.lock_owned()
|
||||
.await
|
||||
Ok(example)
|
||||
}
|
||||
|
||||
280
crates/edit_prediction_cli/src/format_prompt.rs
Normal file
280
crates/edit_prediction_cli/src/format_prompt.rs
Normal file
@@ -0,0 +1,280 @@
|
||||
use crate::{
|
||||
PromptFormat,
|
||||
example::{Example, ExamplePrompt},
|
||||
headless::EpAppState,
|
||||
retrieve_context::run_context_retrieval,
|
||||
};
|
||||
use edit_prediction::{EditPredictionStore, zeta2::zeta2_prompt_input};
|
||||
use gpui::AsyncApp;
|
||||
use std::sync::Arc;
|
||||
use zeta_prompt::format_zeta_prompt;
|
||||
|
||||
pub async fn run_format_prompt(
|
||||
example: &mut Example,
|
||||
prompt_format: PromptFormat,
|
||||
app_state: Arc<EpAppState>,
|
||||
mut cx: AsyncApp,
|
||||
) {
|
||||
run_context_retrieval(example, app_state, cx.clone()).await;
|
||||
|
||||
let prompt = match prompt_format {
|
||||
PromptFormat::Teacher => TeacherPrompt::format(example),
|
||||
PromptFormat::Zeta2 => {
|
||||
let ep_store = cx
|
||||
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
|
||||
.unwrap();
|
||||
|
||||
let state = example.state.as_ref().unwrap();
|
||||
let snapshot = state
|
||||
.buffer
|
||||
.read_with(&cx, |buffer, _| buffer.snapshot())
|
||||
.unwrap();
|
||||
let project = state.project.clone();
|
||||
let (_, input) = ep_store
|
||||
.update(&mut cx, |ep_store, _cx| {
|
||||
zeta2_prompt_input(
|
||||
&snapshot,
|
||||
example.context.as_ref().unwrap().files.clone(),
|
||||
ep_store.edit_history_for_project(&project),
|
||||
example.cursor_path.clone(),
|
||||
example.buffer.as_ref().unwrap().cursor_offset,
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
format_zeta_prompt(&input)
|
||||
}
|
||||
};
|
||||
|
||||
example.prompt = Some(ExamplePrompt {
|
||||
input: prompt,
|
||||
expected_output: example.expected_patch.clone(), // TODO
|
||||
format: prompt_format,
|
||||
});
|
||||
}
|
||||
|
||||
pub trait PromptFormatter {
|
||||
fn format(example: &Example) -> String;
|
||||
}
|
||||
|
||||
pub trait PromptParser {
|
||||
/// Return unified diff patch of prediction given raw LLM response
|
||||
fn parse(example: &Example, response: &str) -> String;
|
||||
}
|
||||
|
||||
pub struct TeacherPrompt;
|
||||
|
||||
impl PromptFormatter for TeacherPrompt {
|
||||
fn format(example: &Example) -> String {
|
||||
let edit_history = Self::format_edit_history(&example.edit_history);
|
||||
let context = Self::format_context(example);
|
||||
let editable_region = Self::format_editable_region(example);
|
||||
|
||||
let prompt = Self::PROMPT
|
||||
.replace("{{context}}", &context)
|
||||
.replace("{{edit_history}}", &edit_history)
|
||||
.replace("{{editable_region}}", &editable_region);
|
||||
|
||||
prompt
|
||||
}
|
||||
}
|
||||
|
||||
impl TeacherPrompt {
|
||||
const PROMPT: &str = include_str!("teacher.prompt.md");
|
||||
pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
|
||||
pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
|
||||
|
||||
/// Truncate edit history to this number of last lines
|
||||
const MAX_HISTORY_LINES: usize = 128;
|
||||
|
||||
fn format_edit_history(edit_history: &str) -> String {
|
||||
// Strip comments ("garbage lines") from edit history
|
||||
let lines = edit_history
|
||||
.lines()
|
||||
.filter(|&s| Self::is_udiff_content_line(s))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
|
||||
&lines[lines.len() - Self::MAX_HISTORY_LINES..]
|
||||
} else {
|
||||
&lines
|
||||
};
|
||||
|
||||
if history_lines.is_empty() {
|
||||
return "(No edit history)".to_string();
|
||||
}
|
||||
|
||||
history_lines.join("\n")
|
||||
}
|
||||
|
||||
fn format_context(example: &Example) -> String {
|
||||
if example.context.is_none() {
|
||||
panic!("Missing context retriever step");
|
||||
}
|
||||
|
||||
let mut prompt = String::new();
|
||||
zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);
|
||||
|
||||
prompt
|
||||
}
|
||||
|
||||
fn format_editable_region(example: &Example) -> String {
|
||||
let mut result = String::new();
|
||||
|
||||
let path_str = example.cursor_path.to_string_lossy();
|
||||
result.push_str(&format!("`````path=\"{path_str}\"\n"));
|
||||
result.push_str(Self::EDITABLE_REGION_START);
|
||||
|
||||
// TODO: control number of lines around cursor
|
||||
result.push_str(&example.cursor_position);
|
||||
if !example.cursor_position.ends_with('\n') {
|
||||
result.push('\n');
|
||||
}
|
||||
|
||||
result.push_str(&format!("{}\n", Self::EDITABLE_REGION_END));
|
||||
result.push_str("`````");
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
fn extract_editable_region(text: &str) -> String {
|
||||
let start = text
|
||||
.find(Self::EDITABLE_REGION_START)
|
||||
.map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
|
||||
let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
|
||||
|
||||
let region = &text[start..end];
|
||||
|
||||
region.replace("<|user_cursor|>", "")
|
||||
}
|
||||
|
||||
fn is_udiff_content_line(s: &str) -> bool {
|
||||
s.starts_with("-")
|
||||
|| s.starts_with("+")
|
||||
|| s.starts_with(" ")
|
||||
|| s.starts_with("---")
|
||||
|| s.starts_with("+++")
|
||||
|| s.starts_with("@@")
|
||||
}
|
||||
}
|
||||
|
||||
impl PromptParser for TeacherPrompt {
|
||||
fn parse(example: &Example, response: &str) -> String {
|
||||
// Ideally, we should always be able to find cursor position in the retrieved context.
|
||||
// In reality, sometimes we don't find it for these reasons:
|
||||
// 1. `example.cursor_position` contains _more_ context than included in the retrieved context
|
||||
// (can be fixed by getting cursor coordinates at the load_example stage)
|
||||
// 2. Context retriever just didn't include cursor line.
|
||||
//
|
||||
// In that case, fallback to using `cursor_position` as excerpt.
|
||||
let cursor_file = &example
|
||||
.buffer
|
||||
.as_ref()
|
||||
.expect("`buffer` should be filled in in the context collection step")
|
||||
.content;
|
||||
|
||||
// Extract updated (new) editable region from the model response
|
||||
let new_editable_region = extract_last_codeblock(response);
|
||||
|
||||
// Reconstruct old editable region we sent to the model
|
||||
let old_editable_region = Self::format_editable_region(example);
|
||||
let old_editable_region = Self::extract_editable_region(&old_editable_region);
|
||||
if !cursor_file.contains(&old_editable_region) {
|
||||
panic!("Something's wrong: editable_region is not found in the cursor file")
|
||||
}
|
||||
|
||||
// Apply editable region to a larger context and compute diff.
|
||||
// This is needed to get a better context lines around the editable region
|
||||
let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
|
||||
let diff = language::unified_diff(&cursor_file, &edited_file);
|
||||
|
||||
let diff = indoc::formatdoc! {"
|
||||
--- a/{path}
|
||||
+++ b/{path}
|
||||
{diff}
|
||||
",
|
||||
path = example.cursor_path.to_string_lossy(),
|
||||
diff = diff,
|
||||
};
|
||||
|
||||
diff
|
||||
}
|
||||
}
|
||||
|
||||
fn extract_last_codeblock(text: &str) -> String {
|
||||
let mut last_block = None;
|
||||
let mut search_start = 0;
|
||||
|
||||
while let Some(start) = text[search_start..].find("```") {
|
||||
let start = start + search_start;
|
||||
let bytes = text.as_bytes();
|
||||
let mut backtick_end = start;
|
||||
|
||||
while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
|
||||
backtick_end += 1;
|
||||
}
|
||||
|
||||
let backtick_count = backtick_end - start;
|
||||
let closing_backticks = "`".repeat(backtick_count);
|
||||
|
||||
while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
|
||||
backtick_end += 1;
|
||||
}
|
||||
|
||||
if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
|
||||
let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1];
|
||||
last_block = Some(code_block.to_string());
|
||||
search_start = backtick_end + end_pos + backtick_count;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
last_block.unwrap_or_else(|| text.to_string())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_extract_last_code_block() {
|
||||
let text = indoc::indoc! {"
|
||||
Some thinking
|
||||
|
||||
```
|
||||
first block
|
||||
```
|
||||
|
||||
`````path='something' lines=1:2
|
||||
last block
|
||||
`````
|
||||
"};
|
||||
let last_block = extract_last_codeblock(text);
|
||||
assert_eq!(last_block, "last block");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_editable_region() {
|
||||
let text = indoc::indoc! {"
|
||||
some lines
|
||||
are
|
||||
here
|
||||
<|editable_region_start|>
|
||||
one
|
||||
two three
|
||||
|
||||
<|editable_region_end|>
|
||||
more
|
||||
lines here
|
||||
"};
|
||||
let parsed = TeacherPrompt::extract_editable_region(text);
|
||||
assert_eq!(
|
||||
parsed,
|
||||
indoc::indoc! {"
|
||||
one
|
||||
two three
|
||||
|
||||
"}
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -16,7 +16,7 @@ use std::sync::Arc;
|
||||
use util::ResultExt as _;
|
||||
|
||||
/// Headless subset of `workspace::AppState`.
|
||||
pub struct ZetaCliAppState {
|
||||
pub struct EpAppState {
|
||||
pub languages: Arc<LanguageRegistry>,
|
||||
pub client: Arc<Client>,
|
||||
pub user_store: Entity<UserStore>,
|
||||
@@ -25,7 +25,7 @@ pub struct ZetaCliAppState {
|
||||
}
|
||||
|
||||
// TODO: dedupe with crates/eval/src/eval.rs
|
||||
pub fn init(cx: &mut App) -> ZetaCliAppState {
|
||||
pub fn init(cx: &mut App) -> EpAppState {
|
||||
let app_commit_sha = option_env!("ZED_COMMIT_SHA").map(|s| AppCommitSha::new(s.to_owned()));
|
||||
|
||||
let app_version = AppVersion::load(
|
||||
@@ -112,7 +112,7 @@ pub fn init(cx: &mut App) -> ZetaCliAppState {
|
||||
prompt_store::init(cx);
|
||||
terminal_view::init(cx);
|
||||
|
||||
ZetaCliAppState {
|
||||
EpAppState {
|
||||
languages,
|
||||
client,
|
||||
user_store,
|
||||
|
||||
320
crates/edit_prediction_cli/src/load_project.rs
Normal file
320
crates/edit_prediction_cli/src/load_project.rs
Normal file
@@ -0,0 +1,320 @@
|
||||
use crate::{
|
||||
example::{Example, ExampleBuffer, ExampleState},
|
||||
headless::EpAppState,
|
||||
};
|
||||
use anyhow::{Result, anyhow};
|
||||
use collections::HashMap;
|
||||
use edit_prediction::EditPredictionStore;
|
||||
use edit_prediction::udiff::OpenedBuffers;
|
||||
use futures::{
|
||||
AsyncWriteExt as _,
|
||||
lock::{Mutex, OwnedMutexGuard},
|
||||
};
|
||||
use gpui::{AsyncApp, Entity};
|
||||
use language::{Anchor, Buffer, ToOffset, ToPoint};
|
||||
use project::buffer_store::BufferStoreEvent;
|
||||
use project::{Project, ProjectPath};
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
fs,
|
||||
path::{Path, PathBuf},
|
||||
sync::Arc,
|
||||
};
|
||||
use util::{paths::PathStyle, rel_path::RelPath};
|
||||
use zeta_prompt::CURSOR_MARKER;
|
||||
|
||||
pub async fn run_load_project(example: &mut Example, app_state: Arc<EpAppState>, mut cx: AsyncApp) {
|
||||
if example.state.is_some() {
|
||||
return;
|
||||
}
|
||||
|
||||
let project = setup_project(example, &app_state, &mut cx).await;
|
||||
let buffer_store = project
|
||||
.read_with(&cx, |project, _| project.buffer_store().clone())
|
||||
.unwrap();
|
||||
|
||||
let ep_store = cx
|
||||
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
|
||||
.unwrap();
|
||||
|
||||
cx.subscribe(&buffer_store, {
|
||||
let project = project.clone();
|
||||
move |_, event, cx| match event {
|
||||
BufferStoreEvent::BufferAdded(buffer) => {
|
||||
ep_store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
})
|
||||
.unwrap()
|
||||
.detach();
|
||||
|
||||
let _open_buffers = apply_edit_history(example, &project, &mut cx)
|
||||
.await
|
||||
.unwrap();
|
||||
let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await;
|
||||
example.buffer = buffer
|
||||
.read_with(&cx, |buffer, _cx| {
|
||||
let cursor_point = cursor_position.to_point(&buffer);
|
||||
Some(ExampleBuffer {
|
||||
content: buffer.text(),
|
||||
cursor_row: cursor_point.row,
|
||||
cursor_column: cursor_point.column,
|
||||
cursor_offset: cursor_position.to_offset(&buffer),
|
||||
})
|
||||
})
|
||||
.unwrap();
|
||||
example.state = Some(ExampleState {
|
||||
buffer,
|
||||
project,
|
||||
cursor_position,
|
||||
_open_buffers,
|
||||
});
|
||||
}
|
||||
|
||||
async fn cursor_position(
|
||||
example: &Example,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> (Entity<Buffer>, Anchor) {
|
||||
let worktree = project
|
||||
.read_with(cx, |project, cx| {
|
||||
project.visible_worktrees(cx).next().unwrap()
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let cursor_path = RelPath::new(&example.cursor_path, PathStyle::Posix)
|
||||
.unwrap()
|
||||
.into_arc();
|
||||
let cursor_buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_buffer(
|
||||
ProjectPath {
|
||||
worktree_id: worktree.read(cx).id(),
|
||||
path: cursor_path,
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.unwrap()
|
||||
.await
|
||||
.unwrap();
|
||||
let cursor_offset_within_excerpt = example
|
||||
.cursor_position
|
||||
.find(CURSOR_MARKER)
|
||||
.ok_or_else(|| anyhow!("missing cursor marker"))
|
||||
.unwrap();
|
||||
let mut cursor_excerpt = example.cursor_position.clone();
|
||||
cursor_excerpt.replace_range(
|
||||
cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
|
||||
"",
|
||||
);
|
||||
let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
|
||||
let text = buffer.text();
|
||||
|
||||
let mut matches = text.match_indices(&cursor_excerpt);
|
||||
let (excerpt_offset, _) = matches.next().unwrap_or_else(|| {
|
||||
panic!(
|
||||
"\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Cursor excerpt did not exist in buffer."
|
||||
);
|
||||
});
|
||||
assert!(matches.next().is_none(), "More than one cursor position match found for {}", &example.name);
|
||||
excerpt_offset
|
||||
}).unwrap();
|
||||
|
||||
let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
|
||||
let cursor_anchor = cursor_buffer
|
||||
.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))
|
||||
.unwrap();
|
||||
|
||||
(cursor_buffer, cursor_anchor)
|
||||
}
|
||||
|
||||
async fn setup_project(
|
||||
example: &mut Example,
|
||||
app_state: &Arc<EpAppState>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Entity<Project> {
|
||||
setup_worktree(example).await;
|
||||
|
||||
let project = cx
|
||||
.update(|cx| {
|
||||
Project::local(
|
||||
app_state.client.clone(),
|
||||
app_state.node_runtime.clone(),
|
||||
app_state.user_store.clone(),
|
||||
app_state.languages.clone(),
|
||||
app_state.fs.clone(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
let worktree = project
|
||||
.update(cx, |project, cx| {
|
||||
project.create_worktree(&example.worktree_path(), true, cx)
|
||||
})
|
||||
.unwrap()
|
||||
.await
|
||||
.unwrap();
|
||||
worktree
|
||||
.read_with(cx, |worktree, _cx| {
|
||||
worktree.as_local().unwrap().scan_complete()
|
||||
})
|
||||
.unwrap()
|
||||
.await;
|
||||
project
|
||||
}
|
||||
|
||||
pub async fn setup_worktree(example: &Example) {
|
||||
let repo_dir = example.repo_path();
|
||||
let repo_lock = lock_repo(&repo_dir).await;
|
||||
|
||||
if !repo_dir.is_dir() {
|
||||
fs::create_dir_all(&repo_dir).unwrap();
|
||||
run_git(&repo_dir, &["init"]).await.unwrap();
|
||||
run_git(
|
||||
&repo_dir,
|
||||
&["remote", "add", "origin", &example.repository_url],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Resolve the example to a revision, fetching it if needed.
|
||||
let revision = run_git(
|
||||
&repo_dir,
|
||||
&["rev-parse", &format!("{}^{{commit}}", example.revision)],
|
||||
)
|
||||
.await;
|
||||
let revision = if let Ok(revision) = revision {
|
||||
revision
|
||||
} else {
|
||||
if run_git(
|
||||
&repo_dir,
|
||||
&["fetch", "--depth", "1", "origin", &example.revision],
|
||||
)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
run_git(&repo_dir, &["fetch", "origin"]).await.unwrap();
|
||||
}
|
||||
let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"])
|
||||
.await
|
||||
.unwrap();
|
||||
if revision != example.revision {
|
||||
run_git(&repo_dir, &["tag", &example.revision, &revision])
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
revision
|
||||
};
|
||||
|
||||
// Create the worktree for this example if needed.
|
||||
let worktree_path = example.worktree_path();
|
||||
if worktree_path.is_dir() {
|
||||
run_git(&worktree_path, &["clean", "--force", "-d"])
|
||||
.await
|
||||
.unwrap();
|
||||
run_git(&worktree_path, &["reset", "--hard", "HEAD"])
|
||||
.await
|
||||
.unwrap();
|
||||
run_git(&worktree_path, &["checkout", revision.as_str()])
|
||||
.await
|
||||
.unwrap();
|
||||
} else {
|
||||
let worktree_path_string = worktree_path.to_string_lossy();
|
||||
run_git(
|
||||
&repo_dir,
|
||||
&["branch", "-f", &example.name, revision.as_str()],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
run_git(
|
||||
&repo_dir,
|
||||
&[
|
||||
"worktree",
|
||||
"add",
|
||||
"-f",
|
||||
&worktree_path_string,
|
||||
&example.name,
|
||||
],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
drop(repo_lock);
|
||||
|
||||
// Apply the uncommitted diff for this example.
|
||||
if !example.uncommitted_diff.is_empty() {
|
||||
let mut apply_process = smol::process::Command::new("git")
|
||||
.current_dir(&worktree_path)
|
||||
.args(&["apply", "-"])
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.spawn()
|
||||
.unwrap();
|
||||
|
||||
let mut stdin = apply_process.stdin.take().unwrap();
|
||||
stdin
|
||||
.write_all(example.uncommitted_diff.as_bytes())
|
||||
.await
|
||||
.unwrap();
|
||||
stdin.close().await.unwrap();
|
||||
drop(stdin);
|
||||
|
||||
let apply_result = apply_process.output().await.unwrap();
|
||||
if !apply_result.status.success() {
|
||||
panic!(
|
||||
"Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
|
||||
apply_result.status,
|
||||
String::from_utf8_lossy(&apply_result.stderr),
|
||||
String::from_utf8_lossy(&apply_result.stdout),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn apply_edit_history(
|
||||
example: &Example,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<OpenedBuffers> {
|
||||
edit_prediction::udiff::apply_diff(&example.edit_history, project, cx).await
|
||||
}
|
||||
|
||||
thread_local! {
|
||||
static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
|
||||
REPO_LOCKS
|
||||
.with(|cell| {
|
||||
cell.borrow_mut()
|
||||
.entry(path.as_ref().to_path_buf())
|
||||
.or_default()
|
||||
.clone()
|
||||
})
|
||||
.lock_owned()
|
||||
.await
|
||||
}
|
||||
|
||||
async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
|
||||
let output = smol::process::Command::new("git")
|
||||
.current_dir(repo_path)
|
||||
.args(args)
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
anyhow::ensure!(
|
||||
output.status.success(),
|
||||
"`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
|
||||
args.join(" "),
|
||||
repo_path.display(),
|
||||
output.status,
|
||||
String::from_utf8_lossy(&output.stderr),
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
);
|
||||
Ok(String::from_utf8(output.stdout)?.trim().to_string())
|
||||
}
|
||||
@@ -1,522 +1,196 @@
|
||||
mod evaluate;
|
||||
mod anthropic_client;
|
||||
mod example;
|
||||
mod format_prompt;
|
||||
mod headless;
|
||||
mod load_project;
|
||||
mod metrics;
|
||||
mod paths;
|
||||
mod predict;
|
||||
mod source_location;
|
||||
mod training;
|
||||
mod util;
|
||||
mod retrieve_context;
|
||||
mod score;
|
||||
|
||||
use crate::{
|
||||
evaluate::run_evaluate,
|
||||
example::{ExampleFormat, NamedExample},
|
||||
headless::ZetaCliAppState,
|
||||
predict::run_predict,
|
||||
source_location::SourceLocation,
|
||||
training::{context::ContextType, distill::run_distill},
|
||||
util::{open_buffer, open_buffer_with_language_server},
|
||||
};
|
||||
use ::util::{ResultExt, paths::PathStyle};
|
||||
use anyhow::{Result, anyhow};
|
||||
use clap::{Args, Parser, Subcommand, ValueEnum};
|
||||
use cloud_llm_client::predict_edits_v3;
|
||||
use edit_prediction::udiff::DiffLine;
|
||||
use edit_prediction_context::EditPredictionExcerptOptions;
|
||||
use gpui::{Application, AsyncApp, Entity, prelude::*};
|
||||
use language::{Bias, Buffer, BufferSnapshot, Point};
|
||||
use metrics::delta_chr_f;
|
||||
use project::{Project, Worktree, lsp_store::OpenLspBufferHandle};
|
||||
use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
|
||||
use edit_prediction::EditPredictionStore;
|
||||
use gpui::Application;
|
||||
use reqwest_client::ReqwestClient;
|
||||
use std::io::{self};
|
||||
use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
|
||||
use crate::example::{read_examples, write_examples};
|
||||
use crate::format_prompt::run_format_prompt;
|
||||
use crate::load_project::run_load_project;
|
||||
use crate::predict::run_prediction;
|
||||
use crate::retrieve_context::run_context_retrieval;
|
||||
use crate::score::run_scoring;
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(name = "zeta")]
|
||||
struct ZetaCliArgs {
|
||||
#[command(name = "ep")]
|
||||
struct EpArgs {
|
||||
#[arg(long, default_value_t = false)]
|
||||
printenv: bool,
|
||||
#[clap(long, default_value_t = 10)]
|
||||
max_parallelism: usize,
|
||||
#[command(subcommand)]
|
||||
command: Option<Command>,
|
||||
#[clap(global = true)]
|
||||
inputs: Vec<PathBuf>,
|
||||
#[arg(long, short, global = true)]
|
||||
output: Option<PathBuf>,
|
||||
#[arg(long, short, global = true)]
|
||||
in_place: bool,
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
enum Command {
|
||||
Context(ContextArgs),
|
||||
Predict(PredictArguments),
|
||||
Eval(EvaluateArguments),
|
||||
Distill(DistillArguments),
|
||||
ConvertExample {
|
||||
path: PathBuf,
|
||||
#[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
|
||||
output_format: ExampleFormat,
|
||||
},
|
||||
Score {
|
||||
golden_patch: PathBuf,
|
||||
actual_patch: PathBuf,
|
||||
},
|
||||
/// Parse markdown examples and output a combined .jsonl file
|
||||
ParseExample,
|
||||
/// Create git worktrees for each example and load file contents
|
||||
LoadBuffer,
|
||||
/// Retrieve context for input examples.
|
||||
Context,
|
||||
/// Generate a prompt string for a specific model
|
||||
FormatPrompt(FormatPromptArgs),
|
||||
/// Runs edit prediction
|
||||
Predict(PredictArgs),
|
||||
/// Computes a score based on actual and expected patches
|
||||
Score(PredictArgs),
|
||||
/// Print aggregated scores
|
||||
Eval(PredictArgs),
|
||||
/// Remove git repositories and worktrees
|
||||
Clean,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
struct ContextArgs {
|
||||
#[arg(long)]
|
||||
provider: ContextProvider,
|
||||
#[arg(long)]
|
||||
worktree: PathBuf,
|
||||
#[arg(long)]
|
||||
cursor: SourceLocation,
|
||||
#[arg(long)]
|
||||
use_language_server: bool,
|
||||
#[arg(long)]
|
||||
edit_history: Option<FileOrStdin>,
|
||||
#[clap(flatten)]
|
||||
zeta2_args: Zeta2Args,
|
||||
struct FormatPromptArgs {
|
||||
#[clap(long)]
|
||||
prompt_format: PromptFormat,
|
||||
}
|
||||
|
||||
#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
|
||||
enum ContextProvider {
|
||||
Zeta1,
|
||||
#[default]
|
||||
#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
|
||||
enum PromptFormat {
|
||||
Teacher,
|
||||
Zeta2,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Args)]
|
||||
struct Zeta2Args {
|
||||
#[arg(long, default_value_t = 8192)]
|
||||
max_prompt_bytes: usize,
|
||||
#[arg(long, default_value_t = 2048)]
|
||||
max_excerpt_bytes: usize,
|
||||
#[arg(long, default_value_t = 1024)]
|
||||
min_excerpt_bytes: usize,
|
||||
#[arg(long, default_value_t = 0.66)]
|
||||
target_before_cursor_over_total_bytes: f32,
|
||||
#[arg(long, default_value_t = 1024)]
|
||||
max_diagnostic_bytes: usize,
|
||||
#[arg(long, value_enum, default_value_t = PromptFormat::default())]
|
||||
prompt_format: PromptFormat,
|
||||
#[arg(long, value_enum, default_value_t = Default::default())]
|
||||
output_format: OutputFormat,
|
||||
#[arg(long, default_value_t = 42)]
|
||||
file_indexing_parallelism: usize,
|
||||
#[arg(long, default_value_t = false)]
|
||||
disable_imports_gathering: bool,
|
||||
#[arg(long, default_value_t = u8::MAX)]
|
||||
max_retrieved_definitions: u8,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
pub struct PredictArguments {
|
||||
#[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
|
||||
format: PredictionsOutputFormat,
|
||||
example_path: PathBuf,
|
||||
#[clap(flatten)]
|
||||
options: PredictionOptions,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
pub struct DistillArguments {
|
||||
split_commit_dataset: PathBuf,
|
||||
#[clap(long, value_enum, default_value_t = ContextType::CurrentFile)]
|
||||
context_type: ContextType,
|
||||
#[clap(long)]
|
||||
batch: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Args)]
|
||||
pub struct PredictionOptions {
|
||||
#[clap(flatten)]
|
||||
zeta2: Zeta2Args,
|
||||
struct PredictArgs {
|
||||
#[clap(long)]
|
||||
provider: PredictionProvider,
|
||||
#[clap(long, value_enum, default_value_t = CacheMode::default())]
|
||||
cache: CacheMode,
|
||||
#[clap(long, default_value_t = 1)]
|
||||
repetitions: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, ValueEnum, Default, Clone, Copy, PartialEq)]
|
||||
pub enum CacheMode {
|
||||
/// Use cached LLM requests and responses, except when multiple repetitions are requested
|
||||
#[default]
|
||||
Auto,
|
||||
/// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
|
||||
#[value(alias = "request")]
|
||||
Requests,
|
||||
/// Ignore existing cache entries for both LLM and search.
|
||||
Skip,
|
||||
/// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
|
||||
/// Useful for reproducing results and fixing bugs outside of search queries
|
||||
Force,
|
||||
}
|
||||
|
||||
impl CacheMode {
|
||||
fn use_cached_llm_responses(&self) -> bool {
|
||||
self.assert_not_auto();
|
||||
matches!(self, CacheMode::Requests | CacheMode::Force)
|
||||
}
|
||||
|
||||
fn use_cached_search_results(&self) -> bool {
|
||||
self.assert_not_auto();
|
||||
matches!(self, CacheMode::Force)
|
||||
}
|
||||
|
||||
fn assert_not_auto(&self) {
|
||||
assert_ne!(
|
||||
*self,
|
||||
CacheMode::Auto,
|
||||
"Cache mode should not be auto at this point!"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(clap::ValueEnum, Debug, Clone)]
|
||||
pub enum PredictionsOutputFormat {
|
||||
Json,
|
||||
Md,
|
||||
Diff,
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
pub struct EvaluateArguments {
|
||||
example_paths: Vec<PathBuf>,
|
||||
#[clap(flatten)]
|
||||
options: PredictionOptions,
|
||||
#[clap(short, long, default_value_t = 1, alias = "repeat")]
|
||||
repetitions: u16,
|
||||
#[arg(long)]
|
||||
skip_prediction: bool,
|
||||
}
|
||||
|
||||
#[derive(clap::ValueEnum, Default, Debug, Clone, Copy, PartialEq)]
|
||||
#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
|
||||
enum PredictionProvider {
|
||||
Zeta1,
|
||||
#[default]
|
||||
Zeta2,
|
||||
Sweep,
|
||||
Mercury,
|
||||
Zeta1,
|
||||
Zeta2,
|
||||
Teacher,
|
||||
}
|
||||
|
||||
fn zeta2_args_to_options(args: &Zeta2Args) -> edit_prediction::ZetaOptions {
|
||||
edit_prediction::ZetaOptions {
|
||||
context: EditPredictionExcerptOptions {
|
||||
max_bytes: args.max_excerpt_bytes,
|
||||
min_bytes: args.min_excerpt_bytes,
|
||||
target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes,
|
||||
},
|
||||
max_prompt_bytes: args.max_prompt_bytes,
|
||||
prompt_format: args.prompt_format.into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
|
||||
enum PromptFormat {
|
||||
OnlySnippets,
|
||||
#[default]
|
||||
OldTextNewText,
|
||||
Minimal,
|
||||
MinimalQwen,
|
||||
SeedCoder1120,
|
||||
}
|
||||
|
||||
impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
|
||||
fn into(self) -> predict_edits_v3::PromptFormat {
|
||||
match self {
|
||||
Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
|
||||
Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
|
||||
Self::Minimal => predict_edits_v3::PromptFormat::Minimal,
|
||||
Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen,
|
||||
Self::SeedCoder1120 => predict_edits_v3::PromptFormat::SeedCoder1120,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(clap::ValueEnum, Default, Debug, Clone)]
|
||||
enum OutputFormat {
|
||||
#[default]
|
||||
Prompt,
|
||||
Request,
|
||||
Full,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum FileOrStdin {
|
||||
File(PathBuf),
|
||||
Stdin,
|
||||
}
|
||||
|
||||
impl FileOrStdin {
|
||||
async fn read_to_string(&self) -> Result<String, std::io::Error> {
|
||||
match self {
|
||||
FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
|
||||
FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for FileOrStdin {
|
||||
type Err = <PathBuf as FromStr>::Err;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"-" => Ok(Self::Stdin),
|
||||
_ => Ok(Self::File(PathBuf::from_str(s)?)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct LoadedContext {
|
||||
full_path_str: String,
|
||||
snapshot: BufferSnapshot,
|
||||
clipped_cursor: Point,
|
||||
worktree: Entity<Worktree>,
|
||||
project: Entity<Project>,
|
||||
buffer: Entity<Buffer>,
|
||||
lsp_open_handle: Option<OpenLspBufferHandle>,
|
||||
}
|
||||
|
||||
async fn load_context(
|
||||
args: &ContextArgs,
|
||||
app_state: &Arc<ZetaCliAppState>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<LoadedContext> {
|
||||
let ContextArgs {
|
||||
worktree: worktree_path,
|
||||
cursor,
|
||||
use_language_server,
|
||||
..
|
||||
} = args;
|
||||
|
||||
let worktree_path = worktree_path.canonicalize()?;
|
||||
|
||||
let project = cx.update(|cx| {
|
||||
Project::local(
|
||||
app_state.client.clone(),
|
||||
app_state.node_runtime.clone(),
|
||||
app_state.user_store.clone(),
|
||||
app_state.languages.clone(),
|
||||
app_state.fs.clone(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
|
||||
let worktree = project
|
||||
.update(cx, |project, cx| {
|
||||
project.create_worktree(&worktree_path, true, cx)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
let mut ready_languages = HashSet::default();
|
||||
let (lsp_open_handle, buffer) = if *use_language_server {
|
||||
let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
|
||||
project.clone(),
|
||||
worktree.clone(),
|
||||
cursor.path.clone(),
|
||||
&mut ready_languages,
|
||||
cx,
|
||||
)
|
||||
.await?;
|
||||
(Some(lsp_open_handle), buffer)
|
||||
} else {
|
||||
let buffer =
|
||||
open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
|
||||
(None, buffer)
|
||||
};
|
||||
|
||||
let full_path_str = worktree
|
||||
.read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
|
||||
.display(PathStyle::local())
|
||||
.to_string();
|
||||
|
||||
let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
|
||||
let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
|
||||
if clipped_cursor != cursor.point {
|
||||
let max_row = snapshot.max_point().row;
|
||||
if cursor.point.row < max_row {
|
||||
return Err(anyhow!(
|
||||
"Cursor position {:?} is out of bounds (line length is {})",
|
||||
cursor.point,
|
||||
snapshot.line_len(cursor.point.row)
|
||||
));
|
||||
impl EpArgs {
|
||||
fn output_path(&self) -> Option<PathBuf> {
|
||||
if self.in_place {
|
||||
if self.inputs.len() == 1 {
|
||||
self.inputs.first().cloned()
|
||||
} else {
|
||||
panic!("--in-place requires exactly one input file")
|
||||
}
|
||||
} else {
|
||||
return Err(anyhow!(
|
||||
"Cursor position {:?} is out of bounds (max row is {})",
|
||||
cursor.point,
|
||||
max_row
|
||||
));
|
||||
self.output.clone()
|
||||
}
|
||||
}
|
||||
|
||||
Ok(LoadedContext {
|
||||
full_path_str,
|
||||
snapshot,
|
||||
clipped_cursor,
|
||||
worktree,
|
||||
project,
|
||||
buffer,
|
||||
lsp_open_handle,
|
||||
})
|
||||
}
|
||||
|
||||
async fn zeta2_context(
|
||||
args: ContextArgs,
|
||||
app_state: &Arc<ZetaCliAppState>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<String> {
|
||||
let LoadedContext {
|
||||
worktree,
|
||||
project,
|
||||
buffer,
|
||||
clipped_cursor,
|
||||
lsp_open_handle: _handle,
|
||||
..
|
||||
} = load_context(&args, app_state, cx).await?;
|
||||
|
||||
// wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
|
||||
// the whole worktree.
|
||||
worktree
|
||||
.read_with(cx, |worktree, _cx| {
|
||||
worktree.as_local().unwrap().scan_complete()
|
||||
})?
|
||||
.await;
|
||||
let output = cx
|
||||
.update(|cx| {
|
||||
let store = cx.new(|cx| {
|
||||
edit_prediction::EditPredictionStore::new(
|
||||
app_state.client.clone(),
|
||||
app_state.user_store.clone(),
|
||||
cx,
|
||||
)
|
||||
});
|
||||
store.update(cx, |store, cx| {
|
||||
store.set_options(zeta2_args_to_options(&args.zeta2_args));
|
||||
store.register_buffer(&buffer, &project, cx);
|
||||
});
|
||||
cx.spawn(async move |cx| {
|
||||
let updates_rx = store.update(cx, |store, cx| {
|
||||
let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
|
||||
store.set_use_context(true);
|
||||
store.refresh_context(&project, &buffer, cursor, cx);
|
||||
store.project_context_updates(&project).unwrap()
|
||||
})?;
|
||||
|
||||
updates_rx.recv().await.ok();
|
||||
|
||||
let context = store.update(cx, |store, cx| {
|
||||
store.context_for_project(&project, cx).to_vec()
|
||||
})?;
|
||||
|
||||
anyhow::Ok(serde_json::to_string_pretty(&context).unwrap())
|
||||
})
|
||||
})?
|
||||
.await?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
async fn zeta1_context(
|
||||
args: ContextArgs,
|
||||
app_state: &Arc<ZetaCliAppState>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<edit_prediction::zeta1::GatherContextOutput> {
|
||||
let LoadedContext {
|
||||
full_path_str,
|
||||
snapshot,
|
||||
clipped_cursor,
|
||||
..
|
||||
} = load_context(&args, app_state, cx).await?;
|
||||
|
||||
let events = match args.edit_history {
|
||||
Some(events) => events.read_to_string().await?,
|
||||
None => String::new(),
|
||||
};
|
||||
|
||||
let prompt_for_events = move || (events, 0);
|
||||
cx.update(|cx| {
|
||||
edit_prediction::zeta1::gather_context(
|
||||
full_path_str,
|
||||
&snapshot,
|
||||
clipped_cursor,
|
||||
prompt_for_events,
|
||||
cloud_llm_client::PredictEditsRequestTrigger::Cli,
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await
|
||||
}
|
||||
|
||||
fn main() {
|
||||
zlog::init();
|
||||
zlog::init_output_stderr();
|
||||
let args = ZetaCliArgs::parse();
|
||||
let args = EpArgs::parse();
|
||||
|
||||
if args.printenv {
|
||||
::util::shell_env::print_env();
|
||||
return;
|
||||
}
|
||||
|
||||
let output = args.output_path();
|
||||
let command = match args.command {
|
||||
Some(cmd) => cmd,
|
||||
None => {
|
||||
EpArgs::command().print_help().unwrap();
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
match &command {
|
||||
Command::Clean => {
|
||||
std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
|
||||
return;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
let mut examples = read_examples(&args.inputs);
|
||||
let http_client = Arc::new(ReqwestClient::new());
|
||||
let app = Application::headless().with_http_client(http_client);
|
||||
|
||||
app.run(move |cx| {
|
||||
let app_state = Arc::new(headless::init(cx));
|
||||
EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
match args.command {
|
||||
None => {
|
||||
if args.printenv {
|
||||
::util::shell_env::print_env();
|
||||
} else {
|
||||
panic!("Expected a command");
|
||||
}
|
||||
}
|
||||
Some(Command::Context(context_args)) => {
|
||||
let result = match context_args.provider {
|
||||
ContextProvider::Zeta1 => {
|
||||
let context =
|
||||
zeta1_context(context_args, &app_state, cx).await.unwrap();
|
||||
serde_json::to_string_pretty(&context.body).unwrap()
|
||||
match &command {
|
||||
Command::Predict(args) => predict::sync_batches(&args.provider).await,
|
||||
_ => (),
|
||||
};
|
||||
|
||||
for data in examples.chunks_mut(args.max_parallelism) {
|
||||
let mut futures = Vec::new();
|
||||
for example in data.iter_mut() {
|
||||
let cx = cx.clone();
|
||||
let app_state = app_state.clone();
|
||||
futures.push(async {
|
||||
match &command {
|
||||
Command::ParseExample => {}
|
||||
Command::LoadBuffer => {
|
||||
run_load_project(example, app_state.clone(), cx).await;
|
||||
}
|
||||
Command::Context => {
|
||||
run_context_retrieval(example, app_state, cx).await;
|
||||
}
|
||||
Command::FormatPrompt(args) => {
|
||||
run_format_prompt(example, args.prompt_format, app_state, cx).await;
|
||||
}
|
||||
Command::Predict(args) => {
|
||||
run_prediction(
|
||||
example,
|
||||
Some(args.provider),
|
||||
args.repetitions,
|
||||
app_state.clone(),
|
||||
cx,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Command::Score(args) | Command::Eval(args) => {
|
||||
run_scoring(example, &args, app_state, cx).await;
|
||||
}
|
||||
Command::Clean => {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
ContextProvider::Zeta2 => {
|
||||
zeta2_context(context_args, &app_state, cx).await.unwrap()
|
||||
}
|
||||
};
|
||||
println!("{}", result);
|
||||
});
|
||||
}
|
||||
Some(Command::Predict(arguments)) => {
|
||||
run_predict(arguments, &app_state, cx).await;
|
||||
}
|
||||
Some(Command::Eval(arguments)) => {
|
||||
run_evaluate(arguments, &app_state, cx).await;
|
||||
}
|
||||
Some(Command::Distill(arguments)) => {
|
||||
let _guard = cx
|
||||
.update(|cx| gpui_tokio::Tokio::handle(cx))
|
||||
.unwrap()
|
||||
.enter();
|
||||
run_distill(arguments).await.log_err();
|
||||
}
|
||||
Some(Command::ConvertExample {
|
||||
path,
|
||||
output_format,
|
||||
}) => {
|
||||
let example = NamedExample::load(path).unwrap();
|
||||
example.write(output_format, io::stdout()).unwrap();
|
||||
}
|
||||
Some(Command::Score {
|
||||
golden_patch,
|
||||
actual_patch,
|
||||
}) => {
|
||||
let golden_content = std::fs::read_to_string(golden_patch).unwrap();
|
||||
let actual_content = std::fs::read_to_string(actual_patch).unwrap();
|
||||
futures::future::join_all(futures).await;
|
||||
}
|
||||
|
||||
let golden_diff: Vec<DiffLine> = golden_content
|
||||
.lines()
|
||||
.map(|line| DiffLine::parse(line))
|
||||
.collect();
|
||||
if args.output.is_some() || !matches!(command, Command::Eval(_)) {
|
||||
write_examples(&examples, output.as_ref());
|
||||
}
|
||||
|
||||
let actual_diff: Vec<DiffLine> = actual_content
|
||||
.lines()
|
||||
.map(|line| DiffLine::parse(line))
|
||||
.collect();
|
||||
|
||||
let score = delta_chr_f(&golden_diff, &actual_diff);
|
||||
println!("{:.2}", score);
|
||||
}
|
||||
Some(Command::Clean) => {
|
||||
std::fs::remove_dir_all(&*crate::paths::TARGET_ZETA_DIR).unwrap()
|
||||
}
|
||||
match &command {
|
||||
Command::Predict(args) => predict::sync_batches(&args.provider).await,
|
||||
Command::Eval(_) => score::print_report(&examples),
|
||||
_ => (),
|
||||
};
|
||||
|
||||
let _ = cx.update(|cx| cx.quit());
|
||||
|
||||
@@ -1,30 +1,34 @@
|
||||
use collections::{HashMap, HashSet};
|
||||
use edit_prediction::udiff::DiffLine;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
type Counts = HashMap<String, usize>;
|
||||
type CountsDelta = HashMap<String, isize>;
|
||||
|
||||
#[derive(Default, Debug, Clone)]
|
||||
pub struct Scores {
|
||||
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ClassificationMetrics {
|
||||
pub true_positives: usize,
|
||||
pub false_positives: usize,
|
||||
pub false_negatives: usize,
|
||||
}
|
||||
|
||||
impl Scores {
|
||||
pub fn from_sets(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
|
||||
impl ClassificationMetrics {
|
||||
pub fn from_sets(
|
||||
expected: &HashSet<String>,
|
||||
actual: &HashSet<String>,
|
||||
) -> ClassificationMetrics {
|
||||
let true_positives = expected.intersection(actual).count();
|
||||
let false_positives = actual.difference(expected).count();
|
||||
let false_negatives = expected.difference(actual).count();
|
||||
|
||||
Scores {
|
||||
ClassificationMetrics {
|
||||
true_positives,
|
||||
false_positives,
|
||||
false_negatives,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_counts(expected: &Counts, actual: &Counts) -> Scores {
|
||||
pub fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics {
|
||||
let mut true_positives = 0;
|
||||
let mut false_positives = 0;
|
||||
let mut false_negatives = 0;
|
||||
@@ -45,32 +49,16 @@ impl Scores {
|
||||
}
|
||||
}
|
||||
|
||||
Scores {
|
||||
ClassificationMetrics {
|
||||
true_positives,
|
||||
false_positives,
|
||||
false_negatives,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_markdown(&self) -> String {
|
||||
format!(
|
||||
"
|
||||
Precision : {:.4}
|
||||
Recall : {:.4}
|
||||
F1 Score : {:.4}
|
||||
True Positives : {}
|
||||
False Positives : {}
|
||||
False Negatives : {}",
|
||||
self.precision(),
|
||||
self.recall(),
|
||||
self.f1_score(),
|
||||
self.true_positives,
|
||||
self.false_positives,
|
||||
self.false_negatives
|
||||
)
|
||||
}
|
||||
|
||||
pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
|
||||
pub fn aggregate<'a>(
|
||||
scores: impl Iterator<Item = &'a ClassificationMetrics>,
|
||||
) -> ClassificationMetrics {
|
||||
let mut true_positives = 0;
|
||||
let mut false_positives = 0;
|
||||
let mut false_negatives = 0;
|
||||
@@ -81,7 +69,7 @@ False Negatives : {}",
|
||||
false_negatives += score.false_negatives;
|
||||
}
|
||||
|
||||
Scores {
|
||||
ClassificationMetrics {
|
||||
true_positives,
|
||||
false_positives,
|
||||
false_negatives,
|
||||
@@ -115,7 +103,10 @@ False Negatives : {}",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn line_match_score(expected_patch: &[DiffLine], actual_patch: &[DiffLine]) -> Scores {
|
||||
pub fn line_match_score(
|
||||
expected_patch: &[DiffLine],
|
||||
actual_patch: &[DiffLine],
|
||||
) -> ClassificationMetrics {
|
||||
let expected_change_lines = expected_patch
|
||||
.iter()
|
||||
.filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
|
||||
@@ -128,7 +119,7 @@ pub fn line_match_score(expected_patch: &[DiffLine], actual_patch: &[DiffLine])
|
||||
.map(|line| line.to_string())
|
||||
.collect();
|
||||
|
||||
Scores::from_sets(&expected_change_lines, &actual_change_lines)
|
||||
ClassificationMetrics::from_sets(&expected_change_lines, &actual_change_lines)
|
||||
}
|
||||
|
||||
enum ChrfWhitespace {
|
||||
@@ -204,7 +195,7 @@ pub fn delta_chr_f(expected: &[DiffLine], actual: &[DiffLine]) -> f64 {
|
||||
let expected_counts = ngram_delta_to_counts(&expected_delta);
|
||||
let actual_counts = ngram_delta_to_counts(&actual_delta);
|
||||
|
||||
let score = Scores::from_counts(&expected_counts, &actual_counts);
|
||||
let score = ClassificationMetrics::from_counts(&expected_counts, &actual_counts);
|
||||
total_precision += score.precision();
|
||||
total_recall += score.recall();
|
||||
}
|
||||
|
||||
@@ -1,57 +1,25 @@
|
||||
use std::{env, path::PathBuf, sync::LazyLock};
|
||||
use std::{
|
||||
path::{Path, PathBuf},
|
||||
sync::LazyLock,
|
||||
};
|
||||
|
||||
pub static TARGET_ZETA_DIR: LazyLock<PathBuf> =
|
||||
LazyLock::new(|| env::current_dir().unwrap().join("target/zeta"));
|
||||
pub static CACHE_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_ZETA_DIR.join("cache"));
|
||||
pub static REPOS_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_ZETA_DIR.join("repos"));
|
||||
pub static WORKTREES_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_ZETA_DIR.join("worktrees"));
|
||||
pub static DATA_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
|
||||
let dir = dirs::home_dir().unwrap().join(".zed_ep");
|
||||
ensure_dir(&dir)
|
||||
});
|
||||
pub static CACHE_DIR: LazyLock<PathBuf> = LazyLock::new(|| ensure_dir(&DATA_DIR.join("cache")));
|
||||
pub static REPOS_DIR: LazyLock<PathBuf> = LazyLock::new(|| ensure_dir(&DATA_DIR.join("repos")));
|
||||
pub static WORKTREES_DIR: LazyLock<PathBuf> =
|
||||
LazyLock::new(|| ensure_dir(&DATA_DIR.join("worktrees")));
|
||||
pub static RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
|
||||
TARGET_ZETA_DIR
|
||||
DATA_DIR
|
||||
.join("runs")
|
||||
.join(chrono::Local::now().format("%d-%m-%y-%H_%M_%S").to_string())
|
||||
});
|
||||
pub static LATEST_EXAMPLE_RUN_DIR: LazyLock<PathBuf> =
|
||||
LazyLock::new(|| TARGET_ZETA_DIR.join("latest"));
|
||||
pub static LATEST_EXAMPLE_RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| DATA_DIR.join("latest"));
|
||||
pub static LLM_CACHE_DB: LazyLock<PathBuf> = LazyLock::new(|| CACHE_DIR.join("llm_cache.sqlite"));
|
||||
|
||||
pub fn print_run_data_dir(deep: bool, use_color: bool) {
|
||||
println!("\n## Run Data\n");
|
||||
let mut files = Vec::new();
|
||||
|
||||
let current_dir = std::env::current_dir().unwrap();
|
||||
for file in std::fs::read_dir(&*RUN_DIR).unwrap() {
|
||||
let file = file.unwrap();
|
||||
if file.file_type().unwrap().is_dir() && deep {
|
||||
for file in std::fs::read_dir(file.path()).unwrap() {
|
||||
let path = file.unwrap().path();
|
||||
let path = path.strip_prefix(¤t_dir).unwrap_or(&path);
|
||||
files.push(format!(
|
||||
"- {}/{}{}{}",
|
||||
path.parent().unwrap().display(),
|
||||
if use_color { "\x1b[34m" } else { "" },
|
||||
path.file_name().unwrap().display(),
|
||||
if use_color { "\x1b[0m" } else { "" },
|
||||
));
|
||||
}
|
||||
} else {
|
||||
let path = file.path();
|
||||
let path = path.strip_prefix(¤t_dir).unwrap_or(&path);
|
||||
files.push(format!(
|
||||
"- {}/{}{}{}",
|
||||
path.parent().unwrap().display(),
|
||||
if use_color { "\x1b[34m" } else { "" },
|
||||
path.file_name().unwrap().display(),
|
||||
if use_color { "\x1b[0m" } else { "" }
|
||||
));
|
||||
}
|
||||
}
|
||||
files.sort();
|
||||
|
||||
for file in files {
|
||||
println!("{}", file);
|
||||
}
|
||||
|
||||
println!(
|
||||
"\n💡 Tip of the day: {} always points to the latest run\n",
|
||||
LATEST_EXAMPLE_RUN_DIR.display()
|
||||
);
|
||||
fn ensure_dir(path: &Path) -> PathBuf {
|
||||
std::fs::create_dir_all(path).expect("Failed to create directory");
|
||||
path.to_path_buf()
|
||||
}
|
||||
|
||||
@@ -1,374 +1,271 @@
|
||||
use crate::example::{ActualExcerpt, NamedExample};
|
||||
use crate::headless::ZetaCliAppState;
|
||||
use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir};
|
||||
use crate::{
|
||||
CacheMode, PredictArguments, PredictionOptions, PredictionProvider, PredictionsOutputFormat,
|
||||
PredictionProvider, PromptFormat,
|
||||
anthropic_client::AnthropicClient,
|
||||
example::{Example, ExamplePrediction},
|
||||
format_prompt::{PromptParser, TeacherPrompt, run_format_prompt},
|
||||
headless::EpAppState,
|
||||
load_project::run_load_project,
|
||||
paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
|
||||
retrieve_context::run_context_retrieval,
|
||||
};
|
||||
use edit_prediction::{DebugEvent, EditPredictionStore};
|
||||
use futures::{FutureExt as _, StreamExt as _, future::Shared};
|
||||
use gpui::{AppContext as _, AsyncApp, Task};
|
||||
use std::{
|
||||
fs,
|
||||
sync::{
|
||||
Arc, Mutex, OnceLock,
|
||||
atomic::{AtomicUsize, Ordering::SeqCst},
|
||||
},
|
||||
};
|
||||
use ::serde::Serialize;
|
||||
use anyhow::{Context, Result, anyhow};
|
||||
use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
|
||||
use edit_prediction::{EditPredictionStore, EvalCache, EvalCacheEntryKind, EvalCacheKey};
|
||||
use futures::StreamExt as _;
|
||||
use gpui::{AppContext, AsyncApp, Entity};
|
||||
use project::Project;
|
||||
use project::buffer_store::BufferStoreEvent;
|
||||
use serde::Deserialize;
|
||||
use std::fs;
|
||||
use std::io::{IsTerminal, Write};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
pub async fn run_predict(
|
||||
args: PredictArguments,
|
||||
app_state: &Arc<ZetaCliAppState>,
|
||||
cx: &mut AsyncApp,
|
||||
pub async fn run_prediction(
|
||||
example: &mut Example,
|
||||
provider: Option<PredictionProvider>,
|
||||
repetition_count: usize,
|
||||
app_state: Arc<EpAppState>,
|
||||
mut cx: AsyncApp,
|
||||
) {
|
||||
let example = NamedExample::load(args.example_path).unwrap();
|
||||
let project = example.setup_project(app_state, cx).await.unwrap();
|
||||
let store = setup_store(args.options.provider, &project, app_state, cx).unwrap();
|
||||
let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
|
||||
let result = perform_predict(example, project, store, None, args.options, cx)
|
||||
.await
|
||||
if !example.predictions.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
run_load_project(example, app_state.clone(), cx.clone()).await;
|
||||
run_context_retrieval(example, app_state.clone(), cx.clone()).await;
|
||||
|
||||
let provider = provider.unwrap();
|
||||
|
||||
if matches!(provider, PredictionProvider::Teacher) {
|
||||
if example.prompt.is_none() {
|
||||
run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await;
|
||||
}
|
||||
|
||||
let batched = true;
|
||||
return predict_anthropic(example, repetition_count, batched).await;
|
||||
}
|
||||
|
||||
if matches!(
|
||||
provider,
|
||||
PredictionProvider::Zeta1 | PredictionProvider::Zeta2
|
||||
) {
|
||||
static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
|
||||
AUTHENTICATED
|
||||
.get_or_init(|| {
|
||||
let client = app_state.client.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
client
|
||||
.sign_in_with_optional_connect(true, cx)
|
||||
.await
|
||||
.unwrap();
|
||||
})
|
||||
.shared()
|
||||
})
|
||||
.clone()
|
||||
.await;
|
||||
}
|
||||
|
||||
let ep_store = cx
|
||||
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
|
||||
.unwrap();
|
||||
result.write(args.format, std::io::stdout()).unwrap();
|
||||
|
||||
print_run_data_dir(true, std::io::stdout().is_terminal());
|
||||
}
|
||||
ep_store
|
||||
.update(&mut cx, |store, _cx| {
|
||||
let model = match provider {
|
||||
PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
|
||||
PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
|
||||
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
|
||||
PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
|
||||
PredictionProvider::Teacher => unreachable!(),
|
||||
};
|
||||
store.set_edit_prediction_model(model);
|
||||
})
|
||||
.unwrap();
|
||||
let state = example.state.as_ref().unwrap();
|
||||
let run_dir = RUN_DIR.join(&example.name);
|
||||
|
||||
pub fn setup_store(
|
||||
provider: PredictionProvider,
|
||||
project: &Entity<Project>,
|
||||
app_state: &Arc<ZetaCliAppState>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<Entity<EditPredictionStore>> {
|
||||
let store = cx.new(|cx| {
|
||||
edit_prediction::EditPredictionStore::new(
|
||||
app_state.client.clone(),
|
||||
app_state.user_store.clone(),
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
let updated_example = Arc::new(Mutex::new(example.clone()));
|
||||
let current_run_ix = Arc::new(AtomicUsize::new(0));
|
||||
|
||||
store.update(cx, |store, _cx| {
|
||||
let model = match provider {
|
||||
PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
|
||||
PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
|
||||
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
|
||||
};
|
||||
store.set_edit_prediction_model(model);
|
||||
})?;
|
||||
let mut debug_rx = ep_store
|
||||
.update(&mut cx, |store, cx| store.debug_info(&state.project, cx))
|
||||
.unwrap();
|
||||
let debug_task = cx.background_spawn({
|
||||
let updated_example = updated_example.clone();
|
||||
let current_run_ix = current_run_ix.clone();
|
||||
let run_dir = run_dir.clone();
|
||||
async move {
|
||||
while let Some(event) = debug_rx.next().await {
|
||||
let run_ix = current_run_ix.load(SeqCst);
|
||||
let mut updated_example = updated_example.lock().unwrap();
|
||||
|
||||
let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
|
||||
let run_dir = if repetition_count > 1 {
|
||||
run_dir.join(format!("{:03}", run_ix))
|
||||
} else {
|
||||
run_dir.clone()
|
||||
};
|
||||
|
||||
cx.subscribe(&buffer_store, {
|
||||
let project = project.clone();
|
||||
let store = store.clone();
|
||||
move |_, event, cx| match event {
|
||||
BufferStoreEvent::BufferAdded(buffer) => {
|
||||
store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
})?
|
||||
.detach();
|
||||
match event {
|
||||
DebugEvent::EditPredictionStarted(request) => {
|
||||
assert_eq!(updated_example.predictions.len(), run_ix + 1);
|
||||
|
||||
anyhow::Ok(store)
|
||||
}
|
||||
|
||||
pub async fn perform_predict(
|
||||
example: NamedExample,
|
||||
project: Entity<Project>,
|
||||
store: Entity<EditPredictionStore>,
|
||||
repetition_ix: Option<u16>,
|
||||
options: PredictionOptions,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<PredictionDetails> {
|
||||
let mut cache_mode = options.cache;
|
||||
if repetition_ix.is_some() {
|
||||
if cache_mode != CacheMode::Auto && cache_mode != CacheMode::Skip {
|
||||
panic!("Repetitions are not supported in Auto cache mode");
|
||||
} else {
|
||||
cache_mode = CacheMode::Skip;
|
||||
}
|
||||
} else if cache_mode == CacheMode::Auto {
|
||||
cache_mode = CacheMode::Requests;
|
||||
}
|
||||
|
||||
let mut example_run_dir = RUN_DIR.join(&example.file_name());
|
||||
if let Some(repetition_ix) = repetition_ix {
|
||||
example_run_dir = example_run_dir.join(format!("{:03}", repetition_ix));
|
||||
}
|
||||
fs::create_dir_all(&example_run_dir)?;
|
||||
if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
|
||||
fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
std::os::unix::fs::symlink(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
|
||||
.context("creating latest link")?;
|
||||
|
||||
#[cfg(windows)]
|
||||
std::os::windows::fs::symlink_dir(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
|
||||
.context("creating latest link")?;
|
||||
|
||||
store.update(cx, |store, _cx| {
|
||||
store.with_eval_cache(Arc::new(RunCache {
|
||||
example_run_dir: example_run_dir.clone(),
|
||||
cache_mode,
|
||||
}));
|
||||
})?;
|
||||
|
||||
let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
|
||||
|
||||
let result = Arc::new(Mutex::new(PredictionDetails::new(example_run_dir.clone())));
|
||||
|
||||
let prompt_format = options.zeta2.prompt_format;
|
||||
|
||||
store.update(cx, |store, _cx| {
|
||||
let mut options = store.options().clone();
|
||||
options.prompt_format = prompt_format.into();
|
||||
store.set_options(options);
|
||||
})?;
|
||||
|
||||
let mut debug_task = gpui::Task::ready(Ok(()));
|
||||
|
||||
if options.provider == crate::PredictionProvider::Zeta2 {
|
||||
let mut debug_rx = store.update(cx, |store, _| store.debug_info())?;
|
||||
|
||||
debug_task = cx.background_spawn({
|
||||
let result = result.clone();
|
||||
async move {
|
||||
let mut start_time = None;
|
||||
let mut retrieval_finished_at = None;
|
||||
while let Some(event) = debug_rx.next().await {
|
||||
match event {
|
||||
edit_prediction::DebugEvent::ContextRetrievalStarted(info) => {
|
||||
start_time = Some(info.timestamp);
|
||||
fs::write(
|
||||
example_run_dir.join("search_prompt.md"),
|
||||
&info.search_prompt,
|
||||
)?;
|
||||
if let Some(prompt) = request.prompt {
|
||||
fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
|
||||
}
|
||||
edit_prediction::DebugEvent::ContextRetrievalFinished(info) => {
|
||||
retrieval_finished_at = Some(info.timestamp);
|
||||
for (key, value) in &info.metadata {
|
||||
if *key == "search_queries" {
|
||||
fs::write(
|
||||
example_run_dir.join("search_queries.json"),
|
||||
value.as_bytes(),
|
||||
)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
DebugEvent::EditPredictionFinished(request) => {
|
||||
assert_eq!(updated_example.predictions.len(), run_ix + 1);
|
||||
|
||||
if let Some(output) = request.model_output {
|
||||
fs::write(run_dir.join("prediction_response.md"), &output)?;
|
||||
updated_example
|
||||
.predictions
|
||||
.last_mut()
|
||||
.unwrap()
|
||||
.actual_output = output;
|
||||
}
|
||||
edit_prediction::DebugEvent::EditPredictionRequested(request) => {
|
||||
let prediction_started_at = Instant::now();
|
||||
start_time.get_or_insert(prediction_started_at);
|
||||
let prompt = request.local_prompt.unwrap_or_default();
|
||||
fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?;
|
||||
|
||||
{
|
||||
let mut result = result.lock().unwrap();
|
||||
result.prompt_len = prompt.chars().count();
|
||||
|
||||
for included_file in request.inputs.included_files {
|
||||
let insertions =
|
||||
vec![(request.inputs.cursor_point, CURSOR_MARKER)];
|
||||
result.excerpts.extend(included_file.excerpts.iter().map(
|
||||
|excerpt| ActualExcerpt {
|
||||
path: included_file.path.components().skip(1).collect(),
|
||||
text: String::from(excerpt.text.as_ref()),
|
||||
},
|
||||
));
|
||||
write_codeblock(
|
||||
&included_file.path,
|
||||
included_file.excerpts.iter(),
|
||||
if included_file.path == request.inputs.cursor_path {
|
||||
&insertions
|
||||
} else {
|
||||
&[]
|
||||
},
|
||||
included_file.max_row,
|
||||
false,
|
||||
&mut result.excerpts_text,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let response =
|
||||
request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
|
||||
let response =
|
||||
edit_prediction::open_ai_response::text_from_response(response)
|
||||
.unwrap_or_default();
|
||||
let prediction_finished_at = Instant::now();
|
||||
fs::write(example_run_dir.join("prediction_response.md"), &response)?;
|
||||
|
||||
let mut result = result.lock().unwrap();
|
||||
result.generated_len = response.chars().count();
|
||||
result.retrieval_time =
|
||||
retrieval_finished_at.unwrap() - start_time.unwrap();
|
||||
result.prediction_time = prediction_finished_at - prediction_started_at;
|
||||
result.total_time = prediction_finished_at - start_time.unwrap();
|
||||
|
||||
if run_ix >= repetition_count {
|
||||
break;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
anyhow::Ok(())
|
||||
}
|
||||
});
|
||||
anyhow::Ok(())
|
||||
}
|
||||
});
|
||||
|
||||
store.update(cx, |store, cx| {
|
||||
store.refresh_context(&project, &cursor_buffer, cursor_anchor, cx)
|
||||
})?;
|
||||
}
|
||||
|
||||
let prediction = store
|
||||
.update(cx, |store, cx| {
|
||||
store.request_prediction(
|
||||
&project,
|
||||
&cursor_buffer,
|
||||
cursor_anchor,
|
||||
cloud_llm_client::PredictEditsRequestTrigger::Cli,
|
||||
cx,
|
||||
)
|
||||
})?
|
||||
.await?;
|
||||
|
||||
debug_task.await?;
|
||||
|
||||
let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
|
||||
|
||||
result.diff = prediction
|
||||
.and_then(|prediction| {
|
||||
let prediction = prediction.prediction.ok()?;
|
||||
prediction.edit_preview.as_unified_diff(&prediction.edits)
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
anyhow::Ok(result)
|
||||
}
|
||||
|
||||
struct RunCache {
|
||||
cache_mode: CacheMode,
|
||||
example_run_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl RunCache {
|
||||
fn output_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
|
||||
CACHE_DIR.join(format!("{kind}_out_{key:x}.json",))
|
||||
}
|
||||
|
||||
fn input_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
|
||||
CACHE_DIR.join(format!("{kind}_in_{key:x}.json",))
|
||||
}
|
||||
|
||||
fn link_to_run(&self, key: &EvalCacheKey) {
|
||||
let output_link_path = self.example_run_dir.join(format!("{}_out.json", key.0));
|
||||
fs::hard_link(Self::output_cache_path(key), &output_link_path).unwrap();
|
||||
|
||||
let input_link_path = self.example_run_dir.join(format!("{}_in.json", key.0));
|
||||
fs::hard_link(Self::input_cache_path(key), &input_link_path).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
impl EvalCache for RunCache {
|
||||
fn read(&self, key: EvalCacheKey) -> Option<String> {
|
||||
let path = RunCache::output_cache_path(&key);
|
||||
|
||||
if path.exists() {
|
||||
let use_cache = match key.0 {
|
||||
EvalCacheEntryKind::Search => self.cache_mode.use_cached_search_results(),
|
||||
EvalCacheEntryKind::Context | EvalCacheEntryKind::Prediction => {
|
||||
self.cache_mode.use_cached_llm_responses()
|
||||
}
|
||||
};
|
||||
if use_cache {
|
||||
log::info!("Using cache entry: {}", path.display());
|
||||
self.link_to_run(&key);
|
||||
Some(fs::read_to_string(path).unwrap())
|
||||
} else {
|
||||
log::trace!("Skipping cached entry: {}", path.display());
|
||||
None
|
||||
}
|
||||
} else if matches!(self.cache_mode, CacheMode::Force) {
|
||||
panic!(
|
||||
"No cached entry found for {:?}. Run without `--cache force` at least once.",
|
||||
key.0
|
||||
);
|
||||
for ix in 0..repetition_count {
|
||||
current_run_ix.store(ix, SeqCst);
|
||||
let run_dir = if repetition_count > 1 {
|
||||
run_dir.join(format!("{:03}", ix))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn write(&self, key: EvalCacheKey, input: &str, output: &str) {
|
||||
fs::create_dir_all(&*CACHE_DIR).unwrap();
|
||||
|
||||
let input_path = RunCache::input_cache_path(&key);
|
||||
fs::write(&input_path, input).unwrap();
|
||||
|
||||
let output_path = RunCache::output_cache_path(&key);
|
||||
log::trace!("Writing cache entry: {}", output_path.display());
|
||||
fs::write(&output_path, output).unwrap();
|
||||
|
||||
self.link_to_run(&key);
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct PredictionDetails {
|
||||
pub diff: String,
|
||||
pub excerpts: Vec<ActualExcerpt>,
|
||||
pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
|
||||
pub retrieval_time: Duration,
|
||||
pub prediction_time: Duration,
|
||||
pub total_time: Duration,
|
||||
pub run_example_dir: PathBuf,
|
||||
pub prompt_len: usize,
|
||||
pub generated_len: usize,
|
||||
}
|
||||
|
||||
impl PredictionDetails {
|
||||
pub fn new(run_example_dir: PathBuf) -> Self {
|
||||
Self {
|
||||
diff: Default::default(),
|
||||
excerpts: Default::default(),
|
||||
excerpts_text: Default::default(),
|
||||
retrieval_time: Default::default(),
|
||||
prediction_time: Default::default(),
|
||||
total_time: Default::default(),
|
||||
run_example_dir,
|
||||
prompt_len: 0,
|
||||
generated_len: 0,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
|
||||
let formatted = match format {
|
||||
PredictionsOutputFormat::Md => self.to_markdown(),
|
||||
PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
|
||||
PredictionsOutputFormat::Diff => self.diff.clone(),
|
||||
run_dir.clone()
|
||||
};
|
||||
|
||||
Ok(out.write_all(formatted.as_bytes())?)
|
||||
fs::create_dir_all(&run_dir).unwrap();
|
||||
if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
|
||||
fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR).unwrap();
|
||||
}
|
||||
#[cfg(unix)]
|
||||
std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
|
||||
#[cfg(windows)]
|
||||
std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
|
||||
|
||||
updated_example
|
||||
.lock()
|
||||
.unwrap()
|
||||
.predictions
|
||||
.push(ExamplePrediction {
|
||||
actual_patch: String::new(),
|
||||
actual_output: String::new(),
|
||||
provider,
|
||||
});
|
||||
|
||||
let prediction = ep_store
|
||||
.update(&mut cx, |store, cx| {
|
||||
store.request_prediction(
|
||||
&state.project,
|
||||
&state.buffer,
|
||||
state.cursor_position,
|
||||
cloud_llm_client::PredictEditsRequestTrigger::Cli,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.unwrap()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
updated_example
|
||||
.lock()
|
||||
.unwrap()
|
||||
.predictions
|
||||
.last_mut()
|
||||
.unwrap()
|
||||
.actual_patch = prediction
|
||||
.and_then(|prediction| {
|
||||
let prediction = prediction.prediction.ok()?;
|
||||
prediction.edit_preview.as_unified_diff(&prediction.edits)
|
||||
})
|
||||
.unwrap_or_default();
|
||||
}
|
||||
|
||||
pub fn to_markdown(&self) -> String {
|
||||
format!(
|
||||
"## Excerpts\n\n\
|
||||
{}\n\n\
|
||||
## Prediction\n\n\
|
||||
{}\n\n\
|
||||
## Time\n\n\
|
||||
Retrieval: {}ms\n\
|
||||
Prediction: {}ms\n\n\
|
||||
Total: {}ms\n",
|
||||
self.excerpts_text,
|
||||
self.diff,
|
||||
self.retrieval_time.as_millis(),
|
||||
self.prediction_time.as_millis(),
|
||||
self.total_time.as_millis(),
|
||||
)
|
||||
ep_store
|
||||
.update(&mut cx, |store, _| {
|
||||
store.remove_project(&state.project);
|
||||
})
|
||||
.unwrap();
|
||||
debug_task.await.unwrap();
|
||||
|
||||
*example = Arc::into_inner(updated_example)
|
||||
.unwrap()
|
||||
.into_inner()
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batched: bool) {
|
||||
let llm_model_name = "claude-sonnet-4-5";
|
||||
let max_tokens = 16384;
|
||||
let llm_client = if batched {
|
||||
AnthropicClient::batch(&crate::paths::LLM_CACHE_DB.as_ref())
|
||||
} else {
|
||||
AnthropicClient::plain()
|
||||
};
|
||||
let llm_client = llm_client.expect("Failed to create LLM client");
|
||||
|
||||
let prompt = example
|
||||
.prompt
|
||||
.as_ref()
|
||||
.unwrap_or_else(|| panic!("Prompt is required for an example {}", &example.name));
|
||||
|
||||
let messages = vec![anthropic::Message {
|
||||
role: anthropic::Role::User,
|
||||
content: vec![anthropic::RequestContent::Text {
|
||||
text: prompt.input.clone(),
|
||||
cache_control: None,
|
||||
}],
|
||||
}];
|
||||
|
||||
let Some(response) = llm_client
|
||||
.generate(llm_model_name, max_tokens, messages)
|
||||
.await
|
||||
.unwrap()
|
||||
else {
|
||||
// Request stashed for batched processing
|
||||
return;
|
||||
};
|
||||
|
||||
let actual_output = response
|
||||
.content
|
||||
.into_iter()
|
||||
.filter_map(|content| match content {
|
||||
anthropic::ResponseContent::Text { text } => Some(text),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n");
|
||||
|
||||
let actual_patch = TeacherPrompt::parse(example, &actual_output);
|
||||
|
||||
let prediction = ExamplePrediction {
|
||||
actual_patch,
|
||||
actual_output,
|
||||
provider: PredictionProvider::Teacher,
|
||||
};
|
||||
|
||||
example.predictions.push(prediction);
|
||||
}
|
||||
|
||||
pub async fn sync_batches(provider: &PredictionProvider) {
|
||||
match provider {
|
||||
PredictionProvider::Teacher => {
|
||||
let cache_path = crate::paths::LLM_CACHE_DB.as_ref();
|
||||
let llm_client =
|
||||
AnthropicClient::batch(cache_path).expect("Failed to create LLM client");
|
||||
llm_client
|
||||
.sync_batches()
|
||||
.await
|
||||
.expect("Failed to sync batches");
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,106 +1,136 @@
|
||||
use anyhow::{Result, anyhow};
|
||||
use futures::channel::mpsc;
|
||||
use futures::{FutureExt as _, StreamExt as _};
|
||||
use crate::{
|
||||
example::{Example, ExampleContext},
|
||||
headless::EpAppState,
|
||||
load_project::run_load_project,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use collections::HashSet;
|
||||
use edit_prediction::{DebugEvent, EditPredictionStore};
|
||||
use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
|
||||
use gpui::{AsyncApp, Entity, Task};
|
||||
use language::{Buffer, LanguageId, LanguageNotFound, LanguageServerId, ParseStatus};
|
||||
use project::lsp_store::OpenLspBufferHandle;
|
||||
use project::{Project, ProjectPath, Worktree};
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use util::rel_path::RelPath;
|
||||
use language::{Buffer, LanguageNotFound};
|
||||
use project::Project;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
pub fn open_buffer(
|
||||
project: Entity<Project>,
|
||||
worktree: Entity<Worktree>,
|
||||
path: Arc<RelPath>,
|
||||
cx: &AsyncApp,
|
||||
) -> Task<Result<Entity<Buffer>>> {
|
||||
cx.spawn(async move |cx| {
|
||||
let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
|
||||
worktree_id: worktree.id(),
|
||||
path,
|
||||
})?;
|
||||
pub async fn run_context_retrieval(
|
||||
example: &mut Example,
|
||||
app_state: Arc<EpAppState>,
|
||||
mut cx: AsyncApp,
|
||||
) {
|
||||
if example.context.is_some() {
|
||||
return;
|
||||
}
|
||||
|
||||
let buffer = project
|
||||
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
|
||||
.await?;
|
||||
run_load_project(example, app_state.clone(), cx.clone()).await;
|
||||
|
||||
let mut parse_status = buffer.read_with(cx, |buffer, _cx| buffer.parse_status())?;
|
||||
while *parse_status.borrow() != ParseStatus::Idle {
|
||||
parse_status.changed().await?;
|
||||
let state = example.state.as_ref().unwrap();
|
||||
let project = state.project.clone();
|
||||
|
||||
let _lsp_handle = project
|
||||
.update(&mut cx, |project, cx| {
|
||||
project.register_buffer_with_language_servers(&state.buffer, cx)
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
wait_for_language_server_to_start(example, &project, &state.buffer, &mut cx).await;
|
||||
|
||||
let ep_store = cx
|
||||
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
|
||||
.unwrap();
|
||||
|
||||
let mut events = ep_store
|
||||
.update(&mut cx, |store, cx| {
|
||||
store.register_buffer(&state.buffer, &project, cx);
|
||||
store.set_use_context(true);
|
||||
store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
|
||||
store.debug_info(&project, cx)
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
while let Some(event) = events.next().await {
|
||||
match event {
|
||||
DebugEvent::ContextRetrievalFinished(_) => {
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(buffer)
|
||||
})
|
||||
let context_files = ep_store
|
||||
.update(&mut cx, |store, cx| store.context_for_project(&project, cx))
|
||||
.unwrap();
|
||||
|
||||
example.context = Some(ExampleContext {
|
||||
files: context_files,
|
||||
});
|
||||
}
|
||||
|
||||
pub async fn open_buffer_with_language_server(
|
||||
project: Entity<Project>,
|
||||
worktree: Entity<Worktree>,
|
||||
path: Arc<RelPath>,
|
||||
ready_languages: &mut HashSet<LanguageId>,
|
||||
async fn wait_for_language_server_to_start(
|
||||
example: &Example,
|
||||
project: &Entity<Project>,
|
||||
buffer: &Entity<Buffer>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<(OpenLspBufferHandle, LanguageServerId, Entity<Buffer>)> {
|
||||
let buffer = open_buffer(project.clone(), worktree, path.clone(), cx).await?;
|
||||
|
||||
let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
|
||||
(
|
||||
project.register_buffer_with_language_servers(&buffer, cx),
|
||||
project.path_style(cx),
|
||||
)
|
||||
})?;
|
||||
|
||||
let language_registry = project.read_with(cx, |project, _| project.languages().clone())?;
|
||||
) {
|
||||
let language_registry = project
|
||||
.read_with(cx, |project, _| project.languages().clone())
|
||||
.unwrap();
|
||||
let result = language_registry
|
||||
.load_language_for_file_path(path.as_std_path())
|
||||
.load_language_for_file_path(&example.cursor_path)
|
||||
.await;
|
||||
|
||||
if let Err(error) = result
|
||||
&& !error.is::<LanguageNotFound>()
|
||||
{
|
||||
anyhow::bail!(error);
|
||||
panic!("Failed to load language for file path: {}", error);
|
||||
}
|
||||
|
||||
let Some(language_id) = buffer.read_with(cx, |buffer, _cx| {
|
||||
buffer.language().map(|language| language.id())
|
||||
})?
|
||||
let Some(language_id) = buffer
|
||||
.read_with(cx, |buffer, _cx| {
|
||||
buffer.language().map(|language| language.id())
|
||||
})
|
||||
.unwrap()
|
||||
else {
|
||||
return Err(anyhow!("No language for {}", path.display(path_style)));
|
||||
panic!("No language for {:?}", example.cursor_path);
|
||||
};
|
||||
|
||||
let log_prefix = format!("{} | ", path.display(path_style));
|
||||
let mut ready_languages = HashSet::default();
|
||||
let log_prefix = format!("{} | ", example.name);
|
||||
if !ready_languages.contains(&language_id) {
|
||||
wait_for_lang_server(&project, &buffer, log_prefix, cx).await?;
|
||||
wait_for_lang_server(&project, &buffer, log_prefix, cx)
|
||||
.await
|
||||
.unwrap();
|
||||
ready_languages.insert(language_id);
|
||||
}
|
||||
|
||||
let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?;
|
||||
let lsp_store = project
|
||||
.read_with(cx, |project, _cx| project.lsp_store())
|
||||
.unwrap();
|
||||
|
||||
// hacky wait for buffer to be registered with the language server
|
||||
for _ in 0..100 {
|
||||
let Some(language_server_id) = lsp_store.update(cx, |lsp_store, cx| {
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
lsp_store
|
||||
.language_servers_for_local_buffer(&buffer, cx)
|
||||
.next()
|
||||
.map(|(_, language_server)| language_server.server_id())
|
||||
if lsp_store
|
||||
.update(cx, |lsp_store, cx| {
|
||||
buffer.update(cx, |buffer, cx| {
|
||||
lsp_store
|
||||
.language_servers_for_local_buffer(&buffer, cx)
|
||||
.next()
|
||||
.map(|(_, language_server)| language_server.server_id())
|
||||
})
|
||||
})
|
||||
})?
|
||||
else {
|
||||
.unwrap()
|
||||
.is_some()
|
||||
{
|
||||
return;
|
||||
} else {
|
||||
cx.background_executor()
|
||||
.timer(Duration::from_millis(10))
|
||||
.await;
|
||||
continue;
|
||||
};
|
||||
|
||||
return Ok((lsp_open_handle, language_server_id, buffer));
|
||||
}
|
||||
}
|
||||
|
||||
return Err(anyhow!("No language server found for buffer"));
|
||||
panic!("No language server found for buffer");
|
||||
}
|
||||
|
||||
// TODO: Dedupe with similar function in crates/eval/src/instance.rs
|
||||
pub fn wait_for_lang_server(
|
||||
project: &Entity<Project>,
|
||||
buffer: &Entity<Buffer>,
|
||||
119
crates/edit_prediction_cli/src/score.rs
Normal file
119
crates/edit_prediction_cli/src/score.rs
Normal file
@@ -0,0 +1,119 @@
|
||||
use crate::{
|
||||
PredictArgs,
|
||||
example::{Example, ExampleScore},
|
||||
headless::EpAppState,
|
||||
metrics::{self, ClassificationMetrics},
|
||||
predict::run_prediction,
|
||||
};
|
||||
use edit_prediction::udiff::DiffLine;
|
||||
use gpui::AsyncApp;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub async fn run_scoring(
|
||||
example: &mut Example,
|
||||
args: &PredictArgs,
|
||||
app_state: Arc<EpAppState>,
|
||||
cx: AsyncApp,
|
||||
) {
|
||||
run_prediction(
|
||||
example,
|
||||
Some(args.provider),
|
||||
args.repetitions,
|
||||
app_state,
|
||||
cx,
|
||||
)
|
||||
.await;
|
||||
|
||||
let expected_patch = parse_patch(&example.expected_patch);
|
||||
|
||||
let mut scores = vec![];
|
||||
|
||||
for pred in &example.predictions {
|
||||
let actual_patch = parse_patch(&pred.actual_patch);
|
||||
let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
|
||||
let delta_chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch) as f32;
|
||||
|
||||
scores.push(ExampleScore {
|
||||
delta_chr_f,
|
||||
line_match,
|
||||
});
|
||||
}
|
||||
|
||||
example.score = scores;
|
||||
}
|
||||
|
||||
fn parse_patch(patch: &str) -> Vec<DiffLine<'_>> {
|
||||
patch.lines().map(DiffLine::parse).collect()
|
||||
}
|
||||
|
||||
pub fn print_report(examples: &[Example]) {
|
||||
eprintln!(
|
||||
"──────────────────────────────────────────────────────────────────────────────────────"
|
||||
);
|
||||
eprintln!(
|
||||
"{:<30} {:>4} {:>4} {:>4} {:>10} {:>8} {:>8} {:>10}",
|
||||
"Example name", "TP", "FP", "FN", "Precision", "Recall", "F1", "DeltaChrF"
|
||||
);
|
||||
eprintln!(
|
||||
"──────────────────────────────────────────────────────────────────────────────────────"
|
||||
);
|
||||
|
||||
let mut all_line_match_scores = Vec::new();
|
||||
let mut all_delta_chr_f_scores = Vec::new();
|
||||
|
||||
for example in examples {
|
||||
for score in example.score.iter() {
|
||||
let line_match = &score.line_match;
|
||||
|
||||
eprintln!(
|
||||
"{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
|
||||
truncate_name(&example.name, 30),
|
||||
line_match.true_positives,
|
||||
line_match.false_positives,
|
||||
line_match.false_negatives,
|
||||
line_match.precision() * 100.0,
|
||||
line_match.recall() * 100.0,
|
||||
line_match.f1_score() * 100.0,
|
||||
score.delta_chr_f
|
||||
);
|
||||
|
||||
all_line_match_scores.push(line_match.clone());
|
||||
all_delta_chr_f_scores.push(score.delta_chr_f);
|
||||
}
|
||||
}
|
||||
|
||||
eprintln!(
|
||||
"──────────────────────────────────────────────────────────────────────────────────────"
|
||||
);
|
||||
|
||||
if !all_line_match_scores.is_empty() {
|
||||
let total_line_match = ClassificationMetrics::aggregate(all_line_match_scores.iter());
|
||||
let avg_delta_chr_f: f32 =
|
||||
all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
|
||||
|
||||
eprintln!(
|
||||
"{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
|
||||
"TOTAL",
|
||||
total_line_match.true_positives,
|
||||
total_line_match.false_positives,
|
||||
total_line_match.false_negatives,
|
||||
total_line_match.precision() * 100.0,
|
||||
total_line_match.recall() * 100.0,
|
||||
total_line_match.f1_score() * 100.0,
|
||||
avg_delta_chr_f
|
||||
);
|
||||
eprintln!(
|
||||
"──────────────────────────────────────────────────────────────────────────────────────"
|
||||
);
|
||||
}
|
||||
|
||||
eprintln!("\n");
|
||||
}
|
||||
|
||||
fn truncate_name(name: &str, max_len: usize) -> String {
|
||||
if name.len() <= max_len {
|
||||
name.to_string()
|
||||
} else {
|
||||
format!("{}...", &name[..max_len - 3])
|
||||
}
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
use std::{fmt, fmt::Display, path::Path, str::FromStr, sync::Arc};
|
||||
|
||||
use ::util::{paths::PathStyle, rel_path::RelPath};
|
||||
use anyhow::{Result, anyhow};
|
||||
use language::Point;
|
||||
use serde::{Deserialize, Deserializer, Serialize, Serializer};
|
||||
|
||||
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
|
||||
pub struct SourceLocation {
|
||||
pub path: Arc<RelPath>,
|
||||
pub point: Point,
|
||||
}
|
||||
|
||||
impl Serialize for SourceLocation {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
serializer.serialize_str(&self.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for SourceLocation {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
let s = String::deserialize(deserializer)?;
|
||||
s.parse().map_err(serde::de::Error::custom)
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for SourceLocation {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(
|
||||
f,
|
||||
"{}:{}:{}",
|
||||
self.path.display(PathStyle::Posix),
|
||||
self.point.row + 1,
|
||||
self.point.column + 1
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for SourceLocation {
|
||||
type Err = anyhow::Error;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self> {
|
||||
let parts: Vec<&str> = s.split(':').collect();
|
||||
if parts.len() != 3 {
|
||||
return Err(anyhow!(
|
||||
"Invalid source location. Expected 'file.rs:line:column', got '{}'",
|
||||
s
|
||||
));
|
||||
}
|
||||
|
||||
let path = RelPath::new(Path::new(&parts[0]), PathStyle::local())?.into_arc();
|
||||
let line: u32 = parts[1]
|
||||
.parse()
|
||||
.map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
|
||||
let column: u32 = parts[2]
|
||||
.parse()
|
||||
.map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
|
||||
|
||||
// Convert from 1-based to 0-based indexing
|
||||
let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
|
||||
|
||||
Ok(SourceLocation { path, point })
|
||||
}
|
||||
}
|
||||
@@ -46,3 +46,7 @@ Output example:
|
||||
## Code Context
|
||||
|
||||
{{context}}
|
||||
|
||||
## Editable region
|
||||
|
||||
{{editable_region}}
|
||||
@@ -1,89 +0,0 @@
|
||||
use std::path::Path;
|
||||
|
||||
use crate::{source_location::SourceLocation, training::teacher::TeacherModel};
|
||||
|
||||
#[derive(Debug, Clone, Default, clap::ValueEnum)]
|
||||
pub enum ContextType {
|
||||
#[default]
|
||||
CurrentFile,
|
||||
}
|
||||
|
||||
const MAX_CONTEXT_SIZE: usize = 32768;
|
||||
|
||||
pub fn collect_context(
|
||||
context_type: &ContextType,
|
||||
worktree_dir: &Path,
|
||||
cursor: SourceLocation,
|
||||
) -> String {
|
||||
let context = match context_type {
|
||||
ContextType::CurrentFile => {
|
||||
let file_path = worktree_dir.join(cursor.path.as_std_path());
|
||||
let context = std::fs::read_to_string(&file_path).unwrap_or_default();
|
||||
|
||||
let context = add_special_tags(&context, worktree_dir, cursor);
|
||||
context
|
||||
}
|
||||
};
|
||||
|
||||
let region_end_offset = context.find(TeacherModel::REGION_END);
|
||||
|
||||
if context.len() <= MAX_CONTEXT_SIZE {
|
||||
return context;
|
||||
}
|
||||
|
||||
if let Some(region_end_offset) = region_end_offset
|
||||
&& region_end_offset + TeacherModel::REGION_END.len() > MAX_CONTEXT_SIZE
|
||||
{
|
||||
let to_truncate = context.len() - MAX_CONTEXT_SIZE;
|
||||
format!(
|
||||
"[...{} bytes truncated]\n{}\n",
|
||||
to_truncate,
|
||||
&context[to_truncate..]
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"{}\n[...{} bytes truncated]\n",
|
||||
&context[..MAX_CONTEXT_SIZE],
|
||||
context.len() - MAX_CONTEXT_SIZE
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Add <|editable_region_start/end|> tags
|
||||
fn add_special_tags(context: &str, worktree_dir: &Path, cursor: SourceLocation) -> String {
|
||||
let path = worktree_dir.join(cursor.path.as_std_path());
|
||||
let file = std::fs::read_to_string(&path).unwrap_or_default();
|
||||
let lines = file.lines().collect::<Vec<_>>();
|
||||
let cursor_row = cursor.point.row as usize;
|
||||
let start_line = cursor_row.saturating_sub(TeacherModel::LEFT_CONTEXT_SIZE);
|
||||
let end_line = (cursor_row + TeacherModel::RIGHT_CONTEXT_SIZE).min(lines.len());
|
||||
|
||||
let snippet = lines[start_line..end_line].join("\n");
|
||||
|
||||
if context.contains(&snippet) {
|
||||
let mut cursor_line = lines[cursor_row].to_string();
|
||||
cursor_line.insert_str(cursor.point.column as usize, TeacherModel::USER_CURSOR);
|
||||
|
||||
let mut snippet_with_tags_lines = vec![];
|
||||
snippet_with_tags_lines.push(TeacherModel::REGION_START);
|
||||
snippet_with_tags_lines.extend(&lines[start_line..cursor_row]);
|
||||
snippet_with_tags_lines.push(&cursor_line);
|
||||
snippet_with_tags_lines.extend(&lines[cursor_row + 1..end_line]);
|
||||
snippet_with_tags_lines.push(TeacherModel::REGION_END);
|
||||
let snippet_with_tags = snippet_with_tags_lines.join("\n");
|
||||
|
||||
context.replace(&snippet, &snippet_with_tags)
|
||||
} else {
|
||||
log::warn!(
|
||||
"Can't find area around the cursor in the context; proceeding without special tags"
|
||||
);
|
||||
context.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn strip_special_tags(context: &str) -> String {
|
||||
context
|
||||
.replace(TeacherModel::REGION_START, "")
|
||||
.replace(TeacherModel::REGION_END, "")
|
||||
.replace(TeacherModel::USER_CURSOR, "")
|
||||
}
|
||||
@@ -1,94 +0,0 @@
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::{
|
||||
DistillArguments,
|
||||
example::Example,
|
||||
source_location::SourceLocation,
|
||||
training::{
|
||||
context::ContextType,
|
||||
llm_client::LlmClient,
|
||||
teacher::{TeacherModel, TeacherOutput},
|
||||
},
|
||||
};
|
||||
use anyhow::Result;
|
||||
use reqwest_client::ReqwestClient;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct SplitCommit {
|
||||
repo_url: String,
|
||||
commit_sha: String,
|
||||
edit_history: String,
|
||||
expected_patch: String,
|
||||
cursor_position: String,
|
||||
}
|
||||
|
||||
pub async fn run_distill(arguments: DistillArguments) -> Result<()> {
|
||||
let split_commits: Vec<SplitCommit> = std::fs::read_to_string(&arguments.split_commit_dataset)
|
||||
.expect("Failed to read split commit dataset")
|
||||
.lines()
|
||||
.map(|line| serde_json::from_str(line).expect("Failed to parse JSON line"))
|
||||
.collect();
|
||||
|
||||
let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
|
||||
|
||||
let llm_client = if let Some(cache_path) = arguments.batch {
|
||||
LlmClient::batch(&cache_path, http_client)?
|
||||
} else {
|
||||
LlmClient::plain(http_client)?
|
||||
};
|
||||
|
||||
let mut teacher = TeacherModel::new(
|
||||
"claude-sonnet-4-5".to_string(),
|
||||
ContextType::CurrentFile,
|
||||
llm_client,
|
||||
);
|
||||
|
||||
let mut num_marked_for_batching = 0;
|
||||
|
||||
for commit in split_commits {
|
||||
if let Some(distilled) = distill_one(&mut teacher, commit).await? {
|
||||
println!("{}", serde_json::to_string(&distilled)?);
|
||||
} else {
|
||||
if num_marked_for_batching == 0 {
|
||||
log::warn!("Marked for batching");
|
||||
}
|
||||
num_marked_for_batching += 1;
|
||||
}
|
||||
}
|
||||
|
||||
eprintln!(
|
||||
"{} requests are marked for batching",
|
||||
num_marked_for_batching
|
||||
);
|
||||
let llm_client = teacher.client;
|
||||
llm_client.sync_batches().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn distill_one(
|
||||
teacher: &mut TeacherModel,
|
||||
commit: SplitCommit,
|
||||
) -> Result<Option<TeacherOutput>> {
|
||||
let cursor: SourceLocation = commit
|
||||
.cursor_position
|
||||
.parse()
|
||||
.expect("Failed to parse cursor position");
|
||||
|
||||
let path = cursor.path.to_rel_path_buf();
|
||||
|
||||
let example = Example {
|
||||
repository_url: commit.repo_url,
|
||||
revision: commit.commit_sha,
|
||||
uncommitted_diff: commit.edit_history.clone(),
|
||||
cursor_path: path.as_std_path().to_path_buf(),
|
||||
cursor_position: commit.cursor_position,
|
||||
edit_history: commit.edit_history, // todo: trim
|
||||
expected_patch: commit.expected_patch,
|
||||
};
|
||||
|
||||
let prediction = teacher.predict(example).await;
|
||||
|
||||
prediction
|
||||
}
|
||||
@@ -1,4 +0,0 @@
|
||||
pub mod context;
|
||||
pub mod distill;
|
||||
pub mod llm_client;
|
||||
pub mod teacher;
|
||||
@@ -1,266 +0,0 @@
|
||||
use crate::{
|
||||
example::Example,
|
||||
source_location::SourceLocation,
|
||||
training::{
|
||||
context::{ContextType, collect_context, strip_special_tags},
|
||||
llm_client::LlmClient,
|
||||
},
|
||||
};
|
||||
use anthropic::{Message, RequestContent, ResponseContent, Role};
|
||||
use anyhow::Result;
|
||||
|
||||
pub struct TeacherModel {
|
||||
pub llm_name: String,
|
||||
pub context: ContextType,
|
||||
pub client: LlmClient,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Serialize)]
|
||||
pub struct TeacherOutput {
|
||||
parsed_output: String,
|
||||
prompt: String,
|
||||
raw_llm_response: String,
|
||||
context: String,
|
||||
diff: String,
|
||||
}
|
||||
|
||||
impl TeacherModel {
|
||||
const PROMPT: &str = include_str!("teacher.prompt.md");
|
||||
pub(crate) const REGION_START: &str = "<|editable_region_start|>\n";
|
||||
pub(crate) const REGION_END: &str = "<|editable_region_end|>";
|
||||
pub(crate) const USER_CURSOR: &str = "<|user_cursor|>";
|
||||
|
||||
/// Number of lines to include before the cursor position
|
||||
pub(crate) const LEFT_CONTEXT_SIZE: usize = 5;
|
||||
|
||||
/// Number of lines to include after the cursor position
|
||||
pub(crate) const RIGHT_CONTEXT_SIZE: usize = 5;
|
||||
|
||||
/// Truncate edit history to this number of last lines
|
||||
const MAX_HISTORY_LINES: usize = 128;
|
||||
|
||||
pub fn new(llm_name: String, context: ContextType, client: LlmClient) -> Self {
|
||||
TeacherModel {
|
||||
llm_name,
|
||||
context,
|
||||
client,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn predict(&self, input: Example) -> Result<Option<TeacherOutput>> {
|
||||
let name = input.unique_name();
|
||||
let worktree_dir = input.setup_worktree(name).await?;
|
||||
let cursor: SourceLocation = input
|
||||
.cursor_position
|
||||
.parse()
|
||||
.expect("Failed to parse cursor position");
|
||||
|
||||
let context = collect_context(&self.context, &worktree_dir, cursor.clone());
|
||||
let edit_history = Self::format_edit_history(&input.edit_history);
|
||||
|
||||
let prompt = Self::PROMPT
|
||||
.replace("{{context}}", &context)
|
||||
.replace("{{edit_history}}", &edit_history);
|
||||
|
||||
let messages = vec![Message {
|
||||
role: Role::User,
|
||||
content: vec![RequestContent::Text {
|
||||
text: prompt.clone(),
|
||||
cache_control: None,
|
||||
}],
|
||||
}];
|
||||
|
||||
let Some(response) = self
|
||||
.client
|
||||
.generate(self.llm_name.clone(), 16384, messages)
|
||||
.await?
|
||||
else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let response_text = response
|
||||
.content
|
||||
.into_iter()
|
||||
.filter_map(|content| match content {
|
||||
ResponseContent::Text { text } => Some(text),
|
||||
_ => None,
|
||||
})
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n");
|
||||
|
||||
let parsed_output = self.parse_response(&response_text);
|
||||
|
||||
let original_editable_region = Self::extract_editable_region(&context);
|
||||
let context_after_edit = context.replace(&original_editable_region, &parsed_output);
|
||||
let context_after_edit = strip_special_tags(&context_after_edit);
|
||||
let context_before_edit = strip_special_tags(&context);
|
||||
let diff = language::unified_diff(&context_before_edit, &context_after_edit);
|
||||
|
||||
// zeta distill --batch batch_results.txt
|
||||
// zeta distill
|
||||
// 1. Run `zeta distill <2000 examples <- all examples>` for the first time
|
||||
// - store LLM requests in a batch, don't actual send the request
|
||||
// - send the batch (2000 requests) after all inputs are processed
|
||||
// 2. `zeta send-batches`
|
||||
// - upload the batch to Anthropic
|
||||
|
||||
// https://platform.claude.com/docs/en/build-with-claude/batch-processing
|
||||
// https://crates.io/crates/anthropic-sdk-rust
|
||||
|
||||
// - poll for results
|
||||
// - when ready, store results in cache (a database)
|
||||
// 3. `zeta distill` again
|
||||
// - use the cached results this time
|
||||
|
||||
Ok(Some(TeacherOutput {
|
||||
parsed_output,
|
||||
prompt,
|
||||
raw_llm_response: response_text,
|
||||
context,
|
||||
diff,
|
||||
}))
|
||||
}
|
||||
|
||||
fn parse_response(&self, content: &str) -> String {
|
||||
let codeblock = Self::extract_last_codeblock(content);
|
||||
let editable_region = Self::extract_editable_region(&codeblock);
|
||||
|
||||
editable_region
|
||||
}
|
||||
|
||||
/// Extract content from the last code-fenced block if any, or else return content as is
|
||||
fn extract_last_codeblock(text: &str) -> String {
|
||||
let mut last_block = None;
|
||||
let mut search_start = 0;
|
||||
|
||||
while let Some(start) = text[search_start..].find("```") {
|
||||
let start = start + search_start;
|
||||
let bytes = text.as_bytes();
|
||||
let mut backtick_end = start;
|
||||
|
||||
while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
|
||||
backtick_end += 1;
|
||||
}
|
||||
|
||||
let backtick_count = backtick_end - start;
|
||||
let closing_backticks = "`".repeat(backtick_count);
|
||||
|
||||
if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
|
||||
let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1];
|
||||
last_block = Some(code_block.to_string());
|
||||
search_start = backtick_end + end_pos + backtick_count;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
last_block.unwrap_or_else(|| text.to_string())
|
||||
}
|
||||
|
||||
fn extract_editable_region(text: &str) -> String {
|
||||
let start = text
|
||||
.find(Self::REGION_START)
|
||||
.map_or(0, |pos| pos + Self::REGION_START.len());
|
||||
let end = text.find(Self::REGION_END).unwrap_or(text.len());
|
||||
|
||||
text[start..end].to_string()
|
||||
}
|
||||
|
||||
/// Truncates edit history to a maximum length and removes comments (unified diff garbage lines)
|
||||
fn format_edit_history(edit_history: &str) -> String {
|
||||
let lines = edit_history
|
||||
.lines()
|
||||
.filter(|&s| Self::is_content_line(s))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
|
||||
&lines[lines.len() - Self::MAX_HISTORY_LINES..]
|
||||
} else {
|
||||
&lines
|
||||
};
|
||||
history_lines.join("\n")
|
||||
}
|
||||
|
||||
fn is_content_line(s: &str) -> bool {
|
||||
s.starts_with("-")
|
||||
|| s.starts_with("+")
|
||||
|| s.starts_with(" ")
|
||||
|| s.starts_with("---")
|
||||
|| s.starts_with("+++")
|
||||
|| s.starts_with("@@")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_response() {
|
||||
let teacher = TeacherModel::new(
|
||||
"test".to_string(),
|
||||
ContextType::CurrentFile,
|
||||
LlmClient::dummy(),
|
||||
);
|
||||
let response = "This is a test response.";
|
||||
let parsed = teacher.parse_response(response);
|
||||
assert_eq!(parsed, response.to_string());
|
||||
|
||||
let response = indoc::indoc! {"
|
||||
Some thinking
|
||||
|
||||
`````
|
||||
actual response
|
||||
`````
|
||||
"};
|
||||
let parsed = teacher.parse_response(response);
|
||||
assert_eq!(parsed, "actual response");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_last_code_block() {
|
||||
let text = indoc::indoc! {"
|
||||
Some thinking
|
||||
|
||||
```
|
||||
first block
|
||||
```
|
||||
|
||||
`````
|
||||
last block
|
||||
`````
|
||||
"};
|
||||
let last_block = TeacherModel::extract_last_codeblock(text);
|
||||
assert_eq!(last_block, "last block");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extract_editable_region() {
|
||||
let teacher = TeacherModel::new(
|
||||
"test".to_string(),
|
||||
ContextType::CurrentFile,
|
||||
LlmClient::dummy(),
|
||||
);
|
||||
let response = indoc::indoc! {"
|
||||
some lines
|
||||
are
|
||||
here
|
||||
<|editable_region_start|>
|
||||
one
|
||||
two three
|
||||
|
||||
<|editable_region_end|>
|
||||
more
|
||||
lines here
|
||||
"};
|
||||
let parsed = teacher.parse_response(response);
|
||||
assert_eq!(
|
||||
parsed,
|
||||
indoc::indoc! {"
|
||||
one
|
||||
two three
|
||||
|
||||
"}
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -26,6 +26,7 @@ serde.workspace = true
|
||||
smallvec.workspace = true
|
||||
tree-sitter.workspace = true
|
||||
util.workspace = true
|
||||
zeta_prompt.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
env_logger.workspace = true
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use crate::RelatedExcerpt;
|
||||
use language::{BufferSnapshot, OffsetRangeExt as _, Point};
|
||||
use std::ops::Range;
|
||||
use zeta_prompt::RelatedExcerpt;
|
||||
|
||||
#[cfg(not(test))]
|
||||
const MAX_OUTLINE_ITEM_BODY_SIZE: usize = 512;
|
||||
@@ -76,14 +76,9 @@ pub fn assemble_excerpts(
|
||||
|
||||
input_ranges
|
||||
.into_iter()
|
||||
.map(|range| {
|
||||
let offset_range = range.to_offset(buffer);
|
||||
RelatedExcerpt {
|
||||
point_range: range,
|
||||
anchor_range: buffer.anchor_before(offset_range.start)
|
||||
..buffer.anchor_after(offset_range.end),
|
||||
text: buffer.as_rope().slice(offset_range),
|
||||
}
|
||||
.map(|range| RelatedExcerpt {
|
||||
row_range: range.start.row..range.end.row,
|
||||
text: buffer.text_for_range(range).collect(),
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
@@ -3,13 +3,13 @@ use anyhow::Result;
|
||||
use collections::HashMap;
|
||||
use futures::{FutureExt, StreamExt as _, channel::mpsc, future};
|
||||
use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
|
||||
use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, Rope, ToOffset as _};
|
||||
use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset as _};
|
||||
use project::{LocationLink, Project, ProjectPath};
|
||||
use serde::{Serialize, Serializer};
|
||||
use smallvec::SmallVec;
|
||||
use std::{
|
||||
collections::hash_map,
|
||||
ops::Range,
|
||||
path::Path,
|
||||
sync::Arc,
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
@@ -24,12 +24,14 @@ mod fake_definition_lsp;
|
||||
|
||||
pub use cloud_llm_client::predict_edits_v3::Line;
|
||||
pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
|
||||
pub use zeta_prompt::{RelatedExcerpt, RelatedFile};
|
||||
|
||||
const IDENTIFIER_LINE_COUNT: u32 = 3;
|
||||
|
||||
pub struct RelatedExcerptStore {
|
||||
project: WeakEntity<Project>,
|
||||
related_files: Vec<RelatedFile>,
|
||||
related_files: Arc<[RelatedFile]>,
|
||||
related_file_buffers: Vec<Entity<Buffer>>,
|
||||
cache: HashMap<Identifier, Arc<CacheEntry>>,
|
||||
update_tx: mpsc::UnboundedSender<(Entity<Buffer>, Anchor)>,
|
||||
identifier_line_count: u32,
|
||||
@@ -68,82 +70,6 @@ struct CachedDefinition {
|
||||
anchor_range: Range<Anchor>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize)]
|
||||
pub struct RelatedFile {
|
||||
#[serde(serialize_with = "serialize_project_path")]
|
||||
pub path: ProjectPath,
|
||||
#[serde(skip)]
|
||||
pub buffer: WeakEntity<Buffer>,
|
||||
pub excerpts: Vec<RelatedExcerpt>,
|
||||
pub max_row: u32,
|
||||
}
|
||||
|
||||
impl RelatedFile {
|
||||
pub fn merge_excerpts(&mut self) {
|
||||
self.excerpts.sort_unstable_by(|a, b| {
|
||||
a.point_range
|
||||
.start
|
||||
.cmp(&b.point_range.start)
|
||||
.then(b.point_range.end.cmp(&a.point_range.end))
|
||||
});
|
||||
|
||||
let mut index = 1;
|
||||
while index < self.excerpts.len() {
|
||||
if self.excerpts[index - 1]
|
||||
.point_range
|
||||
.end
|
||||
.cmp(&self.excerpts[index].point_range.start)
|
||||
.is_ge()
|
||||
{
|
||||
let removed = self.excerpts.remove(index);
|
||||
if removed
|
||||
.point_range
|
||||
.end
|
||||
.cmp(&self.excerpts[index - 1].point_range.end)
|
||||
.is_gt()
|
||||
{
|
||||
self.excerpts[index - 1].point_range.end = removed.point_range.end;
|
||||
self.excerpts[index - 1].anchor_range.end = removed.anchor_range.end;
|
||||
}
|
||||
} else {
|
||||
index += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize)]
|
||||
pub struct RelatedExcerpt {
|
||||
#[serde(skip)]
|
||||
pub anchor_range: Range<Anchor>,
|
||||
#[serde(serialize_with = "serialize_point_range")]
|
||||
pub point_range: Range<Point>,
|
||||
#[serde(serialize_with = "serialize_rope")]
|
||||
pub text: Rope,
|
||||
}
|
||||
|
||||
fn serialize_project_path<S: Serializer>(
|
||||
project_path: &ProjectPath,
|
||||
serializer: S,
|
||||
) -> Result<S::Ok, S::Error> {
|
||||
project_path.path.serialize(serializer)
|
||||
}
|
||||
|
||||
fn serialize_rope<S: Serializer>(rope: &Rope, serializer: S) -> Result<S::Ok, S::Error> {
|
||||
rope.to_string().serialize(serializer)
|
||||
}
|
||||
|
||||
fn serialize_point_range<S: Serializer>(
|
||||
range: &Range<Point>,
|
||||
serializer: S,
|
||||
) -> Result<S::Ok, S::Error> {
|
||||
[
|
||||
[range.start.row, range.start.column],
|
||||
[range.end.row, range.end.column],
|
||||
]
|
||||
.serialize(serializer)
|
||||
}
|
||||
|
||||
const DEBOUNCE_DURATION: Duration = Duration::from_millis(100);
|
||||
|
||||
impl EventEmitter<RelatedExcerptStoreEvent> for RelatedExcerptStore {}
|
||||
@@ -179,7 +105,8 @@ impl RelatedExcerptStore {
|
||||
RelatedExcerptStore {
|
||||
project: project.downgrade(),
|
||||
update_tx,
|
||||
related_files: Vec::new(),
|
||||
related_files: Vec::new().into(),
|
||||
related_file_buffers: Vec::new(),
|
||||
cache: Default::default(),
|
||||
identifier_line_count: IDENTIFIER_LINE_COUNT,
|
||||
}
|
||||
@@ -193,8 +120,21 @@ impl RelatedExcerptStore {
|
||||
self.update_tx.unbounded_send((buffer, position)).ok();
|
||||
}
|
||||
|
||||
pub fn related_files(&self) -> &[RelatedFile] {
|
||||
&self.related_files
|
||||
pub fn related_files(&self) -> Arc<[RelatedFile]> {
|
||||
self.related_files.clone()
|
||||
}
|
||||
|
||||
pub fn related_files_with_buffers(
|
||||
&self,
|
||||
) -> impl Iterator<Item = (RelatedFile, Entity<Buffer>)> {
|
||||
self.related_files
|
||||
.iter()
|
||||
.cloned()
|
||||
.zip(self.related_file_buffers.iter().cloned())
|
||||
}
|
||||
|
||||
pub fn set_related_files(&mut self, files: Vec<RelatedFile>) {
|
||||
self.related_files = files.into();
|
||||
}
|
||||
|
||||
async fn fetch_excerpts(
|
||||
@@ -297,7 +237,8 @@ impl RelatedExcerptStore {
|
||||
}
|
||||
mean_definition_latency /= cache_miss_count.max(1) as u32;
|
||||
|
||||
let (new_cache, related_files) = rebuild_related_files(new_cache, cx).await?;
|
||||
let (new_cache, related_files, related_file_buffers) =
|
||||
rebuild_related_files(&project, new_cache, cx).await?;
|
||||
|
||||
if let Some(file) = &file {
|
||||
log::debug!(
|
||||
@@ -309,7 +250,8 @@ impl RelatedExcerptStore {
|
||||
|
||||
this.update(cx, |this, cx| {
|
||||
this.cache = new_cache;
|
||||
this.related_files = related_files;
|
||||
this.related_files = related_files.into();
|
||||
this.related_file_buffers = related_file_buffers;
|
||||
cx.emit(RelatedExcerptStoreEvent::FinishedRefresh {
|
||||
cache_hit_count,
|
||||
cache_miss_count,
|
||||
@@ -323,10 +265,16 @@ impl RelatedExcerptStore {
|
||||
}
|
||||
|
||||
async fn rebuild_related_files(
|
||||
project: &Entity<Project>,
|
||||
new_entries: HashMap<Identifier, Arc<CacheEntry>>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Result<(HashMap<Identifier, Arc<CacheEntry>>, Vec<RelatedFile>)> {
|
||||
) -> Result<(
|
||||
HashMap<Identifier, Arc<CacheEntry>>,
|
||||
Vec<RelatedFile>,
|
||||
Vec<Entity<Buffer>>,
|
||||
)> {
|
||||
let mut snapshots = HashMap::default();
|
||||
let mut worktree_root_names = HashMap::default();
|
||||
for entry in new_entries.values() {
|
||||
for definition in &entry.definitions {
|
||||
if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) {
|
||||
@@ -340,12 +288,22 @@ async fn rebuild_related_files(
|
||||
.read_with(cx, |buffer, _| buffer.snapshot())?,
|
||||
);
|
||||
}
|
||||
let worktree_id = definition.path.worktree_id;
|
||||
if let hash_map::Entry::Vacant(e) =
|
||||
worktree_root_names.entry(definition.path.worktree_id)
|
||||
{
|
||||
project.read_with(cx, |project, cx| {
|
||||
if let Some(worktree) = project.worktree_for_id(worktree_id, cx) {
|
||||
e.insert(worktree.read(cx).root_name().as_unix_str().to_string());
|
||||
}
|
||||
})?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(cx
|
||||
.background_spawn(async move {
|
||||
let mut files = Vec::<RelatedFile>::new();
|
||||
let mut files = Vec::new();
|
||||
let mut ranges_by_buffer = HashMap::<_, Vec<Range<Point>>>::default();
|
||||
let mut paths_by_buffer = HashMap::default();
|
||||
for entry in new_entries.values() {
|
||||
@@ -369,16 +327,31 @@ async fn rebuild_related_files(
|
||||
continue;
|
||||
};
|
||||
let excerpts = assemble_excerpts(snapshot, ranges);
|
||||
files.push(RelatedFile {
|
||||
path: project_path.clone(),
|
||||
buffer: buffer.downgrade(),
|
||||
excerpts,
|
||||
max_row: snapshot.max_point().row,
|
||||
});
|
||||
let Some(root_name) = worktree_root_names.get(&project_path.worktree_id) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let path = Path::new(&format!(
|
||||
"{}/{}",
|
||||
root_name,
|
||||
project_path.path.as_unix_str()
|
||||
))
|
||||
.into();
|
||||
|
||||
files.push((
|
||||
buffer,
|
||||
RelatedFile {
|
||||
path,
|
||||
excerpts,
|
||||
max_row: snapshot.max_point().row,
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
files.sort_by_key(|file| file.path.clone());
|
||||
(new_entries, files)
|
||||
files.sort_by_key(|(_, file)| file.path.clone());
|
||||
let (related_buffers, related_files) = files.into_iter().unzip();
|
||||
|
||||
(new_entries, related_files, related_buffers)
|
||||
})
|
||||
.await)
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) {
|
||||
&excerpts,
|
||||
&[
|
||||
(
|
||||
"src/company.rs",
|
||||
"root/src/company.rs",
|
||||
&[indoc! {"
|
||||
pub struct Company {
|
||||
owner: Arc<Person>,
|
||||
@@ -56,7 +56,7 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) {
|
||||
}"}],
|
||||
),
|
||||
(
|
||||
"src/main.rs",
|
||||
"root/src/main.rs",
|
||||
&[
|
||||
indoc! {"
|
||||
pub struct Session {
|
||||
@@ -71,7 +71,7 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) {
|
||||
],
|
||||
),
|
||||
(
|
||||
"src/person.rs",
|
||||
"root/src/person.rs",
|
||||
&[
|
||||
indoc! {"
|
||||
impl Person {
|
||||
@@ -446,7 +446,7 @@ fn assert_related_files(actual_files: &[RelatedFile], expected_files: &[(&str, &
|
||||
.iter()
|
||||
.map(|excerpt| excerpt.text.to_string())
|
||||
.collect::<Vec<_>>();
|
||||
(file.path.path.as_unix_str(), excerpts)
|
||||
(file.path.to_str().unwrap(), excerpts)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let expected_excerpts = expected_files
|
||||
@@ -492,10 +492,10 @@ fn format_excerpts(buffer: &Buffer, excerpts: &[RelatedExcerpt]) -> String {
|
||||
if excerpt.text.is_empty() {
|
||||
continue;
|
||||
}
|
||||
if current_row < excerpt.point_range.start.row {
|
||||
if current_row < excerpt.row_range.start {
|
||||
writeln!(&mut output, "…").unwrap();
|
||||
}
|
||||
current_row = excerpt.point_range.start.row;
|
||||
current_row = excerpt.row_range.start;
|
||||
|
||||
for line in excerpt.text.to_string().lines() {
|
||||
output.push_str(line);
|
||||
|
||||
@@ -17,7 +17,6 @@ anyhow.workspace = true
|
||||
buffer_diff.workspace = true
|
||||
client.workspace = true
|
||||
cloud_llm_client.workspace = true
|
||||
cloud_zeta2_prompt.workspace = true
|
||||
codestral.workspace = true
|
||||
command_palette_hooks.workspace = true
|
||||
copilot.workspace = true
|
||||
@@ -46,6 +45,7 @@ ui_input.workspace = true
|
||||
util.workspace = true
|
||||
workspace.workspace = true
|
||||
zed_actions.workspace = true
|
||||
zeta_prompt.workspace = true
|
||||
|
||||
[dev-dependencies]
|
||||
copilot = { workspace = true, features = ["test-support"] }
|
||||
|
||||
@@ -17,7 +17,7 @@ use gpui::{
|
||||
};
|
||||
use multi_buffer::MultiBuffer;
|
||||
use project::Project;
|
||||
use text::OffsetRangeExt;
|
||||
use text::Point;
|
||||
use ui::{
|
||||
ButtonCommon, Clickable, Disableable, FluentBuilder as _, IconButton, IconName,
|
||||
StyledTypography as _, h_flex, v_flex,
|
||||
@@ -66,7 +66,7 @@ impl EditPredictionContextView {
|
||||
) -> Self {
|
||||
let store = EditPredictionStore::global(client, user_store, cx);
|
||||
|
||||
let mut debug_rx = store.update(cx, |store, _| store.debug_info());
|
||||
let mut debug_rx = store.update(cx, |store, cx| store.debug_info(&project, cx));
|
||||
let _update_task = cx.spawn_in(window, async move |this, cx| {
|
||||
while let Some(event) = debug_rx.next().await {
|
||||
this.update_in(cx, |this, window, cx| {
|
||||
@@ -103,7 +103,8 @@ impl EditPredictionContextView {
|
||||
self.handle_context_retrieval_finished(info, window, cx);
|
||||
}
|
||||
}
|
||||
DebugEvent::EditPredictionRequested(_) => {}
|
||||
DebugEvent::EditPredictionStarted(_) => {}
|
||||
DebugEvent::EditPredictionFinished(_) => {}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -152,12 +153,11 @@ impl EditPredictionContextView {
|
||||
run.finished_at = Some(info.timestamp);
|
||||
run.metadata = info.metadata;
|
||||
|
||||
let project = self.project.clone();
|
||||
let related_files = self
|
||||
.store
|
||||
.read(cx)
|
||||
.context_for_project(&self.project, cx)
|
||||
.to_vec();
|
||||
.context_for_project_with_buffers(&self.project, cx)
|
||||
.map_or(Vec::new(), |files| files.collect());
|
||||
|
||||
let editor = run.editor.clone();
|
||||
let multibuffer = run.editor.read(cx).buffer().clone();
|
||||
@@ -168,33 +168,14 @@ impl EditPredictionContextView {
|
||||
|
||||
cx.spawn_in(window, async move |this, cx| {
|
||||
let mut paths = Vec::new();
|
||||
for related_file in related_files {
|
||||
let (buffer, point_ranges): (_, Vec<_>) =
|
||||
if let Some(buffer) = related_file.buffer.upgrade() {
|
||||
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
|
||||
|
||||
(
|
||||
buffer,
|
||||
related_file
|
||||
.excerpts
|
||||
.iter()
|
||||
.map(|excerpt| excerpt.anchor_range.to_point(&snapshot))
|
||||
.collect(),
|
||||
)
|
||||
} else {
|
||||
(
|
||||
project
|
||||
.update(cx, |project, cx| {
|
||||
project.open_buffer(related_file.path.clone(), cx)
|
||||
})?
|
||||
.await?,
|
||||
related_file
|
||||
.excerpts
|
||||
.iter()
|
||||
.map(|excerpt| excerpt.point_range.clone())
|
||||
.collect(),
|
||||
)
|
||||
};
|
||||
for (related_file, buffer) in related_files {
|
||||
let point_ranges = related_file
|
||||
.excerpts
|
||||
.iter()
|
||||
.map(|excerpt| {
|
||||
Point::new(excerpt.row_range.start, 0)..Point::new(excerpt.row_range.end, 0)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
cx.update(|_, cx| {
|
||||
let path = PathKey::for_buffer(&buffer, cx);
|
||||
paths.push((path, buffer, point_ranges));
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
|
||||
use cloud_zeta2_prompt::write_codeblock;
|
||||
use edit_prediction::{EditPrediction, EditPredictionRating, EditPredictionStore};
|
||||
use editor::{Editor, ExcerptRange, MultiBuffer};
|
||||
use feature_flags::FeatureFlag;
|
||||
@@ -362,14 +361,14 @@ impl RatePredictionsModal {
|
||||
write!(&mut formatted_inputs, "## Events\n\n").unwrap();
|
||||
|
||||
for event in &prediction.inputs.events {
|
||||
write!(&mut formatted_inputs, "```diff\n{event}```\n\n").unwrap();
|
||||
formatted_inputs.push_str("```diff\n");
|
||||
zeta_prompt::write_event(&mut formatted_inputs, event.as_ref());
|
||||
formatted_inputs.push_str("```\n\n");
|
||||
}
|
||||
|
||||
write!(&mut formatted_inputs, "## Included files\n\n").unwrap();
|
||||
|
||||
for included_file in &prediction.inputs.included_files {
|
||||
let cursor_insertions = &[(prediction.inputs.cursor_point, "<|CURSOR|>")];
|
||||
write!(&mut formatted_inputs, "## Related files\n\n").unwrap();
|
||||
|
||||
for included_file in prediction.inputs.related_files.as_ref() {
|
||||
write!(
|
||||
&mut formatted_inputs,
|
||||
"### {}\n\n",
|
||||
@@ -377,20 +376,28 @@ impl RatePredictionsModal {
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
write_codeblock(
|
||||
&included_file.path,
|
||||
&included_file.excerpts,
|
||||
if included_file.path == prediction.inputs.cursor_path {
|
||||
cursor_insertions.as_slice()
|
||||
} else {
|
||||
&[]
|
||||
},
|
||||
included_file.max_row,
|
||||
false,
|
||||
&mut formatted_inputs,
|
||||
);
|
||||
for excerpt in included_file.excerpts.iter() {
|
||||
write!(
|
||||
&mut formatted_inputs,
|
||||
"```{}\n{}\n```\n",
|
||||
included_file.path.display(),
|
||||
excerpt.text
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
write!(&mut formatted_inputs, "## Cursor Excerpt\n\n").unwrap();
|
||||
|
||||
writeln!(
|
||||
&mut formatted_inputs,
|
||||
"```{}\n{}<CURSOR>{}\n```\n",
|
||||
prediction.inputs.cursor_path.display(),
|
||||
&prediction.inputs.cursor_excerpt[..prediction.inputs.cursor_offset_in_excerpt],
|
||||
&prediction.inputs.cursor_excerpt[prediction.inputs.cursor_offset_in_excerpt..],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
self.active_prediction = Some(ActivePrediction {
|
||||
prediction,
|
||||
feedback_editor: cx.new(|cx| {
|
||||
|
||||
@@ -20973,9 +20973,22 @@ impl Editor {
|
||||
buffer_ranges.last()
|
||||
}?;
|
||||
|
||||
let selection = text::ToPoint::to_point(&range.start, buffer).row
|
||||
..text::ToPoint::to_point(&range.end, buffer).row;
|
||||
Some((multi_buffer.buffer(buffer.remote_id()).unwrap(), selection))
|
||||
let start_row_in_buffer = text::ToPoint::to_point(&range.start, buffer).row;
|
||||
let end_row_in_buffer = text::ToPoint::to_point(&range.end, buffer).row;
|
||||
|
||||
let Some(buffer_diff) = multi_buffer.diff_for(buffer.remote_id()) else {
|
||||
let selection = start_row_in_buffer..end_row_in_buffer;
|
||||
|
||||
return Some((multi_buffer.buffer(buffer.remote_id()).unwrap(), selection));
|
||||
};
|
||||
|
||||
let buffer_diff_snapshot = buffer_diff.read(cx).snapshot(cx);
|
||||
|
||||
Some((
|
||||
multi_buffer.buffer(buffer.remote_id()).unwrap(),
|
||||
buffer_diff_snapshot.row_to_base_text_row(start_row_in_buffer, buffer)
|
||||
..buffer_diff_snapshot.row_to_base_text_row(end_row_in_buffer, buffer),
|
||||
))
|
||||
});
|
||||
|
||||
let Some((buffer, selection)) = buffer_and_selection else {
|
||||
|
||||
@@ -27701,6 +27701,7 @@ async fn test_markdown_indents(cx: &mut gpui::TestAppContext) {
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.handle_input("x", window, cx);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
cx.assert_editor_state(indoc! {"
|
||||
- [ ] Item 1
|
||||
- [ ] Item 1.a
|
||||
@@ -27716,8 +27717,7 @@ async fn test_markdown_indents(cx: &mut gpui::TestAppContext) {
|
||||
- [ ] Item 1.a
|
||||
- [x] Item 2
|
||||
- [x] Item 2.a
|
||||
- [x] Item 2.bˇ
|
||||
"
|
||||
- [x] Item 2.bˇ"
|
||||
});
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.newline(&Newline, window, cx);
|
||||
@@ -27728,34 +27728,41 @@ async fn test_markdown_indents(cx: &mut gpui::TestAppContext) {
|
||||
- [x] Item 2
|
||||
- [x] Item 2.a
|
||||
- [x] Item 2.b
|
||||
ˇ
|
||||
"
|
||||
ˇ"
|
||||
});
|
||||
|
||||
// Case 3: Test adding a new nested list item preserves indent
|
||||
cx.set_state(&indoc! {"
|
||||
- [ ] Item 1
|
||||
- [ ] Item 1.a
|
||||
- [x] Item 2
|
||||
- [x] Item 2.a
|
||||
- [x] Item 2.b
|
||||
ˇ"
|
||||
});
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.handle_input("-", window, cx);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
cx.assert_editor_state(indoc! {"
|
||||
- [ ] Item 1
|
||||
- [ ] Item 1.a
|
||||
- [x] Item 2
|
||||
- [x] Item 2.a
|
||||
- [x] Item 2.b
|
||||
-ˇ
|
||||
"
|
||||
-ˇ"
|
||||
});
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.handle_input(" [x] Item 2.c", window, cx);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
cx.assert_editor_state(indoc! {"
|
||||
- [ ] Item 1
|
||||
- [ ] Item 1.a
|
||||
- [x] Item 2
|
||||
- [x] Item 2.a
|
||||
- [x] Item 2.b
|
||||
- [x] Item 2.cˇ
|
||||
"
|
||||
- [x] Item 2.cˇ"
|
||||
});
|
||||
|
||||
// Case 4: Test adding new line after nested ordered list preserves indent of previous line
|
||||
@@ -27764,8 +27771,7 @@ async fn test_markdown_indents(cx: &mut gpui::TestAppContext) {
|
||||
1. Item 1.a
|
||||
2. Item 2
|
||||
1. Item 2.a
|
||||
2. Item 2.bˇ
|
||||
"
|
||||
2. Item 2.bˇ"
|
||||
});
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.newline(&Newline, window, cx);
|
||||
@@ -27776,60 +27782,81 @@ async fn test_markdown_indents(cx: &mut gpui::TestAppContext) {
|
||||
2. Item 2
|
||||
1. Item 2.a
|
||||
2. Item 2.b
|
||||
ˇ
|
||||
"
|
||||
ˇ"
|
||||
});
|
||||
|
||||
// Case 5: Adding new ordered list item preserves indent
|
||||
cx.set_state(indoc! {"
|
||||
1. Item 1
|
||||
1. Item 1.a
|
||||
2. Item 2
|
||||
1. Item 2.a
|
||||
2. Item 2.b
|
||||
ˇ"
|
||||
});
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.handle_input("3", window, cx);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
cx.assert_editor_state(indoc! {"
|
||||
1. Item 1
|
||||
1. Item 1.a
|
||||
2. Item 2
|
||||
1. Item 2.a
|
||||
2. Item 2.b
|
||||
3ˇ
|
||||
"
|
||||
3ˇ"
|
||||
});
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.handle_input(".", window, cx);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
cx.assert_editor_state(indoc! {"
|
||||
1. Item 1
|
||||
1. Item 1.a
|
||||
2. Item 2
|
||||
1. Item 2.a
|
||||
2. Item 2.b
|
||||
3.ˇ
|
||||
"
|
||||
3.ˇ"
|
||||
});
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.handle_input(" Item 2.c", window, cx);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
cx.assert_editor_state(indoc! {"
|
||||
1. Item 1
|
||||
1. Item 1.a
|
||||
2. Item 2
|
||||
1. Item 2.a
|
||||
2. Item 2.b
|
||||
3. Item 2.cˇ
|
||||
"
|
||||
3. Item 2.cˇ"
|
||||
});
|
||||
|
||||
// Case 6: Test adding new line after nested ordered list preserves indent of previous line
|
||||
cx.set_state(indoc! {"
|
||||
- Item 1
|
||||
- Item 1.a
|
||||
- Item 1.a
|
||||
ˇ"});
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.handle_input("-", window, cx);
|
||||
});
|
||||
cx.run_until_parked();
|
||||
cx.assert_editor_state(indoc! {"
|
||||
- Item 1
|
||||
- Item 1.a
|
||||
- Item 1.a
|
||||
-ˇ"});
|
||||
|
||||
// Case 7: Test blockquote newline preserves something
|
||||
cx.set_state(indoc! {"
|
||||
> Item 1ˇ
|
||||
"
|
||||
> Item 1ˇ"
|
||||
});
|
||||
cx.update_editor(|editor, window, cx| {
|
||||
editor.newline(&Newline, window, cx);
|
||||
});
|
||||
cx.assert_editor_state(indoc! {"
|
||||
> Item 1
|
||||
ˇ
|
||||
"
|
||||
ˇ"
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ impl FeatureFlag for InlineAssistantV2FeatureFlag {
|
||||
const NAME: &'static str = "inline-assistant-v2";
|
||||
|
||||
fn enabled_for_staff() -> bool {
|
||||
false
|
||||
// false
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -232,14 +232,12 @@ impl From<Oid> for usize {
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub enum RunHook {
|
||||
PreCommit,
|
||||
PrePush,
|
||||
}
|
||||
|
||||
impl RunHook {
|
||||
pub fn as_str(&self) -> &str {
|
||||
match self {
|
||||
Self::PreCommit => "pre-commit",
|
||||
Self::PrePush => "pre-push",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -250,7 +248,6 @@ impl RunHook {
|
||||
pub fn from_proto(value: i32) -> Option<Self> {
|
||||
match value {
|
||||
0 => Some(Self::PreCommit),
|
||||
1 => Some(Self::PrePush),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -652,6 +652,7 @@ pub struct RealGitRepository {
|
||||
pub repository: Arc<Mutex<git2::Repository>>,
|
||||
pub system_git_binary_path: Option<PathBuf>,
|
||||
pub any_git_binary_path: PathBuf,
|
||||
any_git_binary_help_output: Arc<Mutex<Option<SharedString>>>,
|
||||
executor: BackgroundExecutor,
|
||||
}
|
||||
|
||||
@@ -670,6 +671,7 @@ impl RealGitRepository {
|
||||
system_git_binary_path,
|
||||
any_git_binary_path,
|
||||
executor,
|
||||
any_git_binary_help_output: Arc::new(Mutex::new(None)),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -680,6 +682,27 @@ impl RealGitRepository {
|
||||
.context("failed to read git work directory")
|
||||
.map(Path::to_path_buf)
|
||||
}
|
||||
|
||||
async fn any_git_binary_help_output(&self) -> SharedString {
|
||||
if let Some(output) = self.any_git_binary_help_output.lock().clone() {
|
||||
return output;
|
||||
}
|
||||
let git_binary_path = self.any_git_binary_path.clone();
|
||||
let executor = self.executor.clone();
|
||||
let working_directory = self.working_directory();
|
||||
let output: SharedString = self
|
||||
.executor
|
||||
.spawn(async move {
|
||||
GitBinary::new(git_binary_path, working_directory?, executor)
|
||||
.run(["help", "-a"])
|
||||
.await
|
||||
})
|
||||
.await
|
||||
.unwrap_or_default()
|
||||
.into();
|
||||
*self.any_git_binary_help_output.lock() = Some(output.clone());
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
@@ -2290,48 +2313,47 @@ impl GitRepository for RealGitRepository {
|
||||
env: Arc<HashMap<String, String>>,
|
||||
) -> BoxFuture<'_, Result<()>> {
|
||||
let working_directory = self.working_directory();
|
||||
let repository = self.repository.clone();
|
||||
let git_binary_path = self.any_git_binary_path.clone();
|
||||
let executor = self.executor.clone();
|
||||
self.executor
|
||||
.spawn(async move {
|
||||
let working_directory = working_directory?;
|
||||
let git = GitBinary::new(git_binary_path, working_directory.clone(), executor)
|
||||
.envs(HashMap::clone(&env));
|
||||
let help_output = self.any_git_binary_help_output();
|
||||
|
||||
let output = git.run(&["help", "-a"]).await?;
|
||||
if !output.lines().any(|line| line.trim().starts_with("hook ")) {
|
||||
log::warn!(
|
||||
"git hook command not available, running the {} hook manually",
|
||||
hook.as_str()
|
||||
);
|
||||
// Note: Do not spawn these commands on the background thread, as this causes some git hooks to hang.
|
||||
async move {
|
||||
let working_directory = working_directory?;
|
||||
if !help_output
|
||||
.await
|
||||
.lines()
|
||||
.any(|line| line.trim().starts_with("hook "))
|
||||
{
|
||||
let hook_abs_path = repository.lock().path().join("hooks").join(hook.as_str());
|
||||
if hook_abs_path.is_file() {
|
||||
let output = new_smol_command(&hook_abs_path)
|
||||
.envs(env.iter())
|
||||
.current_dir(&working_directory)
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
let hook_abs_path = working_directory
|
||||
.join(".git")
|
||||
.join("hooks")
|
||||
.join(hook.as_str());
|
||||
if hook_abs_path.is_file() {
|
||||
let output = new_smol_command(&hook_abs_path)
|
||||
.envs(env.iter())
|
||||
.current_dir(&working_directory)
|
||||
.output()
|
||||
.await?;
|
||||
|
||||
anyhow::ensure!(
|
||||
output.status.success(),
|
||||
"{} hook failed:\n{}",
|
||||
hook.as_str(),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
if !output.status.success() {
|
||||
return Err(GitBinaryCommandError {
|
||||
stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
|
||||
stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
|
||||
status: output.status,
|
||||
}
|
||||
.into());
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
git.run(&["hook", "run", "--ignore-missing", hook.as_str()])
|
||||
.await?;
|
||||
Ok(())
|
||||
})
|
||||
.boxed()
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let git = GitBinary::new(git_binary_path, working_directory, executor)
|
||||
.envs(HashMap::clone(&env));
|
||||
git.run(&["hook", "run", "--ignore-missing", hook.as_str()])
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use crate::{App, PlatformDispatcher, RunnableMeta, RunnableVariant};
|
||||
use async_task::Runnable;
|
||||
use futures::channel::mpsc;
|
||||
use parking_lot::{Condvar, Mutex};
|
||||
use smol::prelude::*;
|
||||
use std::{
|
||||
fmt::Debug,
|
||||
@@ -154,6 +155,57 @@ impl BackgroundExecutor {
|
||||
self.spawn_internal::<R>(Box::pin(future), None)
|
||||
}
|
||||
|
||||
/// Enqueues the given future to be run to completion on a background thread and blocking the current task on it.
|
||||
///
|
||||
/// This allows to spawn background work that borrows from its scope. Note that the supplied future will run to
|
||||
/// completion before the current task is resumed, even if the current task is slated for cancellation.
|
||||
pub async fn await_on_background<R>(&self, future: impl Future<Output = R> + Send) -> R
|
||||
where
|
||||
R: Send,
|
||||
{
|
||||
// We need to ensure that cancellation of the parent task does not drop the environment
|
||||
// before the our own task has completed or got cancelled.
|
||||
struct NotifyOnDrop<'a>(&'a (Condvar, Mutex<bool>));
|
||||
|
||||
impl Drop for NotifyOnDrop<'_> {
|
||||
fn drop(&mut self) {
|
||||
*self.0.1.lock() = true;
|
||||
self.0.0.notify_all();
|
||||
}
|
||||
}
|
||||
|
||||
struct WaitOnDrop<'a>(&'a (Condvar, Mutex<bool>));
|
||||
|
||||
impl Drop for WaitOnDrop<'_> {
|
||||
fn drop(&mut self) {
|
||||
let mut done = self.0.1.lock();
|
||||
if !*done {
|
||||
self.0.0.wait(&mut done);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let dispatcher = self.dispatcher.clone();
|
||||
let location = core::panic::Location::caller();
|
||||
|
||||
let pair = &(Condvar::new(), Mutex::new(false));
|
||||
let _wait_guard = WaitOnDrop(pair);
|
||||
|
||||
let (runnable, task) = unsafe {
|
||||
async_task::Builder::new()
|
||||
.metadata(RunnableMeta { location })
|
||||
.spawn_unchecked(
|
||||
move |_| async {
|
||||
let _notify_guard = NotifyOnDrop(pair);
|
||||
future.await
|
||||
},
|
||||
move |runnable| dispatcher.dispatch(RunnableVariant::Meta(runnable), None),
|
||||
)
|
||||
};
|
||||
runnable.schedule();
|
||||
task.await
|
||||
}
|
||||
|
||||
/// Enqueues the given future to be run to completion on a background thread.
|
||||
/// The given label can be used to control the priority of the task in tests.
|
||||
#[track_caller]
|
||||
|
||||
@@ -289,6 +289,13 @@ pub trait PlatformDisplay: Send + Sync + Debug {
|
||||
/// Get the bounds for this display
|
||||
fn bounds(&self) -> Bounds<Pixels>;
|
||||
|
||||
/// Get the visible bounds for this display, excluding taskbar/dock areas.
|
||||
/// This is the usable area where windows can be placed without being obscured.
|
||||
/// Defaults to the full display bounds if not overridden.
|
||||
fn visible_bounds(&self) -> Bounds<Pixels> {
|
||||
self.bounds()
|
||||
}
|
||||
|
||||
/// Get the default bounds for this display to place a window
|
||||
fn default_bounds(&self) -> Bounds<Pixels> {
|
||||
let bounds = self.bounds();
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
use crate::{Bounds, DisplayId, Pixels, PlatformDisplay, px, size};
|
||||
use crate::{Bounds, DisplayId, Pixels, PlatformDisplay, point, px, size};
|
||||
use anyhow::Result;
|
||||
use cocoa::{
|
||||
appkit::NSScreen,
|
||||
base::{id, nil},
|
||||
foundation::{NSDictionary, NSString},
|
||||
foundation::{NSArray, NSDictionary, NSString},
|
||||
};
|
||||
use core_foundation::uuid::{CFUUIDGetUUIDBytes, CFUUIDRef};
|
||||
use core_graphics::display::{CGDirectDisplayID, CGDisplayBounds, CGGetActiveDisplayList};
|
||||
@@ -114,4 +114,53 @@ impl PlatformDisplay for MacDisplay {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn visible_bounds(&self) -> Bounds<Pixels> {
|
||||
unsafe {
|
||||
let dominated_screen = self.get_nsscreen();
|
||||
|
||||
if dominated_screen == nil {
|
||||
return self.bounds();
|
||||
}
|
||||
|
||||
let screen_frame = NSScreen::frame(dominated_screen);
|
||||
let visible_frame = NSScreen::visibleFrame(dominated_screen);
|
||||
|
||||
// Convert from bottom-left origin (AppKit) to top-left origin
|
||||
let origin_y =
|
||||
screen_frame.size.height - visible_frame.origin.y - visible_frame.size.height
|
||||
+ screen_frame.origin.y;
|
||||
|
||||
Bounds {
|
||||
origin: point(
|
||||
px(visible_frame.origin.x as f32 - screen_frame.origin.x as f32),
|
||||
px(origin_y as f32),
|
||||
),
|
||||
size: size(
|
||||
px(visible_frame.size.width as f32),
|
||||
px(visible_frame.size.height as f32),
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MacDisplay {
|
||||
/// Find the NSScreen corresponding to this display
|
||||
unsafe fn get_nsscreen(&self) -> id {
|
||||
let screens = unsafe { NSScreen::screens(nil) };
|
||||
let count = unsafe { NSArray::count(screens) };
|
||||
let screen_number_key: id = unsafe { NSString::alloc(nil).init_str("NSScreenNumber") };
|
||||
|
||||
for i in 0..count {
|
||||
let screen = unsafe { NSArray::objectAtIndex(screens, i) };
|
||||
let device_description = unsafe { NSScreen::deviceDescription(screen) };
|
||||
let screen_number = unsafe { device_description.objectForKey_(screen_number_key) };
|
||||
let screen_id: CGDirectDisplayID = msg_send![screen_number, unsignedIntegerValue];
|
||||
if screen_id == self.0 {
|
||||
return screen;
|
||||
}
|
||||
}
|
||||
nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ pub(crate) struct WindowsDisplay {
|
||||
pub display_id: DisplayId,
|
||||
scale_factor: f32,
|
||||
bounds: Bounds<Pixels>,
|
||||
visible_bounds: Bounds<Pixels>,
|
||||
physical_bounds: Bounds<DevicePixels>,
|
||||
uuid: Uuid,
|
||||
}
|
||||
@@ -36,6 +37,7 @@ impl WindowsDisplay {
|
||||
let screen = available_monitors().into_iter().nth(display_id.0 as _)?;
|
||||
let info = get_monitor_info(screen).log_err()?;
|
||||
let monitor_size = info.monitorInfo.rcMonitor;
|
||||
let work_area = info.monitorInfo.rcWork;
|
||||
let uuid = generate_uuid(&info.szDevice);
|
||||
let scale_factor = get_scale_factor_for_monitor(screen).log_err()?;
|
||||
let physical_size = size(
|
||||
@@ -55,6 +57,14 @@ impl WindowsDisplay {
|
||||
),
|
||||
size: physical_size.to_pixels(scale_factor),
|
||||
},
|
||||
visible_bounds: Bounds {
|
||||
origin: logical_point(work_area.left as f32, work_area.top as f32, scale_factor),
|
||||
size: size(
|
||||
(work_area.right - work_area.left) as f32 / scale_factor,
|
||||
(work_area.bottom - work_area.top) as f32 / scale_factor,
|
||||
)
|
||||
.map(crate::px),
|
||||
},
|
||||
physical_bounds: Bounds {
|
||||
origin: point(monitor_size.left.into(), monitor_size.top.into()),
|
||||
size: physical_size,
|
||||
@@ -66,6 +76,7 @@ impl WindowsDisplay {
|
||||
pub fn new_with_handle(monitor: HMONITOR) -> anyhow::Result<Self> {
|
||||
let info = get_monitor_info(monitor)?;
|
||||
let monitor_size = info.monitorInfo.rcMonitor;
|
||||
let work_area = info.monitorInfo.rcWork;
|
||||
let uuid = generate_uuid(&info.szDevice);
|
||||
let display_id = available_monitors()
|
||||
.iter()
|
||||
@@ -89,6 +100,14 @@ impl WindowsDisplay {
|
||||
),
|
||||
size: physical_size.to_pixels(scale_factor),
|
||||
},
|
||||
visible_bounds: Bounds {
|
||||
origin: logical_point(work_area.left as f32, work_area.top as f32, scale_factor),
|
||||
size: size(
|
||||
(work_area.right - work_area.left) as f32 / scale_factor,
|
||||
(work_area.bottom - work_area.top) as f32 / scale_factor,
|
||||
)
|
||||
.map(crate::px),
|
||||
},
|
||||
physical_bounds: Bounds {
|
||||
origin: point(monitor_size.left.into(), monitor_size.top.into()),
|
||||
size: physical_size,
|
||||
@@ -100,6 +119,7 @@ impl WindowsDisplay {
|
||||
fn new_with_handle_and_id(handle: HMONITOR, display_id: DisplayId) -> anyhow::Result<Self> {
|
||||
let info = get_monitor_info(handle)?;
|
||||
let monitor_size = info.monitorInfo.rcMonitor;
|
||||
let work_area = info.monitorInfo.rcWork;
|
||||
let uuid = generate_uuid(&info.szDevice);
|
||||
let scale_factor = get_scale_factor_for_monitor(handle)?;
|
||||
let physical_size = size(
|
||||
@@ -119,6 +139,14 @@ impl WindowsDisplay {
|
||||
),
|
||||
size: physical_size.to_pixels(scale_factor),
|
||||
},
|
||||
visible_bounds: Bounds {
|
||||
origin: logical_point(work_area.left as f32, work_area.top as f32, scale_factor),
|
||||
size: size(
|
||||
(work_area.right - work_area.left) as f32 / scale_factor,
|
||||
(work_area.bottom - work_area.top) as f32 / scale_factor,
|
||||
)
|
||||
.map(crate::px),
|
||||
},
|
||||
physical_bounds: Bounds {
|
||||
origin: point(monitor_size.left.into(), monitor_size.top.into()),
|
||||
size: physical_size,
|
||||
@@ -193,6 +221,10 @@ impl PlatformDisplay for WindowsDisplay {
|
||||
fn bounds(&self) -> Bounds<Pixels> {
|
||||
self.bounds
|
||||
}
|
||||
|
||||
fn visible_bounds(&self) -> Bounds<Pixels> {
|
||||
self.visible_bounds
|
||||
}
|
||||
}
|
||||
|
||||
fn available_monitors() -> SmallVec<[HMONITOR; 4]> {
|
||||
|
||||
@@ -918,86 +918,69 @@ pub(crate) struct ElementStateBox {
|
||||
pub(crate) type_name: &'static str,
|
||||
}
|
||||
|
||||
fn default_bounds(display_id: Option<DisplayId>, cx: &mut App) -> Bounds<Pixels> {
|
||||
#[cfg(target_os = "macos")]
|
||||
{
|
||||
const CASCADE_OFFSET: f32 = 25.0;
|
||||
fn default_bounds(display_id: Option<DisplayId>, cx: &mut App) -> WindowBounds {
|
||||
// TODO, BUG: if you open a window with the currently active window
|
||||
// on the stack, this will erroneously fallback to `None`
|
||||
//
|
||||
// TODO these should be the initial window bounds not considering maximized/fullscreen
|
||||
let active_window_bounds = cx
|
||||
.active_window()
|
||||
.and_then(|w| w.update(cx, |_, window, _| window.window_bounds()).ok());
|
||||
|
||||
let display = display_id
|
||||
.map(|id| cx.find_display(id))
|
||||
.unwrap_or_else(|| cx.primary_display());
|
||||
const CASCADE_OFFSET: f32 = 25.0;
|
||||
|
||||
let display_bounds = display
|
||||
.as_ref()
|
||||
.map(|d| d.bounds())
|
||||
.unwrap_or_else(|| Bounds::new(point(px(0.), px(0.)), DEFAULT_WINDOW_SIZE));
|
||||
let display = display_id
|
||||
.map(|id| cx.find_display(id))
|
||||
.unwrap_or_else(|| cx.primary_display());
|
||||
|
||||
// TODO, BUG: if you open a window with the currently active window
|
||||
// on the stack, this will erroneously select the 'unwrap_or_else'
|
||||
// code path
|
||||
let (base_origin, base_size) = cx
|
||||
.active_window()
|
||||
.and_then(|w| {
|
||||
w.update(cx, |_, window, _| {
|
||||
let bounds = window.bounds();
|
||||
(bounds.origin, bounds.size)
|
||||
})
|
||||
.ok()
|
||||
})
|
||||
.unwrap_or_else(|| {
|
||||
let default_bounds = display
|
||||
.as_ref()
|
||||
.map(|d| d.default_bounds())
|
||||
.unwrap_or_else(|| Bounds::new(point(px(0.), px(0.)), DEFAULT_WINDOW_SIZE));
|
||||
(default_bounds.origin, default_bounds.size)
|
||||
});
|
||||
let default_placement = || Bounds::new(point(px(0.), px(0.)), DEFAULT_WINDOW_SIZE);
|
||||
|
||||
let cascade_offset = point(px(CASCADE_OFFSET), px(CASCADE_OFFSET));
|
||||
let proposed_origin = base_origin + cascade_offset;
|
||||
let proposed_bounds = Bounds::new(proposed_origin, base_size);
|
||||
// Use visible_bounds to exclude taskbar/dock areas
|
||||
let display_bounds = display
|
||||
.as_ref()
|
||||
.map(|d| d.visible_bounds())
|
||||
.unwrap_or_else(default_placement);
|
||||
|
||||
let display_right = display_bounds.origin.x + display_bounds.size.width;
|
||||
let display_bottom = display_bounds.origin.y + display_bounds.size.height;
|
||||
let window_right = proposed_bounds.origin.x + proposed_bounds.size.width;
|
||||
let window_bottom = proposed_bounds.origin.y + proposed_bounds.size.height;
|
||||
let (
|
||||
Bounds {
|
||||
origin: base_origin,
|
||||
size: base_size,
|
||||
},
|
||||
window_bounds_ctor,
|
||||
): (_, fn(Bounds<Pixels>) -> WindowBounds) = match active_window_bounds {
|
||||
Some(bounds) => match bounds {
|
||||
WindowBounds::Windowed(bounds) => (bounds, WindowBounds::Windowed),
|
||||
WindowBounds::Maximized(bounds) => (bounds, WindowBounds::Maximized),
|
||||
WindowBounds::Fullscreen(bounds) => (bounds, WindowBounds::Fullscreen),
|
||||
},
|
||||
None => (
|
||||
display
|
||||
.as_ref()
|
||||
.map(|d| d.default_bounds())
|
||||
.unwrap_or_else(default_placement),
|
||||
WindowBounds::Windowed,
|
||||
),
|
||||
};
|
||||
|
||||
let fits_horizontally = window_right <= display_right;
|
||||
let fits_vertically = window_bottom <= display_bottom;
|
||||
let cascade_offset = point(px(CASCADE_OFFSET), px(CASCADE_OFFSET));
|
||||
let proposed_origin = base_origin + cascade_offset;
|
||||
let proposed_bounds = Bounds::new(proposed_origin, base_size);
|
||||
|
||||
let final_origin = match (fits_horizontally, fits_vertically) {
|
||||
(true, true) => proposed_origin,
|
||||
(false, true) => point(display_bounds.origin.x, base_origin.y),
|
||||
(true, false) => point(base_origin.x, display_bounds.origin.y),
|
||||
(false, false) => display_bounds.origin,
|
||||
};
|
||||
let display_right = display_bounds.origin.x + display_bounds.size.width;
|
||||
let display_bottom = display_bounds.origin.y + display_bounds.size.height;
|
||||
let window_right = proposed_bounds.origin.x + proposed_bounds.size.width;
|
||||
let window_bottom = proposed_bounds.origin.y + proposed_bounds.size.height;
|
||||
|
||||
Bounds::new(final_origin, base_size)
|
||||
}
|
||||
let fits_horizontally = window_right <= display_right;
|
||||
let fits_vertically = window_bottom <= display_bottom;
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
{
|
||||
const DEFAULT_WINDOW_OFFSET: Point<Pixels> = point(px(0.), px(35.));
|
||||
|
||||
// TODO, BUG: if you open a window with the currently active window
|
||||
// on the stack, this will erroneously select the 'unwrap_or_else'
|
||||
// code path
|
||||
cx.active_window()
|
||||
.and_then(|w| w.update(cx, |_, window, _| window.bounds()).ok())
|
||||
.map(|mut bounds| {
|
||||
bounds.origin += DEFAULT_WINDOW_OFFSET;
|
||||
bounds
|
||||
})
|
||||
.unwrap_or_else(|| {
|
||||
let display = display_id
|
||||
.map(|id| cx.find_display(id))
|
||||
.unwrap_or_else(|| cx.primary_display());
|
||||
|
||||
display
|
||||
.as_ref()
|
||||
.map(|display| display.default_bounds())
|
||||
.unwrap_or_else(|| Bounds::new(point(px(0.), px(0.)), DEFAULT_WINDOW_SIZE))
|
||||
})
|
||||
}
|
||||
let final_origin = match (fits_horizontally, fits_vertically) {
|
||||
(true, true) => proposed_origin,
|
||||
(false, true) => point(display_bounds.origin.x, base_origin.y),
|
||||
(true, false) => point(base_origin.x, display_bounds.origin.y),
|
||||
(false, false) => display_bounds.origin,
|
||||
};
|
||||
window_bounds_ctor(Bounds::new(final_origin, base_size))
|
||||
}
|
||||
|
||||
impl Window {
|
||||
@@ -1024,13 +1007,11 @@ impl Window {
|
||||
tabbing_identifier,
|
||||
} = options;
|
||||
|
||||
let bounds = window_bounds
|
||||
.map(|bounds| bounds.get_bounds())
|
||||
.unwrap_or_else(|| default_bounds(display_id, cx));
|
||||
let window_bounds = window_bounds.unwrap_or_else(|| default_bounds(display_id, cx));
|
||||
let mut platform_window = cx.platform.open_window(
|
||||
handle,
|
||||
WindowParams {
|
||||
bounds,
|
||||
bounds: window_bounds.get_bounds(),
|
||||
titlebar,
|
||||
kind,
|
||||
is_movable,
|
||||
@@ -1071,12 +1052,10 @@ impl Window {
|
||||
.request_decorations(window_decorations.unwrap_or(WindowDecorations::Server));
|
||||
platform_window.set_background_appearance(window_background);
|
||||
|
||||
if let Some(ref window_open_state) = window_bounds {
|
||||
match window_open_state {
|
||||
WindowBounds::Fullscreen(_) => platform_window.toggle_fullscreen(),
|
||||
WindowBounds::Maximized(_) => platform_window.zoom(),
|
||||
WindowBounds::Windowed(_) => {}
|
||||
}
|
||||
match window_bounds {
|
||||
WindowBounds::Fullscreen(_) => platform_window.toggle_fullscreen(),
|
||||
WindowBounds::Maximized(_) => platform_window.zoom(),
|
||||
WindowBounds::Windowed(_) => {}
|
||||
}
|
||||
|
||||
platform_window.on_close(Box::new({
|
||||
@@ -1518,7 +1497,8 @@ impl Window {
|
||||
style
|
||||
}
|
||||
|
||||
/// Check if the platform window is maximized
|
||||
/// Check if the platform window is maximized.
|
||||
///
|
||||
/// On some platforms (namely Windows) this is different than the bounds being the size of the display
|
||||
pub fn is_maximized(&self) -> bool {
|
||||
self.platform_window.is_maximized()
|
||||
|
||||
@@ -535,7 +535,7 @@ pub trait LspInstaller {
|
||||
_version: &Self::BinaryVersion,
|
||||
_container_dir: &PathBuf,
|
||||
_delegate: &dyn LspAdapterDelegate,
|
||||
) -> impl Future<Output = Option<LanguageServerBinary>> {
|
||||
) -> impl Send + Future<Output = Option<LanguageServerBinary>> {
|
||||
async { None }
|
||||
}
|
||||
|
||||
@@ -544,7 +544,7 @@ pub trait LspInstaller {
|
||||
latest_version: Self::BinaryVersion,
|
||||
container_dir: PathBuf,
|
||||
delegate: &dyn LspAdapterDelegate,
|
||||
) -> impl Future<Output = Result<LanguageServerBinary>>;
|
||||
) -> impl Send + Future<Output = Result<LanguageServerBinary>>;
|
||||
|
||||
fn cached_server_binary(
|
||||
&self,
|
||||
@@ -575,6 +575,7 @@ pub trait DynLspInstaller {
|
||||
#[async_trait(?Send)]
|
||||
impl<LI, BinaryVersion> DynLspInstaller for LI
|
||||
where
|
||||
BinaryVersion: Send + Sync,
|
||||
LI: LspInstaller<BinaryVersion = BinaryVersion> + LspAdapter,
|
||||
{
|
||||
async fn try_fetch_server_binary(
|
||||
@@ -593,8 +594,13 @@ where
|
||||
.fetch_latest_server_version(delegate.as_ref(), pre_release, cx)
|
||||
.await?;
|
||||
|
||||
if let Some(binary) = self
|
||||
.check_if_version_installed(&latest_version, &container_dir, delegate.as_ref())
|
||||
if let Some(binary) = cx
|
||||
.background_executor()
|
||||
.await_on_background(self.check_if_version_installed(
|
||||
&latest_version,
|
||||
&container_dir,
|
||||
delegate.as_ref(),
|
||||
))
|
||||
.await
|
||||
{
|
||||
log::debug!("language server {:?} is already installed", name.0);
|
||||
@@ -603,8 +609,13 @@ where
|
||||
} else {
|
||||
log::debug!("downloading language server {:?}", name.0);
|
||||
delegate.update_status(name.clone(), BinaryStatus::Downloading);
|
||||
let binary = self
|
||||
.fetch_server_binary(latest_version, container_dir, delegate.as_ref())
|
||||
let binary = cx
|
||||
.background_executor()
|
||||
.await_on_background(self.fetch_server_binary(
|
||||
latest_version,
|
||||
container_dir,
|
||||
delegate.as_ref(),
|
||||
))
|
||||
.await;
|
||||
|
||||
delegate.update_status(name.clone(), BinaryStatus::None);
|
||||
|
||||
@@ -2,6 +2,40 @@
|
||||
|
||||
(identifier) @variable
|
||||
|
||||
(call_expression
|
||||
function: (member_expression
|
||||
object: (identifier) @type.builtin
|
||||
(#any-of?
|
||||
@type.builtin
|
||||
"Promise"
|
||||
"Array"
|
||||
"Object"
|
||||
"Map"
|
||||
"Set"
|
||||
"WeakMap"
|
||||
"WeakSet"
|
||||
"Date"
|
||||
"Error"
|
||||
"TypeError"
|
||||
"RangeError"
|
||||
"SyntaxError"
|
||||
"ReferenceError"
|
||||
"EvalError"
|
||||
"URIError"
|
||||
"RegExp"
|
||||
"Function"
|
||||
"Number"
|
||||
"String"
|
||||
"Boolean"
|
||||
"Symbol"
|
||||
"BigInt"
|
||||
"Proxy"
|
||||
"ArrayBuffer"
|
||||
"DataView"
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
; Properties
|
||||
|
||||
(property_identifier) @property
|
||||
@@ -18,6 +52,12 @@
|
||||
function: (member_expression
|
||||
property: [(property_identifier) (private_property_identifier)] @function.method))
|
||||
|
||||
(new_expression
|
||||
constructor: (identifier) @type)
|
||||
|
||||
(nested_type_identifier
|
||||
module: (identifier) @type)
|
||||
|
||||
; Function and method definitions
|
||||
|
||||
(function_expression
|
||||
@@ -47,10 +87,45 @@
|
||||
left: (identifier) @function
|
||||
right: [(function_expression) (arrow_function)])
|
||||
|
||||
; Special identifiers
|
||||
; Parameters
|
||||
|
||||
(required_parameter
|
||||
(identifier) @variable.parameter)
|
||||
|
||||
(required_parameter
|
||||
(_
|
||||
([
|
||||
(identifier)
|
||||
(shorthand_property_identifier_pattern)
|
||||
]) @variable.parameter))
|
||||
|
||||
(optional_parameter
|
||||
(identifier) @variable.parameter)
|
||||
|
||||
(optional_parameter
|
||||
(_
|
||||
([
|
||||
(identifier)
|
||||
(shorthand_property_identifier_pattern)
|
||||
]) @variable.parameter))
|
||||
|
||||
(catch_clause
|
||||
parameter: (identifier) @variable.parameter)
|
||||
|
||||
(index_signature
|
||||
name: (identifier) @variable.parameter)
|
||||
|
||||
(arrow_function
|
||||
parameter: (identifier) @variable.parameter)
|
||||
|
||||
; Special identifiers
|
||||
;
|
||||
(class_declaration
|
||||
(type_identifier) @type.class)
|
||||
|
||||
(extends_clause
|
||||
value: (identifier) @type.class)
|
||||
|
||||
((identifier) @type
|
||||
(#match? @type "^[A-Z]"))
|
||||
(type_identifier) @type
|
||||
(predefined_type) @type.builtin
|
||||
|
||||
@@ -251,6 +326,34 @@
|
||||
(jsx_closing_element (identifier) @tag.jsx (#match? @tag.jsx "^[a-z][^.]*$"))
|
||||
(jsx_self_closing_element (identifier) @tag.jsx (#match? @tag.jsx "^[a-z][^.]*$"))
|
||||
|
||||
(jsx_opening_element
|
||||
[
|
||||
(identifier) @type
|
||||
(member_expression
|
||||
object: (identifier) @type
|
||||
property: (property_identifier) @type
|
||||
)
|
||||
]
|
||||
)
|
||||
(jsx_closing_element
|
||||
[
|
||||
(identifier) @type
|
||||
(member_expression
|
||||
object: (identifier) @type
|
||||
property: (property_identifier) @type
|
||||
)
|
||||
]
|
||||
)
|
||||
(jsx_self_closing_element
|
||||
[
|
||||
(identifier) @type
|
||||
(member_expression
|
||||
object: (identifier) @type
|
||||
property: (property_identifier) @type
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
(jsx_attribute (property_identifier) @attribute.jsx)
|
||||
(jsx_opening_element (["<" ">"]) @punctuation.bracket.jsx)
|
||||
(jsx_closing_element (["</" ">"]) @punctuation.bracket.jsx)
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
(tag_name) @keyword.jsdoc
|
||||
(type) @type.jsdoc
|
||||
(identifier) @variable.jsdoc
|
||||
|
||||
@@ -24,5 +24,9 @@ rewrap_prefixes = [
|
||||
auto_indent_on_paste = false
|
||||
auto_indent_using_last_non_empty_line = false
|
||||
tab_size = 2
|
||||
decrease_indent_pattern = "^.*$"
|
||||
decrease_indent_patterns = [
|
||||
{ pattern = "^\\s*-", valid_after = ["list_item"] },
|
||||
{ pattern = "^\\s*\\d", valid_after = ["list_item"] },
|
||||
{ pattern = "^\\s*", valid_after = ["list_item"] },
|
||||
]
|
||||
prettier_parser_name = "markdown"
|
||||
|
||||
@@ -1 +1,3 @@
|
||||
(list (list_item) @indent)
|
||||
|
||||
(list_item) @start.list_item
|
||||
|
||||
@@ -2,6 +2,40 @@
|
||||
|
||||
(identifier) @variable
|
||||
|
||||
(call_expression
|
||||
function: (member_expression
|
||||
object: (identifier) @type.builtin
|
||||
(#any-of?
|
||||
@type.builtin
|
||||
"Promise"
|
||||
"Array"
|
||||
"Object"
|
||||
"Map"
|
||||
"Set"
|
||||
"WeakMap"
|
||||
"WeakSet"
|
||||
"Date"
|
||||
"Error"
|
||||
"TypeError"
|
||||
"RangeError"
|
||||
"SyntaxError"
|
||||
"ReferenceError"
|
||||
"EvalError"
|
||||
"URIError"
|
||||
"RegExp"
|
||||
"Function"
|
||||
"Number"
|
||||
"String"
|
||||
"Boolean"
|
||||
"Symbol"
|
||||
"BigInt"
|
||||
"Proxy"
|
||||
"ArrayBuffer"
|
||||
"DataView"
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
; Properties
|
||||
|
||||
(property_identifier) @property
|
||||
@@ -18,6 +52,12 @@
|
||||
function: (member_expression
|
||||
property: [(property_identifier) (private_property_identifier)] @function.method))
|
||||
|
||||
(new_expression
|
||||
constructor: (identifier) @type)
|
||||
|
||||
(nested_type_identifier
|
||||
module: (identifier) @type)
|
||||
|
||||
; Function and method definitions
|
||||
|
||||
(function_expression
|
||||
@@ -47,13 +87,68 @@
|
||||
left: (identifier) @function
|
||||
right: [(function_expression) (arrow_function)])
|
||||
|
||||
; Parameters
|
||||
|
||||
(required_parameter
|
||||
(identifier) @variable.parameter)
|
||||
|
||||
(required_parameter
|
||||
(_
|
||||
([
|
||||
(identifier)
|
||||
(shorthand_property_identifier_pattern)
|
||||
]) @variable.parameter))
|
||||
|
||||
(optional_parameter
|
||||
(identifier) @variable.parameter)
|
||||
|
||||
(optional_parameter
|
||||
(_
|
||||
([
|
||||
(identifier)
|
||||
(shorthand_property_identifier_pattern)
|
||||
]) @variable.parameter))
|
||||
|
||||
(catch_clause
|
||||
parameter: (identifier) @variable.parameter)
|
||||
|
||||
(index_signature
|
||||
name: (identifier) @variable.parameter)
|
||||
|
||||
(arrow_function
|
||||
parameter: (identifier) @variable.parameter)
|
||||
|
||||
(type_predicate
|
||||
name: (identifier) @variable.parameter)
|
||||
|
||||
; Special identifiers
|
||||
|
||||
((identifier) @type
|
||||
(#match? @type "^[A-Z]"))
|
||||
(type_annotation) @type
|
||||
(type_identifier) @type
|
||||
(predefined_type) @type.builtin
|
||||
|
||||
(type_alias_declaration
|
||||
(type_identifier) @type)
|
||||
|
||||
(type_alias_declaration
|
||||
value: (_
|
||||
(type_identifier) @type))
|
||||
|
||||
(interface_declaration
|
||||
(type_identifier) @type)
|
||||
|
||||
(class_declaration
|
||||
(type_identifier) @type.class)
|
||||
|
||||
(extends_clause
|
||||
value: (identifier) @type.class)
|
||||
|
||||
(extends_type_clause
|
||||
type: (type_identifier) @type)
|
||||
|
||||
(implements_clause
|
||||
(type_identifier) @type)
|
||||
|
||||
([
|
||||
(identifier)
|
||||
(shorthand_property_identifier)
|
||||
@@ -231,8 +326,42 @@
|
||||
"<" @punctuation.bracket
|
||||
">" @punctuation.bracket)
|
||||
|
||||
(type_parameters
|
||||
"<" @punctuation.bracket
|
||||
">" @punctuation.bracket)
|
||||
|
||||
(decorator "@" @punctuation.special)
|
||||
|
||||
(union_type
|
||||
("|") @punctuation.special)
|
||||
|
||||
(intersection_type
|
||||
("&") @punctuation.special)
|
||||
|
||||
(type_annotation
|
||||
(":") @punctuation.special)
|
||||
|
||||
(index_signature
|
||||
(":") @punctuation.special)
|
||||
|
||||
(type_predicate_annotation
|
||||
(":") @punctuation.special)
|
||||
|
||||
(public_field_definition
|
||||
("?") @punctuation.special)
|
||||
|
||||
(property_signature
|
||||
("?") @punctuation.special)
|
||||
|
||||
(method_signature
|
||||
("?") @punctuation.special)
|
||||
|
||||
(optional_parameter
|
||||
([
|
||||
"?"
|
||||
":"
|
||||
]) @punctuation.special)
|
||||
|
||||
; Keywords
|
||||
|
||||
[ "abstract"
|
||||
@@ -257,6 +386,34 @@
|
||||
(jsx_closing_element (identifier) @tag.jsx (#match? @tag.jsx "^[a-z][^.]*$"))
|
||||
(jsx_self_closing_element (identifier) @tag.jsx (#match? @tag.jsx "^[a-z][^.]*$"))
|
||||
|
||||
(jsx_opening_element
|
||||
[
|
||||
(identifier) @type
|
||||
(member_expression
|
||||
object: (identifier) @type
|
||||
property: (property_identifier) @type
|
||||
)
|
||||
]
|
||||
)
|
||||
(jsx_closing_element
|
||||
[
|
||||
(identifier) @type
|
||||
(member_expression
|
||||
object: (identifier) @type
|
||||
property: (property_identifier) @type
|
||||
)
|
||||
]
|
||||
)
|
||||
(jsx_self_closing_element
|
||||
[
|
||||
(identifier) @type
|
||||
(member_expression
|
||||
object: (identifier) @type
|
||||
property: (property_identifier) @type
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
(jsx_attribute (property_identifier) @attribute.jsx)
|
||||
(jsx_opening_element (["<" ">"]) @punctuation.bracket.jsx)
|
||||
(jsx_closing_element (["</" ">"]) @punctuation.bracket.jsx)
|
||||
|
||||
@@ -2,13 +2,69 @@
|
||||
|
||||
(identifier) @variable
|
||||
|
||||
(call_expression
|
||||
function: (member_expression
|
||||
object: (identifier) @type.builtin
|
||||
(#any-of?
|
||||
@type.builtin
|
||||
"Promise"
|
||||
"Array"
|
||||
"Object"
|
||||
"Map"
|
||||
"Set"
|
||||
"WeakMap"
|
||||
"WeakSet"
|
||||
"Date"
|
||||
"Error"
|
||||
"TypeError"
|
||||
"RangeError"
|
||||
"SyntaxError"
|
||||
"ReferenceError"
|
||||
"EvalError"
|
||||
"URIError"
|
||||
"RegExp"
|
||||
"Function"
|
||||
"Number"
|
||||
"String"
|
||||
"Boolean"
|
||||
"Symbol"
|
||||
"BigInt"
|
||||
"Proxy"
|
||||
"ArrayBuffer"
|
||||
"DataView"
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
; Special identifiers
|
||||
|
||||
((identifier) @type
|
||||
(#match? @type "^[A-Z]"))
|
||||
(type_annotation) @type
|
||||
|
||||
(type_identifier) @type
|
||||
(predefined_type) @type.builtin
|
||||
|
||||
(type_alias_declaration
|
||||
(type_identifier) @type)
|
||||
|
||||
(type_alias_declaration
|
||||
value: (_
|
||||
(type_identifier) @type))
|
||||
|
||||
(interface_declaration
|
||||
(type_identifier) @type)
|
||||
|
||||
(class_declaration
|
||||
(type_identifier) @type.class)
|
||||
|
||||
(extends_clause
|
||||
value: (identifier) @type.class)
|
||||
|
||||
(extends_type_clause
|
||||
type: (type_identifier) @type)
|
||||
|
||||
(implements_clause
|
||||
(type_identifier) @type)
|
||||
|
||||
;; Enables ts-pretty-errors
|
||||
;; The Lsp returns "snippets" of typescript, which are not valid typescript in totality,
|
||||
;; but should still be highlighted
|
||||
@@ -83,6 +139,12 @@
|
||||
function: (member_expression
|
||||
property: [(property_identifier) (private_property_identifier)] @function.method))
|
||||
|
||||
(new_expression
|
||||
constructor: (identifier) @type)
|
||||
|
||||
(nested_type_identifier
|
||||
module: (identifier) @type)
|
||||
|
||||
; Function and method definitions
|
||||
|
||||
(function_expression
|
||||
@@ -114,6 +176,40 @@
|
||||
|
||||
(arrow_function) @function
|
||||
|
||||
; Parameters
|
||||
|
||||
(required_parameter
|
||||
(identifier) @variable.parameter)
|
||||
|
||||
(required_parameter
|
||||
(_
|
||||
([
|
||||
(identifier)
|
||||
(shorthand_property_identifier_pattern)
|
||||
]) @variable.parameter))
|
||||
|
||||
(optional_parameter
|
||||
(identifier) @variable.parameter)
|
||||
|
||||
(optional_parameter
|
||||
(_
|
||||
([
|
||||
(identifier)
|
||||
(shorthand_property_identifier_pattern)
|
||||
]) @variable.parameter))
|
||||
|
||||
(catch_clause
|
||||
parameter: (identifier) @variable.parameter)
|
||||
|
||||
(index_signature
|
||||
name: (identifier) @variable.parameter)
|
||||
|
||||
(arrow_function
|
||||
parameter: (identifier) @variable.parameter)
|
||||
|
||||
(type_predicate
|
||||
name: (identifier) @variable.parameter)
|
||||
|
||||
; Literals
|
||||
|
||||
(this) @variable.special
|
||||
@@ -244,8 +340,42 @@
|
||||
"<" @punctuation.bracket
|
||||
">" @punctuation.bracket)
|
||||
|
||||
(type_parameters
|
||||
"<" @punctuation.bracket
|
||||
">" @punctuation.bracket)
|
||||
|
||||
(decorator "@" @punctuation.special)
|
||||
|
||||
(union_type
|
||||
("|") @punctuation.special)
|
||||
|
||||
(intersection_type
|
||||
("&") @punctuation.special)
|
||||
|
||||
(type_annotation
|
||||
(":") @punctuation.special)
|
||||
|
||||
(index_signature
|
||||
(":") @punctuation.special)
|
||||
|
||||
(type_predicate_annotation
|
||||
(":") @punctuation.special)
|
||||
|
||||
(public_field_definition
|
||||
("?") @punctuation.special)
|
||||
|
||||
(property_signature
|
||||
("?") @punctuation.special)
|
||||
|
||||
(method_signature
|
||||
("?") @punctuation.special)
|
||||
|
||||
(optional_parameter
|
||||
([
|
||||
"?"
|
||||
":"
|
||||
]) @punctuation.special)
|
||||
|
||||
; Keywords
|
||||
|
||||
[
|
||||
|
||||
@@ -314,7 +314,7 @@ pub struct AdapterServerCapabilities {
|
||||
|
||||
impl LanguageServer {
|
||||
/// Starts a language server process.
|
||||
pub fn new(
|
||||
pub async fn new(
|
||||
stderr_capture: Arc<Mutex<Option<String>>>,
|
||||
server_id: LanguageServerId,
|
||||
server_name: LanguageServerName,
|
||||
@@ -331,26 +331,30 @@ impl LanguageServer {
|
||||
};
|
||||
let root_uri = Uri::from_file_path(&working_dir)
|
||||
.map_err(|()| anyhow!("{working_dir:?} is not a valid URI"))?;
|
||||
|
||||
log::info!(
|
||||
"starting language server process. binary path: {:?}, working directory: {:?}, args: {:?}",
|
||||
"starting language server process. binary path: \
|
||||
{:?}, working directory: {:?}, args: {:?}",
|
||||
binary.path,
|
||||
working_dir,
|
||||
&binary.arguments
|
||||
);
|
||||
|
||||
let mut command = util::command::new_smol_command(&binary.path);
|
||||
command
|
||||
.current_dir(working_dir)
|
||||
.args(&binary.arguments)
|
||||
.envs(binary.env.clone().unwrap_or_default())
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.kill_on_drop(true);
|
||||
let mut server = command
|
||||
.spawn()
|
||||
.with_context(|| format!("failed to spawn command {command:?}",))?;
|
||||
let mut server = cx
|
||||
.background_executor()
|
||||
.await_on_background(async {
|
||||
let mut command = util::command::new_smol_command(&binary.path);
|
||||
command
|
||||
.current_dir(working_dir)
|
||||
.args(&binary.arguments)
|
||||
.envs(binary.env.clone().unwrap_or_default())
|
||||
.stdin(Stdio::piped())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::piped())
|
||||
.kill_on_drop(true);
|
||||
command
|
||||
.spawn()
|
||||
.with_context(|| format!("failed to spawn command {command:?}",))
|
||||
})
|
||||
.await?;
|
||||
|
||||
let stdin = server.stdin.take().unwrap();
|
||||
let stdout = server.stdout.take().unwrap();
|
||||
|
||||
@@ -320,6 +320,7 @@ impl Prettier {
|
||||
Default::default(),
|
||||
&mut cx,
|
||||
)
|
||||
.await
|
||||
.context("prettier server creation")?;
|
||||
|
||||
let server = cx
|
||||
|
||||
@@ -4692,11 +4692,9 @@ impl Repository {
|
||||
});
|
||||
|
||||
let this = cx.weak_entity();
|
||||
let rx = self.run_hook(RunHook::PrePush, cx);
|
||||
self.send_job(
|
||||
Some(format!("git push {} {} {}", args, remote, branch).into()),
|
||||
move |git_repo, mut cx| async move {
|
||||
rx.await??;
|
||||
match git_repo {
|
||||
RepositoryState::Local(LocalRepositoryState {
|
||||
backend,
|
||||
|
||||
@@ -422,6 +422,7 @@ impl LocalLspStore {
|
||||
Some(pending_workspace_folders),
|
||||
cx,
|
||||
)
|
||||
.await
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
@@ -580,7 +580,7 @@ message GitCreateWorktree {
|
||||
message RunGitHook {
|
||||
enum GitHook {
|
||||
PRE_COMMIT = 0;
|
||||
PRE_PUSH = 1;
|
||||
reserved 1;
|
||||
}
|
||||
|
||||
uint64 project_id = 1;
|
||||
|
||||
@@ -132,7 +132,7 @@ async fn spawn_and_read_fd(
|
||||
#[cfg(windows)]
|
||||
async fn capture_windows(
|
||||
shell_path: &Path,
|
||||
_args: &[String],
|
||||
args: &[String],
|
||||
directory: &Path,
|
||||
) -> Result<collections::HashMap<String, String>> {
|
||||
use std::process::Stdio;
|
||||
@@ -141,17 +141,17 @@ async fn capture_windows(
|
||||
std::env::current_exe().context("Failed to determine current zed executable path.")?;
|
||||
|
||||
let shell_kind = ShellKind::new(shell_path, true);
|
||||
if let ShellKind::Csh | ShellKind::Tcsh | ShellKind::Rc | ShellKind::Fish | ShellKind::Xonsh =
|
||||
shell_kind
|
||||
{
|
||||
return Err(anyhow::anyhow!("unsupported shell kind"));
|
||||
}
|
||||
let mut cmd = crate::command::new_smol_command(shell_path);
|
||||
cmd.args(args);
|
||||
let cmd = match shell_kind {
|
||||
ShellKind::Csh | ShellKind::Tcsh | ShellKind::Rc | ShellKind::Fish | ShellKind::Xonsh => {
|
||||
unreachable!()
|
||||
}
|
||||
ShellKind::Posix => cmd.args([
|
||||
ShellKind::Csh
|
||||
| ShellKind::Tcsh
|
||||
| ShellKind::Rc
|
||||
| ShellKind::Fish
|
||||
| ShellKind::Xonsh
|
||||
| ShellKind::Posix => cmd.args([
|
||||
"-l",
|
||||
"-i",
|
||||
"-c",
|
||||
&format!(
|
||||
"cd '{}'; '{}' --printenv",
|
||||
|
||||
15
crates/zeta_prompt/Cargo.toml
Normal file
15
crates/zeta_prompt/Cargo.toml
Normal file
@@ -0,0 +1,15 @@
|
||||
[package]
|
||||
name = "zeta_prompt"
|
||||
version = "0.1.0"
|
||||
publish.workspace = true
|
||||
edition.workspace = true
|
||||
license = "GPL-3.0-or-later"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[lib]
|
||||
path = "src/zeta_prompt.rs"
|
||||
|
||||
[dependencies]
|
||||
serde.workspace = true
|
||||
165
crates/zeta_prompt/src/zeta_prompt.rs
Normal file
165
crates/zeta_prompt/src/zeta_prompt.rs
Normal file
@@ -0,0 +1,165 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Write;
|
||||
use std::ops::Range;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub const CURSOR_MARKER: &str = "<|user_cursor|>";
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ZetaPromptInput {
|
||||
pub cursor_path: Arc<Path>,
|
||||
pub cursor_excerpt: Arc<str>,
|
||||
pub editable_range_in_excerpt: Range<usize>,
|
||||
pub cursor_offset_in_excerpt: usize,
|
||||
pub events: Vec<Arc<Event>>,
|
||||
pub related_files: Arc<[RelatedFile]>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
#[serde(tag = "event")]
|
||||
pub enum Event {
|
||||
BufferChange {
|
||||
path: Arc<Path>,
|
||||
old_path: Arc<Path>,
|
||||
diff: String,
|
||||
predicted: bool,
|
||||
in_open_source_repo: bool,
|
||||
},
|
||||
}
|
||||
|
||||
pub fn write_event(prompt: &mut String, event: &Event) {
|
||||
fn write_path_as_unix_str(prompt: &mut String, path: &Path) {
|
||||
for component in path.components() {
|
||||
prompt.push('/');
|
||||
write!(prompt, "{}", component.as_os_str().display()).ok();
|
||||
}
|
||||
}
|
||||
match event {
|
||||
Event::BufferChange {
|
||||
path,
|
||||
old_path,
|
||||
diff,
|
||||
predicted,
|
||||
in_open_source_repo: _,
|
||||
} => {
|
||||
if *predicted {
|
||||
prompt.push_str("// User accepted prediction:\n");
|
||||
}
|
||||
prompt.push_str("--- a");
|
||||
write_path_as_unix_str(prompt, old_path.as_ref());
|
||||
prompt.push_str("\n+++ b");
|
||||
write_path_as_unix_str(prompt, path.as_ref());
|
||||
prompt.push('\n');
|
||||
prompt.push_str(diff);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct RelatedFile {
|
||||
pub path: Arc<Path>,
|
||||
pub max_row: u32,
|
||||
pub excerpts: Vec<RelatedExcerpt>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct RelatedExcerpt {
|
||||
pub row_range: Range<u32>,
|
||||
pub text: String,
|
||||
}
|
||||
|
||||
pub fn format_zeta_prompt(input: &ZetaPromptInput) -> String {
|
||||
let mut prompt = String::new();
|
||||
write_related_files(&mut prompt, &input.related_files);
|
||||
write_edit_history_section(&mut prompt, input);
|
||||
write_cursor_excerpt_section(&mut prompt, input);
|
||||
prompt
|
||||
}
|
||||
|
||||
pub fn write_related_files(prompt: &mut String, related_files: &[RelatedFile]) {
|
||||
push_delimited(prompt, "related_files", &[], |prompt| {
|
||||
for file in related_files {
|
||||
let path_str = file.path.to_string_lossy();
|
||||
push_delimited(prompt, "related_file", &[("path", &path_str)], |prompt| {
|
||||
for excerpt in &file.excerpts {
|
||||
push_delimited(
|
||||
prompt,
|
||||
"related_excerpt",
|
||||
&[(
|
||||
"lines",
|
||||
&format!(
|
||||
"{}-{}",
|
||||
excerpt.row_range.start + 1,
|
||||
excerpt.row_range.end + 1
|
||||
),
|
||||
)],
|
||||
|prompt| {
|
||||
prompt.push_str(&excerpt.text);
|
||||
prompt.push('\n');
|
||||
},
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn write_edit_history_section(prompt: &mut String, input: &ZetaPromptInput) {
|
||||
push_delimited(prompt, "edit_history", &[], |prompt| {
|
||||
if input.events.is_empty() {
|
||||
prompt.push_str("(No edit history)");
|
||||
} else {
|
||||
for event in &input.events {
|
||||
write_event(prompt, event);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
|
||||
push_delimited(prompt, "cursor_excerpt", &[], |prompt| {
|
||||
let path_str = input.cursor_path.to_string_lossy();
|
||||
push_delimited(prompt, "file", &[("path", &path_str)], |prompt| {
|
||||
prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
|
||||
push_delimited(prompt, "editable_region", &[], |prompt| {
|
||||
prompt.push_str(
|
||||
&input.cursor_excerpt
|
||||
[input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
|
||||
);
|
||||
prompt.push_str(CURSOR_MARKER);
|
||||
prompt.push_str(
|
||||
&input.cursor_excerpt
|
||||
[input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
|
||||
);
|
||||
});
|
||||
prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
fn push_delimited(
|
||||
prompt: &mut String,
|
||||
tag: &'static str,
|
||||
arguments: &[(&str, &str)],
|
||||
cb: impl FnOnce(&mut String),
|
||||
) {
|
||||
if !prompt.ends_with("\n") {
|
||||
prompt.push('\n');
|
||||
}
|
||||
prompt.push('<');
|
||||
prompt.push_str(tag);
|
||||
for (arg_name, arg_value) in arguments {
|
||||
write!(prompt, " {}=\"{}\"", arg_name, arg_value).ok();
|
||||
}
|
||||
prompt.push_str(">\n");
|
||||
|
||||
cb(prompt);
|
||||
|
||||
if !prompt.ends_with('\n') {
|
||||
prompt.push('\n');
|
||||
}
|
||||
prompt.push_str("</");
|
||||
prompt.push_str(tag);
|
||||
prompt.push_str(">\n");
|
||||
}
|
||||
@@ -368,6 +368,8 @@ pub(crate) fn check_postgres_and_protobuf_migrations() -> NamedJob {
|
||||
.runs_on(runners::LINUX_DEFAULT)
|
||||
.add_env(("GIT_AUTHOR_NAME", "Protobuf Action"))
|
||||
.add_env(("GIT_AUTHOR_EMAIL", "ci@zed.dev"))
|
||||
.add_env(("GIT_COMMITTER_NAME", "Protobuf Action"))
|
||||
.add_env(("GIT_COMMITTER_EMAIL", "ci@zed.dev"))
|
||||
.add_step(steps::checkout_repo().with(("fetch-depth", 0))) // fetch full history
|
||||
.add_step(remove_untracked_files())
|
||||
.add_step(ensure_fresh_merge())
|
||||
|
||||
Reference in New Issue
Block a user