Compare commits

...

10 Commits

Author SHA1 Message Date
Oleksiy Syvokon
610536201b Merge branch 'main' into ep-distill 2025-12-12 21:03:52 +02:00
Agus Zubiaga
60f4aa333b edit prediction cli: Improve error handling (#44718)
We were panicking whenever something went wrong with an example in the
CLI. This can be very disruptive when running many examples, and e.g a
single request fails. Instead, if running more than one example, errors
will now be logged alongside instructions to explore and re-run the
example by itself.

<img width="1454" height="744" alt="CleanShot 2025-12-12 at 13 32 04@2x"
src="https://github.com/user-attachments/assets/87c59e64-08b9-4461-af5b-03af5de94152"></img>


You can still opt in to stop as soon as en error occurs with the new
`--failfast` argument.

Release Notes:

- N/A
2025-12-12 14:15:58 -03:00
localcc
a698f1bf63 Fix Bounds::contains (#44711)
Closes #11643 

Release Notes:

- Fixed double hover state on windows

Co-authored-by: Kirill Bulatov <mail4score@gmail.com>
2025-12-12 14:49:29 +00:00
localcc
636d11ebec Multiple priority scheduler (#44701)
Improves the scheduler by allowing tasks to have a set priority which
will significantly improve responsiveness.

Release notes:

- N/A

---------

Co-authored-by: Yara <git@yara.blue>
Co-authored-by: dvdsk <noreply@davidsk.dev>
2025-12-12 06:32:30 -08:00
Agus Zubiaga
4d0e760b04 edit prediction cli: Progress output cleanup (#44708)
- Limit status lines to 10 in case `max_parallelism` is specified with a
grater value
- Handle logging gracefully rather than writing over it when clearing
status lines

Release Notes:

- N/A
2025-12-12 14:03:08 +00:00
localcc
8bd4d866b9 Windows/send keystrokes (#44707)
Closes #41176 

Release Notes:

- Fixed SendKeystrokes mapping on windows

Co-authored-by: Kirill Bulatov <mail4score@gmail.com>
2025-12-12 05:51:11 -08:00
Oleksiy Syvokon
a2a96e4038 Merge branch 'main' into ep-distill 2025-12-12 11:28:15 +02:00
Oleksiy Syvokon
ec26556dab Parse expected output for the zeta2 prompt
Co-authored-by: Agus Zubiaga <agus@zed.dev>
    Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
2025-12-11 20:13:44 +02:00
Oleksiy Syvokon
1a8d8e9572 Add ep distill command
Co-authored-by: Agus Zubiaga <agus@zed.dev>
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
2025-12-11 19:36:30 +02:00
Oleksiy Syvokon
ab893ca754 ep_cli fixes, non-batched teacher, and other
Co-authored-by: Agus Zubiaga <agus@zed.dev>
Co-authored-by: Max Brunsfeld <maxbrunsfeld@gmail.com>
2025-12-11 19:20:25 +02:00
36 changed files with 1750 additions and 589 deletions

2
Cargo.lock generated
View File

@@ -5201,7 +5201,6 @@ dependencies = [
"wasmtime",
"watch",
"zeta_prompt",
"zlog",
]
[[package]]
@@ -7240,6 +7239,7 @@ dependencies = [
"libc",
"log",
"lyon",
"mach2 0.5.0",
"media",
"metal",
"naga",

View File

@@ -56,7 +56,6 @@ watch.workspace = true
edit_prediction = { workspace = true, features = ["cli-support"] }
wasmtime.workspace = true
zeta_prompt.workspace = true
zlog.workspace = true
# Wasmtime is included as a dependency in order to enable the same
# features that are enabled in Zed.

View File

@@ -1,14 +1,22 @@
use anyhow::{Result, anyhow};
use std::mem;
use crate::example::Example;
pub async fn run_distill(example: &mut Example) {
let [prediction]: [_; 1] = mem::take(&mut example.predictions)
.try_into()
.expect("Run predict first with a single repetition");
pub async fn run_distill(example: &mut Example) -> Result<()> {
let [prediction]: [_; 1] =
mem::take(&mut example.predictions)
.try_into()
.map_err(|preds: Vec<_>| {
anyhow!(
"Example has {} predictions, but it should have exactly one",
preds.len()
)
})?;
example.expected_patch = prediction.actual_patch;
example.prompt = None;
example.predictions = Vec::new();
example.score = Vec::new();
Ok(())
}

View File

@@ -6,6 +6,7 @@ use crate::{
progress::{Progress, Step},
retrieve_context::run_context_retrieval,
};
use anyhow::{Context as _, Result, ensure};
use edit_prediction::{
EditPredictionStore,
zeta2::{zeta2_output_for_patch, zeta2_prompt_input},
@@ -18,12 +19,11 @@ pub async fn run_format_prompt(
example: &mut Example,
prompt_format: PromptFormat,
app_state: Arc<EpAppState>,
progress: Arc<Progress>,
mut cx: AsyncApp,
) {
run_context_retrieval(example, app_state.clone(), progress.clone(), cx.clone()).await;
) -> Result<()> {
run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
let _step_progress = progress.start(Step::FormatPrompt, &example.name);
let _step_progress = Progress::global().start(Step::FormatPrompt, &example.name);
match prompt_format {
PromptFormat::Teacher => {
@@ -35,29 +35,33 @@ pub async fn run_format_prompt(
});
}
PromptFormat::Zeta2 => {
run_load_project(example, app_state, progress.clone(), cx.clone()).await;
run_load_project(example, app_state, cx.clone()).await?;
let ep_store = cx
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
.unwrap();
let ep_store = cx.update(|cx| {
EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
})??;
let state = example.state.as_ref().unwrap();
let snapshot = state
.buffer
.read_with(&cx, |buffer, _| buffer.snapshot())
.unwrap();
let state = example.state.as_ref().context("state must be set")?;
let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot())?;
let project = state.project.clone();
let (_, input) = ep_store
.update(&mut cx, |ep_store, _cx| {
zeta2_prompt_input(
&snapshot,
example.context.as_ref().unwrap().files.clone(),
ep_store.edit_history_for_project(&project),
example.cursor_path.clone(),
example.buffer.as_ref().unwrap().cursor_offset,
)
})
.unwrap();
let (_, input) = ep_store.update(&mut cx, |ep_store, _cx| {
anyhow::Ok(zeta2_prompt_input(
&snapshot,
example
.context
.as_ref()
.context("context must be set")?
.files
.clone(),
ep_store.edit_history_for_project(&project),
example.cursor_path.clone(),
example
.buffer
.as_ref()
.context("buffer must be set")?
.cursor_offset,
))
})??;
let prompt = format_zeta_prompt(&input);
let expected_output = zeta2_output_for_patch(&input, &example.expected_patch.clone());
example.prompt = Some(ExamplePrompt {
@@ -67,6 +71,7 @@ pub async fn run_format_prompt(
});
}
};
Ok(())
}
pub struct TeacherPrompt;
@@ -92,7 +97,7 @@ impl TeacherPrompt {
prompt
}
pub fn parse(example: &Example, response: &str) -> String {
pub fn parse(example: &Example, response: &str) -> Result<String> {
// Ideally, we should always be able to find cursor position in the retrieved context.
// In reality, sometimes we don't find it for these reasons:
// 1. `example.cursor_position` contains _more_ context than included in the retrieved context
@@ -103,7 +108,7 @@ impl TeacherPrompt {
let cursor_file = &example
.buffer
.as_ref()
.expect("`buffer` should be filled in in the context collection step")
.context("`buffer` should be filled in in the context collection step")?
.content;
// Extract updated (new) editable region from the model response
@@ -112,9 +117,10 @@ impl TeacherPrompt {
// Reconstruct old editable region we sent to the model
let old_editable_region = Self::format_editable_region(example);
let old_editable_region = Self::extract_editable_region(&old_editable_region);
if !cursor_file.contains(&old_editable_region) {
panic!("Something's wrong: editable_region is not found in the cursor file")
}
ensure!(
cursor_file.contains(&old_editable_region),
"Something's wrong: editable_region is not found in the cursor file"
);
// Apply editable region to a larger context and compute diff.
// This is needed to get a better context lines around the editable region
@@ -129,7 +135,7 @@ impl TeacherPrompt {
diff = diff,
};
diff
Ok(diff)
}
fn format_edit_history(edit_history: &str) -> String {
@@ -153,9 +159,7 @@ impl TeacherPrompt {
}
fn format_context(example: &Example) -> String {
if example.context.is_none() {
panic!("Missing context retriever step");
}
assert!(example.context.is_some(), "Missing context retriever step");
let mut prompt = String::new();
zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);

View File

@@ -4,7 +4,7 @@ use crate::{
paths::{REPOS_DIR, WORKTREES_DIR},
progress::{InfoStyle, Progress, Step, StepProgress},
};
use anyhow::{Result, anyhow};
use anyhow::{Context as _, Result};
use collections::HashMap;
use edit_prediction::EditPredictionStore;
use edit_prediction::udiff::OpenedBuffers;
@@ -28,40 +28,35 @@ use zeta_prompt::CURSOR_MARKER;
pub async fn run_load_project(
example: &mut Example,
app_state: Arc<EpAppState>,
progress: Arc<Progress>,
mut cx: AsyncApp,
) {
) -> Result<()> {
if example.state.is_some() {
return;
return Ok(());
}
let progress = progress.start(Step::LoadProject, &example.name);
let progress = Progress::global().start(Step::LoadProject, &example.name);
let project = setup_project(example, &app_state, &progress, &mut cx).await;
let project = setup_project(example, &app_state, &progress, &mut cx).await?;
let _open_buffers = apply_edit_history(example, &project, &mut cx)
.await
.unwrap();
let _open_buffers = apply_edit_history(example, &project, &mut cx).await?;
let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await;
let (example_buffer, language_name) = buffer
.read_with(&cx, |buffer, _cx| {
let cursor_point = cursor_position.to_point(&buffer);
let language_name = buffer
.language()
.map(|l| l.name().to_string())
.unwrap_or_else(|| "Unknown".to_string());
(
ExampleBuffer {
content: buffer.text(),
cursor_row: cursor_point.row,
cursor_column: cursor_point.column,
cursor_offset: cursor_position.to_offset(&buffer),
},
language_name,
)
})
.unwrap();
let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await?;
let (example_buffer, language_name) = buffer.read_with(&cx, |buffer, _cx| {
let cursor_point = cursor_position.to_point(&buffer);
let language_name = buffer
.language()
.map(|l| l.name().to_string())
.unwrap_or_else(|| "Unknown".to_string());
(
ExampleBuffer {
content: buffer.text(),
cursor_row: cursor_point.row,
cursor_column: cursor_point.column,
cursor_offset: cursor_position.to_offset(&buffer),
},
language_name,
)
})?;
progress.set_info(language_name, InfoStyle::Normal);
@@ -72,16 +67,15 @@ pub async fn run_load_project(
cursor_position,
_open_buffers,
});
Ok(())
}
async fn cursor_position(
example: &Example,
project: &Entity<Project>,
cx: &mut AsyncApp,
) -> (Entity<Buffer>, Anchor) {
let language_registry = project
.read_with(cx, |project, _| project.languages().clone())
.unwrap();
) -> Result<(Entity<Buffer>, Anchor)> {
let language_registry = project.read_with(cx, |project, _| project.languages().clone())?;
let result = language_registry
.load_language_for_file_path(&example.cursor_path)
.await;
@@ -89,17 +83,18 @@ async fn cursor_position(
if let Err(error) = result
&& !error.is::<LanguageNotFound>()
{
panic!("Failed to load language for file path: {}", error);
return Err(error);
}
let worktree = project
.read_with(cx, |project, cx| {
project.visible_worktrees(cx).next().unwrap()
})
.unwrap();
let worktree = project.read_with(cx, |project, cx| {
project
.visible_worktrees(cx)
.next()
.context("No visible worktrees")
})??;
let cursor_path = RelPath::new(&example.cursor_path, PathStyle::Posix)
.unwrap()
.context("Failed to create RelPath")?
.into_arc();
let cursor_buffer = project
.update(cx, |project, cx| {
@@ -110,15 +105,12 @@ async fn cursor_position(
},
cx,
)
})
.unwrap()
.await
.unwrap();
})?
.await?;
let cursor_offset_within_excerpt = example
.cursor_position
.find(CURSOR_MARKER)
.ok_or_else(|| anyhow!("missing cursor marker"))
.unwrap();
.context("missing cursor marker")?;
let mut cursor_excerpt = example.cursor_position.clone();
cursor_excerpt.replace_range(
cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
@@ -128,90 +120,76 @@ async fn cursor_position(
let text = buffer.text();
let mut matches = text.match_indices(&cursor_excerpt);
let (excerpt_offset, _) = matches.next().unwrap_or_else(|| {
panic!(
let (excerpt_offset, _) = matches.next().with_context(|| {
format!(
"\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Example: {}\nCursor excerpt did not exist in buffer.",
example.name
);
});
assert!(matches.next().is_none(), "More than one cursor position match found for {}", &example.name);
excerpt_offset
}).unwrap();
)
})?;
anyhow::ensure!(matches.next().is_none(), "More than one cursor position match found for {}", &example.name);
Ok(excerpt_offset)
})??;
let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
let cursor_anchor = cursor_buffer
.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))
.unwrap();
let cursor_anchor =
cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
(cursor_buffer, cursor_anchor)
Ok((cursor_buffer, cursor_anchor))
}
async fn setup_project(
example: &mut Example,
app_state: &Arc<EpAppState>,
step_progress: &Arc<StepProgress>,
step_progress: &StepProgress,
cx: &mut AsyncApp,
) -> Entity<Project> {
) -> Result<Entity<Project>> {
let ep_store = cx
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
.unwrap();
.update(|cx| EditPredictionStore::try_global(cx))?
.context("Store should be initialized at init")?;
let worktree_path = setup_worktree(example, step_progress).await;
let worktree_path = setup_worktree(example, step_progress).await?;
if let Some(project) = app_state.project_cache.get(&example.repository_url) {
ep_store
.update(cx, |ep_store, _| {
ep_store.clear_history_for_project(&project);
})
.unwrap();
let buffer_store = project
.read_with(cx, |project, _| project.buffer_store().clone())
.unwrap();
let buffers = buffer_store
.read_with(cx, |buffer_store, _| {
buffer_store.buffers().collect::<Vec<_>>()
})
.unwrap();
ep_store.update(cx, |ep_store, _| {
ep_store.clear_history_for_project(&project);
})?;
let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
let buffers = buffer_store.read_with(cx, |buffer_store, _| {
buffer_store.buffers().collect::<Vec<_>>()
})?;
for buffer in buffers {
buffer
.update(cx, |buffer, cx| buffer.reload(cx))
.unwrap()
.update(cx, |buffer, cx| buffer.reload(cx))?
.await
.ok();
}
return project;
return Ok(project);
}
let project = cx
.update(|cx| {
Project::local(
app_state.client.clone(),
app_state.node_runtime.clone(),
app_state.user_store.clone(),
app_state.languages.clone(),
app_state.fs.clone(),
None,
cx,
)
})
.unwrap();
let project = cx.update(|cx| {
Project::local(
app_state.client.clone(),
app_state.node_runtime.clone(),
app_state.user_store.clone(),
app_state.languages.clone(),
app_state.fs.clone(),
None,
cx,
)
})?;
project
.update(cx, |project, cx| {
project.disable_worktree_scanner(cx);
project.create_worktree(&worktree_path, true, cx)
})
.unwrap()
.await
.unwrap();
})?
.await?;
app_state
.project_cache
.insert(example.repository_url.clone(), project.clone());
let buffer_store = project
.read_with(cx, |project, _| project.buffer_store().clone())
.unwrap();
let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
cx.subscribe(&buffer_store, {
let project = project.clone();
move |_, event, cx| match event {
@@ -220,15 +198,14 @@ async fn setup_project(
}
_ => {}
}
})
.unwrap()
})?
.detach();
project
Ok(project)
}
async fn setup_worktree(example: &Example, step_progress: &Arc<StepProgress>) -> PathBuf {
let (repo_owner, repo_name) = example.repo_name().expect("failed to get repo name");
async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Result<PathBuf> {
let (repo_owner, repo_name) = example.repo_name().context("failed to get repo name")?;
let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
let worktree_path = WORKTREES_DIR
.join(repo_owner.as_ref())
@@ -237,14 +214,13 @@ async fn setup_worktree(example: &Example, step_progress: &Arc<StepProgress>) ->
if !repo_dir.is_dir() {
step_progress.set_substatus(format!("cloning {}", repo_name));
fs::create_dir_all(&repo_dir).unwrap();
run_git(&repo_dir, &["init"]).await.unwrap();
fs::create_dir_all(&repo_dir)?;
run_git(&repo_dir, &["init"]).await?;
run_git(
&repo_dir,
&["remote", "add", "origin", &example.repository_url],
)
.await
.unwrap();
.await?;
}
// Resolve the example to a revision, fetching it if needed.
@@ -264,34 +240,25 @@ async fn setup_worktree(example: &Example, step_progress: &Arc<StepProgress>) ->
.await
.is_err()
{
run_git(&repo_dir, &["fetch", "origin"]).await.unwrap();
run_git(&repo_dir, &["fetch", "origin"]).await?;
}
let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"])
.await
.unwrap();
let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
revision
};
// Create the worktree for this example if needed.
step_progress.set_substatus("preparing worktree");
if worktree_path.is_dir() {
run_git(&worktree_path, &["clean", "--force", "-d"])
.await
.unwrap();
run_git(&worktree_path, &["reset", "--hard", "HEAD"])
.await
.unwrap();
run_git(&worktree_path, &["checkout", revision.as_str()])
.await
.unwrap();
run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
} else {
let worktree_path_string = worktree_path.to_string_lossy();
run_git(
&repo_dir,
&["branch", "-f", &example.name, revision.as_str()],
)
.await
.unwrap();
.await?;
run_git(
&repo_dir,
&[
@@ -302,8 +269,7 @@ async fn setup_worktree(example: &Example, step_progress: &Arc<StepProgress>) ->
&example.name,
],
)
.await
.unwrap();
.await?;
}
drop(repo_lock);
@@ -314,30 +280,25 @@ async fn setup_worktree(example: &Example, step_progress: &Arc<StepProgress>) ->
.current_dir(&worktree_path)
.args(&["apply", "-"])
.stdin(std::process::Stdio::piped())
.spawn()
.unwrap();
.spawn()?;
let mut stdin = apply_process.stdin.take().unwrap();
stdin
.write_all(example.uncommitted_diff.as_bytes())
.await
.unwrap();
stdin.close().await.unwrap();
let mut stdin = apply_process.stdin.take().context("Failed to get stdin")?;
stdin.write_all(example.uncommitted_diff.as_bytes()).await?;
stdin.close().await?;
drop(stdin);
let apply_result = apply_process.output().await.unwrap();
if !apply_result.status.success() {
panic!(
"Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
apply_result.status,
String::from_utf8_lossy(&apply_result.stderr),
String::from_utf8_lossy(&apply_result.stdout),
);
}
let apply_result = apply_process.output().await?;
anyhow::ensure!(
apply_result.status.success(),
"Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
apply_result.status,
String::from_utf8_lossy(&apply_result.stderr),
String::from_utf8_lossy(&apply_result.stdout),
);
}
step_progress.clear_substatus();
worktree_path
Ok(worktree_path)
}
async fn apply_edit_history(

View File

@@ -16,12 +16,14 @@ use edit_prediction::EditPredictionStore;
use gpui::Application;
use reqwest_client::ReqwestClient;
use serde::{Deserialize, Serialize};
use std::fmt::Display;
use std::{path::PathBuf, sync::Arc};
use crate::distill::run_distill;
use crate::example::{group_examples_by_repo, read_examples, write_examples};
use crate::format_prompt::run_format_prompt;
use crate::load_project::run_load_project;
use crate::paths::FAILED_EXAMPLES_DIR;
use crate::predict::run_prediction;
use crate::progress::Progress;
use crate::retrieve_context::run_context_retrieval;
@@ -32,7 +34,7 @@ use crate::score::run_scoring;
struct EpArgs {
#[arg(long, default_value_t = false)]
printenv: bool,
#[clap(long, default_value_t = 10)]
#[clap(long, default_value_t = 10, global = true)]
max_parallelism: usize,
#[command(subcommand)]
command: Option<Command>,
@@ -42,6 +44,8 @@ struct EpArgs {
output: Option<PathBuf>,
#[arg(long, short, global = true)]
in_place: bool,
#[arg(long, short, global = true)]
failfast: bool,
}
#[derive(Subcommand, Debug)]
@@ -67,6 +71,58 @@ enum Command {
Clean,
}
impl Display for Command {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Command::ParseExample => write!(f, "parse-example"),
Command::LoadProject => write!(f, "load-project"),
Command::Context => write!(f, "context"),
Command::FormatPrompt(format_prompt_args) => write!(
f,
"format-prompt --prompt-format={}",
format_prompt_args
.prompt_format
.to_possible_value()
.unwrap()
.get_name()
),
Command::Predict(predict_args) => {
write!(
f,
"predict --provider={:?}",
predict_args
.provider
.to_possible_value()
.unwrap()
.get_name()
)
}
Command::Score(predict_args) => {
write!(
f,
"score --provider={:?}",
predict_args
.provider
.to_possible_value()
.unwrap()
.get_name()
)
}
Command::Distill => write!(f, "distill"),
Command::Eval(predict_args) => write!(
f,
"eval --provider={:?}",
predict_args
.provider
.to_possible_value()
.unwrap()
.get_name()
),
Command::Clean => write!(f, "clean"),
}
}
}
#[derive(Debug, Args)]
struct FormatPromptArgs {
#[clap(long)]
@@ -112,8 +168,6 @@ impl EpArgs {
}
fn main() {
let _ = zlog::try_init(Some("error".into()));
zlog::init_output_stderr();
let args = EpArgs::parse();
if args.printenv {
@@ -147,92 +201,140 @@ fn main() {
EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
cx.spawn(async move |cx| {
if let Command::Predict(args) = &command {
predict::sync_batches(&args.provider).await
};
let result = async {
if let Command::Predict(args) = &command {
predict::sync_batches(&args.provider).await?;
}
let total_examples = examples.len();
let progress = Progress::new(total_examples);
let total_examples = examples.len();
Progress::global().set_total_examples(total_examples);
let mut grouped_examples = group_examples_by_repo(&mut examples);
let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
let mut grouped_examples = group_examples_by_repo(&mut examples);
let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
for example_batch in example_batches {
let futures = example_batch.into_iter().map(|repo_examples| async {
for example in repo_examples.iter_mut() {
match &command {
Command::ParseExample => {}
Command::LoadProject => {
run_load_project(
example,
app_state.clone(),
progress.clone(),
cx.clone(),
)
.await;
for example_batch in example_batches {
let futures = example_batch.into_iter().map(|repo_examples| async {
for example in repo_examples.iter_mut() {
let result = async {
match &command {
Command::ParseExample => {}
Command::LoadProject => {
run_load_project(example, app_state.clone(), cx.clone())
.await?;
}
Command::Context => {
run_context_retrieval(
example,
app_state.clone(),
cx.clone(),
)
.await?;
}
Command::FormatPrompt(args) => {
run_format_prompt(
example,
args.prompt_format,
app_state.clone(),
cx.clone(),
)
.await?;
}
Command::Predict(args) => {
run_prediction(
example,
Some(args.provider),
args.repetitions,
app_state.clone(),
cx.clone(),
)
.await?;
}
Command::Distill => {
run_distill(example).await?;
}
Command::Score(args) | Command::Eval(args) => {
run_scoring(example, &args, app_state.clone(), cx.clone())
.await?;
}
Command::Clean => {
unreachable!()
}
}
anyhow::Ok(())
}
Command::Context => {
run_context_retrieval(
example,
app_state.clone(),
progress.clone(),
cx.clone(),
)
.await;
}
Command::FormatPrompt(args) => {
run_format_prompt(
example,
args.prompt_format,
app_state.clone(),
progress.clone(),
cx.clone(),
)
.await;
}
Command::Predict(args) => {
run_prediction(
example,
Some(args.provider),
args.repetitions,
app_state.clone(),
progress.clone(),
cx.clone(),
)
.await;
}
Command::Distill => {
run_distill(example).await;
}
Command::Score(args) | Command::Eval(args) => {
run_scoring(
example,
&args,
app_state.clone(),
progress.clone(),
cx.clone(),
)
.await;
}
Command::Clean => {
unreachable!()
.await;
if let Err(e) = result {
Progress::global().increment_failed();
let failed_example_path =
FAILED_EXAMPLES_DIR.join(format!("{}.json", example.name));
app_state
.fs
.write(
&failed_example_path,
&serde_json::to_vec_pretty(&example).unwrap(),
)
.await
.unwrap();
let err_path =
FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example.name));
app_state
.fs
.write(&err_path, e.to_string().as_bytes())
.await
.unwrap();
let msg = format!(
indoc::indoc! {"
While processing {}:
{:?}
Written to: \x1b[36m{}\x1b[0m
Explore this example data with:
fx \x1b[36m{}\x1b[0m
Re-run this example with:
cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
"},
example.name,
e,
err_path.display(),
failed_example_path.display(),
command,
failed_example_path.display(),
);
if args.failfast || total_examples == 1 {
Progress::global().finalize();
panic!("{}", msg);
} else {
log::error!("{}", msg);
}
}
}
}
});
futures::future::join_all(futures).await;
}
progress.clear();
});
futures::future::join_all(futures).await;
}
Progress::global().finalize();
if args.output.is_some() || !matches!(command, Command::Eval(_)) {
write_examples(&examples, output.as_ref());
}
if args.output.is_some() || !matches!(command, Command::Eval(_)) {
write_examples(&examples, output.as_ref());
}
match &command {
Command::Predict(args) => predict::sync_batches(&args.provider).await,
Command::Eval(_) => score::print_report(&examples),
_ => (),
};
match &command {
Command::Predict(args) => predict::sync_batches(&args.provider).await?,
Command::Eval(_) => score::print_report(&examples),
_ => (),
};
anyhow::Ok(())
}
.await;
if let Err(e) = result {
panic!("Fatal error: {:?}", e);
}
let _ = cx.update(|cx| cx.quit());
})

View File

@@ -18,6 +18,8 @@ pub static RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
});
pub static LATEST_EXAMPLE_RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| DATA_DIR.join("latest"));
pub static LLM_CACHE_DB: LazyLock<PathBuf> = LazyLock::new(|| CACHE_DIR.join("llm_cache.sqlite"));
pub static FAILED_EXAMPLES_DIR: LazyLock<PathBuf> =
LazyLock::new(|| ensure_dir(&RUN_DIR.join("failed")));
fn ensure_dir(path: &Path) -> PathBuf {
std::fs::create_dir_all(path).expect("Failed to create directory");

View File

@@ -9,6 +9,7 @@ use crate::{
progress::{InfoStyle, Progress, Step},
retrieve_context::run_context_retrieval,
};
use anyhow::Context as _;
use edit_prediction::{DebugEvent, EditPredictionStore};
use futures::{FutureExt as _, StreamExt as _, future::Shared};
use gpui::{AppContext as _, AsyncApp, Task};
@@ -25,41 +26,33 @@ pub async fn run_prediction(
provider: Option<PredictionProvider>,
repetition_count: usize,
app_state: Arc<EpAppState>,
progress: Arc<Progress>,
mut cx: AsyncApp,
) {
) -> anyhow::Result<()> {
if !example.predictions.is_empty() {
return;
return Ok(());
}
let provider = provider.unwrap();
let provider = provider.context("provider is required")?;
run_context_retrieval(example, app_state.clone(), progress.clone(), cx.clone()).await;
run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
if matches!(
provider,
PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching
) {
let _step_progress = progress.start(Step::Predict, &example.name);
let _step_progress = Progress::global().start(Step::Predict, &example.name);
if example.prompt.is_none() {
run_format_prompt(
example,
PromptFormat::Teacher,
app_state.clone(),
progress,
cx,
)
.await;
run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await?;
}
let batched = matches!(provider, PredictionProvider::Teacher);
return predict_anthropic(example, repetition_count, batched).await;
}
run_load_project(example, app_state.clone(), progress.clone(), cx.clone()).await;
run_load_project(example, app_state.clone(), cx.clone()).await?;
let _step_progress = progress.start(Step::Predict, &example.name);
let _step_progress = Progress::global().start(Step::Predict, &example.name);
if matches!(
provider,
@@ -70,10 +63,9 @@ pub async fn run_prediction(
.get_or_init(|| {
let client = app_state.client.clone();
cx.spawn(async move |cx| {
client
.sign_in_with_optional_connect(true, cx)
.await
.unwrap();
if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
eprintln!("Authentication failed: {}", e);
}
})
.shared()
})
@@ -81,33 +73,30 @@ pub async fn run_prediction(
.await;
}
let ep_store = cx
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
.unwrap();
let ep_store = cx.update(|cx| {
EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
})??;
ep_store
.update(&mut cx, |store, _cx| {
let model = match provider {
PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
unreachable!()
}
};
store.set_edit_prediction_model(model);
})
.unwrap();
let state = example.state.as_ref().unwrap();
ep_store.update(&mut cx, |store, _cx| {
let model = match provider {
PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
unreachable!()
}
};
store.set_edit_prediction_model(model);
})?;
let state = example.state.as_ref().context("state must be set")?;
let run_dir = RUN_DIR.join(&example.name);
let updated_example = Arc::new(Mutex::new(example.clone()));
let current_run_ix = Arc::new(AtomicUsize::new(0));
let mut debug_rx = ep_store
.update(&mut cx, |store, cx| store.debug_info(&state.project, cx))
.unwrap();
let mut debug_rx =
ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx))?;
let debug_task = cx.background_spawn({
let updated_example = updated_example.clone();
let current_run_ix = current_run_ix.clone();
@@ -161,14 +150,14 @@ pub async fn run_prediction(
run_dir.clone()
};
fs::create_dir_all(&run_dir).unwrap();
fs::create_dir_all(&run_dir)?;
if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR).unwrap();
fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
}
#[cfg(unix)]
std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
#[cfg(windows)]
std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
updated_example
.lock()
@@ -189,10 +178,8 @@ pub async fn run_prediction(
cloud_llm_client::PredictEditsRequestTrigger::Cli,
cx,
)
})
.unwrap()
.await
.unwrap();
})?
.await?;
let actual_patch = prediction
.and_then(|prediction| {
@@ -221,20 +208,23 @@ pub async fn run_prediction(
}
}
ep_store
.update(&mut cx, |store, _| {
store.remove_project(&state.project);
})
.unwrap();
debug_task.await.unwrap();
ep_store.update(&mut cx, |store, _| {
store.remove_project(&state.project);
})?;
debug_task.await?;
*example = Arc::into_inner(updated_example)
.unwrap()
.ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
.into_inner()
.unwrap();
.map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
Ok(())
}
async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batched: bool) {
async fn predict_anthropic(
example: &mut Example,
_repetition_count: usize,
batched: bool,
) -> anyhow::Result<()> {
let llm_model_name = "claude-sonnet-4-5";
let max_tokens = 16384;
let llm_client = if batched {
@@ -242,12 +232,9 @@ async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batc
} else {
AnthropicClient::plain()
};
let llm_client = llm_client.expect("Failed to create LLM client");
let llm_client = llm_client.context("Failed to create LLM client")?;
let prompt = example
.prompt
.as_ref()
.unwrap_or_else(|| panic!("Prompt is required for an example {}", &example.name));
let prompt = example.prompt.as_ref().context("Prompt is required")?;
let messages = vec![anthropic::Message {
role: anthropic::Role::User,
@@ -259,11 +246,10 @@ async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batc
let Some(response) = llm_client
.generate(llm_model_name, max_tokens, messages)
.await
.unwrap()
.await?
else {
// Request stashed for batched processing
return;
return Ok(());
};
let actual_output = response
@@ -276,7 +262,7 @@ async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batc
.collect::<Vec<String>>()
.join("\n");
let actual_patch = TeacherPrompt::parse(example, &actual_output);
let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
let prediction = ExamplePrediction {
actual_patch,
@@ -285,19 +271,21 @@ async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batc
};
example.predictions.push(prediction);
Ok(())
}
pub async fn sync_batches(provider: &PredictionProvider) {
pub async fn sync_batches(provider: &PredictionProvider) -> anyhow::Result<()> {
match provider {
PredictionProvider::Teacher => {
let cache_path = crate::paths::LLM_CACHE_DB.as_ref();
let llm_client =
AnthropicClient::batch(cache_path).expect("Failed to create LLM client");
AnthropicClient::batch(cache_path).context("Failed to create LLM client")?;
llm_client
.sync_batches()
.await
.expect("Failed to sync batches");
.context("Failed to sync batches")?;
}
_ => (),
}
};
Ok(())
}

View File

@@ -2,10 +2,12 @@ use std::{
borrow::Cow,
collections::HashMap,
io::{IsTerminal, Write},
sync::{Arc, Mutex},
sync::{Arc, Mutex, OnceLock},
time::{Duration, Instant},
};
use log::{Level, Log, Metadata, Record};
pub struct Progress {
inner: Mutex<ProgressInner>,
}
@@ -18,6 +20,8 @@ struct ProgressInner {
max_example_name_len: usize,
status_lines_displayed: usize,
total_examples: usize,
failed_examples: usize,
last_line_is_logging: bool,
}
#[derive(Clone)]
@@ -72,70 +76,120 @@ impl Step {
}
}
const RIGHT_MARGIN: usize = 4;
static GLOBAL: OnceLock<Arc<Progress>> = OnceLock::new();
static LOGGER: ProgressLogger = ProgressLogger;
const MARGIN: usize = 4;
const MAX_STATUS_LINES: usize = 10;
impl Progress {
pub fn new(total_examples: usize) -> Arc<Self> {
Arc::new(Self {
inner: Mutex::new(ProgressInner {
completed: Vec::new(),
in_progress: HashMap::new(),
is_tty: std::io::stderr().is_terminal(),
terminal_width: get_terminal_width(),
max_example_name_len: 0,
status_lines_displayed: 0,
total_examples,
}),
})
/// Returns the global Progress instance, initializing it if necessary.
pub fn global() -> Arc<Progress> {
GLOBAL
.get_or_init(|| {
let progress = Arc::new(Self {
inner: Mutex::new(ProgressInner {
completed: Vec::new(),
in_progress: HashMap::new(),
is_tty: std::io::stderr().is_terminal(),
terminal_width: get_terminal_width(),
max_example_name_len: 0,
status_lines_displayed: 0,
total_examples: 0,
failed_examples: 0,
last_line_is_logging: false,
}),
});
let _ = log::set_logger(&LOGGER);
log::set_max_level(log::LevelFilter::Error);
progress
})
.clone()
}
pub fn start(self: &Arc<Self>, step: Step, example_name: &str) -> Arc<StepProgress> {
{
let mut inner = self.inner.lock().unwrap();
pub fn set_total_examples(&self, total: usize) {
let mut inner = self.inner.lock().unwrap();
inner.total_examples = total;
}
Self::clear_status_lines(&mut inner);
pub fn increment_failed(&self) {
let mut inner = self.inner.lock().unwrap();
inner.failed_examples += 1;
}
inner.max_example_name_len = inner.max_example_name_len.max(example_name.len());
/// Prints a message to stderr, clearing and redrawing status lines to avoid corruption.
/// This should be used for any output that needs to appear above the status lines.
fn log(&self, message: &str) {
let mut inner = self.inner.lock().unwrap();
Self::clear_status_lines(&mut inner);
inner.in_progress.insert(
example_name.to_string(),
InProgressTask {
step,
started_at: Instant::now(),
substatus: None,
info: None,
},
);
Self::print_status_lines(&mut inner);
if !inner.last_line_is_logging {
let reset = "\x1b[0m";
let dim = "\x1b[2m";
let divider = "".repeat(inner.terminal_width.saturating_sub(MARGIN));
eprintln!("{dim}{divider}{reset}");
inner.last_line_is_logging = true;
}
Arc::new(StepProgress {
eprintln!("{}", message);
}
pub fn start(self: &Arc<Self>, step: Step, example_name: &str) -> StepProgress {
let mut inner = self.inner.lock().unwrap();
Self::clear_status_lines(&mut inner);
inner.max_example_name_len = inner.max_example_name_len.max(example_name.len());
inner.in_progress.insert(
example_name.to_string(),
InProgressTask {
step,
started_at: Instant::now(),
substatus: None,
info: None,
},
);
Self::print_status_lines(&mut inner);
StepProgress {
progress: self.clone(),
step,
example_name: example_name.to_string(),
})
}
}
pub fn finish(&self, step: Step, example_name: &str) {
fn finish(&self, step: Step, example_name: &str) {
let mut inner = self.inner.lock().unwrap();
let task = inner.in_progress.remove(example_name);
if let Some(task) = task {
if task.step == step {
inner.completed.push(CompletedTask {
step: task.step,
example_name: example_name.to_string(),
duration: task.started_at.elapsed(),
info: task.info,
});
let Some(task) = inner.in_progress.remove(example_name) else {
return;
};
Self::clear_status_lines(&mut inner);
Self::print_completed(&inner, inner.completed.last().unwrap());
Self::print_status_lines(&mut inner);
} else {
inner.in_progress.insert(example_name.to_string(), task);
}
if task.step == step {
inner.completed.push(CompletedTask {
step: task.step,
example_name: example_name.to_string(),
duration: task.started_at.elapsed(),
info: task.info,
});
Self::clear_status_lines(&mut inner);
Self::print_logging_closing_divider(&mut inner);
Self::print_completed(&inner, inner.completed.last().unwrap());
Self::print_status_lines(&mut inner);
} else {
inner.in_progress.insert(example_name.to_string(), task);
}
}
fn print_logging_closing_divider(inner: &mut ProgressInner) {
if inner.last_line_is_logging {
let reset = "\x1b[0m";
let dim = "\x1b[2m";
let divider = "".repeat(inner.terminal_width.saturating_sub(MARGIN));
eprintln!("{dim}{divider}{reset}");
inner.last_line_is_logging = false;
}
}
@@ -182,7 +236,7 @@ impl Progress {
let duration_with_margin = format!("{duration} ");
let padding_needed = inner
.terminal_width
.saturating_sub(RIGHT_MARGIN)
.saturating_sub(MARGIN)
.saturating_sub(duration_with_margin.len())
.saturating_sub(strip_ansi_len(&prefix));
let padding = " ".repeat(padding_needed);
@@ -216,27 +270,41 @@ impl Progress {
// Build the done/in-progress/total label
let done_count = inner.completed.len();
let in_progress_count = inner.in_progress.len();
let failed_count = inner.failed_examples;
let failed_label = if failed_count > 0 {
format!(" {} failed ", failed_count)
} else {
String::new()
};
let range_label = format!(
" {}/{}/{} ",
done_count, in_progress_count, inner.total_examples
);
// Print a divider line with range label aligned with timestamps
// Print a divider line with failed count on left, range label on right
let failed_visible_len = strip_ansi_len(&failed_label);
let range_visible_len = range_label.len();
let left_divider_len = inner
let middle_divider_len = inner
.terminal_width
.saturating_sub(RIGHT_MARGIN)
.saturating_sub(MARGIN * 2)
.saturating_sub(failed_visible_len)
.saturating_sub(range_visible_len);
let left_divider = "".repeat(left_divider_len);
let right_divider = "".repeat(RIGHT_MARGIN);
eprintln!("{dim}{left_divider}{reset}{range_label}{dim}{right_divider}{reset}");
let left_divider = "".repeat(MARGIN);
let middle_divider = "".repeat(middle_divider_len);
let right_divider = "".repeat(MARGIN);
eprintln!(
"{dim}{left_divider}{reset}{failed_label}{dim}{middle_divider}{reset}{range_label}{dim}{right_divider}{reset}"
);
let mut tasks: Vec<_> = inner.in_progress.iter().collect();
tasks.sort_by_key(|(name, _)| *name);
let total_tasks = tasks.len();
let mut lines_printed = 0;
for (name, task) in tasks.iter() {
for (name, task) in tasks.iter().take(MAX_STATUS_LINES) {
let elapsed = format_duration(task.started_at.elapsed());
let substatus_part = task
.substatus
@@ -256,7 +324,7 @@ impl Progress {
let duration_with_margin = format!("{elapsed} ");
let padding_needed = inner
.terminal_width
.saturating_sub(RIGHT_MARGIN)
.saturating_sub(MARGIN)
.saturating_sub(duration_with_margin.len())
.saturating_sub(strip_ansi_len(&prefix));
let padding = " ".repeat(padding_needed);
@@ -265,13 +333,34 @@ impl Progress {
lines_printed += 1;
}
// Show "+N more" on its own line if there are more tasks
if total_tasks > MAX_STATUS_LINES {
let remaining = total_tasks - MAX_STATUS_LINES;
eprintln!("{:>12} +{remaining} more", "");
lines_printed += 1;
}
inner.status_lines_displayed = lines_printed + 1; // +1 for the divider line
let _ = std::io::stderr().flush();
}
pub fn clear(&self) {
pub fn finalize(&self) {
let mut inner = self.inner.lock().unwrap();
Self::clear_status_lines(&mut inner);
// Print summary if there were failures
if inner.failed_examples > 0 {
let total_processed = inner.completed.len() + inner.failed_examples;
let percentage = if total_processed > 0 {
inner.failed_examples as f64 / total_processed as f64 * 100.0
} else {
0.0
};
eprintln!(
"\n{} of {} examples failed ({:.1}%)",
inner.failed_examples, total_processed, percentage
);
}
}
}
@@ -314,6 +403,53 @@ impl Drop for StepProgress {
}
}
struct ProgressLogger;
impl Log for ProgressLogger {
fn enabled(&self, metadata: &Metadata) -> bool {
metadata.level() <= Level::Info
}
fn log(&self, record: &Record) {
if !self.enabled(record.metadata()) {
return;
}
let level_color = match record.level() {
Level::Error => "\x1b[31m",
Level::Warn => "\x1b[33m",
Level::Info => "\x1b[32m",
Level::Debug => "\x1b[34m",
Level::Trace => "\x1b[35m",
};
let reset = "\x1b[0m";
let bold = "\x1b[1m";
let level_label = match record.level() {
Level::Error => "Error",
Level::Warn => "Warn",
Level::Info => "Info",
Level::Debug => "Debug",
Level::Trace => "Trace",
};
let message = format!(
"{bold}{level_color}{level_label:>12}{reset} {}",
record.args()
);
if let Some(progress) = GLOBAL.get() {
progress.log(&message);
} else {
eprintln!("{}", message);
}
}
fn flush(&self) {
let _ = std::io::stderr().flush();
}
}
#[cfg(unix)]
fn get_terminal_width() -> usize {
unsafe {

View File

@@ -4,6 +4,7 @@ use crate::{
load_project::run_load_project,
progress::{InfoStyle, Progress, Step, StepProgress},
};
use anyhow::Context as _;
use collections::HashSet;
use edit_prediction::{DebugEvent, EditPredictionStore};
use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
@@ -16,39 +17,36 @@ use std::time::Duration;
pub async fn run_context_retrieval(
example: &mut Example,
app_state: Arc<EpAppState>,
progress: Arc<Progress>,
mut cx: AsyncApp,
) {
) -> anyhow::Result<()> {
if example.context.is_some() {
return;
return Ok(());
}
run_load_project(example, app_state.clone(), progress.clone(), cx.clone()).await;
run_load_project(example, app_state.clone(), cx.clone()).await?;
let step_progress = progress.start(Step::Context, &example.name);
let step_progress: Arc<StepProgress> = Progress::global()
.start(Step::Context, &example.name)
.into();
let state = example.state.as_ref().unwrap();
let project = state.project.clone();
let _lsp_handle = project
.update(&mut cx, |project, cx| {
project.register_buffer_with_language_servers(&state.buffer, cx)
})
.unwrap();
wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await;
let _lsp_handle = project.update(&mut cx, |project, cx| {
project.register_buffer_with_language_servers(&state.buffer, cx)
})?;
wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await?;
let ep_store = cx
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
.unwrap();
let ep_store = cx.update(|cx| {
EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
})??;
let mut events = ep_store
.update(&mut cx, |store, cx| {
store.register_buffer(&state.buffer, &project, cx);
store.set_use_context(true);
store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
store.debug_info(&project, cx)
})
.unwrap();
let mut events = ep_store.update(&mut cx, |store, cx| {
store.register_buffer(&state.buffer, &project, cx);
store.set_use_context(true);
store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
store.debug_info(&project, cx)
})?;
while let Some(event) = events.next().await {
match event {
@@ -59,9 +57,8 @@ pub async fn run_context_retrieval(
}
}
let context_files = ep_store
.update(&mut cx, |store, cx| store.context_for_project(&project, cx))
.unwrap();
let context_files =
ep_store.update(&mut cx, |store, cx| store.context_for_project(&project, cx))?;
let excerpt_count: usize = context_files.iter().map(|f| f.excerpts.len()).sum();
step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal);
@@ -69,6 +66,7 @@ pub async fn run_context_retrieval(
example.context = Some(ExampleContext {
files: context_files,
});
Ok(())
}
async fn wait_for_language_servers_to_start(
@@ -76,10 +74,8 @@ async fn wait_for_language_servers_to_start(
buffer: &Entity<Buffer>,
step_progress: &Arc<StepProgress>,
cx: &mut AsyncApp,
) {
let lsp_store = project
.read_with(cx, |project, _| project.lsp_store())
.unwrap();
) -> anyhow::Result<()> {
let lsp_store = project.read_with(cx, |project, _| project.lsp_store())?;
let (language_server_ids, mut starting_language_server_ids) = buffer
.update(cx, |buffer, cx| {
@@ -122,7 +118,7 @@ async fn wait_for_language_servers_to_start(
}
},
_ = timeout.clone().fuse() => {
panic!("LSP wait timed out after 5 minutes");
return Err(anyhow::anyhow!("LSP wait timed out after 5 minutes"));
}
}
}
@@ -131,8 +127,7 @@ async fn wait_for_language_servers_to_start(
if !language_server_ids.is_empty() {
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
.unwrap()
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
.detach();
}
@@ -174,10 +169,8 @@ async fn wait_for_language_servers_to_start(
];
project
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
.unwrap()
.await
.unwrap();
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
.await?;
let mut pending_language_server_ids = HashSet::from_iter(language_server_ids.into_iter());
while !pending_language_server_ids.is_empty() {
@@ -188,11 +181,12 @@ async fn wait_for_language_servers_to_start(
}
},
_ = timeout.clone().fuse() => {
panic!("LSP wait timed out after 5 minutes");
return Err(anyhow::anyhow!("LSP wait timed out after 5 minutes"));
}
}
}
drop(subscriptions);
step_progress.clear_substatus();
Ok(())
}

View File

@@ -14,20 +14,18 @@ pub async fn run_scoring(
example: &mut Example,
args: &PredictArgs,
app_state: Arc<EpAppState>,
progress: Arc<Progress>,
cx: AsyncApp,
) {
) -> anyhow::Result<()> {
run_prediction(
example,
Some(args.provider),
args.repetitions,
app_state,
progress.clone(),
cx,
)
.await;
.await?;
let _progress = progress.start(Step::Score, &example.name);
let _progress = Progress::global().start(Step::Score, &example.name);
let expected_patch = parse_patch(&example.expected_patch);
@@ -45,6 +43,7 @@ pub async fn run_scoring(
}
example.score = scores;
Ok(())
}
fn parse_patch(patch: &str) -> Vec<DiffLine<'_>> {

View File

@@ -21,7 +21,6 @@ default = ["font-kit", "wayland", "x11", "windows-manifest"]
test-support = [
"leak-detection",
"collections/test-support",
"rand",
"util/test-support",
"http_client/test-support",
"wayland",
@@ -109,7 +108,7 @@ parking = "2.0.0"
parking_lot.workspace = true
postage.workspace = true
profiling.workspace = true
rand = { optional = true, workspace = true }
rand.workspace = true
raw-window-handle = "0.6"
refineable.workspace = true
resvg = { version = "0.45.0", default-features = false, features = [
@@ -158,8 +157,10 @@ media.workspace = true
objc.workspace = true
objc2 = { version = "0.6", optional = true }
objc2-metal = { version = "0.3", optional = true }
mach2.workspace = true
#TODO: replace with "objc2"
metal.workspace = true
flume = "0.11"
[target.'cfg(any(target_os = "linux", target_os = "freebsd", target_os = "macos"))'.dependencies]
pathfinder_geometry = "0.5"

View File

@@ -84,6 +84,8 @@ mod macos {
.allowlist_var("_dispatch_main_q")
.allowlist_var("_dispatch_source_type_data_add")
.allowlist_var("DISPATCH_QUEUE_PRIORITY_HIGH")
.allowlist_var("DISPATCH_QUEUE_PRIORITY_DEFAULT")
.allowlist_var("DISPATCH_QUEUE_PRIORITY_LOW")
.allowlist_var("DISPATCH_TIME_NOW")
.allowlist_function("dispatch_get_global_queue")
.allowlist_function("dispatch_async_f")

View File

@@ -38,10 +38,11 @@ use crate::{
AssetSource, BackgroundExecutor, Bounds, ClipboardItem, CursorStyle, DispatchPhase, DisplayId,
EventEmitter, FocusHandle, FocusMap, ForegroundExecutor, Global, KeyBinding, KeyContext,
Keymap, Keystroke, LayoutId, Menu, MenuItem, OwnedMenu, PathPromptOptions, Pixels, Platform,
PlatformDisplay, PlatformKeyboardLayout, PlatformKeyboardMapper, Point, PromptBuilder,
PromptButton, PromptHandle, PromptLevel, Render, RenderImage, RenderablePromptHandle,
Reservation, ScreenCaptureSource, SharedString, SubscriberSet, Subscription, SvgRenderer, Task,
TextSystem, Window, WindowAppearance, WindowHandle, WindowId, WindowInvalidator,
PlatformDisplay, PlatformKeyboardLayout, PlatformKeyboardMapper, Point, Priority,
PromptBuilder, PromptButton, PromptHandle, PromptLevel, Render, RenderImage,
RenderablePromptHandle, Reservation, ScreenCaptureSource, SharedString, SubscriberSet,
Subscription, SvgRenderer, Task, TextSystem, Window, WindowAppearance, WindowHandle, WindowId,
WindowInvalidator,
colors::{Colors, GlobalColors},
current_platform, hash, init_app_menus,
};
@@ -1494,6 +1495,24 @@ impl App {
.spawn(async move { f(&mut cx).await })
}
/// Spawns the future returned by the given function on the main thread with
/// the given priority. The closure will be invoked with [AsyncApp], which
/// allows the application state to be accessed across await points.
pub fn spawn_with_priority<AsyncFn, R>(&self, priority: Priority, f: AsyncFn) -> Task<R>
where
AsyncFn: AsyncFnOnce(&mut AsyncApp) -> R + 'static,
R: 'static,
{
if self.quitting {
debug_panic!("Can't spawn on main thread after on_app_quit")
};
let mut cx = self.to_async();
self.foreground_executor
.spawn_with_priority(priority, async move { f(&mut cx).await })
}
/// Schedules the given function to be run at the end of the current effect cycle, allowing entities
/// that are currently on the stack to be returned to the app.
pub fn defer(&mut self, f: impl FnOnce(&mut App) + 'static) {

View File

@@ -1,7 +1,7 @@
use crate::{
AnyView, AnyWindowHandle, AppContext, AsyncApp, DispatchPhase, Effect, EntityId, EventEmitter,
FocusHandle, FocusOutEvent, Focusable, Global, KeystrokeObserver, Reservation, SubscriberSet,
Subscription, Task, WeakEntity, WeakFocusHandle, Window, WindowHandle,
FocusHandle, FocusOutEvent, Focusable, Global, KeystrokeObserver, Priority, Reservation,
SubscriberSet, Subscription, Task, WeakEntity, WeakFocusHandle, Window, WindowHandle,
};
use anyhow::Result;
use futures::FutureExt;
@@ -667,6 +667,25 @@ impl<'a, T: 'static> Context<'a, T> {
window.spawn(self, async move |cx| f(view, cx).await)
}
/// Schedule a future to be run asynchronously with the given priority.
/// The given callback is invoked with a [`WeakEntity<V>`] to avoid leaking the entity for a long-running process.
/// It's also given an [`AsyncWindowContext`], which can be used to access the state of the entity across await points.
/// The returned future will be polled on the main thread.
#[track_caller]
pub fn spawn_in_with_priority<AsyncFn, R>(
&self,
priority: Priority,
window: &Window,
f: AsyncFn,
) -> Task<R>
where
R: 'static,
AsyncFn: AsyncFnOnce(WeakEntity<T>, &mut AsyncWindowContext) -> R + 'static,
{
let view = self.weak_entity();
window.spawn_with_priority(priority, self, async move |cx| f(view, cx).await)
}
/// Register a callback to be invoked when the given global state changes.
pub fn observe_global_in<G: Global>(
&mut self,

View File

@@ -1,4 +1,4 @@
use crate::{App, PlatformDispatcher, RunnableMeta, RunnableVariant};
use crate::{App, PlatformDispatcher, RunnableMeta, RunnableVariant, TaskTiming, profiler};
use async_task::Runnable;
use futures::channel::mpsc;
use parking_lot::{Condvar, Mutex};
@@ -47,6 +47,52 @@ pub struct ForegroundExecutor {
not_send: PhantomData<Rc<()>>,
}
/// Realtime task priority
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
#[repr(u8)]
pub enum RealtimePriority {
/// Audio task
Audio,
/// Other realtime task
#[default]
Other,
}
/// Task priority
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
#[repr(u8)]
pub enum Priority {
/// Realtime priority
///
/// Spawning a task with this priority will spin it off on a separate thread dedicated just to that task.
Realtime(RealtimePriority),
/// High priority
///
/// Only use for tasks that are critical to the user experience / responsiveness of the editor.
High,
/// Medium priority, probably suits most of your use cases.
#[default]
Medium,
/// Low priority
///
/// Prioritize this for background work that can come in large quantities
/// to not starve the executor of resources for high priority tasks
Low,
}
impl Priority {
#[allow(dead_code)]
pub(crate) const fn probability(&self) -> u32 {
match self {
// realtime priorities are not considered for probability scheduling
Priority::Realtime(_) => 0,
Priority::High => 60,
Priority::Medium => 30,
Priority::Low => 10,
}
}
}
/// Task is a primitive that allows work to happen in the background.
///
/// It implements [`Future`] so you can `.await` on it.
@@ -152,7 +198,20 @@ impl BackgroundExecutor {
where
R: Send + 'static,
{
self.spawn_internal::<R>(Box::pin(future), None)
self.spawn_with_priority(Priority::default(), future)
}
/// Enqueues the given future to be run to completion on a background thread.
#[track_caller]
pub fn spawn_with_priority<R>(
&self,
priority: Priority,
future: impl Future<Output = R> + Send + 'static,
) -> Task<R>
where
R: Send + 'static,
{
self.spawn_internal::<R>(Box::pin(future), None, priority)
}
/// Enqueues the given future to be run to completion on a background thread and blocking the current task on it.
@@ -199,7 +258,13 @@ impl BackgroundExecutor {
let _notify_guard = NotifyOnDrop(pair);
future.await
},
move |runnable| dispatcher.dispatch(RunnableVariant::Meta(runnable), None),
move |runnable| {
dispatcher.dispatch(
RunnableVariant::Meta(runnable),
None,
Priority::default(),
)
},
)
};
runnable.schedule();
@@ -217,7 +282,7 @@ impl BackgroundExecutor {
where
R: Send + 'static,
{
self.spawn_internal::<R>(Box::pin(future), Some(label))
self.spawn_internal::<R>(Box::pin(future), Some(label), Priority::default())
}
#[track_caller]
@@ -225,15 +290,55 @@ impl BackgroundExecutor {
&self,
future: AnyFuture<R>,
label: Option<TaskLabel>,
priority: Priority,
) -> Task<R> {
let dispatcher = self.dispatcher.clone();
let location = core::panic::Location::caller();
let (runnable, task) = async_task::Builder::new()
.metadata(RunnableMeta { location })
.spawn(
move |_| future,
move |runnable| dispatcher.dispatch(RunnableVariant::Meta(runnable), label),
let (runnable, task) = if let Priority::Realtime(realtime) = priority {
let location = core::panic::Location::caller();
let (mut tx, rx) = flume::bounded::<Runnable<RunnableMeta>>(1);
dispatcher.spawn_realtime(
realtime,
Box::new(move || {
while let Ok(runnable) = rx.recv() {
let start = Instant::now();
let location = runnable.metadata().location;
let mut timing = TaskTiming {
location,
start,
end: None,
};
profiler::add_task_timing(timing);
runnable.run();
let end = Instant::now();
timing.end = Some(end);
profiler::add_task_timing(timing);
}
}),
);
async_task::Builder::new()
.metadata(RunnableMeta { location })
.spawn(
move |_| future,
move |runnable| {
let _ = tx.send(runnable);
},
)
} else {
let location = core::panic::Location::caller();
async_task::Builder::new()
.metadata(RunnableMeta { location })
.spawn(
move |_| future,
move |runnable| {
dispatcher.dispatch(RunnableVariant::Meta(runnable), label, priority)
},
)
};
runnable.schedule();
Task(TaskState::Spawned(task))
}
@@ -406,11 +511,28 @@ impl BackgroundExecutor {
where
F: FnOnce(&mut Scope<'scope>),
{
let mut scope = Scope::new(self.clone());
let mut scope = Scope::new(self.clone(), Priority::default());
(scheduler)(&mut scope);
let spawned = mem::take(&mut scope.futures)
.into_iter()
.map(|f| self.spawn(f))
.map(|f| self.spawn_with_priority(scope.priority, f))
.collect::<Vec<_>>();
for task in spawned {
task.await;
}
}
/// Scoped lets you start a number of tasks and waits
/// for all of them to complete before returning.
pub async fn scoped_priority<'scope, F>(&self, priority: Priority, scheduler: F)
where
F: FnOnce(&mut Scope<'scope>),
{
let mut scope = Scope::new(self.clone(), priority);
(scheduler)(&mut scope);
let spawned = mem::take(&mut scope.futures)
.into_iter()
.map(|f| self.spawn_with_priority(scope.priority, f))
.collect::<Vec<_>>();
for task in spawned {
task.await;
@@ -546,6 +668,19 @@ impl ForegroundExecutor {
/// Enqueues the given Task to run on the main thread at some point in the future.
#[track_caller]
pub fn spawn<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
where
R: 'static,
{
self.spawn_with_priority(Priority::default(), future)
}
/// Enqueues the given Task to run on the main thread at some point in the future.
#[track_caller]
pub fn spawn_with_priority<R>(
&self,
priority: Priority,
future: impl Future<Output = R> + 'static,
) -> Task<R>
where
R: 'static,
{
@@ -557,16 +692,19 @@ impl ForegroundExecutor {
dispatcher: Arc<dyn PlatformDispatcher>,
future: AnyLocalFuture<R>,
location: &'static core::panic::Location<'static>,
priority: Priority,
) -> Task<R> {
let (runnable, task) = spawn_local_with_source_location(
future,
move |runnable| dispatcher.dispatch_on_main_thread(RunnableVariant::Meta(runnable)),
move |runnable| {
dispatcher.dispatch_on_main_thread(RunnableVariant::Meta(runnable), priority)
},
RunnableMeta { location },
);
runnable.schedule();
Task(TaskState::Spawned(task))
}
inner::<R>(dispatcher, Box::pin(future), location)
inner::<R>(dispatcher, Box::pin(future), location, priority)
}
}
@@ -642,6 +780,7 @@ where
/// Scope manages a set of tasks that are enqueued and waited on together. See [`BackgroundExecutor::scoped`].
pub struct Scope<'a> {
executor: BackgroundExecutor,
priority: Priority,
futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
tx: Option<mpsc::Sender<()>>,
rx: mpsc::Receiver<()>,
@@ -649,10 +788,11 @@ pub struct Scope<'a> {
}
impl<'a> Scope<'a> {
fn new(executor: BackgroundExecutor) -> Self {
fn new(executor: BackgroundExecutor, priority: Priority) -> Self {
let (tx, rx) = mpsc::channel(1);
Self {
executor,
priority,
tx: Some(tx),
rx,
futures: Default::default(),

View File

@@ -1416,9 +1416,9 @@ where
/// ```
pub fn contains(&self, point: &Point<T>) -> bool {
point.x >= self.origin.x
&& point.x <= self.origin.x.clone() + self.size.width.clone()
&& point.x < self.origin.x.clone() + self.size.width.clone()
&& point.y >= self.origin.y
&& point.y <= self.origin.y.clone() + self.size.height.clone()
&& point.y < self.origin.y.clone() + self.size.height.clone()
}
/// Checks if this bounds is completely contained within another bounds.

View File

@@ -31,6 +31,8 @@ mod path_builder;
mod platform;
pub mod prelude;
mod profiler;
#[cfg(any(target_os = "windows", target_os = "linux"))]
mod queue;
mod scene;
mod shared_string;
mod shared_uri;
@@ -89,16 +91,20 @@ pub use keymap::*;
pub use path_builder::*;
pub use platform::*;
pub use profiler::*;
#[cfg(any(target_os = "windows", target_os = "linux"))]
pub(crate) use queue::{PriorityQueueReceiver, PriorityQueueSender};
pub use refineable::*;
pub use scene::*;
pub use shared_string::*;
pub use shared_uri::*;
pub use smol::Timer;
use std::{any::Any, future::Future};
pub use style::*;
pub use styled::*;
pub use subscription::*;
pub use svg_renderer::*;
pub(crate) use tab_stop::*;
use taffy::TaffyLayoutEngine;
pub use taffy::{AvailableSpace, LayoutId};
#[cfg(any(test, feature = "test-support"))]
pub use test::*;
@@ -109,9 +115,6 @@ pub use util::{FutureExt, Timeout, arc_cow::ArcCow};
pub use view::*;
pub use window::*;
use std::{any::Any, future::Future};
use taffy::TaffyLayoutEngine;
/// The context trait, allows the different contexts in GPUI to be used
/// interchangeably for certain operations.
pub trait AppContext {

View File

@@ -39,9 +39,10 @@ use crate::{
Action, AnyWindowHandle, App, AsyncWindowContext, BackgroundExecutor, Bounds,
DEFAULT_WINDOW_SIZE, DevicePixels, DispatchEventResult, Font, FontId, FontMetrics, FontRun,
ForegroundExecutor, GlyphId, GpuSpecs, ImageSource, Keymap, LineLayout, Pixels, PlatformInput,
Point, RenderGlyphParams, RenderImage, RenderImageParams, RenderSvgParams, Scene, ShapedGlyph,
ShapedRun, SharedString, Size, SvgRenderer, SystemWindowTab, Task, TaskLabel, TaskTiming,
ThreadTaskTimings, Window, WindowControlArea, hash, point, px, size,
Point, Priority, RealtimePriority, RenderGlyphParams, RenderImage, RenderImageParams,
RenderSvgParams, Scene, ShapedGlyph, ShapedRun, SharedString, Size, SvgRenderer,
SystemWindowTab, Task, TaskLabel, TaskTiming, ThreadTaskTimings, Window, WindowControlArea,
hash, point, px, size,
};
use anyhow::Result;
use async_task::Runnable;
@@ -587,9 +588,10 @@ pub trait PlatformDispatcher: Send + Sync {
fn get_all_timings(&self) -> Vec<ThreadTaskTimings>;
fn get_current_thread_timings(&self) -> Vec<TaskTiming>;
fn is_main_thread(&self) -> bool;
fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>);
fn dispatch_on_main_thread(&self, runnable: RunnableVariant);
fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>, priority: Priority);
fn dispatch_on_main_thread(&self, runnable: RunnableVariant, priority: Priority);
fn dispatch_after(&self, duration: Duration, runnable: RunnableVariant);
fn spawn_realtime(&self, priority: RealtimePriority, f: Box<dyn FnOnce() + Send>);
fn now(&self) -> Instant {
Instant::now()

View File

@@ -1,9 +1,10 @@
use crate::{
GLOBAL_THREAD_TIMINGS, PlatformDispatcher, RunnableVariant, THREAD_TIMINGS, TaskLabel,
TaskTiming, ThreadTaskTimings,
GLOBAL_THREAD_TIMINGS, PlatformDispatcher, Priority, PriorityQueueReceiver,
PriorityQueueSender, RealtimePriority, RunnableVariant, THREAD_TIMINGS, TaskLabel, TaskTiming,
ThreadTaskTimings, profiler,
};
use calloop::{
EventLoop,
EventLoop, PostAction,
channel::{self, Sender},
timer::TimeoutAction,
};
@@ -19,9 +20,9 @@ struct TimerAfter {
}
pub(crate) struct LinuxDispatcher {
main_sender: Sender<RunnableVariant>,
main_sender: PriorityQueueCalloopSender<RunnableVariant>,
timer_sender: Sender<TimerAfter>,
background_sender: flume::Sender<RunnableVariant>,
background_sender: PriorityQueueSender<RunnableVariant>,
_background_threads: Vec<thread::JoinHandle<()>>,
main_thread_id: thread::ThreadId,
}
@@ -29,18 +30,20 @@ pub(crate) struct LinuxDispatcher {
const MIN_THREADS: usize = 2;
impl LinuxDispatcher {
pub fn new(main_sender: Sender<RunnableVariant>) -> Self {
let (background_sender, background_receiver) = flume::unbounded::<RunnableVariant>();
pub fn new(main_sender: PriorityQueueCalloopSender<RunnableVariant>) -> Self {
let (background_sender, background_receiver) = PriorityQueueReceiver::new();
let thread_count =
std::thread::available_parallelism().map_or(MIN_THREADS, |i| i.get().max(MIN_THREADS));
// These thread should really be lower prio then the foreground
// executor
let mut background_threads = (0..thread_count)
.map(|i| {
let receiver = background_receiver.clone();
let mut receiver = background_receiver.clone();
std::thread::Builder::new()
.name(format!("Worker-{i}"))
.spawn(move || {
for runnable in receiver {
for runnable in receiver.iter() {
let start = Instant::now();
let mut location = match runnable {
@@ -51,7 +54,7 @@ impl LinuxDispatcher {
start,
end: None,
};
Self::add_task_timing(timing);
profiler::add_task_timing(timing);
runnable.run();
timing
@@ -63,7 +66,7 @@ impl LinuxDispatcher {
start,
end: None,
};
Self::add_task_timing(timing);
profiler::add_task_timing(timing);
runnable.run();
timing
@@ -72,7 +75,7 @@ impl LinuxDispatcher {
let end = Instant::now();
location.end = Some(end);
Self::add_task_timing(location);
profiler::add_task_timing(location);
log::trace!(
"background thread {}: ran runnable. took: {:?}",
@@ -113,7 +116,7 @@ impl LinuxDispatcher {
start,
end: None,
};
Self::add_task_timing(timing);
profiler::add_task_timing(timing);
runnable.run();
timing
@@ -124,7 +127,7 @@ impl LinuxDispatcher {
start,
end: None,
};
Self::add_task_timing(timing);
profiler::add_task_timing(timing);
runnable.run();
timing
@@ -133,7 +136,7 @@ impl LinuxDispatcher {
let end = Instant::now();
timing.end = Some(end);
Self::add_task_timing(timing);
profiler::add_task_timing(timing);
}
TimeoutAction::Drop
},
@@ -157,22 +160,6 @@ impl LinuxDispatcher {
main_thread_id: thread::current().id(),
}
}
pub(crate) fn add_task_timing(timing: TaskTiming) {
THREAD_TIMINGS.with(|timings| {
let mut timings = timings.lock();
let timings = &mut timings.timings;
if let Some(last_timing) = timings.iter_mut().rev().next() {
if last_timing.location == timing.location {
last_timing.end = timing.end;
return;
}
}
timings.push_back(timing);
});
}
}
impl PlatformDispatcher for LinuxDispatcher {
@@ -199,22 +186,26 @@ impl PlatformDispatcher for LinuxDispatcher {
thread::current().id() == self.main_thread_id
}
fn dispatch(&self, runnable: RunnableVariant, _: Option<TaskLabel>) {
self.background_sender.send(runnable).unwrap();
fn dispatch(&self, runnable: RunnableVariant, _: Option<TaskLabel>, priority: Priority) {
self.background_sender
.send(priority, runnable)
.unwrap_or_else(|_| panic!("blocking sender returned without value"));
}
fn dispatch_on_main_thread(&self, runnable: RunnableVariant) {
self.main_sender.send(runnable).unwrap_or_else(|runnable| {
// NOTE: Runnable may wrap a Future that is !Send.
//
// This is usually safe because we only poll it on the main thread.
// However if the send fails, we know that:
// 1. main_receiver has been dropped (which implies the app is shutting down)
// 2. we are on a background thread.
// It is not safe to drop something !Send on the wrong thread, and
// the app will exit soon anyway, so we must forget the runnable.
std::mem::forget(runnable);
});
fn dispatch_on_main_thread(&self, runnable: RunnableVariant, priority: Priority) {
self.main_sender
.send(priority, runnable)
.unwrap_or_else(|runnable| {
// NOTE: Runnable may wrap a Future that is !Send.
//
// This is usually safe because we only poll it on the main thread.
// However if the send fails, we know that:
// 1. main_receiver has been dropped (which implies the app is shutting down)
// 2. we are on a background thread.
// It is not safe to drop something !Send on the wrong thread, and
// the app will exit soon anyway, so we must forget the runnable.
std::mem::forget(runnable);
});
}
fn dispatch_after(&self, duration: Duration, runnable: RunnableVariant) {
@@ -222,4 +213,252 @@ impl PlatformDispatcher for LinuxDispatcher {
.send(TimerAfter { duration, runnable })
.ok();
}
fn spawn_realtime(&self, priority: RealtimePriority, f: Box<dyn FnOnce() + Send>) {
std::thread::spawn(move || {
// SAFETY: always safe to call
let thread_id = unsafe { libc::pthread_self() };
let policy = match priority {
RealtimePriority::Audio => libc::SCHED_FIFO,
RealtimePriority::Other => libc::SCHED_RR,
};
let sched_priority = match priority {
RealtimePriority::Audio => 65,
RealtimePriority::Other => 45,
};
let sched_param = libc::sched_param { sched_priority };
// SAFETY: sched_param is a valid initialized structure
let result = unsafe { libc::pthread_setschedparam(thread_id, policy, &sched_param) };
if result != 0 {
log::warn!("failed to set realtime thread priority to {:?}", priority);
}
f();
});
}
}
pub struct PriorityQueueCalloopSender<T> {
sender: PriorityQueueSender<T>,
ping: calloop::ping::Ping,
}
impl<T> PriorityQueueCalloopSender<T> {
fn new(tx: PriorityQueueSender<T>, ping: calloop::ping::Ping) -> Self {
Self { sender: tx, ping }
}
fn send(&self, priority: Priority, item: T) -> Result<(), crate::queue::SendError<T>> {
let res = self.sender.send(priority, item);
if res.is_ok() {
self.ping.ping();
}
res
}
}
impl<T> Drop for PriorityQueueCalloopSender<T> {
fn drop(&mut self) {
self.ping.ping();
}
}
pub struct PriorityQueueCalloopReceiver<T> {
receiver: PriorityQueueReceiver<T>,
source: calloop::ping::PingSource,
ping: calloop::ping::Ping,
}
impl<T> PriorityQueueCalloopReceiver<T> {
pub fn new() -> (PriorityQueueCalloopSender<T>, Self) {
let (ping, source) = calloop::ping::make_ping().expect("Failed to create a Ping.");
let (tx, rx) = PriorityQueueReceiver::new();
(
PriorityQueueCalloopSender::new(tx, ping.clone()),
Self {
receiver: rx,
source,
ping,
},
)
}
}
use calloop::channel::Event;
#[derive(Debug)]
pub struct ChannelError(calloop::ping::PingError);
impl std::fmt::Display for ChannelError {
#[cfg_attr(feature = "nightly_coverage", coverage(off))]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(&self.0, f)
}
}
impl std::error::Error for ChannelError {
#[cfg_attr(feature = "nightly_coverage", coverage(off))]
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(&self.0)
}
}
impl<T> calloop::EventSource for PriorityQueueCalloopReceiver<T> {
type Event = Event<T>;
type Metadata = ();
type Ret = ();
type Error = ChannelError;
fn process_events<F>(
&mut self,
readiness: calloop::Readiness,
token: calloop::Token,
mut callback: F,
) -> Result<calloop::PostAction, Self::Error>
where
F: FnMut(Self::Event, &mut Self::Metadata) -> Self::Ret,
{
let mut clear_readiness = false;
let mut disconnected = false;
let action = self
.source
.process_events(readiness, token, |(), &mut ()| {
let mut is_empty = true;
let mut receiver = self.receiver.clone();
for runnable in receiver.try_iter() {
match runnable {
Ok(r) => {
callback(Event::Msg(r), &mut ());
is_empty = false;
}
Err(_) => {
disconnected = true;
}
}
}
if disconnected {
callback(Event::Closed, &mut ());
}
if is_empty {
clear_readiness = true;
}
})
.map_err(ChannelError)?;
if disconnected {
Ok(PostAction::Remove)
} else if clear_readiness {
Ok(action)
} else {
// Re-notify the ping source so we can try again.
self.ping.ping();
Ok(PostAction::Continue)
}
}
fn register(
&mut self,
poll: &mut calloop::Poll,
token_factory: &mut calloop::TokenFactory,
) -> calloop::Result<()> {
self.source.register(poll, token_factory)
}
fn reregister(
&mut self,
poll: &mut calloop::Poll,
token_factory: &mut calloop::TokenFactory,
) -> calloop::Result<()> {
self.source.reregister(poll, token_factory)
}
fn unregister(&mut self, poll: &mut calloop::Poll) -> calloop::Result<()> {
self.source.unregister(poll)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn calloop_works() {
let mut event_loop = calloop::EventLoop::try_new().unwrap();
let handle = event_loop.handle();
let (tx, rx) = PriorityQueueCalloopReceiver::new();
struct Data {
got_msg: bool,
got_closed: bool,
}
let mut data = Data {
got_msg: false,
got_closed: false,
};
let _channel_token = handle
.insert_source(rx, move |evt, &mut (), data: &mut Data| match evt {
Event::Msg(()) => {
data.got_msg = true;
}
Event::Closed => {
data.got_closed = true;
}
})
.unwrap();
// nothing is sent, nothing is received
event_loop
.dispatch(Some(::std::time::Duration::ZERO), &mut data)
.unwrap();
assert!(!data.got_msg);
assert!(!data.got_closed);
// a message is send
tx.send(Priority::Medium, ()).unwrap();
event_loop
.dispatch(Some(::std::time::Duration::ZERO), &mut data)
.unwrap();
assert!(data.got_msg);
assert!(!data.got_closed);
// the sender is dropped
drop(tx);
event_loop
.dispatch(Some(::std::time::Duration::ZERO), &mut data)
.unwrap();
assert!(data.got_msg);
assert!(data.got_closed);
}
}
// running 1 test
// test platform::linux::dispatcher::tests::tomato ... FAILED
// failures:
// ---- platform::linux::dispatcher::tests::tomato stdout ----
// [crates/gpui/src/platform/linux/dispatcher.rs:262:9]
// returning 1 tasks to process
// [crates/gpui/src/platform/linux/dispatcher.rs:480:75] evt = Msg(
// (),
// )
// returning 0 tasks to process
// thread 'platform::linux::dispatcher::tests::tomato' (478301) panicked at crates/gpui/src/platform/linux/dispatcher.rs:515:9:
// assertion failed: data.got_closed
// note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace

View File

@@ -14,7 +14,7 @@ use std::{
};
use anyhow::{Context as _, anyhow};
use calloop::{LoopSignal, channel::Channel};
use calloop::LoopSignal;
use futures::channel::oneshot;
use util::ResultExt as _;
use util::command::{new_smol_command, new_std_command};
@@ -25,8 +25,8 @@ use crate::{
Action, AnyWindowHandle, BackgroundExecutor, ClipboardItem, CursorStyle, DisplayId,
ForegroundExecutor, Keymap, LinuxDispatcher, Menu, MenuItem, OwnedMenu, PathPromptOptions,
Pixels, Platform, PlatformDisplay, PlatformKeyboardLayout, PlatformKeyboardMapper,
PlatformTextSystem, PlatformWindow, Point, Result, RunnableVariant, Task, WindowAppearance,
WindowParams, px,
PlatformTextSystem, PlatformWindow, Point, PriorityQueueCalloopReceiver, Result,
RunnableVariant, Task, WindowAppearance, WindowParams, px,
};
#[cfg(any(feature = "wayland", feature = "x11"))]
@@ -149,8 +149,8 @@ pub(crate) struct LinuxCommon {
}
impl LinuxCommon {
pub fn new(signal: LoopSignal) -> (Self, Channel<RunnableVariant>) {
let (main_sender, main_receiver) = calloop::channel::channel::<RunnableVariant>();
pub fn new(signal: LoopSignal) -> (Self, PriorityQueueCalloopReceiver<RunnableVariant>) {
let (main_sender, main_receiver) = PriorityQueueCalloopReceiver::new();
#[cfg(any(feature = "wayland", feature = "x11"))]
let text_system = Arc::new(crate::CosmicTextSystem::new());

View File

@@ -77,10 +77,10 @@ use crate::{
LinuxKeyboardLayout, Modifiers, ModifiersChangedEvent, MouseButton, MouseDownEvent,
MouseExitEvent, MouseMoveEvent, MouseUpEvent, NavigationDirection, Pixels, PlatformDisplay,
PlatformInput, PlatformKeyboardLayout, Point, ResultExt as _, SCROLL_LINES, ScrollDelta,
ScrollWheelEvent, Size, TouchPhase, WindowParams, point, px, size,
ScrollWheelEvent, Size, TouchPhase, WindowParams, point, profiler, px, size,
};
use crate::{
LinuxDispatcher, RunnableVariant, TaskTiming,
RunnableVariant, TaskTiming,
platform::{PlatformWindow, blade::BladeContext},
};
use crate::{
@@ -503,7 +503,7 @@ impl WaylandClient {
start,
end: None,
};
LinuxDispatcher::add_task_timing(timing);
profiler::add_task_timing(timing);
runnable.run();
timing
@@ -515,7 +515,7 @@ impl WaylandClient {
start,
end: None,
};
LinuxDispatcher::add_task_timing(timing);
profiler::add_task_timing(timing);
runnable.run();
timing
@@ -524,7 +524,7 @@ impl WaylandClient {
let end = Instant::now();
timing.end = Some(end);
LinuxDispatcher::add_task_timing(timing);
profiler::add_task_timing(timing);
});
}
}

View File

@@ -1,4 +1,4 @@
use crate::{Capslock, LinuxDispatcher, ResultExt as _, RunnableVariant, TaskTiming, xcb_flush};
use crate::{Capslock, ResultExt as _, RunnableVariant, TaskTiming, profiler, xcb_flush};
use anyhow::{Context as _, anyhow};
use ashpd::WindowIdentifier;
use calloop::{
@@ -322,7 +322,7 @@ impl X11Client {
start,
end: None,
};
LinuxDispatcher::add_task_timing(timing);
profiler::add_task_timing(timing);
runnable.run();
timing
@@ -334,7 +334,7 @@ impl X11Client {
start,
end: None,
};
LinuxDispatcher::add_task_timing(timing);
profiler::add_task_timing(timing);
runnable.run();
timing
@@ -343,7 +343,7 @@ impl X11Client {
let end = Instant::now();
timing.end = Some(end);
LinuxDispatcher::add_task_timing(timing);
profiler::add_task_timing(timing);
});
}
}

View File

@@ -3,11 +3,22 @@
#![allow(non_snake_case)]
use crate::{
GLOBAL_THREAD_TIMINGS, PlatformDispatcher, RunnableMeta, RunnableVariant, THREAD_TIMINGS,
TaskLabel, TaskTiming, ThreadTaskTimings,
GLOBAL_THREAD_TIMINGS, PlatformDispatcher, Priority, RealtimePriority, RunnableMeta,
RunnableVariant, THREAD_TIMINGS, TaskLabel, TaskTiming, ThreadTaskTimings,
};
use anyhow::Context;
use async_task::Runnable;
use mach2::{
kern_return::KERN_SUCCESS,
mach_time::mach_timebase_info_data_t,
thread_policy::{
THREAD_EXTENDED_POLICY, THREAD_EXTENDED_POLICY_COUNT, THREAD_PRECEDENCE_POLICY,
THREAD_PRECEDENCE_POLICY_COUNT, THREAD_TIME_CONSTRAINT_POLICY,
THREAD_TIME_CONSTRAINT_POLICY_COUNT, thread_extended_policy_data_t,
thread_precedence_policy_data_t, thread_time_constraint_policy_data_t,
},
};
use objc::{
class, msg_send,
runtime::{BOOL, YES},
@@ -15,9 +26,11 @@ use objc::{
};
use std::{
ffi::c_void,
mem::MaybeUninit,
ptr::{NonNull, addr_of},
time::{Duration, Instant},
};
use util::ResultExt;
/// All items in the generated file are marked as pub, so we're gonna wrap it in a separate mod to prevent
/// these pub items from leaking into public API.
@@ -56,7 +69,7 @@ impl PlatformDispatcher for MacDispatcher {
is_main_thread == YES
}
fn dispatch(&self, runnable: RunnableVariant, _: Option<TaskLabel>) {
fn dispatch(&self, runnable: RunnableVariant, _: Option<TaskLabel>, priority: Priority) {
let (context, trampoline) = match runnable {
RunnableVariant::Meta(runnable) => (
runnable.into_raw().as_ptr() as *mut c_void,
@@ -67,16 +80,24 @@ impl PlatformDispatcher for MacDispatcher {
Some(trampoline_compat as unsafe extern "C" fn(*mut c_void)),
),
};
let queue_priority = match priority {
Priority::Realtime(_) => unreachable!(),
Priority::High => DISPATCH_QUEUE_PRIORITY_HIGH as isize,
Priority::Medium => DISPATCH_QUEUE_PRIORITY_DEFAULT as isize,
Priority::Low => DISPATCH_QUEUE_PRIORITY_LOW as isize,
};
unsafe {
dispatch_async_f(
dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_HIGH.try_into().unwrap(), 0),
dispatch_get_global_queue(queue_priority, 0),
context,
trampoline,
);
}
}
fn dispatch_on_main_thread(&self, runnable: RunnableVariant) {
fn dispatch_on_main_thread(&self, runnable: RunnableVariant, _priority: Priority) {
let (context, trampoline) = match runnable {
RunnableVariant::Meta(runnable) => (
runnable.into_raw().as_ptr() as *mut c_void,
@@ -110,6 +131,120 @@ impl PlatformDispatcher for MacDispatcher {
dispatch_after_f(when, queue, context, trampoline);
}
}
fn spawn_realtime(&self, priority: RealtimePriority, f: Box<dyn FnOnce() + Send>) {
std::thread::spawn(move || {
match priority {
RealtimePriority::Audio => set_audio_thread_priority(),
RealtimePriority::Other => set_high_thread_priority(),
}
.context(format!("for priority {:?}", priority))
.log_err();
f();
});
}
}
fn set_high_thread_priority() -> anyhow::Result<()> {
// SAFETY: always safe to call
let thread_id = unsafe { libc::pthread_self() };
// SAFETY: all sched_param members are valid when initialized to zero.
let mut sched_param = unsafe { MaybeUninit::<libc::sched_param>::zeroed().assume_init() };
sched_param.sched_priority = 45;
let result = unsafe { libc::pthread_setschedparam(thread_id, libc::SCHED_FIFO, &sched_param) };
if result != 0 {
anyhow::bail!("failed to set realtime thread priority")
}
Ok(())
}
fn set_audio_thread_priority() -> anyhow::Result<()> {
// https://chromium.googlesource.com/chromium/chromium/+/master/base/threading/platform_thread_mac.mm#93
// SAFETY: always safe to call
let thread_id = unsafe { libc::pthread_self() };
// SAFETY: thread_id is a valid thread id
let thread_id = unsafe { libc::pthread_mach_thread_np(thread_id) };
// Fixed priority thread
let mut policy = thread_extended_policy_data_t { timeshare: 0 };
// SAFETY: thread_id is a valid thread id
// SAFETY: thread_extended_policy_data_t is passed as THREAD_EXTENDED_POLICY
let result = unsafe {
mach2::thread_policy::thread_policy_set(
thread_id,
THREAD_EXTENDED_POLICY,
&mut policy as *mut _ as *mut _,
THREAD_EXTENDED_POLICY_COUNT,
)
};
if result != KERN_SUCCESS {
anyhow::bail!("failed to set thread extended policy");
}
// relatively high priority
let mut precedence = thread_precedence_policy_data_t { importance: 63 };
// SAFETY: thread_id is a valid thread id
// SAFETY: thread_precedence_policy_data_t is passed as THREAD_PRECEDENCE_POLICY
let result = unsafe {
mach2::thread_policy::thread_policy_set(
thread_id,
THREAD_PRECEDENCE_POLICY,
&mut precedence as *mut _ as *mut _,
THREAD_PRECEDENCE_POLICY_COUNT,
)
};
if result != KERN_SUCCESS {
anyhow::bail!("failed to set thread precedence policy");
}
const GUARANTEED_AUDIO_DUTY_CYCLE: f32 = 0.75;
const MAX_AUDIO_DUTY_CYCLE: f32 = 0.85;
// ~128 frames @ 44.1KHz
const TIME_QUANTUM: f32 = 2.9;
const AUDIO_TIME_NEEDED: f32 = GUARANTEED_AUDIO_DUTY_CYCLE * TIME_QUANTUM;
const MAX_TIME_ALLOWED: f32 = MAX_AUDIO_DUTY_CYCLE * TIME_QUANTUM;
let mut timebase_info = mach_timebase_info_data_t { numer: 0, denom: 0 };
// SAFETY: timebase_info is a valid pointer to a mach_timebase_info_data_t struct
unsafe { mach2::mach_time::mach_timebase_info(&mut timebase_info) };
let ms_to_abs_time = ((timebase_info.denom as f32) / (timebase_info.numer as f32)) * 1000000f32;
let mut time_constraints = thread_time_constraint_policy_data_t {
period: (TIME_QUANTUM * ms_to_abs_time) as u32,
computation: (AUDIO_TIME_NEEDED * ms_to_abs_time) as u32,
constraint: (MAX_TIME_ALLOWED * ms_to_abs_time) as u32,
preemptible: 0,
};
// SAFETY: thread_id is a valid thread id
// SAFETY: thread_precedence_pthread_time_constraint_policy_data_t is passed as THREAD_TIME_CONSTRAINT_POLICY
let result = unsafe {
mach2::thread_policy::thread_policy_set(
thread_id,
THREAD_TIME_CONSTRAINT_POLICY,
&mut time_constraints as *mut _ as *mut _,
THREAD_TIME_CONSTRAINT_POLICY_COUNT,
)
};
if result != KERN_SUCCESS {
anyhow::bail!("failed to set thread time constraint policy");
}
Ok(())
}
extern "C" fn trampoline(runnable: *mut c_void) {

View File

@@ -1,4 +1,4 @@
use crate::{PlatformDispatcher, RunnableVariant, TaskLabel};
use crate::{PlatformDispatcher, Priority, RunnableVariant, TaskLabel};
use backtrace::Backtrace;
use collections::{HashMap, HashSet, VecDeque};
use parking::Unparker;
@@ -284,7 +284,7 @@ impl PlatformDispatcher for TestDispatcher {
state.start_time + state.time
}
fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>) {
fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>, _priority: Priority) {
{
let mut state = self.state.lock();
if label.is_some_and(|label| state.deprioritized_task_labels.contains(&label)) {
@@ -296,7 +296,7 @@ impl PlatformDispatcher for TestDispatcher {
self.unpark_all();
}
fn dispatch_on_main_thread(&self, runnable: RunnableVariant) {
fn dispatch_on_main_thread(&self, runnable: RunnableVariant, _priority: Priority) {
self.state
.lock()
.foreground
@@ -318,4 +318,10 @@ impl PlatformDispatcher for TestDispatcher {
fn as_test(&self) -> Option<&TestDispatcher> {
Some(self)
}
fn spawn_realtime(&self, _priority: crate::RealtimePriority, f: Box<dyn FnOnce() + Send>) {
std::thread::spawn(move || {
f();
});
}
}

View File

@@ -4,24 +4,31 @@ use std::{
time::{Duration, Instant},
};
use flume::Sender;
use anyhow::Context;
use util::ResultExt;
use windows::{
System::Threading::{ThreadPool, ThreadPoolTimer, TimerElapsedHandler, WorkItemHandler},
System::Threading::{
ThreadPool, ThreadPoolTimer, TimerElapsedHandler, WorkItemHandler, WorkItemPriority,
},
Win32::{
Foundation::{LPARAM, WPARAM},
System::Threading::{
GetCurrentThread, HIGH_PRIORITY_CLASS, SetPriorityClass, SetThreadPriority,
THREAD_PRIORITY_HIGHEST, THREAD_PRIORITY_TIME_CRITICAL,
},
UI::WindowsAndMessaging::PostMessageW,
},
};
use crate::{
GLOBAL_THREAD_TIMINGS, HWND, PlatformDispatcher, RunnableVariant, SafeHwnd, THREAD_TIMINGS,
TaskLabel, TaskTiming, ThreadTaskTimings, WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD,
GLOBAL_THREAD_TIMINGS, HWND, PlatformDispatcher, Priority, PriorityQueueSender,
RealtimePriority, RunnableVariant, SafeHwnd, THREAD_TIMINGS, TaskLabel, TaskTiming,
ThreadTaskTimings, WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD, profiler,
};
pub(crate) struct WindowsDispatcher {
pub(crate) wake_posted: AtomicBool,
main_sender: Sender<RunnableVariant>,
main_sender: PriorityQueueSender<RunnableVariant>,
main_thread_id: ThreadId,
pub(crate) platform_window_handle: SafeHwnd,
validation_number: usize,
@@ -29,7 +36,7 @@ pub(crate) struct WindowsDispatcher {
impl WindowsDispatcher {
pub(crate) fn new(
main_sender: Sender<RunnableVariant>,
main_sender: PriorityQueueSender<RunnableVariant>,
platform_window_handle: HWND,
validation_number: usize,
) -> Self {
@@ -45,7 +52,7 @@ impl WindowsDispatcher {
}
}
fn dispatch_on_threadpool(&self, runnable: RunnableVariant) {
fn dispatch_on_threadpool(&self, priority: WorkItemPriority, runnable: RunnableVariant) {
let handler = {
let mut task_wrapper = Some(runnable);
WorkItemHandler::new(move |_| {
@@ -53,7 +60,8 @@ impl WindowsDispatcher {
Ok(())
})
};
ThreadPool::RunAsync(&handler).log_err();
ThreadPool::RunWithPriorityAsync(&handler, priority).log_err();
}
fn dispatch_on_threadpool_after(&self, runnable: RunnableVariant, duration: Duration) {
@@ -79,7 +87,7 @@ impl WindowsDispatcher {
start,
end: None,
};
Self::add_task_timing(timing);
profiler::add_task_timing(timing);
runnable.run();
@@ -91,7 +99,7 @@ impl WindowsDispatcher {
start,
end: None,
};
Self::add_task_timing(timing);
profiler::add_task_timing(timing);
runnable.run();
@@ -102,23 +110,7 @@ impl WindowsDispatcher {
let end = Instant::now();
timing.end = Some(end);
Self::add_task_timing(timing);
}
pub(crate) fn add_task_timing(timing: TaskTiming) {
THREAD_TIMINGS.with(|timings| {
let mut timings = timings.lock();
let timings = &mut timings.timings;
if let Some(last_timing) = timings.iter_mut().rev().next() {
if last_timing.location == timing.location {
last_timing.end = timing.end;
return;
}
}
timings.push_back(timing);
});
profiler::add_task_timing(timing);
}
}
@@ -146,15 +138,22 @@ impl PlatformDispatcher for WindowsDispatcher {
current().id() == self.main_thread_id
}
fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>) {
self.dispatch_on_threadpool(runnable);
fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>, priority: Priority) {
let priority = match priority {
Priority::Realtime(_) => unreachable!(),
Priority::High => WorkItemPriority::High,
Priority::Medium => WorkItemPriority::Normal,
Priority::Low => WorkItemPriority::Low,
};
self.dispatch_on_threadpool(priority, runnable);
if let Some(label) = label {
log::debug!("TaskLabel: {label:?}");
}
}
fn dispatch_on_main_thread(&self, runnable: RunnableVariant) {
match self.main_sender.send(runnable) {
fn dispatch_on_main_thread(&self, runnable: RunnableVariant, priority: Priority) {
match self.main_sender.send(priority, runnable) {
Ok(_) => {
if !self.wake_posted.swap(true, Ordering::AcqRel) {
unsafe {
@@ -185,4 +184,28 @@ impl PlatformDispatcher for WindowsDispatcher {
fn dispatch_after(&self, duration: Duration, runnable: RunnableVariant) {
self.dispatch_on_threadpool_after(runnable, duration);
}
fn spawn_realtime(&self, priority: RealtimePriority, f: Box<dyn FnOnce() + Send>) {
std::thread::spawn(move || {
// SAFETY: always safe to call
let thread_handle = unsafe { GetCurrentThread() };
let thread_priority = match priority {
RealtimePriority::Audio => THREAD_PRIORITY_TIME_CRITICAL,
RealtimePriority::Other => THREAD_PRIORITY_HIGHEST,
};
// SAFETY: thread_handle is a valid handle to a thread
unsafe { SetPriorityClass(thread_handle, HIGH_PRIORITY_CLASS) }
.context("thread priority class")
.log_err();
// SAFETY: thread_handle is a valid handle to a thread
unsafe { SetThreadPriority(thread_handle, thread_priority) }
.context("thread priority")
.log_err();
f();
});
}
}

View File

@@ -243,7 +243,8 @@ impl WindowsWindowInner {
fn handle_timer_msg(&self, handle: HWND, wparam: WPARAM) -> Option<isize> {
if wparam.0 == SIZE_MOVE_LOOP_TIMER_ID {
for runnable in self.main_receiver.drain() {
let mut runnables = self.main_receiver.clone().try_iter();
while let Some(Ok(runnable)) = runnables.next() {
WindowsDispatcher::execute_runnable(runnable);
}
self.handle_paint_msg(handle)

View File

@@ -51,7 +51,7 @@ struct WindowsPlatformInner {
raw_window_handles: std::sync::Weak<RwLock<SmallVec<[SafeHwnd; 4]>>>,
// The below members will never change throughout the entire lifecycle of the app.
validation_number: usize,
main_receiver: flume::Receiver<RunnableVariant>,
main_receiver: PriorityQueueReceiver<RunnableVariant>,
dispatcher: Arc<WindowsDispatcher>,
}
@@ -98,7 +98,7 @@ impl WindowsPlatform {
OleInitialize(None).context("unable to initialize Windows OLE")?;
}
let directx_devices = DirectXDevices::new().context("Creating DirectX devices")?;
let (main_sender, main_receiver) = flume::unbounded::<RunnableVariant>();
let (main_sender, main_receiver) = PriorityQueueReceiver::new();
let validation_number = if usize::BITS == 64 {
rand::random::<u64>() as usize
} else {
@@ -857,22 +857,24 @@ impl WindowsPlatformInner {
}
break 'tasks;
}
match self.main_receiver.try_recv() {
Err(_) => break 'timeout_loop,
Ok(runnable) => WindowsDispatcher::execute_runnable(runnable),
let mut main_receiver = self.main_receiver.clone();
match main_receiver.try_pop() {
Ok(Some(runnable)) => WindowsDispatcher::execute_runnable(runnable),
_ => break 'timeout_loop,
}
}
// Someone could enqueue a Runnable here. The flag is still true, so they will not PostMessage.
// We need to check for those Runnables after we clear the flag.
self.dispatcher.wake_posted.store(false, Ordering::Release);
match self.main_receiver.try_recv() {
Err(_) => break 'tasks,
Ok(runnable) => {
let mut main_receiver = self.main_receiver.clone();
match main_receiver.try_pop() {
Ok(Some(runnable)) => {
self.dispatcher.wake_posted.store(true, Ordering::Release);
WindowsDispatcher::execute_runnable(runnable);
}
_ => break 'tasks,
}
}
@@ -934,7 +936,7 @@ pub(crate) struct WindowCreationInfo {
pub(crate) windows_version: WindowsVersion,
pub(crate) drop_target_helper: IDropTargetHelper,
pub(crate) validation_number: usize,
pub(crate) main_receiver: flume::Receiver<RunnableVariant>,
pub(crate) main_receiver: PriorityQueueReceiver<RunnableVariant>,
pub(crate) platform_window_handle: HWND,
pub(crate) disable_direct_composition: bool,
pub(crate) directx_devices: DirectXDevices,
@@ -947,8 +949,8 @@ struct PlatformWindowCreateContext {
inner: Option<Result<Rc<WindowsPlatformInner>>>,
raw_window_handles: std::sync::Weak<RwLock<SmallVec<[SafeHwnd; 4]>>>,
validation_number: usize,
main_sender: Option<flume::Sender<RunnableVariant>>,
main_receiver: Option<flume::Receiver<RunnableVariant>>,
main_sender: Option<PriorityQueueSender<RunnableVariant>>,
main_receiver: Option<PriorityQueueReceiver<RunnableVariant>>,
directx_devices: Option<DirectXDevices>,
dispatcher: Option<Arc<WindowsDispatcher>>,
}

View File

@@ -81,7 +81,7 @@ pub(crate) struct WindowsWindowInner {
pub(crate) executor: ForegroundExecutor,
pub(crate) windows_version: WindowsVersion,
pub(crate) validation_number: usize,
pub(crate) main_receiver: flume::Receiver<RunnableVariant>,
pub(crate) main_receiver: PriorityQueueReceiver<RunnableVariant>,
pub(crate) platform_window_handle: HWND,
}
@@ -362,7 +362,7 @@ struct WindowCreateContext {
windows_version: WindowsVersion,
drop_target_helper: IDropTargetHelper,
validation_number: usize,
main_receiver: flume::Receiver<RunnableVariant>,
main_receiver: PriorityQueueReceiver<RunnableVariant>,
platform_window_handle: HWND,
appearance: WindowAppearance,
disable_direct_composition: bool,

View File

@@ -216,3 +216,19 @@ impl Drop for ThreadTimings {
thread_timings.swap_remove(index);
}
}
pub(crate) fn add_task_timing(timing: TaskTiming) {
THREAD_TIMINGS.with(|timings| {
let mut timings = timings.lock();
let timings = &mut timings.timings;
if let Some(last_timing) = timings.iter_mut().rev().next() {
if last_timing.location == timing.location {
last_timing.end = timing.end;
return;
}
}
timings.push_back(timing);
});
}

329
crates/gpui/src/queue.rs Normal file
View File

@@ -0,0 +1,329 @@
use std::{
fmt,
iter::FusedIterator,
sync::{Arc, atomic::AtomicUsize},
};
use rand::{Rng, SeedableRng, rngs::SmallRng};
use crate::Priority;
struct PriorityQueues<T> {
high_priority: Vec<T>,
medium_priority: Vec<T>,
low_priority: Vec<T>,
}
impl<T> PriorityQueues<T> {
fn is_empty(&self) -> bool {
self.high_priority.is_empty()
&& self.medium_priority.is_empty()
&& self.low_priority.is_empty()
}
}
struct PriorityQueueState<T> {
queues: parking_lot::Mutex<PriorityQueues<T>>,
condvar: parking_lot::Condvar,
receiver_count: AtomicUsize,
sender_count: AtomicUsize,
}
impl<T> PriorityQueueState<T> {
fn send(&self, priority: Priority, item: T) -> Result<(), SendError<T>> {
if self
.receiver_count
.load(std::sync::atomic::Ordering::Relaxed)
== 0
{
return Err(SendError(item));
}
let mut queues = self.queues.lock();
match priority {
Priority::Realtime(_) => unreachable!(),
Priority::High => queues.high_priority.push(item),
Priority::Medium => queues.medium_priority.push(item),
Priority::Low => queues.low_priority.push(item),
};
self.condvar.notify_one();
Ok(())
}
fn recv<'a>(&'a self) -> Result<parking_lot::MutexGuard<'a, PriorityQueues<T>>, RecvError> {
let mut queues = self.queues.lock();
let sender_count = self.sender_count.load(std::sync::atomic::Ordering::Relaxed);
if queues.is_empty() && sender_count == 0 {
return Err(crate::queue::RecvError);
}
// parking_lot doesn't do spurious wakeups so an if is fine
if queues.is_empty() {
self.condvar.wait(&mut queues);
}
Ok(queues)
}
fn try_recv<'a>(
&'a self,
) -> Result<Option<parking_lot::MutexGuard<'a, PriorityQueues<T>>>, RecvError> {
let mut queues = self.queues.lock();
let sender_count = self.sender_count.load(std::sync::atomic::Ordering::Relaxed);
if queues.is_empty() && sender_count == 0 {
return Err(crate::queue::RecvError);
}
if queues.is_empty() {
Ok(None)
} else {
Ok(Some(queues))
}
}
}
pub(crate) struct PriorityQueueSender<T> {
state: Arc<PriorityQueueState<T>>,
}
impl<T> PriorityQueueSender<T> {
fn new(state: Arc<PriorityQueueState<T>>) -> Self {
Self { state }
}
pub(crate) fn send(&self, priority: Priority, item: T) -> Result<(), SendError<T>> {
self.state.send(priority, item)?;
Ok(())
}
}
impl<T> Drop for PriorityQueueSender<T> {
fn drop(&mut self) {
self.state
.sender_count
.fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
}
}
pub(crate) struct PriorityQueueReceiver<T> {
state: Arc<PriorityQueueState<T>>,
rand: SmallRng,
disconnected: bool,
}
impl<T> Clone for PriorityQueueReceiver<T> {
fn clone(&self) -> Self {
self.state
.receiver_count
.fetch_add(1, std::sync::atomic::Ordering::AcqRel);
Self {
state: Arc::clone(&self.state),
rand: SmallRng::seed_from_u64(0),
disconnected: self.disconnected,
}
}
}
pub(crate) struct SendError<T>(T);
impl<T: fmt::Debug> fmt::Debug for SendError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_tuple("SendError").field(&self.0).finish()
}
}
#[derive(Debug)]
pub(crate) struct RecvError;
#[allow(dead_code)]
impl<T> PriorityQueueReceiver<T> {
pub(crate) fn new() -> (PriorityQueueSender<T>, Self) {
let state = PriorityQueueState {
queues: parking_lot::Mutex::new(PriorityQueues {
high_priority: Vec::new(),
medium_priority: Vec::new(),
low_priority: Vec::new(),
}),
condvar: parking_lot::Condvar::new(),
receiver_count: AtomicUsize::new(1),
sender_count: AtomicUsize::new(1),
};
let state = Arc::new(state);
let sender = PriorityQueueSender::new(Arc::clone(&state));
let receiver = PriorityQueueReceiver {
state,
rand: SmallRng::seed_from_u64(0),
disconnected: false,
};
(sender, receiver)
}
/// Tries to pop one element from the priority queue without blocking.
///
/// This will early return if there are no elements in the queue.
///
/// This method is best suited if you only intend to pop one element, for better performance
/// on large queues see [`Self::try_iter`]
///
/// # Errors
///
/// If the sender was dropped
pub(crate) fn try_pop(&mut self) -> Result<Option<T>, RecvError> {
self.pop_inner(false)
}
/// Pops an element from the priority queue blocking if necessary.
///
/// This method is best suited if you only intend to pop one element, for better performance
/// on large queues see [`Self::iter``]
///
/// # Errors
///
/// If the sender was dropped
pub(crate) fn pop(&mut self) -> Result<T, RecvError> {
self.pop_inner(true).map(|e| e.unwrap())
}
/// Returns an iterator over the elements of the queue
/// this iterator will end when all elements have been consumed and will not wait for new ones.
pub(crate) fn try_iter(self) -> TryIter<T> {
TryIter {
receiver: self,
ended: false,
}
}
/// Returns an iterator over the elements of the queue
/// this iterator will wait for new elements if the queue is empty.
pub(crate) fn iter(self) -> Iter<T> {
Iter(self)
}
#[inline(always)]
// algorithm is the loaded die from biased coin from
// https://www.keithschwarz.com/darts-dice-coins/
fn pop_inner(&mut self, block: bool) -> Result<Option<T>, RecvError> {
use Priority as P;
let mut queues = if !block {
let Some(queues) = self.state.try_recv()? else {
return Ok(None);
};
queues
} else {
self.state.recv()?
};
let high = P::High.probability() * !queues.high_priority.is_empty() as u32;
let medium = P::Medium.probability() * !queues.medium_priority.is_empty() as u32;
let low = P::Low.probability() * !queues.low_priority.is_empty() as u32;
let mut mass = high + medium + low; //%
if !queues.high_priority.is_empty() {
let flip = self.rand.random_ratio(P::High.probability(), mass);
if flip {
return Ok(queues.high_priority.pop());
}
mass -= P::High.probability();
}
if !queues.medium_priority.is_empty() {
let flip = self.rand.random_ratio(P::Medium.probability(), mass);
if flip {
return Ok(queues.medium_priority.pop());
}
mass -= P::Medium.probability();
}
if !queues.low_priority.is_empty() {
let flip = self.rand.random_ratio(P::Low.probability(), mass);
if flip {
return Ok(queues.low_priority.pop());
}
}
Ok(None)
}
}
impl<T> Drop for PriorityQueueReceiver<T> {
fn drop(&mut self) {
self.state
.receiver_count
.fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
}
}
/// If None is returned the sender disconnected
pub(crate) struct Iter<T>(PriorityQueueReceiver<T>);
impl<T> Iterator for Iter<T> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
self.0.pop_inner(true).ok().flatten()
}
}
impl<T> FusedIterator for Iter<T> {}
/// If None is returned there are no more elements in the queue
pub(crate) struct TryIter<T> {
receiver: PriorityQueueReceiver<T>,
ended: bool,
}
impl<T> Iterator for TryIter<T> {
type Item = Result<T, RecvError>;
fn next(&mut self) -> Option<Self::Item> {
if self.ended {
return None;
}
let res = self.receiver.pop_inner(false);
self.ended = res.is_err();
res.transpose()
}
}
impl<T> FusedIterator for TryIter<T> {}
#[cfg(test)]
mod tests {
use collections::HashSet;
use super::*;
#[test]
fn all_tasks_get_yielded() {
let (tx, mut rx) = PriorityQueueReceiver::new();
tx.send(Priority::Medium, 20).unwrap();
tx.send(Priority::High, 30).unwrap();
tx.send(Priority::Low, 10).unwrap();
tx.send(Priority::Medium, 21).unwrap();
tx.send(Priority::High, 31).unwrap();
drop(tx);
assert_eq!(
rx.iter().collect::<HashSet<_>>(),
[30, 31, 20, 21, 10].into_iter().collect::<HashSet<_>>()
)
}
#[test]
fn new_high_prio_task_get_scheduled_quickly() {
let (tx, mut rx) = PriorityQueueReceiver::new();
for _ in 0..100 {
tx.send(Priority::Low, 1).unwrap();
}
assert_eq!(rx.pop().unwrap(), 1);
tx.send(Priority::High, 3).unwrap();
assert_eq!(rx.pop().unwrap(), 3);
assert_eq!(rx.pop().unwrap(), 1);
}
}

View File

@@ -9,14 +9,15 @@ use crate::{
KeyBinding, KeyContext, KeyDownEvent, KeyEvent, Keystroke, KeystrokeEvent, LayoutId,
LineLayoutIndex, Modifiers, ModifiersChangedEvent, MonochromeSprite, MouseButton, MouseEvent,
MouseMoveEvent, MouseUpEvent, Path, Pixels, PlatformAtlas, PlatformDisplay, PlatformInput,
PlatformInputHandler, PlatformWindow, Point, PolychromeSprite, PromptButton, PromptLevel, Quad,
Render, RenderGlyphParams, RenderImage, RenderImageParams, RenderSvgParams, Replay, ResizeEdge,
SMOOTH_SVG_SCALE_FACTOR, SUBPIXEL_VARIANTS_X, SUBPIXEL_VARIANTS_Y, ScaledPixels, Scene, Shadow,
SharedString, Size, StrikethroughStyle, Style, SubscriberSet, Subscription, SystemWindowTab,
SystemWindowTabController, TabStopMap, TaffyLayoutEngine, Task, TextStyle, TextStyleRefinement,
TransformationMatrix, Underline, UnderlineStyle, WindowAppearance, WindowBackgroundAppearance,
WindowBounds, WindowControls, WindowDecorations, WindowOptions, WindowParams, WindowTextSystem,
point, prelude::*, px, rems, size, transparent_black,
PlatformInputHandler, PlatformWindow, Point, PolychromeSprite, Priority, PromptButton,
PromptLevel, Quad, Render, RenderGlyphParams, RenderImage, RenderImageParams, RenderSvgParams,
Replay, ResizeEdge, SMOOTH_SVG_SCALE_FACTOR, SUBPIXEL_VARIANTS_X, SUBPIXEL_VARIANTS_Y,
ScaledPixels, Scene, Shadow, SharedString, Size, StrikethroughStyle, Style, SubscriberSet,
Subscription, SystemWindowTab, SystemWindowTabController, TabStopMap, TaffyLayoutEngine, Task,
TextStyle, TextStyleRefinement, TransformationMatrix, Underline, UnderlineStyle,
WindowAppearance, WindowBackgroundAppearance, WindowBounds, WindowControls, WindowDecorations,
WindowOptions, WindowParams, WindowTextSystem, point, prelude::*, px, rems, size,
transparent_black,
};
use anyhow::{Context as _, Result, anyhow};
use collections::{FxHashMap, FxHashSet};
@@ -1725,6 +1726,27 @@ impl Window {
})
}
/// Spawn the future returned by the given closure on the application thread
/// pool, with the given priority. The closure is provided a handle to the
/// current window and an `AsyncWindowContext` for use within your future.
#[track_caller]
pub fn spawn_with_priority<AsyncFn, R>(
&self,
priority: Priority,
cx: &App,
f: AsyncFn,
) -> Task<R>
where
R: 'static,
AsyncFn: AsyncFnOnce(&mut AsyncWindowContext) -> R + 'static,
{
let handle = self.handle;
cx.spawn_with_priority(priority, async move |app| {
let mut async_window_cx = AsyncWindowContext::new_context(app.clone(), handle);
f(&mut async_window_cx).await
})
}
fn bounds_changed(&mut self, cx: &mut App) {
self.scale_factor = self.platform_window.scale_factor();
self.viewport_size = self.platform_window.content_size();

View File

@@ -12,7 +12,7 @@ mod session;
use std::{sync::Arc, time::Duration};
use async_dispatcher::{Dispatcher, Runnable, set_dispatcher};
use gpui::{App, PlatformDispatcher, RunnableVariant};
use gpui::{App, PlatformDispatcher, Priority, RunnableVariant};
use project::Fs;
pub use runtimelib::ExecutionState;
@@ -46,7 +46,7 @@ fn zed_dispatcher(cx: &mut App) -> impl Dispatcher {
impl Dispatcher for ZedDispatcher {
fn dispatch(&self, runnable: Runnable) {
self.dispatcher
.dispatch(RunnableVariant::Compat(runnable), None);
.dispatch(RunnableVariant::Compat(runnable), None, Priority::default());
}
fn dispatch_after(&self, duration: Duration, runnable: Runnable) {

View File

@@ -2452,6 +2452,12 @@ impl Workspace {
.0
.split(' ')
.flat_map(|k| Keystroke::parse(k).log_err())
.map(|k| {
cx.keyboard_mapper()
.map_key_equivalent(k, true)
.inner()
.clone()
})
.collect();
let _ = self.send_keystrokes_impl(keystrokes, window, cx);
}

View File

@@ -22,7 +22,8 @@ use git::{
COMMIT_MESSAGE, DOT_GIT, FSMONITOR_DAEMON, GITIGNORE, INDEX_LOCK, LFS_DIR, status::GitSummary,
};
use gpui::{
App, AppContext as _, AsyncApp, BackgroundExecutor, Context, Entity, EventEmitter, Task,
App, AppContext as _, AsyncApp, BackgroundExecutor, Context, Entity, EventEmitter, Priority,
Task,
};
use ignore::IgnoreStack;
use language::DiskState;
@@ -4144,7 +4145,7 @@ impl BackgroundScanner {
let progress_update_count = AtomicUsize::new(0);
self.executor
.scoped(|scope| {
.scoped_priority(Priority::Low, |scope| {
for _ in 0..self.executor.num_cpus() {
scope.spawn(async {
let mut last_progress_update_count = 0;

View File

@@ -52,6 +52,8 @@ extend-exclude = [
"crates/project_panel/benches/linux_repo_snapshot.txt",
# Some multibuffer test cases have word fragments that register as typos
"crates/multi_buffer/src/multi_buffer_tests.rs",
# Macos apis
"crates/gpui/src/platform/mac/dispatcher.rs",
]
[default]