Compare commits

...

21 Commits

Author SHA1 Message Date
Michael Benfield
aef1106e7a cleanup 2025-12-11 10:52:32 -08:00
Michael Benfield
605c594181 better handling of model explanation 2025-12-11 10:52:32 -08:00
Michael Benfield
f332b08a34 tmp 2025-12-11 10:52:32 -08:00
Michael Benfield
cfb61624c3 tmp 2025-12-11 10:52:32 -08:00
Michael Benfield
f79dd8784b tmp 2025-12-11 10:52:32 -08:00
Michael Benfield
0c5bebae93 tool use conversion to streaming in progress
Co-authored-by: Mikayla Maki <mikayla.c.maki@gmail.com>
2025-12-11 10:52:32 -08:00
Agus Zubiaga
2098b67304 edit prediction: Respect enabled settings when refreshing from diagnostics (#44640)
Release Notes:

- N/A
2025-12-11 17:39:57 +00:00
Lukas Wirth
5a6198cc39 language: Spawn language servers on background threads (#44631)
Closes https://github.com/zed-industries/zed/issues/39056

Leverages a new `await_on_background` API that spawns the future on the
background but blocks the current task, allowing to borrow from the
surrounding scope.

Release Notes:

- N/A *or* Added/Fixed/Improved ...
2025-12-11 17:23:27 +00:00
Siame Rafiq
cda78c12ab git: Make permalinks aware of current diffs (#41915)
Addressing #22546, we want git permalinks to be aware of the current
changes within the buffer.

This change calculates how many lines have been added/deleted between
the start and end of the selection and uses those values to offset the
selection.

This is done within `Editor::get_permalink_to_line` so that it can be
passed to any git_store.

Example:

<img width="284" height="316" alt="image"
src="https://github.com/user-attachments/assets/268043a0-2fc8-41c1-b094-d650fd4e0ae0"
/>

Where this selections permalink would previously return L3-L9, it now
returns L2-L7.

Release Notes:

- git: make permalinks aware of current diffs

Closes #22546

---

This is my first PR into the zed repository so very happy for any
feedback on how I've implemented this. Thanks!
2025-12-11 10:53:20 -05:00
Smit Barmase
f4378672b8 editor: Fix auto-indent cases in Markdown (#44616)
Builds on https://github.com/zed-industries/zed/pull/40794 and
https://github.com/zed-industries/zed/pull/44381

- Fixes the case where creating a new line inside a nested list puts the
cursor correctly under that nested list item.
- Fixes the case where typing a new list item at the expected indent no
longer auto-indents or outdents incorrectly.

Release Notes:

- Fixed an issue in Markdown where new list items weren’t respecting the
expected indentation on type.
2025-12-11 21:14:15 +05:30
Yara 🏳️‍⚧️
ecb8d3d4dd Revert "Multiple priority scheduler" (#44637)
Reverts zed-industries/zed#44575
2025-12-11 16:16:43 +01:00
localcc
95dbc0efc2 Multiple priority scheduler (#44575)
Improves the scheduler by allowing tasks to have a set priority which
will significantly improve responsiveness.

Release notes:

- N/A

---------

Co-authored-by: Yara <git@yara.blue>
2025-12-11 13:22:39 +00:00
Gaauwe Rombouts
8572c19a02 Improve TS/TSX/JS syntax highlighting for parameters, types, and punctuation (#44532)
Relands https://github.com/zed-industries/zed/pull/43437

Release Notes:

- Refined syntax highlighting in JavaScript and TypeScript for better
visual distinction of types, parameters, and JSDoc elements

---------

Co-authored-by: MrSubidubi <dev@bahn.sh>
Co-authored-by: Clay Tercek <30105080+claytercek@users.noreply.github.com>
2025-12-11 12:02:28 +01:00
Lukas Wirth
045c14593f util: Honor shell args for shell env fetching on windows (#44615)
Closes https://github.com/zed-industries/zed/issues/40464

Release Notes:

- Fixed shell environment fetching on windows discarding specified
arguments in settings
2025-12-11 10:34:37 +00:00
Lukas Wirth
0ff3b68a5e windows: Fix incorrect cursor insertion keybinds (#44608)
Release Notes:

- N/A *or* Added/Fixed/Improved ...
2025-12-11 09:38:44 +00:00
Lukas Wirth
a6b9524d78 gpui: Retain maximized and fullscreen state for new windows derived from previous windows (#44605)
Release Notes:

- Fixed new windows underflowing the taskbar on windows
- Improved new windows spawned from maximized or fullscreened windows by
copying the maximized and fullscreened states
2025-12-11 09:38:38 +00:00
CharlesChen0823
7ed5d42696 git: Fix git hook hang with prek (#44212)
Fix git hook hang when using with `prek`. Can see
[comments](https://github.com/zed-industries/zed/issues/44057#issuecomment-3606837089),
this is easy test, should using release build, debug build sometimes not
hang.

The issue existing long time, see issue #37293 , and then in commit
#42239 this issue had fixed. but in commit #43285 broken again. So I
reference the implementation in #42239, then this code work.

I MUST CLAIM, I really don't known what happend, and why this code work.
But it worked.

Release Notes:

- N/A

---------

Co-authored-by: Cole Miller <cole@zed.dev>
2025-12-11 03:17:13 +00:00
Max Brunsfeld
25d74480aa Rework edit prediction CLI (#44562)
This PR restructures the commands of the Edit Prediction CLI (now called
`ep`), to support some flows that are important for the training
process:
* generating zeta2 prompt and expected output, without running
predictions
* scoring outputs that are generated by a system other than the
production code (to evaluate the model during training)

To achieve this, we've restructured the CLI commands so that they all
take as input, and produce as output, a consistent, uniform data format:
a set of one or more `Example` structs, expressible either as the
original markdown format, or as a JSON lines. The `Example` struct
starts with the basic fields that are in human-readable eval format, but
contain a number of optional fields that are filled in by different
steps in the processing pipeline (`context`, `predict`, `format-prompt`,
and `score`).

### To do

* [x] Adjust the teacher model output parsing to use the full buffer
contents
* [x] Move udiff to cli
* [x] Align `format-prompt` with Zeta2's production code
* [x] Change score output to assume same provider
* [x] Move pretty reporting to `eval` command
* [x] Store cursor point in addition to cursor offset
* [x] Rename `edit_prediction_cli2` -> `edit_prediction_cli` (nuke the
old one)

Release Notes:

- N/A

---------

Co-authored-by: Oleksiy Syvokon <oleksiy@zed.dev>
Co-authored-by: Agus Zubiaga <agus@zed.dev>
Co-authored-by: Ben Kunkle <ben@zed.dev>
2025-12-10 17:36:51 -08:00
Cole Miller
37077a8ebb git: Avoid calling git help -a on every commit (#44586)
Updates #43993 

Release Notes:

- N/A
2025-12-11 01:03:35 +00:00
Finn Evers
7c4a85f5f1 ci: Explicitly set git committer information in protobuf check (#44582)
This should hopefully fix the flakes for good.

Release Notes:

- N/A
2025-12-10 23:35:02 +00:00
Cole Miller
d21628c349 Revert "Increase askpass timeout for git operations (#42946)" (#44578)
This reverts commit a74aac88c9.

cc @11happy, we need to do a bit more than just running `git hook
pre-push` before pushing, as described
[here](https://github.com/zed-industries/zed/pull/42946#issuecomment-3550570438).
Right now this is also running the pre-push hook twice.

Release Notes:

- N/A
2025-12-10 18:07:01 -05:00
77 changed files with 3709 additions and 5273 deletions

View File

@@ -497,6 +497,8 @@ jobs:
env:
GIT_AUTHOR_NAME: Protobuf Action
GIT_AUTHOR_EMAIL: ci@zed.dev
GIT_COMMITTER_NAME: Protobuf Action
GIT_COMMITTER_EMAIL: ci@zed.dev
steps:
- name: steps::checkout_repo
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683

30
Cargo.lock generated
View File

@@ -3111,16 +3111,6 @@ dependencies = [
"uuid",
]
[[package]]
name = "cloud_zeta2_prompt"
version = "0.1.0"
dependencies = [
"anyhow",
"cloud_llm_client",
"indoc",
"serde",
]
[[package]]
name = "cmake"
version = "0.1.54"
@@ -5119,7 +5109,6 @@ dependencies = [
"clock",
"cloud_api_types",
"cloud_llm_client",
"cloud_zeta2_prompt",
"collections",
"copilot",
"credentials_provider",
@@ -5150,8 +5139,6 @@ dependencies = [
"serde",
"serde_json",
"settings",
"smol",
"strsim",
"strum 0.27.2",
"telemetry",
"telemetry_events",
@@ -5162,6 +5149,7 @@ dependencies = [
"workspace",
"worktree",
"zed_actions",
"zeta_prompt",
"zlog",
]
@@ -5175,11 +5163,10 @@ dependencies = [
"clap",
"client",
"cloud_llm_client",
"cloud_zeta2_prompt",
"collections",
"debug_adapter_extension",
"dirs 4.0.0",
"edit_prediction",
"edit_prediction_context",
"extension",
"fs",
"futures 0.3.31",
@@ -5209,9 +5196,10 @@ dependencies = [
"sqlez",
"sqlez_macros",
"terminal_view",
"toml 0.8.23",
"util",
"wasmtime",
"watch",
"zeta_prompt",
"zlog",
]
@@ -5239,6 +5227,7 @@ dependencies = [
"text",
"tree-sitter",
"util",
"zeta_prompt",
"zlog",
]
@@ -5260,7 +5249,6 @@ dependencies = [
"buffer_diff",
"client",
"cloud_llm_client",
"cloud_zeta2_prompt",
"codestral",
"command_palette_hooks",
"copilot",
@@ -5291,6 +5279,7 @@ dependencies = [
"util",
"workspace",
"zed_actions",
"zeta_prompt",
]
[[package]]
@@ -20933,6 +20922,13 @@ dependencies = [
"syn 2.0.106",
]
[[package]]
name = "zeta_prompt"
version = "0.1.0"
dependencies = [
"serde",
]
[[package]]
name = "zip"
version = "0.6.6"

View File

@@ -32,7 +32,6 @@ members = [
"crates/cloud_api_client",
"crates/cloud_api_types",
"crates/cloud_llm_client",
"crates/cloud_zeta2_prompt",
"crates/collab",
"crates/collab_ui",
"crates/collections",
@@ -202,6 +201,7 @@ members = [
"crates/zed_actions",
"crates/zed_env_vars",
"crates/edit_prediction_cli",
"crates/zeta_prompt",
"crates/zlog",
"crates/zlog_settings",
"crates/ztracing",
@@ -266,7 +266,6 @@ clock = { path = "crates/clock" }
cloud_api_client = { path = "crates/cloud_api_client" }
cloud_api_types = { path = "crates/cloud_api_types" }
cloud_llm_client = { path = "crates/cloud_llm_client" }
cloud_zeta2_prompt = { path = "crates/cloud_zeta2_prompt" }
collab_ui = { path = "crates/collab_ui" }
collections = { path = "crates/collections", version = "0.1.0" }
command_palette = { path = "crates/command_palette" }
@@ -425,6 +424,7 @@ zed = { path = "crates/zed" }
zed_actions = { path = "crates/zed_actions" }
zed_env_vars = { path = "crates/zed_env_vars" }
edit_prediction = { path = "crates/edit_prediction" }
zeta_prompt = { path = "crates/zeta_prompt" }
zlog = { path = "crates/zlog" }
zlog_settings = { path = "crates/zlog_settings" }
ztracing = { path = "crates/ztracing" }
@@ -631,7 +631,7 @@ shellexpand = "2.1.0"
shlex = "1.3.0"
simplelog = "0.12.2"
slotmap = "1.0.6"
smallvec = { version = "1.6", features = ["union"] }
smallvec = { version = "1.6", features = ["union", "const_new"] }
smol = "2.0"
sqlformat = "0.2"
stacksafe = "0.1"
@@ -657,6 +657,7 @@ time = { version = "0.3", features = [
tiny_http = "0.8"
tokio = { version = "1" }
tokio-tungstenite = { version = "0.26", features = ["__rustls-tls"] }
tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io", "tokio"] }
toml = "0.8"
toml_edit = { version = "0.22", default-features = false, features = ["display", "parse", "serde"] }
tower-http = "0.4.4"

View File

@@ -489,8 +489,8 @@
"bindings": {
"ctrl-[": "editor::Outdent",
"ctrl-]": "editor::Indent",
"ctrl-shift-alt-up": ["editor::AddSelectionAbove", { "skip_soft_wrap": true }], // Insert Cursor Above
"ctrl-shift-alt-down": ["editor::AddSelectionBelow", { "skip_soft_wrap": true }], // Insert Cursor Below
"ctrl-alt-up": ["editor::AddSelectionAbove", { "skip_soft_wrap": true }], // Insert Cursor Above
"ctrl-alt-down": ["editor::AddSelectionBelow", { "skip_soft_wrap": true }], // Insert Cursor Below
"ctrl-shift-k": "editor::DeleteLine",
"alt-up": "editor::MoveLineUp",
"alt-down": "editor::MoveLineDown",

View File

@@ -1,6 +1,7 @@
use crate::{context::LoadedContext, inline_prompt_editor::CodegenStatus};
use agent_settings::AgentSettings;
use anyhow::{Context as _, Result};
use client::telemetry::Telemetry;
use cloud_llm_client::CompletionIntent;
use collections::HashSet;
@@ -11,12 +12,14 @@ use futures::{
channel::mpsc,
future::{LocalBoxFuture, Shared},
join,
stream::BoxStream,
};
use gpui::{App, AppContext as _, AsyncApp, Context, Entity, EventEmitter, Subscription, Task};
use language::{Buffer, IndentKind, Point, TransactionId, line_diff};
use language_model::{
LanguageModel, LanguageModelCompletionError, LanguageModelRegistry, LanguageModelRequest,
LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelTextStream, Role,
LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
LanguageModelRequestTool, LanguageModelTextStream, LanguageModelToolUse, Role, TokenUsage,
report_assistant_event,
};
use multi_buffer::MultiBufferRow;
@@ -46,6 +49,7 @@ pub struct FailureMessageInput {
/// A brief message to the user explaining why you're unable to fulfill the request or to ask a question about the request.
///
/// The message may use markdown formatting if you wish.
#[serde(default)]
pub message: String,
}
@@ -56,9 +60,11 @@ pub struct RewriteSectionInput {
///
/// The description may use markdown formatting if you wish.
/// This is optional - if the edit is simple or obvious, you should leave it empty.
#[serde(default)]
pub description: String,
/// The text to replace the section with.
#[serde(default)]
pub replacement_text: String,
}
@@ -400,9 +406,15 @@ impl CodegenAlternative {
if cx.has_flag::<InlineAssistantV2FeatureFlag>() {
let request = self.build_request(&model, user_prompt, context_task, cx)?;
let tool_use =
cx.spawn(async move |_, cx| model.stream_completion_tool(request.await, cx).await);
self.handle_tool_use(telemetry_id, provider_id.to_string(), api_key, tool_use, cx);
let completion_events =
cx.spawn(async move |_, cx| model.stream_completion(request.await, cx).await);
self.generation = self.handle_completion(
telemetry_id,
provider_id.to_string(),
api_key,
completion_events,
cx,
);
} else {
let stream: LocalBoxFuture<Result<LanguageModelTextStream>> =
if user_prompt.trim().to_lowercase() == "delete" {
@@ -414,7 +426,8 @@ impl CodegenAlternative {
})
.boxed_local()
};
self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
self.generation =
self.handle_stream(telemetry_id, provider_id.to_string(), api_key, stream, cx);
}
Ok(())
@@ -603,7 +616,7 @@ impl CodegenAlternative {
model_api_key: Option<String>,
stream: impl 'static + Future<Output = Result<LanguageModelTextStream>>,
cx: &mut Context<Self>,
) {
) -> Task<()> {
let start_time = Instant::now();
// Make a new snapshot and re-resolve anchor in case the document was modified.
@@ -659,7 +672,8 @@ impl CodegenAlternative {
let completion = Arc::new(Mutex::new(String::new()));
let completion_clone = completion.clone();
self.generation = cx.spawn(async move |codegen, cx| {
cx.notify();
cx.spawn(async move |codegen, cx| {
let stream = stream.await;
let token_usage = stream
@@ -685,6 +699,7 @@ impl CodegenAlternative {
stream?.stream.map_err(|error| error.into()),
);
futures::pin_mut!(chunks);
let mut diff = StreamingDiff::new(selected_text.to_string());
let mut line_diff = LineDiff::default();
@@ -694,6 +709,7 @@ impl CodegenAlternative {
let mut first_line = true;
while let Some(chunk) = chunks.next().await {
dbg!(&chunk);
if response_latency.is_none() {
response_latency = Some(request_start.elapsed());
}
@@ -876,8 +892,7 @@ impl CodegenAlternative {
cx.notify();
})
.ok();
});
cx.notify();
})
}
pub fn current_completion(&self) -> Option<String> {
@@ -1060,21 +1075,29 @@ impl CodegenAlternative {
})
}
fn handle_tool_use(
fn handle_completion(
&mut self,
_telemetry_id: String,
_provider_id: String,
_api_key: Option<String>,
tool_use: impl 'static
+ Future<
Output = Result<language_model::LanguageModelToolUse, LanguageModelCompletionError>,
telemetry_id: String,
provider_id: String,
api_key: Option<String>,
completion_stream: Task<
Result<
BoxStream<
'static,
Result<LanguageModelCompletionEvent, LanguageModelCompletionError>,
>,
LanguageModelCompletionError,
>,
>,
cx: &mut Context<Self>,
) {
) -> Task<()> {
self.diff = Diff::default();
self.status = CodegenStatus::Pending;
self.generation = cx.spawn(async move |codegen, cx| {
cx.notify();
// Leaving this in generation so that STOP equivalent events are respected even
// while we're still pre-processing the completion event
cx.spawn(async move |codegen, cx| {
let finish_with_status = |status: CodegenStatus, cx: &mut AsyncApp| {
let _ = codegen.update(cx, |this, cx| {
this.status = status;
@@ -1083,76 +1106,177 @@ impl CodegenAlternative {
});
};
let tool_use = tool_use.await;
let mut completion_events = match completion_stream.await {
Ok(events) => events,
Err(err) => {
finish_with_status(CodegenStatus::Error(err.into()), cx);
return;
}
};
match tool_use {
Ok(tool_use) if tool_use.name.as_ref() == "rewrite_section" => {
// Parse the input JSON into RewriteSectionInput
match serde_json::from_value::<RewriteSectionInput>(tool_use.input) {
Ok(input) => {
// Store the description if non-empty
let description = if !input.description.trim().is_empty() {
Some(input.description.clone())
} else {
None
let chars_read_so_far = Arc::new(Mutex::new(0usize));
let tool_to_text_and_message =
move |tool_use: LanguageModelToolUse| -> (Option<String>, Option<String>) {
let mut chars_read_so_far = chars_read_so_far.lock();
match tool_use.name.as_ref() {
"rewrite_section" => {
let Ok(mut input) =
serde_json::from_value::<RewriteSectionInput>(tool_use.input)
else {
return (None, None);
};
let value = input.replacement_text[*chars_read_so_far..].to_string();
*chars_read_so_far = input.replacement_text.len();
(Some(value), Some(std::mem::take(&mut input.description)))
}
"failure_message" => {
let Ok(mut input) =
serde_json::from_value::<FailureMessageInput>(tool_use.input)
else {
return (None, None);
};
(None, Some(std::mem::take(&mut input.message)))
}
_ => (None, None),
}
};
// Apply the replacement text to the buffer and compute diff
let batch_diff_task = codegen
.update(cx, |this, cx| {
this.model_explanation = description.map(Into::into);
let range = this.range.clone();
this.apply_edits(
std::iter::once((range, input.replacement_text)),
cx,
);
this.reapply_batch_diff(cx)
})
.ok();
let mut message_id = None;
let mut first_text = None;
let last_token_usage = Arc::new(Mutex::new(TokenUsage::default()));
let total_text = Arc::new(Mutex::new(String::new()));
// Wait for the diff computation to complete
if let Some(diff_task) = batch_diff_task {
diff_task.await;
loop {
if let Some(first_event) = completion_events.next().await {
dbg!(&first_event);
match first_event {
Ok(LanguageModelCompletionEvent::StartMessage { message_id: id }) => {
message_id = Some(id);
}
Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
if matches!(
tool_use.name.as_ref(),
"rewrite_section" | "failure_message"
) =>
{
let is_complete = tool_use.is_input_complete;
let (text, message) = tool_to_text_and_message(tool_use);
// Only update the model explanation if the tool use is complete.
// Otherwise the UI element bounces around as it's updated.
if is_complete {
let _ = codegen.update(cx, |this, _cx| {
this.model_explanation = message.map(Into::into);
});
}
finish_with_status(CodegenStatus::Done, cx);
return;
first_text = text;
if first_text.is_some() {
break;
}
}
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
*last_token_usage.lock() = token_usage;
}
Ok(LanguageModelCompletionEvent::Text(text)) => {
let mut lock = total_text.lock();
lock.push_str(&text);
}
Ok(e) => {
log::warn!("Unexpected event: {:?}", e);
break;
}
Err(e) => {
finish_with_status(CodegenStatus::Error(e.into()), cx);
return;
break;
}
}
}
Ok(tool_use) if tool_use.name.as_ref() == "failure_message" => {
// Handle failure message tool use
match serde_json::from_value::<FailureMessageInput>(tool_use.input) {
Ok(input) => {
let _ = codegen.update(cx, |this, _cx| {
// Store the failure message as the tool description
this.model_explanation = Some(input.message.into());
});
finish_with_status(CodegenStatus::Done, cx);
return;
}
Err(e) => {
finish_with_status(CodegenStatus::Error(e.into()), cx);
return;
}
}
}
Ok(_tool_use) => {
// Unexpected tool.
finish_with_status(CodegenStatus::Done, cx);
return;
}
Err(e) => {
finish_with_status(CodegenStatus::Error(e.into()), cx);
return;
}
}
});
cx.notify();
let Some(first_text) = first_text else {
finish_with_status(CodegenStatus::Done, cx);
return;
};
let (message_tx, mut message_rx) = futures::channel::mpsc::unbounded();
cx.spawn({
let codegen = codegen.clone();
async move |cx| {
while let Some(message) = message_rx.next().await {
let _ = codegen.update(cx, |this, _cx| {
this.model_explanation = message;
});
}
}
})
.detach();
let move_last_token_usage = last_token_usage.clone();
let text_stream = Box::pin(futures::stream::once(async { Ok(first_text) }).chain(
completion_events.filter_map(move |e| {
let tool_to_text_and_message = tool_to_text_and_message.clone();
let last_token_usage = move_last_token_usage.clone();
let total_text = total_text.clone();
let mut message_tx = message_tx.clone();
async move {
match e {
Ok(LanguageModelCompletionEvent::ToolUse(tool_use))
if matches!(
tool_use.name.as_ref(),
"rewrite_section" | "failure_message"
) =>
{
let is_complete = tool_use.is_input_complete;
let (text, message) = tool_to_text_and_message(tool_use);
if is_complete {
// Again only send the message when complete to not get a bouncing UI element.
let _ = message_tx.send(message.map(Into::into)).await;
}
text.map(Ok)
}
Ok(LanguageModelCompletionEvent::UsageUpdate(token_usage)) => {
*last_token_usage.lock() = token_usage;
None
}
Ok(LanguageModelCompletionEvent::Text(text)) => {
let mut lock = total_text.lock();
lock.push_str(&text);
None
}
Ok(LanguageModelCompletionEvent::Stop(_reason)) => None,
e => {
log::error!("UNEXPECTED EVENT {:?}", e);
None
}
}
}
}),
));
let language_model_text_stream = LanguageModelTextStream {
message_id: message_id,
stream: text_stream,
last_token_usage,
};
let Some(task) = codegen
.update(cx, move |codegen, cx| {
codegen.handle_stream(
telemetry_id,
provider_id,
api_key,
async { Ok(language_model_text_stream) },
cx,
)
})
.ok()
else {
return;
};
task.await;
})
}
}
@@ -1679,7 +1803,7 @@ mod tests {
) -> mpsc::UnboundedSender<String> {
let (chunks_tx, chunks_rx) = mpsc::unbounded();
codegen.update(cx, |codegen, cx| {
codegen.handle_stream(
codegen.generation = codegen.handle_stream(
String::new(),
String::new(),
None,

View File

@@ -1455,60 +1455,8 @@ impl InlineAssistant {
let old_snapshot = codegen.snapshot(cx);
let old_buffer = codegen.old_buffer(cx);
let deleted_row_ranges = codegen.diff(cx).deleted_row_ranges.clone();
// let model_explanation = codegen.model_explanation(cx);
editor.update(cx, |editor, cx| {
// Update tool description block
// if let Some(description) = model_explanation {
// if let Some(block_id) = decorations.model_explanation {
// editor.remove_blocks(HashSet::from_iter([block_id]), None, cx);
// let new_block_id = editor.insert_blocks(
// [BlockProperties {
// style: BlockStyle::Flex,
// placement: BlockPlacement::Below(assist.range.end),
// height: Some(1),
// render: Arc::new({
// let description = description.clone();
// move |cx| {
// div()
// .w_full()
// .py_1()
// .px_2()
// .bg(cx.theme().colors().editor_background)
// .border_y_1()
// .border_color(cx.theme().status().info_border)
// .child(
// Label::new(description.clone())
// .color(Color::Muted)
// .size(LabelSize::Small),
// )
// .into_any_element()
// }
// }),
// priority: 0,
// }],
// None,
// cx,
// );
// decorations.model_explanation = new_block_id.into_iter().next();
// }
// } else if let Some(block_id) = decorations.model_explanation {
// // Hide the block if there's no description
// editor.remove_blocks(HashSet::from_iter([block_id]), None, cx);
// let new_block_id = editor.insert_blocks(
// [BlockProperties {
// style: BlockStyle::Flex,
// placement: BlockPlacement::Below(assist.range.end),
// height: Some(0),
// render: Arc::new(|_cx| div().into_any_element()),
// priority: 0,
// }],
// None,
// cx,
// );
// decorations.model_explanation = new_block_id.into_iter().next();
// }
let old_blocks = mem::take(&mut decorations.removed_line_block_ids);
editor.remove_blocks(old_blocks, None, cx);

View File

@@ -429,6 +429,19 @@ impl Model {
let mut headers = vec![];
match self {
Self::ClaudeOpus4
| Self::ClaudeOpus4_1
| Self::ClaudeOpus4_5
| Self::ClaudeSonnet4
| Self::ClaudeSonnet4_5
| Self::ClaudeOpus4Thinking
| Self::ClaudeOpus4_1Thinking
| Self::ClaudeOpus4_5Thinking
| Self::ClaudeSonnet4Thinking
| Self::ClaudeSonnet4_5Thinking => {
// Fine-grained tool streaming for newer models
headers.push("fine-grained-tool-streaming-2025-05-14".to_string());
}
Self::Claude3_7Sonnet | Self::Claude3_7SonnetThinking => {
// Try beta token-efficient tool use (supported in Claude 3.7 Sonnet only)
// https://docs.anthropic.com/en/docs/build-with-claude/tool-use/token-efficient-tool-use

View File

@@ -53,7 +53,7 @@ text.workspace = true
thiserror.workspace = true
time.workspace = true
tiny_http.workspace = true
tokio-socks = { version = "0.5.2", default-features = false, features = ["futures-io"] }
tokio-socks.workspace = true
tokio.workspace = true
url.workspace = true
util.workspace = true

View File

@@ -1,18 +0,0 @@
[package]
name = "cloud_zeta2_prompt"
version = "0.1.0"
publish.workspace = true
edition.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/cloud_zeta2_prompt.rs"
[dependencies]
anyhow.workspace = true
cloud_llm_client.workspace = true
indoc.workspace = true
serde.workspace = true

View File

@@ -1,485 +0,0 @@
use anyhow::Result;
use cloud_llm_client::predict_edits_v3::{
self, DiffPathFmt, Event, Excerpt, Line, Point, PromptFormat, RelatedFile,
};
use indoc::indoc;
use std::cmp;
use std::fmt::Write;
use std::path::Path;
use std::sync::Arc;
pub const DEFAULT_MAX_PROMPT_BYTES: usize = 10 * 1024;
pub const CURSOR_MARKER: &str = "<|user_cursor|>";
/// NOTE: Differs from zed version of constant - includes a newline
pub const EDITABLE_REGION_START_MARKER_WITH_NEWLINE: &str = "<|editable_region_start|>\n";
/// NOTE: Differs from zed version of constant - includes a newline
pub const EDITABLE_REGION_END_MARKER_WITH_NEWLINE: &str = "<|editable_region_end|>\n";
const STUDENT_MODEL_INSTRUCTIONS: &str = indoc! {r#"
You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase.
## Edit History
"#};
const MINIMAL_PROMPT_REMINDER: &str = indoc! {"
---
Please analyze the edit history and the files, then provide the unified diff for your predicted edits.
Do not include the cursor marker in your output.
If you're editing multiple files, be sure to reflect filename in the hunk's header.
"};
const XML_TAGS_INSTRUCTIONS: &str = indoc! {r#"
# Instructions
You are an edit prediction agent in a code editor.
Analyze the history of edits made by the user in order to infer what they are currently trying to accomplish.
Then complete the remainder of the current change if it is incomplete, or predict the next edit the user intends to make.
Always continue along the user's current trajectory, rather than changing course.
## Output Format
You should briefly explain your understanding of the user's overall goal in one sentence, then explain what the next change
along the users current trajectory will be in another, and finally specify the next edit using the following XML-like format:
<edits path="my-project/src/myapp/cli.py">
<old_text>
OLD TEXT 1 HERE
</old_text>
<new_text>
NEW TEXT 1 HERE
</new_text>
<old_text>
OLD TEXT 1 HERE
</old_text>
<new_text>
NEW TEXT 1 HERE
</new_text>
</edits>
- Specify the file to edit using the `path` attribute.
- Use `<old_text>` and `<new_text>` tags to replace content
- `<old_text>` must exactly match existing file content, including indentation
- `<old_text>` cannot be empty
- Do not escape quotes, newlines, or other characters within tags
- Always close all tags properly
- Don't include the <|user_cursor|> marker in your output.
## Edit History
"#};
const OLD_TEXT_NEW_TEXT_REMINDER: &str = indoc! {r#"
---
Remember that the edits in the edit history have already been applied.
"#};
pub fn build_prompt(request: &predict_edits_v3::PredictEditsRequest) -> Result<String> {
let prompt_data = PromptData {
events: request.events.clone(),
cursor_point: request.cursor_point,
cursor_path: request.excerpt_path.clone(),
included_files: request.related_files.clone(),
};
match request.prompt_format {
PromptFormat::MinimalQwen => {
return Ok(MinimalQwenPrompt.render(&prompt_data));
}
PromptFormat::SeedCoder1120 => {
return Ok(SeedCoder1120Prompt.render(&prompt_data));
}
_ => (),
};
let insertions = match request.prompt_format {
PromptFormat::Minimal | PromptFormat::OldTextNewText => {
vec![(request.cursor_point, CURSOR_MARKER)]
}
PromptFormat::OnlySnippets => vec![],
PromptFormat::MinimalQwen => unreachable!(),
PromptFormat::SeedCoder1120 => unreachable!(),
};
let mut prompt = match request.prompt_format {
PromptFormat::OldTextNewText => XML_TAGS_INSTRUCTIONS.to_string(),
PromptFormat::OnlySnippets => String::new(),
PromptFormat::Minimal => STUDENT_MODEL_INSTRUCTIONS.to_string(),
PromptFormat::MinimalQwen => unreachable!(),
PromptFormat::SeedCoder1120 => unreachable!(),
};
if request.events.is_empty() {
prompt.push_str("(No edit history)\n\n");
} else {
let edit_preamble = if request.prompt_format == PromptFormat::Minimal {
"The following are the latest edits made by the user, from earlier to later.\n\n"
} else {
"Here are the latest edits made by the user, from earlier to later.\n\n"
};
prompt.push_str(edit_preamble);
push_events(&mut prompt, &request.events);
}
let excerpts_preamble = match request.prompt_format {
PromptFormat::Minimal => indoc! {"
## Part of the file under the cursor
(The cursor marker <|user_cursor|> indicates the current user cursor position.
The file is in current state, edits from edit history has been applied.
We only show part of the file around the cursor.
You can only edit exactly this part of the file.
We prepend line numbers (e.g., `123|<actual line>`); they are not part of the file.)
"},
PromptFormat::OldTextNewText => indoc! {"
## Code Excerpts
Here is some excerpts of code that you should take into account to predict the next edit.
The cursor position is marked by `<|user_cursor|>` as it stands after the last edit in the history.
In addition other excerpts are included to better understand what the edit will be, including the declaration
or references of symbols around the cursor, or other similar code snippets that may need to be updated
following patterns that appear in the edit history.
Consider each of them carefully in relation to the edit history, and that the user may not have navigated
to the next place they want to edit yet.
Lines starting with `…` indicate omitted line ranges. These may appear inside multi-line code constructs.
"},
PromptFormat::OnlySnippets | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => {
indoc! {"
## Code Excerpts
The cursor marker <|user_cursor|> indicates the current user cursor position.
The file is in current state, edits from edit history have been applied.
"}
}
};
prompt.push_str(excerpts_preamble);
prompt.push('\n');
let include_line_numbers = matches!(request.prompt_format, PromptFormat::Minimal);
for related_file in &request.related_files {
if request.prompt_format == PromptFormat::Minimal {
write_codeblock_with_filename(
&related_file.path,
&related_file.excerpts,
if related_file.path == request.excerpt_path {
&insertions
} else {
&[]
},
related_file.max_row,
include_line_numbers,
&mut prompt,
);
} else {
write_codeblock(
&related_file.path,
&related_file.excerpts,
if related_file.path == request.excerpt_path {
&insertions
} else {
&[]
},
related_file.max_row,
include_line_numbers,
&mut prompt,
);
}
}
match request.prompt_format {
PromptFormat::OldTextNewText => {
prompt.push_str(OLD_TEXT_NEW_TEXT_REMINDER);
}
PromptFormat::Minimal => {
prompt.push_str(MINIMAL_PROMPT_REMINDER);
}
_ => {}
}
Ok(prompt)
}
pub fn generation_params(prompt_format: PromptFormat) -> GenerationParams {
match prompt_format {
PromptFormat::SeedCoder1120 => SeedCoder1120Prompt::generation_params(),
_ => GenerationParams::default(),
}
}
pub fn write_codeblock<'a>(
path: &Path,
excerpts: impl IntoIterator<Item = &'a Excerpt>,
sorted_insertions: &[(Point, &str)],
file_line_count: Line,
include_line_numbers: bool,
output: &'a mut String,
) {
writeln!(output, "`````{}", DiffPathFmt(path)).unwrap();
write_excerpts(
excerpts,
sorted_insertions,
file_line_count,
include_line_numbers,
output,
);
write!(output, "`````\n\n").unwrap();
}
fn write_codeblock_with_filename<'a>(
path: &Path,
excerpts: impl IntoIterator<Item = &'a Excerpt>,
sorted_insertions: &[(Point, &str)],
file_line_count: Line,
include_line_numbers: bool,
output: &'a mut String,
) {
writeln!(output, "`````filename={}", DiffPathFmt(path)).unwrap();
write_excerpts(
excerpts,
sorted_insertions,
file_line_count,
include_line_numbers,
output,
);
write!(output, "`````\n\n").unwrap();
}
pub fn write_excerpts<'a>(
excerpts: impl IntoIterator<Item = &'a Excerpt>,
sorted_insertions: &[(Point, &str)],
file_line_count: Line,
include_line_numbers: bool,
output: &mut String,
) {
let mut current_row = Line(0);
let mut sorted_insertions = sorted_insertions.iter().peekable();
for excerpt in excerpts {
if excerpt.start_line > current_row {
writeln!(output, "").unwrap();
}
if excerpt.text.is_empty() {
return;
}
current_row = excerpt.start_line;
for mut line in excerpt.text.lines() {
if include_line_numbers {
write!(output, "{}|", current_row.0 + 1).unwrap();
}
while let Some((insertion_location, insertion_marker)) = sorted_insertions.peek() {
match current_row.cmp(&insertion_location.line) {
cmp::Ordering::Equal => {
let (prefix, suffix) = line.split_at(insertion_location.column as usize);
output.push_str(prefix);
output.push_str(insertion_marker);
line = suffix;
sorted_insertions.next();
}
cmp::Ordering::Less => break,
cmp::Ordering::Greater => {
sorted_insertions.next();
break;
}
}
}
output.push_str(line);
output.push('\n');
current_row.0 += 1;
}
}
if current_row < file_line_count {
writeln!(output, "").unwrap();
}
}
pub fn push_events(output: &mut String, events: &[Arc<predict_edits_v3::Event>]) {
if events.is_empty() {
return;
};
writeln!(output, "`````diff").unwrap();
for event in events {
writeln!(output, "{}", event).unwrap();
}
writeln!(output, "`````\n").unwrap();
}
struct PromptData {
events: Vec<Arc<Event>>,
cursor_point: Point,
cursor_path: Arc<Path>, // TODO: make a common struct with cursor_point
included_files: Vec<RelatedFile>,
}
#[derive(Default)]
pub struct GenerationParams {
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub stop: Option<Vec<String>>,
}
trait PromptFormatter {
fn render(&self, data: &PromptData) -> String;
fn generation_params() -> GenerationParams {
return GenerationParams::default();
}
}
struct MinimalQwenPrompt;
impl PromptFormatter for MinimalQwenPrompt {
fn render(&self, data: &PromptData) -> String {
let edit_history = self.fmt_edit_history(data);
let context = self.fmt_context(data);
format!(
"{instructions}\n\n{edit_history}\n\n{context}",
instructions = MinimalQwenPrompt::INSTRUCTIONS,
edit_history = edit_history,
context = context
)
}
}
impl MinimalQwenPrompt {
const INSTRUCTIONS: &str = "You are a code completion assistant that analyzes edit history to identify and systematically complete incomplete refactorings or patterns across the entire codebase.\n";
fn fmt_edit_history(&self, data: &PromptData) -> String {
if data.events.is_empty() {
"(No edit history)\n\n".to_string()
} else {
let mut events_str = String::new();
push_events(&mut events_str, &data.events);
format!(
"The following are the latest edits made by the user, from earlier to later.\n\n{}",
events_str
)
}
}
fn fmt_context(&self, data: &PromptData) -> String {
let mut context = String::new();
let include_line_numbers = true;
for related_file in &data.included_files {
writeln!(context, "<|file_sep|>{}", DiffPathFmt(&related_file.path)).unwrap();
if related_file.path == data.cursor_path {
write!(context, "<|fim_prefix|>").unwrap();
write_excerpts(
&related_file.excerpts,
&[(data.cursor_point, "<|fim_suffix|>")],
related_file.max_row,
include_line_numbers,
&mut context,
);
writeln!(context, "<|fim_middle|>").unwrap();
} else {
write_excerpts(
&related_file.excerpts,
&[],
related_file.max_row,
include_line_numbers,
&mut context,
);
}
}
context
}
}
struct SeedCoder1120Prompt;
impl PromptFormatter for SeedCoder1120Prompt {
fn render(&self, data: &PromptData) -> String {
let edit_history = self.fmt_edit_history(data);
let context = self.fmt_context(data);
format!(
"# Edit History:\n{edit_history}\n\n{context}",
edit_history = edit_history,
context = context
)
}
fn generation_params() -> GenerationParams {
GenerationParams {
temperature: Some(0.2),
top_p: Some(0.9),
stop: Some(vec!["<[end_of_sentence]>".into()]),
}
}
}
impl SeedCoder1120Prompt {
fn fmt_edit_history(&self, data: &PromptData) -> String {
if data.events.is_empty() {
"(No edit history)\n\n".to_string()
} else {
let mut events_str = String::new();
push_events(&mut events_str, &data.events);
events_str
}
}
fn fmt_context(&self, data: &PromptData) -> String {
let mut context = String::new();
let include_line_numbers = true;
for related_file in &data.included_files {
writeln!(context, "# Path: {}\n", DiffPathFmt(&related_file.path)).unwrap();
if related_file.path == data.cursor_path {
let fim_prompt = self.fmt_fim(&related_file, data.cursor_point);
context.push_str(&fim_prompt);
} else {
write_excerpts(
&related_file.excerpts,
&[],
related_file.max_row,
include_line_numbers,
&mut context,
);
}
}
context
}
fn fmt_fim(&self, file: &RelatedFile, cursor_point: Point) -> String {
let mut buf = String::new();
const FIM_SUFFIX: &str = "<[fim-suffix]>";
const FIM_PREFIX: &str = "<[fim-prefix]>";
const FIM_MIDDLE: &str = "<[fim-middle]>";
write!(buf, "{}", FIM_PREFIX).unwrap();
write_excerpts(
&file.excerpts,
&[(cursor_point, FIM_SUFFIX)],
file.max_row,
true,
&mut buf,
);
// Swap prefix and suffix parts
let index = buf.find(FIM_SUFFIX).unwrap();
let prefix = &buf[..index];
let suffix = &buf[index..];
format!("{}{}{}", suffix, prefix, FIM_MIDDLE)
}
}

View File

@@ -516,7 +516,8 @@ impl Copilot {
None,
Default::default(),
cx,
)?;
)
.await?;
server
.on_notification::<StatusNotification, _>(|_, _| { /* Silence the notification */ })

View File

@@ -1045,54 +1045,47 @@ async fn heuristic_syntactic_expand(
let node_range = node_start..node_end;
let row_count = node_end.row - node_start.row + 1;
let mut ancestor_range = None;
let reached_outline_node = cx.background_executor().scoped({
let node_range = node_range.clone();
let outline_range = outline_range.clone();
let ancestor_range = &mut ancestor_range;
|scope| {
scope.spawn(async move {
// Stop if we've exceeded the row count or reached an outline node. Then, find the interval
// of node children which contains the query range. For example, this allows just returning
// the header of a declaration rather than the entire declaration.
if row_count > max_row_count || outline_range == Some(node_range.clone()) {
let mut cursor = node.walk();
let mut included_child_start = None;
let mut included_child_end = None;
let mut previous_end = node_start;
if cursor.goto_first_child() {
loop {
let child_node = cursor.node();
let child_range =
previous_end..Point::from_ts_point(child_node.end_position());
if included_child_start.is_none()
&& child_range.contains(&input_range.start)
{
included_child_start = Some(child_range.start);
}
if child_range.contains(&input_range.end) {
included_child_end = Some(child_range.end);
}
previous_end = child_range.end;
if !cursor.goto_next_sibling() {
break;
}
cx.background_executor()
.await_on_background(async {
// Stop if we've exceeded the row count or reached an outline node. Then, find the interval
// of node children which contains the query range. For example, this allows just returning
// the header of a declaration rather than the entire declaration.
if row_count > max_row_count || outline_range == Some(node_range.clone()) {
let mut cursor = node.walk();
let mut included_child_start = None;
let mut included_child_end = None;
let mut previous_end = node_start;
if cursor.goto_first_child() {
loop {
let child_node = cursor.node();
let child_range =
previous_end..Point::from_ts_point(child_node.end_position());
if included_child_start.is_none()
&& child_range.contains(&input_range.start)
{
included_child_start = Some(child_range.start);
}
if child_range.contains(&input_range.end) {
included_child_end = Some(child_range.end);
}
previous_end = child_range.end;
if !cursor.goto_next_sibling() {
break;
}
}
let end = included_child_end.unwrap_or(node_range.end);
if let Some(start) = included_child_start {
let row_count = end.row - start.row;
if row_count < max_row_count {
*ancestor_range =
Some(Some(RangeInclusive::new(start.row, end.row)));
return;
}
}
*ancestor_range = Some(None);
}
})
}
});
reached_outline_node.await;
let end = included_child_end.unwrap_or(node_range.end);
if let Some(start) = included_child_start {
let row_count = end.row - start.row;
if row_count < max_row_count {
ancestor_range = Some(Some(RangeInclusive::new(start.row, end.row)));
return;
}
}
ancestor_range = Some(None);
}
})
.await;
if let Some(node) = ancestor_range {
return node;
}

View File

@@ -21,7 +21,6 @@ arrayvec.workspace = true
brotli.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
cloud_zeta2_prompt.workspace = true
collections.workspace = true
copilot.workspace = true
credentials_provider.workspace = true
@@ -50,8 +49,6 @@ semver.workspace = true
serde.workspace = true
serde_json.workspace = true
settings.workspace = true
smol.workspace = true
strsim.workspace = true
strum.workspace = true
telemetry.workspace = true
telemetry_events.workspace = true
@@ -62,6 +59,7 @@ uuid.workspace = true
workspace.workspace = true
worktree.workspace = true
zed_actions.workspace = true
zeta_prompt.workspace = true
[dev-dependencies]
clock = { workspace = true, features = ["test-support"] }

View File

@@ -1,14 +1,13 @@
use anyhow::Result;
use arrayvec::ArrayVec;
use client::{Client, EditPredictionUsage, UserStore};
use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
use cloud_llm_client::predict_edits_v3::{self, PromptFormat};
use cloud_llm_client::{
AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, EditPredictionRejectReason,
EditPredictionRejection, MAX_EDIT_PREDICTION_REJECTIONS_PER_REQUEST,
MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsRequestTrigger, RejectEditPredictionsBodyRef,
ZED_VERSION_HEADER_NAME,
};
use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES;
use collections::{HashMap, HashSet};
use db::kvp::{Dismissable, KEY_VALUE_STORE};
use edit_prediction_context::EditPredictionExcerptOptions;
@@ -16,10 +15,7 @@ use edit_prediction_context::{RelatedExcerptStore, RelatedExcerptStoreEvent, Rel
use feature_flags::{FeatureFlag, FeatureFlagAppExt as _};
use futures::{
AsyncReadExt as _, FutureExt as _, StreamExt as _,
channel::{
mpsc::{self, UnboundedReceiver},
oneshot,
},
channel::mpsc::{self, UnboundedReceiver},
select_biased,
};
use gpui::BackgroundExecutor;
@@ -58,8 +54,10 @@ mod onboarding_modal;
pub mod open_ai_response;
mod prediction;
pub mod sweep_ai;
#[cfg(any(test, feature = "test-support", feature = "eval-support"))]
pub mod udiff;
mod xml_edits;
mod zed_edit_prediction_delegate;
pub mod zeta1;
pub mod zeta2;
@@ -72,7 +70,6 @@ use crate::mercury::Mercury;
use crate::onboarding_modal::ZedPredictModal;
pub use crate::prediction::EditPrediction;
pub use crate::prediction::EditPredictionId;
pub use crate::prediction::EditPredictionInputs;
use crate::prediction::EditPredictionResult;
pub use crate::sweep_ai::SweepAi;
pub use telemetry_events::EditPredictionRating;
@@ -112,7 +109,6 @@ pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions {
min_bytes: 128,
target_before_cursor_over_total_bytes: 0.5,
},
max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES,
prompt_format: PromptFormat::DEFAULT,
};
@@ -162,7 +158,6 @@ pub struct EditPredictionStore {
use_context: bool,
options: ZetaOptions,
update_required: bool,
debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
#[cfg(feature = "eval-support")]
eval_cache: Option<Arc<dyn EvalCache>>,
edit_prediction_model: EditPredictionModel,
@@ -183,10 +178,22 @@ pub enum EditPredictionModel {
Mercury,
}
pub struct EditPredictionModelInput {
project: Entity<Project>,
buffer: Entity<Buffer>,
snapshot: BufferSnapshot,
position: Anchor,
events: Vec<Arc<zeta_prompt::Event>>,
related_files: Arc<[RelatedFile]>,
recent_paths: VecDeque<ProjectPath>,
trigger: PredictEditsRequestTrigger,
diagnostic_search_range: Range<Point>,
debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct ZetaOptions {
pub context: EditPredictionExcerptOptions,
pub max_prompt_bytes: usize,
pub prompt_format: predict_edits_v3::PromptFormat,
}
@@ -194,7 +201,8 @@ pub struct ZetaOptions {
pub enum DebugEvent {
ContextRetrievalStarted(ContextRetrievalStartedDebugEvent),
ContextRetrievalFinished(ContextRetrievalFinishedDebugEvent),
EditPredictionRequested(EditPredictionRequestedDebugEvent),
EditPredictionStarted(EditPredictionStartedDebugEvent),
EditPredictionFinished(EditPredictionFinishedDebugEvent),
}
#[derive(Debug)]
@@ -212,27 +220,30 @@ pub struct ContextRetrievalFinishedDebugEvent {
}
#[derive(Debug)]
pub struct EditPredictionRequestedDebugEvent {
pub inputs: EditPredictionInputs,
pub retrieval_time: Duration,
pub struct EditPredictionStartedDebugEvent {
pub buffer: WeakEntity<Buffer>,
pub position: Anchor,
pub local_prompt: Result<String, String>,
pub response_rx: oneshot::Receiver<(Result<open_ai::Response, String>, Duration)>,
pub prompt: Option<String>,
}
#[derive(Debug)]
pub struct EditPredictionFinishedDebugEvent {
pub buffer: WeakEntity<Buffer>,
pub position: Anchor,
pub model_output: Option<String>,
}
pub type RequestDebugInfo = predict_edits_v3::DebugInfo;
struct ProjectState {
events: VecDeque<Arc<cloud_llm_client::predict_edits_v3::Event>>,
events: VecDeque<Arc<zeta_prompt::Event>>,
last_event: Option<LastEvent>,
recent_paths: VecDeque<ProjectPath>,
registered_buffers: HashMap<gpui::EntityId, RegisteredBuffer>,
current_prediction: Option<CurrentEditPrediction>,
next_pending_prediction_id: usize,
pending_predictions: ArrayVec<PendingPrediction, 2>,
context_updates_tx: smol::channel::Sender<()>,
context_updates_rx: smol::channel::Receiver<()>,
debug_tx: Option<mpsc::UnboundedSender<DebugEvent>>,
last_prediction_refresh: Option<(EntityId, Instant)>,
cancelled_predictions: HashSet<usize>,
context: Entity<RelatedExcerptStore>,
@@ -241,7 +252,7 @@ struct ProjectState {
}
impl ProjectState {
pub fn events(&self, cx: &App) -> Vec<Arc<cloud_llm_client::predict_edits_v3::Event>> {
pub fn events(&self, cx: &App) -> Vec<Arc<zeta_prompt::Event>> {
self.events
.iter()
.cloned()
@@ -272,6 +283,18 @@ impl ProjectState {
})
.detach()
}
fn active_buffer(
&self,
project: &Entity<Project>,
cx: &App,
) -> Option<(Entity<Buffer>, Option<Anchor>)> {
let project = project.read(cx);
let active_path = project.path_for_entry(project.active_entry()?, cx)?;
let active_buffer = project.buffer_store().read(cx).get_by_path(&active_path)?;
let registered_buffer = self.registered_buffers.get(&active_buffer.entity_id())?;
Some((active_buffer, registered_buffer.last_position))
}
}
#[derive(Debug, Clone)]
@@ -362,6 +385,7 @@ impl std::ops::Deref for BufferEditPrediction<'_> {
struct RegisteredBuffer {
snapshot: BufferSnapshot,
last_position: Option<Anchor>,
_subscriptions: [gpui::Subscription; 2],
}
@@ -376,7 +400,7 @@ impl LastEvent {
&self,
license_detection_watchers: &HashMap<WorktreeId, Rc<LicenseDetectionWatcher>>,
cx: &App,
) -> Option<Arc<predict_edits_v3::Event>> {
) -> Option<Arc<zeta_prompt::Event>> {
let path = buffer_path_with_id_fallback(&self.new_snapshot, cx);
let old_path = buffer_path_with_id_fallback(&self.old_snapshot, cx);
@@ -396,7 +420,7 @@ impl LastEvent {
if path == old_path && diff.is_empty() {
None
} else {
Some(Arc::new(predict_edits_v3::Event::BufferChange {
Some(Arc::new(zeta_prompt::Event::BufferChange {
old_path,
path,
diff,
@@ -481,7 +505,6 @@ impl EditPredictionStore {
},
),
update_required: false,
debug_tx: None,
#[cfg(feature = "eval-support")]
eval_cache: None,
edit_prediction_model: EditPredictionModel::Zeta2,
@@ -536,12 +559,6 @@ impl EditPredictionStore {
self.eval_cache = Some(cache);
}
pub fn debug_info(&mut self) -> mpsc::UnboundedReceiver<DebugEvent> {
let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
self.debug_tx = Some(debug_watch_tx);
debug_watch_rx
}
pub fn options(&self) -> &ZetaOptions {
&self.options
}
@@ -560,15 +577,35 @@ impl EditPredictionStore {
}
}
pub fn edit_history_for_project(
&self,
project: &Entity<Project>,
) -> Vec<Arc<zeta_prompt::Event>> {
self.projects
.get(&project.entity_id())
.map(|project_state| project_state.events.iter().cloned().collect())
.unwrap_or_default()
}
pub fn context_for_project<'a>(
&'a self,
project: &Entity<Project>,
cx: &'a App,
) -> &'a [RelatedFile] {
) -> Arc<[RelatedFile]> {
self.projects
.get(&project.entity_id())
.map(|project| project.context.read(cx).related_files())
.unwrap_or(&[])
.unwrap_or_else(|| vec![].into())
}
pub fn context_for_project_with_buffers<'a>(
&'a self,
project: &Entity<Project>,
cx: &'a App,
) -> Option<impl 'a + Iterator<Item = (RelatedFile, Entity<Buffer>)>> {
self.projects
.get(&project.entity_id())
.map(|project| project.context.read(cx).related_files_with_buffers())
}
pub fn usage(&self, cx: &App) -> Option<EditPredictionUsage> {
@@ -599,85 +636,21 @@ impl EditPredictionStore {
cx: &mut Context<Self>,
) -> &mut ProjectState {
let entity_id = project.entity_id();
let (context_updates_tx, context_updates_rx) = smol::channel::unbounded();
self.projects
.entry(entity_id)
.or_insert_with(|| ProjectState {
context: {
let related_excerpt_store = cx.new(|cx| RelatedExcerptStore::new(project, cx));
cx.subscribe(
&related_excerpt_store,
move |this, _, event, _| match event {
RelatedExcerptStoreEvent::StartedRefresh => {
if let Some(debug_tx) = this.debug_tx.clone() {
debug_tx
.unbounded_send(DebugEvent::ContextRetrievalStarted(
ContextRetrievalStartedDebugEvent {
project_entity_id: entity_id,
timestamp: Instant::now(),
search_prompt: String::new(),
},
))
.ok();
}
}
RelatedExcerptStoreEvent::FinishedRefresh {
cache_hit_count,
cache_miss_count,
mean_definition_latency,
max_definition_latency,
} => {
if let Some(debug_tx) = this.debug_tx.clone() {
debug_tx
.unbounded_send(DebugEvent::ContextRetrievalFinished(
ContextRetrievalFinishedDebugEvent {
project_entity_id: entity_id,
timestamp: Instant::now(),
metadata: vec![
(
"Cache Hits",
format!(
"{}/{}",
cache_hit_count,
cache_hit_count + cache_miss_count
)
.into(),
),
(
"Max LSP Time",
format!(
"{} ms",
max_definition_latency.as_millis()
)
.into(),
),
(
"Mean LSP Time",
format!(
"{} ms",
mean_definition_latency.as_millis()
)
.into(),
),
],
},
))
.ok();
}
if let Some(project_state) = this.projects.get(&entity_id) {
project_state.context_updates_tx.send_blocking(()).ok();
}
}
},
)
cx.subscribe(&related_excerpt_store, move |this, _, event, _| {
this.handle_excerpt_store_event(entity_id, event);
})
.detach();
related_excerpt_store
},
events: VecDeque::new(),
last_event: None,
recent_paths: VecDeque::new(),
context_updates_rx,
context_updates_tx,
debug_tx: None,
registered_buffers: HashMap::default(),
current_prediction: None,
cancelled_predictions: HashSet::default(),
@@ -689,12 +662,79 @@ impl EditPredictionStore {
})
}
pub fn project_context_updates(
&self,
pub fn remove_project(&mut self, project: &Entity<Project>) {
self.projects.remove(&project.entity_id());
}
fn handle_excerpt_store_event(
&mut self,
project_entity_id: EntityId,
event: &RelatedExcerptStoreEvent,
) {
if let Some(project_state) = self.projects.get(&project_entity_id) {
if let Some(debug_tx) = project_state.debug_tx.clone() {
match event {
RelatedExcerptStoreEvent::StartedRefresh => {
debug_tx
.unbounded_send(DebugEvent::ContextRetrievalStarted(
ContextRetrievalStartedDebugEvent {
project_entity_id: project_entity_id,
timestamp: Instant::now(),
search_prompt: String::new(),
},
))
.ok();
}
RelatedExcerptStoreEvent::FinishedRefresh {
cache_hit_count,
cache_miss_count,
mean_definition_latency,
max_definition_latency,
} => {
debug_tx
.unbounded_send(DebugEvent::ContextRetrievalFinished(
ContextRetrievalFinishedDebugEvent {
project_entity_id: project_entity_id,
timestamp: Instant::now(),
metadata: vec![
(
"Cache Hits",
format!(
"{}/{}",
cache_hit_count,
cache_hit_count + cache_miss_count
)
.into(),
),
(
"Max LSP Time",
format!("{} ms", max_definition_latency.as_millis())
.into(),
),
(
"Mean LSP Time",
format!("{} ms", mean_definition_latency.as_millis())
.into(),
),
],
},
))
.ok();
}
}
}
}
}
pub fn debug_info(
&mut self,
project: &Entity<Project>,
) -> Option<smol::channel::Receiver<()>> {
let project_state = self.projects.get(&project.entity_id())?;
Some(project_state.context_updates_rx.clone())
cx: &mut Context<Self>,
) -> mpsc::UnboundedReceiver<DebugEvent> {
let project_state = self.get_or_init_project(project, cx);
let (debug_watch_tx, debug_watch_rx) = mpsc::unbounded();
project_state.debug_tx = Some(debug_watch_tx);
debug_watch_rx
}
fn handle_project_event(
@@ -768,6 +808,7 @@ impl EditPredictionStore {
let project_entity_id = project.entity_id();
entry.insert(RegisteredBuffer {
snapshot,
last_position: None,
_subscriptions: [
cx.subscribe(buffer, {
let project = project.downgrade();
@@ -855,13 +896,21 @@ impl EditPredictionStore {
});
}
fn current_prediction_for_buffer(
&self,
fn prediction_at(
&mut self,
buffer: &Entity<Buffer>,
position: Option<language::Anchor>,
project: &Entity<Project>,
cx: &App,
) -> Option<BufferEditPrediction<'_>> {
let project_state = self.projects.get(&project.entity_id())?;
let project_state = self.projects.get_mut(&project.entity_id())?;
if let Some(position) = position
&& let Some(buffer) = project_state
.registered_buffers
.get_mut(&buffer.entity_id())
{
buffer.last_position = Some(position);
}
let CurrentEditPrediction {
requested_by,
@@ -1104,12 +1153,21 @@ impl EditPredictionStore {
};
self.queue_prediction_refresh(project.clone(), project.entity_id(), cx, move |this, cx| {
let Some(open_buffer_task) = project
.update(cx, |project, cx| {
project
.active_entry()
.and_then(|entry| project.path_for_entry(entry, cx))
.map(|path| project.open_buffer(path, cx))
let Some((active_buffer, snapshot, cursor_point)) = this
.read_with(cx, |this, cx| {
let project_state = this.projects.get(&project.entity_id())?;
let (buffer, position) = project_state.active_buffer(&project, cx)?;
let snapshot = buffer.read(cx).snapshot();
if !Self::predictions_enabled_at(&snapshot, position, cx) {
return None;
}
let cursor_point = position
.map(|pos| pos.to_point(&snapshot))
.unwrap_or_default();
Some((buffer, snapshot, cursor_point))
})
.log_err()
.flatten()
@@ -1118,14 +1176,11 @@ impl EditPredictionStore {
};
cx.spawn(async move |cx| {
let active_buffer = open_buffer_task.await?;
let snapshot = active_buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
let Some((jump_buffer, jump_position)) = Self::next_diagnostic_location(
active_buffer,
&snapshot,
Default::default(),
Default::default(),
cursor_point,
&project,
cx,
)
@@ -1170,6 +1225,37 @@ impl EditPredictionStore {
});
}
fn predictions_enabled_at(
snapshot: &BufferSnapshot,
position: Option<language::Anchor>,
cx: &App,
) -> bool {
let file = snapshot.file();
let all_settings = all_language_settings(file, cx);
if !all_settings.show_edit_predictions(snapshot.language(), cx)
|| file.is_some_and(|file| !all_settings.edit_predictions_enabled_for_file(file, cx))
{
return false;
}
if let Some(last_position) = position {
let settings = snapshot.settings_at(last_position, cx);
if !settings.edit_predictions_disabled_in.is_empty()
&& let Some(scope) = snapshot.language_scope_at(last_position)
&& let Some(scope_name) = scope.override_name()
&& settings
.edit_predictions_disabled_in
.iter()
.any(|s| s == scope_name)
{
return false;
}
}
true
}
#[cfg(not(test))]
pub const THROTTLE_TIMEOUT: Duration = Duration::from_millis(300);
#[cfg(test)]
@@ -1348,6 +1434,7 @@ impl EditPredictionStore {
let project_state = self.projects.get(&project.entity_id()).unwrap();
let events = project_state.events(cx);
let has_events = !events.is_empty();
let debug_tx = project_state.debug_tx.clone();
let snapshot = active_buffer.read(cx).snapshot();
let cursor_point = position.to_point(&snapshot);
@@ -1357,55 +1444,29 @@ impl EditPredictionStore {
Point::new(diagnostic_search_start, 0)..Point::new(diagnostic_search_end, 0);
let related_files = if self.use_context {
self.context_for_project(&project, cx).to_vec()
self.context_for_project(&project, cx)
} else {
Vec::new()
Vec::new().into()
};
let inputs = EditPredictionModelInput {
project: project.clone(),
buffer: active_buffer.clone(),
snapshot: snapshot.clone(),
position,
events,
related_files,
recent_paths: project_state.recent_paths.clone(),
trigger,
diagnostic_search_range: diagnostic_search_range.clone(),
debug_tx,
};
let task = match self.edit_prediction_model {
EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(
self,
&project,
&active_buffer,
snapshot.clone(),
position,
events,
trigger,
cx,
),
EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(
self,
&project,
&active_buffer,
snapshot.clone(),
position,
events,
related_files,
trigger,
cx,
),
EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(
&project,
&active_buffer,
snapshot.clone(),
position,
events,
&project_state.recent_paths,
related_files,
diagnostic_search_range.clone(),
cx,
),
EditPredictionModel::Mercury => self.mercury.request_prediction(
&project,
&active_buffer,
snapshot.clone(),
position,
events,
&project_state.recent_paths,
related_files,
diagnostic_search_range.clone(),
cx,
),
EditPredictionModel::Zeta1 => zeta1::request_prediction_with_zeta1(self, inputs, cx),
EditPredictionModel::Zeta2 => zeta2::request_prediction_with_zeta2(self, inputs, cx),
EditPredictionModel::Sweep => self.sweep_ai.request_prediction_with_sweep(inputs, cx),
EditPredictionModel::Mercury => self.mercury.request_prediction(inputs, cx),
};
cx.spawn(async move |this, cx| {
@@ -1706,6 +1767,20 @@ impl EditPredictionStore {
}
}
#[cfg(feature = "eval-support")]
pub fn set_context_for_buffer(
&mut self,
project: &Entity<Project>,
related_files: Vec<RelatedFile>,
cx: &mut Context<Self>,
) {
self.get_or_init_project(project, cx)
.context
.update(cx, |store, _| {
store.set_related_files(related_files);
});
}
fn is_file_open_source(
&self,
project: &Entity<Project>,
@@ -1729,14 +1804,14 @@ impl EditPredictionStore {
self.data_collection_choice.is_enabled() && self.is_file_open_source(project, file, cx)
}
fn can_collect_events(&self, events: &[Arc<Event>]) -> bool {
fn can_collect_events(&self, events: &[Arc<zeta_prompt::Event>]) -> bool {
if !self.data_collection_choice.is_enabled() {
return false;
}
events.iter().all(|event| {
matches!(
event.as_ref(),
Event::BufferChange {
zeta_prompt::Event::BufferChange {
in_open_source_repo: true,
..
}

View File

@@ -1,5 +1,5 @@
use super::*;
use crate::zeta1::MAX_EVENT_TOKENS;
use crate::{udiff::apply_diff_to_string, zeta1::MAX_EVENT_TOKENS};
use client::{UserStore, test::FakeServer};
use clock::{FakeSystemClock, ReplicaId};
use cloud_api_types::{CreateLlmTokenResponse, LlmToken};
@@ -7,7 +7,6 @@ use cloud_llm_client::{
EditPredictionRejectReason, EditPredictionRejection, PredictEditsBody, PredictEditsResponse,
RejectEditPredictionsBody,
};
use edit_prediction_context::Line;
use futures::{
AsyncReadExt, StreamExt,
channel::{mpsc, oneshot},
@@ -28,6 +27,7 @@ use settings::SettingsStore;
use std::{path::Path, sync::Arc, time::Duration};
use util::{path, rel_path::rel_path};
use uuid::Uuid;
use zeta_prompt::ZetaPromptInput;
use crate::{BufferEditPrediction, EditPredictionId, EditPredictionStore, REJECT_REQUEST_DEBOUNCE};
@@ -45,10 +45,6 @@ async fn test_current_state(cx: &mut TestAppContext) {
.await;
let project = Project::test(fs, vec![path!("/root").as_ref()], cx).await;
ep_store.update(cx, |ep_store, cx| {
ep_store.register_project(&project, cx);
});
let buffer1 = project
.update(cx, |project, cx| {
let path = project.find_project_path(path!("/root/1.txt"), cx).unwrap();
@@ -60,30 +56,38 @@ async fn test_current_state(cx: &mut TestAppContext) {
let snapshot1 = buffer1.read_with(cx, |buffer, _cx| buffer.snapshot());
let position = snapshot1.anchor_before(language::Point::new(1, 3));
ep_store.update(cx, |ep_store, cx| {
ep_store.register_project(&project, cx);
ep_store.register_buffer(&buffer1, &project, cx);
});
// Prediction for current file
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer1.clone(), position, cx)
});
let (_request, respond_tx) = requests.predict.next().await.unwrap();
let (request, respond_tx) = requests.predict.next().await.unwrap();
respond_tx
.send(model_response(indoc! {r"
--- a/root/1.txt
+++ b/root/1.txt
@@ ... @@
Hello!
-How
+How are you?
Bye
"}))
.send(model_response(
request,
indoc! {r"
--- a/root/1.txt
+++ b/root/1.txt
@@ ... @@
Hello!
-How
+How are you?
Bye
"},
))
.unwrap();
cx.run_until_parked();
ep_store.read_with(cx, |ep_store, cx| {
ep_store.update(cx, |ep_store, cx| {
let prediction = ep_store
.current_prediction_for_buffer(&buffer1, &project, cx)
.prediction_at(&buffer1, None, &project, cx)
.unwrap();
assert_matches!(prediction, BufferEditPrediction::Local { .. });
});
@@ -120,22 +124,26 @@ async fn test_current_state(cx: &mut TestAppContext) {
});
});
let (_request, respond_tx) = requests.predict.next().await.unwrap();
let (request, respond_tx) = requests.predict.next().await.unwrap();
respond_tx
.send(model_response(indoc! {r#"
--- a/root/2.txt
+++ b/root/2.txt
Hola!
-Como
+Como estas?
Adios
"#}))
.send(model_response(
request,
indoc! {r#"
--- a/root/2.txt
+++ b/root/2.txt
@@ ... @@
Hola!
-Como
+Como estas?
Adios
"#},
))
.unwrap();
cx.run_until_parked();
ep_store.read_with(cx, |ep_store, cx| {
ep_store.update(cx, |ep_store, cx| {
let prediction = ep_store
.current_prediction_for_buffer(&buffer1, &project, cx)
.prediction_at(&buffer1, None, &project, cx)
.unwrap();
assert_matches!(
prediction,
@@ -151,9 +159,9 @@ async fn test_current_state(cx: &mut TestAppContext) {
.await
.unwrap();
ep_store.read_with(cx, |ep_store, cx| {
ep_store.update(cx, |ep_store, cx| {
let prediction = ep_store
.current_prediction_for_buffer(&buffer2, &project, cx)
.prediction_at(&buffer2, None, &project, cx)
.unwrap();
assert_matches!(prediction, BufferEditPrediction::Local { .. });
});
@@ -186,7 +194,7 @@ async fn test_simple_request(cx: &mut TestAppContext) {
ep_store.request_prediction(&project, &buffer, position, Default::default(), cx)
});
let (_, respond_tx) = requests.predict.next().await.unwrap();
let (request, respond_tx) = requests.predict.next().await.unwrap();
// TODO Put back when we have a structured request again
// assert_eq!(
@@ -202,15 +210,18 @@ async fn test_simple_request(cx: &mut TestAppContext) {
// );
respond_tx
.send(model_response(indoc! { r"
--- a/root/foo.md
+++ b/root/foo.md
@@ ... @@
Hello!
-How
+How are you?
Bye
"}))
.send(model_response(
request,
indoc! { r"
--- a/root/foo.md
+++ b/root/foo.md
@@ ... @@
Hello!
-How
+How are you?
Bye
"},
))
.unwrap();
let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
@@ -276,15 +287,18 @@ async fn test_request_events(cx: &mut TestAppContext) {
);
respond_tx
.send(model_response(indoc! {r#"
--- a/root/foo.md
+++ b/root/foo.md
@@ ... @@
Hello!
-How
+How are you?
Bye
"#}))
.send(model_response(
request,
indoc! {r#"
--- a/root/foo.md
+++ b/root/foo.md
@@ ... @@
Hello!
-How
+How are you?
Bye
"#},
))
.unwrap();
let prediction = prediction_task.await.unwrap().unwrap().prediction.unwrap();
@@ -324,27 +338,17 @@ async fn test_empty_prediction(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
const NO_OP_DIFF: &str = indoc! { r"
--- a/root/foo.md
+++ b/root/foo.md
@@ ... @@
Hello!
-How
+How
Bye
"};
let (_, respond_tx) = requests.predict.next().await.unwrap();
let response = model_response(NO_OP_DIFF);
let (request, respond_tx) = requests.predict.next().await.unwrap();
let response = model_response(request, "");
let id = response.id.clone();
respond_tx.send(response).unwrap();
cx.run_until_parked();
ep_store.read_with(cx, |ep_store, cx| {
ep_store.update(cx, |ep_store, cx| {
assert!(
ep_store
.current_prediction_for_buffer(&buffer, &project, cx)
.prediction_at(&buffer, None, &project, cx)
.is_none()
);
});
@@ -389,22 +393,22 @@ async fn test_interpolated_empty(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (_, respond_tx) = requests.predict.next().await.unwrap();
let (request, respond_tx) = requests.predict.next().await.unwrap();
buffer.update(cx, |buffer, cx| {
buffer.set_text("Hello!\nHow are you?\nBye", cx);
});
let response = model_response(SIMPLE_DIFF);
let response = model_response(request, SIMPLE_DIFF);
let id = response.id.clone();
respond_tx.send(response).unwrap();
cx.run_until_parked();
ep_store.read_with(cx, |ep_store, cx| {
ep_store.update(cx, |ep_store, cx| {
assert!(
ep_store
.current_prediction_for_buffer(&buffer, &project, cx)
.prediction_at(&buffer, None, &project, cx)
.is_none()
);
});
@@ -459,17 +463,17 @@ async fn test_replace_current(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (_, respond_tx) = requests.predict.next().await.unwrap();
let first_response = model_response(SIMPLE_DIFF);
let (request, respond_tx) = requests.predict.next().await.unwrap();
let first_response = model_response(request, SIMPLE_DIFF);
let first_id = first_response.id.clone();
respond_tx.send(first_response).unwrap();
cx.run_until_parked();
ep_store.read_with(cx, |ep_store, cx| {
ep_store.update(cx, |ep_store, cx| {
assert_eq!(
ep_store
.current_prediction_for_buffer(&buffer, &project, cx)
.prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
@@ -482,18 +486,18 @@ async fn test_replace_current(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (_, respond_tx) = requests.predict.next().await.unwrap();
let second_response = model_response(SIMPLE_DIFF);
let (request, respond_tx) = requests.predict.next().await.unwrap();
let second_response = model_response(request, SIMPLE_DIFF);
let second_id = second_response.id.clone();
respond_tx.send(second_response).unwrap();
cx.run_until_parked();
ep_store.read_with(cx, |ep_store, cx| {
ep_store.update(cx, |ep_store, cx| {
// second replaces first
assert_eq!(
ep_store
.current_prediction_for_buffer(&buffer, &project, cx)
.prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
@@ -541,17 +545,17 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (_, respond_tx) = requests.predict.next().await.unwrap();
let first_response = model_response(SIMPLE_DIFF);
let (request, respond_tx) = requests.predict.next().await.unwrap();
let first_response = model_response(request, SIMPLE_DIFF);
let first_id = first_response.id.clone();
respond_tx.send(first_response).unwrap();
cx.run_until_parked();
ep_store.read_with(cx, |ep_store, cx| {
ep_store.update(cx, |ep_store, cx| {
assert_eq!(
ep_store
.current_prediction_for_buffer(&buffer, &project, cx)
.prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
@@ -564,27 +568,30 @@ async fn test_current_preferred(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (_, respond_tx) = requests.predict.next().await.unwrap();
let (request, respond_tx) = requests.predict.next().await.unwrap();
// worse than current prediction
let second_response = model_response(indoc! { r"
--- a/root/foo.md
+++ b/root/foo.md
@@ ... @@
Hello!
-How
+How are
Bye
"});
let second_response = model_response(
request,
indoc! { r"
--- a/root/foo.md
+++ b/root/foo.md
@@ ... @@
Hello!
-How
+How are
Bye
"},
);
let second_id = second_response.id.clone();
respond_tx.send(second_response).unwrap();
cx.run_until_parked();
ep_store.read_with(cx, |ep_store, cx| {
ep_store.update(cx, |ep_store, cx| {
// first is preferred over second
assert_eq!(
ep_store
.current_prediction_for_buffer(&buffer, &project, cx)
.prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
@@ -633,29 +640,29 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (_, respond_first) = requests.predict.next().await.unwrap();
let (request1, respond_first) = requests.predict.next().await.unwrap();
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (_, respond_second) = requests.predict.next().await.unwrap();
let (request, respond_second) = requests.predict.next().await.unwrap();
// wait for throttle
cx.run_until_parked();
// second responds first
let second_response = model_response(SIMPLE_DIFF);
let second_response = model_response(request, SIMPLE_DIFF);
let second_id = second_response.id.clone();
respond_second.send(second_response).unwrap();
cx.run_until_parked();
ep_store.read_with(cx, |ep_store, cx| {
ep_store.update(cx, |ep_store, cx| {
// current prediction is second
assert_eq!(
ep_store
.current_prediction_for_buffer(&buffer, &project, cx)
.prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
@@ -663,17 +670,17 @@ async fn test_cancel_earlier_pending_requests(cx: &mut TestAppContext) {
);
});
let first_response = model_response(SIMPLE_DIFF);
let first_response = model_response(request1, SIMPLE_DIFF);
let first_id = first_response.id.clone();
respond_first.send(first_response).unwrap();
cx.run_until_parked();
ep_store.read_with(cx, |ep_store, cx| {
ep_store.update(cx, |ep_store, cx| {
// current prediction is still second, since first was cancelled
assert_eq!(
ep_store
.current_prediction_for_buffer(&buffer, &project, cx)
.prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
@@ -724,13 +731,13 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (_, respond_first) = requests.predict.next().await.unwrap();
let (request1, respond_first) = requests.predict.next().await.unwrap();
ep_store.update(cx, |ep_store, cx| {
ep_store.refresh_prediction_from_buffer(project.clone(), buffer.clone(), position, cx);
});
let (_, respond_second) = requests.predict.next().await.unwrap();
let (request2, respond_second) = requests.predict.next().await.unwrap();
// wait for throttle, so requests are sent
cx.run_until_parked();
@@ -754,19 +761,19 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
// wait for throttle
cx.run_until_parked();
let (_, respond_third) = requests.predict.next().await.unwrap();
let (request3, respond_third) = requests.predict.next().await.unwrap();
let first_response = model_response(SIMPLE_DIFF);
let first_response = model_response(request1, SIMPLE_DIFF);
let first_id = first_response.id.clone();
respond_first.send(first_response).unwrap();
cx.run_until_parked();
ep_store.read_with(cx, |ep_store, cx| {
ep_store.update(cx, |ep_store, cx| {
// current prediction is first
assert_eq!(
ep_store
.current_prediction_for_buffer(&buffer, &project, cx)
.prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
@@ -774,17 +781,17 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
);
});
let cancelled_response = model_response(SIMPLE_DIFF);
let cancelled_response = model_response(request2, SIMPLE_DIFF);
let cancelled_id = cancelled_response.id.clone();
respond_second.send(cancelled_response).unwrap();
cx.run_until_parked();
ep_store.read_with(cx, |ep_store, cx| {
ep_store.update(cx, |ep_store, cx| {
// current prediction is still first, since second was cancelled
assert_eq!(
ep_store
.current_prediction_for_buffer(&buffer, &project, cx)
.prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
@@ -792,17 +799,17 @@ async fn test_cancel_second_on_third_request(cx: &mut TestAppContext) {
);
});
let third_response = model_response(SIMPLE_DIFF);
let third_response = model_response(request3, SIMPLE_DIFF);
let third_response_id = third_response.id.clone();
respond_third.send(third_response).unwrap();
cx.run_until_parked();
ep_store.read_with(cx, |ep_store, cx| {
ep_store.update(cx, |ep_store, cx| {
// third completes and replaces first
assert_eq!(
ep_store
.current_prediction_for_buffer(&buffer, &project, cx)
.prediction_at(&buffer, None, &project, cx)
.unwrap()
.id
.0,
@@ -1036,7 +1043,24 @@ async fn test_rejections_flushing(cx: &mut TestAppContext) {
// );
// }
fn model_response(text: &str) -> open_ai::Response {
// Generate a model response that would apply the given diff to the active file.
fn model_response(request: open_ai::Request, diff_to_apply: &str) -> open_ai::Response {
let prompt = match &request.messages[0] {
open_ai::RequestMessage::User {
content: open_ai::MessageContent::Plain(content),
} => content,
_ => panic!("unexpected request {request:?}"),
};
let open = "<editable_region>\n";
let close = "</editable_region>";
let cursor = "<|user_cursor|>";
let start_ix = open.len() + prompt.find(open).unwrap();
let end_ix = start_ix + &prompt[start_ix..].find(close).unwrap();
let excerpt = prompt[start_ix..end_ix].replace(cursor, "");
let new_excerpt = apply_diff_to_string(diff_to_apply, &excerpt).unwrap();
open_ai::Response {
id: Uuid::new_v4().to_string(),
object: "response".into(),
@@ -1045,7 +1069,7 @@ fn model_response(text: &str) -> open_ai::Response {
choices: vec![open_ai::Choice {
index: 0,
message: open_ai::RequestMessage::Assistant {
content: Some(open_ai::MessageContent::Plain(text.to_string())),
content: Some(open_ai::MessageContent::Plain(new_excerpt)),
tool_calls: vec![],
},
finish_reason: None,
@@ -1160,20 +1184,19 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
.read(|cx| buffer.read(cx).preview_edits(edits.clone(), cx))
.await;
let completion = EditPrediction {
let prediction = EditPrediction {
edits,
edit_preview,
buffer: buffer.clone(),
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
id: EditPredictionId("the-id".into()),
inputs: EditPredictionInputs {
inputs: ZetaPromptInput {
events: Default::default(),
included_files: Default::default(),
cursor_point: cloud_llm_client::predict_edits_v3::Point {
line: Line(0),
column: 0,
},
related_files: Default::default(),
cursor_path: Path::new("").into(),
cursor_excerpt: "".into(),
editable_range_in_excerpt: 0..0,
cursor_offset_in_excerpt: 0,
},
buffer_snapshotted_at: Instant::now(),
response_received_at: Instant::now(),
@@ -1182,7 +1205,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
cx.update(|cx| {
assert_eq!(
from_completion_edits(
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1192,7 +1215,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "")], None, cx));
assert_eq!(
from_completion_edits(
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1202,7 +1225,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.undo(cx));
assert_eq!(
from_completion_edits(
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1212,7 +1235,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(2..5, "R")], None, cx));
assert_eq!(
from_completion_edits(
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1222,7 +1245,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(3..3, "E")], None, cx));
assert_eq!(
from_completion_edits(
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1232,7 +1255,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(4..4, "M")], None, cx));
assert_eq!(
from_completion_edits(
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1242,7 +1265,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(4..5, "")], None, cx));
assert_eq!(
from_completion_edits(
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1252,7 +1275,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
buffer.update(cx, |buffer, cx| buffer.edit([(8..10, "")], None, cx));
assert_eq!(
from_completion_edits(
&completion.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&prediction.interpolate(&buffer.read(cx).snapshot()).unwrap(),
&buffer,
cx
),
@@ -1260,7 +1283,7 @@ async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
);
buffer.update(cx, |buffer, cx| buffer.edit([(4..6, "")], None, cx));
assert_eq!(completion.interpolate(&buffer.read(cx).snapshot()), None);
assert_eq!(prediction.interpolate(&buffer.read(cx).snapshot()), None);
})
}

View File

@@ -1,20 +1,17 @@
use anyhow::{Context as _, Result};
use cloud_llm_client::predict_edits_v3::Event;
use credentials_provider::CredentialsProvider;
use edit_prediction_context::RelatedFile;
use futures::{AsyncReadExt as _, FutureExt, future::Shared};
use gpui::{
App, AppContext as _, Entity, Task,
App, AppContext as _, Task,
http_client::{self, AsyncBody, Method},
};
use language::{Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _};
use project::{Project, ProjectPath};
use std::{
collections::VecDeque, fmt::Write as _, mem, ops::Range, path::Path, sync::Arc, time::Instant,
};
use language::{OffsetRangeExt as _, ToOffset, ToPoint as _};
use std::{mem, ops::Range, path::Path, sync::Arc, time::Instant};
use zeta_prompt::ZetaPromptInput;
use crate::{
EditPredictionId, EditPredictionInputs, open_ai_response::text_from_response,
DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
EditPredictionStartedDebugEvent, open_ai_response::text_from_response,
prediction::EditPredictionResult,
};
@@ -38,16 +35,17 @@ impl Mercury {
store_api_token_in_keychain(api_token, cx)
}
pub fn request_prediction(
pub(crate) fn request_prediction(
&self,
_project: &Entity<Project>,
active_buffer: &Entity<Buffer>,
snapshot: BufferSnapshot,
position: language::Anchor,
events: Vec<Arc<Event>>,
_recent_paths: &VecDeque<ProjectPath>,
related_files: Vec<RelatedFile>,
_diagnostic_search_range: Range<Point>,
EditPredictionModelInput {
buffer,
snapshot,
position,
events,
related_files,
debug_tx,
..
}: EditPredictionModelInput,
cx: &mut App,
) -> Task<Result<Option<EditPredictionResult>>> {
let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
@@ -62,6 +60,7 @@ impl Mercury {
let http_client = cx.http_client();
let cursor_point = position.to_point(&snapshot);
let buffer_snapshotted_at = Instant::now();
let active_buffer = buffer.clone();
let result = cx.background_spawn(async move {
let (editable_range, context_range) =
@@ -72,39 +71,39 @@ impl Mercury {
MAX_REWRITE_TOKENS,
);
let offset_range = editable_range.to_offset(&snapshot);
let prompt = build_prompt(
&events,
&related_files,
&snapshot,
full_path.as_ref(),
cursor_point,
editable_range,
context_range.clone(),
);
let context_offset_range = context_range.to_offset(&snapshot);
let inputs = EditPredictionInputs {
events: events,
included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
path: full_path.clone(),
max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
start_line: cloud_llm_client::predict_edits_v3::Line(
context_range.start.row,
),
text: snapshot
.text_for_range(context_range.clone())
.collect::<String>()
.into(),
}],
}],
cursor_point: cloud_llm_client::predict_edits_v3::Point {
column: cursor_point.column,
line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
},
let editable_offset_range = editable_range.to_offset(&snapshot);
let inputs = zeta_prompt::ZetaPromptInput {
events,
related_files,
cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot)
- context_range.start.to_offset(&snapshot),
cursor_path: full_path.clone(),
cursor_excerpt: snapshot
.text_for_range(context_range)
.collect::<String>()
.into(),
editable_range_in_excerpt: (editable_offset_range.start
- context_offset_range.start)
..(editable_offset_range.end - context_offset_range.start),
};
let prompt = build_prompt(&inputs);
if let Some(debug_tx) = &debug_tx {
debug_tx
.unbounded_send(DebugEvent::EditPredictionStarted(
EditPredictionStartedDebugEvent {
buffer: active_buffer.downgrade(),
prompt: Some(prompt.clone()),
position,
},
))
.ok();
}
let request_body = open_ai::Request {
model: "mercury-coder".into(),
messages: vec![open_ai::RequestMessage::User {
@@ -160,6 +159,18 @@ impl Mercury {
let id = mem::take(&mut response.id);
let response_str = text_from_response(response).unwrap_or_default();
if let Some(debug_tx) = &debug_tx {
debug_tx
.unbounded_send(DebugEvent::EditPredictionFinished(
EditPredictionFinishedDebugEvent {
buffer: active_buffer.downgrade(),
model_output: Some(response_str.clone()),
position,
},
))
.ok();
}
let response_str = response_str.strip_prefix("```\n").unwrap_or(&response_str);
let response_str = response_str.strip_suffix("\n```").unwrap_or(&response_str);
@@ -168,15 +179,16 @@ impl Mercury {
if response_str != NO_PREDICTION_OUTPUT {
let old_text = snapshot
.text_for_range(offset_range.clone())
.text_for_range(editable_offset_range.clone())
.collect::<String>();
edits.extend(
language::text_diff(&old_text, &response_str)
.into_iter()
.map(|(range, text)| {
(
snapshot.anchor_after(offset_range.start + range.start)
..snapshot.anchor_before(offset_range.start + range.end),
snapshot.anchor_after(editable_offset_range.start + range.start)
..snapshot
.anchor_before(editable_offset_range.start + range.end),
text,
)
}),
@@ -186,8 +198,6 @@ impl Mercury {
anyhow::Ok((id, edits, snapshot, response_received_at, inputs))
});
let buffer = active_buffer.clone();
cx.spawn(async move |cx| {
let (id, edits, old_snapshot, response_received_at, inputs) =
result.await.context("Mercury edit prediction failed")?;
@@ -208,15 +218,7 @@ impl Mercury {
}
}
fn build_prompt(
events: &[Arc<Event>],
related_files: &[RelatedFile],
cursor_buffer: &BufferSnapshot,
cursor_buffer_path: &Path,
cursor_point: Point,
editable_range: Range<Point>,
context_range: Range<Point>,
) -> String {
fn build_prompt(inputs: &ZetaPromptInput) -> String {
const RECENTLY_VIEWED_SNIPPETS_START: &str = "<|recently_viewed_code_snippets|>\n";
const RECENTLY_VIEWED_SNIPPETS_END: &str = "<|/recently_viewed_code_snippets|>\n";
const RECENTLY_VIEWED_SNIPPET_START: &str = "<|recently_viewed_code_snippet|>\n";
@@ -237,14 +239,14 @@ fn build_prompt(
&mut prompt,
RECENTLY_VIEWED_SNIPPETS_START..RECENTLY_VIEWED_SNIPPETS_END,
|prompt| {
for related_file in related_files {
for related_file in inputs.related_files.iter() {
for related_excerpt in &related_file.excerpts {
push_delimited(
prompt,
RECENTLY_VIEWED_SNIPPET_START..RECENTLY_VIEWED_SNIPPET_END,
|prompt| {
prompt.push_str(CODE_SNIPPET_FILE_PATH_PREFIX);
prompt.push_str(related_file.path.path.as_unix_str());
prompt.push_str(related_file.path.to_string_lossy().as_ref());
prompt.push('\n');
prompt.push_str(&related_excerpt.text.to_string());
},
@@ -259,21 +261,22 @@ fn build_prompt(
CURRENT_FILE_CONTENT_START..CURRENT_FILE_CONTENT_END,
|prompt| {
prompt.push_str(CURRENT_FILE_PATH_PREFIX);
prompt.push_str(cursor_buffer_path.as_os_str().to_string_lossy().as_ref());
prompt.push_str(inputs.cursor_path.as_os_str().to_string_lossy().as_ref());
prompt.push('\n');
let prefix_range = context_range.start..editable_range.start;
let suffix_range = editable_range.end..context_range.end;
prompt.extend(cursor_buffer.text_for_range(prefix_range));
prompt.push_str(&inputs.cursor_excerpt[0..inputs.editable_range_in_excerpt.start]);
push_delimited(prompt, CODE_TO_EDIT_START..CODE_TO_EDIT_END, |prompt| {
let range_before_cursor = editable_range.start..cursor_point;
let range_after_cursor = cursor_point..editable_range.end;
prompt.extend(cursor_buffer.text_for_range(range_before_cursor));
prompt.push_str(
&inputs.cursor_excerpt
[inputs.editable_range_in_excerpt.start..inputs.cursor_offset_in_excerpt],
);
prompt.push_str(CURSOR_TAG);
prompt.extend(cursor_buffer.text_for_range(range_after_cursor));
prompt.push_str(
&inputs.cursor_excerpt
[inputs.cursor_offset_in_excerpt..inputs.editable_range_in_excerpt.end],
);
});
prompt.extend(cursor_buffer.text_for_range(suffix_range));
prompt.push_str(&inputs.cursor_excerpt[inputs.editable_range_in_excerpt.end..]);
},
);
@@ -281,8 +284,8 @@ fn build_prompt(
&mut prompt,
EDIT_DIFF_HISTORY_START..EDIT_DIFF_HISTORY_END,
|prompt| {
for event in events {
writeln!(prompt, "{event}").unwrap();
for event in inputs.events.iter() {
zeta_prompt::write_event(prompt, &event);
}
},
);

View File

@@ -1,6 +1,5 @@
use std::{
ops::Range,
path::Path,
sync::Arc,
time::{Duration, Instant},
};
@@ -9,7 +8,7 @@ use cloud_llm_client::EditPredictionRejectReason;
use edit_prediction_types::interpolate_edits;
use gpui::{AsyncApp, Entity, SharedString};
use language::{Anchor, Buffer, BufferSnapshot, EditPreview, TextBufferSnapshot};
use serde::Serialize;
use zeta_prompt::ZetaPromptInput;
#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)]
pub struct EditPredictionId(pub SharedString);
@@ -40,7 +39,7 @@ impl EditPredictionResult {
edits: Arc<[(Range<Anchor>, Arc<str>)]>,
buffer_snapshotted_at: Instant,
response_received_at: Instant,
inputs: EditPredictionInputs,
inputs: ZetaPromptInput,
cx: &mut AsyncApp,
) -> Self {
if edits.is_empty() {
@@ -94,15 +93,7 @@ pub struct EditPrediction {
pub buffer: Entity<Buffer>,
pub buffer_snapshotted_at: Instant,
pub response_received_at: Instant,
pub inputs: EditPredictionInputs,
}
#[derive(Debug, Clone, Serialize)]
pub struct EditPredictionInputs {
pub events: Vec<Arc<cloud_llm_client::predict_edits_v3::Event>>,
pub included_files: Vec<cloud_llm_client::predict_edits_v3::RelatedFile>,
pub cursor_point: cloud_llm_client::predict_edits_v3::Point,
pub cursor_path: Arc<Path>,
pub inputs: zeta_prompt::ZetaPromptInput,
}
impl EditPrediction {
@@ -133,9 +124,12 @@ impl std::fmt::Debug for EditPrediction {
#[cfg(test)]
mod tests {
use std::path::Path;
use super::*;
use gpui::{App, Entity, TestAppContext, prelude::*};
use language::{Buffer, ToOffset as _};
use zeta_prompt::ZetaPromptInput;
#[gpui::test]
async fn test_edit_prediction_basic_interpolation(cx: &mut TestAppContext) {
@@ -154,14 +148,13 @@ mod tests {
snapshot: cx.read(|cx| buffer.read(cx).snapshot()),
buffer: buffer.clone(),
edit_preview,
inputs: EditPredictionInputs {
inputs: ZetaPromptInput {
events: vec![],
included_files: vec![],
cursor_point: cloud_llm_client::predict_edits_v3::Point {
line: cloud_llm_client::predict_edits_v3::Line(0),
column: 0,
},
related_files: vec![].into(),
cursor_path: Path::new("path.txt").into(),
cursor_offset_in_excerpt: 0,
cursor_excerpt: "".into(),
editable_range_in_excerpt: 0..0,
},
buffer_snapshotted_at: Instant::now(),
response_received_at: Instant::now(),

View File

@@ -1,26 +1,21 @@
use anyhow::{Context as _, Result};
use cloud_llm_client::predict_edits_v3::Event;
use credentials_provider::CredentialsProvider;
use edit_prediction_context::RelatedFile;
use futures::{AsyncReadExt as _, FutureExt, future::Shared};
use gpui::{
App, AppContext as _, Entity, Task,
App, AppContext as _, Task,
http_client::{self, AsyncBody, Method},
};
use language::{Buffer, BufferSnapshot, Point, ToOffset as _, ToPoint as _};
use language::{Point, ToOffset as _};
use lsp::DiagnosticSeverity;
use project::{Project, ProjectPath};
use serde::{Deserialize, Serialize};
use std::{
collections::VecDeque,
fmt::{self, Write as _},
ops::Range,
path::Path,
sync::Arc,
time::Instant,
};
use crate::{EditPredictionId, EditPredictionInputs, prediction::EditPredictionResult};
use crate::{EditPredictionId, EditPredictionModelInput, prediction::EditPredictionResult};
const SWEEP_API_URL: &str = "https://autocomplete.sweep.dev/backend/next_edit_autocomplete";
@@ -44,40 +39,34 @@ impl SweepAi {
pub fn request_prediction_with_sweep(
&self,
project: &Entity<Project>,
active_buffer: &Entity<Buffer>,
snapshot: BufferSnapshot,
position: language::Anchor,
events: Vec<Arc<Event>>,
recent_paths: &VecDeque<ProjectPath>,
related_files: Vec<RelatedFile>,
diagnostic_search_range: Range<Point>,
inputs: EditPredictionModelInput,
cx: &mut App,
) -> Task<Result<Option<EditPredictionResult>>> {
let debug_info = self.debug_info.clone();
let Some(api_token) = self.api_token.clone().now_or_never().flatten() else {
return Task::ready(Ok(None));
};
let full_path: Arc<Path> = snapshot
let full_path: Arc<Path> = inputs
.snapshot
.file()
.map(|file| file.full_path(cx))
.unwrap_or_else(|| "untitled".into())
.into();
let project_file = project::File::from_dyn(snapshot.file());
let project_file = project::File::from_dyn(inputs.snapshot.file());
let repo_name = project_file
.map(|file| file.worktree.read(cx).root_name_str())
.unwrap_or("untitled")
.into();
let offset = position.to_offset(&snapshot);
let offset = inputs.position.to_offset(&inputs.snapshot);
let recent_buffers = recent_paths.iter().cloned();
let recent_buffers = inputs.recent_paths.iter().cloned();
let http_client = cx.http_client();
let recent_buffer_snapshots = recent_buffers
.filter_map(|project_path| {
let buffer = project.read(cx).get_open_buffer(&project_path, cx)?;
if active_buffer == &buffer {
let buffer = inputs.project.read(cx).get_open_buffer(&project_path, cx)?;
if inputs.buffer == buffer {
None
} else {
Some(buffer.read(cx).snapshot())
@@ -86,14 +75,13 @@ impl SweepAi {
.take(3)
.collect::<Vec<_>>();
let cursor_point = position.to_point(&snapshot);
let buffer_snapshotted_at = Instant::now();
let result = cx.background_spawn(async move {
let text = snapshot.text();
let text = inputs.snapshot.text();
let mut recent_changes = String::new();
for event in &events {
for event in &inputs.events {
write_event(event.as_ref(), &mut recent_changes).unwrap();
}
@@ -122,20 +110,23 @@ impl SweepAi {
})
.collect::<Vec<_>>();
let retrieval_chunks = related_files
let retrieval_chunks = inputs
.related_files
.iter()
.flat_map(|related_file| {
related_file.excerpts.iter().map(|excerpt| FileChunk {
file_path: related_file.path.path.as_unix_str().to_string(),
start_line: excerpt.point_range.start.row as usize,
end_line: excerpt.point_range.end.row as usize,
file_path: related_file.path.to_string_lossy().to_string(),
start_line: excerpt.row_range.start as usize,
end_line: excerpt.row_range.end as usize,
content: excerpt.text.to_string(),
timestamp: None,
})
})
.collect();
let diagnostic_entries = snapshot.diagnostics_in_range(diagnostic_search_range, false);
let diagnostic_entries = inputs
.snapshot
.diagnostics_in_range(inputs.diagnostic_search_range, false);
let mut diagnostic_content = String::new();
let mut diagnostic_count = 0;
@@ -195,21 +186,14 @@ impl SweepAi {
serde_json::to_writer(writer, &request_body)?;
let body: AsyncBody = buf.into();
let inputs = EditPredictionInputs {
events,
included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
path: full_path.clone(),
max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
start_line: cloud_llm_client::predict_edits_v3::Line(0),
text: request_body.file_contents.into(),
}],
}],
cursor_point: cloud_llm_client::predict_edits_v3::Point {
column: cursor_point.column,
line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
},
let ep_inputs = zeta_prompt::ZetaPromptInput {
events: inputs.events,
related_files: inputs.related_files.clone(),
cursor_path: full_path.clone(),
cursor_excerpt: request_body.file_contents.into(),
// we actually don't know
editable_range_in_excerpt: 0..inputs.snapshot.len(),
cursor_offset_in_excerpt: request_body.cursor_position,
};
let request = http_client::Request::builder()
@@ -237,15 +221,20 @@ impl SweepAi {
let response: AutocompleteResponse = serde_json::from_slice(&body)?;
let old_text = snapshot
let old_text = inputs
.snapshot
.text_for_range(response.start_index..response.end_index)
.collect::<String>();
let edits = language::text_diff(&old_text, &response.completion)
.into_iter()
.map(|(range, text)| {
(
snapshot.anchor_after(response.start_index + range.start)
..snapshot.anchor_before(response.start_index + range.end),
inputs
.snapshot
.anchor_after(response.start_index + range.start)
..inputs
.snapshot
.anchor_before(response.start_index + range.end),
text,
)
})
@@ -254,13 +243,13 @@ impl SweepAi {
anyhow::Ok((
response.autocomplete_id,
edits,
snapshot,
inputs.snapshot,
response_received_at,
inputs,
ep_inputs,
))
});
let buffer = active_buffer.clone();
let buffer = inputs.buffer.clone();
cx.spawn(async move |cx| {
let (id, edits, old_snapshot, response_received_at, inputs) = result.await?;
@@ -403,12 +392,9 @@ struct AdditionalCompletion {
pub finish_reason: Option<String>,
}
fn write_event(
event: &cloud_llm_client::predict_edits_v3::Event,
f: &mut impl fmt::Write,
) -> fmt::Result {
fn write_event(event: &zeta_prompt::Event, f: &mut impl fmt::Write) -> fmt::Result {
match event {
cloud_llm_client::predict_edits_v3::Event::BufferChange {
zeta_prompt::Event::BufferChange {
old_path,
path,
diff,

View File

@@ -14,68 +14,18 @@ use anyhow::anyhow;
use collections::HashMap;
use gpui::AsyncApp;
use gpui::Entity;
use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, TextBufferSnapshot};
use language::{Anchor, Buffer, OffsetRangeExt as _, TextBufferSnapshot};
use project::Project;
pub async fn parse_diff<'a>(
diff_str: &'a str,
get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
let mut diff = DiffParser::new(diff_str);
let mut edited_buffer = None;
let mut edits = Vec::new();
while let Some(event) = diff.next()? {
match event {
DiffEvent::Hunk {
path: file_path,
hunk,
} => {
let (buffer, ranges) = match edited_buffer {
None => {
edited_buffer = get_buffer(&Path::new(file_path.as_ref()));
edited_buffer
.as_ref()
.context("Model tried to edit a file that wasn't included")?
}
Some(ref current) => current,
};
edits.extend(
resolve_hunk_edits_in_buffer(hunk, &buffer.text, ranges)
.with_context(|| format!("Diff:\n{diff_str}"))?,
);
}
DiffEvent::FileEnd { renamed_to } => {
let (buffer, _) = edited_buffer
.take()
.context("Got a FileEnd event before an Hunk event")?;
if renamed_to.is_some() {
anyhow::bail!("edit predictions cannot rename files");
}
if diff.next()?.is_some() {
anyhow::bail!("Edited more than one file");
}
return Ok((buffer, edits));
}
}
}
Err(anyhow::anyhow!("No EOF"))
}
#[derive(Debug)]
pub struct OpenedBuffers<'a>(#[allow(unused)] HashMap<Cow<'a, str>, Entity<Buffer>>);
#[derive(Clone, Debug)]
pub struct OpenedBuffers(#[allow(unused)] HashMap<String, Entity<Buffer>>);
#[must_use]
pub async fn apply_diff<'a>(
diff_str: &'a str,
pub async fn apply_diff(
diff_str: &str,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> Result<OpenedBuffers<'a>> {
) -> Result<OpenedBuffers> {
let mut included_files = HashMap::default();
for line in diff_str.lines() {
@@ -94,7 +44,7 @@ pub async fn apply_diff<'a>(
})??
.await?;
included_files.insert(path, buffer);
included_files.insert(path.to_string(), buffer);
}
}
@@ -113,7 +63,7 @@ pub async fn apply_diff<'a>(
let (buffer, ranges) = match current_file {
None => {
let buffer = included_files
.get_mut(&file_path)
.get_mut(file_path.as_ref())
.expect("Opened all files in diff");
current_file = Some((buffer, ranges.as_slice()));
@@ -167,6 +117,29 @@ pub async fn apply_diff<'a>(
Ok(OpenedBuffers(included_files))
}
pub fn apply_diff_to_string(diff_str: &str, text: &str) -> Result<String> {
let mut diff = DiffParser::new(diff_str);
let mut text = text.to_string();
while let Some(event) = diff.next()? {
match event {
DiffEvent::Hunk { hunk, .. } => {
let hunk_offset = text
.find(&hunk.context)
.ok_or_else(|| anyhow!("couldn't result hunk {:?}", hunk.context))?;
for edit in hunk.edits.iter().rev() {
let range = (hunk_offset + edit.range.start)..(hunk_offset + edit.range.end);
text.replace_range(range, &edit.text);
}
}
DiffEvent::FileEnd { .. } => {}
}
}
Ok(text)
}
struct PatchFile<'a> {
old_path: Cow<'a, str>,
new_path: Cow<'a, str>,
@@ -492,7 +465,6 @@ mod tests {
use super::*;
use gpui::TestAppContext;
use indoc::indoc;
use language::Point;
use pretty_assertions::assert_eq;
use project::{FakeFs, Project};
use serde_json::json;
@@ -817,137 +789,6 @@ mod tests {
});
}
#[gpui::test]
async fn test_apply_diff_non_unique(cx: &mut TestAppContext) {
let fs = init_test(cx);
let buffer_1_text = indoc! {r#"
one
two
three
four
five
one
two
three
four
five
"# };
fs.insert_tree(
path!("/root"),
json!({
"file1": buffer_1_text,
}),
)
.await;
let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
project.open_local_buffer(path!("/root/file1"), cx)
})
.await
.unwrap();
let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
let diff = indoc! {r#"
--- a/root/file1
+++ b/root/file1
one
two
-three
+3
four
five
"#};
let final_text = indoc! {r#"
one
two
three
four
five
one
two
3
four
five
"#};
apply_diff(diff, &project, &mut cx.to_async())
.await
.expect_err("Non-unique edits should fail");
let ranges = [buffer_snapshot.anchor_before(Point::new(1, 0))
..buffer_snapshot.anchor_after(buffer_snapshot.max_point())];
let (edited_snapshot, edits) = parse_diff(diff, |_path| Some((&buffer_snapshot, &ranges)))
.await
.unwrap();
assert_eq!(edited_snapshot.remote_id(), buffer_snapshot.remote_id());
buffer.update(cx, |buffer, cx| {
buffer.edit(edits, None, cx);
assert_eq!(buffer.text(), final_text);
});
}
#[gpui::test]
async fn test_parse_diff_with_edits_within_line(cx: &mut TestAppContext) {
let fs = init_test(cx);
let buffer_1_text = indoc! {r#"
one two three four
five six seven eight
nine ten eleven twelve
"# };
fs.insert_tree(
path!("/root"),
json!({
"file1": buffer_1_text,
}),
)
.await;
let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
project.open_local_buffer(path!("/root/file1"), cx)
})
.await
.unwrap();
let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
let diff = indoc! {r#"
--- a/root/file1
+++ b/root/file1
one two three four
-five six seven eight
+five SIX seven eight!
nine ten eleven twelve
"#};
let (buffer, edits) = parse_diff(diff, |_path| {
Some((&buffer_snapshot, &[(Anchor::MIN..Anchor::MAX)] as &[_]))
})
.await
.unwrap();
let edits = edits
.into_iter()
.map(|(range, text)| (range.to_point(&buffer), text))
.collect::<Vec<_>>();
assert_eq!(
edits,
&[
(Point::new(1, 5)..Point::new(1, 8), "SIX".into()),
(Point::new(1, 20)..Point::new(1, 20), "!".into())
]
);
}
#[gpui::test]
async fn test_apply_diff_unique_via_previous_context(cx: &mut TestAppContext) {
let fs = init_test(cx);

View File

@@ -1,637 +0,0 @@
use anyhow::{Context as _, Result};
use language::{Anchor, BufferSnapshot, OffsetRangeExt as _, Point};
use std::{cmp, ops::Range, path::Path, sync::Arc};
const EDITS_TAG_NAME: &'static str = "edits";
const OLD_TEXT_TAG_NAME: &'static str = "old_text";
const NEW_TEXT_TAG_NAME: &'static str = "new_text";
const XML_TAGS: &[&str] = &[EDITS_TAG_NAME, OLD_TEXT_TAG_NAME, NEW_TEXT_TAG_NAME];
pub async fn parse_xml_edits<'a>(
input: &'a str,
get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
parse_xml_edits_inner(input, get_buffer)
.await
.with_context(|| format!("Failed to parse XML edits:\n{input}"))
}
async fn parse_xml_edits_inner<'a>(
input: &'a str,
get_buffer: impl Fn(&Path) -> Option<(&'a BufferSnapshot, &'a [Range<Anchor>])> + Send,
) -> Result<(&'a BufferSnapshot, Vec<(Range<Anchor>, Arc<str>)>)> {
let xml_edits = extract_xml_replacements(input)?;
let (buffer, context_ranges) = get_buffer(xml_edits.file_path.as_ref())
.with_context(|| format!("no buffer for file {}", xml_edits.file_path))?;
let mut all_edits = vec![];
for (old_text, new_text) in xml_edits.replacements {
let match_range = fuzzy_match_in_ranges(old_text, buffer, context_ranges)?;
let matched_old_text = buffer
.text_for_range(match_range.clone())
.collect::<String>();
let edits_within_hunk = language::text_diff(&matched_old_text, new_text);
all_edits.extend(
edits_within_hunk
.into_iter()
.map(move |(inner_range, inner_text)| {
(
buffer.anchor_after(match_range.start + inner_range.start)
..buffer.anchor_before(match_range.start + inner_range.end),
inner_text,
)
}),
);
}
Ok((buffer, all_edits))
}
fn fuzzy_match_in_ranges(
old_text: &str,
buffer: &BufferSnapshot,
context_ranges: &[Range<Anchor>],
) -> Result<Range<usize>> {
let mut state = FuzzyMatcher::new(buffer, old_text);
let mut best_match = None;
let mut tie_match_range = None;
for range in context_ranges {
let best_match_cost = best_match.as_ref().map(|(score, _)| *score);
match (best_match_cost, state.match_range(range.to_offset(buffer))) {
(Some(lowest_cost), Some((new_cost, new_range))) => {
if new_cost == lowest_cost {
tie_match_range = Some(new_range);
} else if new_cost < lowest_cost {
tie_match_range.take();
best_match = Some((new_cost, new_range));
}
}
(None, Some(new_match)) => {
best_match = Some(new_match);
}
(None, None) | (Some(_), None) => {}
};
}
if let Some((_, best_match_range)) = best_match {
if let Some(tie_match_range) = tie_match_range {
anyhow::bail!(
"Multiple ambiguous matches:\n{:?}:\n{}\n\n{:?}:\n{}",
best_match_range.clone(),
buffer.text_for_range(best_match_range).collect::<String>(),
tie_match_range.clone(),
buffer.text_for_range(tie_match_range).collect::<String>()
);
}
return Ok(best_match_range);
}
anyhow::bail!(
"Failed to fuzzy match `old_text`:\n{}\nin:\n```\n{}\n```",
old_text,
context_ranges
.iter()
.map(|range| buffer.text_for_range(range.clone()).collect::<String>())
.collect::<Vec<String>>()
.join("```\n```")
);
}
#[derive(Debug)]
struct XmlEdits<'a> {
file_path: &'a str,
/// Vec of (old_text, new_text) pairs
replacements: Vec<(&'a str, &'a str)>,
}
fn extract_xml_replacements(input: &str) -> Result<XmlEdits<'_>> {
let mut cursor = 0;
let (edits_body_start, edits_attrs) =
find_tag_open(input, &mut cursor, EDITS_TAG_NAME)?.context("No edits tag found")?;
let file_path = edits_attrs
.trim_start()
.strip_prefix("path")
.context("no path attribute on edits tag")?
.trim_end()
.strip_prefix('=')
.context("no value for path attribute")?
.trim()
.trim_start_matches('"')
.trim_end_matches('"');
cursor = edits_body_start;
let mut edits_list = Vec::new();
while let Some((old_body_start, _)) = find_tag_open(input, &mut cursor, OLD_TEXT_TAG_NAME)? {
let old_body_end = find_tag_close(input, &mut cursor)?;
let old_text = trim_surrounding_newlines(&input[old_body_start..old_body_end]);
let (new_body_start, _) = find_tag_open(input, &mut cursor, NEW_TEXT_TAG_NAME)?
.context("no new_text tag following old_text")?;
let new_body_end = find_tag_close(input, &mut cursor)?;
let new_text = trim_surrounding_newlines(&input[new_body_start..new_body_end]);
edits_list.push((old_text, new_text));
}
Ok(XmlEdits {
file_path,
replacements: edits_list,
})
}
/// Trims a single leading and trailing newline
fn trim_surrounding_newlines(input: &str) -> &str {
let start = input.strip_prefix('\n').unwrap_or(input);
let end = start.strip_suffix('\n').unwrap_or(start);
end
}
fn find_tag_open<'a>(
input: &'a str,
cursor: &mut usize,
expected_tag: &str,
) -> Result<Option<(usize, &'a str)>> {
let mut search_pos = *cursor;
while search_pos < input.len() {
let Some(tag_start) = input[search_pos..].find("<") else {
break;
};
let tag_start = search_pos + tag_start;
if !input[tag_start + 1..].starts_with(expected_tag) {
search_pos = search_pos + tag_start + 1;
continue;
};
let after_tag_name = tag_start + expected_tag.len() + 1;
let close_bracket = input[after_tag_name..]
.find('>')
.with_context(|| format!("missing > after <{}", expected_tag))?;
let attrs_end = after_tag_name + close_bracket;
let body_start = attrs_end + 1;
let attributes = input[after_tag_name..attrs_end].trim();
*cursor = body_start;
return Ok(Some((body_start, attributes)));
}
Ok(None)
}
fn find_tag_close(input: &str, cursor: &mut usize) -> Result<usize> {
let mut depth = 1;
let mut search_pos = *cursor;
while search_pos < input.len() && depth > 0 {
let Some(bracket_offset) = input[search_pos..].find('<') else {
break;
};
let bracket_pos = search_pos + bracket_offset;
if input[bracket_pos..].starts_with("</")
&& let Some(close_end) = input[bracket_pos + 2..].find('>')
{
let close_start = bracket_pos + 2;
let tag_name = input[close_start..close_start + close_end].trim();
if XML_TAGS.contains(&tag_name) {
depth -= 1;
if depth == 0 {
*cursor = close_start + close_end + 1;
return Ok(bracket_pos);
}
}
search_pos = close_start + close_end + 1;
continue;
} else if let Some(close_bracket_offset) = input[bracket_pos..].find('>') {
let close_bracket_pos = bracket_pos + close_bracket_offset;
let tag_name = &input[bracket_pos + 1..close_bracket_pos].trim();
if XML_TAGS.contains(&tag_name) {
depth += 1;
}
}
search_pos = bracket_pos + 1;
}
anyhow::bail!("no closing tag found")
}
const REPLACEMENT_COST: u32 = 1;
const INSERTION_COST: u32 = 3;
const DELETION_COST: u32 = 10;
/// A fuzzy matcher that can process text chunks incrementally
/// and return the best match found so far at each step.
struct FuzzyMatcher<'a> {
snapshot: &'a BufferSnapshot,
query_lines: Vec<&'a str>,
matrix: SearchMatrix,
}
impl<'a> FuzzyMatcher<'a> {
fn new(snapshot: &'a BufferSnapshot, old_text: &'a str) -> Self {
let query_lines = old_text.lines().collect();
Self {
snapshot,
query_lines,
matrix: SearchMatrix::new(0),
}
}
fn match_range(&mut self, range: Range<usize>) -> Option<(u32, Range<usize>)> {
let point_range = range.to_point(&self.snapshot);
let buffer_line_count = (point_range.end.row - point_range.start.row + 1) as usize;
self.matrix
.reset(self.query_lines.len() + 1, buffer_line_count + 1);
let query_line_count = self.query_lines.len();
for row in 0..query_line_count {
let query_line = self.query_lines[row].trim();
let leading_deletion_cost = (row + 1) as u32 * DELETION_COST;
self.matrix.set(
row + 1,
0,
SearchState::new(leading_deletion_cost, SearchDirection::Up),
);
let mut buffer_lines = self.snapshot.text_for_range(range.clone()).lines();
let mut col = 0;
while let Some(buffer_line) = buffer_lines.next() {
let buffer_line = buffer_line.trim();
let up = SearchState::new(
self.matrix
.get(row, col + 1)
.cost
.saturating_add(DELETION_COST),
SearchDirection::Up,
);
let left = SearchState::new(
self.matrix
.get(row + 1, col)
.cost
.saturating_add(INSERTION_COST),
SearchDirection::Left,
);
let diagonal = SearchState::new(
if query_line == buffer_line {
self.matrix.get(row, col).cost
} else if fuzzy_eq(query_line, buffer_line) {
self.matrix.get(row, col).cost + REPLACEMENT_COST
} else {
self.matrix
.get(row, col)
.cost
.saturating_add(DELETION_COST + INSERTION_COST)
},
SearchDirection::Diagonal,
);
self.matrix
.set(row + 1, col + 1, up.min(left).min(diagonal));
col += 1;
}
}
// Find all matches with the best cost
let mut best_cost = u32::MAX;
let mut matches_with_best_cost = Vec::new();
for col in 1..=buffer_line_count {
let cost = self.matrix.get(query_line_count, col).cost;
if cost < best_cost {
best_cost = cost;
matches_with_best_cost.clear();
matches_with_best_cost.push(col as u32);
} else if cost == best_cost {
matches_with_best_cost.push(col as u32);
}
}
// Find ranges for the matches
for &match_end_col in &matches_with_best_cost {
let mut matched_lines = 0;
let mut query_row = query_line_count;
let mut match_start_col = match_end_col;
while query_row > 0 && match_start_col > 0 {
let current = self.matrix.get(query_row, match_start_col as usize);
match current.direction {
SearchDirection::Diagonal => {
query_row -= 1;
match_start_col -= 1;
matched_lines += 1;
}
SearchDirection::Up => {
query_row -= 1;
}
SearchDirection::Left => {
match_start_col -= 1;
}
}
}
let buffer_row_start = match_start_col + point_range.start.row;
let buffer_row_end = match_end_col + point_range.start.row;
let matched_buffer_row_count = buffer_row_end - buffer_row_start;
let matched_ratio = matched_lines as f32
/ (matched_buffer_row_count as f32).max(query_line_count as f32);
if matched_ratio >= 0.8 {
let buffer_start_ix = self
.snapshot
.point_to_offset(Point::new(buffer_row_start, 0));
let buffer_end_ix = self.snapshot.point_to_offset(Point::new(
buffer_row_end - 1,
self.snapshot.line_len(buffer_row_end - 1),
));
return Some((best_cost, buffer_start_ix..buffer_end_ix));
}
}
None
}
}
fn fuzzy_eq(left: &str, right: &str) -> bool {
const THRESHOLD: f64 = 0.8;
let min_levenshtein = left.len().abs_diff(right.len());
let min_normalized_levenshtein =
1. - (min_levenshtein as f64 / cmp::max(left.len(), right.len()) as f64);
if min_normalized_levenshtein < THRESHOLD {
return false;
}
strsim::normalized_levenshtein(left, right) >= THRESHOLD
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
enum SearchDirection {
Up,
Left,
Diagonal,
}
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord)]
struct SearchState {
cost: u32,
direction: SearchDirection,
}
impl SearchState {
fn new(cost: u32, direction: SearchDirection) -> Self {
Self { cost, direction }
}
}
struct SearchMatrix {
cols: usize,
rows: usize,
data: Vec<SearchState>,
}
impl SearchMatrix {
fn new(cols: usize) -> Self {
SearchMatrix {
cols,
rows: 0,
data: Vec::new(),
}
}
fn reset(&mut self, rows: usize, cols: usize) {
self.rows = rows;
self.cols = cols;
self.data
.fill(SearchState::new(0, SearchDirection::Diagonal));
self.data.resize(
self.rows * self.cols,
SearchState::new(0, SearchDirection::Diagonal),
);
}
fn get(&self, row: usize, col: usize) -> SearchState {
debug_assert!(row < self.rows);
debug_assert!(col < self.cols);
self.data[row * self.cols + col]
}
fn set(&mut self, row: usize, col: usize, state: SearchState) {
debug_assert!(row < self.rows && col < self.cols);
self.data[row * self.cols + col] = state;
}
}
#[cfg(test)]
mod tests {
use super::*;
use gpui::TestAppContext;
use indoc::indoc;
use language::Point;
use project::{FakeFs, Project};
use serde_json::json;
use settings::SettingsStore;
use util::path;
#[test]
fn test_extract_xml_edits() {
let input = indoc! {r#"
<edits path="test.rs">
<old_text>
old content
</old_text>
<new_text>
new content
</new_text>
</edits>
"#};
let result = extract_xml_replacements(input).unwrap();
assert_eq!(result.file_path, "test.rs");
assert_eq!(result.replacements.len(), 1);
assert_eq!(result.replacements[0].0, "old content");
assert_eq!(result.replacements[0].1, "new content");
}
#[test]
fn test_extract_xml_edits_with_wrong_closing_tags() {
let input = indoc! {r#"
<edits path="test.rs">
<old_text>
old content
</new_text>
<new_text>
new content
</old_text>
</ edits >
"#};
let result = extract_xml_replacements(input).unwrap();
assert_eq!(result.file_path, "test.rs");
assert_eq!(result.replacements.len(), 1);
assert_eq!(result.replacements[0].0, "old content");
assert_eq!(result.replacements[0].1, "new content");
}
#[test]
fn test_extract_xml_edits_with_xml_like_content() {
let input = indoc! {r#"
<edits path="component.tsx">
<old_text>
<foo><bar></bar></foo>
</old_text>
<new_text>
<foo><bar><baz></baz></bar></foo>
</new_text>
</edits>
"#};
let result = extract_xml_replacements(input).unwrap();
assert_eq!(result.file_path, "component.tsx");
assert_eq!(result.replacements.len(), 1);
assert_eq!(result.replacements[0].0, "<foo><bar></bar></foo>");
assert_eq!(
result.replacements[0].1,
"<foo><bar><baz></baz></bar></foo>"
);
}
#[test]
fn test_extract_xml_edits_with_conflicting_content() {
let input = indoc! {r#"
<edits path="component.tsx">
<old_text>
<new_text></new_text>
</old_text>
<new_text>
<old_text></old_text>
</new_text>
</edits>
"#};
let result = extract_xml_replacements(input).unwrap();
assert_eq!(result.file_path, "component.tsx");
assert_eq!(result.replacements.len(), 1);
assert_eq!(result.replacements[0].0, "<new_text></new_text>");
assert_eq!(result.replacements[0].1, "<old_text></old_text>");
}
#[test]
fn test_extract_xml_edits_multiple_pairs() {
let input = indoc! {r#"
Some reasoning before edits. Lots of thinking going on here
<edits path="test.rs">
<old_text>
first old
</old_text>
<new_text>
first new
</new_text>
<old_text>
second old
</edits>
<new_text>
second new
</old_text>
</edits>
"#};
let result = extract_xml_replacements(input).unwrap();
assert_eq!(result.file_path, "test.rs");
assert_eq!(result.replacements.len(), 2);
assert_eq!(result.replacements[0].0, "first old");
assert_eq!(result.replacements[0].1, "first new");
assert_eq!(result.replacements[1].0, "second old");
assert_eq!(result.replacements[1].1, "second new");
}
#[test]
fn test_extract_xml_edits_unexpected_eof() {
let input = indoc! {r#"
<edits path="test.rs">
<old_text>
first old
</
"#};
extract_xml_replacements(input).expect_err("Unexpected end of file");
}
#[gpui::test]
async fn test_parse_xml_edits(cx: &mut TestAppContext) {
let fs = init_test(cx);
let buffer_1_text = indoc! {r#"
one two three four
five six seven eight
nine ten eleven twelve
thirteen fourteen fifteen
sixteen seventeen eighteen
"#};
fs.insert_tree(
path!("/root"),
json!({
"file1": buffer_1_text,
}),
)
.await;
let project = Project::test(fs, [path!("/root").as_ref()], cx).await;
let buffer = project
.update(cx, |project, cx| {
project.open_local_buffer(path!("/root/file1"), cx)
})
.await
.unwrap();
let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
let edits = indoc! {r#"
<edits path="root/file1">
<old_text>
nine ten eleven twelve
</old_text>
<new_text>
nine TEN eleven twelve!
</new_text>
</edits>
"#};
let included_ranges = [(buffer_snapshot.anchor_before(Point::new(1, 0))..Anchor::MAX)];
let (buffer, edits) = parse_xml_edits(edits, |_path| {
Some((&buffer_snapshot, included_ranges.as_slice()))
})
.await
.unwrap();
let edits = edits
.into_iter()
.map(|(range, text)| (range.to_point(&buffer), text))
.collect::<Vec<_>>();
assert_eq!(
edits,
&[
(Point::new(2, 5)..Point::new(2, 8), "TEN".into()),
(Point::new(2, 22)..Point::new(2, 22), "!".into())
]
);
}
fn init_test(cx: &mut TestAppContext) -> Arc<FakeFs> {
cx.update(|cx| {
let settings_store = SettingsStore::test(cx);
cx.set_global(settings_store);
});
FakeFs::new(cx.background_executor.clone())
}
}

View File

@@ -125,14 +125,15 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate {
return;
}
if let Some(current) = store.current_prediction_for_buffer(&buffer, &self.project, cx)
&& let BufferEditPrediction::Local { prediction } = current
&& prediction.interpolate(buffer.read(cx)).is_some()
{
return;
}
self.store.update(cx, |store, cx| {
if let Some(current) =
store.prediction_at(&buffer, Some(cursor_position), &self.project, cx)
&& let BufferEditPrediction::Local { prediction } = current
&& prediction.interpolate(buffer.read(cx)).is_some()
{
return;
}
store.refresh_context(&self.project, &buffer, cursor_position, cx);
store.refresh_prediction_from_buffer(self.project.clone(), buffer, cursor_position, cx)
});
@@ -171,69 +172,68 @@ impl EditPredictionDelegate for ZedEditPredictionDelegate {
cursor_position: language::Anchor,
cx: &mut Context<Self>,
) -> Option<edit_prediction_types::EditPrediction> {
let prediction =
self.store
.read(cx)
.current_prediction_for_buffer(buffer, &self.project, cx)?;
self.store.update(cx, |store, cx| {
let prediction =
store.prediction_at(buffer, Some(cursor_position), &self.project, cx)?;
let prediction = match prediction {
BufferEditPrediction::Local { prediction } => prediction,
BufferEditPrediction::Jump { prediction } => {
return Some(edit_prediction_types::EditPrediction::Jump {
id: Some(prediction.id.to_string().into()),
snapshot: prediction.snapshot.clone(),
target: prediction.edits.first().unwrap().0.start,
});
}
};
let prediction = match prediction {
BufferEditPrediction::Local { prediction } => prediction,
BufferEditPrediction::Jump { prediction } => {
return Some(edit_prediction_types::EditPrediction::Jump {
id: Some(prediction.id.to_string().into()),
snapshot: prediction.snapshot.clone(),
target: prediction.edits.first().unwrap().0.start,
});
}
};
let buffer = buffer.read(cx);
let snapshot = buffer.snapshot();
let buffer = buffer.read(cx);
let snapshot = buffer.snapshot();
let Some(edits) = prediction.interpolate(&snapshot) else {
self.store.update(cx, |store, _cx| {
let Some(edits) = prediction.interpolate(&snapshot) else {
store.reject_current_prediction(
EditPredictionRejectReason::InterpolatedEmpty,
&self.project,
);
});
return None;
};
return None;
};
let cursor_row = cursor_position.to_point(&snapshot).row;
let (closest_edit_ix, (closest_edit_range, _)) =
edits.iter().enumerate().min_by_key(|(_, (range, _))| {
let distance_from_start = cursor_row.abs_diff(range.start.to_point(&snapshot).row);
let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
cmp::min(distance_from_start, distance_from_end)
})?;
let cursor_row = cursor_position.to_point(&snapshot).row;
let (closest_edit_ix, (closest_edit_range, _)) =
edits.iter().enumerate().min_by_key(|(_, (range, _))| {
let distance_from_start =
cursor_row.abs_diff(range.start.to_point(&snapshot).row);
let distance_from_end = cursor_row.abs_diff(range.end.to_point(&snapshot).row);
cmp::min(distance_from_start, distance_from_end)
})?;
let mut edit_start_ix = closest_edit_ix;
for (range, _) in edits[..edit_start_ix].iter().rev() {
let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
- range.end.to_point(&snapshot).row;
if distance_from_closest_edit <= 1 {
edit_start_ix -= 1;
} else {
break;
let mut edit_start_ix = closest_edit_ix;
for (range, _) in edits[..edit_start_ix].iter().rev() {
let distance_from_closest_edit = closest_edit_range.start.to_point(&snapshot).row
- range.end.to_point(&snapshot).row;
if distance_from_closest_edit <= 1 {
edit_start_ix -= 1;
} else {
break;
}
}
}
let mut edit_end_ix = closest_edit_ix + 1;
for (range, _) in &edits[edit_end_ix..] {
let distance_from_closest_edit =
range.start.to_point(buffer).row - closest_edit_range.end.to_point(&snapshot).row;
if distance_from_closest_edit <= 1 {
edit_end_ix += 1;
} else {
break;
let mut edit_end_ix = closest_edit_ix + 1;
for (range, _) in &edits[edit_end_ix..] {
let distance_from_closest_edit = range.start.to_point(buffer).row
- closest_edit_range.end.to_point(&snapshot).row;
if distance_from_closest_edit <= 1 {
edit_end_ix += 1;
} else {
break;
}
}
}
Some(edit_prediction_types::EditPrediction::Local {
id: Some(prediction.id.to_string().into()),
edits: edits[edit_start_ix..edit_end_ix].to_vec(),
edit_preview: Some(prediction.edit_preview.clone()),
Some(edit_prediction_types::EditPrediction::Local {
id: Some(prediction.id.to_string().into()),
edits: edits[edit_start_ix..edit_end_ix].to_vec(),
edit_preview: Some(prediction.edit_preview.clone()),
})
})
}
}

View File

@@ -1,22 +1,23 @@
use std::{fmt::Write, ops::Range, path::Path, sync::Arc, time::Instant};
use crate::{
EditPredictionId, EditPredictionStore, ZedUpdateRequiredError,
DebugEvent, EditPredictionFinishedDebugEvent, EditPredictionId, EditPredictionModelInput,
EditPredictionStartedDebugEvent, EditPredictionStore, ZedUpdateRequiredError,
cursor_excerpt::{editable_and_context_ranges_for_cursor_position, guess_token_count},
prediction::{EditPredictionInputs, EditPredictionResult},
prediction::EditPredictionResult,
};
use anyhow::{Context as _, Result};
use cloud_llm_client::{
PredictEditsBody, PredictEditsGitInfo, PredictEditsRequestTrigger, PredictEditsResponse,
predict_edits_v3::Event,
};
use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task};
use language::{
Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToPoint as _, text_diff,
Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset, ToPoint as _, text_diff,
};
use project::{Project, ProjectPath};
use release_channel::AppVersion;
use workspace::notifications::{ErrorMessagePrompt, NotificationId, show_app_notification};
use zeta_prompt::{Event, ZetaPromptInput};
const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
const START_OF_FILE_MARKER: &str = "<|start_of_file|>";
@@ -29,24 +30,27 @@ pub(crate) const MAX_EVENT_TOKENS: usize = 500;
pub(crate) fn request_prediction_with_zeta1(
store: &mut EditPredictionStore,
project: &Entity<Project>,
buffer: &Entity<Buffer>,
snapshot: BufferSnapshot,
position: language::Anchor,
events: Vec<Arc<Event>>,
trigger: PredictEditsRequestTrigger,
EditPredictionModelInput {
project,
buffer,
snapshot,
position,
events,
trigger,
debug_tx,
..
}: EditPredictionModelInput,
cx: &mut Context<EditPredictionStore>,
) -> Task<Result<Option<EditPredictionResult>>> {
let buffer = buffer.clone();
let buffer_snapshotted_at = Instant::now();
let client = store.client.clone();
let llm_token = store.llm_token.clone();
let app_version = AppVersion::global(cx);
let (git_info, can_collect_file) = if let Some(file) = snapshot.file() {
let can_collect_file = store.can_collect_file(project, file, cx);
let can_collect_file = store.can_collect_file(&project, file, cx);
let git_info = if can_collect_file {
git_info_for_file(project, &ProjectPath::from_file(file.as_ref(), cx), cx)
git_info_for_file(&project, &ProjectPath::from_file(file.as_ref(), cx), cx)
} else {
None
};
@@ -120,33 +124,33 @@ pub(crate) fn request_prediction_with_zeta1(
)
.await;
let inputs = EditPredictionInputs {
let context_start_offset = context_range.start.to_offset(&snapshot);
let editable_offset_range = editable_range.to_offset(&snapshot);
let inputs = ZetaPromptInput {
events: included_events.into(),
included_files: vec![cloud_llm_client::predict_edits_v3::RelatedFile {
path: full_path.clone(),
max_row: cloud_llm_client::predict_edits_v3::Line(snapshot.max_point().row),
excerpts: vec![cloud_llm_client::predict_edits_v3::Excerpt {
start_line: cloud_llm_client::predict_edits_v3::Line(context_range.start.row),
text: snapshot
.text_for_range(context_range)
.collect::<String>()
.into(),
}],
}],
cursor_point: cloud_llm_client::predict_edits_v3::Point {
column: cursor_point.column,
line: cloud_llm_client::predict_edits_v3::Line(cursor_point.row),
},
related_files: vec![].into(),
cursor_path: full_path,
cursor_excerpt: snapshot
.text_for_range(context_range)
.collect::<String>()
.into(),
editable_range_in_excerpt: (editable_range.start - context_start_offset)
..(editable_offset_range.end - context_start_offset),
cursor_offset_in_excerpt: cursor_point.to_offset(&snapshot) - context_start_offset,
};
// let response = perform_predict_edits(PerformPredictEditsParams {
// client,
// llm_token,
// app_version,
// body,
// })
// .await;
if let Some(debug_tx) = &debug_tx {
debug_tx
.unbounded_send(DebugEvent::EditPredictionStarted(
EditPredictionStartedDebugEvent {
buffer: buffer.downgrade(),
prompt: Some(serde_json::to_string(&inputs).unwrap()),
position,
},
))
.ok();
}
let (response, usage) = match response {
Ok(response) => response,
@@ -189,6 +193,18 @@ pub(crate) fn request_prediction_with_zeta1(
.ok();
}
if let Some(debug_tx) = &debug_tx {
debug_tx
.unbounded_send(DebugEvent::EditPredictionFinished(
EditPredictionFinishedDebugEvent {
buffer: buffer.downgrade(),
model_output: Some(response.output_excerpt.clone()),
position,
},
))
.ok();
}
let edit_prediction = process_completion_response(
response,
buffer,
@@ -226,7 +242,7 @@ fn process_completion_response(
buffer: Entity<Buffer>,
snapshot: &BufferSnapshot,
editable_range: Range<usize>,
inputs: EditPredictionInputs,
inputs: ZetaPromptInput,
buffer_snapshotted_at: Instant,
received_response_at: Instant,
cx: &AsyncApp,

View File

@@ -3,46 +3,39 @@ use crate::EvalCacheEntryKind;
use crate::open_ai_response::text_from_response;
use crate::prediction::EditPredictionResult;
use crate::{
DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionId, EditPredictionInputs,
EditPredictionRequestedDebugEvent, EditPredictionStore,
DebugEvent, EDIT_PREDICTIONS_MODEL_ID, EditPredictionFinishedDebugEvent, EditPredictionId,
EditPredictionModelInput, EditPredictionStartedDebugEvent, EditPredictionStore,
};
use anyhow::{Result, anyhow, bail};
use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat};
use cloud_llm_client::{EditPredictionRejectReason, PredictEditsRequestTrigger};
use cloud_zeta2_prompt::CURSOR_MARKER;
use edit_prediction_context::{EditPredictionExcerpt, Line};
use edit_prediction_context::{RelatedExcerpt, RelatedFile};
use futures::channel::oneshot;
use gpui::{Entity, Task, prelude::*};
use language::{Anchor, BufferSnapshot};
use language::{Buffer, Point, ToOffset as _, ToPoint};
use project::{Project, ProjectItem as _};
use anyhow::{Result, anyhow};
use cloud_llm_client::EditPredictionRejectReason;
use gpui::{Task, prelude::*};
use language::{OffsetRangeExt as _, ToOffset as _, ToPoint};
use release_channel::AppVersion;
use std::{
env,
path::Path,
sync::Arc,
time::{Duration, Instant},
};
use std::{path::Path, sync::Arc, time::Instant};
use zeta_prompt::CURSOR_MARKER;
use zeta_prompt::format_zeta_prompt;
const MAX_CONTEXT_TOKENS: usize = 150;
const MAX_REWRITE_TOKENS: usize = 350;
pub fn request_prediction_with_zeta2(
store: &mut EditPredictionStore,
project: &Entity<Project>,
active_buffer: &Entity<Buffer>,
active_snapshot: BufferSnapshot,
position: Anchor,
events: Vec<Arc<Event>>,
mut included_files: Vec<RelatedFile>,
trigger: PredictEditsRequestTrigger,
EditPredictionModelInput {
buffer,
snapshot,
position,
related_files,
events,
debug_tx,
..
}: EditPredictionModelInput,
cx: &mut Context<EditPredictionStore>,
) -> Task<Result<Option<EditPredictionResult>>> {
let options = store.options.clone();
let buffer_snapshotted_at = Instant::now();
let Some((excerpt_path, active_project_path)) = active_snapshot
let Some(excerpt_path) = snapshot
.file()
.map(|file| -> Arc<Path> { file.full_path(cx).into() })
.zip(active_buffer.read(cx).project_path(cx))
else {
return Task::ready(Err(anyhow!("No file path for excerpt")));
};
@@ -50,148 +43,35 @@ pub fn request_prediction_with_zeta2(
let client = store.client.clone();
let llm_token = store.llm_token.clone();
let app_version = AppVersion::global(cx);
let debug_tx = store.debug_tx.clone();
let file = active_buffer.read(cx).file();
let active_file_full_path = file.as_ref().map(|f| f.full_path(cx));
// TODO data collection
let can_collect_data = file
.as_ref()
.map_or(false, |file| store.can_collect_file(project, file, cx));
#[cfg(feature = "eval-support")]
let eval_cache = store.eval_cache.clone();
let request_task = cx.background_spawn({
let active_buffer = active_buffer.clone();
async move {
let cursor_offset = position.to_offset(&active_snapshot);
let cursor_point = cursor_offset.to_point(&active_snapshot);
let before_retrieval = Instant::now();
let excerpt_options = options.context;
let Some(excerpt) = EditPredictionExcerpt::select_from_buffer(
cursor_point,
&active_snapshot,
&excerpt_options,
) else {
return Ok((None, None));
};
let excerpt_anchor_range = active_snapshot.anchor_after(excerpt.range.start)
..active_snapshot.anchor_before(excerpt.range.end);
let related_excerpt = RelatedExcerpt {
anchor_range: excerpt_anchor_range.clone(),
point_range: Point::new(excerpt.line_range.start.0, 0)
..Point::new(excerpt.line_range.end.0, 0),
text: active_snapshot.as_rope().slice(excerpt.range),
};
if let Some(buffer_ix) = included_files
.iter()
.position(|file| file.buffer.entity_id() == active_buffer.entity_id())
{
let file = &mut included_files[buffer_ix];
file.excerpts.push(related_excerpt);
file.merge_excerpts();
let last_ix = included_files.len() - 1;
included_files.swap(buffer_ix, last_ix);
} else {
let active_file = RelatedFile {
path: active_project_path,
buffer: active_buffer.downgrade(),
excerpts: vec![related_excerpt],
max_row: active_snapshot.max_point().row,
};
included_files.push(active_file);
}
let included_files = included_files
.iter()
.map(|related_file| predict_edits_v3::RelatedFile {
path: Arc::from(related_file.path.path.as_std_path()),
max_row: Line(related_file.max_row),
excerpts: related_file
.excerpts
.iter()
.map(|excerpt| predict_edits_v3::Excerpt {
start_line: Line(excerpt.point_range.start.row),
text: excerpt.text.to_string().into(),
})
.collect(),
})
.collect::<Vec<_>>();
let cloud_request = predict_edits_v3::PredictEditsRequest {
excerpt_path,
excerpt: String::new(),
excerpt_line_range: Line(0)..Line(0),
excerpt_range: 0..0,
cursor_point: predict_edits_v3::Point {
line: predict_edits_v3::Line(cursor_point.row),
column: cursor_point.column,
},
related_files: included_files,
let cursor_offset = position.to_offset(&snapshot);
let (editable_offset_range, prompt_input) = zeta2_prompt_input(
&snapshot,
related_files,
events,
can_collect_data,
debug_info: debug_tx.is_some(),
prompt_max_bytes: Some(options.max_prompt_bytes),
prompt_format: options.prompt_format,
excerpt_parent: None,
git_info: None,
trigger,
};
excerpt_path,
cursor_offset,
);
let prompt_result = cloud_zeta2_prompt::build_prompt(&cloud_request);
let inputs = EditPredictionInputs {
included_files: cloud_request.related_files,
events: cloud_request.events,
cursor_point: cloud_request.cursor_point,
cursor_path: cloud_request.excerpt_path,
};
let retrieval_time = Instant::now() - before_retrieval;
let debug_response_tx = if let Some(debug_tx) = &debug_tx {
let (response_tx, response_rx) = oneshot::channel();
let prompt = format_zeta_prompt(&prompt_input);
if let Some(debug_tx) = &debug_tx {
debug_tx
.unbounded_send(DebugEvent::EditPredictionRequested(
EditPredictionRequestedDebugEvent {
inputs: inputs.clone(),
retrieval_time,
buffer: active_buffer.downgrade(),
local_prompt: match prompt_result.as_ref() {
Ok(prompt) => Ok(prompt.clone()),
Err(err) => Err(err.to_string()),
},
.unbounded_send(DebugEvent::EditPredictionStarted(
EditPredictionStartedDebugEvent {
buffer: buffer.downgrade(),
prompt: Some(prompt.clone()),
position,
response_rx,
},
))
.ok();
Some(response_tx)
} else {
None
};
if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() {
if let Some(debug_response_tx) = debug_response_tx {
debug_response_tx
.send((Err("Request skipped".to_string()), Duration::ZERO))
.ok();
}
anyhow::bail!("Skipping request because ZED_ZETA2_SKIP_REQUEST is set")
}
let prompt = prompt_result?;
let generation_params =
cloud_zeta2_prompt::generation_params(cloud_request.prompt_format);
let request = open_ai::Request {
model: EDIT_PREDICTIONS_MODEL_ID.clone(),
messages: vec![open_ai::RequestMessage::User {
@@ -199,8 +79,8 @@ pub fn request_prediction_with_zeta2(
}],
stream: false,
max_completion_tokens: None,
stop: generation_params.stop.unwrap_or_default(),
temperature: generation_params.temperature.or(Some(0.7)),
stop: Default::default(),
temperature: Default::default(),
tool_choice: None,
parallel_tool_calls: None,
tools: vec![],
@@ -210,7 +90,6 @@ pub fn request_prediction_with_zeta2(
log::trace!("Sending edit prediction request");
let before_request = Instant::now();
let response = EditPredictionStore::send_raw_llm_request(
request,
client,
@@ -223,68 +102,53 @@ pub fn request_prediction_with_zeta2(
)
.await;
let received_response_at = Instant::now();
let request_time = received_response_at - before_request;
log::trace!("Got edit prediction response");
if let Some(debug_response_tx) = debug_response_tx {
debug_response_tx
.send((
response
.as_ref()
.map_err(|err| err.to_string())
.map(|response| response.0.clone()),
request_time,
))
.ok();
}
let (res, usage) = response?;
let request_id = EditPredictionId(res.id.clone().into());
let Some(mut output_text) = text_from_response(res) else {
return Ok((Some((request_id, None)), usage));
};
if let Some(debug_tx) = &debug_tx {
debug_tx
.unbounded_send(DebugEvent::EditPredictionFinished(
EditPredictionFinishedDebugEvent {
buffer: buffer.downgrade(),
position,
model_output: Some(output_text.clone()),
},
))
.ok();
}
if output_text.contains(CURSOR_MARKER) {
log::trace!("Stripping out {CURSOR_MARKER} from response");
output_text = output_text.replace(CURSOR_MARKER, "");
}
let get_buffer_from_context = |path: &Path| {
if Some(path) == active_file_full_path.as_deref() {
Some((
&active_snapshot,
std::slice::from_ref(&excerpt_anchor_range),
))
} else {
None
}
};
let (_, edits) = match options.prompt_format {
PromptFormat::Minimal | PromptFormat::MinimalQwen | PromptFormat::SeedCoder1120 => {
if output_text.contains("--- a/\n+++ b/\nNo edits") {
let edits = vec![];
(&active_snapshot, edits)
} else {
crate::udiff::parse_diff(&output_text, get_buffer_from_context).await?
}
}
PromptFormat::OldTextNewText => {
crate::xml_edits::parse_xml_edits(&output_text, get_buffer_from_context).await?
}
_ => {
bail!("unsupported prompt format {}", options.prompt_format)
}
};
let old_text = snapshot
.text_for_range(editable_offset_range.clone())
.collect::<String>();
let edits: Vec<_> = language::text_diff(&old_text, &output_text)
.into_iter()
.map(|(range, text)| {
(
snapshot.anchor_after(editable_offset_range.start + range.start)
..snapshot.anchor_before(editable_offset_range.start + range.end),
text,
)
})
.collect();
anyhow::Ok((
Some((
request_id,
Some((
inputs,
active_buffer,
active_snapshot.clone(),
prompt_input,
buffer,
snapshot.clone(),
edits,
received_response_at,
)),
@@ -325,3 +189,40 @@ pub fn request_prediction_with_zeta2(
))
})
}
pub fn zeta2_prompt_input(
snapshot: &language::BufferSnapshot,
related_files: Arc<[zeta_prompt::RelatedFile]>,
events: Vec<Arc<zeta_prompt::Event>>,
excerpt_path: Arc<Path>,
cursor_offset: usize,
) -> (std::ops::Range<usize>, zeta_prompt::ZetaPromptInput) {
let cursor_point = cursor_offset.to_point(snapshot);
let (editable_range, context_range) =
crate::cursor_excerpt::editable_and_context_ranges_for_cursor_position(
cursor_point,
snapshot,
MAX_CONTEXT_TOKENS,
MAX_REWRITE_TOKENS,
);
let context_start_offset = context_range.start.to_offset(snapshot);
let editable_offset_range = editable_range.to_offset(snapshot);
let cursor_offset_in_excerpt = cursor_offset - context_start_offset;
let editable_range_in_excerpt = (editable_offset_range.start - context_start_offset)
..(editable_offset_range.end - context_start_offset);
let prompt_input = zeta_prompt::ZetaPromptInput {
cursor_path: excerpt_path,
cursor_excerpt: snapshot
.text_for_range(context_range)
.collect::<String>()
.into(),
editable_range_in_excerpt,
cursor_offset_in_excerpt,
events,
related_files,
};
(editable_offset_range, prompt_input)
}

View File

@@ -9,7 +9,7 @@ license = "GPL-3.0-or-later"
workspace = true
[[bin]]
name = "ep_cli"
name = "ep"
path = "src/main.rs"
[dependencies]
@@ -20,10 +20,9 @@ chrono.workspace = true
clap.workspace = true
client.workspace = true
cloud_llm_client.workspace= true
cloud_zeta2_prompt.workspace = true
collections.workspace = true
debug_adapter_extension.workspace = true
edit_prediction_context.workspace = true
dirs.workspace = true
extension.workspace = true
fs.workspace = true
futures.workspace = true
@@ -51,12 +50,21 @@ smol.workspace = true
sqlez.workspace = true
sqlez_macros.workspace = true
terminal_view.workspace = true
toml.workspace = true
util.workspace = true
watch.workspace = true
edit_prediction = { workspace = true, features = ["eval-support"] }
wasmtime.workspace = true
zeta_prompt.workspace = true
zlog.workspace = true
# Wasmtime is included as a dependency in order to enable the same
# features that are enabled in Zed.
#
# If we don't enable these features we get crashes when creating
# a Tree-sitter WasmStore.
[package.metadata.cargo-machete]
ignored = ["wasmtime"]
[dev-dependencies]
indoc.workspace = true
gpui = { workspace = true, features = ["test-support"] }

View File

@@ -5,11 +5,13 @@ use anthropic::{
use anyhow::Result;
use http_client::HttpClient;
use indoc::indoc;
use reqwest_client::ReqwestClient;
use sqlez::bindable::Bind;
use sqlez::bindable::StaticColumnCount;
use sqlez_macros::sql;
use std::hash::Hash;
use std::hash::Hasher;
use std::path::Path;
use std::sync::Arc;
pub struct PlainLlmClient {
@@ -18,7 +20,8 @@ pub struct PlainLlmClient {
}
impl PlainLlmClient {
fn new(http_client: Arc<dyn HttpClient>) -> Result<Self> {
fn new() -> Result<Self> {
let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
let api_key = std::env::var("ANTHROPIC_API_KEY")
.map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
Ok(Self {
@@ -29,12 +32,12 @@ impl PlainLlmClient {
async fn generate(
&self,
model: String,
model: &str,
max_tokens: u64,
messages: Vec<Message>,
) -> Result<AnthropicResponse> {
let request = AnthropicRequest {
model,
model: model.to_string(),
max_tokens,
messages,
tools: Vec::new(),
@@ -105,11 +108,12 @@ struct SerializableMessage {
}
impl BatchingLlmClient {
fn new(cache_path: &str, http_client: Arc<dyn HttpClient>) -> Result<Self> {
fn new(cache_path: &Path) -> Result<Self> {
let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
let api_key = std::env::var("ANTHROPIC_API_KEY")
.map_err(|_| anyhow::anyhow!("ANTHROPIC_API_KEY environment variable not set"))?;
let connection = sqlez::connection::Connection::open_file(&cache_path);
let connection = sqlez::connection::Connection::open_file(&cache_path.to_str().unwrap());
let mut statement = sqlez::statement::Statement::prepare(
&connection,
indoc! {"
@@ -182,16 +186,16 @@ impl BatchingLlmClient {
async fn generate(
&self,
model: String,
model: &str,
max_tokens: u64,
messages: Vec<Message>,
) -> Result<Option<AnthropicResponse>> {
let response = self.lookup(&model, max_tokens, &messages)?;
let response = self.lookup(model, max_tokens, &messages)?;
if let Some(response) = response {
return Ok(Some(response));
}
self.mark_for_batch(&model, max_tokens, &messages)?;
self.mark_for_batch(model, max_tokens, &messages)?;
Ok(None)
}
@@ -258,7 +262,7 @@ impl BatchingLlmClient {
}
}
}
log::info!("Uploaded {} successful requests", success_count);
log::info!("Downloaded {} successful requests", success_count);
}
}
@@ -363,23 +367,20 @@ fn message_content_to_string(content: &[RequestContent]) -> String {
.join("\n")
}
pub enum LlmClient {
pub enum AnthropicClient {
// No batching
Plain(PlainLlmClient),
Batch(BatchingLlmClient),
Dummy,
}
impl LlmClient {
pub fn plain(http_client: Arc<dyn HttpClient>) -> Result<Self> {
Ok(Self::Plain(PlainLlmClient::new(http_client)?))
impl AnthropicClient {
pub fn plain() -> Result<Self> {
Ok(Self::Plain(PlainLlmClient::new()?))
}
pub fn batch(cache_path: &str, http_client: Arc<dyn HttpClient>) -> Result<Self> {
Ok(Self::Batch(BatchingLlmClient::new(
cache_path,
http_client,
)?))
pub fn batch(cache_path: &Path) -> Result<Self> {
Ok(Self::Batch(BatchingLlmClient::new(cache_path)?))
}
#[allow(dead_code)]
@@ -389,29 +390,29 @@ impl LlmClient {
pub async fn generate(
&self,
model: String,
model: &str,
max_tokens: u64,
messages: Vec<Message>,
) -> Result<Option<AnthropicResponse>> {
match self {
LlmClient::Plain(plain_llm_client) => plain_llm_client
AnthropicClient::Plain(plain_llm_client) => plain_llm_client
.generate(model, max_tokens, messages)
.await
.map(Some),
LlmClient::Batch(batching_llm_client) => {
AnthropicClient::Batch(batching_llm_client) => {
batching_llm_client
.generate(model, max_tokens, messages)
.await
}
LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
}
}
pub async fn sync_batches(&self) -> Result<()> {
match self {
LlmClient::Plain(_) => Ok(()),
LlmClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
LlmClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
AnthropicClient::Plain(_) => Ok(()),
AnthropicClient::Batch(batching_llm_client) => batching_llm_client.sync_batches().await,
AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
}
}
}

View File

@@ -1,641 +0,0 @@
use crate::metrics::{self, Scores};
use std::{
collections::HashMap,
io::{IsTerminal, Write},
sync::Arc,
};
use anyhow::Result;
use edit_prediction::{EditPredictionStore, udiff::DiffLine};
use gpui::{AsyncApp, Entity};
use project::Project;
use util::ResultExt as _;
use crate::{
EvaluateArguments, PredictionOptions,
example::{Example, NamedExample},
headless::ZetaCliAppState,
paths::print_run_data_dir,
predict::{PredictionDetails, perform_predict, setup_store},
};
#[derive(Debug)]
pub(crate) struct ExecutionData {
execution_id: String,
diff: String,
reasoning: String,
}
pub async fn run_evaluate(
args: EvaluateArguments,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
) {
if args.example_paths.is_empty() {
eprintln!("No examples provided");
return;
}
let all_tasks = args.example_paths.into_iter().map(|path| {
let options = args.options.clone();
let app_state = app_state.clone();
let example = NamedExample::load(&path).expect("Failed to load example");
cx.spawn(async move |cx| {
let project = example.setup_project(&app_state, cx).await.unwrap();
let providers = (0..args.repetitions)
.map(|_| setup_store(args.options.provider, &project, &app_state, cx).unwrap())
.collect::<Vec<_>>();
let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
let tasks = providers
.into_iter()
.enumerate()
.map(move |(repetition_ix, store)| {
let repetition_ix = (args.repetitions > 1).then(|| repetition_ix as u16);
let example = example.clone();
let project = project.clone();
let options = options.clone();
cx.spawn(async move |cx| {
let name = example.name.clone();
run_evaluate_one(
example,
repetition_ix,
project,
store,
options,
!args.skip_prediction,
cx,
)
.await
.map_err(|err| (err, name, repetition_ix))
})
});
futures::future::join_all(tasks).await
})
});
let all_results = futures::future::join_all(all_tasks).await;
write_aggregated_scores(&mut std::io::stdout(), &all_results).unwrap();
if let Some(mut output_file) =
std::fs::File::create(crate::paths::RUN_DIR.join("aggregated_results.md")).log_err()
{
write_aggregated_scores(&mut output_file, &all_results).log_err();
};
if args.repetitions > 1 {
if let Err(e) = write_bucketed_analysis(&all_results) {
eprintln!("Failed to write bucketed analysis: {:?}", e);
}
}
print_run_data_dir(args.repetitions == 1, std::io::stdout().is_terminal());
}
fn write_aggregated_scores(
w: &mut impl std::io::Write,
all_results: &Vec<
Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
>,
) -> Result<()> {
let mut successful = Vec::new();
let mut failed_count = 0;
for result in all_results.iter().flatten() {
match result {
Ok((eval_result, _execution_data)) => successful.push(eval_result),
Err((err, name, repetition_ix)) => {
if failed_count == 0 {
writeln!(w, "## Errors\n")?;
}
failed_count += 1;
writeln!(w, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
}
}
}
if successful.len() > 1 {
let edit_scores = successful
.iter()
.filter_map(|r| r.edit_scores.clone())
.collect::<Vec<_>>();
let has_edit_predictions = edit_scores.len() > 0;
let aggregated_result = EvaluationResult {
context_scores: Scores::aggregate(successful.iter().map(|r| &r.context_scores)),
edit_scores: has_edit_predictions.then(|| EditScores::aggregate(&edit_scores)),
prompt_len: successful.iter().map(|r| r.prompt_len).sum::<usize>() / successful.len(),
generated_len: successful.iter().map(|r| r.generated_len).sum::<usize>()
/ successful.len(),
};
writeln!(w, "\n{}", "-".repeat(80))?;
writeln!(w, "\n## TOTAL SCORES")?;
writeln!(w, "{:#}", aggregated_result)?;
}
if successful.len() + failed_count > 1 {
writeln!(
w,
"\nCongratulations! {}/{} ({:.2}%) of runs weren't outright failures 🎉",
successful.len(),
successful.len() + failed_count,
(successful.len() as f64 / (successful.len() + failed_count) as f64) * 100.0
)?;
}
Ok(())
}
pub async fn run_evaluate_one(
example: NamedExample,
repetition_ix: Option<u16>,
project: Entity<Project>,
store: Entity<EditPredictionStore>,
prediction_options: PredictionOptions,
predict: bool,
cx: &mut AsyncApp,
) -> Result<(EvaluationResult, ExecutionData)> {
let predict_result = perform_predict(
example.clone(),
project,
store,
repetition_ix,
prediction_options,
cx,
)
.await?;
let evaluation_result = evaluate(&example.example, &predict_result, predict);
if repetition_ix.is_none() {
write_eval_result(
&example,
&predict_result,
&evaluation_result,
&mut std::io::stdout(),
std::io::stdout().is_terminal(),
predict,
)?;
}
if let Some(mut results_file) =
std::fs::File::create(predict_result.run_example_dir.join("results.md")).log_err()
{
write_eval_result(
&example,
&predict_result,
&evaluation_result,
&mut results_file,
false,
predict,
)
.log_err();
}
let execution_data = ExecutionData {
execution_id: if let Some(rep_ix) = repetition_ix {
format!("{:03}", rep_ix)
} else {
example.name.clone()
},
diff: predict_result.diff.clone(),
reasoning: std::fs::read_to_string(
predict_result
.run_example_dir
.join("prediction_response.md"),
)
.unwrap_or_default(),
};
anyhow::Ok((evaluation_result, execution_data))
}
fn write_eval_result(
example: &NamedExample,
predictions: &PredictionDetails,
evaluation_result: &EvaluationResult,
out: &mut impl Write,
use_color: bool,
predict: bool,
) -> Result<()> {
if predict {
writeln!(
out,
"## Expected edit prediction:\n\n```diff\n{}\n```\n",
compare_diffs(
&example.example.expected_patch,
&predictions.diff,
use_color
)
)?;
writeln!(
out,
"## Actual edit prediction:\n\n```diff\n{}\n```\n",
compare_diffs(
&predictions.diff,
&example.example.expected_patch,
use_color
)
)?;
}
writeln!(out, "{:#}", evaluation_result)?;
anyhow::Ok(())
}
#[derive(Debug, Default, Clone)]
pub struct EditScores {
pub line_match: Scores,
pub chr_f: f64,
}
impl EditScores {
pub fn aggregate(scores: &[EditScores]) -> EditScores {
let line_match = Scores::aggregate(scores.iter().map(|s| &s.line_match));
let chr_f = scores.iter().map(|s| s.chr_f).sum::<f64>() / scores.len() as f64;
EditScores { line_match, chr_f }
}
}
#[derive(Debug, Default)]
pub struct EvaluationResult {
pub edit_scores: Option<EditScores>,
pub context_scores: Scores,
pub prompt_len: usize,
pub generated_len: usize,
}
impl std::fmt::Display for EvaluationResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if f.alternate() {
self.fmt_table(f)
} else {
self.fmt_markdown(f)
}
}
}
impl EvaluationResult {
fn fmt_markdown(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
r#"
### Context Scores
{}
"#,
self.context_scores.to_markdown(),
)?;
if let Some(scores) = &self.edit_scores {
write!(
f,
r#"
### Edit Prediction Scores
{}"#,
scores.line_match.to_markdown()
)?;
}
Ok(())
}
fn fmt_table(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "#### Prompt Statistics")?;
writeln!(f, "─────────────────────────")?;
writeln!(f, "Prompt_len Generated_len")?;
writeln!(f, "─────────────────────────")?;
writeln!(f, "{:<11} {:<14}", self.prompt_len, self.generated_len,)?;
writeln!(f)?;
writeln!(f)?;
writeln!(f, "#### Performance Scores")?;
writeln!(
f,
"──────────────────────────────────────────────────────────────────"
)?;
writeln!(
f,
" TP FP FN Precision Recall F1"
)?;
writeln!(
f,
"──────────────────────────────────────────────────────────────────"
)?;
writeln!(
f,
"Context Retrieval {:<6} {:<6} {:<6} {:>8.2} {:>7.2} {:>6.2}",
self.context_scores.true_positives,
self.context_scores.false_positives,
self.context_scores.false_negatives,
self.context_scores.precision() * 100.0,
self.context_scores.recall() * 100.0,
self.context_scores.f1_score() * 100.0
)?;
if let Some(edit_scores) = &self.edit_scores {
let line_match = &edit_scores.line_match;
writeln!(f, "Edit Prediction")?;
writeln!(
f,
" ├─ exact lines {:<6} {:<6} {:<6} {:>8.2} {:>7.2} {:>6.2}",
line_match.true_positives,
line_match.false_positives,
line_match.false_negatives,
line_match.precision() * 100.0,
line_match.recall() * 100.0,
line_match.f1_score() * 100.0
)?;
writeln!(
f,
" └─ diff chrF {:<6} {:<6} {:<6} {:>8} {:>8} {:>6.2}",
"-", "-", "-", "-", "-", edit_scores.chr_f
)?;
}
Ok(())
}
}
fn evaluate(example: &Example, preds: &PredictionDetails, predict: bool) -> EvaluationResult {
let mut eval_result = EvaluationResult {
prompt_len: preds.prompt_len,
generated_len: preds.generated_len,
..Default::default()
};
if predict {
// todo: alternatives for patches
let expected_patch = example
.expected_patch
.lines()
.map(DiffLine::parse)
.collect::<Vec<_>>();
let actual_patch = preds.diff.lines().map(DiffLine::parse).collect::<Vec<_>>();
let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
let chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch);
eval_result.edit_scores = Some(EditScores { line_match, chr_f });
}
eval_result
}
/// Return annotated `patch_a` so that:
/// Additions and deletions that are not present in `patch_b` will be highlighted in red.
/// Additions and deletions that are present in `patch_b` will be highlighted in green.
pub fn compare_diffs(patch_a: &str, patch_b: &str, use_color: bool) -> String {
let green = if use_color { "\x1b[32m✓ " } else { "" };
let red = if use_color { "\x1b[31m✗ " } else { "" };
let neutral = if use_color { " " } else { "" };
let reset = if use_color { "\x1b[0m" } else { "" };
let lines_a = patch_a.lines().map(DiffLine::parse);
let lines_b: Vec<_> = patch_b.lines().map(DiffLine::parse).collect();
let annotated = lines_a
.map(|line| match line {
DiffLine::Addition(_) | DiffLine::Deletion(_) => {
if lines_b.contains(&line) {
format!("{green}{line}{reset}")
} else {
format!("{red}{line}{reset}")
}
}
_ => format!("{neutral}{line}{reset}"),
})
.collect::<Vec<String>>();
annotated.join("\n")
}
fn write_bucketed_analysis(
all_results: &Vec<
Vec<Result<(EvaluationResult, ExecutionData), (anyhow::Error, String, Option<u16>)>>,
>,
) -> Result<()> {
#[derive(Debug)]
struct EditBucket {
diff: String,
is_correct: bool,
execution_indices: Vec<String>,
reasoning_samples: Vec<String>,
}
let mut total_executions = 0;
let mut empty_predictions = Vec::new();
let mut errors = Vec::new();
let mut buckets: HashMap<String, EditBucket> = HashMap::new();
for result in all_results.iter().flatten() {
total_executions += 1;
let (evaluation_result, execution_data) = match result {
Ok((eval_result, execution_data)) => {
if execution_data.diff.is_empty() {
empty_predictions.push(execution_data);
continue;
}
(eval_result, execution_data)
}
Err(err) => {
errors.push(err);
continue;
}
};
buckets
.entry(execution_data.diff.clone())
.and_modify(|bucket| {
bucket
.execution_indices
.push(execution_data.execution_id.clone());
bucket
.reasoning_samples
.push(execution_data.reasoning.clone());
})
.or_insert_with(|| EditBucket {
diff: execution_data.diff.clone(),
is_correct: {
evaluation_result
.edit_scores
.as_ref()
.map_or(false, |edit_scores| {
edit_scores.line_match.false_positives == 0
&& edit_scores.line_match.false_negatives == 0
&& edit_scores.line_match.true_positives > 0
})
},
execution_indices: vec![execution_data.execution_id.clone()],
reasoning_samples: vec![execution_data.reasoning.clone()],
});
}
let mut sorted_buckets = buckets.into_values().collect::<Vec<_>>();
sorted_buckets.sort_by(|a, b| match (a.is_correct, b.is_correct) {
(true, false) => std::cmp::Ordering::Less,
(false, true) => std::cmp::Ordering::Greater,
_ => b.execution_indices.len().cmp(&a.execution_indices.len()),
});
let output_path = crate::paths::RUN_DIR.join("bucketed_analysis.md");
let mut output = std::fs::File::create(&output_path)?;
writeln!(output, "# Bucketed Edit Analysis\n")?;
writeln!(output, "## Summary\n")?;
writeln!(output, "- **Total executions**: {}", total_executions)?;
let correct_count: usize = sorted_buckets
.iter()
.filter(|b| b.is_correct)
.map(|b| b.execution_indices.len())
.sum();
let incorrect_count: usize = sorted_buckets
.iter()
.filter(|b| !b.is_correct)
.map(|b| b.execution_indices.len())
.sum();
writeln!(
output,
"- **Correct predictions**: {} ({:.1}%)",
correct_count,
(correct_count as f64 / total_executions as f64) * 100.0
)?;
writeln!(
output,
"- **Incorrect predictions**: {} ({:.1}%)",
incorrect_count,
(incorrect_count as f64 / total_executions as f64) * 100.0
)?;
writeln!(
output,
"- **No Predictions**: {} ({:.1}%)",
empty_predictions.len(),
(empty_predictions.len() as f64 / total_executions as f64) * 100.0
)?;
let unique_incorrect = sorted_buckets.iter().filter(|b| !b.is_correct).count();
writeln!(
output,
"- **Unique incorrect edit patterns**: {}\n",
unique_incorrect
)?;
writeln!(output, "---\n")?;
for (idx, bucket) in sorted_buckets.iter().filter(|b| b.is_correct).enumerate() {
if idx == 0 {
writeln!(
output,
"## Correct Predictions ({} occurrences)\n",
bucket.execution_indices.len()
)?;
}
writeln!(output, "**Predicted Edit:**\n")?;
writeln!(output, "```diff")?;
writeln!(output, "{}", bucket.diff)?;
writeln!(output, "```\n")?;
writeln!(
output,
"**Executions:** {}\n",
bucket.execution_indices.join(", ")
)?;
writeln!(output, "---\n")?;
}
for (idx, bucket) in sorted_buckets.iter().filter(|b| !b.is_correct).enumerate() {
writeln!(
output,
"## Incorrect Prediction #{} ({} occurrences)\n",
idx + 1,
bucket.execution_indices.len()
)?;
writeln!(output, "**Predicted Edit:**\n")?;
writeln!(output, "```diff")?;
writeln!(output, "{}", bucket.diff)?;
writeln!(output, "```\n")?;
writeln!(
output,
"**Executions:** {}\n",
bucket.execution_indices.join(", ")
)?;
for (exec_id, reasoning) in bucket
.execution_indices
.iter()
.zip(bucket.reasoning_samples.iter())
{
writeln!(output, "{}", fmt_execution(exec_id, reasoning))?;
}
writeln!(output, "\n---\n")?;
}
if !empty_predictions.is_empty() {
writeln!(
output,
"## No Predictions ({} occurrences)\n",
empty_predictions.len()
)?;
for execution_data in &empty_predictions {
writeln!(
output,
"{}",
fmt_execution(&execution_data.execution_id, &execution_data.reasoning)
)?;
}
writeln!(output, "\n---\n")?;
}
if !errors.is_empty() {
writeln!(output, "## Errors ({} occurrences)\n", errors.len())?;
for (err, name, repetition_ix) in &errors {
writeln!(output, "{}", fmt_evaluation_error(err, name, repetition_ix))?;
}
writeln!(output, "\n---\n")?;
}
fn fmt_execution(exec_id: &str, reasoning: &str) -> String {
let exec_content = format!(
"\n### Execution {} `{}/{}/prediction_response.md`{}",
exec_id,
crate::paths::RUN_DIR.display(),
exec_id,
indent_text(&format!("\n\n```\n{}\n```\n", reasoning,), 2)
);
indent_text(&exec_content, 2)
}
fn indent_text(text: &str, spaces: usize) -> String {
let indent = " ".repeat(spaces);
text.lines()
.collect::<Vec<_>>()
.join(&format!("\n{}", indent))
}
Ok(())
}
fn fmt_evaluation_error(err: &anyhow::Error, name: &str, repetition_ix: &Option<u16>) -> String {
let err = format!("{err:?}")
.replace("<edits", "```xml\n<edits")
.replace("</edits>", "</edits>\n```");
format!(
"### ERROR {name}{}\n\n{err}\n",
repetition_ix
.map(|ix| format!(" [RUN {ix:03}]"))
.unwrap_or_default()
)
}

View File

@@ -1,59 +1,103 @@
use crate::{
PredictionProvider, PromptFormat,
metrics::ClassificationMetrics,
paths::{REPOS_DIR, WORKTREES_DIR},
};
use anyhow::{Context as _, Result};
use edit_prediction::udiff::OpenedBuffers;
use gpui::Entity;
use http_client::Url;
use language::{Anchor, Buffer};
use project::Project;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::{
borrow::Cow,
cell::RefCell,
fmt::{self, Display},
fs,
hash::Hash,
hash::Hasher,
io::Write,
io::{Read, Write},
mem,
path::{Path, PathBuf},
sync::{Arc, OnceLock},
};
use zeta_prompt::RelatedFile;
use crate::headless::ZetaCliAppState;
use anyhow::{Context as _, Result, anyhow};
use clap::ValueEnum;
use cloud_zeta2_prompt::CURSOR_MARKER;
use collections::HashMap;
use edit_prediction::udiff::OpenedBuffers;
use futures::{
AsyncWriteExt as _,
lock::{Mutex, OwnedMutexGuard},
};
use futures::{FutureExt as _, future::Shared};
use gpui::{AsyncApp, Entity, Task, http_client::Url};
use language::{Anchor, Buffer};
use project::{Project, ProjectPath};
use pulldown_cmark::CowStr;
use serde::{Deserialize, Serialize};
use util::{paths::PathStyle, rel_path::RelPath};
use crate::paths::{REPOS_DIR, WORKTREES_DIR};
const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
const EDIT_HISTORY_HEADING: &str = "Edit History";
const CURSOR_POSITION_HEADING: &str = "Cursor Position";
const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
const REPOSITORY_URL_FIELD: &str = "repository_url";
const REVISION_FIELD: &str = "revision";
#[derive(Debug, Clone)]
pub struct NamedExample {
pub name: String,
pub example: Example,
}
#[derive(Clone, Debug, Hash, Serialize, Deserialize)]
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Example {
#[serde(default)]
pub name: String,
pub repository_url: String,
pub revision: String,
pub uncommitted_diff: String,
pub cursor_path: PathBuf,
pub cursor_path: Arc<Path>,
pub cursor_position: String,
pub edit_history: String,
pub expected_patch: String,
/// The full content of the file where an edit is being predicted, and the
/// actual cursor offset.
#[serde(skip_serializing_if = "Option::is_none")]
pub buffer: Option<ExampleBuffer>,
/// The context retrieved for the prediction. This requires the worktree to
/// be loaded and the language server to be started.
#[serde(skip_serializing_if = "Option::is_none")]
pub context: Option<ExampleContext>,
/// The input and expected output from the edit prediction model.
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt: Option<ExamplePrompt>,
/// The actual predictions from the model.
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub predictions: Vec<ExamplePrediction>,
/// The scores, for how well the actual predictions match the expected
/// predictions.
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub score: Vec<ExampleScore>,
/// The application state used to process this example.
#[serde(skip)]
pub state: Option<ExampleState>,
}
#[derive(Clone, Debug)]
pub struct ExampleState {
pub project: Entity<Project>,
pub buffer: Entity<Buffer>,
pub cursor_position: Anchor,
pub _open_buffers: OpenedBuffers,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ExampleContext {
pub files: Arc<[RelatedFile]>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ExampleBuffer {
pub content: String,
pub cursor_row: u32,
pub cursor_column: u32,
pub cursor_offset: usize,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ExamplePrompt {
pub input: String,
pub expected_output: String,
pub format: PromptFormat,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ExamplePrediction {
pub actual_patch: String,
pub actual_output: String,
pub provider: PredictionProvider,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ExampleScore {
pub delta_chr_f: f32,
pub line_match: ClassificationMetrics,
}
impl Example {
@@ -90,485 +134,244 @@ impl Example {
}
}
pub async fn setup_worktree(&self, file_name: String) -> Result<PathBuf> {
let (repo_owner, repo_name) = self.repo_name()?;
pub fn worktree_path(&self) -> PathBuf {
WORKTREES_DIR
.join(&self.name)
.join(self.repo_name().unwrap().1.as_ref())
}
let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
let repo_lock = lock_repo(&repo_dir).await;
pub fn repo_path(&self) -> PathBuf {
let (repo_owner, repo_name) = self.repo_name().expect("failed to get repo name");
REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref())
}
}
if !repo_dir.is_dir() {
fs::create_dir_all(&repo_dir)?;
run_git(&repo_dir, &["init"]).await?;
run_git(
&repo_dir,
&["remote", "add", "origin", &self.repository_url],
)
.await?;
}
pub fn read_examples(inputs: &[PathBuf]) -> Vec<Example> {
let mut examples = Vec::new();
// Resolve the example to a revision, fetching it if needed.
let revision = run_git(
&repo_dir,
&["rev-parse", &format!("{}^{{commit}}", self.revision)],
)
.await;
let revision = if let Ok(revision) = revision {
revision
let stdin_path: PathBuf = PathBuf::from("-");
let inputs = if inputs.is_empty() {
&[stdin_path]
} else {
inputs
};
for path in inputs {
let is_stdin = path.as_path() == Path::new("-");
let content = if is_stdin {
let mut buffer = String::new();
std::io::stdin()
.read_to_string(&mut buffer)
.expect("Failed to read from stdin");
buffer
} else {
if run_git(
&repo_dir,
&["fetch", "--depth", "1", "origin", &self.revision],
)
.await
.is_err()
{
run_git(&repo_dir, &["fetch", "origin"]).await?;
}
let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
if revision != self.revision {
run_git(&repo_dir, &["tag", &self.revision, &revision]).await?;
}
revision
std::fs::read_to_string(path)
.unwrap_or_else(|_| panic!("Failed to read path: {:?}", &path))
};
let filename = path.file_stem().unwrap().to_string_lossy().to_string();
let ext = if !is_stdin {
path.extension()
.map(|ext| ext.to_string_lossy().to_string())
.unwrap_or_else(|| panic!("{} should have an extension", path.display()))
} else {
"jsonl".to_string()
};
// Create the worktree for this example if needed.
let worktree_path = WORKTREES_DIR.join(&file_name).join(repo_name.as_ref());
if worktree_path.is_dir() {
run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
} else {
let worktree_path_string = worktree_path.to_string_lossy();
run_git(&repo_dir, &["branch", "-f", &file_name, revision.as_str()]).await?;
run_git(
&repo_dir,
&["worktree", "add", "-f", &worktree_path_string, &file_name],
)
.await?;
}
drop(repo_lock);
// Apply the uncommitted diff for this example.
if !self.uncommitted_diff.is_empty() {
let mut apply_process = smol::process::Command::new("git")
.current_dir(&worktree_path)
.args(&["apply", "-"])
.stdin(std::process::Stdio::piped())
.spawn()?;
let mut stdin = apply_process.stdin.take().unwrap();
stdin.write_all(self.uncommitted_diff.as_bytes()).await?;
stdin.close().await?;
drop(stdin);
let apply_result = apply_process.output().await?;
if !apply_result.status.success() {
anyhow::bail!(
"Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
apply_result.status,
String::from_utf8_lossy(&apply_result.stderr),
String::from_utf8_lossy(&apply_result.stdout),
);
match ext.as_ref() {
"json" => {
let mut example =
serde_json::from_str::<Example>(&content).unwrap_or_else(|error| {
panic!("Failed to parse example file: {}\n{error}", path.display())
});
if example.name.is_empty() {
example.name = filename;
}
examples.push(example);
}
"jsonl" => examples.extend(
content
.lines()
.enumerate()
.map(|(line_ix, line)| {
let mut example =
serde_json::from_str::<Example>(line).unwrap_or_else(|_| {
panic!(
"Failed to parse example on {}:{}",
path.display(),
line_ix + 1
)
});
if example.name.is_empty() {
example.name = format!("{filename}-{line_ix}")
}
example
})
.collect::<Vec<Example>>(),
),
"md" => {
examples.push(parse_markdown_example(filename, &content).unwrap());
}
ext => {
panic!("{} has invalid example extension `{ext}`", path.display())
}
}
Ok(worktree_path)
}
examples
}
pub fn unique_name(&self) -> String {
let mut hasher = std::hash::DefaultHasher::new();
self.hash(&mut hasher);
let disambiguator = hasher.finish();
let hash = format!("{:04x}", disambiguator);
format!("{}_{}", &self.revision[..8], &hash[..4])
pub fn write_examples(examples: &[Example], output_path: Option<&PathBuf>) {
let mut content = String::new();
for example in examples {
let line = serde_json::to_string(example).unwrap();
content.push_str(&line);
content.push('\n');
}
if let Some(output_path) = output_path {
std::fs::write(output_path, content).expect("Failed to write examples");
} else {
std::io::stdout().write_all(&content.as_bytes()).unwrap();
}
}
pub type ActualExcerpt = Excerpt;
fn parse_markdown_example(id: String, input: &str) -> Result<Example> {
use pulldown_cmark::{CodeBlockKind, CowStr, Event, HeadingLevel, Parser, Tag, TagEnd};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Excerpt {
pub path: PathBuf,
pub text: String,
}
const UNCOMMITTED_DIFF_HEADING: &str = "Uncommitted Diff";
const EDIT_HISTORY_HEADING: &str = "Edit History";
const CURSOR_POSITION_HEADING: &str = "Cursor Position";
const EXPECTED_PATCH_HEADING: &str = "Expected Patch";
const EXPECTED_CONTEXT_HEADING: &str = "Expected Context";
const REPOSITORY_URL_FIELD: &str = "repository_url";
const REVISION_FIELD: &str = "revision";
#[derive(ValueEnum, Debug, Clone)]
pub enum ExampleFormat {
Json,
Toml,
Md,
}
let parser = Parser::new(input);
impl NamedExample {
pub fn load(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref();
let content = std::fs::read_to_string(path)?;
let ext = path.extension();
let mut example = Example {
name: id,
repository_url: String::new(),
revision: String::new(),
uncommitted_diff: String::new(),
cursor_path: PathBuf::new().into(),
cursor_position: String::new(),
edit_history: String::new(),
expected_patch: String::new(),
buffer: None,
context: None,
prompt: None,
predictions: Vec::new(),
score: Vec::new(),
state: None,
};
match ext.and_then(|s| s.to_str()) {
Some("json") => Ok(Self {
name: path.file_stem().unwrap_or_default().display().to_string(),
example: serde_json::from_str(&content)?,
}),
Some("toml") => Ok(Self {
name: path.file_stem().unwrap_or_default().display().to_string(),
example: toml::from_str(&content)?,
}),
Some("md") => Self::parse_md(&content),
Some(_) => {
anyhow::bail!("Unrecognized example extension: {}", ext.unwrap().display());
}
None => {
anyhow::bail!(
"Failed to determine example type since the file does not have an extension."
);
}
}
let mut name = String::new();
let mut text = String::new();
let mut block_info: CowStr = "".into();
#[derive(PartialEq)]
enum Section {
UncommittedDiff,
EditHistory,
CursorPosition,
ExpectedExcerpts,
ExpectedPatch,
Other,
}
pub fn parse_md(input: &str) -> Result<Self> {
use pulldown_cmark::{CodeBlockKind, Event, HeadingLevel, Parser, Tag, TagEnd};
let mut current_section = Section::Other;
let parser = Parser::new(input);
for event in parser {
match event {
Event::Text(line) => {
text.push_str(&line);
let mut named = NamedExample {
name: String::new(),
example: Example {
repository_url: String::new(),
revision: String::new(),
uncommitted_diff: String::new(),
cursor_path: PathBuf::new(),
cursor_position: String::new(),
edit_history: String::new(),
expected_patch: String::new(),
},
};
let mut text = String::new();
let mut block_info: CowStr = "".into();
#[derive(PartialEq)]
enum Section {
UncommittedDiff,
EditHistory,
CursorPosition,
ExpectedExcerpts,
ExpectedPatch,
Other,
}
let mut current_section = Section::Other;
for event in parser {
match event {
Event::Text(line) => {
text.push_str(&line);
if !named.name.is_empty()
&& current_section == Section::Other
// in h1 section
&& let Some((field, value)) = line.split_once('=')
{
match field.trim() {
REPOSITORY_URL_FIELD => {
named.example.repository_url = value.trim().to_string();
}
REVISION_FIELD => {
named.example.revision = value.trim().to_string();
}
_ => {}
if let Some((field, value)) = line.split_once('=') {
match field.trim() {
REPOSITORY_URL_FIELD => {
example.repository_url = value.trim().to_string();
}
REVISION_FIELD => {
example.revision = value.trim().to_string();
}
_ => {}
}
}
Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
if !named.name.is_empty() {
anyhow::bail!(
"Found multiple H1 headings. There should only be one with the name of the example."
);
}
named.name = mem::take(&mut text);
}
Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
let title = mem::take(&mut text);
current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
Section::UncommittedDiff
} else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
Section::EditHistory
} else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
Section::CursorPosition
} else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
Section::ExpectedPatch
} else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
Section::ExpectedExcerpts
} else {
Section::Other
};
}
Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
mem::take(&mut text);
}
Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
mem::take(&mut text);
}
Event::End(TagEnd::Heading(level)) => {
anyhow::bail!("Unexpected heading level: {level}");
}
Event::Start(Tag::CodeBlock(kind)) => {
match kind {
CodeBlockKind::Fenced(info) => {
block_info = info;
}
CodeBlockKind::Indented => {
anyhow::bail!("Unexpected indented codeblock");
}
};
}
Event::Start(_) => {
text.clear();
block_info = "".into();
}
Event::End(TagEnd::CodeBlock) => {
let block_info = block_info.trim();
match current_section {
Section::UncommittedDiff => {
named.example.uncommitted_diff = mem::take(&mut text);
}
Section::EditHistory => {
named.example.edit_history.push_str(&mem::take(&mut text));
}
Section::CursorPosition => {
named.example.cursor_path = block_info.into();
named.example.cursor_position = mem::take(&mut text);
}
Section::ExpectedExcerpts => {
mem::take(&mut text);
}
Section::ExpectedPatch => {
named.example.expected_patch = mem::take(&mut text);
}
Section::Other => {}
}
}
_ => {}
}
}
if named.example.cursor_path.as_path() == Path::new("")
|| named.example.cursor_position.is_empty()
{
anyhow::bail!("Missing cursor position codeblock");
}
Ok(named)
}
pub fn write(&self, format: ExampleFormat, mut out: impl Write) -> Result<()> {
match format {
ExampleFormat::Json => Ok(serde_json::to_writer(out, &self.example)?),
ExampleFormat::Toml => {
Ok(out.write_all(toml::to_string_pretty(&self.example)?.as_bytes())?)
Event::End(TagEnd::Heading(HeadingLevel::H1)) => {
if !name.is_empty() {
anyhow::bail!(
"Found multiple H1 headings. There should only be one with the name of the example."
);
}
name = mem::take(&mut text);
}
ExampleFormat::Md => Ok(write!(out, "{}", self)?),
}
}
pub async fn setup_project(
&self,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
) -> Result<Entity<Project>> {
let worktree_path = self.setup_worktree().await?;
static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
AUTHENTICATED
.get_or_init(|| {
let client = app_state.client.clone();
cx.spawn(async move |cx| {
client
.sign_in_with_optional_connect(true, cx)
.await
.unwrap();
})
.shared()
})
.clone()
.await;
let project = cx.update(|cx| {
Project::local(
app_state.client.clone(),
app_state.node_runtime.clone(),
app_state.user_store.clone(),
app_state.languages.clone(),
app_state.fs.clone(),
None,
cx,
)
})?;
let worktree = project
.update(cx, |project, cx| {
project.create_worktree(&worktree_path, true, cx)
})?
.await?;
worktree
.read_with(cx, |worktree, _cx| {
worktree.as_local().unwrap().scan_complete()
})?
.await;
anyhow::Ok(project)
}
pub async fn setup_worktree(&self) -> Result<PathBuf> {
self.example.setup_worktree(self.file_name()).await
}
pub fn file_name(&self) -> String {
self.name
.chars()
.map(|c| {
if c.is_whitespace() {
'-'
Event::End(TagEnd::Heading(HeadingLevel::H2)) => {
let title = mem::take(&mut text);
current_section = if title.eq_ignore_ascii_case(UNCOMMITTED_DIFF_HEADING) {
Section::UncommittedDiff
} else if title.eq_ignore_ascii_case(EDIT_HISTORY_HEADING) {
Section::EditHistory
} else if title.eq_ignore_ascii_case(CURSOR_POSITION_HEADING) {
Section::CursorPosition
} else if title.eq_ignore_ascii_case(EXPECTED_PATCH_HEADING) {
Section::ExpectedPatch
} else if title.eq_ignore_ascii_case(EXPECTED_CONTEXT_HEADING) {
Section::ExpectedExcerpts
} else {
c.to_ascii_lowercase()
Section::Other
};
}
Event::End(TagEnd::Heading(HeadingLevel::H3)) => {
mem::take(&mut text);
}
Event::End(TagEnd::Heading(HeadingLevel::H4)) => {
mem::take(&mut text);
}
Event::End(TagEnd::Heading(level)) => {
anyhow::bail!("Unexpected heading level: {level}");
}
Event::Start(Tag::CodeBlock(kind)) => {
match kind {
CodeBlockKind::Fenced(info) => {
block_info = info;
}
CodeBlockKind::Indented => {
anyhow::bail!("Unexpected indented codeblock");
}
};
}
Event::Start(_) => {
text.clear();
block_info = "".into();
}
Event::End(TagEnd::CodeBlock) => {
let block_info = block_info.trim();
match current_section {
Section::UncommittedDiff => {
example.uncommitted_diff = mem::take(&mut text);
}
Section::EditHistory => {
example.edit_history.push_str(&mem::take(&mut text));
}
Section::CursorPosition => {
example.cursor_path = Path::new(block_info).into();
example.cursor_position = mem::take(&mut text);
}
Section::ExpectedExcerpts => {
mem::take(&mut text);
}
Section::ExpectedPatch => {
example.expected_patch = mem::take(&mut text);
}
Section::Other => {}
}
})
.collect()
}
pub async fn cursor_position(
&self,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> Result<(Entity<Buffer>, Anchor)> {
let worktree = project.read_with(cx, |project, cx| {
project.visible_worktrees(cx).next().unwrap()
})?;
let cursor_path = RelPath::new(&self.example.cursor_path, PathStyle::Posix)?.into_arc();
let cursor_buffer = project
.update(cx, |project, cx| {
project.open_buffer(
ProjectPath {
worktree_id: worktree.read(cx).id(),
path: cursor_path,
},
cx,
)
})?
.await?;
let cursor_offset_within_excerpt = self
.example
.cursor_position
.find(CURSOR_MARKER)
.ok_or_else(|| anyhow!("missing cursor marker"))?;
let mut cursor_excerpt = self.example.cursor_position.clone();
cursor_excerpt.replace_range(
cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
"",
);
let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
let text = buffer.text();
let mut matches = text.match_indices(&cursor_excerpt);
let Some((excerpt_offset, _)) = matches.next() else {
anyhow::bail!(
"\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Cursor excerpt did not exist in buffer."
);
};
assert!(matches.next().is_none());
Ok(excerpt_offset)
})??;
let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
let cursor_anchor =
cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
Ok((cursor_buffer, cursor_anchor))
}
#[must_use]
pub async fn apply_edit_history(
&self,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> Result<OpenedBuffers<'_>> {
edit_prediction::udiff::apply_diff(&self.example.edit_history, project, cx).await
}
}
async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
let output = smol::process::Command::new("git")
.current_dir(repo_path)
.args(args)
.output()
.await?;
anyhow::ensure!(
output.status.success(),
"`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
args.join(" "),
repo_path.display(),
output.status,
String::from_utf8_lossy(&output.stderr),
String::from_utf8_lossy(&output.stdout),
);
Ok(String::from_utf8(output.stdout)?.trim().to_string())
}
impl Display for NamedExample {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "# {}\n\n", self.name)?;
write!(
f,
"{REPOSITORY_URL_FIELD} = {}\n",
self.example.repository_url
)?;
write!(f, "{REVISION_FIELD} = {}\n\n", self.example.revision)?;
write!(f, "## {UNCOMMITTED_DIFF_HEADING}\n\n")?;
write!(f, "`````diff\n")?;
write!(f, "{}", self.example.uncommitted_diff)?;
write!(f, "`````\n")?;
if !self.example.edit_history.is_empty() {
write!(f, "`````diff\n{}`````\n", self.example.edit_history)?;
}
_ => {}
}
write!(
f,
"## {CURSOR_POSITION_HEADING}\n\n`````{}\n{}`````\n",
self.example.cursor_path.display(),
self.example.cursor_position
)?;
write!(f, "## {EDIT_HISTORY_HEADING}\n\n")?;
if !self.example.expected_patch.is_empty() {
write!(
f,
"\n## {EXPECTED_PATCH_HEADING}\n\n`````diff\n{}`````\n",
self.example.expected_patch
)?;
}
Ok(())
}
}
if example.cursor_path.as_ref() == Path::new("") || example.cursor_position.is_empty() {
anyhow::bail!("Missing cursor position codeblock");
}
thread_local! {
static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
}
#[must_use]
pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
REPO_LOCKS
.with(|cell| {
cell.borrow_mut()
.entry(path.as_ref().to_path_buf())
.or_default()
.clone()
})
.lock_owned()
.await
Ok(example)
}

View File

@@ -0,0 +1,280 @@
use crate::{
PromptFormat,
example::{Example, ExamplePrompt},
headless::EpAppState,
retrieve_context::run_context_retrieval,
};
use edit_prediction::{EditPredictionStore, zeta2::zeta2_prompt_input};
use gpui::AsyncApp;
use std::sync::Arc;
use zeta_prompt::format_zeta_prompt;
pub async fn run_format_prompt(
example: &mut Example,
prompt_format: PromptFormat,
app_state: Arc<EpAppState>,
mut cx: AsyncApp,
) {
run_context_retrieval(example, app_state, cx.clone()).await;
let prompt = match prompt_format {
PromptFormat::Teacher => TeacherPrompt::format(example),
PromptFormat::Zeta2 => {
let ep_store = cx
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
.unwrap();
let state = example.state.as_ref().unwrap();
let snapshot = state
.buffer
.read_with(&cx, |buffer, _| buffer.snapshot())
.unwrap();
let project = state.project.clone();
let (_, input) = ep_store
.update(&mut cx, |ep_store, _cx| {
zeta2_prompt_input(
&snapshot,
example.context.as_ref().unwrap().files.clone(),
ep_store.edit_history_for_project(&project),
example.cursor_path.clone(),
example.buffer.as_ref().unwrap().cursor_offset,
)
})
.unwrap();
format_zeta_prompt(&input)
}
};
example.prompt = Some(ExamplePrompt {
input: prompt,
expected_output: example.expected_patch.clone(), // TODO
format: prompt_format,
});
}
pub trait PromptFormatter {
fn format(example: &Example) -> String;
}
pub trait PromptParser {
/// Return unified diff patch of prediction given raw LLM response
fn parse(example: &Example, response: &str) -> String;
}
pub struct TeacherPrompt;
impl PromptFormatter for TeacherPrompt {
fn format(example: &Example) -> String {
let edit_history = Self::format_edit_history(&example.edit_history);
let context = Self::format_context(example);
let editable_region = Self::format_editable_region(example);
let prompt = Self::PROMPT
.replace("{{context}}", &context)
.replace("{{edit_history}}", &edit_history)
.replace("{{editable_region}}", &editable_region);
prompt
}
}
impl TeacherPrompt {
const PROMPT: &str = include_str!("teacher.prompt.md");
pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
/// Truncate edit history to this number of last lines
const MAX_HISTORY_LINES: usize = 128;
fn format_edit_history(edit_history: &str) -> String {
// Strip comments ("garbage lines") from edit history
let lines = edit_history
.lines()
.filter(|&s| Self::is_udiff_content_line(s))
.collect::<Vec<_>>();
let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
&lines[lines.len() - Self::MAX_HISTORY_LINES..]
} else {
&lines
};
if history_lines.is_empty() {
return "(No edit history)".to_string();
}
history_lines.join("\n")
}
fn format_context(example: &Example) -> String {
if example.context.is_none() {
panic!("Missing context retriever step");
}
let mut prompt = String::new();
zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);
prompt
}
fn format_editable_region(example: &Example) -> String {
let mut result = String::new();
let path_str = example.cursor_path.to_string_lossy();
result.push_str(&format!("`````path=\"{path_str}\"\n"));
result.push_str(Self::EDITABLE_REGION_START);
// TODO: control number of lines around cursor
result.push_str(&example.cursor_position);
if !example.cursor_position.ends_with('\n') {
result.push('\n');
}
result.push_str(&format!("{}\n", Self::EDITABLE_REGION_END));
result.push_str("`````");
result
}
fn extract_editable_region(text: &str) -> String {
let start = text
.find(Self::EDITABLE_REGION_START)
.map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
let region = &text[start..end];
region.replace("<|user_cursor|>", "")
}
fn is_udiff_content_line(s: &str) -> bool {
s.starts_with("-")
|| s.starts_with("+")
|| s.starts_with(" ")
|| s.starts_with("---")
|| s.starts_with("+++")
|| s.starts_with("@@")
}
}
impl PromptParser for TeacherPrompt {
fn parse(example: &Example, response: &str) -> String {
// Ideally, we should always be able to find cursor position in the retrieved context.
// In reality, sometimes we don't find it for these reasons:
// 1. `example.cursor_position` contains _more_ context than included in the retrieved context
// (can be fixed by getting cursor coordinates at the load_example stage)
// 2. Context retriever just didn't include cursor line.
//
// In that case, fallback to using `cursor_position` as excerpt.
let cursor_file = &example
.buffer
.as_ref()
.expect("`buffer` should be filled in in the context collection step")
.content;
// Extract updated (new) editable region from the model response
let new_editable_region = extract_last_codeblock(response);
// Reconstruct old editable region we sent to the model
let old_editable_region = Self::format_editable_region(example);
let old_editable_region = Self::extract_editable_region(&old_editable_region);
if !cursor_file.contains(&old_editable_region) {
panic!("Something's wrong: editable_region is not found in the cursor file")
}
// Apply editable region to a larger context and compute diff.
// This is needed to get a better context lines around the editable region
let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
let diff = language::unified_diff(&cursor_file, &edited_file);
let diff = indoc::formatdoc! {"
--- a/{path}
+++ b/{path}
{diff}
",
path = example.cursor_path.to_string_lossy(),
diff = diff,
};
diff
}
}
fn extract_last_codeblock(text: &str) -> String {
let mut last_block = None;
let mut search_start = 0;
while let Some(start) = text[search_start..].find("```") {
let start = start + search_start;
let bytes = text.as_bytes();
let mut backtick_end = start;
while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
backtick_end += 1;
}
let backtick_count = backtick_end - start;
let closing_backticks = "`".repeat(backtick_count);
while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
backtick_end += 1;
}
if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1];
last_block = Some(code_block.to_string());
search_start = backtick_end + end_pos + backtick_count;
} else {
break;
}
}
last_block.unwrap_or_else(|| text.to_string())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_last_code_block() {
let text = indoc::indoc! {"
Some thinking
```
first block
```
`````path='something' lines=1:2
last block
`````
"};
let last_block = extract_last_codeblock(text);
assert_eq!(last_block, "last block");
}
#[test]
fn test_extract_editable_region() {
let text = indoc::indoc! {"
some lines
are
here
<|editable_region_start|>
one
two three
<|editable_region_end|>
more
lines here
"};
let parsed = TeacherPrompt::extract_editable_region(text);
assert_eq!(
parsed,
indoc::indoc! {"
one
two three
"}
);
}
}

View File

@@ -16,7 +16,7 @@ use std::sync::Arc;
use util::ResultExt as _;
/// Headless subset of `workspace::AppState`.
pub struct ZetaCliAppState {
pub struct EpAppState {
pub languages: Arc<LanguageRegistry>,
pub client: Arc<Client>,
pub user_store: Entity<UserStore>,
@@ -25,7 +25,7 @@ pub struct ZetaCliAppState {
}
// TODO: dedupe with crates/eval/src/eval.rs
pub fn init(cx: &mut App) -> ZetaCliAppState {
pub fn init(cx: &mut App) -> EpAppState {
let app_commit_sha = option_env!("ZED_COMMIT_SHA").map(|s| AppCommitSha::new(s.to_owned()));
let app_version = AppVersion::load(
@@ -112,7 +112,7 @@ pub fn init(cx: &mut App) -> ZetaCliAppState {
prompt_store::init(cx);
terminal_view::init(cx);
ZetaCliAppState {
EpAppState {
languages,
client,
user_store,

View File

@@ -0,0 +1,320 @@
use crate::{
example::{Example, ExampleBuffer, ExampleState},
headless::EpAppState,
};
use anyhow::{Result, anyhow};
use collections::HashMap;
use edit_prediction::EditPredictionStore;
use edit_prediction::udiff::OpenedBuffers;
use futures::{
AsyncWriteExt as _,
lock::{Mutex, OwnedMutexGuard},
};
use gpui::{AsyncApp, Entity};
use language::{Anchor, Buffer, ToOffset, ToPoint};
use project::buffer_store::BufferStoreEvent;
use project::{Project, ProjectPath};
use std::{
cell::RefCell,
fs,
path::{Path, PathBuf},
sync::Arc,
};
use util::{paths::PathStyle, rel_path::RelPath};
use zeta_prompt::CURSOR_MARKER;
pub async fn run_load_project(example: &mut Example, app_state: Arc<EpAppState>, mut cx: AsyncApp) {
if example.state.is_some() {
return;
}
let project = setup_project(example, &app_state, &mut cx).await;
let buffer_store = project
.read_with(&cx, |project, _| project.buffer_store().clone())
.unwrap();
let ep_store = cx
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
.unwrap();
cx.subscribe(&buffer_store, {
let project = project.clone();
move |_, event, cx| match event {
BufferStoreEvent::BufferAdded(buffer) => {
ep_store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
}
_ => {}
}
})
.unwrap()
.detach();
let _open_buffers = apply_edit_history(example, &project, &mut cx)
.await
.unwrap();
let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await;
example.buffer = buffer
.read_with(&cx, |buffer, _cx| {
let cursor_point = cursor_position.to_point(&buffer);
Some(ExampleBuffer {
content: buffer.text(),
cursor_row: cursor_point.row,
cursor_column: cursor_point.column,
cursor_offset: cursor_position.to_offset(&buffer),
})
})
.unwrap();
example.state = Some(ExampleState {
buffer,
project,
cursor_position,
_open_buffers,
});
}
async fn cursor_position(
example: &Example,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> (Entity<Buffer>, Anchor) {
let worktree = project
.read_with(cx, |project, cx| {
project.visible_worktrees(cx).next().unwrap()
})
.unwrap();
let cursor_path = RelPath::new(&example.cursor_path, PathStyle::Posix)
.unwrap()
.into_arc();
let cursor_buffer = project
.update(cx, |project, cx| {
project.open_buffer(
ProjectPath {
worktree_id: worktree.read(cx).id(),
path: cursor_path,
},
cx,
)
})
.unwrap()
.await
.unwrap();
let cursor_offset_within_excerpt = example
.cursor_position
.find(CURSOR_MARKER)
.ok_or_else(|| anyhow!("missing cursor marker"))
.unwrap();
let mut cursor_excerpt = example.cursor_position.clone();
cursor_excerpt.replace_range(
cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
"",
);
let excerpt_offset = cursor_buffer.read_with(cx, |buffer, _cx| {
let text = buffer.text();
let mut matches = text.match_indices(&cursor_excerpt);
let (excerpt_offset, _) = matches.next().unwrap_or_else(|| {
panic!(
"\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Cursor excerpt did not exist in buffer."
);
});
assert!(matches.next().is_none(), "More than one cursor position match found for {}", &example.name);
excerpt_offset
}).unwrap();
let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
let cursor_anchor = cursor_buffer
.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))
.unwrap();
(cursor_buffer, cursor_anchor)
}
async fn setup_project(
example: &mut Example,
app_state: &Arc<EpAppState>,
cx: &mut AsyncApp,
) -> Entity<Project> {
setup_worktree(example).await;
let project = cx
.update(|cx| {
Project::local(
app_state.client.clone(),
app_state.node_runtime.clone(),
app_state.user_store.clone(),
app_state.languages.clone(),
app_state.fs.clone(),
None,
cx,
)
})
.unwrap();
let worktree = project
.update(cx, |project, cx| {
project.create_worktree(&example.worktree_path(), true, cx)
})
.unwrap()
.await
.unwrap();
worktree
.read_with(cx, |worktree, _cx| {
worktree.as_local().unwrap().scan_complete()
})
.unwrap()
.await;
project
}
pub async fn setup_worktree(example: &Example) {
let repo_dir = example.repo_path();
let repo_lock = lock_repo(&repo_dir).await;
if !repo_dir.is_dir() {
fs::create_dir_all(&repo_dir).unwrap();
run_git(&repo_dir, &["init"]).await.unwrap();
run_git(
&repo_dir,
&["remote", "add", "origin", &example.repository_url],
)
.await
.unwrap();
}
// Resolve the example to a revision, fetching it if needed.
let revision = run_git(
&repo_dir,
&["rev-parse", &format!("{}^{{commit}}", example.revision)],
)
.await;
let revision = if let Ok(revision) = revision {
revision
} else {
if run_git(
&repo_dir,
&["fetch", "--depth", "1", "origin", &example.revision],
)
.await
.is_err()
{
run_git(&repo_dir, &["fetch", "origin"]).await.unwrap();
}
let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"])
.await
.unwrap();
if revision != example.revision {
run_git(&repo_dir, &["tag", &example.revision, &revision])
.await
.unwrap();
}
revision
};
// Create the worktree for this example if needed.
let worktree_path = example.worktree_path();
if worktree_path.is_dir() {
run_git(&worktree_path, &["clean", "--force", "-d"])
.await
.unwrap();
run_git(&worktree_path, &["reset", "--hard", "HEAD"])
.await
.unwrap();
run_git(&worktree_path, &["checkout", revision.as_str()])
.await
.unwrap();
} else {
let worktree_path_string = worktree_path.to_string_lossy();
run_git(
&repo_dir,
&["branch", "-f", &example.name, revision.as_str()],
)
.await
.unwrap();
run_git(
&repo_dir,
&[
"worktree",
"add",
"-f",
&worktree_path_string,
&example.name,
],
)
.await
.unwrap();
}
drop(repo_lock);
// Apply the uncommitted diff for this example.
if !example.uncommitted_diff.is_empty() {
let mut apply_process = smol::process::Command::new("git")
.current_dir(&worktree_path)
.args(&["apply", "-"])
.stdin(std::process::Stdio::piped())
.spawn()
.unwrap();
let mut stdin = apply_process.stdin.take().unwrap();
stdin
.write_all(example.uncommitted_diff.as_bytes())
.await
.unwrap();
stdin.close().await.unwrap();
drop(stdin);
let apply_result = apply_process.output().await.unwrap();
if !apply_result.status.success() {
panic!(
"Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
apply_result.status,
String::from_utf8_lossy(&apply_result.stderr),
String::from_utf8_lossy(&apply_result.stdout),
);
}
}
}
async fn apply_edit_history(
example: &Example,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> Result<OpenedBuffers> {
edit_prediction::udiff::apply_diff(&example.edit_history, project, cx).await
}
thread_local! {
static REPO_LOCKS: RefCell<HashMap<PathBuf, Arc<Mutex<()>>>> = RefCell::new(HashMap::default());
}
#[must_use]
pub async fn lock_repo(path: impl AsRef<Path>) -> OwnedMutexGuard<()> {
REPO_LOCKS
.with(|cell| {
cell.borrow_mut()
.entry(path.as_ref().to_path_buf())
.or_default()
.clone()
})
.lock_owned()
.await
}
async fn run_git(repo_path: &Path, args: &[&str]) -> Result<String> {
let output = smol::process::Command::new("git")
.current_dir(repo_path)
.args(args)
.output()
.await?;
anyhow::ensure!(
output.status.success(),
"`git {}` within `{}` failed with status: {}\nstderr:\n{}\nstdout:\n{}",
args.join(" "),
repo_path.display(),
output.status,
String::from_utf8_lossy(&output.stderr),
String::from_utf8_lossy(&output.stdout),
);
Ok(String::from_utf8(output.stdout)?.trim().to_string())
}

View File

@@ -1,522 +1,196 @@
mod evaluate;
mod anthropic_client;
mod example;
mod format_prompt;
mod headless;
mod load_project;
mod metrics;
mod paths;
mod predict;
mod source_location;
mod training;
mod util;
mod retrieve_context;
mod score;
use crate::{
evaluate::run_evaluate,
example::{ExampleFormat, NamedExample},
headless::ZetaCliAppState,
predict::run_predict,
source_location::SourceLocation,
training::{context::ContextType, distill::run_distill},
util::{open_buffer, open_buffer_with_language_server},
};
use ::util::{ResultExt, paths::PathStyle};
use anyhow::{Result, anyhow};
use clap::{Args, Parser, Subcommand, ValueEnum};
use cloud_llm_client::predict_edits_v3;
use edit_prediction::udiff::DiffLine;
use edit_prediction_context::EditPredictionExcerptOptions;
use gpui::{Application, AsyncApp, Entity, prelude::*};
use language::{Bias, Buffer, BufferSnapshot, Point};
use metrics::delta_chr_f;
use project::{Project, Worktree, lsp_store::OpenLspBufferHandle};
use clap::{Args, CommandFactory, Parser, Subcommand, ValueEnum};
use edit_prediction::EditPredictionStore;
use gpui::Application;
use reqwest_client::ReqwestClient;
use std::io::{self};
use std::{collections::HashSet, path::PathBuf, str::FromStr, sync::Arc};
use serde::{Deserialize, Serialize};
use std::{path::PathBuf, sync::Arc};
use crate::example::{read_examples, write_examples};
use crate::format_prompt::run_format_prompt;
use crate::load_project::run_load_project;
use crate::predict::run_prediction;
use crate::retrieve_context::run_context_retrieval;
use crate::score::run_scoring;
#[derive(Parser, Debug)]
#[command(name = "zeta")]
struct ZetaCliArgs {
#[command(name = "ep")]
struct EpArgs {
#[arg(long, default_value_t = false)]
printenv: bool,
#[clap(long, default_value_t = 10)]
max_parallelism: usize,
#[command(subcommand)]
command: Option<Command>,
#[clap(global = true)]
inputs: Vec<PathBuf>,
#[arg(long, short, global = true)]
output: Option<PathBuf>,
#[arg(long, short, global = true)]
in_place: bool,
}
#[derive(Subcommand, Debug)]
enum Command {
Context(ContextArgs),
Predict(PredictArguments),
Eval(EvaluateArguments),
Distill(DistillArguments),
ConvertExample {
path: PathBuf,
#[arg(long, value_enum, default_value_t = ExampleFormat::Md)]
output_format: ExampleFormat,
},
Score {
golden_patch: PathBuf,
actual_patch: PathBuf,
},
/// Parse markdown examples and output a combined .jsonl file
ParseExample,
/// Create git worktrees for each example and load file contents
LoadBuffer,
/// Retrieve context for input examples.
Context,
/// Generate a prompt string for a specific model
FormatPrompt(FormatPromptArgs),
/// Runs edit prediction
Predict(PredictArgs),
/// Computes a score based on actual and expected patches
Score(PredictArgs),
/// Print aggregated scores
Eval(PredictArgs),
/// Remove git repositories and worktrees
Clean,
}
#[derive(Debug, Args)]
struct ContextArgs {
#[arg(long)]
provider: ContextProvider,
#[arg(long)]
worktree: PathBuf,
#[arg(long)]
cursor: SourceLocation,
#[arg(long)]
use_language_server: bool,
#[arg(long)]
edit_history: Option<FileOrStdin>,
#[clap(flatten)]
zeta2_args: Zeta2Args,
struct FormatPromptArgs {
#[clap(long)]
prompt_format: PromptFormat,
}
#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
enum ContextProvider {
Zeta1,
#[default]
#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
enum PromptFormat {
Teacher,
Zeta2,
}
#[derive(Clone, Debug, Args)]
struct Zeta2Args {
#[arg(long, default_value_t = 8192)]
max_prompt_bytes: usize,
#[arg(long, default_value_t = 2048)]
max_excerpt_bytes: usize,
#[arg(long, default_value_t = 1024)]
min_excerpt_bytes: usize,
#[arg(long, default_value_t = 0.66)]
target_before_cursor_over_total_bytes: f32,
#[arg(long, default_value_t = 1024)]
max_diagnostic_bytes: usize,
#[arg(long, value_enum, default_value_t = PromptFormat::default())]
prompt_format: PromptFormat,
#[arg(long, value_enum, default_value_t = Default::default())]
output_format: OutputFormat,
#[arg(long, default_value_t = 42)]
file_indexing_parallelism: usize,
#[arg(long, default_value_t = false)]
disable_imports_gathering: bool,
#[arg(long, default_value_t = u8::MAX)]
max_retrieved_definitions: u8,
}
#[derive(Debug, Args)]
pub struct PredictArguments {
#[clap(long, short, value_enum, default_value_t = PredictionsOutputFormat::Md)]
format: PredictionsOutputFormat,
example_path: PathBuf,
#[clap(flatten)]
options: PredictionOptions,
}
#[derive(Debug, Args)]
pub struct DistillArguments {
split_commit_dataset: PathBuf,
#[clap(long, value_enum, default_value_t = ContextType::CurrentFile)]
context_type: ContextType,
#[clap(long)]
batch: Option<String>,
}
#[derive(Clone, Debug, Args)]
pub struct PredictionOptions {
#[clap(flatten)]
zeta2: Zeta2Args,
struct PredictArgs {
#[clap(long)]
provider: PredictionProvider,
#[clap(long, value_enum, default_value_t = CacheMode::default())]
cache: CacheMode,
#[clap(long, default_value_t = 1)]
repetitions: usize,
}
#[derive(Debug, ValueEnum, Default, Clone, Copy, PartialEq)]
pub enum CacheMode {
/// Use cached LLM requests and responses, except when multiple repetitions are requested
#[default]
Auto,
/// Use cached LLM requests and responses, based on the hash of the prompt and the endpoint.
#[value(alias = "request")]
Requests,
/// Ignore existing cache entries for both LLM and search.
Skip,
/// Use cached LLM responses AND search results for full determinism. Fails if they haven't been cached yet.
/// Useful for reproducing results and fixing bugs outside of search queries
Force,
}
impl CacheMode {
fn use_cached_llm_responses(&self) -> bool {
self.assert_not_auto();
matches!(self, CacheMode::Requests | CacheMode::Force)
}
fn use_cached_search_results(&self) -> bool {
self.assert_not_auto();
matches!(self, CacheMode::Force)
}
fn assert_not_auto(&self) {
assert_ne!(
*self,
CacheMode::Auto,
"Cache mode should not be auto at this point!"
);
}
}
#[derive(clap::ValueEnum, Debug, Clone)]
pub enum PredictionsOutputFormat {
Json,
Md,
Diff,
}
#[derive(Debug, Args)]
pub struct EvaluateArguments {
example_paths: Vec<PathBuf>,
#[clap(flatten)]
options: PredictionOptions,
#[clap(short, long, default_value_t = 1, alias = "repeat")]
repetitions: u16,
#[arg(long)]
skip_prediction: bool,
}
#[derive(clap::ValueEnum, Default, Debug, Clone, Copy, PartialEq)]
#[derive(Clone, Copy, Debug, ValueEnum, Serialize, Deserialize)]
enum PredictionProvider {
Zeta1,
#[default]
Zeta2,
Sweep,
Mercury,
Zeta1,
Zeta2,
Teacher,
}
fn zeta2_args_to_options(args: &Zeta2Args) -> edit_prediction::ZetaOptions {
edit_prediction::ZetaOptions {
context: EditPredictionExcerptOptions {
max_bytes: args.max_excerpt_bytes,
min_bytes: args.min_excerpt_bytes,
target_before_cursor_over_total_bytes: args.target_before_cursor_over_total_bytes,
},
max_prompt_bytes: args.max_prompt_bytes,
prompt_format: args.prompt_format.into(),
}
}
#[derive(clap::ValueEnum, Default, Debug, Clone, Copy)]
enum PromptFormat {
OnlySnippets,
#[default]
OldTextNewText,
Minimal,
MinimalQwen,
SeedCoder1120,
}
impl Into<predict_edits_v3::PromptFormat> for PromptFormat {
fn into(self) -> predict_edits_v3::PromptFormat {
match self {
Self::OnlySnippets => predict_edits_v3::PromptFormat::OnlySnippets,
Self::OldTextNewText => predict_edits_v3::PromptFormat::OldTextNewText,
Self::Minimal => predict_edits_v3::PromptFormat::Minimal,
Self::MinimalQwen => predict_edits_v3::PromptFormat::MinimalQwen,
Self::SeedCoder1120 => predict_edits_v3::PromptFormat::SeedCoder1120,
}
}
}
#[derive(clap::ValueEnum, Default, Debug, Clone)]
enum OutputFormat {
#[default]
Prompt,
Request,
Full,
}
#[derive(Debug, Clone)]
enum FileOrStdin {
File(PathBuf),
Stdin,
}
impl FileOrStdin {
async fn read_to_string(&self) -> Result<String, std::io::Error> {
match self {
FileOrStdin::File(path) => smol::fs::read_to_string(path).await,
FileOrStdin::Stdin => smol::unblock(|| std::io::read_to_string(std::io::stdin())).await,
}
}
}
impl FromStr for FileOrStdin {
type Err = <PathBuf as FromStr>::Err;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"-" => Ok(Self::Stdin),
_ => Ok(Self::File(PathBuf::from_str(s)?)),
}
}
}
struct LoadedContext {
full_path_str: String,
snapshot: BufferSnapshot,
clipped_cursor: Point,
worktree: Entity<Worktree>,
project: Entity<Project>,
buffer: Entity<Buffer>,
lsp_open_handle: Option<OpenLspBufferHandle>,
}
async fn load_context(
args: &ContextArgs,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
) -> Result<LoadedContext> {
let ContextArgs {
worktree: worktree_path,
cursor,
use_language_server,
..
} = args;
let worktree_path = worktree_path.canonicalize()?;
let project = cx.update(|cx| {
Project::local(
app_state.client.clone(),
app_state.node_runtime.clone(),
app_state.user_store.clone(),
app_state.languages.clone(),
app_state.fs.clone(),
None,
cx,
)
})?;
let worktree = project
.update(cx, |project, cx| {
project.create_worktree(&worktree_path, true, cx)
})?
.await?;
let mut ready_languages = HashSet::default();
let (lsp_open_handle, buffer) = if *use_language_server {
let (lsp_open_handle, _, buffer) = open_buffer_with_language_server(
project.clone(),
worktree.clone(),
cursor.path.clone(),
&mut ready_languages,
cx,
)
.await?;
(Some(lsp_open_handle), buffer)
} else {
let buffer =
open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?;
(None, buffer)
};
let full_path_str = worktree
.read_with(cx, |worktree, _| worktree.root_name().join(&cursor.path))?
.display(PathStyle::local())
.to_string();
let snapshot = cx.update(|cx| buffer.read(cx).snapshot())?;
let clipped_cursor = snapshot.clip_point(cursor.point, Bias::Left);
if clipped_cursor != cursor.point {
let max_row = snapshot.max_point().row;
if cursor.point.row < max_row {
return Err(anyhow!(
"Cursor position {:?} is out of bounds (line length is {})",
cursor.point,
snapshot.line_len(cursor.point.row)
));
impl EpArgs {
fn output_path(&self) -> Option<PathBuf> {
if self.in_place {
if self.inputs.len() == 1 {
self.inputs.first().cloned()
} else {
panic!("--in-place requires exactly one input file")
}
} else {
return Err(anyhow!(
"Cursor position {:?} is out of bounds (max row is {})",
cursor.point,
max_row
));
self.output.clone()
}
}
Ok(LoadedContext {
full_path_str,
snapshot,
clipped_cursor,
worktree,
project,
buffer,
lsp_open_handle,
})
}
async fn zeta2_context(
args: ContextArgs,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
) -> Result<String> {
let LoadedContext {
worktree,
project,
buffer,
clipped_cursor,
lsp_open_handle: _handle,
..
} = load_context(&args, app_state, cx).await?;
// wait for worktree scan before starting zeta2 so that wait_for_initial_indexing waits for
// the whole worktree.
worktree
.read_with(cx, |worktree, _cx| {
worktree.as_local().unwrap().scan_complete()
})?
.await;
let output = cx
.update(|cx| {
let store = cx.new(|cx| {
edit_prediction::EditPredictionStore::new(
app_state.client.clone(),
app_state.user_store.clone(),
cx,
)
});
store.update(cx, |store, cx| {
store.set_options(zeta2_args_to_options(&args.zeta2_args));
store.register_buffer(&buffer, &project, cx);
});
cx.spawn(async move |cx| {
let updates_rx = store.update(cx, |store, cx| {
let cursor = buffer.read(cx).snapshot().anchor_before(clipped_cursor);
store.set_use_context(true);
store.refresh_context(&project, &buffer, cursor, cx);
store.project_context_updates(&project).unwrap()
})?;
updates_rx.recv().await.ok();
let context = store.update(cx, |store, cx| {
store.context_for_project(&project, cx).to_vec()
})?;
anyhow::Ok(serde_json::to_string_pretty(&context).unwrap())
})
})?
.await?;
Ok(output)
}
async fn zeta1_context(
args: ContextArgs,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
) -> Result<edit_prediction::zeta1::GatherContextOutput> {
let LoadedContext {
full_path_str,
snapshot,
clipped_cursor,
..
} = load_context(&args, app_state, cx).await?;
let events = match args.edit_history {
Some(events) => events.read_to_string().await?,
None => String::new(),
};
let prompt_for_events = move || (events, 0);
cx.update(|cx| {
edit_prediction::zeta1::gather_context(
full_path_str,
&snapshot,
clipped_cursor,
prompt_for_events,
cloud_llm_client::PredictEditsRequestTrigger::Cli,
cx,
)
})?
.await
}
fn main() {
zlog::init();
zlog::init_output_stderr();
let args = ZetaCliArgs::parse();
let args = EpArgs::parse();
if args.printenv {
::util::shell_env::print_env();
return;
}
let output = args.output_path();
let command = match args.command {
Some(cmd) => cmd,
None => {
EpArgs::command().print_help().unwrap();
return;
}
};
match &command {
Command::Clean => {
std::fs::remove_dir_all(&*paths::DATA_DIR).unwrap();
return;
}
_ => {}
}
let mut examples = read_examples(&args.inputs);
let http_client = Arc::new(ReqwestClient::new());
let app = Application::headless().with_http_client(http_client);
app.run(move |cx| {
let app_state = Arc::new(headless::init(cx));
EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
cx.spawn(async move |cx| {
match args.command {
None => {
if args.printenv {
::util::shell_env::print_env();
} else {
panic!("Expected a command");
}
}
Some(Command::Context(context_args)) => {
let result = match context_args.provider {
ContextProvider::Zeta1 => {
let context =
zeta1_context(context_args, &app_state, cx).await.unwrap();
serde_json::to_string_pretty(&context.body).unwrap()
match &command {
Command::Predict(args) => predict::sync_batches(&args.provider).await,
_ => (),
};
for data in examples.chunks_mut(args.max_parallelism) {
let mut futures = Vec::new();
for example in data.iter_mut() {
let cx = cx.clone();
let app_state = app_state.clone();
futures.push(async {
match &command {
Command::ParseExample => {}
Command::LoadBuffer => {
run_load_project(example, app_state.clone(), cx).await;
}
Command::Context => {
run_context_retrieval(example, app_state, cx).await;
}
Command::FormatPrompt(args) => {
run_format_prompt(example, args.prompt_format, app_state, cx).await;
}
Command::Predict(args) => {
run_prediction(
example,
Some(args.provider),
args.repetitions,
app_state.clone(),
cx,
)
.await;
}
Command::Score(args) | Command::Eval(args) => {
run_scoring(example, &args, app_state, cx).await;
}
Command::Clean => {
unreachable!()
}
}
ContextProvider::Zeta2 => {
zeta2_context(context_args, &app_state, cx).await.unwrap()
}
};
println!("{}", result);
});
}
Some(Command::Predict(arguments)) => {
run_predict(arguments, &app_state, cx).await;
}
Some(Command::Eval(arguments)) => {
run_evaluate(arguments, &app_state, cx).await;
}
Some(Command::Distill(arguments)) => {
let _guard = cx
.update(|cx| gpui_tokio::Tokio::handle(cx))
.unwrap()
.enter();
run_distill(arguments).await.log_err();
}
Some(Command::ConvertExample {
path,
output_format,
}) => {
let example = NamedExample::load(path).unwrap();
example.write(output_format, io::stdout()).unwrap();
}
Some(Command::Score {
golden_patch,
actual_patch,
}) => {
let golden_content = std::fs::read_to_string(golden_patch).unwrap();
let actual_content = std::fs::read_to_string(actual_patch).unwrap();
futures::future::join_all(futures).await;
}
let golden_diff: Vec<DiffLine> = golden_content
.lines()
.map(|line| DiffLine::parse(line))
.collect();
if args.output.is_some() || !matches!(command, Command::Eval(_)) {
write_examples(&examples, output.as_ref());
}
let actual_diff: Vec<DiffLine> = actual_content
.lines()
.map(|line| DiffLine::parse(line))
.collect();
let score = delta_chr_f(&golden_diff, &actual_diff);
println!("{:.2}", score);
}
Some(Command::Clean) => {
std::fs::remove_dir_all(&*crate::paths::TARGET_ZETA_DIR).unwrap()
}
match &command {
Command::Predict(args) => predict::sync_batches(&args.provider).await,
Command::Eval(_) => score::print_report(&examples),
_ => (),
};
let _ = cx.update(|cx| cx.quit());

View File

@@ -1,30 +1,34 @@
use collections::{HashMap, HashSet};
use edit_prediction::udiff::DiffLine;
use serde::{Deserialize, Serialize};
type Counts = HashMap<String, usize>;
type CountsDelta = HashMap<String, isize>;
#[derive(Default, Debug, Clone)]
pub struct Scores {
#[derive(Default, Debug, Clone, Serialize, Deserialize)]
pub struct ClassificationMetrics {
pub true_positives: usize,
pub false_positives: usize,
pub false_negatives: usize,
}
impl Scores {
pub fn from_sets(expected: &HashSet<String>, actual: &HashSet<String>) -> Scores {
impl ClassificationMetrics {
pub fn from_sets(
expected: &HashSet<String>,
actual: &HashSet<String>,
) -> ClassificationMetrics {
let true_positives = expected.intersection(actual).count();
let false_positives = actual.difference(expected).count();
let false_negatives = expected.difference(actual).count();
Scores {
ClassificationMetrics {
true_positives,
false_positives,
false_negatives,
}
}
pub fn from_counts(expected: &Counts, actual: &Counts) -> Scores {
pub fn from_counts(expected: &Counts, actual: &Counts) -> ClassificationMetrics {
let mut true_positives = 0;
let mut false_positives = 0;
let mut false_negatives = 0;
@@ -45,32 +49,16 @@ impl Scores {
}
}
Scores {
ClassificationMetrics {
true_positives,
false_positives,
false_negatives,
}
}
pub fn to_markdown(&self) -> String {
format!(
"
Precision : {:.4}
Recall : {:.4}
F1 Score : {:.4}
True Positives : {}
False Positives : {}
False Negatives : {}",
self.precision(),
self.recall(),
self.f1_score(),
self.true_positives,
self.false_positives,
self.false_negatives
)
}
pub fn aggregate<'a>(scores: impl Iterator<Item = &'a Scores>) -> Scores {
pub fn aggregate<'a>(
scores: impl Iterator<Item = &'a ClassificationMetrics>,
) -> ClassificationMetrics {
let mut true_positives = 0;
let mut false_positives = 0;
let mut false_negatives = 0;
@@ -81,7 +69,7 @@ False Negatives : {}",
false_negatives += score.false_negatives;
}
Scores {
ClassificationMetrics {
true_positives,
false_positives,
false_negatives,
@@ -115,7 +103,10 @@ False Negatives : {}",
}
}
pub fn line_match_score(expected_patch: &[DiffLine], actual_patch: &[DiffLine]) -> Scores {
pub fn line_match_score(
expected_patch: &[DiffLine],
actual_patch: &[DiffLine],
) -> ClassificationMetrics {
let expected_change_lines = expected_patch
.iter()
.filter(|line| matches!(line, DiffLine::Addition(_) | DiffLine::Deletion(_)))
@@ -128,7 +119,7 @@ pub fn line_match_score(expected_patch: &[DiffLine], actual_patch: &[DiffLine])
.map(|line| line.to_string())
.collect();
Scores::from_sets(&expected_change_lines, &actual_change_lines)
ClassificationMetrics::from_sets(&expected_change_lines, &actual_change_lines)
}
enum ChrfWhitespace {
@@ -204,7 +195,7 @@ pub fn delta_chr_f(expected: &[DiffLine], actual: &[DiffLine]) -> f64 {
let expected_counts = ngram_delta_to_counts(&expected_delta);
let actual_counts = ngram_delta_to_counts(&actual_delta);
let score = Scores::from_counts(&expected_counts, &actual_counts);
let score = ClassificationMetrics::from_counts(&expected_counts, &actual_counts);
total_precision += score.precision();
total_recall += score.recall();
}

View File

@@ -1,57 +1,25 @@
use std::{env, path::PathBuf, sync::LazyLock};
use std::{
path::{Path, PathBuf},
sync::LazyLock,
};
pub static TARGET_ZETA_DIR: LazyLock<PathBuf> =
LazyLock::new(|| env::current_dir().unwrap().join("target/zeta"));
pub static CACHE_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_ZETA_DIR.join("cache"));
pub static REPOS_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_ZETA_DIR.join("repos"));
pub static WORKTREES_DIR: LazyLock<PathBuf> = LazyLock::new(|| TARGET_ZETA_DIR.join("worktrees"));
pub static DATA_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
let dir = dirs::home_dir().unwrap().join(".zed_ep");
ensure_dir(&dir)
});
pub static CACHE_DIR: LazyLock<PathBuf> = LazyLock::new(|| ensure_dir(&DATA_DIR.join("cache")));
pub static REPOS_DIR: LazyLock<PathBuf> = LazyLock::new(|| ensure_dir(&DATA_DIR.join("repos")));
pub static WORKTREES_DIR: LazyLock<PathBuf> =
LazyLock::new(|| ensure_dir(&DATA_DIR.join("worktrees")));
pub static RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
TARGET_ZETA_DIR
DATA_DIR
.join("runs")
.join(chrono::Local::now().format("%d-%m-%y-%H_%M_%S").to_string())
});
pub static LATEST_EXAMPLE_RUN_DIR: LazyLock<PathBuf> =
LazyLock::new(|| TARGET_ZETA_DIR.join("latest"));
pub static LATEST_EXAMPLE_RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| DATA_DIR.join("latest"));
pub static LLM_CACHE_DB: LazyLock<PathBuf> = LazyLock::new(|| CACHE_DIR.join("llm_cache.sqlite"));
pub fn print_run_data_dir(deep: bool, use_color: bool) {
println!("\n## Run Data\n");
let mut files = Vec::new();
let current_dir = std::env::current_dir().unwrap();
for file in std::fs::read_dir(&*RUN_DIR).unwrap() {
let file = file.unwrap();
if file.file_type().unwrap().is_dir() && deep {
for file in std::fs::read_dir(file.path()).unwrap() {
let path = file.unwrap().path();
let path = path.strip_prefix(&current_dir).unwrap_or(&path);
files.push(format!(
"- {}/{}{}{}",
path.parent().unwrap().display(),
if use_color { "\x1b[34m" } else { "" },
path.file_name().unwrap().display(),
if use_color { "\x1b[0m" } else { "" },
));
}
} else {
let path = file.path();
let path = path.strip_prefix(&current_dir).unwrap_or(&path);
files.push(format!(
"- {}/{}{}{}",
path.parent().unwrap().display(),
if use_color { "\x1b[34m" } else { "" },
path.file_name().unwrap().display(),
if use_color { "\x1b[0m" } else { "" }
));
}
}
files.sort();
for file in files {
println!("{}", file);
}
println!(
"\n💡 Tip of the day: {} always points to the latest run\n",
LATEST_EXAMPLE_RUN_DIR.display()
);
fn ensure_dir(path: &Path) -> PathBuf {
std::fs::create_dir_all(path).expect("Failed to create directory");
path.to_path_buf()
}

View File

@@ -1,374 +1,271 @@
use crate::example::{ActualExcerpt, NamedExample};
use crate::headless::ZetaCliAppState;
use crate::paths::{CACHE_DIR, LATEST_EXAMPLE_RUN_DIR, RUN_DIR, print_run_data_dir};
use crate::{
CacheMode, PredictArguments, PredictionOptions, PredictionProvider, PredictionsOutputFormat,
PredictionProvider, PromptFormat,
anthropic_client::AnthropicClient,
example::{Example, ExamplePrediction},
format_prompt::{PromptParser, TeacherPrompt, run_format_prompt},
headless::EpAppState,
load_project::run_load_project,
paths::{LATEST_EXAMPLE_RUN_DIR, RUN_DIR},
retrieve_context::run_context_retrieval,
};
use edit_prediction::{DebugEvent, EditPredictionStore};
use futures::{FutureExt as _, StreamExt as _, future::Shared};
use gpui::{AppContext as _, AsyncApp, Task};
use std::{
fs,
sync::{
Arc, Mutex, OnceLock,
atomic::{AtomicUsize, Ordering::SeqCst},
},
};
use ::serde::Serialize;
use anyhow::{Context, Result, anyhow};
use cloud_zeta2_prompt::{CURSOR_MARKER, write_codeblock};
use edit_prediction::{EditPredictionStore, EvalCache, EvalCacheEntryKind, EvalCacheKey};
use futures::StreamExt as _;
use gpui::{AppContext, AsyncApp, Entity};
use project::Project;
use project::buffer_store::BufferStoreEvent;
use serde::Deserialize;
use std::fs;
use std::io::{IsTerminal, Write};
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::Mutex;
use std::time::{Duration, Instant};
pub async fn run_predict(
args: PredictArguments,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
pub async fn run_prediction(
example: &mut Example,
provider: Option<PredictionProvider>,
repetition_count: usize,
app_state: Arc<EpAppState>,
mut cx: AsyncApp,
) {
let example = NamedExample::load(args.example_path).unwrap();
let project = example.setup_project(app_state, cx).await.unwrap();
let store = setup_store(args.options.provider, &project, app_state, cx).unwrap();
let _edited_buffers = example.apply_edit_history(&project, cx).await.unwrap();
let result = perform_predict(example, project, store, None, args.options, cx)
.await
if !example.predictions.is_empty() {
return;
}
run_load_project(example, app_state.clone(), cx.clone()).await;
run_context_retrieval(example, app_state.clone(), cx.clone()).await;
let provider = provider.unwrap();
if matches!(provider, PredictionProvider::Teacher) {
if example.prompt.is_none() {
run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await;
}
let batched = true;
return predict_anthropic(example, repetition_count, batched).await;
}
if matches!(
provider,
PredictionProvider::Zeta1 | PredictionProvider::Zeta2
) {
static AUTHENTICATED: OnceLock<Shared<Task<()>>> = OnceLock::new();
AUTHENTICATED
.get_or_init(|| {
let client = app_state.client.clone();
cx.spawn(async move |cx| {
client
.sign_in_with_optional_connect(true, cx)
.await
.unwrap();
})
.shared()
})
.clone()
.await;
}
let ep_store = cx
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
.unwrap();
result.write(args.format, std::io::stdout()).unwrap();
print_run_data_dir(true, std::io::stdout().is_terminal());
}
ep_store
.update(&mut cx, |store, _cx| {
let model = match provider {
PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
PredictionProvider::Teacher => unreachable!(),
};
store.set_edit_prediction_model(model);
})
.unwrap();
let state = example.state.as_ref().unwrap();
let run_dir = RUN_DIR.join(&example.name);
pub fn setup_store(
provider: PredictionProvider,
project: &Entity<Project>,
app_state: &Arc<ZetaCliAppState>,
cx: &mut AsyncApp,
) -> Result<Entity<EditPredictionStore>> {
let store = cx.new(|cx| {
edit_prediction::EditPredictionStore::new(
app_state.client.clone(),
app_state.user_store.clone(),
cx,
)
})?;
let updated_example = Arc::new(Mutex::new(example.clone()));
let current_run_ix = Arc::new(AtomicUsize::new(0));
store.update(cx, |store, _cx| {
let model = match provider {
PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
};
store.set_edit_prediction_model(model);
})?;
let mut debug_rx = ep_store
.update(&mut cx, |store, cx| store.debug_info(&state.project, cx))
.unwrap();
let debug_task = cx.background_spawn({
let updated_example = updated_example.clone();
let current_run_ix = current_run_ix.clone();
let run_dir = run_dir.clone();
async move {
while let Some(event) = debug_rx.next().await {
let run_ix = current_run_ix.load(SeqCst);
let mut updated_example = updated_example.lock().unwrap();
let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
let run_dir = if repetition_count > 1 {
run_dir.join(format!("{:03}", run_ix))
} else {
run_dir.clone()
};
cx.subscribe(&buffer_store, {
let project = project.clone();
let store = store.clone();
move |_, event, cx| match event {
BufferStoreEvent::BufferAdded(buffer) => {
store.update(cx, |store, cx| store.register_buffer(&buffer, &project, cx));
}
_ => {}
}
})?
.detach();
match event {
DebugEvent::EditPredictionStarted(request) => {
assert_eq!(updated_example.predictions.len(), run_ix + 1);
anyhow::Ok(store)
}
pub async fn perform_predict(
example: NamedExample,
project: Entity<Project>,
store: Entity<EditPredictionStore>,
repetition_ix: Option<u16>,
options: PredictionOptions,
cx: &mut AsyncApp,
) -> Result<PredictionDetails> {
let mut cache_mode = options.cache;
if repetition_ix.is_some() {
if cache_mode != CacheMode::Auto && cache_mode != CacheMode::Skip {
panic!("Repetitions are not supported in Auto cache mode");
} else {
cache_mode = CacheMode::Skip;
}
} else if cache_mode == CacheMode::Auto {
cache_mode = CacheMode::Requests;
}
let mut example_run_dir = RUN_DIR.join(&example.file_name());
if let Some(repetition_ix) = repetition_ix {
example_run_dir = example_run_dir.join(format!("{:03}", repetition_ix));
}
fs::create_dir_all(&example_run_dir)?;
if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
}
#[cfg(unix)]
std::os::unix::fs::symlink(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
.context("creating latest link")?;
#[cfg(windows)]
std::os::windows::fs::symlink_dir(&example_run_dir, &*LATEST_EXAMPLE_RUN_DIR)
.context("creating latest link")?;
store.update(cx, |store, _cx| {
store.with_eval_cache(Arc::new(RunCache {
example_run_dir: example_run_dir.clone(),
cache_mode,
}));
})?;
let (cursor_buffer, cursor_anchor) = example.cursor_position(&project, cx).await?;
let result = Arc::new(Mutex::new(PredictionDetails::new(example_run_dir.clone())));
let prompt_format = options.zeta2.prompt_format;
store.update(cx, |store, _cx| {
let mut options = store.options().clone();
options.prompt_format = prompt_format.into();
store.set_options(options);
})?;
let mut debug_task = gpui::Task::ready(Ok(()));
if options.provider == crate::PredictionProvider::Zeta2 {
let mut debug_rx = store.update(cx, |store, _| store.debug_info())?;
debug_task = cx.background_spawn({
let result = result.clone();
async move {
let mut start_time = None;
let mut retrieval_finished_at = None;
while let Some(event) = debug_rx.next().await {
match event {
edit_prediction::DebugEvent::ContextRetrievalStarted(info) => {
start_time = Some(info.timestamp);
fs::write(
example_run_dir.join("search_prompt.md"),
&info.search_prompt,
)?;
if let Some(prompt) = request.prompt {
fs::write(run_dir.join("prediction_prompt.md"), &prompt)?;
}
edit_prediction::DebugEvent::ContextRetrievalFinished(info) => {
retrieval_finished_at = Some(info.timestamp);
for (key, value) in &info.metadata {
if *key == "search_queries" {
fs::write(
example_run_dir.join("search_queries.json"),
value.as_bytes(),
)?;
}
}
}
DebugEvent::EditPredictionFinished(request) => {
assert_eq!(updated_example.predictions.len(), run_ix + 1);
if let Some(output) = request.model_output {
fs::write(run_dir.join("prediction_response.md"), &output)?;
updated_example
.predictions
.last_mut()
.unwrap()
.actual_output = output;
}
edit_prediction::DebugEvent::EditPredictionRequested(request) => {
let prediction_started_at = Instant::now();
start_time.get_or_insert(prediction_started_at);
let prompt = request.local_prompt.unwrap_or_default();
fs::write(example_run_dir.join("prediction_prompt.md"), &prompt)?;
{
let mut result = result.lock().unwrap();
result.prompt_len = prompt.chars().count();
for included_file in request.inputs.included_files {
let insertions =
vec![(request.inputs.cursor_point, CURSOR_MARKER)];
result.excerpts.extend(included_file.excerpts.iter().map(
|excerpt| ActualExcerpt {
path: included_file.path.components().skip(1).collect(),
text: String::from(excerpt.text.as_ref()),
},
));
write_codeblock(
&included_file.path,
included_file.excerpts.iter(),
if included_file.path == request.inputs.cursor_path {
&insertions
} else {
&[]
},
included_file.max_row,
false,
&mut result.excerpts_text,
);
}
}
let response =
request.response_rx.await?.0.map_err(|err| anyhow!(err))?;
let response =
edit_prediction::open_ai_response::text_from_response(response)
.unwrap_or_default();
let prediction_finished_at = Instant::now();
fs::write(example_run_dir.join("prediction_response.md"), &response)?;
let mut result = result.lock().unwrap();
result.generated_len = response.chars().count();
result.retrieval_time =
retrieval_finished_at.unwrap() - start_time.unwrap();
result.prediction_time = prediction_finished_at - prediction_started_at;
result.total_time = prediction_finished_at - start_time.unwrap();
if run_ix >= repetition_count {
break;
}
}
_ => {}
}
anyhow::Ok(())
}
});
anyhow::Ok(())
}
});
store.update(cx, |store, cx| {
store.refresh_context(&project, &cursor_buffer, cursor_anchor, cx)
})?;
}
let prediction = store
.update(cx, |store, cx| {
store.request_prediction(
&project,
&cursor_buffer,
cursor_anchor,
cloud_llm_client::PredictEditsRequestTrigger::Cli,
cx,
)
})?
.await?;
debug_task.await?;
let mut result = Arc::into_inner(result).unwrap().into_inner().unwrap();
result.diff = prediction
.and_then(|prediction| {
let prediction = prediction.prediction.ok()?;
prediction.edit_preview.as_unified_diff(&prediction.edits)
})
.unwrap_or_default();
anyhow::Ok(result)
}
struct RunCache {
cache_mode: CacheMode,
example_run_dir: PathBuf,
}
impl RunCache {
fn output_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
CACHE_DIR.join(format!("{kind}_out_{key:x}.json",))
}
fn input_cache_path((kind, key): &EvalCacheKey) -> PathBuf {
CACHE_DIR.join(format!("{kind}_in_{key:x}.json",))
}
fn link_to_run(&self, key: &EvalCacheKey) {
let output_link_path = self.example_run_dir.join(format!("{}_out.json", key.0));
fs::hard_link(Self::output_cache_path(key), &output_link_path).unwrap();
let input_link_path = self.example_run_dir.join(format!("{}_in.json", key.0));
fs::hard_link(Self::input_cache_path(key), &input_link_path).unwrap();
}
}
impl EvalCache for RunCache {
fn read(&self, key: EvalCacheKey) -> Option<String> {
let path = RunCache::output_cache_path(&key);
if path.exists() {
let use_cache = match key.0 {
EvalCacheEntryKind::Search => self.cache_mode.use_cached_search_results(),
EvalCacheEntryKind::Context | EvalCacheEntryKind::Prediction => {
self.cache_mode.use_cached_llm_responses()
}
};
if use_cache {
log::info!("Using cache entry: {}", path.display());
self.link_to_run(&key);
Some(fs::read_to_string(path).unwrap())
} else {
log::trace!("Skipping cached entry: {}", path.display());
None
}
} else if matches!(self.cache_mode, CacheMode::Force) {
panic!(
"No cached entry found for {:?}. Run without `--cache force` at least once.",
key.0
);
for ix in 0..repetition_count {
current_run_ix.store(ix, SeqCst);
let run_dir = if repetition_count > 1 {
run_dir.join(format!("{:03}", ix))
} else {
None
}
}
fn write(&self, key: EvalCacheKey, input: &str, output: &str) {
fs::create_dir_all(&*CACHE_DIR).unwrap();
let input_path = RunCache::input_cache_path(&key);
fs::write(&input_path, input).unwrap();
let output_path = RunCache::output_cache_path(&key);
log::trace!("Writing cache entry: {}", output_path.display());
fs::write(&output_path, output).unwrap();
self.link_to_run(&key);
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PredictionDetails {
pub diff: String,
pub excerpts: Vec<ActualExcerpt>,
pub excerpts_text: String, // TODO: contains the worktree root path. Drop this field and compute it on the fly
pub retrieval_time: Duration,
pub prediction_time: Duration,
pub total_time: Duration,
pub run_example_dir: PathBuf,
pub prompt_len: usize,
pub generated_len: usize,
}
impl PredictionDetails {
pub fn new(run_example_dir: PathBuf) -> Self {
Self {
diff: Default::default(),
excerpts: Default::default(),
excerpts_text: Default::default(),
retrieval_time: Default::default(),
prediction_time: Default::default(),
total_time: Default::default(),
run_example_dir,
prompt_len: 0,
generated_len: 0,
}
}
pub fn write(&self, format: PredictionsOutputFormat, mut out: impl Write) -> Result<()> {
let formatted = match format {
PredictionsOutputFormat::Md => self.to_markdown(),
PredictionsOutputFormat::Json => serde_json::to_string_pretty(self)?,
PredictionsOutputFormat::Diff => self.diff.clone(),
run_dir.clone()
};
Ok(out.write_all(formatted.as_bytes())?)
fs::create_dir_all(&run_dir).unwrap();
if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR).unwrap();
}
#[cfg(unix)]
std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
#[cfg(windows)]
std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
updated_example
.lock()
.unwrap()
.predictions
.push(ExamplePrediction {
actual_patch: String::new(),
actual_output: String::new(),
provider,
});
let prediction = ep_store
.update(&mut cx, |store, cx| {
store.request_prediction(
&state.project,
&state.buffer,
state.cursor_position,
cloud_llm_client::PredictEditsRequestTrigger::Cli,
cx,
)
})
.unwrap()
.await
.unwrap();
updated_example
.lock()
.unwrap()
.predictions
.last_mut()
.unwrap()
.actual_patch = prediction
.and_then(|prediction| {
let prediction = prediction.prediction.ok()?;
prediction.edit_preview.as_unified_diff(&prediction.edits)
})
.unwrap_or_default();
}
pub fn to_markdown(&self) -> String {
format!(
"## Excerpts\n\n\
{}\n\n\
## Prediction\n\n\
{}\n\n\
## Time\n\n\
Retrieval: {}ms\n\
Prediction: {}ms\n\n\
Total: {}ms\n",
self.excerpts_text,
self.diff,
self.retrieval_time.as_millis(),
self.prediction_time.as_millis(),
self.total_time.as_millis(),
)
ep_store
.update(&mut cx, |store, _| {
store.remove_project(&state.project);
})
.unwrap();
debug_task.await.unwrap();
*example = Arc::into_inner(updated_example)
.unwrap()
.into_inner()
.unwrap();
}
async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batched: bool) {
let llm_model_name = "claude-sonnet-4-5";
let max_tokens = 16384;
let llm_client = if batched {
AnthropicClient::batch(&crate::paths::LLM_CACHE_DB.as_ref())
} else {
AnthropicClient::plain()
};
let llm_client = llm_client.expect("Failed to create LLM client");
let prompt = example
.prompt
.as_ref()
.unwrap_or_else(|| panic!("Prompt is required for an example {}", &example.name));
let messages = vec![anthropic::Message {
role: anthropic::Role::User,
content: vec![anthropic::RequestContent::Text {
text: prompt.input.clone(),
cache_control: None,
}],
}];
let Some(response) = llm_client
.generate(llm_model_name, max_tokens, messages)
.await
.unwrap()
else {
// Request stashed for batched processing
return;
};
let actual_output = response
.content
.into_iter()
.filter_map(|content| match content {
anthropic::ResponseContent::Text { text } => Some(text),
_ => None,
})
.collect::<Vec<String>>()
.join("\n");
let actual_patch = TeacherPrompt::parse(example, &actual_output);
let prediction = ExamplePrediction {
actual_patch,
actual_output,
provider: PredictionProvider::Teacher,
};
example.predictions.push(prediction);
}
pub async fn sync_batches(provider: &PredictionProvider) {
match provider {
PredictionProvider::Teacher => {
let cache_path = crate::paths::LLM_CACHE_DB.as_ref();
let llm_client =
AnthropicClient::batch(cache_path).expect("Failed to create LLM client");
llm_client
.sync_batches()
.await
.expect("Failed to sync batches");
}
_ => (),
}
}

View File

@@ -1,106 +1,136 @@
use anyhow::{Result, anyhow};
use futures::channel::mpsc;
use futures::{FutureExt as _, StreamExt as _};
use crate::{
example::{Example, ExampleContext},
headless::EpAppState,
load_project::run_load_project,
};
use anyhow::Result;
use collections::HashSet;
use edit_prediction::{DebugEvent, EditPredictionStore};
use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
use gpui::{AsyncApp, Entity, Task};
use language::{Buffer, LanguageId, LanguageNotFound, LanguageServerId, ParseStatus};
use project::lsp_store::OpenLspBufferHandle;
use project::{Project, ProjectPath, Worktree};
use std::collections::HashSet;
use std::sync::Arc;
use std::time::Duration;
use util::rel_path::RelPath;
use language::{Buffer, LanguageNotFound};
use project::Project;
use std::{sync::Arc, time::Duration};
pub fn open_buffer(
project: Entity<Project>,
worktree: Entity<Worktree>,
path: Arc<RelPath>,
cx: &AsyncApp,
) -> Task<Result<Entity<Buffer>>> {
cx.spawn(async move |cx| {
let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath {
worktree_id: worktree.id(),
path,
})?;
pub async fn run_context_retrieval(
example: &mut Example,
app_state: Arc<EpAppState>,
mut cx: AsyncApp,
) {
if example.context.is_some() {
return;
}
let buffer = project
.update(cx, |project, cx| project.open_buffer(project_path, cx))?
.await?;
run_load_project(example, app_state.clone(), cx.clone()).await;
let mut parse_status = buffer.read_with(cx, |buffer, _cx| buffer.parse_status())?;
while *parse_status.borrow() != ParseStatus::Idle {
parse_status.changed().await?;
let state = example.state.as_ref().unwrap();
let project = state.project.clone();
let _lsp_handle = project
.update(&mut cx, |project, cx| {
project.register_buffer_with_language_servers(&state.buffer, cx)
})
.unwrap();
wait_for_language_server_to_start(example, &project, &state.buffer, &mut cx).await;
let ep_store = cx
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
.unwrap();
let mut events = ep_store
.update(&mut cx, |store, cx| {
store.register_buffer(&state.buffer, &project, cx);
store.set_use_context(true);
store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
store.debug_info(&project, cx)
})
.unwrap();
while let Some(event) = events.next().await {
match event {
DebugEvent::ContextRetrievalFinished(_) => {
break;
}
_ => {}
}
}
Ok(buffer)
})
let context_files = ep_store
.update(&mut cx, |store, cx| store.context_for_project(&project, cx))
.unwrap();
example.context = Some(ExampleContext {
files: context_files,
});
}
pub async fn open_buffer_with_language_server(
project: Entity<Project>,
worktree: Entity<Worktree>,
path: Arc<RelPath>,
ready_languages: &mut HashSet<LanguageId>,
async fn wait_for_language_server_to_start(
example: &Example,
project: &Entity<Project>,
buffer: &Entity<Buffer>,
cx: &mut AsyncApp,
) -> Result<(OpenLspBufferHandle, LanguageServerId, Entity<Buffer>)> {
let buffer = open_buffer(project.clone(), worktree, path.clone(), cx).await?;
let (lsp_open_handle, path_style) = project.update(cx, |project, cx| {
(
project.register_buffer_with_language_servers(&buffer, cx),
project.path_style(cx),
)
})?;
let language_registry = project.read_with(cx, |project, _| project.languages().clone())?;
) {
let language_registry = project
.read_with(cx, |project, _| project.languages().clone())
.unwrap();
let result = language_registry
.load_language_for_file_path(path.as_std_path())
.load_language_for_file_path(&example.cursor_path)
.await;
if let Err(error) = result
&& !error.is::<LanguageNotFound>()
{
anyhow::bail!(error);
panic!("Failed to load language for file path: {}", error);
}
let Some(language_id) = buffer.read_with(cx, |buffer, _cx| {
buffer.language().map(|language| language.id())
})?
let Some(language_id) = buffer
.read_with(cx, |buffer, _cx| {
buffer.language().map(|language| language.id())
})
.unwrap()
else {
return Err(anyhow!("No language for {}", path.display(path_style)));
panic!("No language for {:?}", example.cursor_path);
};
let log_prefix = format!("{} | ", path.display(path_style));
let mut ready_languages = HashSet::default();
let log_prefix = format!("{} | ", example.name);
if !ready_languages.contains(&language_id) {
wait_for_lang_server(&project, &buffer, log_prefix, cx).await?;
wait_for_lang_server(&project, &buffer, log_prefix, cx)
.await
.unwrap();
ready_languages.insert(language_id);
}
let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?;
let lsp_store = project
.read_with(cx, |project, _cx| project.lsp_store())
.unwrap();
// hacky wait for buffer to be registered with the language server
for _ in 0..100 {
let Some(language_server_id) = lsp_store.update(cx, |lsp_store, cx| {
buffer.update(cx, |buffer, cx| {
lsp_store
.language_servers_for_local_buffer(&buffer, cx)
.next()
.map(|(_, language_server)| language_server.server_id())
if lsp_store
.update(cx, |lsp_store, cx| {
buffer.update(cx, |buffer, cx| {
lsp_store
.language_servers_for_local_buffer(&buffer, cx)
.next()
.map(|(_, language_server)| language_server.server_id())
})
})
})?
else {
.unwrap()
.is_some()
{
return;
} else {
cx.background_executor()
.timer(Duration::from_millis(10))
.await;
continue;
};
return Ok((lsp_open_handle, language_server_id, buffer));
}
}
return Err(anyhow!("No language server found for buffer"));
panic!("No language server found for buffer");
}
// TODO: Dedupe with similar function in crates/eval/src/instance.rs
pub fn wait_for_lang_server(
project: &Entity<Project>,
buffer: &Entity<Buffer>,

View File

@@ -0,0 +1,119 @@
use crate::{
PredictArgs,
example::{Example, ExampleScore},
headless::EpAppState,
metrics::{self, ClassificationMetrics},
predict::run_prediction,
};
use edit_prediction::udiff::DiffLine;
use gpui::AsyncApp;
use std::sync::Arc;
pub async fn run_scoring(
example: &mut Example,
args: &PredictArgs,
app_state: Arc<EpAppState>,
cx: AsyncApp,
) {
run_prediction(
example,
Some(args.provider),
args.repetitions,
app_state,
cx,
)
.await;
let expected_patch = parse_patch(&example.expected_patch);
let mut scores = vec![];
for pred in &example.predictions {
let actual_patch = parse_patch(&pred.actual_patch);
let line_match = metrics::line_match_score(&expected_patch, &actual_patch);
let delta_chr_f = metrics::delta_chr_f(&expected_patch, &actual_patch) as f32;
scores.push(ExampleScore {
delta_chr_f,
line_match,
});
}
example.score = scores;
}
fn parse_patch(patch: &str) -> Vec<DiffLine<'_>> {
patch.lines().map(DiffLine::parse).collect()
}
pub fn print_report(examples: &[Example]) {
eprintln!(
"──────────────────────────────────────────────────────────────────────────────────────"
);
eprintln!(
"{:<30} {:>4} {:>4} {:>4} {:>10} {:>8} {:>8} {:>10}",
"Example name", "TP", "FP", "FN", "Precision", "Recall", "F1", "DeltaChrF"
);
eprintln!(
"──────────────────────────────────────────────────────────────────────────────────────"
);
let mut all_line_match_scores = Vec::new();
let mut all_delta_chr_f_scores = Vec::new();
for example in examples {
for score in example.score.iter() {
let line_match = &score.line_match;
eprintln!(
"{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
truncate_name(&example.name, 30),
line_match.true_positives,
line_match.false_positives,
line_match.false_negatives,
line_match.precision() * 100.0,
line_match.recall() * 100.0,
line_match.f1_score() * 100.0,
score.delta_chr_f
);
all_line_match_scores.push(line_match.clone());
all_delta_chr_f_scores.push(score.delta_chr_f);
}
}
eprintln!(
"──────────────────────────────────────────────────────────────────────────────────────"
);
if !all_line_match_scores.is_empty() {
let total_line_match = ClassificationMetrics::aggregate(all_line_match_scores.iter());
let avg_delta_chr_f: f32 =
all_delta_chr_f_scores.iter().sum::<f32>() / all_delta_chr_f_scores.len() as f32;
eprintln!(
"{:<30} {:>4} {:>4} {:>4} {:>9.2}% {:>7.2}% {:>7.2}% {:>9.2}",
"TOTAL",
total_line_match.true_positives,
total_line_match.false_positives,
total_line_match.false_negatives,
total_line_match.precision() * 100.0,
total_line_match.recall() * 100.0,
total_line_match.f1_score() * 100.0,
avg_delta_chr_f
);
eprintln!(
"──────────────────────────────────────────────────────────────────────────────────────"
);
}
eprintln!("\n");
}
fn truncate_name(name: &str, max_len: usize) -> String {
if name.len() <= max_len {
name.to_string()
} else {
format!("{}...", &name[..max_len - 3])
}
}

View File

@@ -1,70 +0,0 @@
use std::{fmt, fmt::Display, path::Path, str::FromStr, sync::Arc};
use ::util::{paths::PathStyle, rel_path::RelPath};
use anyhow::{Result, anyhow};
use language::Point;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct SourceLocation {
pub path: Arc<RelPath>,
pub point: Point,
}
impl Serialize for SourceLocation {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&self.to_string())
}
}
impl<'de> Deserialize<'de> for SourceLocation {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let s = String::deserialize(deserializer)?;
s.parse().map_err(serde::de::Error::custom)
}
}
impl Display for SourceLocation {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"{}:{}:{}",
self.path.display(PathStyle::Posix),
self.point.row + 1,
self.point.column + 1
)
}
}
impl FromStr for SourceLocation {
type Err = anyhow::Error;
fn from_str(s: &str) -> Result<Self> {
let parts: Vec<&str> = s.split(':').collect();
if parts.len() != 3 {
return Err(anyhow!(
"Invalid source location. Expected 'file.rs:line:column', got '{}'",
s
));
}
let path = RelPath::new(Path::new(&parts[0]), PathStyle::local())?.into_arc();
let line: u32 = parts[1]
.parse()
.map_err(|_| anyhow!("Invalid line number: '{}'", parts[1]))?;
let column: u32 = parts[2]
.parse()
.map_err(|_| anyhow!("Invalid column number: '{}'", parts[2]))?;
// Convert from 1-based to 0-based indexing
let point = Point::new(line.saturating_sub(1), column.saturating_sub(1));
Ok(SourceLocation { path, point })
}
}

View File

@@ -46,3 +46,7 @@ Output example:
## Code Context
{{context}}
## Editable region
{{editable_region}}

View File

@@ -1,89 +0,0 @@
use std::path::Path;
use crate::{source_location::SourceLocation, training::teacher::TeacherModel};
#[derive(Debug, Clone, Default, clap::ValueEnum)]
pub enum ContextType {
#[default]
CurrentFile,
}
const MAX_CONTEXT_SIZE: usize = 32768;
pub fn collect_context(
context_type: &ContextType,
worktree_dir: &Path,
cursor: SourceLocation,
) -> String {
let context = match context_type {
ContextType::CurrentFile => {
let file_path = worktree_dir.join(cursor.path.as_std_path());
let context = std::fs::read_to_string(&file_path).unwrap_or_default();
let context = add_special_tags(&context, worktree_dir, cursor);
context
}
};
let region_end_offset = context.find(TeacherModel::REGION_END);
if context.len() <= MAX_CONTEXT_SIZE {
return context;
}
if let Some(region_end_offset) = region_end_offset
&& region_end_offset + TeacherModel::REGION_END.len() > MAX_CONTEXT_SIZE
{
let to_truncate = context.len() - MAX_CONTEXT_SIZE;
format!(
"[...{} bytes truncated]\n{}\n",
to_truncate,
&context[to_truncate..]
)
} else {
format!(
"{}\n[...{} bytes truncated]\n",
&context[..MAX_CONTEXT_SIZE],
context.len() - MAX_CONTEXT_SIZE
)
}
}
/// Add <|editable_region_start/end|> tags
fn add_special_tags(context: &str, worktree_dir: &Path, cursor: SourceLocation) -> String {
let path = worktree_dir.join(cursor.path.as_std_path());
let file = std::fs::read_to_string(&path).unwrap_or_default();
let lines = file.lines().collect::<Vec<_>>();
let cursor_row = cursor.point.row as usize;
let start_line = cursor_row.saturating_sub(TeacherModel::LEFT_CONTEXT_SIZE);
let end_line = (cursor_row + TeacherModel::RIGHT_CONTEXT_SIZE).min(lines.len());
let snippet = lines[start_line..end_line].join("\n");
if context.contains(&snippet) {
let mut cursor_line = lines[cursor_row].to_string();
cursor_line.insert_str(cursor.point.column as usize, TeacherModel::USER_CURSOR);
let mut snippet_with_tags_lines = vec![];
snippet_with_tags_lines.push(TeacherModel::REGION_START);
snippet_with_tags_lines.extend(&lines[start_line..cursor_row]);
snippet_with_tags_lines.push(&cursor_line);
snippet_with_tags_lines.extend(&lines[cursor_row + 1..end_line]);
snippet_with_tags_lines.push(TeacherModel::REGION_END);
let snippet_with_tags = snippet_with_tags_lines.join("\n");
context.replace(&snippet, &snippet_with_tags)
} else {
log::warn!(
"Can't find area around the cursor in the context; proceeding without special tags"
);
context.to_string()
}
}
pub fn strip_special_tags(context: &str) -> String {
context
.replace(TeacherModel::REGION_START, "")
.replace(TeacherModel::REGION_END, "")
.replace(TeacherModel::USER_CURSOR, "")
}

View File

@@ -1,94 +0,0 @@
use serde::Deserialize;
use std::sync::Arc;
use crate::{
DistillArguments,
example::Example,
source_location::SourceLocation,
training::{
context::ContextType,
llm_client::LlmClient,
teacher::{TeacherModel, TeacherOutput},
},
};
use anyhow::Result;
use reqwest_client::ReqwestClient;
#[derive(Debug, Deserialize)]
pub struct SplitCommit {
repo_url: String,
commit_sha: String,
edit_history: String,
expected_patch: String,
cursor_position: String,
}
pub async fn run_distill(arguments: DistillArguments) -> Result<()> {
let split_commits: Vec<SplitCommit> = std::fs::read_to_string(&arguments.split_commit_dataset)
.expect("Failed to read split commit dataset")
.lines()
.map(|line| serde_json::from_str(line).expect("Failed to parse JSON line"))
.collect();
let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
let llm_client = if let Some(cache_path) = arguments.batch {
LlmClient::batch(&cache_path, http_client)?
} else {
LlmClient::plain(http_client)?
};
let mut teacher = TeacherModel::new(
"claude-sonnet-4-5".to_string(),
ContextType::CurrentFile,
llm_client,
);
let mut num_marked_for_batching = 0;
for commit in split_commits {
if let Some(distilled) = distill_one(&mut teacher, commit).await? {
println!("{}", serde_json::to_string(&distilled)?);
} else {
if num_marked_for_batching == 0 {
log::warn!("Marked for batching");
}
num_marked_for_batching += 1;
}
}
eprintln!(
"{} requests are marked for batching",
num_marked_for_batching
);
let llm_client = teacher.client;
llm_client.sync_batches().await?;
Ok(())
}
pub async fn distill_one(
teacher: &mut TeacherModel,
commit: SplitCommit,
) -> Result<Option<TeacherOutput>> {
let cursor: SourceLocation = commit
.cursor_position
.parse()
.expect("Failed to parse cursor position");
let path = cursor.path.to_rel_path_buf();
let example = Example {
repository_url: commit.repo_url,
revision: commit.commit_sha,
uncommitted_diff: commit.edit_history.clone(),
cursor_path: path.as_std_path().to_path_buf(),
cursor_position: commit.cursor_position,
edit_history: commit.edit_history, // todo: trim
expected_patch: commit.expected_patch,
};
let prediction = teacher.predict(example).await;
prediction
}

View File

@@ -1,4 +0,0 @@
pub mod context;
pub mod distill;
pub mod llm_client;
pub mod teacher;

View File

@@ -1,266 +0,0 @@
use crate::{
example::Example,
source_location::SourceLocation,
training::{
context::{ContextType, collect_context, strip_special_tags},
llm_client::LlmClient,
},
};
use anthropic::{Message, RequestContent, ResponseContent, Role};
use anyhow::Result;
pub struct TeacherModel {
pub llm_name: String,
pub context: ContextType,
pub client: LlmClient,
}
#[derive(Debug, serde::Serialize)]
pub struct TeacherOutput {
parsed_output: String,
prompt: String,
raw_llm_response: String,
context: String,
diff: String,
}
impl TeacherModel {
const PROMPT: &str = include_str!("teacher.prompt.md");
pub(crate) const REGION_START: &str = "<|editable_region_start|>\n";
pub(crate) const REGION_END: &str = "<|editable_region_end|>";
pub(crate) const USER_CURSOR: &str = "<|user_cursor|>";
/// Number of lines to include before the cursor position
pub(crate) const LEFT_CONTEXT_SIZE: usize = 5;
/// Number of lines to include after the cursor position
pub(crate) const RIGHT_CONTEXT_SIZE: usize = 5;
/// Truncate edit history to this number of last lines
const MAX_HISTORY_LINES: usize = 128;
pub fn new(llm_name: String, context: ContextType, client: LlmClient) -> Self {
TeacherModel {
llm_name,
context,
client,
}
}
pub async fn predict(&self, input: Example) -> Result<Option<TeacherOutput>> {
let name = input.unique_name();
let worktree_dir = input.setup_worktree(name).await?;
let cursor: SourceLocation = input
.cursor_position
.parse()
.expect("Failed to parse cursor position");
let context = collect_context(&self.context, &worktree_dir, cursor.clone());
let edit_history = Self::format_edit_history(&input.edit_history);
let prompt = Self::PROMPT
.replace("{{context}}", &context)
.replace("{{edit_history}}", &edit_history);
let messages = vec![Message {
role: Role::User,
content: vec![RequestContent::Text {
text: prompt.clone(),
cache_control: None,
}],
}];
let Some(response) = self
.client
.generate(self.llm_name.clone(), 16384, messages)
.await?
else {
return Ok(None);
};
let response_text = response
.content
.into_iter()
.filter_map(|content| match content {
ResponseContent::Text { text } => Some(text),
_ => None,
})
.collect::<Vec<String>>()
.join("\n");
let parsed_output = self.parse_response(&response_text);
let original_editable_region = Self::extract_editable_region(&context);
let context_after_edit = context.replace(&original_editable_region, &parsed_output);
let context_after_edit = strip_special_tags(&context_after_edit);
let context_before_edit = strip_special_tags(&context);
let diff = language::unified_diff(&context_before_edit, &context_after_edit);
// zeta distill --batch batch_results.txt
// zeta distill
// 1. Run `zeta distill <2000 examples <- all examples>` for the first time
// - store LLM requests in a batch, don't actual send the request
// - send the batch (2000 requests) after all inputs are processed
// 2. `zeta send-batches`
// - upload the batch to Anthropic
// https://platform.claude.com/docs/en/build-with-claude/batch-processing
// https://crates.io/crates/anthropic-sdk-rust
// - poll for results
// - when ready, store results in cache (a database)
// 3. `zeta distill` again
// - use the cached results this time
Ok(Some(TeacherOutput {
parsed_output,
prompt,
raw_llm_response: response_text,
context,
diff,
}))
}
fn parse_response(&self, content: &str) -> String {
let codeblock = Self::extract_last_codeblock(content);
let editable_region = Self::extract_editable_region(&codeblock);
editable_region
}
/// Extract content from the last code-fenced block if any, or else return content as is
fn extract_last_codeblock(text: &str) -> String {
let mut last_block = None;
let mut search_start = 0;
while let Some(start) = text[search_start..].find("```") {
let start = start + search_start;
let bytes = text.as_bytes();
let mut backtick_end = start;
while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
backtick_end += 1;
}
let backtick_count = backtick_end - start;
let closing_backticks = "`".repeat(backtick_count);
if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1];
last_block = Some(code_block.to_string());
search_start = backtick_end + end_pos + backtick_count;
} else {
break;
}
}
last_block.unwrap_or_else(|| text.to_string())
}
fn extract_editable_region(text: &str) -> String {
let start = text
.find(Self::REGION_START)
.map_or(0, |pos| pos + Self::REGION_START.len());
let end = text.find(Self::REGION_END).unwrap_or(text.len());
text[start..end].to_string()
}
/// Truncates edit history to a maximum length and removes comments (unified diff garbage lines)
fn format_edit_history(edit_history: &str) -> String {
let lines = edit_history
.lines()
.filter(|&s| Self::is_content_line(s))
.collect::<Vec<_>>();
let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
&lines[lines.len() - Self::MAX_HISTORY_LINES..]
} else {
&lines
};
history_lines.join("\n")
}
fn is_content_line(s: &str) -> bool {
s.starts_with("-")
|| s.starts_with("+")
|| s.starts_with(" ")
|| s.starts_with("---")
|| s.starts_with("+++")
|| s.starts_with("@@")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_response() {
let teacher = TeacherModel::new(
"test".to_string(),
ContextType::CurrentFile,
LlmClient::dummy(),
);
let response = "This is a test response.";
let parsed = teacher.parse_response(response);
assert_eq!(parsed, response.to_string());
let response = indoc::indoc! {"
Some thinking
`````
actual response
`````
"};
let parsed = teacher.parse_response(response);
assert_eq!(parsed, "actual response");
}
#[test]
fn test_extract_last_code_block() {
let text = indoc::indoc! {"
Some thinking
```
first block
```
`````
last block
`````
"};
let last_block = TeacherModel::extract_last_codeblock(text);
assert_eq!(last_block, "last block");
}
#[test]
fn test_extract_editable_region() {
let teacher = TeacherModel::new(
"test".to_string(),
ContextType::CurrentFile,
LlmClient::dummy(),
);
let response = indoc::indoc! {"
some lines
are
here
<|editable_region_start|>
one
two three
<|editable_region_end|>
more
lines here
"};
let parsed = teacher.parse_response(response);
assert_eq!(
parsed,
indoc::indoc! {"
one
two three
"}
);
}
}

View File

@@ -26,6 +26,7 @@ serde.workspace = true
smallvec.workspace = true
tree-sitter.workspace = true
util.workspace = true
zeta_prompt.workspace = true
[dev-dependencies]
env_logger.workspace = true

View File

@@ -1,6 +1,6 @@
use crate::RelatedExcerpt;
use language::{BufferSnapshot, OffsetRangeExt as _, Point};
use std::ops::Range;
use zeta_prompt::RelatedExcerpt;
#[cfg(not(test))]
const MAX_OUTLINE_ITEM_BODY_SIZE: usize = 512;
@@ -76,14 +76,9 @@ pub fn assemble_excerpts(
input_ranges
.into_iter()
.map(|range| {
let offset_range = range.to_offset(buffer);
RelatedExcerpt {
point_range: range,
anchor_range: buffer.anchor_before(offset_range.start)
..buffer.anchor_after(offset_range.end),
text: buffer.as_rope().slice(offset_range),
}
.map(|range| RelatedExcerpt {
row_range: range.start.row..range.end.row,
text: buffer.text_for_range(range).collect(),
})
.collect()
}

View File

@@ -3,13 +3,13 @@ use anyhow::Result;
use collections::HashMap;
use futures::{FutureExt, StreamExt as _, channel::mpsc, future};
use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, Rope, ToOffset as _};
use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset as _};
use project::{LocationLink, Project, ProjectPath};
use serde::{Serialize, Serializer};
use smallvec::SmallVec;
use std::{
collections::hash_map,
ops::Range,
path::Path,
sync::Arc,
time::{Duration, Instant},
};
@@ -24,12 +24,14 @@ mod fake_definition_lsp;
pub use cloud_llm_client::predict_edits_v3::Line;
pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
pub use zeta_prompt::{RelatedExcerpt, RelatedFile};
const IDENTIFIER_LINE_COUNT: u32 = 3;
pub struct RelatedExcerptStore {
project: WeakEntity<Project>,
related_files: Vec<RelatedFile>,
related_files: Arc<[RelatedFile]>,
related_file_buffers: Vec<Entity<Buffer>>,
cache: HashMap<Identifier, Arc<CacheEntry>>,
update_tx: mpsc::UnboundedSender<(Entity<Buffer>, Anchor)>,
identifier_line_count: u32,
@@ -68,82 +70,6 @@ struct CachedDefinition {
anchor_range: Range<Anchor>,
}
#[derive(Clone, Debug, Serialize)]
pub struct RelatedFile {
#[serde(serialize_with = "serialize_project_path")]
pub path: ProjectPath,
#[serde(skip)]
pub buffer: WeakEntity<Buffer>,
pub excerpts: Vec<RelatedExcerpt>,
pub max_row: u32,
}
impl RelatedFile {
pub fn merge_excerpts(&mut self) {
self.excerpts.sort_unstable_by(|a, b| {
a.point_range
.start
.cmp(&b.point_range.start)
.then(b.point_range.end.cmp(&a.point_range.end))
});
let mut index = 1;
while index < self.excerpts.len() {
if self.excerpts[index - 1]
.point_range
.end
.cmp(&self.excerpts[index].point_range.start)
.is_ge()
{
let removed = self.excerpts.remove(index);
if removed
.point_range
.end
.cmp(&self.excerpts[index - 1].point_range.end)
.is_gt()
{
self.excerpts[index - 1].point_range.end = removed.point_range.end;
self.excerpts[index - 1].anchor_range.end = removed.anchor_range.end;
}
} else {
index += 1;
}
}
}
}
#[derive(Clone, Debug, Serialize)]
pub struct RelatedExcerpt {
#[serde(skip)]
pub anchor_range: Range<Anchor>,
#[serde(serialize_with = "serialize_point_range")]
pub point_range: Range<Point>,
#[serde(serialize_with = "serialize_rope")]
pub text: Rope,
}
fn serialize_project_path<S: Serializer>(
project_path: &ProjectPath,
serializer: S,
) -> Result<S::Ok, S::Error> {
project_path.path.serialize(serializer)
}
fn serialize_rope<S: Serializer>(rope: &Rope, serializer: S) -> Result<S::Ok, S::Error> {
rope.to_string().serialize(serializer)
}
fn serialize_point_range<S: Serializer>(
range: &Range<Point>,
serializer: S,
) -> Result<S::Ok, S::Error> {
[
[range.start.row, range.start.column],
[range.end.row, range.end.column],
]
.serialize(serializer)
}
const DEBOUNCE_DURATION: Duration = Duration::from_millis(100);
impl EventEmitter<RelatedExcerptStoreEvent> for RelatedExcerptStore {}
@@ -179,7 +105,8 @@ impl RelatedExcerptStore {
RelatedExcerptStore {
project: project.downgrade(),
update_tx,
related_files: Vec::new(),
related_files: Vec::new().into(),
related_file_buffers: Vec::new(),
cache: Default::default(),
identifier_line_count: IDENTIFIER_LINE_COUNT,
}
@@ -193,8 +120,21 @@ impl RelatedExcerptStore {
self.update_tx.unbounded_send((buffer, position)).ok();
}
pub fn related_files(&self) -> &[RelatedFile] {
&self.related_files
pub fn related_files(&self) -> Arc<[RelatedFile]> {
self.related_files.clone()
}
pub fn related_files_with_buffers(
&self,
) -> impl Iterator<Item = (RelatedFile, Entity<Buffer>)> {
self.related_files
.iter()
.cloned()
.zip(self.related_file_buffers.iter().cloned())
}
pub fn set_related_files(&mut self, files: Vec<RelatedFile>) {
self.related_files = files.into();
}
async fn fetch_excerpts(
@@ -297,7 +237,8 @@ impl RelatedExcerptStore {
}
mean_definition_latency /= cache_miss_count.max(1) as u32;
let (new_cache, related_files) = rebuild_related_files(new_cache, cx).await?;
let (new_cache, related_files, related_file_buffers) =
rebuild_related_files(&project, new_cache, cx).await?;
if let Some(file) = &file {
log::debug!(
@@ -309,7 +250,8 @@ impl RelatedExcerptStore {
this.update(cx, |this, cx| {
this.cache = new_cache;
this.related_files = related_files;
this.related_files = related_files.into();
this.related_file_buffers = related_file_buffers;
cx.emit(RelatedExcerptStoreEvent::FinishedRefresh {
cache_hit_count,
cache_miss_count,
@@ -323,10 +265,16 @@ impl RelatedExcerptStore {
}
async fn rebuild_related_files(
project: &Entity<Project>,
new_entries: HashMap<Identifier, Arc<CacheEntry>>,
cx: &mut AsyncApp,
) -> Result<(HashMap<Identifier, Arc<CacheEntry>>, Vec<RelatedFile>)> {
) -> Result<(
HashMap<Identifier, Arc<CacheEntry>>,
Vec<RelatedFile>,
Vec<Entity<Buffer>>,
)> {
let mut snapshots = HashMap::default();
let mut worktree_root_names = HashMap::default();
for entry in new_entries.values() {
for definition in &entry.definitions {
if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) {
@@ -340,12 +288,22 @@ async fn rebuild_related_files(
.read_with(cx, |buffer, _| buffer.snapshot())?,
);
}
let worktree_id = definition.path.worktree_id;
if let hash_map::Entry::Vacant(e) =
worktree_root_names.entry(definition.path.worktree_id)
{
project.read_with(cx, |project, cx| {
if let Some(worktree) = project.worktree_for_id(worktree_id, cx) {
e.insert(worktree.read(cx).root_name().as_unix_str().to_string());
}
})?;
}
}
}
Ok(cx
.background_spawn(async move {
let mut files = Vec::<RelatedFile>::new();
let mut files = Vec::new();
let mut ranges_by_buffer = HashMap::<_, Vec<Range<Point>>>::default();
let mut paths_by_buffer = HashMap::default();
for entry in new_entries.values() {
@@ -369,16 +327,31 @@ async fn rebuild_related_files(
continue;
};
let excerpts = assemble_excerpts(snapshot, ranges);
files.push(RelatedFile {
path: project_path.clone(),
buffer: buffer.downgrade(),
excerpts,
max_row: snapshot.max_point().row,
});
let Some(root_name) = worktree_root_names.get(&project_path.worktree_id) else {
continue;
};
let path = Path::new(&format!(
"{}/{}",
root_name,
project_path.path.as_unix_str()
))
.into();
files.push((
buffer,
RelatedFile {
path,
excerpts,
max_row: snapshot.max_point().row,
},
));
}
files.sort_by_key(|file| file.path.clone());
(new_entries, files)
files.sort_by_key(|(_, file)| file.path.clone());
let (related_buffers, related_files) = files.into_iter().unzip();
(new_entries, related_files, related_buffers)
})
.await)
}

View File

@@ -48,7 +48,7 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) {
&excerpts,
&[
(
"src/company.rs",
"root/src/company.rs",
&[indoc! {"
pub struct Company {
owner: Arc<Person>,
@@ -56,7 +56,7 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) {
}"}],
),
(
"src/main.rs",
"root/src/main.rs",
&[
indoc! {"
pub struct Session {
@@ -71,7 +71,7 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) {
],
),
(
"src/person.rs",
"root/src/person.rs",
&[
indoc! {"
impl Person {
@@ -446,7 +446,7 @@ fn assert_related_files(actual_files: &[RelatedFile], expected_files: &[(&str, &
.iter()
.map(|excerpt| excerpt.text.to_string())
.collect::<Vec<_>>();
(file.path.path.as_unix_str(), excerpts)
(file.path.to_str().unwrap(), excerpts)
})
.collect::<Vec<_>>();
let expected_excerpts = expected_files
@@ -492,10 +492,10 @@ fn format_excerpts(buffer: &Buffer, excerpts: &[RelatedExcerpt]) -> String {
if excerpt.text.is_empty() {
continue;
}
if current_row < excerpt.point_range.start.row {
if current_row < excerpt.row_range.start {
writeln!(&mut output, "").unwrap();
}
current_row = excerpt.point_range.start.row;
current_row = excerpt.row_range.start;
for line in excerpt.text.to_string().lines() {
output.push_str(line);

View File

@@ -17,7 +17,6 @@ anyhow.workspace = true
buffer_diff.workspace = true
client.workspace = true
cloud_llm_client.workspace = true
cloud_zeta2_prompt.workspace = true
codestral.workspace = true
command_palette_hooks.workspace = true
copilot.workspace = true
@@ -46,6 +45,7 @@ ui_input.workspace = true
util.workspace = true
workspace.workspace = true
zed_actions.workspace = true
zeta_prompt.workspace = true
[dev-dependencies]
copilot = { workspace = true, features = ["test-support"] }

View File

@@ -17,7 +17,7 @@ use gpui::{
};
use multi_buffer::MultiBuffer;
use project::Project;
use text::OffsetRangeExt;
use text::Point;
use ui::{
ButtonCommon, Clickable, Disableable, FluentBuilder as _, IconButton, IconName,
StyledTypography as _, h_flex, v_flex,
@@ -66,7 +66,7 @@ impl EditPredictionContextView {
) -> Self {
let store = EditPredictionStore::global(client, user_store, cx);
let mut debug_rx = store.update(cx, |store, _| store.debug_info());
let mut debug_rx = store.update(cx, |store, cx| store.debug_info(&project, cx));
let _update_task = cx.spawn_in(window, async move |this, cx| {
while let Some(event) = debug_rx.next().await {
this.update_in(cx, |this, window, cx| {
@@ -103,7 +103,8 @@ impl EditPredictionContextView {
self.handle_context_retrieval_finished(info, window, cx);
}
}
DebugEvent::EditPredictionRequested(_) => {}
DebugEvent::EditPredictionStarted(_) => {}
DebugEvent::EditPredictionFinished(_) => {}
}
}
@@ -152,12 +153,11 @@ impl EditPredictionContextView {
run.finished_at = Some(info.timestamp);
run.metadata = info.metadata;
let project = self.project.clone();
let related_files = self
.store
.read(cx)
.context_for_project(&self.project, cx)
.to_vec();
.context_for_project_with_buffers(&self.project, cx)
.map_or(Vec::new(), |files| files.collect());
let editor = run.editor.clone();
let multibuffer = run.editor.read(cx).buffer().clone();
@@ -168,33 +168,14 @@ impl EditPredictionContextView {
cx.spawn_in(window, async move |this, cx| {
let mut paths = Vec::new();
for related_file in related_files {
let (buffer, point_ranges): (_, Vec<_>) =
if let Some(buffer) = related_file.buffer.upgrade() {
let snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot())?;
(
buffer,
related_file
.excerpts
.iter()
.map(|excerpt| excerpt.anchor_range.to_point(&snapshot))
.collect(),
)
} else {
(
project
.update(cx, |project, cx| {
project.open_buffer(related_file.path.clone(), cx)
})?
.await?,
related_file
.excerpts
.iter()
.map(|excerpt| excerpt.point_range.clone())
.collect(),
)
};
for (related_file, buffer) in related_files {
let point_ranges = related_file
.excerpts
.iter()
.map(|excerpt| {
Point::new(excerpt.row_range.start, 0)..Point::new(excerpt.row_range.end, 0)
})
.collect::<Vec<_>>();
cx.update(|_, cx| {
let path = PathKey::for_buffer(&buffer, cx);
paths.push((path, buffer, point_ranges));

View File

@@ -1,5 +1,4 @@
use buffer_diff::{BufferDiff, BufferDiffSnapshot};
use cloud_zeta2_prompt::write_codeblock;
use edit_prediction::{EditPrediction, EditPredictionRating, EditPredictionStore};
use editor::{Editor, ExcerptRange, MultiBuffer};
use feature_flags::FeatureFlag;
@@ -362,14 +361,14 @@ impl RatePredictionsModal {
write!(&mut formatted_inputs, "## Events\n\n").unwrap();
for event in &prediction.inputs.events {
write!(&mut formatted_inputs, "```diff\n{event}```\n\n").unwrap();
formatted_inputs.push_str("```diff\n");
zeta_prompt::write_event(&mut formatted_inputs, event.as_ref());
formatted_inputs.push_str("```\n\n");
}
write!(&mut formatted_inputs, "## Included files\n\n").unwrap();
for included_file in &prediction.inputs.included_files {
let cursor_insertions = &[(prediction.inputs.cursor_point, "<|CURSOR|>")];
write!(&mut formatted_inputs, "## Related files\n\n").unwrap();
for included_file in prediction.inputs.related_files.as_ref() {
write!(
&mut formatted_inputs,
"### {}\n\n",
@@ -377,20 +376,28 @@ impl RatePredictionsModal {
)
.unwrap();
write_codeblock(
&included_file.path,
&included_file.excerpts,
if included_file.path == prediction.inputs.cursor_path {
cursor_insertions.as_slice()
} else {
&[]
},
included_file.max_row,
false,
&mut formatted_inputs,
);
for excerpt in included_file.excerpts.iter() {
write!(
&mut formatted_inputs,
"```{}\n{}\n```\n",
included_file.path.display(),
excerpt.text
)
.unwrap();
}
}
write!(&mut formatted_inputs, "## Cursor Excerpt\n\n").unwrap();
writeln!(
&mut formatted_inputs,
"```{}\n{}<CURSOR>{}\n```\n",
prediction.inputs.cursor_path.display(),
&prediction.inputs.cursor_excerpt[..prediction.inputs.cursor_offset_in_excerpt],
&prediction.inputs.cursor_excerpt[prediction.inputs.cursor_offset_in_excerpt..],
)
.unwrap();
self.active_prediction = Some(ActivePrediction {
prediction,
feedback_editor: cx.new(|cx| {

View File

@@ -20973,9 +20973,22 @@ impl Editor {
buffer_ranges.last()
}?;
let selection = text::ToPoint::to_point(&range.start, buffer).row
..text::ToPoint::to_point(&range.end, buffer).row;
Some((multi_buffer.buffer(buffer.remote_id()).unwrap(), selection))
let start_row_in_buffer = text::ToPoint::to_point(&range.start, buffer).row;
let end_row_in_buffer = text::ToPoint::to_point(&range.end, buffer).row;
let Some(buffer_diff) = multi_buffer.diff_for(buffer.remote_id()) else {
let selection = start_row_in_buffer..end_row_in_buffer;
return Some((multi_buffer.buffer(buffer.remote_id()).unwrap(), selection));
};
let buffer_diff_snapshot = buffer_diff.read(cx).snapshot(cx);
Some((
multi_buffer.buffer(buffer.remote_id()).unwrap(),
buffer_diff_snapshot.row_to_base_text_row(start_row_in_buffer, buffer)
..buffer_diff_snapshot.row_to_base_text_row(end_row_in_buffer, buffer),
))
});
let Some((buffer, selection)) = buffer_and_selection else {

View File

@@ -27701,6 +27701,7 @@ async fn test_markdown_indents(cx: &mut gpui::TestAppContext) {
cx.update_editor(|editor, window, cx| {
editor.handle_input("x", window, cx);
});
cx.run_until_parked();
cx.assert_editor_state(indoc! {"
- [ ] Item 1
- [ ] Item 1.a
@@ -27716,8 +27717,7 @@ async fn test_markdown_indents(cx: &mut gpui::TestAppContext) {
- [ ] Item 1.a
- [x] Item 2
- [x] Item 2.a
- [x] Item 2.bˇ
"
- [x] Item 2.bˇ"
});
cx.update_editor(|editor, window, cx| {
editor.newline(&Newline, window, cx);
@@ -27728,34 +27728,41 @@ async fn test_markdown_indents(cx: &mut gpui::TestAppContext) {
- [x] Item 2
- [x] Item 2.a
- [x] Item 2.b
ˇ
"
ˇ"
});
// Case 3: Test adding a new nested list item preserves indent
cx.set_state(&indoc! {"
- [ ] Item 1
- [ ] Item 1.a
- [x] Item 2
- [x] Item 2.a
- [x] Item 2.b
ˇ"
});
cx.update_editor(|editor, window, cx| {
editor.handle_input("-", window, cx);
});
cx.run_until_parked();
cx.assert_editor_state(indoc! {"
- [ ] Item 1
- [ ] Item 1.a
- [x] Item 2
- [x] Item 2.a
- [x] Item 2.b
"
"
});
cx.update_editor(|editor, window, cx| {
editor.handle_input(" [x] Item 2.c", window, cx);
});
cx.run_until_parked();
cx.assert_editor_state(indoc! {"
- [ ] Item 1
- [ ] Item 1.a
- [x] Item 2
- [x] Item 2.a
- [x] Item 2.b
- [x] Item 2.cˇ
"
- [x] Item 2.cˇ"
});
// Case 4: Test adding new line after nested ordered list preserves indent of previous line
@@ -27764,8 +27771,7 @@ async fn test_markdown_indents(cx: &mut gpui::TestAppContext) {
1. Item 1.a
2. Item 2
1. Item 2.a
2. Item 2.bˇ
"
2. Item 2.bˇ"
});
cx.update_editor(|editor, window, cx| {
editor.newline(&Newline, window, cx);
@@ -27776,60 +27782,81 @@ async fn test_markdown_indents(cx: &mut gpui::TestAppContext) {
2. Item 2
1. Item 2.a
2. Item 2.b
ˇ
"
ˇ"
});
// Case 5: Adding new ordered list item preserves indent
cx.set_state(indoc! {"
1. Item 1
1. Item 1.a
2. Item 2
1. Item 2.a
2. Item 2.b
ˇ"
});
cx.update_editor(|editor, window, cx| {
editor.handle_input("3", window, cx);
});
cx.run_until_parked();
cx.assert_editor_state(indoc! {"
1. Item 1
1. Item 1.a
2. Item 2
1. Item 2.a
2. Item 2.b
"
"
});
cx.update_editor(|editor, window, cx| {
editor.handle_input(".", window, cx);
});
cx.run_until_parked();
cx.assert_editor_state(indoc! {"
1. Item 1
1. Item 1.a
2. Item 2
1. Item 2.a
2. Item 2.b
3.ˇ
"
3.ˇ"
});
cx.update_editor(|editor, window, cx| {
editor.handle_input(" Item 2.c", window, cx);
});
cx.run_until_parked();
cx.assert_editor_state(indoc! {"
1. Item 1
1. Item 1.a
2. Item 2
1. Item 2.a
2. Item 2.b
3. Item 2.cˇ
"
3. Item 2.cˇ"
});
// Case 6: Test adding new line after nested ordered list preserves indent of previous line
cx.set_state(indoc! {"
- Item 1
- Item 1.a
- Item 1.a
ˇ"});
cx.update_editor(|editor, window, cx| {
editor.handle_input("-", window, cx);
});
cx.run_until_parked();
cx.assert_editor_state(indoc! {"
- Item 1
- Item 1.a
- Item 1.a
"});
// Case 7: Test blockquote newline preserves something
cx.set_state(indoc! {"
> Item 1ˇ
"
> Item 1ˇ"
});
cx.update_editor(|editor, window, cx| {
editor.newline(&Newline, window, cx);
});
cx.assert_editor_state(indoc! {"
> Item 1
ˇ
"
ˇ"
});
}

View File

@@ -18,6 +18,7 @@ impl FeatureFlag for InlineAssistantV2FeatureFlag {
const NAME: &'static str = "inline-assistant-v2";
fn enabled_for_staff() -> bool {
false
// false
true
}
}

View File

@@ -232,14 +232,12 @@ impl From<Oid> for usize {
#[derive(Copy, Clone, Debug)]
pub enum RunHook {
PreCommit,
PrePush,
}
impl RunHook {
pub fn as_str(&self) -> &str {
match self {
Self::PreCommit => "pre-commit",
Self::PrePush => "pre-push",
}
}
@@ -250,7 +248,6 @@ impl RunHook {
pub fn from_proto(value: i32) -> Option<Self> {
match value {
0 => Some(Self::PreCommit),
1 => Some(Self::PrePush),
_ => None,
}
}

View File

@@ -652,6 +652,7 @@ pub struct RealGitRepository {
pub repository: Arc<Mutex<git2::Repository>>,
pub system_git_binary_path: Option<PathBuf>,
pub any_git_binary_path: PathBuf,
any_git_binary_help_output: Arc<Mutex<Option<SharedString>>>,
executor: BackgroundExecutor,
}
@@ -670,6 +671,7 @@ impl RealGitRepository {
system_git_binary_path,
any_git_binary_path,
executor,
any_git_binary_help_output: Arc::new(Mutex::new(None)),
})
}
@@ -680,6 +682,27 @@ impl RealGitRepository {
.context("failed to read git work directory")
.map(Path::to_path_buf)
}
async fn any_git_binary_help_output(&self) -> SharedString {
if let Some(output) = self.any_git_binary_help_output.lock().clone() {
return output;
}
let git_binary_path = self.any_git_binary_path.clone();
let executor = self.executor.clone();
let working_directory = self.working_directory();
let output: SharedString = self
.executor
.spawn(async move {
GitBinary::new(git_binary_path, working_directory?, executor)
.run(["help", "-a"])
.await
})
.await
.unwrap_or_default()
.into();
*self.any_git_binary_help_output.lock() = Some(output.clone());
output
}
}
#[derive(Clone, Debug)]
@@ -2290,48 +2313,47 @@ impl GitRepository for RealGitRepository {
env: Arc<HashMap<String, String>>,
) -> BoxFuture<'_, Result<()>> {
let working_directory = self.working_directory();
let repository = self.repository.clone();
let git_binary_path = self.any_git_binary_path.clone();
let executor = self.executor.clone();
self.executor
.spawn(async move {
let working_directory = working_directory?;
let git = GitBinary::new(git_binary_path, working_directory.clone(), executor)
.envs(HashMap::clone(&env));
let help_output = self.any_git_binary_help_output();
let output = git.run(&["help", "-a"]).await?;
if !output.lines().any(|line| line.trim().starts_with("hook ")) {
log::warn!(
"git hook command not available, running the {} hook manually",
hook.as_str()
);
// Note: Do not spawn these commands on the background thread, as this causes some git hooks to hang.
async move {
let working_directory = working_directory?;
if !help_output
.await
.lines()
.any(|line| line.trim().starts_with("hook "))
{
let hook_abs_path = repository.lock().path().join("hooks").join(hook.as_str());
if hook_abs_path.is_file() {
let output = new_smol_command(&hook_abs_path)
.envs(env.iter())
.current_dir(&working_directory)
.output()
.await?;
let hook_abs_path = working_directory
.join(".git")
.join("hooks")
.join(hook.as_str());
if hook_abs_path.is_file() {
let output = new_smol_command(&hook_abs_path)
.envs(env.iter())
.current_dir(&working_directory)
.output()
.await?;
anyhow::ensure!(
output.status.success(),
"{} hook failed:\n{}",
hook.as_str(),
String::from_utf8_lossy(&output.stderr)
);
if !output.status.success() {
return Err(GitBinaryCommandError {
stdout: String::from_utf8_lossy(&output.stdout).into_owned(),
stderr: String::from_utf8_lossy(&output.stderr).into_owned(),
status: output.status,
}
.into());
}
return Ok(());
}
git.run(&["hook", "run", "--ignore-missing", hook.as_str()])
.await?;
Ok(())
})
.boxed()
return Ok(());
}
let git = GitBinary::new(git_binary_path, working_directory, executor)
.envs(HashMap::clone(&env));
git.run(&["hook", "run", "--ignore-missing", hook.as_str()])
.await?;
Ok(())
}
.boxed()
}
}

View File

@@ -1,6 +1,7 @@
use crate::{App, PlatformDispatcher, RunnableMeta, RunnableVariant};
use async_task::Runnable;
use futures::channel::mpsc;
use parking_lot::{Condvar, Mutex};
use smol::prelude::*;
use std::{
fmt::Debug,
@@ -154,6 +155,57 @@ impl BackgroundExecutor {
self.spawn_internal::<R>(Box::pin(future), None)
}
/// Enqueues the given future to be run to completion on a background thread and blocking the current task on it.
///
/// This allows to spawn background work that borrows from its scope. Note that the supplied future will run to
/// completion before the current task is resumed, even if the current task is slated for cancellation.
pub async fn await_on_background<R>(&self, future: impl Future<Output = R> + Send) -> R
where
R: Send,
{
// We need to ensure that cancellation of the parent task does not drop the environment
// before the our own task has completed or got cancelled.
struct NotifyOnDrop<'a>(&'a (Condvar, Mutex<bool>));
impl Drop for NotifyOnDrop<'_> {
fn drop(&mut self) {
*self.0.1.lock() = true;
self.0.0.notify_all();
}
}
struct WaitOnDrop<'a>(&'a (Condvar, Mutex<bool>));
impl Drop for WaitOnDrop<'_> {
fn drop(&mut self) {
let mut done = self.0.1.lock();
if !*done {
self.0.0.wait(&mut done);
}
}
}
let dispatcher = self.dispatcher.clone();
let location = core::panic::Location::caller();
let pair = &(Condvar::new(), Mutex::new(false));
let _wait_guard = WaitOnDrop(pair);
let (runnable, task) = unsafe {
async_task::Builder::new()
.metadata(RunnableMeta { location })
.spawn_unchecked(
move |_| async {
let _notify_guard = NotifyOnDrop(pair);
future.await
},
move |runnable| dispatcher.dispatch(RunnableVariant::Meta(runnable), None),
)
};
runnable.schedule();
task.await
}
/// Enqueues the given future to be run to completion on a background thread.
/// The given label can be used to control the priority of the task in tests.
#[track_caller]

View File

@@ -289,6 +289,13 @@ pub trait PlatformDisplay: Send + Sync + Debug {
/// Get the bounds for this display
fn bounds(&self) -> Bounds<Pixels>;
/// Get the visible bounds for this display, excluding taskbar/dock areas.
/// This is the usable area where windows can be placed without being obscured.
/// Defaults to the full display bounds if not overridden.
fn visible_bounds(&self) -> Bounds<Pixels> {
self.bounds()
}
/// Get the default bounds for this display to place a window
fn default_bounds(&self) -> Bounds<Pixels> {
let bounds = self.bounds();

View File

@@ -1,9 +1,9 @@
use crate::{Bounds, DisplayId, Pixels, PlatformDisplay, px, size};
use crate::{Bounds, DisplayId, Pixels, PlatformDisplay, point, px, size};
use anyhow::Result;
use cocoa::{
appkit::NSScreen,
base::{id, nil},
foundation::{NSDictionary, NSString},
foundation::{NSArray, NSDictionary, NSString},
};
use core_foundation::uuid::{CFUUIDGetUUIDBytes, CFUUIDRef};
use core_graphics::display::{CGDirectDisplayID, CGDisplayBounds, CGGetActiveDisplayList};
@@ -114,4 +114,53 @@ impl PlatformDisplay for MacDisplay {
}
}
}
fn visible_bounds(&self) -> Bounds<Pixels> {
unsafe {
let dominated_screen = self.get_nsscreen();
if dominated_screen == nil {
return self.bounds();
}
let screen_frame = NSScreen::frame(dominated_screen);
let visible_frame = NSScreen::visibleFrame(dominated_screen);
// Convert from bottom-left origin (AppKit) to top-left origin
let origin_y =
screen_frame.size.height - visible_frame.origin.y - visible_frame.size.height
+ screen_frame.origin.y;
Bounds {
origin: point(
px(visible_frame.origin.x as f32 - screen_frame.origin.x as f32),
px(origin_y as f32),
),
size: size(
px(visible_frame.size.width as f32),
px(visible_frame.size.height as f32),
),
}
}
}
}
impl MacDisplay {
/// Find the NSScreen corresponding to this display
unsafe fn get_nsscreen(&self) -> id {
let screens = unsafe { NSScreen::screens(nil) };
let count = unsafe { NSArray::count(screens) };
let screen_number_key: id = unsafe { NSString::alloc(nil).init_str("NSScreenNumber") };
for i in 0..count {
let screen = unsafe { NSArray::objectAtIndex(screens, i) };
let device_description = unsafe { NSScreen::deviceDescription(screen) };
let screen_number = unsafe { device_description.objectForKey_(screen_number_key) };
let screen_id: CGDirectDisplayID = msg_send![screen_number, unsignedIntegerValue];
if screen_id == self.0 {
return screen;
}
}
nil
}
}

View File

@@ -23,6 +23,7 @@ pub(crate) struct WindowsDisplay {
pub display_id: DisplayId,
scale_factor: f32,
bounds: Bounds<Pixels>,
visible_bounds: Bounds<Pixels>,
physical_bounds: Bounds<DevicePixels>,
uuid: Uuid,
}
@@ -36,6 +37,7 @@ impl WindowsDisplay {
let screen = available_monitors().into_iter().nth(display_id.0 as _)?;
let info = get_monitor_info(screen).log_err()?;
let monitor_size = info.monitorInfo.rcMonitor;
let work_area = info.monitorInfo.rcWork;
let uuid = generate_uuid(&info.szDevice);
let scale_factor = get_scale_factor_for_monitor(screen).log_err()?;
let physical_size = size(
@@ -55,6 +57,14 @@ impl WindowsDisplay {
),
size: physical_size.to_pixels(scale_factor),
},
visible_bounds: Bounds {
origin: logical_point(work_area.left as f32, work_area.top as f32, scale_factor),
size: size(
(work_area.right - work_area.left) as f32 / scale_factor,
(work_area.bottom - work_area.top) as f32 / scale_factor,
)
.map(crate::px),
},
physical_bounds: Bounds {
origin: point(monitor_size.left.into(), monitor_size.top.into()),
size: physical_size,
@@ -66,6 +76,7 @@ impl WindowsDisplay {
pub fn new_with_handle(monitor: HMONITOR) -> anyhow::Result<Self> {
let info = get_monitor_info(monitor)?;
let monitor_size = info.monitorInfo.rcMonitor;
let work_area = info.monitorInfo.rcWork;
let uuid = generate_uuid(&info.szDevice);
let display_id = available_monitors()
.iter()
@@ -89,6 +100,14 @@ impl WindowsDisplay {
),
size: physical_size.to_pixels(scale_factor),
},
visible_bounds: Bounds {
origin: logical_point(work_area.left as f32, work_area.top as f32, scale_factor),
size: size(
(work_area.right - work_area.left) as f32 / scale_factor,
(work_area.bottom - work_area.top) as f32 / scale_factor,
)
.map(crate::px),
},
physical_bounds: Bounds {
origin: point(monitor_size.left.into(), monitor_size.top.into()),
size: physical_size,
@@ -100,6 +119,7 @@ impl WindowsDisplay {
fn new_with_handle_and_id(handle: HMONITOR, display_id: DisplayId) -> anyhow::Result<Self> {
let info = get_monitor_info(handle)?;
let monitor_size = info.monitorInfo.rcMonitor;
let work_area = info.monitorInfo.rcWork;
let uuid = generate_uuid(&info.szDevice);
let scale_factor = get_scale_factor_for_monitor(handle)?;
let physical_size = size(
@@ -119,6 +139,14 @@ impl WindowsDisplay {
),
size: physical_size.to_pixels(scale_factor),
},
visible_bounds: Bounds {
origin: logical_point(work_area.left as f32, work_area.top as f32, scale_factor),
size: size(
(work_area.right - work_area.left) as f32 / scale_factor,
(work_area.bottom - work_area.top) as f32 / scale_factor,
)
.map(crate::px),
},
physical_bounds: Bounds {
origin: point(monitor_size.left.into(), monitor_size.top.into()),
size: physical_size,
@@ -193,6 +221,10 @@ impl PlatformDisplay for WindowsDisplay {
fn bounds(&self) -> Bounds<Pixels> {
self.bounds
}
fn visible_bounds(&self) -> Bounds<Pixels> {
self.visible_bounds
}
}
fn available_monitors() -> SmallVec<[HMONITOR; 4]> {

View File

@@ -918,86 +918,69 @@ pub(crate) struct ElementStateBox {
pub(crate) type_name: &'static str,
}
fn default_bounds(display_id: Option<DisplayId>, cx: &mut App) -> Bounds<Pixels> {
#[cfg(target_os = "macos")]
{
const CASCADE_OFFSET: f32 = 25.0;
fn default_bounds(display_id: Option<DisplayId>, cx: &mut App) -> WindowBounds {
// TODO, BUG: if you open a window with the currently active window
// on the stack, this will erroneously fallback to `None`
//
// TODO these should be the initial window bounds not considering maximized/fullscreen
let active_window_bounds = cx
.active_window()
.and_then(|w| w.update(cx, |_, window, _| window.window_bounds()).ok());
let display = display_id
.map(|id| cx.find_display(id))
.unwrap_or_else(|| cx.primary_display());
const CASCADE_OFFSET: f32 = 25.0;
let display_bounds = display
.as_ref()
.map(|d| d.bounds())
.unwrap_or_else(|| Bounds::new(point(px(0.), px(0.)), DEFAULT_WINDOW_SIZE));
let display = display_id
.map(|id| cx.find_display(id))
.unwrap_or_else(|| cx.primary_display());
// TODO, BUG: if you open a window with the currently active window
// on the stack, this will erroneously select the 'unwrap_or_else'
// code path
let (base_origin, base_size) = cx
.active_window()
.and_then(|w| {
w.update(cx, |_, window, _| {
let bounds = window.bounds();
(bounds.origin, bounds.size)
})
.ok()
})
.unwrap_or_else(|| {
let default_bounds = display
.as_ref()
.map(|d| d.default_bounds())
.unwrap_or_else(|| Bounds::new(point(px(0.), px(0.)), DEFAULT_WINDOW_SIZE));
(default_bounds.origin, default_bounds.size)
});
let default_placement = || Bounds::new(point(px(0.), px(0.)), DEFAULT_WINDOW_SIZE);
let cascade_offset = point(px(CASCADE_OFFSET), px(CASCADE_OFFSET));
let proposed_origin = base_origin + cascade_offset;
let proposed_bounds = Bounds::new(proposed_origin, base_size);
// Use visible_bounds to exclude taskbar/dock areas
let display_bounds = display
.as_ref()
.map(|d| d.visible_bounds())
.unwrap_or_else(default_placement);
let display_right = display_bounds.origin.x + display_bounds.size.width;
let display_bottom = display_bounds.origin.y + display_bounds.size.height;
let window_right = proposed_bounds.origin.x + proposed_bounds.size.width;
let window_bottom = proposed_bounds.origin.y + proposed_bounds.size.height;
let (
Bounds {
origin: base_origin,
size: base_size,
},
window_bounds_ctor,
): (_, fn(Bounds<Pixels>) -> WindowBounds) = match active_window_bounds {
Some(bounds) => match bounds {
WindowBounds::Windowed(bounds) => (bounds, WindowBounds::Windowed),
WindowBounds::Maximized(bounds) => (bounds, WindowBounds::Maximized),
WindowBounds::Fullscreen(bounds) => (bounds, WindowBounds::Fullscreen),
},
None => (
display
.as_ref()
.map(|d| d.default_bounds())
.unwrap_or_else(default_placement),
WindowBounds::Windowed,
),
};
let fits_horizontally = window_right <= display_right;
let fits_vertically = window_bottom <= display_bottom;
let cascade_offset = point(px(CASCADE_OFFSET), px(CASCADE_OFFSET));
let proposed_origin = base_origin + cascade_offset;
let proposed_bounds = Bounds::new(proposed_origin, base_size);
let final_origin = match (fits_horizontally, fits_vertically) {
(true, true) => proposed_origin,
(false, true) => point(display_bounds.origin.x, base_origin.y),
(true, false) => point(base_origin.x, display_bounds.origin.y),
(false, false) => display_bounds.origin,
};
let display_right = display_bounds.origin.x + display_bounds.size.width;
let display_bottom = display_bounds.origin.y + display_bounds.size.height;
let window_right = proposed_bounds.origin.x + proposed_bounds.size.width;
let window_bottom = proposed_bounds.origin.y + proposed_bounds.size.height;
Bounds::new(final_origin, base_size)
}
let fits_horizontally = window_right <= display_right;
let fits_vertically = window_bottom <= display_bottom;
#[cfg(not(target_os = "macos"))]
{
const DEFAULT_WINDOW_OFFSET: Point<Pixels> = point(px(0.), px(35.));
// TODO, BUG: if you open a window with the currently active window
// on the stack, this will erroneously select the 'unwrap_or_else'
// code path
cx.active_window()
.and_then(|w| w.update(cx, |_, window, _| window.bounds()).ok())
.map(|mut bounds| {
bounds.origin += DEFAULT_WINDOW_OFFSET;
bounds
})
.unwrap_or_else(|| {
let display = display_id
.map(|id| cx.find_display(id))
.unwrap_or_else(|| cx.primary_display());
display
.as_ref()
.map(|display| display.default_bounds())
.unwrap_or_else(|| Bounds::new(point(px(0.), px(0.)), DEFAULT_WINDOW_SIZE))
})
}
let final_origin = match (fits_horizontally, fits_vertically) {
(true, true) => proposed_origin,
(false, true) => point(display_bounds.origin.x, base_origin.y),
(true, false) => point(base_origin.x, display_bounds.origin.y),
(false, false) => display_bounds.origin,
};
window_bounds_ctor(Bounds::new(final_origin, base_size))
}
impl Window {
@@ -1024,13 +1007,11 @@ impl Window {
tabbing_identifier,
} = options;
let bounds = window_bounds
.map(|bounds| bounds.get_bounds())
.unwrap_or_else(|| default_bounds(display_id, cx));
let window_bounds = window_bounds.unwrap_or_else(|| default_bounds(display_id, cx));
let mut platform_window = cx.platform.open_window(
handle,
WindowParams {
bounds,
bounds: window_bounds.get_bounds(),
titlebar,
kind,
is_movable,
@@ -1071,12 +1052,10 @@ impl Window {
.request_decorations(window_decorations.unwrap_or(WindowDecorations::Server));
platform_window.set_background_appearance(window_background);
if let Some(ref window_open_state) = window_bounds {
match window_open_state {
WindowBounds::Fullscreen(_) => platform_window.toggle_fullscreen(),
WindowBounds::Maximized(_) => platform_window.zoom(),
WindowBounds::Windowed(_) => {}
}
match window_bounds {
WindowBounds::Fullscreen(_) => platform_window.toggle_fullscreen(),
WindowBounds::Maximized(_) => platform_window.zoom(),
WindowBounds::Windowed(_) => {}
}
platform_window.on_close(Box::new({
@@ -1518,7 +1497,8 @@ impl Window {
style
}
/// Check if the platform window is maximized
/// Check if the platform window is maximized.
///
/// On some platforms (namely Windows) this is different than the bounds being the size of the display
pub fn is_maximized(&self) -> bool {
self.platform_window.is_maximized()

View File

@@ -535,7 +535,7 @@ pub trait LspInstaller {
_version: &Self::BinaryVersion,
_container_dir: &PathBuf,
_delegate: &dyn LspAdapterDelegate,
) -> impl Future<Output = Option<LanguageServerBinary>> {
) -> impl Send + Future<Output = Option<LanguageServerBinary>> {
async { None }
}
@@ -544,7 +544,7 @@ pub trait LspInstaller {
latest_version: Self::BinaryVersion,
container_dir: PathBuf,
delegate: &dyn LspAdapterDelegate,
) -> impl Future<Output = Result<LanguageServerBinary>>;
) -> impl Send + Future<Output = Result<LanguageServerBinary>>;
fn cached_server_binary(
&self,
@@ -575,6 +575,7 @@ pub trait DynLspInstaller {
#[async_trait(?Send)]
impl<LI, BinaryVersion> DynLspInstaller for LI
where
BinaryVersion: Send + Sync,
LI: LspInstaller<BinaryVersion = BinaryVersion> + LspAdapter,
{
async fn try_fetch_server_binary(
@@ -593,8 +594,13 @@ where
.fetch_latest_server_version(delegate.as_ref(), pre_release, cx)
.await?;
if let Some(binary) = self
.check_if_version_installed(&latest_version, &container_dir, delegate.as_ref())
if let Some(binary) = cx
.background_executor()
.await_on_background(self.check_if_version_installed(
&latest_version,
&container_dir,
delegate.as_ref(),
))
.await
{
log::debug!("language server {:?} is already installed", name.0);
@@ -603,8 +609,13 @@ where
} else {
log::debug!("downloading language server {:?}", name.0);
delegate.update_status(name.clone(), BinaryStatus::Downloading);
let binary = self
.fetch_server_binary(latest_version, container_dir, delegate.as_ref())
let binary = cx
.background_executor()
.await_on_background(self.fetch_server_binary(
latest_version,
container_dir,
delegate.as_ref(),
))
.await;
delegate.update_status(name.clone(), BinaryStatus::None);

View File

@@ -2,6 +2,40 @@
(identifier) @variable
(call_expression
function: (member_expression
object: (identifier) @type.builtin
(#any-of?
@type.builtin
"Promise"
"Array"
"Object"
"Map"
"Set"
"WeakMap"
"WeakSet"
"Date"
"Error"
"TypeError"
"RangeError"
"SyntaxError"
"ReferenceError"
"EvalError"
"URIError"
"RegExp"
"Function"
"Number"
"String"
"Boolean"
"Symbol"
"BigInt"
"Proxy"
"ArrayBuffer"
"DataView"
)
)
)
; Properties
(property_identifier) @property
@@ -18,6 +52,12 @@
function: (member_expression
property: [(property_identifier) (private_property_identifier)] @function.method))
(new_expression
constructor: (identifier) @type)
(nested_type_identifier
module: (identifier) @type)
; Function and method definitions
(function_expression
@@ -47,10 +87,45 @@
left: (identifier) @function
right: [(function_expression) (arrow_function)])
; Special identifiers
; Parameters
(required_parameter
(identifier) @variable.parameter)
(required_parameter
(_
([
(identifier)
(shorthand_property_identifier_pattern)
]) @variable.parameter))
(optional_parameter
(identifier) @variable.parameter)
(optional_parameter
(_
([
(identifier)
(shorthand_property_identifier_pattern)
]) @variable.parameter))
(catch_clause
parameter: (identifier) @variable.parameter)
(index_signature
name: (identifier) @variable.parameter)
(arrow_function
parameter: (identifier) @variable.parameter)
; Special identifiers
;
(class_declaration
(type_identifier) @type.class)
(extends_clause
value: (identifier) @type.class)
((identifier) @type
(#match? @type "^[A-Z]"))
(type_identifier) @type
(predefined_type) @type.builtin
@@ -251,6 +326,34 @@
(jsx_closing_element (identifier) @tag.jsx (#match? @tag.jsx "^[a-z][^.]*$"))
(jsx_self_closing_element (identifier) @tag.jsx (#match? @tag.jsx "^[a-z][^.]*$"))
(jsx_opening_element
[
(identifier) @type
(member_expression
object: (identifier) @type
property: (property_identifier) @type
)
]
)
(jsx_closing_element
[
(identifier) @type
(member_expression
object: (identifier) @type
property: (property_identifier) @type
)
]
)
(jsx_self_closing_element
[
(identifier) @type
(member_expression
object: (identifier) @type
property: (property_identifier) @type
)
]
)
(jsx_attribute (property_identifier) @attribute.jsx)
(jsx_opening_element (["<" ">"]) @punctuation.bracket.jsx)
(jsx_closing_element (["</" ">"]) @punctuation.bracket.jsx)

View File

@@ -1,2 +1,3 @@
(tag_name) @keyword.jsdoc
(type) @type.jsdoc
(identifier) @variable.jsdoc

View File

@@ -24,5 +24,9 @@ rewrap_prefixes = [
auto_indent_on_paste = false
auto_indent_using_last_non_empty_line = false
tab_size = 2
decrease_indent_pattern = "^.*$"
decrease_indent_patterns = [
{ pattern = "^\\s*-", valid_after = ["list_item"] },
{ pattern = "^\\s*\\d", valid_after = ["list_item"] },
{ pattern = "^\\s*", valid_after = ["list_item"] },
]
prettier_parser_name = "markdown"

View File

@@ -1 +1,3 @@
(list (list_item) @indent)
(list_item) @start.list_item

View File

@@ -2,6 +2,40 @@
(identifier) @variable
(call_expression
function: (member_expression
object: (identifier) @type.builtin
(#any-of?
@type.builtin
"Promise"
"Array"
"Object"
"Map"
"Set"
"WeakMap"
"WeakSet"
"Date"
"Error"
"TypeError"
"RangeError"
"SyntaxError"
"ReferenceError"
"EvalError"
"URIError"
"RegExp"
"Function"
"Number"
"String"
"Boolean"
"Symbol"
"BigInt"
"Proxy"
"ArrayBuffer"
"DataView"
)
)
)
; Properties
(property_identifier) @property
@@ -18,6 +52,12 @@
function: (member_expression
property: [(property_identifier) (private_property_identifier)] @function.method))
(new_expression
constructor: (identifier) @type)
(nested_type_identifier
module: (identifier) @type)
; Function and method definitions
(function_expression
@@ -47,13 +87,68 @@
left: (identifier) @function
right: [(function_expression) (arrow_function)])
; Parameters
(required_parameter
(identifier) @variable.parameter)
(required_parameter
(_
([
(identifier)
(shorthand_property_identifier_pattern)
]) @variable.parameter))
(optional_parameter
(identifier) @variable.parameter)
(optional_parameter
(_
([
(identifier)
(shorthand_property_identifier_pattern)
]) @variable.parameter))
(catch_clause
parameter: (identifier) @variable.parameter)
(index_signature
name: (identifier) @variable.parameter)
(arrow_function
parameter: (identifier) @variable.parameter)
(type_predicate
name: (identifier) @variable.parameter)
; Special identifiers
((identifier) @type
(#match? @type "^[A-Z]"))
(type_annotation) @type
(type_identifier) @type
(predefined_type) @type.builtin
(type_alias_declaration
(type_identifier) @type)
(type_alias_declaration
value: (_
(type_identifier) @type))
(interface_declaration
(type_identifier) @type)
(class_declaration
(type_identifier) @type.class)
(extends_clause
value: (identifier) @type.class)
(extends_type_clause
type: (type_identifier) @type)
(implements_clause
(type_identifier) @type)
([
(identifier)
(shorthand_property_identifier)
@@ -231,8 +326,42 @@
"<" @punctuation.bracket
">" @punctuation.bracket)
(type_parameters
"<" @punctuation.bracket
">" @punctuation.bracket)
(decorator "@" @punctuation.special)
(union_type
("|") @punctuation.special)
(intersection_type
("&") @punctuation.special)
(type_annotation
(":") @punctuation.special)
(index_signature
(":") @punctuation.special)
(type_predicate_annotation
(":") @punctuation.special)
(public_field_definition
("?") @punctuation.special)
(property_signature
("?") @punctuation.special)
(method_signature
("?") @punctuation.special)
(optional_parameter
([
"?"
":"
]) @punctuation.special)
; Keywords
[ "abstract"
@@ -257,6 +386,34 @@
(jsx_closing_element (identifier) @tag.jsx (#match? @tag.jsx "^[a-z][^.]*$"))
(jsx_self_closing_element (identifier) @tag.jsx (#match? @tag.jsx "^[a-z][^.]*$"))
(jsx_opening_element
[
(identifier) @type
(member_expression
object: (identifier) @type
property: (property_identifier) @type
)
]
)
(jsx_closing_element
[
(identifier) @type
(member_expression
object: (identifier) @type
property: (property_identifier) @type
)
]
)
(jsx_self_closing_element
[
(identifier) @type
(member_expression
object: (identifier) @type
property: (property_identifier) @type
)
]
)
(jsx_attribute (property_identifier) @attribute.jsx)
(jsx_opening_element (["<" ">"]) @punctuation.bracket.jsx)
(jsx_closing_element (["</" ">"]) @punctuation.bracket.jsx)

View File

@@ -2,13 +2,69 @@
(identifier) @variable
(call_expression
function: (member_expression
object: (identifier) @type.builtin
(#any-of?
@type.builtin
"Promise"
"Array"
"Object"
"Map"
"Set"
"WeakMap"
"WeakSet"
"Date"
"Error"
"TypeError"
"RangeError"
"SyntaxError"
"ReferenceError"
"EvalError"
"URIError"
"RegExp"
"Function"
"Number"
"String"
"Boolean"
"Symbol"
"BigInt"
"Proxy"
"ArrayBuffer"
"DataView"
)
)
)
; Special identifiers
((identifier) @type
(#match? @type "^[A-Z]"))
(type_annotation) @type
(type_identifier) @type
(predefined_type) @type.builtin
(type_alias_declaration
(type_identifier) @type)
(type_alias_declaration
value: (_
(type_identifier) @type))
(interface_declaration
(type_identifier) @type)
(class_declaration
(type_identifier) @type.class)
(extends_clause
value: (identifier) @type.class)
(extends_type_clause
type: (type_identifier) @type)
(implements_clause
(type_identifier) @type)
;; Enables ts-pretty-errors
;; The Lsp returns "snippets" of typescript, which are not valid typescript in totality,
;; but should still be highlighted
@@ -83,6 +139,12 @@
function: (member_expression
property: [(property_identifier) (private_property_identifier)] @function.method))
(new_expression
constructor: (identifier) @type)
(nested_type_identifier
module: (identifier) @type)
; Function and method definitions
(function_expression
@@ -114,6 +176,40 @@
(arrow_function) @function
; Parameters
(required_parameter
(identifier) @variable.parameter)
(required_parameter
(_
([
(identifier)
(shorthand_property_identifier_pattern)
]) @variable.parameter))
(optional_parameter
(identifier) @variable.parameter)
(optional_parameter
(_
([
(identifier)
(shorthand_property_identifier_pattern)
]) @variable.parameter))
(catch_clause
parameter: (identifier) @variable.parameter)
(index_signature
name: (identifier) @variable.parameter)
(arrow_function
parameter: (identifier) @variable.parameter)
(type_predicate
name: (identifier) @variable.parameter)
; Literals
(this) @variable.special
@@ -244,8 +340,42 @@
"<" @punctuation.bracket
">" @punctuation.bracket)
(type_parameters
"<" @punctuation.bracket
">" @punctuation.bracket)
(decorator "@" @punctuation.special)
(union_type
("|") @punctuation.special)
(intersection_type
("&") @punctuation.special)
(type_annotation
(":") @punctuation.special)
(index_signature
(":") @punctuation.special)
(type_predicate_annotation
(":") @punctuation.special)
(public_field_definition
("?") @punctuation.special)
(property_signature
("?") @punctuation.special)
(method_signature
("?") @punctuation.special)
(optional_parameter
([
"?"
":"
]) @punctuation.special)
; Keywords
[

View File

@@ -314,7 +314,7 @@ pub struct AdapterServerCapabilities {
impl LanguageServer {
/// Starts a language server process.
pub fn new(
pub async fn new(
stderr_capture: Arc<Mutex<Option<String>>>,
server_id: LanguageServerId,
server_name: LanguageServerName,
@@ -331,26 +331,30 @@ impl LanguageServer {
};
let root_uri = Uri::from_file_path(&working_dir)
.map_err(|()| anyhow!("{working_dir:?} is not a valid URI"))?;
log::info!(
"starting language server process. binary path: {:?}, working directory: {:?}, args: {:?}",
"starting language server process. binary path: \
{:?}, working directory: {:?}, args: {:?}",
binary.path,
working_dir,
&binary.arguments
);
let mut command = util::command::new_smol_command(&binary.path);
command
.current_dir(working_dir)
.args(&binary.arguments)
.envs(binary.env.clone().unwrap_or_default())
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.kill_on_drop(true);
let mut server = command
.spawn()
.with_context(|| format!("failed to spawn command {command:?}",))?;
let mut server = cx
.background_executor()
.await_on_background(async {
let mut command = util::command::new_smol_command(&binary.path);
command
.current_dir(working_dir)
.args(&binary.arguments)
.envs(binary.env.clone().unwrap_or_default())
.stdin(Stdio::piped())
.stdout(Stdio::piped())
.stderr(Stdio::piped())
.kill_on_drop(true);
command
.spawn()
.with_context(|| format!("failed to spawn command {command:?}",))
})
.await?;
let stdin = server.stdin.take().unwrap();
let stdout = server.stdout.take().unwrap();

View File

@@ -320,6 +320,7 @@ impl Prettier {
Default::default(),
&mut cx,
)
.await
.context("prettier server creation")?;
let server = cx

View File

@@ -4692,11 +4692,9 @@ impl Repository {
});
let this = cx.weak_entity();
let rx = self.run_hook(RunHook::PrePush, cx);
self.send_job(
Some(format!("git push {} {} {}", args, remote, branch).into()),
move |git_repo, mut cx| async move {
rx.await??;
match git_repo {
RepositoryState::Local(LocalRepositoryState {
backend,

View File

@@ -422,6 +422,7 @@ impl LocalLspStore {
Some(pending_workspace_folders),
cx,
)
.await
}
});

View File

@@ -580,7 +580,7 @@ message GitCreateWorktree {
message RunGitHook {
enum GitHook {
PRE_COMMIT = 0;
PRE_PUSH = 1;
reserved 1;
}
uint64 project_id = 1;

View File

@@ -132,7 +132,7 @@ async fn spawn_and_read_fd(
#[cfg(windows)]
async fn capture_windows(
shell_path: &Path,
_args: &[String],
args: &[String],
directory: &Path,
) -> Result<collections::HashMap<String, String>> {
use std::process::Stdio;
@@ -141,17 +141,17 @@ async fn capture_windows(
std::env::current_exe().context("Failed to determine current zed executable path.")?;
let shell_kind = ShellKind::new(shell_path, true);
if let ShellKind::Csh | ShellKind::Tcsh | ShellKind::Rc | ShellKind::Fish | ShellKind::Xonsh =
shell_kind
{
return Err(anyhow::anyhow!("unsupported shell kind"));
}
let mut cmd = crate::command::new_smol_command(shell_path);
cmd.args(args);
let cmd = match shell_kind {
ShellKind::Csh | ShellKind::Tcsh | ShellKind::Rc | ShellKind::Fish | ShellKind::Xonsh => {
unreachable!()
}
ShellKind::Posix => cmd.args([
ShellKind::Csh
| ShellKind::Tcsh
| ShellKind::Rc
| ShellKind::Fish
| ShellKind::Xonsh
| ShellKind::Posix => cmd.args([
"-l",
"-i",
"-c",
&format!(
"cd '{}'; '{}' --printenv",

View File

@@ -0,0 +1,15 @@
[package]
name = "zeta_prompt"
version = "0.1.0"
publish.workspace = true
edition.workspace = true
license = "GPL-3.0-or-later"
[lints]
workspace = true
[lib]
path = "src/zeta_prompt.rs"
[dependencies]
serde.workspace = true

View File

@@ -0,0 +1,165 @@
use serde::{Deserialize, Serialize};
use std::fmt::Write;
use std::ops::Range;
use std::path::Path;
use std::sync::Arc;
pub const CURSOR_MARKER: &str = "<|user_cursor|>";
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ZetaPromptInput {
pub cursor_path: Arc<Path>,
pub cursor_excerpt: Arc<str>,
pub editable_range_in_excerpt: Range<usize>,
pub cursor_offset_in_excerpt: usize,
pub events: Vec<Arc<Event>>,
pub related_files: Arc<[RelatedFile]>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
#[serde(tag = "event")]
pub enum Event {
BufferChange {
path: Arc<Path>,
old_path: Arc<Path>,
diff: String,
predicted: bool,
in_open_source_repo: bool,
},
}
pub fn write_event(prompt: &mut String, event: &Event) {
fn write_path_as_unix_str(prompt: &mut String, path: &Path) {
for component in path.components() {
prompt.push('/');
write!(prompt, "{}", component.as_os_str().display()).ok();
}
}
match event {
Event::BufferChange {
path,
old_path,
diff,
predicted,
in_open_source_repo: _,
} => {
if *predicted {
prompt.push_str("// User accepted prediction:\n");
}
prompt.push_str("--- a");
write_path_as_unix_str(prompt, old_path.as_ref());
prompt.push_str("\n+++ b");
write_path_as_unix_str(prompt, path.as_ref());
prompt.push('\n');
prompt.push_str(diff);
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RelatedFile {
pub path: Arc<Path>,
pub max_row: u32,
pub excerpts: Vec<RelatedExcerpt>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RelatedExcerpt {
pub row_range: Range<u32>,
pub text: String,
}
pub fn format_zeta_prompt(input: &ZetaPromptInput) -> String {
let mut prompt = String::new();
write_related_files(&mut prompt, &input.related_files);
write_edit_history_section(&mut prompt, input);
write_cursor_excerpt_section(&mut prompt, input);
prompt
}
pub fn write_related_files(prompt: &mut String, related_files: &[RelatedFile]) {
push_delimited(prompt, "related_files", &[], |prompt| {
for file in related_files {
let path_str = file.path.to_string_lossy();
push_delimited(prompt, "related_file", &[("path", &path_str)], |prompt| {
for excerpt in &file.excerpts {
push_delimited(
prompt,
"related_excerpt",
&[(
"lines",
&format!(
"{}-{}",
excerpt.row_range.start + 1,
excerpt.row_range.end + 1
),
)],
|prompt| {
prompt.push_str(&excerpt.text);
prompt.push('\n');
},
);
}
});
}
});
}
fn write_edit_history_section(prompt: &mut String, input: &ZetaPromptInput) {
push_delimited(prompt, "edit_history", &[], |prompt| {
if input.events.is_empty() {
prompt.push_str("(No edit history)");
} else {
for event in &input.events {
write_event(prompt, event);
}
}
});
}
fn write_cursor_excerpt_section(prompt: &mut String, input: &ZetaPromptInput) {
push_delimited(prompt, "cursor_excerpt", &[], |prompt| {
let path_str = input.cursor_path.to_string_lossy();
push_delimited(prompt, "file", &[("path", &path_str)], |prompt| {
prompt.push_str(&input.cursor_excerpt[..input.editable_range_in_excerpt.start]);
push_delimited(prompt, "editable_region", &[], |prompt| {
prompt.push_str(
&input.cursor_excerpt
[input.editable_range_in_excerpt.start..input.cursor_offset_in_excerpt],
);
prompt.push_str(CURSOR_MARKER);
prompt.push_str(
&input.cursor_excerpt
[input.cursor_offset_in_excerpt..input.editable_range_in_excerpt.end],
);
});
prompt.push_str(&input.cursor_excerpt[input.editable_range_in_excerpt.end..]);
});
});
}
fn push_delimited(
prompt: &mut String,
tag: &'static str,
arguments: &[(&str, &str)],
cb: impl FnOnce(&mut String),
) {
if !prompt.ends_with("\n") {
prompt.push('\n');
}
prompt.push('<');
prompt.push_str(tag);
for (arg_name, arg_value) in arguments {
write!(prompt, " {}=\"{}\"", arg_name, arg_value).ok();
}
prompt.push_str(">\n");
cb(prompt);
if !prompt.ends_with('\n') {
prompt.push('\n');
}
prompt.push_str("</");
prompt.push_str(tag);
prompt.push_str(">\n");
}

View File

@@ -368,6 +368,8 @@ pub(crate) fn check_postgres_and_protobuf_migrations() -> NamedJob {
.runs_on(runners::LINUX_DEFAULT)
.add_env(("GIT_AUTHOR_NAME", "Protobuf Action"))
.add_env(("GIT_AUTHOR_EMAIL", "ci@zed.dev"))
.add_env(("GIT_COMMITTER_NAME", "Protobuf Action"))
.add_env(("GIT_COMMITTER_EMAIL", "ci@zed.dev"))
.add_step(steps::checkout_repo().with(("fetch-depth", 0))) // fetch full history
.add_step(remove_untracked_files())
.add_step(ensure_fresh_merge())