Compare commits
10 Commits
windows/se
...
ep-distill
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
610536201b | ||
|
|
60f4aa333b | ||
|
|
a698f1bf63 | ||
|
|
636d11ebec | ||
|
|
4d0e760b04 | ||
|
|
8bd4d866b9 | ||
|
|
a2a96e4038 | ||
|
|
ec26556dab | ||
|
|
1a8d8e9572 | ||
|
|
ab893ca754 |
2
Cargo.lock
generated
2
Cargo.lock
generated
@@ -5201,7 +5201,6 @@ dependencies = [
|
||||
"wasmtime",
|
||||
"watch",
|
||||
"zeta_prompt",
|
||||
"zlog",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -7240,6 +7239,7 @@ dependencies = [
|
||||
"libc",
|
||||
"log",
|
||||
"lyon",
|
||||
"mach2 0.5.0",
|
||||
"media",
|
||||
"metal",
|
||||
"naga",
|
||||
|
||||
@@ -56,7 +56,6 @@ watch.workspace = true
|
||||
edit_prediction = { workspace = true, features = ["cli-support"] }
|
||||
wasmtime.workspace = true
|
||||
zeta_prompt.workspace = true
|
||||
zlog.workspace = true
|
||||
|
||||
# Wasmtime is included as a dependency in order to enable the same
|
||||
# features that are enabled in Zed.
|
||||
|
||||
@@ -1,14 +1,22 @@
|
||||
use anyhow::{Result, anyhow};
|
||||
use std::mem;
|
||||
|
||||
use crate::example::Example;
|
||||
|
||||
pub async fn run_distill(example: &mut Example) {
|
||||
let [prediction]: [_; 1] = mem::take(&mut example.predictions)
|
||||
.try_into()
|
||||
.expect("Run predict first with a single repetition");
|
||||
pub async fn run_distill(example: &mut Example) -> Result<()> {
|
||||
let [prediction]: [_; 1] =
|
||||
mem::take(&mut example.predictions)
|
||||
.try_into()
|
||||
.map_err(|preds: Vec<_>| {
|
||||
anyhow!(
|
||||
"Example has {} predictions, but it should have exactly one",
|
||||
preds.len()
|
||||
)
|
||||
})?;
|
||||
|
||||
example.expected_patch = prediction.actual_patch;
|
||||
example.prompt = None;
|
||||
example.predictions = Vec::new();
|
||||
example.score = Vec::new();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@ use crate::{
|
||||
progress::{Progress, Step},
|
||||
retrieve_context::run_context_retrieval,
|
||||
};
|
||||
use anyhow::{Context as _, Result, ensure};
|
||||
use edit_prediction::{
|
||||
EditPredictionStore,
|
||||
zeta2::{zeta2_output_for_patch, zeta2_prompt_input},
|
||||
@@ -18,12 +19,11 @@ pub async fn run_format_prompt(
|
||||
example: &mut Example,
|
||||
prompt_format: PromptFormat,
|
||||
app_state: Arc<EpAppState>,
|
||||
progress: Arc<Progress>,
|
||||
mut cx: AsyncApp,
|
||||
) {
|
||||
run_context_retrieval(example, app_state.clone(), progress.clone(), cx.clone()).await;
|
||||
) -> Result<()> {
|
||||
run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
|
||||
|
||||
let _step_progress = progress.start(Step::FormatPrompt, &example.name);
|
||||
let _step_progress = Progress::global().start(Step::FormatPrompt, &example.name);
|
||||
|
||||
match prompt_format {
|
||||
PromptFormat::Teacher => {
|
||||
@@ -35,29 +35,33 @@ pub async fn run_format_prompt(
|
||||
});
|
||||
}
|
||||
PromptFormat::Zeta2 => {
|
||||
run_load_project(example, app_state, progress.clone(), cx.clone()).await;
|
||||
run_load_project(example, app_state, cx.clone()).await?;
|
||||
|
||||
let ep_store = cx
|
||||
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
|
||||
.unwrap();
|
||||
let ep_store = cx.update(|cx| {
|
||||
EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
|
||||
})??;
|
||||
|
||||
let state = example.state.as_ref().unwrap();
|
||||
let snapshot = state
|
||||
.buffer
|
||||
.read_with(&cx, |buffer, _| buffer.snapshot())
|
||||
.unwrap();
|
||||
let state = example.state.as_ref().context("state must be set")?;
|
||||
let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot())?;
|
||||
let project = state.project.clone();
|
||||
let (_, input) = ep_store
|
||||
.update(&mut cx, |ep_store, _cx| {
|
||||
zeta2_prompt_input(
|
||||
&snapshot,
|
||||
example.context.as_ref().unwrap().files.clone(),
|
||||
ep_store.edit_history_for_project(&project),
|
||||
example.cursor_path.clone(),
|
||||
example.buffer.as_ref().unwrap().cursor_offset,
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
let (_, input) = ep_store.update(&mut cx, |ep_store, _cx| {
|
||||
anyhow::Ok(zeta2_prompt_input(
|
||||
&snapshot,
|
||||
example
|
||||
.context
|
||||
.as_ref()
|
||||
.context("context must be set")?
|
||||
.files
|
||||
.clone(),
|
||||
ep_store.edit_history_for_project(&project),
|
||||
example.cursor_path.clone(),
|
||||
example
|
||||
.buffer
|
||||
.as_ref()
|
||||
.context("buffer must be set")?
|
||||
.cursor_offset,
|
||||
))
|
||||
})??;
|
||||
let prompt = format_zeta_prompt(&input);
|
||||
let expected_output = zeta2_output_for_patch(&input, &example.expected_patch.clone());
|
||||
example.prompt = Some(ExamplePrompt {
|
||||
@@ -67,6 +71,7 @@ pub async fn run_format_prompt(
|
||||
});
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub struct TeacherPrompt;
|
||||
@@ -92,7 +97,7 @@ impl TeacherPrompt {
|
||||
prompt
|
||||
}
|
||||
|
||||
pub fn parse(example: &Example, response: &str) -> String {
|
||||
pub fn parse(example: &Example, response: &str) -> Result<String> {
|
||||
// Ideally, we should always be able to find cursor position in the retrieved context.
|
||||
// In reality, sometimes we don't find it for these reasons:
|
||||
// 1. `example.cursor_position` contains _more_ context than included in the retrieved context
|
||||
@@ -103,7 +108,7 @@ impl TeacherPrompt {
|
||||
let cursor_file = &example
|
||||
.buffer
|
||||
.as_ref()
|
||||
.expect("`buffer` should be filled in in the context collection step")
|
||||
.context("`buffer` should be filled in in the context collection step")?
|
||||
.content;
|
||||
|
||||
// Extract updated (new) editable region from the model response
|
||||
@@ -112,9 +117,10 @@ impl TeacherPrompt {
|
||||
// Reconstruct old editable region we sent to the model
|
||||
let old_editable_region = Self::format_editable_region(example);
|
||||
let old_editable_region = Self::extract_editable_region(&old_editable_region);
|
||||
if !cursor_file.contains(&old_editable_region) {
|
||||
panic!("Something's wrong: editable_region is not found in the cursor file")
|
||||
}
|
||||
ensure!(
|
||||
cursor_file.contains(&old_editable_region),
|
||||
"Something's wrong: editable_region is not found in the cursor file"
|
||||
);
|
||||
|
||||
// Apply editable region to a larger context and compute diff.
|
||||
// This is needed to get a better context lines around the editable region
|
||||
@@ -129,7 +135,7 @@ impl TeacherPrompt {
|
||||
diff = diff,
|
||||
};
|
||||
|
||||
diff
|
||||
Ok(diff)
|
||||
}
|
||||
|
||||
fn format_edit_history(edit_history: &str) -> String {
|
||||
@@ -153,9 +159,7 @@ impl TeacherPrompt {
|
||||
}
|
||||
|
||||
fn format_context(example: &Example) -> String {
|
||||
if example.context.is_none() {
|
||||
panic!("Missing context retriever step");
|
||||
}
|
||||
assert!(example.context.is_some(), "Missing context retriever step");
|
||||
|
||||
let mut prompt = String::new();
|
||||
zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);
|
||||
|
||||
@@ -4,7 +4,7 @@ use crate::{
|
||||
paths::{REPOS_DIR, WORKTREES_DIR},
|
||||
progress::{InfoStyle, Progress, Step, StepProgress},
|
||||
};
|
||||
use anyhow::{Result, anyhow};
|
||||
use anyhow::{Context as _, Result};
|
||||
use collections::HashMap;
|
||||
use edit_prediction::EditPredictionStore;
|
||||
use edit_prediction::udiff::OpenedBuffers;
|
||||
@@ -28,40 +28,35 @@ use zeta_prompt::CURSOR_MARKER;
|
||||
pub async fn run_load_project(
|
||||
example: &mut Example,
|
||||
app_state: Arc<EpAppState>,
|
||||
progress: Arc<Progress>,
|
||||
mut cx: AsyncApp,
|
||||
) {
|
||||
) -> Result<()> {
|
||||
if example.state.is_some() {
|
||||
return;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let progress = progress.start(Step::LoadProject, &example.name);
|
||||
let progress = Progress::global().start(Step::LoadProject, &example.name);
|
||||
|
||||
let project = setup_project(example, &app_state, &progress, &mut cx).await;
|
||||
let project = setup_project(example, &app_state, &progress, &mut cx).await?;
|
||||
|
||||
let _open_buffers = apply_edit_history(example, &project, &mut cx)
|
||||
.await
|
||||
.unwrap();
|
||||
let _open_buffers = apply_edit_history(example, &project, &mut cx).await?;
|
||||
|
||||
let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await;
|
||||
let (example_buffer, language_name) = buffer
|
||||
.read_with(&cx, |buffer, _cx| {
|
||||
let cursor_point = cursor_position.to_point(&buffer);
|
||||
let language_name = buffer
|
||||
.language()
|
||||
.map(|l| l.name().to_string())
|
||||
.unwrap_or_else(|| "Unknown".to_string());
|
||||
(
|
||||
ExampleBuffer {
|
||||
content: buffer.text(),
|
||||
cursor_row: cursor_point.row,
|
||||
cursor_column: cursor_point.column,
|
||||
cursor_offset: cursor_position.to_offset(&buffer),
|
||||
},
|
||||
language_name,
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
let (buffer, cursor_position) = cursor_position(example, &project, &mut cx).await?;
|
||||
let (example_buffer, language_name) = buffer.read_with(&cx, |buffer, _cx| {
|
||||
let cursor_point = cursor_position.to_point(&buffer);
|
||||
let language_name = buffer
|
||||
.language()
|
||||
.map(|l| l.name().to_string())
|
||||
.unwrap_or_else(|| "Unknown".to_string());
|
||||
(
|
||||
ExampleBuffer {
|
||||
content: buffer.text(),
|
||||
cursor_row: cursor_point.row,
|
||||
cursor_column: cursor_point.column,
|
||||
cursor_offset: cursor_position.to_offset(&buffer),
|
||||
},
|
||||
language_name,
|
||||
)
|
||||
})?;
|
||||
|
||||
progress.set_info(language_name, InfoStyle::Normal);
|
||||
|
||||
@@ -72,16 +67,15 @@ pub async fn run_load_project(
|
||||
cursor_position,
|
||||
_open_buffers,
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn cursor_position(
|
||||
example: &Example,
|
||||
project: &Entity<Project>,
|
||||
cx: &mut AsyncApp,
|
||||
) -> (Entity<Buffer>, Anchor) {
|
||||
let language_registry = project
|
||||
.read_with(cx, |project, _| project.languages().clone())
|
||||
.unwrap();
|
||||
) -> Result<(Entity<Buffer>, Anchor)> {
|
||||
let language_registry = project.read_with(cx, |project, _| project.languages().clone())?;
|
||||
let result = language_registry
|
||||
.load_language_for_file_path(&example.cursor_path)
|
||||
.await;
|
||||
@@ -89,17 +83,18 @@ async fn cursor_position(
|
||||
if let Err(error) = result
|
||||
&& !error.is::<LanguageNotFound>()
|
||||
{
|
||||
panic!("Failed to load language for file path: {}", error);
|
||||
return Err(error);
|
||||
}
|
||||
|
||||
let worktree = project
|
||||
.read_with(cx, |project, cx| {
|
||||
project.visible_worktrees(cx).next().unwrap()
|
||||
})
|
||||
.unwrap();
|
||||
let worktree = project.read_with(cx, |project, cx| {
|
||||
project
|
||||
.visible_worktrees(cx)
|
||||
.next()
|
||||
.context("No visible worktrees")
|
||||
})??;
|
||||
|
||||
let cursor_path = RelPath::new(&example.cursor_path, PathStyle::Posix)
|
||||
.unwrap()
|
||||
.context("Failed to create RelPath")?
|
||||
.into_arc();
|
||||
let cursor_buffer = project
|
||||
.update(cx, |project, cx| {
|
||||
@@ -110,15 +105,12 @@ async fn cursor_position(
|
||||
},
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.unwrap()
|
||||
.await
|
||||
.unwrap();
|
||||
})?
|
||||
.await?;
|
||||
let cursor_offset_within_excerpt = example
|
||||
.cursor_position
|
||||
.find(CURSOR_MARKER)
|
||||
.ok_or_else(|| anyhow!("missing cursor marker"))
|
||||
.unwrap();
|
||||
.context("missing cursor marker")?;
|
||||
let mut cursor_excerpt = example.cursor_position.clone();
|
||||
cursor_excerpt.replace_range(
|
||||
cursor_offset_within_excerpt..(cursor_offset_within_excerpt + CURSOR_MARKER.len()),
|
||||
@@ -128,90 +120,76 @@ async fn cursor_position(
|
||||
let text = buffer.text();
|
||||
|
||||
let mut matches = text.match_indices(&cursor_excerpt);
|
||||
let (excerpt_offset, _) = matches.next().unwrap_or_else(|| {
|
||||
panic!(
|
||||
let (excerpt_offset, _) = matches.next().with_context(|| {
|
||||
format!(
|
||||
"\nExcerpt:\n\n{cursor_excerpt}\nBuffer text:\n{text}\n.Example: {}\nCursor excerpt did not exist in buffer.",
|
||||
example.name
|
||||
);
|
||||
});
|
||||
assert!(matches.next().is_none(), "More than one cursor position match found for {}", &example.name);
|
||||
excerpt_offset
|
||||
}).unwrap();
|
||||
)
|
||||
})?;
|
||||
anyhow::ensure!(matches.next().is_none(), "More than one cursor position match found for {}", &example.name);
|
||||
Ok(excerpt_offset)
|
||||
})??;
|
||||
|
||||
let cursor_offset = excerpt_offset + cursor_offset_within_excerpt;
|
||||
let cursor_anchor = cursor_buffer
|
||||
.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))
|
||||
.unwrap();
|
||||
let cursor_anchor =
|
||||
cursor_buffer.read_with(cx, |buffer, _| buffer.anchor_after(cursor_offset))?;
|
||||
|
||||
(cursor_buffer, cursor_anchor)
|
||||
Ok((cursor_buffer, cursor_anchor))
|
||||
}
|
||||
|
||||
async fn setup_project(
|
||||
example: &mut Example,
|
||||
app_state: &Arc<EpAppState>,
|
||||
step_progress: &Arc<StepProgress>,
|
||||
step_progress: &StepProgress,
|
||||
cx: &mut AsyncApp,
|
||||
) -> Entity<Project> {
|
||||
) -> Result<Entity<Project>> {
|
||||
let ep_store = cx
|
||||
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
|
||||
.unwrap();
|
||||
.update(|cx| EditPredictionStore::try_global(cx))?
|
||||
.context("Store should be initialized at init")?;
|
||||
|
||||
let worktree_path = setup_worktree(example, step_progress).await;
|
||||
let worktree_path = setup_worktree(example, step_progress).await?;
|
||||
|
||||
if let Some(project) = app_state.project_cache.get(&example.repository_url) {
|
||||
ep_store
|
||||
.update(cx, |ep_store, _| {
|
||||
ep_store.clear_history_for_project(&project);
|
||||
})
|
||||
.unwrap();
|
||||
let buffer_store = project
|
||||
.read_with(cx, |project, _| project.buffer_store().clone())
|
||||
.unwrap();
|
||||
let buffers = buffer_store
|
||||
.read_with(cx, |buffer_store, _| {
|
||||
buffer_store.buffers().collect::<Vec<_>>()
|
||||
})
|
||||
.unwrap();
|
||||
ep_store.update(cx, |ep_store, _| {
|
||||
ep_store.clear_history_for_project(&project);
|
||||
})?;
|
||||
let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
|
||||
let buffers = buffer_store.read_with(cx, |buffer_store, _| {
|
||||
buffer_store.buffers().collect::<Vec<_>>()
|
||||
})?;
|
||||
for buffer in buffers {
|
||||
buffer
|
||||
.update(cx, |buffer, cx| buffer.reload(cx))
|
||||
.unwrap()
|
||||
.update(cx, |buffer, cx| buffer.reload(cx))?
|
||||
.await
|
||||
.ok();
|
||||
}
|
||||
return project;
|
||||
return Ok(project);
|
||||
}
|
||||
|
||||
let project = cx
|
||||
.update(|cx| {
|
||||
Project::local(
|
||||
app_state.client.clone(),
|
||||
app_state.node_runtime.clone(),
|
||||
app_state.user_store.clone(),
|
||||
app_state.languages.clone(),
|
||||
app_state.fs.clone(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.unwrap();
|
||||
let project = cx.update(|cx| {
|
||||
Project::local(
|
||||
app_state.client.clone(),
|
||||
app_state.node_runtime.clone(),
|
||||
app_state.user_store.clone(),
|
||||
app_state.languages.clone(),
|
||||
app_state.fs.clone(),
|
||||
None,
|
||||
cx,
|
||||
)
|
||||
})?;
|
||||
|
||||
project
|
||||
.update(cx, |project, cx| {
|
||||
project.disable_worktree_scanner(cx);
|
||||
project.create_worktree(&worktree_path, true, cx)
|
||||
})
|
||||
.unwrap()
|
||||
.await
|
||||
.unwrap();
|
||||
})?
|
||||
.await?;
|
||||
|
||||
app_state
|
||||
.project_cache
|
||||
.insert(example.repository_url.clone(), project.clone());
|
||||
|
||||
let buffer_store = project
|
||||
.read_with(cx, |project, _| project.buffer_store().clone())
|
||||
.unwrap();
|
||||
let buffer_store = project.read_with(cx, |project, _| project.buffer_store().clone())?;
|
||||
cx.subscribe(&buffer_store, {
|
||||
let project = project.clone();
|
||||
move |_, event, cx| match event {
|
||||
@@ -220,15 +198,14 @@ async fn setup_project(
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
})
|
||||
.unwrap()
|
||||
})?
|
||||
.detach();
|
||||
|
||||
project
|
||||
Ok(project)
|
||||
}
|
||||
|
||||
async fn setup_worktree(example: &Example, step_progress: &Arc<StepProgress>) -> PathBuf {
|
||||
let (repo_owner, repo_name) = example.repo_name().expect("failed to get repo name");
|
||||
async fn setup_worktree(example: &Example, step_progress: &StepProgress) -> Result<PathBuf> {
|
||||
let (repo_owner, repo_name) = example.repo_name().context("failed to get repo name")?;
|
||||
let repo_dir = REPOS_DIR.join(repo_owner.as_ref()).join(repo_name.as_ref());
|
||||
let worktree_path = WORKTREES_DIR
|
||||
.join(repo_owner.as_ref())
|
||||
@@ -237,14 +214,13 @@ async fn setup_worktree(example: &Example, step_progress: &Arc<StepProgress>) ->
|
||||
|
||||
if !repo_dir.is_dir() {
|
||||
step_progress.set_substatus(format!("cloning {}", repo_name));
|
||||
fs::create_dir_all(&repo_dir).unwrap();
|
||||
run_git(&repo_dir, &["init"]).await.unwrap();
|
||||
fs::create_dir_all(&repo_dir)?;
|
||||
run_git(&repo_dir, &["init"]).await?;
|
||||
run_git(
|
||||
&repo_dir,
|
||||
&["remote", "add", "origin", &example.repository_url],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Resolve the example to a revision, fetching it if needed.
|
||||
@@ -264,34 +240,25 @@ async fn setup_worktree(example: &Example, step_progress: &Arc<StepProgress>) ->
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
run_git(&repo_dir, &["fetch", "origin"]).await.unwrap();
|
||||
run_git(&repo_dir, &["fetch", "origin"]).await?;
|
||||
}
|
||||
let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"])
|
||||
.await
|
||||
.unwrap();
|
||||
let revision = run_git(&repo_dir, &["rev-parse", "FETCH_HEAD"]).await?;
|
||||
revision
|
||||
};
|
||||
|
||||
// Create the worktree for this example if needed.
|
||||
step_progress.set_substatus("preparing worktree");
|
||||
if worktree_path.is_dir() {
|
||||
run_git(&worktree_path, &["clean", "--force", "-d"])
|
||||
.await
|
||||
.unwrap();
|
||||
run_git(&worktree_path, &["reset", "--hard", "HEAD"])
|
||||
.await
|
||||
.unwrap();
|
||||
run_git(&worktree_path, &["checkout", revision.as_str()])
|
||||
.await
|
||||
.unwrap();
|
||||
run_git(&worktree_path, &["clean", "--force", "-d"]).await?;
|
||||
run_git(&worktree_path, &["reset", "--hard", "HEAD"]).await?;
|
||||
run_git(&worktree_path, &["checkout", revision.as_str()]).await?;
|
||||
} else {
|
||||
let worktree_path_string = worktree_path.to_string_lossy();
|
||||
run_git(
|
||||
&repo_dir,
|
||||
&["branch", "-f", &example.name, revision.as_str()],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
.await?;
|
||||
run_git(
|
||||
&repo_dir,
|
||||
&[
|
||||
@@ -302,8 +269,7 @@ async fn setup_worktree(example: &Example, step_progress: &Arc<StepProgress>) ->
|
||||
&example.name,
|
||||
],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
.await?;
|
||||
}
|
||||
drop(repo_lock);
|
||||
|
||||
@@ -314,30 +280,25 @@ async fn setup_worktree(example: &Example, step_progress: &Arc<StepProgress>) ->
|
||||
.current_dir(&worktree_path)
|
||||
.args(&["apply", "-"])
|
||||
.stdin(std::process::Stdio::piped())
|
||||
.spawn()
|
||||
.unwrap();
|
||||
.spawn()?;
|
||||
|
||||
let mut stdin = apply_process.stdin.take().unwrap();
|
||||
stdin
|
||||
.write_all(example.uncommitted_diff.as_bytes())
|
||||
.await
|
||||
.unwrap();
|
||||
stdin.close().await.unwrap();
|
||||
let mut stdin = apply_process.stdin.take().context("Failed to get stdin")?;
|
||||
stdin.write_all(example.uncommitted_diff.as_bytes()).await?;
|
||||
stdin.close().await?;
|
||||
drop(stdin);
|
||||
|
||||
let apply_result = apply_process.output().await.unwrap();
|
||||
if !apply_result.status.success() {
|
||||
panic!(
|
||||
"Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
|
||||
apply_result.status,
|
||||
String::from_utf8_lossy(&apply_result.stderr),
|
||||
String::from_utf8_lossy(&apply_result.stdout),
|
||||
);
|
||||
}
|
||||
let apply_result = apply_process.output().await?;
|
||||
anyhow::ensure!(
|
||||
apply_result.status.success(),
|
||||
"Failed to apply uncommitted diff patch with status: {}\nstderr:\n{}\nstdout:\n{}",
|
||||
apply_result.status,
|
||||
String::from_utf8_lossy(&apply_result.stderr),
|
||||
String::from_utf8_lossy(&apply_result.stdout),
|
||||
);
|
||||
}
|
||||
|
||||
step_progress.clear_substatus();
|
||||
worktree_path
|
||||
Ok(worktree_path)
|
||||
}
|
||||
|
||||
async fn apply_edit_history(
|
||||
|
||||
@@ -16,12 +16,14 @@ use edit_prediction::EditPredictionStore;
|
||||
use gpui::Application;
|
||||
use reqwest_client::ReqwestClient;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Display;
|
||||
use std::{path::PathBuf, sync::Arc};
|
||||
|
||||
use crate::distill::run_distill;
|
||||
use crate::example::{group_examples_by_repo, read_examples, write_examples};
|
||||
use crate::format_prompt::run_format_prompt;
|
||||
use crate::load_project::run_load_project;
|
||||
use crate::paths::FAILED_EXAMPLES_DIR;
|
||||
use crate::predict::run_prediction;
|
||||
use crate::progress::Progress;
|
||||
use crate::retrieve_context::run_context_retrieval;
|
||||
@@ -32,7 +34,7 @@ use crate::score::run_scoring;
|
||||
struct EpArgs {
|
||||
#[arg(long, default_value_t = false)]
|
||||
printenv: bool,
|
||||
#[clap(long, default_value_t = 10)]
|
||||
#[clap(long, default_value_t = 10, global = true)]
|
||||
max_parallelism: usize,
|
||||
#[command(subcommand)]
|
||||
command: Option<Command>,
|
||||
@@ -42,6 +44,8 @@ struct EpArgs {
|
||||
output: Option<PathBuf>,
|
||||
#[arg(long, short, global = true)]
|
||||
in_place: bool,
|
||||
#[arg(long, short, global = true)]
|
||||
failfast: bool,
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
@@ -67,6 +71,58 @@ enum Command {
|
||||
Clean,
|
||||
}
|
||||
|
||||
impl Display for Command {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Command::ParseExample => write!(f, "parse-example"),
|
||||
Command::LoadProject => write!(f, "load-project"),
|
||||
Command::Context => write!(f, "context"),
|
||||
Command::FormatPrompt(format_prompt_args) => write!(
|
||||
f,
|
||||
"format-prompt --prompt-format={}",
|
||||
format_prompt_args
|
||||
.prompt_format
|
||||
.to_possible_value()
|
||||
.unwrap()
|
||||
.get_name()
|
||||
),
|
||||
Command::Predict(predict_args) => {
|
||||
write!(
|
||||
f,
|
||||
"predict --provider={:?}",
|
||||
predict_args
|
||||
.provider
|
||||
.to_possible_value()
|
||||
.unwrap()
|
||||
.get_name()
|
||||
)
|
||||
}
|
||||
Command::Score(predict_args) => {
|
||||
write!(
|
||||
f,
|
||||
"score --provider={:?}",
|
||||
predict_args
|
||||
.provider
|
||||
.to_possible_value()
|
||||
.unwrap()
|
||||
.get_name()
|
||||
)
|
||||
}
|
||||
Command::Distill => write!(f, "distill"),
|
||||
Command::Eval(predict_args) => write!(
|
||||
f,
|
||||
"eval --provider={:?}",
|
||||
predict_args
|
||||
.provider
|
||||
.to_possible_value()
|
||||
.unwrap()
|
||||
.get_name()
|
||||
),
|
||||
Command::Clean => write!(f, "clean"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Args)]
|
||||
struct FormatPromptArgs {
|
||||
#[clap(long)]
|
||||
@@ -112,8 +168,6 @@ impl EpArgs {
|
||||
}
|
||||
|
||||
fn main() {
|
||||
let _ = zlog::try_init(Some("error".into()));
|
||||
zlog::init_output_stderr();
|
||||
let args = EpArgs::parse();
|
||||
|
||||
if args.printenv {
|
||||
@@ -147,92 +201,140 @@ fn main() {
|
||||
EditPredictionStore::global(&app_state.client, &app_state.user_store, cx);
|
||||
|
||||
cx.spawn(async move |cx| {
|
||||
if let Command::Predict(args) = &command {
|
||||
predict::sync_batches(&args.provider).await
|
||||
};
|
||||
let result = async {
|
||||
if let Command::Predict(args) = &command {
|
||||
predict::sync_batches(&args.provider).await?;
|
||||
}
|
||||
|
||||
let total_examples = examples.len();
|
||||
let progress = Progress::new(total_examples);
|
||||
let total_examples = examples.len();
|
||||
Progress::global().set_total_examples(total_examples);
|
||||
|
||||
let mut grouped_examples = group_examples_by_repo(&mut examples);
|
||||
let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
|
||||
let mut grouped_examples = group_examples_by_repo(&mut examples);
|
||||
let example_batches = grouped_examples.chunks_mut(args.max_parallelism);
|
||||
|
||||
for example_batch in example_batches {
|
||||
let futures = example_batch.into_iter().map(|repo_examples| async {
|
||||
for example in repo_examples.iter_mut() {
|
||||
match &command {
|
||||
Command::ParseExample => {}
|
||||
Command::LoadProject => {
|
||||
run_load_project(
|
||||
example,
|
||||
app_state.clone(),
|
||||
progress.clone(),
|
||||
cx.clone(),
|
||||
)
|
||||
.await;
|
||||
for example_batch in example_batches {
|
||||
let futures = example_batch.into_iter().map(|repo_examples| async {
|
||||
for example in repo_examples.iter_mut() {
|
||||
let result = async {
|
||||
match &command {
|
||||
Command::ParseExample => {}
|
||||
Command::LoadProject => {
|
||||
run_load_project(example, app_state.clone(), cx.clone())
|
||||
.await?;
|
||||
}
|
||||
Command::Context => {
|
||||
run_context_retrieval(
|
||||
example,
|
||||
app_state.clone(),
|
||||
cx.clone(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
Command::FormatPrompt(args) => {
|
||||
run_format_prompt(
|
||||
example,
|
||||
args.prompt_format,
|
||||
app_state.clone(),
|
||||
cx.clone(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
Command::Predict(args) => {
|
||||
run_prediction(
|
||||
example,
|
||||
Some(args.provider),
|
||||
args.repetitions,
|
||||
app_state.clone(),
|
||||
cx.clone(),
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
Command::Distill => {
|
||||
run_distill(example).await?;
|
||||
}
|
||||
Command::Score(args) | Command::Eval(args) => {
|
||||
run_scoring(example, &args, app_state.clone(), cx.clone())
|
||||
.await?;
|
||||
}
|
||||
Command::Clean => {
|
||||
unreachable!()
|
||||
}
|
||||
}
|
||||
anyhow::Ok(())
|
||||
}
|
||||
Command::Context => {
|
||||
run_context_retrieval(
|
||||
example,
|
||||
app_state.clone(),
|
||||
progress.clone(),
|
||||
cx.clone(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Command::FormatPrompt(args) => {
|
||||
run_format_prompt(
|
||||
example,
|
||||
args.prompt_format,
|
||||
app_state.clone(),
|
||||
progress.clone(),
|
||||
cx.clone(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Command::Predict(args) => {
|
||||
run_prediction(
|
||||
example,
|
||||
Some(args.provider),
|
||||
args.repetitions,
|
||||
app_state.clone(),
|
||||
progress.clone(),
|
||||
cx.clone(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Command::Distill => {
|
||||
run_distill(example).await;
|
||||
}
|
||||
Command::Score(args) | Command::Eval(args) => {
|
||||
run_scoring(
|
||||
example,
|
||||
&args,
|
||||
app_state.clone(),
|
||||
progress.clone(),
|
||||
cx.clone(),
|
||||
)
|
||||
.await;
|
||||
}
|
||||
Command::Clean => {
|
||||
unreachable!()
|
||||
.await;
|
||||
|
||||
if let Err(e) = result {
|
||||
Progress::global().increment_failed();
|
||||
let failed_example_path =
|
||||
FAILED_EXAMPLES_DIR.join(format!("{}.json", example.name));
|
||||
app_state
|
||||
.fs
|
||||
.write(
|
||||
&failed_example_path,
|
||||
&serde_json::to_vec_pretty(&example).unwrap(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let err_path =
|
||||
FAILED_EXAMPLES_DIR.join(format!("{}_err.txt", example.name));
|
||||
app_state
|
||||
.fs
|
||||
.write(&err_path, e.to_string().as_bytes())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let msg = format!(
|
||||
indoc::indoc! {"
|
||||
While processing {}:
|
||||
|
||||
{:?}
|
||||
|
||||
Written to: \x1b[36m{}\x1b[0m
|
||||
|
||||
Explore this example data with:
|
||||
fx \x1b[36m{}\x1b[0m
|
||||
|
||||
Re-run this example with:
|
||||
cargo run -p edit_prediction_cli -- {} \x1b[36m{}\x1b[0m
|
||||
"},
|
||||
example.name,
|
||||
e,
|
||||
err_path.display(),
|
||||
failed_example_path.display(),
|
||||
command,
|
||||
failed_example_path.display(),
|
||||
);
|
||||
if args.failfast || total_examples == 1 {
|
||||
Progress::global().finalize();
|
||||
panic!("{}", msg);
|
||||
} else {
|
||||
log::error!("{}", msg);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
futures::future::join_all(futures).await;
|
||||
}
|
||||
progress.clear();
|
||||
});
|
||||
futures::future::join_all(futures).await;
|
||||
}
|
||||
Progress::global().finalize();
|
||||
|
||||
if args.output.is_some() || !matches!(command, Command::Eval(_)) {
|
||||
write_examples(&examples, output.as_ref());
|
||||
}
|
||||
if args.output.is_some() || !matches!(command, Command::Eval(_)) {
|
||||
write_examples(&examples, output.as_ref());
|
||||
}
|
||||
|
||||
match &command {
|
||||
Command::Predict(args) => predict::sync_batches(&args.provider).await,
|
||||
Command::Eval(_) => score::print_report(&examples),
|
||||
_ => (),
|
||||
};
|
||||
match &command {
|
||||
Command::Predict(args) => predict::sync_batches(&args.provider).await?,
|
||||
Command::Eval(_) => score::print_report(&examples),
|
||||
_ => (),
|
||||
};
|
||||
|
||||
anyhow::Ok(())
|
||||
}
|
||||
.await;
|
||||
|
||||
if let Err(e) = result {
|
||||
panic!("Fatal error: {:?}", e);
|
||||
}
|
||||
|
||||
let _ = cx.update(|cx| cx.quit());
|
||||
})
|
||||
|
||||
@@ -18,6 +18,8 @@ pub static RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| {
|
||||
});
|
||||
pub static LATEST_EXAMPLE_RUN_DIR: LazyLock<PathBuf> = LazyLock::new(|| DATA_DIR.join("latest"));
|
||||
pub static LLM_CACHE_DB: LazyLock<PathBuf> = LazyLock::new(|| CACHE_DIR.join("llm_cache.sqlite"));
|
||||
pub static FAILED_EXAMPLES_DIR: LazyLock<PathBuf> =
|
||||
LazyLock::new(|| ensure_dir(&RUN_DIR.join("failed")));
|
||||
|
||||
fn ensure_dir(path: &Path) -> PathBuf {
|
||||
std::fs::create_dir_all(path).expect("Failed to create directory");
|
||||
|
||||
@@ -9,6 +9,7 @@ use crate::{
|
||||
progress::{InfoStyle, Progress, Step},
|
||||
retrieve_context::run_context_retrieval,
|
||||
};
|
||||
use anyhow::Context as _;
|
||||
use edit_prediction::{DebugEvent, EditPredictionStore};
|
||||
use futures::{FutureExt as _, StreamExt as _, future::Shared};
|
||||
use gpui::{AppContext as _, AsyncApp, Task};
|
||||
@@ -25,41 +26,33 @@ pub async fn run_prediction(
|
||||
provider: Option<PredictionProvider>,
|
||||
repetition_count: usize,
|
||||
app_state: Arc<EpAppState>,
|
||||
progress: Arc<Progress>,
|
||||
mut cx: AsyncApp,
|
||||
) {
|
||||
) -> anyhow::Result<()> {
|
||||
if !example.predictions.is_empty() {
|
||||
return;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let provider = provider.unwrap();
|
||||
let provider = provider.context("provider is required")?;
|
||||
|
||||
run_context_retrieval(example, app_state.clone(), progress.clone(), cx.clone()).await;
|
||||
run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
|
||||
|
||||
if matches!(
|
||||
provider,
|
||||
PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching
|
||||
) {
|
||||
let _step_progress = progress.start(Step::Predict, &example.name);
|
||||
let _step_progress = Progress::global().start(Step::Predict, &example.name);
|
||||
|
||||
if example.prompt.is_none() {
|
||||
run_format_prompt(
|
||||
example,
|
||||
PromptFormat::Teacher,
|
||||
app_state.clone(),
|
||||
progress,
|
||||
cx,
|
||||
)
|
||||
.await;
|
||||
run_format_prompt(example, PromptFormat::Teacher, app_state.clone(), cx).await?;
|
||||
}
|
||||
|
||||
let batched = matches!(provider, PredictionProvider::Teacher);
|
||||
return predict_anthropic(example, repetition_count, batched).await;
|
||||
}
|
||||
|
||||
run_load_project(example, app_state.clone(), progress.clone(), cx.clone()).await;
|
||||
run_load_project(example, app_state.clone(), cx.clone()).await?;
|
||||
|
||||
let _step_progress = progress.start(Step::Predict, &example.name);
|
||||
let _step_progress = Progress::global().start(Step::Predict, &example.name);
|
||||
|
||||
if matches!(
|
||||
provider,
|
||||
@@ -70,10 +63,9 @@ pub async fn run_prediction(
|
||||
.get_or_init(|| {
|
||||
let client = app_state.client.clone();
|
||||
cx.spawn(async move |cx| {
|
||||
client
|
||||
.sign_in_with_optional_connect(true, cx)
|
||||
.await
|
||||
.unwrap();
|
||||
if let Err(e) = client.sign_in_with_optional_connect(true, cx).await {
|
||||
eprintln!("Authentication failed: {}", e);
|
||||
}
|
||||
})
|
||||
.shared()
|
||||
})
|
||||
@@ -81,33 +73,30 @@ pub async fn run_prediction(
|
||||
.await;
|
||||
}
|
||||
|
||||
let ep_store = cx
|
||||
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
|
||||
.unwrap();
|
||||
let ep_store = cx.update(|cx| {
|
||||
EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
|
||||
})??;
|
||||
|
||||
ep_store
|
||||
.update(&mut cx, |store, _cx| {
|
||||
let model = match provider {
|
||||
PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
|
||||
PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
|
||||
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
|
||||
PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
|
||||
PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
|
||||
unreachable!()
|
||||
}
|
||||
};
|
||||
store.set_edit_prediction_model(model);
|
||||
})
|
||||
.unwrap();
|
||||
let state = example.state.as_ref().unwrap();
|
||||
ep_store.update(&mut cx, |store, _cx| {
|
||||
let model = match provider {
|
||||
PredictionProvider::Zeta1 => edit_prediction::EditPredictionModel::Zeta1,
|
||||
PredictionProvider::Zeta2 => edit_prediction::EditPredictionModel::Zeta2,
|
||||
PredictionProvider::Sweep => edit_prediction::EditPredictionModel::Sweep,
|
||||
PredictionProvider::Mercury => edit_prediction::EditPredictionModel::Mercury,
|
||||
PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
|
||||
unreachable!()
|
||||
}
|
||||
};
|
||||
store.set_edit_prediction_model(model);
|
||||
})?;
|
||||
let state = example.state.as_ref().context("state must be set")?;
|
||||
let run_dir = RUN_DIR.join(&example.name);
|
||||
|
||||
let updated_example = Arc::new(Mutex::new(example.clone()));
|
||||
let current_run_ix = Arc::new(AtomicUsize::new(0));
|
||||
|
||||
let mut debug_rx = ep_store
|
||||
.update(&mut cx, |store, cx| store.debug_info(&state.project, cx))
|
||||
.unwrap();
|
||||
let mut debug_rx =
|
||||
ep_store.update(&mut cx, |store, cx| store.debug_info(&state.project, cx))?;
|
||||
let debug_task = cx.background_spawn({
|
||||
let updated_example = updated_example.clone();
|
||||
let current_run_ix = current_run_ix.clone();
|
||||
@@ -161,14 +150,14 @@ pub async fn run_prediction(
|
||||
run_dir.clone()
|
||||
};
|
||||
|
||||
fs::create_dir_all(&run_dir).unwrap();
|
||||
fs::create_dir_all(&run_dir)?;
|
||||
if LATEST_EXAMPLE_RUN_DIR.is_symlink() {
|
||||
fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR).unwrap();
|
||||
fs::remove_file(&*LATEST_EXAMPLE_RUN_DIR)?;
|
||||
}
|
||||
#[cfg(unix)]
|
||||
std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
|
||||
std::os::unix::fs::symlink(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
|
||||
#[cfg(windows)]
|
||||
std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR).unwrap();
|
||||
std::os::windows::fs::symlink_dir(&run_dir, &*LATEST_EXAMPLE_RUN_DIR)?;
|
||||
|
||||
updated_example
|
||||
.lock()
|
||||
@@ -189,10 +178,8 @@ pub async fn run_prediction(
|
||||
cloud_llm_client::PredictEditsRequestTrigger::Cli,
|
||||
cx,
|
||||
)
|
||||
})
|
||||
.unwrap()
|
||||
.await
|
||||
.unwrap();
|
||||
})?
|
||||
.await?;
|
||||
|
||||
let actual_patch = prediction
|
||||
.and_then(|prediction| {
|
||||
@@ -221,20 +208,23 @@ pub async fn run_prediction(
|
||||
}
|
||||
}
|
||||
|
||||
ep_store
|
||||
.update(&mut cx, |store, _| {
|
||||
store.remove_project(&state.project);
|
||||
})
|
||||
.unwrap();
|
||||
debug_task.await.unwrap();
|
||||
ep_store.update(&mut cx, |store, _| {
|
||||
store.remove_project(&state.project);
|
||||
})?;
|
||||
debug_task.await?;
|
||||
|
||||
*example = Arc::into_inner(updated_example)
|
||||
.unwrap()
|
||||
.ok_or_else(|| anyhow::anyhow!("Failed to unwrap Arc"))?
|
||||
.into_inner()
|
||||
.unwrap();
|
||||
.map_err(|_| anyhow::anyhow!("Failed to unwrap Mutex"))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batched: bool) {
|
||||
async fn predict_anthropic(
|
||||
example: &mut Example,
|
||||
_repetition_count: usize,
|
||||
batched: bool,
|
||||
) -> anyhow::Result<()> {
|
||||
let llm_model_name = "claude-sonnet-4-5";
|
||||
let max_tokens = 16384;
|
||||
let llm_client = if batched {
|
||||
@@ -242,12 +232,9 @@ async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batc
|
||||
} else {
|
||||
AnthropicClient::plain()
|
||||
};
|
||||
let llm_client = llm_client.expect("Failed to create LLM client");
|
||||
let llm_client = llm_client.context("Failed to create LLM client")?;
|
||||
|
||||
let prompt = example
|
||||
.prompt
|
||||
.as_ref()
|
||||
.unwrap_or_else(|| panic!("Prompt is required for an example {}", &example.name));
|
||||
let prompt = example.prompt.as_ref().context("Prompt is required")?;
|
||||
|
||||
let messages = vec![anthropic::Message {
|
||||
role: anthropic::Role::User,
|
||||
@@ -259,11 +246,10 @@ async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batc
|
||||
|
||||
let Some(response) = llm_client
|
||||
.generate(llm_model_name, max_tokens, messages)
|
||||
.await
|
||||
.unwrap()
|
||||
.await?
|
||||
else {
|
||||
// Request stashed for batched processing
|
||||
return;
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
let actual_output = response
|
||||
@@ -276,7 +262,7 @@ async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batc
|
||||
.collect::<Vec<String>>()
|
||||
.join("\n");
|
||||
|
||||
let actual_patch = TeacherPrompt::parse(example, &actual_output);
|
||||
let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
|
||||
|
||||
let prediction = ExamplePrediction {
|
||||
actual_patch,
|
||||
@@ -285,19 +271,21 @@ async fn predict_anthropic(example: &mut Example, _repetition_count: usize, batc
|
||||
};
|
||||
|
||||
example.predictions.push(prediction);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn sync_batches(provider: &PredictionProvider) {
|
||||
pub async fn sync_batches(provider: &PredictionProvider) -> anyhow::Result<()> {
|
||||
match provider {
|
||||
PredictionProvider::Teacher => {
|
||||
let cache_path = crate::paths::LLM_CACHE_DB.as_ref();
|
||||
let llm_client =
|
||||
AnthropicClient::batch(cache_path).expect("Failed to create LLM client");
|
||||
AnthropicClient::batch(cache_path).context("Failed to create LLM client")?;
|
||||
llm_client
|
||||
.sync_batches()
|
||||
.await
|
||||
.expect("Failed to sync batches");
|
||||
.context("Failed to sync batches")?;
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -2,10 +2,12 @@ use std::{
|
||||
borrow::Cow,
|
||||
collections::HashMap,
|
||||
io::{IsTerminal, Write},
|
||||
sync::{Arc, Mutex},
|
||||
sync::{Arc, Mutex, OnceLock},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use log::{Level, Log, Metadata, Record};
|
||||
|
||||
pub struct Progress {
|
||||
inner: Mutex<ProgressInner>,
|
||||
}
|
||||
@@ -18,6 +20,8 @@ struct ProgressInner {
|
||||
max_example_name_len: usize,
|
||||
status_lines_displayed: usize,
|
||||
total_examples: usize,
|
||||
failed_examples: usize,
|
||||
last_line_is_logging: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
@@ -72,70 +76,120 @@ impl Step {
|
||||
}
|
||||
}
|
||||
|
||||
const RIGHT_MARGIN: usize = 4;
|
||||
static GLOBAL: OnceLock<Arc<Progress>> = OnceLock::new();
|
||||
static LOGGER: ProgressLogger = ProgressLogger;
|
||||
|
||||
const MARGIN: usize = 4;
|
||||
const MAX_STATUS_LINES: usize = 10;
|
||||
|
||||
impl Progress {
|
||||
pub fn new(total_examples: usize) -> Arc<Self> {
|
||||
Arc::new(Self {
|
||||
inner: Mutex::new(ProgressInner {
|
||||
completed: Vec::new(),
|
||||
in_progress: HashMap::new(),
|
||||
is_tty: std::io::stderr().is_terminal(),
|
||||
terminal_width: get_terminal_width(),
|
||||
max_example_name_len: 0,
|
||||
status_lines_displayed: 0,
|
||||
total_examples,
|
||||
}),
|
||||
})
|
||||
/// Returns the global Progress instance, initializing it if necessary.
|
||||
pub fn global() -> Arc<Progress> {
|
||||
GLOBAL
|
||||
.get_or_init(|| {
|
||||
let progress = Arc::new(Self {
|
||||
inner: Mutex::new(ProgressInner {
|
||||
completed: Vec::new(),
|
||||
in_progress: HashMap::new(),
|
||||
is_tty: std::io::stderr().is_terminal(),
|
||||
terminal_width: get_terminal_width(),
|
||||
max_example_name_len: 0,
|
||||
status_lines_displayed: 0,
|
||||
total_examples: 0,
|
||||
failed_examples: 0,
|
||||
last_line_is_logging: false,
|
||||
}),
|
||||
});
|
||||
let _ = log::set_logger(&LOGGER);
|
||||
log::set_max_level(log::LevelFilter::Error);
|
||||
progress
|
||||
})
|
||||
.clone()
|
||||
}
|
||||
|
||||
pub fn start(self: &Arc<Self>, step: Step, example_name: &str) -> Arc<StepProgress> {
|
||||
{
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
pub fn set_total_examples(&self, total: usize) {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
inner.total_examples = total;
|
||||
}
|
||||
|
||||
Self::clear_status_lines(&mut inner);
|
||||
pub fn increment_failed(&self) {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
inner.failed_examples += 1;
|
||||
}
|
||||
|
||||
inner.max_example_name_len = inner.max_example_name_len.max(example_name.len());
|
||||
/// Prints a message to stderr, clearing and redrawing status lines to avoid corruption.
|
||||
/// This should be used for any output that needs to appear above the status lines.
|
||||
fn log(&self, message: &str) {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
Self::clear_status_lines(&mut inner);
|
||||
|
||||
inner.in_progress.insert(
|
||||
example_name.to_string(),
|
||||
InProgressTask {
|
||||
step,
|
||||
started_at: Instant::now(),
|
||||
substatus: None,
|
||||
info: None,
|
||||
},
|
||||
);
|
||||
|
||||
Self::print_status_lines(&mut inner);
|
||||
if !inner.last_line_is_logging {
|
||||
let reset = "\x1b[0m";
|
||||
let dim = "\x1b[2m";
|
||||
let divider = "─".repeat(inner.terminal_width.saturating_sub(MARGIN));
|
||||
eprintln!("{dim}{divider}{reset}");
|
||||
inner.last_line_is_logging = true;
|
||||
}
|
||||
|
||||
Arc::new(StepProgress {
|
||||
eprintln!("{}", message);
|
||||
}
|
||||
|
||||
pub fn start(self: &Arc<Self>, step: Step, example_name: &str) -> StepProgress {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
|
||||
Self::clear_status_lines(&mut inner);
|
||||
|
||||
inner.max_example_name_len = inner.max_example_name_len.max(example_name.len());
|
||||
inner.in_progress.insert(
|
||||
example_name.to_string(),
|
||||
InProgressTask {
|
||||
step,
|
||||
started_at: Instant::now(),
|
||||
substatus: None,
|
||||
info: None,
|
||||
},
|
||||
);
|
||||
|
||||
Self::print_status_lines(&mut inner);
|
||||
|
||||
StepProgress {
|
||||
progress: self.clone(),
|
||||
step,
|
||||
example_name: example_name.to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn finish(&self, step: Step, example_name: &str) {
|
||||
fn finish(&self, step: Step, example_name: &str) {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
|
||||
let task = inner.in_progress.remove(example_name);
|
||||
if let Some(task) = task {
|
||||
if task.step == step {
|
||||
inner.completed.push(CompletedTask {
|
||||
step: task.step,
|
||||
example_name: example_name.to_string(),
|
||||
duration: task.started_at.elapsed(),
|
||||
info: task.info,
|
||||
});
|
||||
let Some(task) = inner.in_progress.remove(example_name) else {
|
||||
return;
|
||||
};
|
||||
|
||||
Self::clear_status_lines(&mut inner);
|
||||
Self::print_completed(&inner, inner.completed.last().unwrap());
|
||||
Self::print_status_lines(&mut inner);
|
||||
} else {
|
||||
inner.in_progress.insert(example_name.to_string(), task);
|
||||
}
|
||||
if task.step == step {
|
||||
inner.completed.push(CompletedTask {
|
||||
step: task.step,
|
||||
example_name: example_name.to_string(),
|
||||
duration: task.started_at.elapsed(),
|
||||
info: task.info,
|
||||
});
|
||||
|
||||
Self::clear_status_lines(&mut inner);
|
||||
Self::print_logging_closing_divider(&mut inner);
|
||||
Self::print_completed(&inner, inner.completed.last().unwrap());
|
||||
Self::print_status_lines(&mut inner);
|
||||
} else {
|
||||
inner.in_progress.insert(example_name.to_string(), task);
|
||||
}
|
||||
}
|
||||
|
||||
fn print_logging_closing_divider(inner: &mut ProgressInner) {
|
||||
if inner.last_line_is_logging {
|
||||
let reset = "\x1b[0m";
|
||||
let dim = "\x1b[2m";
|
||||
let divider = "─".repeat(inner.terminal_width.saturating_sub(MARGIN));
|
||||
eprintln!("{dim}{divider}{reset}");
|
||||
inner.last_line_is_logging = false;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -182,7 +236,7 @@ impl Progress {
|
||||
let duration_with_margin = format!("{duration} ");
|
||||
let padding_needed = inner
|
||||
.terminal_width
|
||||
.saturating_sub(RIGHT_MARGIN)
|
||||
.saturating_sub(MARGIN)
|
||||
.saturating_sub(duration_with_margin.len())
|
||||
.saturating_sub(strip_ansi_len(&prefix));
|
||||
let padding = " ".repeat(padding_needed);
|
||||
@@ -216,27 +270,41 @@ impl Progress {
|
||||
// Build the done/in-progress/total label
|
||||
let done_count = inner.completed.len();
|
||||
let in_progress_count = inner.in_progress.len();
|
||||
let failed_count = inner.failed_examples;
|
||||
|
||||
let failed_label = if failed_count > 0 {
|
||||
format!(" {} failed ", failed_count)
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
let range_label = format!(
|
||||
" {}/{}/{} ",
|
||||
done_count, in_progress_count, inner.total_examples
|
||||
);
|
||||
|
||||
// Print a divider line with range label aligned with timestamps
|
||||
// Print a divider line with failed count on left, range label on right
|
||||
let failed_visible_len = strip_ansi_len(&failed_label);
|
||||
let range_visible_len = range_label.len();
|
||||
let left_divider_len = inner
|
||||
let middle_divider_len = inner
|
||||
.terminal_width
|
||||
.saturating_sub(RIGHT_MARGIN)
|
||||
.saturating_sub(MARGIN * 2)
|
||||
.saturating_sub(failed_visible_len)
|
||||
.saturating_sub(range_visible_len);
|
||||
let left_divider = "─".repeat(left_divider_len);
|
||||
let right_divider = "─".repeat(RIGHT_MARGIN);
|
||||
eprintln!("{dim}{left_divider}{reset}{range_label}{dim}{right_divider}{reset}");
|
||||
let left_divider = "─".repeat(MARGIN);
|
||||
let middle_divider = "─".repeat(middle_divider_len);
|
||||
let right_divider = "─".repeat(MARGIN);
|
||||
eprintln!(
|
||||
"{dim}{left_divider}{reset}{failed_label}{dim}{middle_divider}{reset}{range_label}{dim}{right_divider}{reset}"
|
||||
);
|
||||
|
||||
let mut tasks: Vec<_> = inner.in_progress.iter().collect();
|
||||
tasks.sort_by_key(|(name, _)| *name);
|
||||
|
||||
let total_tasks = tasks.len();
|
||||
let mut lines_printed = 0;
|
||||
|
||||
for (name, task) in tasks.iter() {
|
||||
for (name, task) in tasks.iter().take(MAX_STATUS_LINES) {
|
||||
let elapsed = format_duration(task.started_at.elapsed());
|
||||
let substatus_part = task
|
||||
.substatus
|
||||
@@ -256,7 +324,7 @@ impl Progress {
|
||||
let duration_with_margin = format!("{elapsed} ");
|
||||
let padding_needed = inner
|
||||
.terminal_width
|
||||
.saturating_sub(RIGHT_MARGIN)
|
||||
.saturating_sub(MARGIN)
|
||||
.saturating_sub(duration_with_margin.len())
|
||||
.saturating_sub(strip_ansi_len(&prefix));
|
||||
let padding = " ".repeat(padding_needed);
|
||||
@@ -265,13 +333,34 @@ impl Progress {
|
||||
lines_printed += 1;
|
||||
}
|
||||
|
||||
// Show "+N more" on its own line if there are more tasks
|
||||
if total_tasks > MAX_STATUS_LINES {
|
||||
let remaining = total_tasks - MAX_STATUS_LINES;
|
||||
eprintln!("{:>12} +{remaining} more", "");
|
||||
lines_printed += 1;
|
||||
}
|
||||
|
||||
inner.status_lines_displayed = lines_printed + 1; // +1 for the divider line
|
||||
let _ = std::io::stderr().flush();
|
||||
}
|
||||
|
||||
pub fn clear(&self) {
|
||||
pub fn finalize(&self) {
|
||||
let mut inner = self.inner.lock().unwrap();
|
||||
Self::clear_status_lines(&mut inner);
|
||||
|
||||
// Print summary if there were failures
|
||||
if inner.failed_examples > 0 {
|
||||
let total_processed = inner.completed.len() + inner.failed_examples;
|
||||
let percentage = if total_processed > 0 {
|
||||
inner.failed_examples as f64 / total_processed as f64 * 100.0
|
||||
} else {
|
||||
0.0
|
||||
};
|
||||
eprintln!(
|
||||
"\n{} of {} examples failed ({:.1}%)",
|
||||
inner.failed_examples, total_processed, percentage
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -314,6 +403,53 @@ impl Drop for StepProgress {
|
||||
}
|
||||
}
|
||||
|
||||
struct ProgressLogger;
|
||||
|
||||
impl Log for ProgressLogger {
|
||||
fn enabled(&self, metadata: &Metadata) -> bool {
|
||||
metadata.level() <= Level::Info
|
||||
}
|
||||
|
||||
fn log(&self, record: &Record) {
|
||||
if !self.enabled(record.metadata()) {
|
||||
return;
|
||||
}
|
||||
|
||||
let level_color = match record.level() {
|
||||
Level::Error => "\x1b[31m",
|
||||
Level::Warn => "\x1b[33m",
|
||||
Level::Info => "\x1b[32m",
|
||||
Level::Debug => "\x1b[34m",
|
||||
Level::Trace => "\x1b[35m",
|
||||
};
|
||||
let reset = "\x1b[0m";
|
||||
let bold = "\x1b[1m";
|
||||
|
||||
let level_label = match record.level() {
|
||||
Level::Error => "Error",
|
||||
Level::Warn => "Warn",
|
||||
Level::Info => "Info",
|
||||
Level::Debug => "Debug",
|
||||
Level::Trace => "Trace",
|
||||
};
|
||||
|
||||
let message = format!(
|
||||
"{bold}{level_color}{level_label:>12}{reset} {}",
|
||||
record.args()
|
||||
);
|
||||
|
||||
if let Some(progress) = GLOBAL.get() {
|
||||
progress.log(&message);
|
||||
} else {
|
||||
eprintln!("{}", message);
|
||||
}
|
||||
}
|
||||
|
||||
fn flush(&self) {
|
||||
let _ = std::io::stderr().flush();
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
fn get_terminal_width() -> usize {
|
||||
unsafe {
|
||||
|
||||
@@ -4,6 +4,7 @@ use crate::{
|
||||
load_project::run_load_project,
|
||||
progress::{InfoStyle, Progress, Step, StepProgress},
|
||||
};
|
||||
use anyhow::Context as _;
|
||||
use collections::HashSet;
|
||||
use edit_prediction::{DebugEvent, EditPredictionStore};
|
||||
use futures::{FutureExt as _, StreamExt as _, channel::mpsc};
|
||||
@@ -16,39 +17,36 @@ use std::time::Duration;
|
||||
pub async fn run_context_retrieval(
|
||||
example: &mut Example,
|
||||
app_state: Arc<EpAppState>,
|
||||
progress: Arc<Progress>,
|
||||
mut cx: AsyncApp,
|
||||
) {
|
||||
) -> anyhow::Result<()> {
|
||||
if example.context.is_some() {
|
||||
return;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
run_load_project(example, app_state.clone(), progress.clone(), cx.clone()).await;
|
||||
run_load_project(example, app_state.clone(), cx.clone()).await?;
|
||||
|
||||
let step_progress = progress.start(Step::Context, &example.name);
|
||||
let step_progress: Arc<StepProgress> = Progress::global()
|
||||
.start(Step::Context, &example.name)
|
||||
.into();
|
||||
|
||||
let state = example.state.as_ref().unwrap();
|
||||
let project = state.project.clone();
|
||||
|
||||
let _lsp_handle = project
|
||||
.update(&mut cx, |project, cx| {
|
||||
project.register_buffer_with_language_servers(&state.buffer, cx)
|
||||
})
|
||||
.unwrap();
|
||||
wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await;
|
||||
let _lsp_handle = project.update(&mut cx, |project, cx| {
|
||||
project.register_buffer_with_language_servers(&state.buffer, cx)
|
||||
})?;
|
||||
wait_for_language_servers_to_start(&project, &state.buffer, &step_progress, &mut cx).await?;
|
||||
|
||||
let ep_store = cx
|
||||
.update(|cx| EditPredictionStore::try_global(cx).unwrap())
|
||||
.unwrap();
|
||||
let ep_store = cx.update(|cx| {
|
||||
EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
|
||||
})??;
|
||||
|
||||
let mut events = ep_store
|
||||
.update(&mut cx, |store, cx| {
|
||||
store.register_buffer(&state.buffer, &project, cx);
|
||||
store.set_use_context(true);
|
||||
store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
|
||||
store.debug_info(&project, cx)
|
||||
})
|
||||
.unwrap();
|
||||
let mut events = ep_store.update(&mut cx, |store, cx| {
|
||||
store.register_buffer(&state.buffer, &project, cx);
|
||||
store.set_use_context(true);
|
||||
store.refresh_context(&project, &state.buffer, state.cursor_position, cx);
|
||||
store.debug_info(&project, cx)
|
||||
})?;
|
||||
|
||||
while let Some(event) = events.next().await {
|
||||
match event {
|
||||
@@ -59,9 +57,8 @@ pub async fn run_context_retrieval(
|
||||
}
|
||||
}
|
||||
|
||||
let context_files = ep_store
|
||||
.update(&mut cx, |store, cx| store.context_for_project(&project, cx))
|
||||
.unwrap();
|
||||
let context_files =
|
||||
ep_store.update(&mut cx, |store, cx| store.context_for_project(&project, cx))?;
|
||||
|
||||
let excerpt_count: usize = context_files.iter().map(|f| f.excerpts.len()).sum();
|
||||
step_progress.set_info(format!("{} excerpts", excerpt_count), InfoStyle::Normal);
|
||||
@@ -69,6 +66,7 @@ pub async fn run_context_retrieval(
|
||||
example.context = Some(ExampleContext {
|
||||
files: context_files,
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn wait_for_language_servers_to_start(
|
||||
@@ -76,10 +74,8 @@ async fn wait_for_language_servers_to_start(
|
||||
buffer: &Entity<Buffer>,
|
||||
step_progress: &Arc<StepProgress>,
|
||||
cx: &mut AsyncApp,
|
||||
) {
|
||||
let lsp_store = project
|
||||
.read_with(cx, |project, _| project.lsp_store())
|
||||
.unwrap();
|
||||
) -> anyhow::Result<()> {
|
||||
let lsp_store = project.read_with(cx, |project, _| project.lsp_store())?;
|
||||
|
||||
let (language_server_ids, mut starting_language_server_ids) = buffer
|
||||
.update(cx, |buffer, cx| {
|
||||
@@ -122,7 +118,7 @@ async fn wait_for_language_servers_to_start(
|
||||
}
|
||||
},
|
||||
_ = timeout.clone().fuse() => {
|
||||
panic!("LSP wait timed out after 5 minutes");
|
||||
return Err(anyhow::anyhow!("LSP wait timed out after 5 minutes"));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -131,8 +127,7 @@ async fn wait_for_language_servers_to_start(
|
||||
|
||||
if !language_server_ids.is_empty() {
|
||||
project
|
||||
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
|
||||
.unwrap()
|
||||
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
|
||||
.detach();
|
||||
}
|
||||
|
||||
@@ -174,10 +169,8 @@ async fn wait_for_language_servers_to_start(
|
||||
];
|
||||
|
||||
project
|
||||
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))
|
||||
.unwrap()
|
||||
.await
|
||||
.unwrap();
|
||||
.update(cx, |project, cx| project.save_buffer(buffer.clone(), cx))?
|
||||
.await?;
|
||||
|
||||
let mut pending_language_server_ids = HashSet::from_iter(language_server_ids.into_iter());
|
||||
while !pending_language_server_ids.is_empty() {
|
||||
@@ -188,11 +181,12 @@ async fn wait_for_language_servers_to_start(
|
||||
}
|
||||
},
|
||||
_ = timeout.clone().fuse() => {
|
||||
panic!("LSP wait timed out after 5 minutes");
|
||||
return Err(anyhow::anyhow!("LSP wait timed out after 5 minutes"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
drop(subscriptions);
|
||||
step_progress.clear_substatus();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -14,20 +14,18 @@ pub async fn run_scoring(
|
||||
example: &mut Example,
|
||||
args: &PredictArgs,
|
||||
app_state: Arc<EpAppState>,
|
||||
progress: Arc<Progress>,
|
||||
cx: AsyncApp,
|
||||
) {
|
||||
) -> anyhow::Result<()> {
|
||||
run_prediction(
|
||||
example,
|
||||
Some(args.provider),
|
||||
args.repetitions,
|
||||
app_state,
|
||||
progress.clone(),
|
||||
cx,
|
||||
)
|
||||
.await;
|
||||
.await?;
|
||||
|
||||
let _progress = progress.start(Step::Score, &example.name);
|
||||
let _progress = Progress::global().start(Step::Score, &example.name);
|
||||
|
||||
let expected_patch = parse_patch(&example.expected_patch);
|
||||
|
||||
@@ -45,6 +43,7 @@ pub async fn run_scoring(
|
||||
}
|
||||
|
||||
example.score = scores;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn parse_patch(patch: &str) -> Vec<DiffLine<'_>> {
|
||||
|
||||
@@ -21,7 +21,6 @@ default = ["font-kit", "wayland", "x11", "windows-manifest"]
|
||||
test-support = [
|
||||
"leak-detection",
|
||||
"collections/test-support",
|
||||
"rand",
|
||||
"util/test-support",
|
||||
"http_client/test-support",
|
||||
"wayland",
|
||||
@@ -109,7 +108,7 @@ parking = "2.0.0"
|
||||
parking_lot.workspace = true
|
||||
postage.workspace = true
|
||||
profiling.workspace = true
|
||||
rand = { optional = true, workspace = true }
|
||||
rand.workspace = true
|
||||
raw-window-handle = "0.6"
|
||||
refineable.workspace = true
|
||||
resvg = { version = "0.45.0", default-features = false, features = [
|
||||
@@ -158,8 +157,10 @@ media.workspace = true
|
||||
objc.workspace = true
|
||||
objc2 = { version = "0.6", optional = true }
|
||||
objc2-metal = { version = "0.3", optional = true }
|
||||
mach2.workspace = true
|
||||
#TODO: replace with "objc2"
|
||||
metal.workspace = true
|
||||
flume = "0.11"
|
||||
|
||||
[target.'cfg(any(target_os = "linux", target_os = "freebsd", target_os = "macos"))'.dependencies]
|
||||
pathfinder_geometry = "0.5"
|
||||
|
||||
@@ -84,6 +84,8 @@ mod macos {
|
||||
.allowlist_var("_dispatch_main_q")
|
||||
.allowlist_var("_dispatch_source_type_data_add")
|
||||
.allowlist_var("DISPATCH_QUEUE_PRIORITY_HIGH")
|
||||
.allowlist_var("DISPATCH_QUEUE_PRIORITY_DEFAULT")
|
||||
.allowlist_var("DISPATCH_QUEUE_PRIORITY_LOW")
|
||||
.allowlist_var("DISPATCH_TIME_NOW")
|
||||
.allowlist_function("dispatch_get_global_queue")
|
||||
.allowlist_function("dispatch_async_f")
|
||||
|
||||
@@ -38,10 +38,11 @@ use crate::{
|
||||
AssetSource, BackgroundExecutor, Bounds, ClipboardItem, CursorStyle, DispatchPhase, DisplayId,
|
||||
EventEmitter, FocusHandle, FocusMap, ForegroundExecutor, Global, KeyBinding, KeyContext,
|
||||
Keymap, Keystroke, LayoutId, Menu, MenuItem, OwnedMenu, PathPromptOptions, Pixels, Platform,
|
||||
PlatformDisplay, PlatformKeyboardLayout, PlatformKeyboardMapper, Point, PromptBuilder,
|
||||
PromptButton, PromptHandle, PromptLevel, Render, RenderImage, RenderablePromptHandle,
|
||||
Reservation, ScreenCaptureSource, SharedString, SubscriberSet, Subscription, SvgRenderer, Task,
|
||||
TextSystem, Window, WindowAppearance, WindowHandle, WindowId, WindowInvalidator,
|
||||
PlatformDisplay, PlatformKeyboardLayout, PlatformKeyboardMapper, Point, Priority,
|
||||
PromptBuilder, PromptButton, PromptHandle, PromptLevel, Render, RenderImage,
|
||||
RenderablePromptHandle, Reservation, ScreenCaptureSource, SharedString, SubscriberSet,
|
||||
Subscription, SvgRenderer, Task, TextSystem, Window, WindowAppearance, WindowHandle, WindowId,
|
||||
WindowInvalidator,
|
||||
colors::{Colors, GlobalColors},
|
||||
current_platform, hash, init_app_menus,
|
||||
};
|
||||
@@ -1494,6 +1495,24 @@ impl App {
|
||||
.spawn(async move { f(&mut cx).await })
|
||||
}
|
||||
|
||||
/// Spawns the future returned by the given function on the main thread with
|
||||
/// the given priority. The closure will be invoked with [AsyncApp], which
|
||||
/// allows the application state to be accessed across await points.
|
||||
pub fn spawn_with_priority<AsyncFn, R>(&self, priority: Priority, f: AsyncFn) -> Task<R>
|
||||
where
|
||||
AsyncFn: AsyncFnOnce(&mut AsyncApp) -> R + 'static,
|
||||
R: 'static,
|
||||
{
|
||||
if self.quitting {
|
||||
debug_panic!("Can't spawn on main thread after on_app_quit")
|
||||
};
|
||||
|
||||
let mut cx = self.to_async();
|
||||
|
||||
self.foreground_executor
|
||||
.spawn_with_priority(priority, async move { f(&mut cx).await })
|
||||
}
|
||||
|
||||
/// Schedules the given function to be run at the end of the current effect cycle, allowing entities
|
||||
/// that are currently on the stack to be returned to the app.
|
||||
pub fn defer(&mut self, f: impl FnOnce(&mut App) + 'static) {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use crate::{
|
||||
AnyView, AnyWindowHandle, AppContext, AsyncApp, DispatchPhase, Effect, EntityId, EventEmitter,
|
||||
FocusHandle, FocusOutEvent, Focusable, Global, KeystrokeObserver, Reservation, SubscriberSet,
|
||||
Subscription, Task, WeakEntity, WeakFocusHandle, Window, WindowHandle,
|
||||
FocusHandle, FocusOutEvent, Focusable, Global, KeystrokeObserver, Priority, Reservation,
|
||||
SubscriberSet, Subscription, Task, WeakEntity, WeakFocusHandle, Window, WindowHandle,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use futures::FutureExt;
|
||||
@@ -667,6 +667,25 @@ impl<'a, T: 'static> Context<'a, T> {
|
||||
window.spawn(self, async move |cx| f(view, cx).await)
|
||||
}
|
||||
|
||||
/// Schedule a future to be run asynchronously with the given priority.
|
||||
/// The given callback is invoked with a [`WeakEntity<V>`] to avoid leaking the entity for a long-running process.
|
||||
/// It's also given an [`AsyncWindowContext`], which can be used to access the state of the entity across await points.
|
||||
/// The returned future will be polled on the main thread.
|
||||
#[track_caller]
|
||||
pub fn spawn_in_with_priority<AsyncFn, R>(
|
||||
&self,
|
||||
priority: Priority,
|
||||
window: &Window,
|
||||
f: AsyncFn,
|
||||
) -> Task<R>
|
||||
where
|
||||
R: 'static,
|
||||
AsyncFn: AsyncFnOnce(WeakEntity<T>, &mut AsyncWindowContext) -> R + 'static,
|
||||
{
|
||||
let view = self.weak_entity();
|
||||
window.spawn_with_priority(priority, self, async move |cx| f(view, cx).await)
|
||||
}
|
||||
|
||||
/// Register a callback to be invoked when the given global state changes.
|
||||
pub fn observe_global_in<G: Global>(
|
||||
&mut self,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::{App, PlatformDispatcher, RunnableMeta, RunnableVariant};
|
||||
use crate::{App, PlatformDispatcher, RunnableMeta, RunnableVariant, TaskTiming, profiler};
|
||||
use async_task::Runnable;
|
||||
use futures::channel::mpsc;
|
||||
use parking_lot::{Condvar, Mutex};
|
||||
@@ -47,6 +47,52 @@ pub struct ForegroundExecutor {
|
||||
not_send: PhantomData<Rc<()>>,
|
||||
}
|
||||
|
||||
/// Realtime task priority
|
||||
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
|
||||
#[repr(u8)]
|
||||
pub enum RealtimePriority {
|
||||
/// Audio task
|
||||
Audio,
|
||||
/// Other realtime task
|
||||
#[default]
|
||||
Other,
|
||||
}
|
||||
|
||||
/// Task priority
|
||||
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
|
||||
#[repr(u8)]
|
||||
pub enum Priority {
|
||||
/// Realtime priority
|
||||
///
|
||||
/// Spawning a task with this priority will spin it off on a separate thread dedicated just to that task.
|
||||
Realtime(RealtimePriority),
|
||||
/// High priority
|
||||
///
|
||||
/// Only use for tasks that are critical to the user experience / responsiveness of the editor.
|
||||
High,
|
||||
/// Medium priority, probably suits most of your use cases.
|
||||
#[default]
|
||||
Medium,
|
||||
/// Low priority
|
||||
///
|
||||
/// Prioritize this for background work that can come in large quantities
|
||||
/// to not starve the executor of resources for high priority tasks
|
||||
Low,
|
||||
}
|
||||
|
||||
impl Priority {
|
||||
#[allow(dead_code)]
|
||||
pub(crate) const fn probability(&self) -> u32 {
|
||||
match self {
|
||||
// realtime priorities are not considered for probability scheduling
|
||||
Priority::Realtime(_) => 0,
|
||||
Priority::High => 60,
|
||||
Priority::Medium => 30,
|
||||
Priority::Low => 10,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Task is a primitive that allows work to happen in the background.
|
||||
///
|
||||
/// It implements [`Future`] so you can `.await` on it.
|
||||
@@ -152,7 +198,20 @@ impl BackgroundExecutor {
|
||||
where
|
||||
R: Send + 'static,
|
||||
{
|
||||
self.spawn_internal::<R>(Box::pin(future), None)
|
||||
self.spawn_with_priority(Priority::default(), future)
|
||||
}
|
||||
|
||||
/// Enqueues the given future to be run to completion on a background thread.
|
||||
#[track_caller]
|
||||
pub fn spawn_with_priority<R>(
|
||||
&self,
|
||||
priority: Priority,
|
||||
future: impl Future<Output = R> + Send + 'static,
|
||||
) -> Task<R>
|
||||
where
|
||||
R: Send + 'static,
|
||||
{
|
||||
self.spawn_internal::<R>(Box::pin(future), None, priority)
|
||||
}
|
||||
|
||||
/// Enqueues the given future to be run to completion on a background thread and blocking the current task on it.
|
||||
@@ -199,7 +258,13 @@ impl BackgroundExecutor {
|
||||
let _notify_guard = NotifyOnDrop(pair);
|
||||
future.await
|
||||
},
|
||||
move |runnable| dispatcher.dispatch(RunnableVariant::Meta(runnable), None),
|
||||
move |runnable| {
|
||||
dispatcher.dispatch(
|
||||
RunnableVariant::Meta(runnable),
|
||||
None,
|
||||
Priority::default(),
|
||||
)
|
||||
},
|
||||
)
|
||||
};
|
||||
runnable.schedule();
|
||||
@@ -217,7 +282,7 @@ impl BackgroundExecutor {
|
||||
where
|
||||
R: Send + 'static,
|
||||
{
|
||||
self.spawn_internal::<R>(Box::pin(future), Some(label))
|
||||
self.spawn_internal::<R>(Box::pin(future), Some(label), Priority::default())
|
||||
}
|
||||
|
||||
#[track_caller]
|
||||
@@ -225,15 +290,55 @@ impl BackgroundExecutor {
|
||||
&self,
|
||||
future: AnyFuture<R>,
|
||||
label: Option<TaskLabel>,
|
||||
priority: Priority,
|
||||
) -> Task<R> {
|
||||
let dispatcher = self.dispatcher.clone();
|
||||
let location = core::panic::Location::caller();
|
||||
let (runnable, task) = async_task::Builder::new()
|
||||
.metadata(RunnableMeta { location })
|
||||
.spawn(
|
||||
move |_| future,
|
||||
move |runnable| dispatcher.dispatch(RunnableVariant::Meta(runnable), label),
|
||||
let (runnable, task) = if let Priority::Realtime(realtime) = priority {
|
||||
let location = core::panic::Location::caller();
|
||||
let (mut tx, rx) = flume::bounded::<Runnable<RunnableMeta>>(1);
|
||||
|
||||
dispatcher.spawn_realtime(
|
||||
realtime,
|
||||
Box::new(move || {
|
||||
while let Ok(runnable) = rx.recv() {
|
||||
let start = Instant::now();
|
||||
let location = runnable.metadata().location;
|
||||
let mut timing = TaskTiming {
|
||||
location,
|
||||
start,
|
||||
end: None,
|
||||
};
|
||||
profiler::add_task_timing(timing);
|
||||
|
||||
runnable.run();
|
||||
|
||||
let end = Instant::now();
|
||||
timing.end = Some(end);
|
||||
profiler::add_task_timing(timing);
|
||||
}
|
||||
}),
|
||||
);
|
||||
|
||||
async_task::Builder::new()
|
||||
.metadata(RunnableMeta { location })
|
||||
.spawn(
|
||||
move |_| future,
|
||||
move |runnable| {
|
||||
let _ = tx.send(runnable);
|
||||
},
|
||||
)
|
||||
} else {
|
||||
let location = core::panic::Location::caller();
|
||||
async_task::Builder::new()
|
||||
.metadata(RunnableMeta { location })
|
||||
.spawn(
|
||||
move |_| future,
|
||||
move |runnable| {
|
||||
dispatcher.dispatch(RunnableVariant::Meta(runnable), label, priority)
|
||||
},
|
||||
)
|
||||
};
|
||||
|
||||
runnable.schedule();
|
||||
Task(TaskState::Spawned(task))
|
||||
}
|
||||
@@ -406,11 +511,28 @@ impl BackgroundExecutor {
|
||||
where
|
||||
F: FnOnce(&mut Scope<'scope>),
|
||||
{
|
||||
let mut scope = Scope::new(self.clone());
|
||||
let mut scope = Scope::new(self.clone(), Priority::default());
|
||||
(scheduler)(&mut scope);
|
||||
let spawned = mem::take(&mut scope.futures)
|
||||
.into_iter()
|
||||
.map(|f| self.spawn(f))
|
||||
.map(|f| self.spawn_with_priority(scope.priority, f))
|
||||
.collect::<Vec<_>>();
|
||||
for task in spawned {
|
||||
task.await;
|
||||
}
|
||||
}
|
||||
|
||||
/// Scoped lets you start a number of tasks and waits
|
||||
/// for all of them to complete before returning.
|
||||
pub async fn scoped_priority<'scope, F>(&self, priority: Priority, scheduler: F)
|
||||
where
|
||||
F: FnOnce(&mut Scope<'scope>),
|
||||
{
|
||||
let mut scope = Scope::new(self.clone(), priority);
|
||||
(scheduler)(&mut scope);
|
||||
let spawned = mem::take(&mut scope.futures)
|
||||
.into_iter()
|
||||
.map(|f| self.spawn_with_priority(scope.priority, f))
|
||||
.collect::<Vec<_>>();
|
||||
for task in spawned {
|
||||
task.await;
|
||||
@@ -546,6 +668,19 @@ impl ForegroundExecutor {
|
||||
/// Enqueues the given Task to run on the main thread at some point in the future.
|
||||
#[track_caller]
|
||||
pub fn spawn<R>(&self, future: impl Future<Output = R> + 'static) -> Task<R>
|
||||
where
|
||||
R: 'static,
|
||||
{
|
||||
self.spawn_with_priority(Priority::default(), future)
|
||||
}
|
||||
|
||||
/// Enqueues the given Task to run on the main thread at some point in the future.
|
||||
#[track_caller]
|
||||
pub fn spawn_with_priority<R>(
|
||||
&self,
|
||||
priority: Priority,
|
||||
future: impl Future<Output = R> + 'static,
|
||||
) -> Task<R>
|
||||
where
|
||||
R: 'static,
|
||||
{
|
||||
@@ -557,16 +692,19 @@ impl ForegroundExecutor {
|
||||
dispatcher: Arc<dyn PlatformDispatcher>,
|
||||
future: AnyLocalFuture<R>,
|
||||
location: &'static core::panic::Location<'static>,
|
||||
priority: Priority,
|
||||
) -> Task<R> {
|
||||
let (runnable, task) = spawn_local_with_source_location(
|
||||
future,
|
||||
move |runnable| dispatcher.dispatch_on_main_thread(RunnableVariant::Meta(runnable)),
|
||||
move |runnable| {
|
||||
dispatcher.dispatch_on_main_thread(RunnableVariant::Meta(runnable), priority)
|
||||
},
|
||||
RunnableMeta { location },
|
||||
);
|
||||
runnable.schedule();
|
||||
Task(TaskState::Spawned(task))
|
||||
}
|
||||
inner::<R>(dispatcher, Box::pin(future), location)
|
||||
inner::<R>(dispatcher, Box::pin(future), location, priority)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -642,6 +780,7 @@ where
|
||||
/// Scope manages a set of tasks that are enqueued and waited on together. See [`BackgroundExecutor::scoped`].
|
||||
pub struct Scope<'a> {
|
||||
executor: BackgroundExecutor,
|
||||
priority: Priority,
|
||||
futures: Vec<Pin<Box<dyn Future<Output = ()> + Send + 'static>>>,
|
||||
tx: Option<mpsc::Sender<()>>,
|
||||
rx: mpsc::Receiver<()>,
|
||||
@@ -649,10 +788,11 @@ pub struct Scope<'a> {
|
||||
}
|
||||
|
||||
impl<'a> Scope<'a> {
|
||||
fn new(executor: BackgroundExecutor) -> Self {
|
||||
fn new(executor: BackgroundExecutor, priority: Priority) -> Self {
|
||||
let (tx, rx) = mpsc::channel(1);
|
||||
Self {
|
||||
executor,
|
||||
priority,
|
||||
tx: Some(tx),
|
||||
rx,
|
||||
futures: Default::default(),
|
||||
|
||||
@@ -1416,9 +1416,9 @@ where
|
||||
/// ```
|
||||
pub fn contains(&self, point: &Point<T>) -> bool {
|
||||
point.x >= self.origin.x
|
||||
&& point.x <= self.origin.x.clone() + self.size.width.clone()
|
||||
&& point.x < self.origin.x.clone() + self.size.width.clone()
|
||||
&& point.y >= self.origin.y
|
||||
&& point.y <= self.origin.y.clone() + self.size.height.clone()
|
||||
&& point.y < self.origin.y.clone() + self.size.height.clone()
|
||||
}
|
||||
|
||||
/// Checks if this bounds is completely contained within another bounds.
|
||||
|
||||
@@ -31,6 +31,8 @@ mod path_builder;
|
||||
mod platform;
|
||||
pub mod prelude;
|
||||
mod profiler;
|
||||
#[cfg(any(target_os = "windows", target_os = "linux"))]
|
||||
mod queue;
|
||||
mod scene;
|
||||
mod shared_string;
|
||||
mod shared_uri;
|
||||
@@ -89,16 +91,20 @@ pub use keymap::*;
|
||||
pub use path_builder::*;
|
||||
pub use platform::*;
|
||||
pub use profiler::*;
|
||||
#[cfg(any(target_os = "windows", target_os = "linux"))]
|
||||
pub(crate) use queue::{PriorityQueueReceiver, PriorityQueueSender};
|
||||
pub use refineable::*;
|
||||
pub use scene::*;
|
||||
pub use shared_string::*;
|
||||
pub use shared_uri::*;
|
||||
pub use smol::Timer;
|
||||
use std::{any::Any, future::Future};
|
||||
pub use style::*;
|
||||
pub use styled::*;
|
||||
pub use subscription::*;
|
||||
pub use svg_renderer::*;
|
||||
pub(crate) use tab_stop::*;
|
||||
use taffy::TaffyLayoutEngine;
|
||||
pub use taffy::{AvailableSpace, LayoutId};
|
||||
#[cfg(any(test, feature = "test-support"))]
|
||||
pub use test::*;
|
||||
@@ -109,9 +115,6 @@ pub use util::{FutureExt, Timeout, arc_cow::ArcCow};
|
||||
pub use view::*;
|
||||
pub use window::*;
|
||||
|
||||
use std::{any::Any, future::Future};
|
||||
use taffy::TaffyLayoutEngine;
|
||||
|
||||
/// The context trait, allows the different contexts in GPUI to be used
|
||||
/// interchangeably for certain operations.
|
||||
pub trait AppContext {
|
||||
|
||||
@@ -39,9 +39,10 @@ use crate::{
|
||||
Action, AnyWindowHandle, App, AsyncWindowContext, BackgroundExecutor, Bounds,
|
||||
DEFAULT_WINDOW_SIZE, DevicePixels, DispatchEventResult, Font, FontId, FontMetrics, FontRun,
|
||||
ForegroundExecutor, GlyphId, GpuSpecs, ImageSource, Keymap, LineLayout, Pixels, PlatformInput,
|
||||
Point, RenderGlyphParams, RenderImage, RenderImageParams, RenderSvgParams, Scene, ShapedGlyph,
|
||||
ShapedRun, SharedString, Size, SvgRenderer, SystemWindowTab, Task, TaskLabel, TaskTiming,
|
||||
ThreadTaskTimings, Window, WindowControlArea, hash, point, px, size,
|
||||
Point, Priority, RealtimePriority, RenderGlyphParams, RenderImage, RenderImageParams,
|
||||
RenderSvgParams, Scene, ShapedGlyph, ShapedRun, SharedString, Size, SvgRenderer,
|
||||
SystemWindowTab, Task, TaskLabel, TaskTiming, ThreadTaskTimings, Window, WindowControlArea,
|
||||
hash, point, px, size,
|
||||
};
|
||||
use anyhow::Result;
|
||||
use async_task::Runnable;
|
||||
@@ -587,9 +588,10 @@ pub trait PlatformDispatcher: Send + Sync {
|
||||
fn get_all_timings(&self) -> Vec<ThreadTaskTimings>;
|
||||
fn get_current_thread_timings(&self) -> Vec<TaskTiming>;
|
||||
fn is_main_thread(&self) -> bool;
|
||||
fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>);
|
||||
fn dispatch_on_main_thread(&self, runnable: RunnableVariant);
|
||||
fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>, priority: Priority);
|
||||
fn dispatch_on_main_thread(&self, runnable: RunnableVariant, priority: Priority);
|
||||
fn dispatch_after(&self, duration: Duration, runnable: RunnableVariant);
|
||||
fn spawn_realtime(&self, priority: RealtimePriority, f: Box<dyn FnOnce() + Send>);
|
||||
|
||||
fn now(&self) -> Instant {
|
||||
Instant::now()
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
use crate::{
|
||||
GLOBAL_THREAD_TIMINGS, PlatformDispatcher, RunnableVariant, THREAD_TIMINGS, TaskLabel,
|
||||
TaskTiming, ThreadTaskTimings,
|
||||
GLOBAL_THREAD_TIMINGS, PlatformDispatcher, Priority, PriorityQueueReceiver,
|
||||
PriorityQueueSender, RealtimePriority, RunnableVariant, THREAD_TIMINGS, TaskLabel, TaskTiming,
|
||||
ThreadTaskTimings, profiler,
|
||||
};
|
||||
use calloop::{
|
||||
EventLoop,
|
||||
EventLoop, PostAction,
|
||||
channel::{self, Sender},
|
||||
timer::TimeoutAction,
|
||||
};
|
||||
@@ -19,9 +20,9 @@ struct TimerAfter {
|
||||
}
|
||||
|
||||
pub(crate) struct LinuxDispatcher {
|
||||
main_sender: Sender<RunnableVariant>,
|
||||
main_sender: PriorityQueueCalloopSender<RunnableVariant>,
|
||||
timer_sender: Sender<TimerAfter>,
|
||||
background_sender: flume::Sender<RunnableVariant>,
|
||||
background_sender: PriorityQueueSender<RunnableVariant>,
|
||||
_background_threads: Vec<thread::JoinHandle<()>>,
|
||||
main_thread_id: thread::ThreadId,
|
||||
}
|
||||
@@ -29,18 +30,20 @@ pub(crate) struct LinuxDispatcher {
|
||||
const MIN_THREADS: usize = 2;
|
||||
|
||||
impl LinuxDispatcher {
|
||||
pub fn new(main_sender: Sender<RunnableVariant>) -> Self {
|
||||
let (background_sender, background_receiver) = flume::unbounded::<RunnableVariant>();
|
||||
pub fn new(main_sender: PriorityQueueCalloopSender<RunnableVariant>) -> Self {
|
||||
let (background_sender, background_receiver) = PriorityQueueReceiver::new();
|
||||
let thread_count =
|
||||
std::thread::available_parallelism().map_or(MIN_THREADS, |i| i.get().max(MIN_THREADS));
|
||||
|
||||
// These thread should really be lower prio then the foreground
|
||||
// executor
|
||||
let mut background_threads = (0..thread_count)
|
||||
.map(|i| {
|
||||
let receiver = background_receiver.clone();
|
||||
let mut receiver = background_receiver.clone();
|
||||
std::thread::Builder::new()
|
||||
.name(format!("Worker-{i}"))
|
||||
.spawn(move || {
|
||||
for runnable in receiver {
|
||||
for runnable in receiver.iter() {
|
||||
let start = Instant::now();
|
||||
|
||||
let mut location = match runnable {
|
||||
@@ -51,7 +54,7 @@ impl LinuxDispatcher {
|
||||
start,
|
||||
end: None,
|
||||
};
|
||||
Self::add_task_timing(timing);
|
||||
profiler::add_task_timing(timing);
|
||||
|
||||
runnable.run();
|
||||
timing
|
||||
@@ -63,7 +66,7 @@ impl LinuxDispatcher {
|
||||
start,
|
||||
end: None,
|
||||
};
|
||||
Self::add_task_timing(timing);
|
||||
profiler::add_task_timing(timing);
|
||||
|
||||
runnable.run();
|
||||
timing
|
||||
@@ -72,7 +75,7 @@ impl LinuxDispatcher {
|
||||
|
||||
let end = Instant::now();
|
||||
location.end = Some(end);
|
||||
Self::add_task_timing(location);
|
||||
profiler::add_task_timing(location);
|
||||
|
||||
log::trace!(
|
||||
"background thread {}: ran runnable. took: {:?}",
|
||||
@@ -113,7 +116,7 @@ impl LinuxDispatcher {
|
||||
start,
|
||||
end: None,
|
||||
};
|
||||
Self::add_task_timing(timing);
|
||||
profiler::add_task_timing(timing);
|
||||
|
||||
runnable.run();
|
||||
timing
|
||||
@@ -124,7 +127,7 @@ impl LinuxDispatcher {
|
||||
start,
|
||||
end: None,
|
||||
};
|
||||
Self::add_task_timing(timing);
|
||||
profiler::add_task_timing(timing);
|
||||
|
||||
runnable.run();
|
||||
timing
|
||||
@@ -133,7 +136,7 @@ impl LinuxDispatcher {
|
||||
let end = Instant::now();
|
||||
|
||||
timing.end = Some(end);
|
||||
Self::add_task_timing(timing);
|
||||
profiler::add_task_timing(timing);
|
||||
}
|
||||
TimeoutAction::Drop
|
||||
},
|
||||
@@ -157,22 +160,6 @@ impl LinuxDispatcher {
|
||||
main_thread_id: thread::current().id(),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn add_task_timing(timing: TaskTiming) {
|
||||
THREAD_TIMINGS.with(|timings| {
|
||||
let mut timings = timings.lock();
|
||||
let timings = &mut timings.timings;
|
||||
|
||||
if let Some(last_timing) = timings.iter_mut().rev().next() {
|
||||
if last_timing.location == timing.location {
|
||||
last_timing.end = timing.end;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
timings.push_back(timing);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
impl PlatformDispatcher for LinuxDispatcher {
|
||||
@@ -199,22 +186,26 @@ impl PlatformDispatcher for LinuxDispatcher {
|
||||
thread::current().id() == self.main_thread_id
|
||||
}
|
||||
|
||||
fn dispatch(&self, runnable: RunnableVariant, _: Option<TaskLabel>) {
|
||||
self.background_sender.send(runnable).unwrap();
|
||||
fn dispatch(&self, runnable: RunnableVariant, _: Option<TaskLabel>, priority: Priority) {
|
||||
self.background_sender
|
||||
.send(priority, runnable)
|
||||
.unwrap_or_else(|_| panic!("blocking sender returned without value"));
|
||||
}
|
||||
|
||||
fn dispatch_on_main_thread(&self, runnable: RunnableVariant) {
|
||||
self.main_sender.send(runnable).unwrap_or_else(|runnable| {
|
||||
// NOTE: Runnable may wrap a Future that is !Send.
|
||||
//
|
||||
// This is usually safe because we only poll it on the main thread.
|
||||
// However if the send fails, we know that:
|
||||
// 1. main_receiver has been dropped (which implies the app is shutting down)
|
||||
// 2. we are on a background thread.
|
||||
// It is not safe to drop something !Send on the wrong thread, and
|
||||
// the app will exit soon anyway, so we must forget the runnable.
|
||||
std::mem::forget(runnable);
|
||||
});
|
||||
fn dispatch_on_main_thread(&self, runnable: RunnableVariant, priority: Priority) {
|
||||
self.main_sender
|
||||
.send(priority, runnable)
|
||||
.unwrap_or_else(|runnable| {
|
||||
// NOTE: Runnable may wrap a Future that is !Send.
|
||||
//
|
||||
// This is usually safe because we only poll it on the main thread.
|
||||
// However if the send fails, we know that:
|
||||
// 1. main_receiver has been dropped (which implies the app is shutting down)
|
||||
// 2. we are on a background thread.
|
||||
// It is not safe to drop something !Send on the wrong thread, and
|
||||
// the app will exit soon anyway, so we must forget the runnable.
|
||||
std::mem::forget(runnable);
|
||||
});
|
||||
}
|
||||
|
||||
fn dispatch_after(&self, duration: Duration, runnable: RunnableVariant) {
|
||||
@@ -222,4 +213,252 @@ impl PlatformDispatcher for LinuxDispatcher {
|
||||
.send(TimerAfter { duration, runnable })
|
||||
.ok();
|
||||
}
|
||||
|
||||
fn spawn_realtime(&self, priority: RealtimePriority, f: Box<dyn FnOnce() + Send>) {
|
||||
std::thread::spawn(move || {
|
||||
// SAFETY: always safe to call
|
||||
let thread_id = unsafe { libc::pthread_self() };
|
||||
|
||||
let policy = match priority {
|
||||
RealtimePriority::Audio => libc::SCHED_FIFO,
|
||||
RealtimePriority::Other => libc::SCHED_RR,
|
||||
};
|
||||
let sched_priority = match priority {
|
||||
RealtimePriority::Audio => 65,
|
||||
RealtimePriority::Other => 45,
|
||||
};
|
||||
|
||||
let sched_param = libc::sched_param { sched_priority };
|
||||
// SAFETY: sched_param is a valid initialized structure
|
||||
let result = unsafe { libc::pthread_setschedparam(thread_id, policy, &sched_param) };
|
||||
if result != 0 {
|
||||
log::warn!("failed to set realtime thread priority to {:?}", priority);
|
||||
}
|
||||
|
||||
f();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PriorityQueueCalloopSender<T> {
|
||||
sender: PriorityQueueSender<T>,
|
||||
ping: calloop::ping::Ping,
|
||||
}
|
||||
|
||||
impl<T> PriorityQueueCalloopSender<T> {
|
||||
fn new(tx: PriorityQueueSender<T>, ping: calloop::ping::Ping) -> Self {
|
||||
Self { sender: tx, ping }
|
||||
}
|
||||
|
||||
fn send(&self, priority: Priority, item: T) -> Result<(), crate::queue::SendError<T>> {
|
||||
let res = self.sender.send(priority, item);
|
||||
if res.is_ok() {
|
||||
self.ping.ping();
|
||||
}
|
||||
res
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Drop for PriorityQueueCalloopSender<T> {
|
||||
fn drop(&mut self) {
|
||||
self.ping.ping();
|
||||
}
|
||||
}
|
||||
|
||||
pub struct PriorityQueueCalloopReceiver<T> {
|
||||
receiver: PriorityQueueReceiver<T>,
|
||||
source: calloop::ping::PingSource,
|
||||
ping: calloop::ping::Ping,
|
||||
}
|
||||
|
||||
impl<T> PriorityQueueCalloopReceiver<T> {
|
||||
pub fn new() -> (PriorityQueueCalloopSender<T>, Self) {
|
||||
let (ping, source) = calloop::ping::make_ping().expect("Failed to create a Ping.");
|
||||
|
||||
let (tx, rx) = PriorityQueueReceiver::new();
|
||||
|
||||
(
|
||||
PriorityQueueCalloopSender::new(tx, ping.clone()),
|
||||
Self {
|
||||
receiver: rx,
|
||||
source,
|
||||
ping,
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
use calloop::channel::Event;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ChannelError(calloop::ping::PingError);
|
||||
|
||||
impl std::fmt::Display for ChannelError {
|
||||
#[cfg_attr(feature = "nightly_coverage", coverage(off))]
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
std::fmt::Display::fmt(&self.0, f)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for ChannelError {
|
||||
#[cfg_attr(feature = "nightly_coverage", coverage(off))]
|
||||
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
|
||||
Some(&self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> calloop::EventSource for PriorityQueueCalloopReceiver<T> {
|
||||
type Event = Event<T>;
|
||||
type Metadata = ();
|
||||
type Ret = ();
|
||||
type Error = ChannelError;
|
||||
|
||||
fn process_events<F>(
|
||||
&mut self,
|
||||
readiness: calloop::Readiness,
|
||||
token: calloop::Token,
|
||||
mut callback: F,
|
||||
) -> Result<calloop::PostAction, Self::Error>
|
||||
where
|
||||
F: FnMut(Self::Event, &mut Self::Metadata) -> Self::Ret,
|
||||
{
|
||||
let mut clear_readiness = false;
|
||||
let mut disconnected = false;
|
||||
|
||||
let action = self
|
||||
.source
|
||||
.process_events(readiness, token, |(), &mut ()| {
|
||||
let mut is_empty = true;
|
||||
|
||||
let mut receiver = self.receiver.clone();
|
||||
for runnable in receiver.try_iter() {
|
||||
match runnable {
|
||||
Ok(r) => {
|
||||
callback(Event::Msg(r), &mut ());
|
||||
is_empty = false;
|
||||
}
|
||||
Err(_) => {
|
||||
disconnected = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if disconnected {
|
||||
callback(Event::Closed, &mut ());
|
||||
}
|
||||
|
||||
if is_empty {
|
||||
clear_readiness = true;
|
||||
}
|
||||
})
|
||||
.map_err(ChannelError)?;
|
||||
|
||||
if disconnected {
|
||||
Ok(PostAction::Remove)
|
||||
} else if clear_readiness {
|
||||
Ok(action)
|
||||
} else {
|
||||
// Re-notify the ping source so we can try again.
|
||||
self.ping.ping();
|
||||
Ok(PostAction::Continue)
|
||||
}
|
||||
}
|
||||
|
||||
fn register(
|
||||
&mut self,
|
||||
poll: &mut calloop::Poll,
|
||||
token_factory: &mut calloop::TokenFactory,
|
||||
) -> calloop::Result<()> {
|
||||
self.source.register(poll, token_factory)
|
||||
}
|
||||
|
||||
fn reregister(
|
||||
&mut self,
|
||||
poll: &mut calloop::Poll,
|
||||
token_factory: &mut calloop::TokenFactory,
|
||||
) -> calloop::Result<()> {
|
||||
self.source.reregister(poll, token_factory)
|
||||
}
|
||||
|
||||
fn unregister(&mut self, poll: &mut calloop::Poll) -> calloop::Result<()> {
|
||||
self.source.unregister(poll)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn calloop_works() {
|
||||
let mut event_loop = calloop::EventLoop::try_new().unwrap();
|
||||
let handle = event_loop.handle();
|
||||
|
||||
let (tx, rx) = PriorityQueueCalloopReceiver::new();
|
||||
|
||||
struct Data {
|
||||
got_msg: bool,
|
||||
got_closed: bool,
|
||||
}
|
||||
|
||||
let mut data = Data {
|
||||
got_msg: false,
|
||||
got_closed: false,
|
||||
};
|
||||
|
||||
let _channel_token = handle
|
||||
.insert_source(rx, move |evt, &mut (), data: &mut Data| match evt {
|
||||
Event::Msg(()) => {
|
||||
data.got_msg = true;
|
||||
}
|
||||
|
||||
Event::Closed => {
|
||||
data.got_closed = true;
|
||||
}
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// nothing is sent, nothing is received
|
||||
event_loop
|
||||
.dispatch(Some(::std::time::Duration::ZERO), &mut data)
|
||||
.unwrap();
|
||||
|
||||
assert!(!data.got_msg);
|
||||
assert!(!data.got_closed);
|
||||
// a message is send
|
||||
|
||||
tx.send(Priority::Medium, ()).unwrap();
|
||||
event_loop
|
||||
.dispatch(Some(::std::time::Duration::ZERO), &mut data)
|
||||
.unwrap();
|
||||
|
||||
assert!(data.got_msg);
|
||||
assert!(!data.got_closed);
|
||||
|
||||
// the sender is dropped
|
||||
drop(tx);
|
||||
event_loop
|
||||
.dispatch(Some(::std::time::Duration::ZERO), &mut data)
|
||||
.unwrap();
|
||||
|
||||
assert!(data.got_msg);
|
||||
assert!(data.got_closed);
|
||||
}
|
||||
}
|
||||
|
||||
// running 1 test
|
||||
// test platform::linux::dispatcher::tests::tomato ... FAILED
|
||||
|
||||
// failures:
|
||||
|
||||
// ---- platform::linux::dispatcher::tests::tomato stdout ----
|
||||
// [crates/gpui/src/platform/linux/dispatcher.rs:262:9]
|
||||
// returning 1 tasks to process
|
||||
// [crates/gpui/src/platform/linux/dispatcher.rs:480:75] evt = Msg(
|
||||
// (),
|
||||
// )
|
||||
// returning 0 tasks to process
|
||||
|
||||
// thread 'platform::linux::dispatcher::tests::tomato' (478301) panicked at crates/gpui/src/platform/linux/dispatcher.rs:515:9:
|
||||
// assertion failed: data.got_closed
|
||||
// note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace
|
||||
|
||||
@@ -14,7 +14,7 @@ use std::{
|
||||
};
|
||||
|
||||
use anyhow::{Context as _, anyhow};
|
||||
use calloop::{LoopSignal, channel::Channel};
|
||||
use calloop::LoopSignal;
|
||||
use futures::channel::oneshot;
|
||||
use util::ResultExt as _;
|
||||
use util::command::{new_smol_command, new_std_command};
|
||||
@@ -25,8 +25,8 @@ use crate::{
|
||||
Action, AnyWindowHandle, BackgroundExecutor, ClipboardItem, CursorStyle, DisplayId,
|
||||
ForegroundExecutor, Keymap, LinuxDispatcher, Menu, MenuItem, OwnedMenu, PathPromptOptions,
|
||||
Pixels, Platform, PlatformDisplay, PlatformKeyboardLayout, PlatformKeyboardMapper,
|
||||
PlatformTextSystem, PlatformWindow, Point, Result, RunnableVariant, Task, WindowAppearance,
|
||||
WindowParams, px,
|
||||
PlatformTextSystem, PlatformWindow, Point, PriorityQueueCalloopReceiver, Result,
|
||||
RunnableVariant, Task, WindowAppearance, WindowParams, px,
|
||||
};
|
||||
|
||||
#[cfg(any(feature = "wayland", feature = "x11"))]
|
||||
@@ -149,8 +149,8 @@ pub(crate) struct LinuxCommon {
|
||||
}
|
||||
|
||||
impl LinuxCommon {
|
||||
pub fn new(signal: LoopSignal) -> (Self, Channel<RunnableVariant>) {
|
||||
let (main_sender, main_receiver) = calloop::channel::channel::<RunnableVariant>();
|
||||
pub fn new(signal: LoopSignal) -> (Self, PriorityQueueCalloopReceiver<RunnableVariant>) {
|
||||
let (main_sender, main_receiver) = PriorityQueueCalloopReceiver::new();
|
||||
|
||||
#[cfg(any(feature = "wayland", feature = "x11"))]
|
||||
let text_system = Arc::new(crate::CosmicTextSystem::new());
|
||||
|
||||
@@ -77,10 +77,10 @@ use crate::{
|
||||
LinuxKeyboardLayout, Modifiers, ModifiersChangedEvent, MouseButton, MouseDownEvent,
|
||||
MouseExitEvent, MouseMoveEvent, MouseUpEvent, NavigationDirection, Pixels, PlatformDisplay,
|
||||
PlatformInput, PlatformKeyboardLayout, Point, ResultExt as _, SCROLL_LINES, ScrollDelta,
|
||||
ScrollWheelEvent, Size, TouchPhase, WindowParams, point, px, size,
|
||||
ScrollWheelEvent, Size, TouchPhase, WindowParams, point, profiler, px, size,
|
||||
};
|
||||
use crate::{
|
||||
LinuxDispatcher, RunnableVariant, TaskTiming,
|
||||
RunnableVariant, TaskTiming,
|
||||
platform::{PlatformWindow, blade::BladeContext},
|
||||
};
|
||||
use crate::{
|
||||
@@ -503,7 +503,7 @@ impl WaylandClient {
|
||||
start,
|
||||
end: None,
|
||||
};
|
||||
LinuxDispatcher::add_task_timing(timing);
|
||||
profiler::add_task_timing(timing);
|
||||
|
||||
runnable.run();
|
||||
timing
|
||||
@@ -515,7 +515,7 @@ impl WaylandClient {
|
||||
start,
|
||||
end: None,
|
||||
};
|
||||
LinuxDispatcher::add_task_timing(timing);
|
||||
profiler::add_task_timing(timing);
|
||||
|
||||
runnable.run();
|
||||
timing
|
||||
@@ -524,7 +524,7 @@ impl WaylandClient {
|
||||
|
||||
let end = Instant::now();
|
||||
timing.end = Some(end);
|
||||
LinuxDispatcher::add_task_timing(timing);
|
||||
profiler::add_task_timing(timing);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::{Capslock, LinuxDispatcher, ResultExt as _, RunnableVariant, TaskTiming, xcb_flush};
|
||||
use crate::{Capslock, ResultExt as _, RunnableVariant, TaskTiming, profiler, xcb_flush};
|
||||
use anyhow::{Context as _, anyhow};
|
||||
use ashpd::WindowIdentifier;
|
||||
use calloop::{
|
||||
@@ -322,7 +322,7 @@ impl X11Client {
|
||||
start,
|
||||
end: None,
|
||||
};
|
||||
LinuxDispatcher::add_task_timing(timing);
|
||||
profiler::add_task_timing(timing);
|
||||
|
||||
runnable.run();
|
||||
timing
|
||||
@@ -334,7 +334,7 @@ impl X11Client {
|
||||
start,
|
||||
end: None,
|
||||
};
|
||||
LinuxDispatcher::add_task_timing(timing);
|
||||
profiler::add_task_timing(timing);
|
||||
|
||||
runnable.run();
|
||||
timing
|
||||
@@ -343,7 +343,7 @@ impl X11Client {
|
||||
|
||||
let end = Instant::now();
|
||||
timing.end = Some(end);
|
||||
LinuxDispatcher::add_task_timing(timing);
|
||||
profiler::add_task_timing(timing);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,11 +3,22 @@
|
||||
#![allow(non_snake_case)]
|
||||
|
||||
use crate::{
|
||||
GLOBAL_THREAD_TIMINGS, PlatformDispatcher, RunnableMeta, RunnableVariant, THREAD_TIMINGS,
|
||||
TaskLabel, TaskTiming, ThreadTaskTimings,
|
||||
GLOBAL_THREAD_TIMINGS, PlatformDispatcher, Priority, RealtimePriority, RunnableMeta,
|
||||
RunnableVariant, THREAD_TIMINGS, TaskLabel, TaskTiming, ThreadTaskTimings,
|
||||
};
|
||||
|
||||
use anyhow::Context;
|
||||
use async_task::Runnable;
|
||||
use mach2::{
|
||||
kern_return::KERN_SUCCESS,
|
||||
mach_time::mach_timebase_info_data_t,
|
||||
thread_policy::{
|
||||
THREAD_EXTENDED_POLICY, THREAD_EXTENDED_POLICY_COUNT, THREAD_PRECEDENCE_POLICY,
|
||||
THREAD_PRECEDENCE_POLICY_COUNT, THREAD_TIME_CONSTRAINT_POLICY,
|
||||
THREAD_TIME_CONSTRAINT_POLICY_COUNT, thread_extended_policy_data_t,
|
||||
thread_precedence_policy_data_t, thread_time_constraint_policy_data_t,
|
||||
},
|
||||
};
|
||||
use objc::{
|
||||
class, msg_send,
|
||||
runtime::{BOOL, YES},
|
||||
@@ -15,9 +26,11 @@ use objc::{
|
||||
};
|
||||
use std::{
|
||||
ffi::c_void,
|
||||
mem::MaybeUninit,
|
||||
ptr::{NonNull, addr_of},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use util::ResultExt;
|
||||
|
||||
/// All items in the generated file are marked as pub, so we're gonna wrap it in a separate mod to prevent
|
||||
/// these pub items from leaking into public API.
|
||||
@@ -56,7 +69,7 @@ impl PlatformDispatcher for MacDispatcher {
|
||||
is_main_thread == YES
|
||||
}
|
||||
|
||||
fn dispatch(&self, runnable: RunnableVariant, _: Option<TaskLabel>) {
|
||||
fn dispatch(&self, runnable: RunnableVariant, _: Option<TaskLabel>, priority: Priority) {
|
||||
let (context, trampoline) = match runnable {
|
||||
RunnableVariant::Meta(runnable) => (
|
||||
runnable.into_raw().as_ptr() as *mut c_void,
|
||||
@@ -67,16 +80,24 @@ impl PlatformDispatcher for MacDispatcher {
|
||||
Some(trampoline_compat as unsafe extern "C" fn(*mut c_void)),
|
||||
),
|
||||
};
|
||||
|
||||
let queue_priority = match priority {
|
||||
Priority::Realtime(_) => unreachable!(),
|
||||
Priority::High => DISPATCH_QUEUE_PRIORITY_HIGH as isize,
|
||||
Priority::Medium => DISPATCH_QUEUE_PRIORITY_DEFAULT as isize,
|
||||
Priority::Low => DISPATCH_QUEUE_PRIORITY_LOW as isize,
|
||||
};
|
||||
|
||||
unsafe {
|
||||
dispatch_async_f(
|
||||
dispatch_get_global_queue(DISPATCH_QUEUE_PRIORITY_HIGH.try_into().unwrap(), 0),
|
||||
dispatch_get_global_queue(queue_priority, 0),
|
||||
context,
|
||||
trampoline,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn dispatch_on_main_thread(&self, runnable: RunnableVariant) {
|
||||
fn dispatch_on_main_thread(&self, runnable: RunnableVariant, _priority: Priority) {
|
||||
let (context, trampoline) = match runnable {
|
||||
RunnableVariant::Meta(runnable) => (
|
||||
runnable.into_raw().as_ptr() as *mut c_void,
|
||||
@@ -110,6 +131,120 @@ impl PlatformDispatcher for MacDispatcher {
|
||||
dispatch_after_f(when, queue, context, trampoline);
|
||||
}
|
||||
}
|
||||
|
||||
fn spawn_realtime(&self, priority: RealtimePriority, f: Box<dyn FnOnce() + Send>) {
|
||||
std::thread::spawn(move || {
|
||||
match priority {
|
||||
RealtimePriority::Audio => set_audio_thread_priority(),
|
||||
RealtimePriority::Other => set_high_thread_priority(),
|
||||
}
|
||||
.context(format!("for priority {:?}", priority))
|
||||
.log_err();
|
||||
|
||||
f();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn set_high_thread_priority() -> anyhow::Result<()> {
|
||||
// SAFETY: always safe to call
|
||||
let thread_id = unsafe { libc::pthread_self() };
|
||||
|
||||
// SAFETY: all sched_param members are valid when initialized to zero.
|
||||
let mut sched_param = unsafe { MaybeUninit::<libc::sched_param>::zeroed().assume_init() };
|
||||
sched_param.sched_priority = 45;
|
||||
|
||||
let result = unsafe { libc::pthread_setschedparam(thread_id, libc::SCHED_FIFO, &sched_param) };
|
||||
if result != 0 {
|
||||
anyhow::bail!("failed to set realtime thread priority")
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn set_audio_thread_priority() -> anyhow::Result<()> {
|
||||
// https://chromium.googlesource.com/chromium/chromium/+/master/base/threading/platform_thread_mac.mm#93
|
||||
|
||||
// SAFETY: always safe to call
|
||||
let thread_id = unsafe { libc::pthread_self() };
|
||||
|
||||
// SAFETY: thread_id is a valid thread id
|
||||
let thread_id = unsafe { libc::pthread_mach_thread_np(thread_id) };
|
||||
|
||||
// Fixed priority thread
|
||||
let mut policy = thread_extended_policy_data_t { timeshare: 0 };
|
||||
|
||||
// SAFETY: thread_id is a valid thread id
|
||||
// SAFETY: thread_extended_policy_data_t is passed as THREAD_EXTENDED_POLICY
|
||||
let result = unsafe {
|
||||
mach2::thread_policy::thread_policy_set(
|
||||
thread_id,
|
||||
THREAD_EXTENDED_POLICY,
|
||||
&mut policy as *mut _ as *mut _,
|
||||
THREAD_EXTENDED_POLICY_COUNT,
|
||||
)
|
||||
};
|
||||
|
||||
if result != KERN_SUCCESS {
|
||||
anyhow::bail!("failed to set thread extended policy");
|
||||
}
|
||||
|
||||
// relatively high priority
|
||||
let mut precedence = thread_precedence_policy_data_t { importance: 63 };
|
||||
|
||||
// SAFETY: thread_id is a valid thread id
|
||||
// SAFETY: thread_precedence_policy_data_t is passed as THREAD_PRECEDENCE_POLICY
|
||||
let result = unsafe {
|
||||
mach2::thread_policy::thread_policy_set(
|
||||
thread_id,
|
||||
THREAD_PRECEDENCE_POLICY,
|
||||
&mut precedence as *mut _ as *mut _,
|
||||
THREAD_PRECEDENCE_POLICY_COUNT,
|
||||
)
|
||||
};
|
||||
|
||||
if result != KERN_SUCCESS {
|
||||
anyhow::bail!("failed to set thread precedence policy");
|
||||
}
|
||||
|
||||
const GUARANTEED_AUDIO_DUTY_CYCLE: f32 = 0.75;
|
||||
const MAX_AUDIO_DUTY_CYCLE: f32 = 0.85;
|
||||
|
||||
// ~128 frames @ 44.1KHz
|
||||
const TIME_QUANTUM: f32 = 2.9;
|
||||
|
||||
const AUDIO_TIME_NEEDED: f32 = GUARANTEED_AUDIO_DUTY_CYCLE * TIME_QUANTUM;
|
||||
const MAX_TIME_ALLOWED: f32 = MAX_AUDIO_DUTY_CYCLE * TIME_QUANTUM;
|
||||
|
||||
let mut timebase_info = mach_timebase_info_data_t { numer: 0, denom: 0 };
|
||||
// SAFETY: timebase_info is a valid pointer to a mach_timebase_info_data_t struct
|
||||
unsafe { mach2::mach_time::mach_timebase_info(&mut timebase_info) };
|
||||
|
||||
let ms_to_abs_time = ((timebase_info.denom as f32) / (timebase_info.numer as f32)) * 1000000f32;
|
||||
|
||||
let mut time_constraints = thread_time_constraint_policy_data_t {
|
||||
period: (TIME_QUANTUM * ms_to_abs_time) as u32,
|
||||
computation: (AUDIO_TIME_NEEDED * ms_to_abs_time) as u32,
|
||||
constraint: (MAX_TIME_ALLOWED * ms_to_abs_time) as u32,
|
||||
preemptible: 0,
|
||||
};
|
||||
|
||||
// SAFETY: thread_id is a valid thread id
|
||||
// SAFETY: thread_precedence_pthread_time_constraint_policy_data_t is passed as THREAD_TIME_CONSTRAINT_POLICY
|
||||
let result = unsafe {
|
||||
mach2::thread_policy::thread_policy_set(
|
||||
thread_id,
|
||||
THREAD_TIME_CONSTRAINT_POLICY,
|
||||
&mut time_constraints as *mut _ as *mut _,
|
||||
THREAD_TIME_CONSTRAINT_POLICY_COUNT,
|
||||
)
|
||||
};
|
||||
|
||||
if result != KERN_SUCCESS {
|
||||
anyhow::bail!("failed to set thread time constraint policy");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
extern "C" fn trampoline(runnable: *mut c_void) {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use crate::{PlatformDispatcher, RunnableVariant, TaskLabel};
|
||||
use crate::{PlatformDispatcher, Priority, RunnableVariant, TaskLabel};
|
||||
use backtrace::Backtrace;
|
||||
use collections::{HashMap, HashSet, VecDeque};
|
||||
use parking::Unparker;
|
||||
@@ -284,7 +284,7 @@ impl PlatformDispatcher for TestDispatcher {
|
||||
state.start_time + state.time
|
||||
}
|
||||
|
||||
fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>) {
|
||||
fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>, _priority: Priority) {
|
||||
{
|
||||
let mut state = self.state.lock();
|
||||
if label.is_some_and(|label| state.deprioritized_task_labels.contains(&label)) {
|
||||
@@ -296,7 +296,7 @@ impl PlatformDispatcher for TestDispatcher {
|
||||
self.unpark_all();
|
||||
}
|
||||
|
||||
fn dispatch_on_main_thread(&self, runnable: RunnableVariant) {
|
||||
fn dispatch_on_main_thread(&self, runnable: RunnableVariant, _priority: Priority) {
|
||||
self.state
|
||||
.lock()
|
||||
.foreground
|
||||
@@ -318,4 +318,10 @@ impl PlatformDispatcher for TestDispatcher {
|
||||
fn as_test(&self) -> Option<&TestDispatcher> {
|
||||
Some(self)
|
||||
}
|
||||
|
||||
fn spawn_realtime(&self, _priority: crate::RealtimePriority, f: Box<dyn FnOnce() + Send>) {
|
||||
std::thread::spawn(move || {
|
||||
f();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,24 +4,31 @@ use std::{
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use flume::Sender;
|
||||
use anyhow::Context;
|
||||
use util::ResultExt;
|
||||
use windows::{
|
||||
System::Threading::{ThreadPool, ThreadPoolTimer, TimerElapsedHandler, WorkItemHandler},
|
||||
System::Threading::{
|
||||
ThreadPool, ThreadPoolTimer, TimerElapsedHandler, WorkItemHandler, WorkItemPriority,
|
||||
},
|
||||
Win32::{
|
||||
Foundation::{LPARAM, WPARAM},
|
||||
System::Threading::{
|
||||
GetCurrentThread, HIGH_PRIORITY_CLASS, SetPriorityClass, SetThreadPriority,
|
||||
THREAD_PRIORITY_HIGHEST, THREAD_PRIORITY_TIME_CRITICAL,
|
||||
},
|
||||
UI::WindowsAndMessaging::PostMessageW,
|
||||
},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
GLOBAL_THREAD_TIMINGS, HWND, PlatformDispatcher, RunnableVariant, SafeHwnd, THREAD_TIMINGS,
|
||||
TaskLabel, TaskTiming, ThreadTaskTimings, WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD,
|
||||
GLOBAL_THREAD_TIMINGS, HWND, PlatformDispatcher, Priority, PriorityQueueSender,
|
||||
RealtimePriority, RunnableVariant, SafeHwnd, THREAD_TIMINGS, TaskLabel, TaskTiming,
|
||||
ThreadTaskTimings, WM_GPUI_TASK_DISPATCHED_ON_MAIN_THREAD, profiler,
|
||||
};
|
||||
|
||||
pub(crate) struct WindowsDispatcher {
|
||||
pub(crate) wake_posted: AtomicBool,
|
||||
main_sender: Sender<RunnableVariant>,
|
||||
main_sender: PriorityQueueSender<RunnableVariant>,
|
||||
main_thread_id: ThreadId,
|
||||
pub(crate) platform_window_handle: SafeHwnd,
|
||||
validation_number: usize,
|
||||
@@ -29,7 +36,7 @@ pub(crate) struct WindowsDispatcher {
|
||||
|
||||
impl WindowsDispatcher {
|
||||
pub(crate) fn new(
|
||||
main_sender: Sender<RunnableVariant>,
|
||||
main_sender: PriorityQueueSender<RunnableVariant>,
|
||||
platform_window_handle: HWND,
|
||||
validation_number: usize,
|
||||
) -> Self {
|
||||
@@ -45,7 +52,7 @@ impl WindowsDispatcher {
|
||||
}
|
||||
}
|
||||
|
||||
fn dispatch_on_threadpool(&self, runnable: RunnableVariant) {
|
||||
fn dispatch_on_threadpool(&self, priority: WorkItemPriority, runnable: RunnableVariant) {
|
||||
let handler = {
|
||||
let mut task_wrapper = Some(runnable);
|
||||
WorkItemHandler::new(move |_| {
|
||||
@@ -53,7 +60,8 @@ impl WindowsDispatcher {
|
||||
Ok(())
|
||||
})
|
||||
};
|
||||
ThreadPool::RunAsync(&handler).log_err();
|
||||
|
||||
ThreadPool::RunWithPriorityAsync(&handler, priority).log_err();
|
||||
}
|
||||
|
||||
fn dispatch_on_threadpool_after(&self, runnable: RunnableVariant, duration: Duration) {
|
||||
@@ -79,7 +87,7 @@ impl WindowsDispatcher {
|
||||
start,
|
||||
end: None,
|
||||
};
|
||||
Self::add_task_timing(timing);
|
||||
profiler::add_task_timing(timing);
|
||||
|
||||
runnable.run();
|
||||
|
||||
@@ -91,7 +99,7 @@ impl WindowsDispatcher {
|
||||
start,
|
||||
end: None,
|
||||
};
|
||||
Self::add_task_timing(timing);
|
||||
profiler::add_task_timing(timing);
|
||||
|
||||
runnable.run();
|
||||
|
||||
@@ -102,23 +110,7 @@ impl WindowsDispatcher {
|
||||
let end = Instant::now();
|
||||
timing.end = Some(end);
|
||||
|
||||
Self::add_task_timing(timing);
|
||||
}
|
||||
|
||||
pub(crate) fn add_task_timing(timing: TaskTiming) {
|
||||
THREAD_TIMINGS.with(|timings| {
|
||||
let mut timings = timings.lock();
|
||||
let timings = &mut timings.timings;
|
||||
|
||||
if let Some(last_timing) = timings.iter_mut().rev().next() {
|
||||
if last_timing.location == timing.location {
|
||||
last_timing.end = timing.end;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
timings.push_back(timing);
|
||||
});
|
||||
profiler::add_task_timing(timing);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -146,15 +138,22 @@ impl PlatformDispatcher for WindowsDispatcher {
|
||||
current().id() == self.main_thread_id
|
||||
}
|
||||
|
||||
fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>) {
|
||||
self.dispatch_on_threadpool(runnable);
|
||||
fn dispatch(&self, runnable: RunnableVariant, label: Option<TaskLabel>, priority: Priority) {
|
||||
let priority = match priority {
|
||||
Priority::Realtime(_) => unreachable!(),
|
||||
Priority::High => WorkItemPriority::High,
|
||||
Priority::Medium => WorkItemPriority::Normal,
|
||||
Priority::Low => WorkItemPriority::Low,
|
||||
};
|
||||
self.dispatch_on_threadpool(priority, runnable);
|
||||
|
||||
if let Some(label) = label {
|
||||
log::debug!("TaskLabel: {label:?}");
|
||||
}
|
||||
}
|
||||
|
||||
fn dispatch_on_main_thread(&self, runnable: RunnableVariant) {
|
||||
match self.main_sender.send(runnable) {
|
||||
fn dispatch_on_main_thread(&self, runnable: RunnableVariant, priority: Priority) {
|
||||
match self.main_sender.send(priority, runnable) {
|
||||
Ok(_) => {
|
||||
if !self.wake_posted.swap(true, Ordering::AcqRel) {
|
||||
unsafe {
|
||||
@@ -185,4 +184,28 @@ impl PlatformDispatcher for WindowsDispatcher {
|
||||
fn dispatch_after(&self, duration: Duration, runnable: RunnableVariant) {
|
||||
self.dispatch_on_threadpool_after(runnable, duration);
|
||||
}
|
||||
|
||||
fn spawn_realtime(&self, priority: RealtimePriority, f: Box<dyn FnOnce() + Send>) {
|
||||
std::thread::spawn(move || {
|
||||
// SAFETY: always safe to call
|
||||
let thread_handle = unsafe { GetCurrentThread() };
|
||||
|
||||
let thread_priority = match priority {
|
||||
RealtimePriority::Audio => THREAD_PRIORITY_TIME_CRITICAL,
|
||||
RealtimePriority::Other => THREAD_PRIORITY_HIGHEST,
|
||||
};
|
||||
|
||||
// SAFETY: thread_handle is a valid handle to a thread
|
||||
unsafe { SetPriorityClass(thread_handle, HIGH_PRIORITY_CLASS) }
|
||||
.context("thread priority class")
|
||||
.log_err();
|
||||
|
||||
// SAFETY: thread_handle is a valid handle to a thread
|
||||
unsafe { SetThreadPriority(thread_handle, thread_priority) }
|
||||
.context("thread priority")
|
||||
.log_err();
|
||||
|
||||
f();
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -243,7 +243,8 @@ impl WindowsWindowInner {
|
||||
|
||||
fn handle_timer_msg(&self, handle: HWND, wparam: WPARAM) -> Option<isize> {
|
||||
if wparam.0 == SIZE_MOVE_LOOP_TIMER_ID {
|
||||
for runnable in self.main_receiver.drain() {
|
||||
let mut runnables = self.main_receiver.clone().try_iter();
|
||||
while let Some(Ok(runnable)) = runnables.next() {
|
||||
WindowsDispatcher::execute_runnable(runnable);
|
||||
}
|
||||
self.handle_paint_msg(handle)
|
||||
|
||||
@@ -51,7 +51,7 @@ struct WindowsPlatformInner {
|
||||
raw_window_handles: std::sync::Weak<RwLock<SmallVec<[SafeHwnd; 4]>>>,
|
||||
// The below members will never change throughout the entire lifecycle of the app.
|
||||
validation_number: usize,
|
||||
main_receiver: flume::Receiver<RunnableVariant>,
|
||||
main_receiver: PriorityQueueReceiver<RunnableVariant>,
|
||||
dispatcher: Arc<WindowsDispatcher>,
|
||||
}
|
||||
|
||||
@@ -98,7 +98,7 @@ impl WindowsPlatform {
|
||||
OleInitialize(None).context("unable to initialize Windows OLE")?;
|
||||
}
|
||||
let directx_devices = DirectXDevices::new().context("Creating DirectX devices")?;
|
||||
let (main_sender, main_receiver) = flume::unbounded::<RunnableVariant>();
|
||||
let (main_sender, main_receiver) = PriorityQueueReceiver::new();
|
||||
let validation_number = if usize::BITS == 64 {
|
||||
rand::random::<u64>() as usize
|
||||
} else {
|
||||
@@ -857,22 +857,24 @@ impl WindowsPlatformInner {
|
||||
}
|
||||
break 'tasks;
|
||||
}
|
||||
match self.main_receiver.try_recv() {
|
||||
Err(_) => break 'timeout_loop,
|
||||
Ok(runnable) => WindowsDispatcher::execute_runnable(runnable),
|
||||
let mut main_receiver = self.main_receiver.clone();
|
||||
match main_receiver.try_pop() {
|
||||
Ok(Some(runnable)) => WindowsDispatcher::execute_runnable(runnable),
|
||||
_ => break 'timeout_loop,
|
||||
}
|
||||
}
|
||||
|
||||
// Someone could enqueue a Runnable here. The flag is still true, so they will not PostMessage.
|
||||
// We need to check for those Runnables after we clear the flag.
|
||||
self.dispatcher.wake_posted.store(false, Ordering::Release);
|
||||
match self.main_receiver.try_recv() {
|
||||
Err(_) => break 'tasks,
|
||||
Ok(runnable) => {
|
||||
let mut main_receiver = self.main_receiver.clone();
|
||||
match main_receiver.try_pop() {
|
||||
Ok(Some(runnable)) => {
|
||||
self.dispatcher.wake_posted.store(true, Ordering::Release);
|
||||
|
||||
WindowsDispatcher::execute_runnable(runnable);
|
||||
}
|
||||
_ => break 'tasks,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -934,7 +936,7 @@ pub(crate) struct WindowCreationInfo {
|
||||
pub(crate) windows_version: WindowsVersion,
|
||||
pub(crate) drop_target_helper: IDropTargetHelper,
|
||||
pub(crate) validation_number: usize,
|
||||
pub(crate) main_receiver: flume::Receiver<RunnableVariant>,
|
||||
pub(crate) main_receiver: PriorityQueueReceiver<RunnableVariant>,
|
||||
pub(crate) platform_window_handle: HWND,
|
||||
pub(crate) disable_direct_composition: bool,
|
||||
pub(crate) directx_devices: DirectXDevices,
|
||||
@@ -947,8 +949,8 @@ struct PlatformWindowCreateContext {
|
||||
inner: Option<Result<Rc<WindowsPlatformInner>>>,
|
||||
raw_window_handles: std::sync::Weak<RwLock<SmallVec<[SafeHwnd; 4]>>>,
|
||||
validation_number: usize,
|
||||
main_sender: Option<flume::Sender<RunnableVariant>>,
|
||||
main_receiver: Option<flume::Receiver<RunnableVariant>>,
|
||||
main_sender: Option<PriorityQueueSender<RunnableVariant>>,
|
||||
main_receiver: Option<PriorityQueueReceiver<RunnableVariant>>,
|
||||
directx_devices: Option<DirectXDevices>,
|
||||
dispatcher: Option<Arc<WindowsDispatcher>>,
|
||||
}
|
||||
|
||||
@@ -81,7 +81,7 @@ pub(crate) struct WindowsWindowInner {
|
||||
pub(crate) executor: ForegroundExecutor,
|
||||
pub(crate) windows_version: WindowsVersion,
|
||||
pub(crate) validation_number: usize,
|
||||
pub(crate) main_receiver: flume::Receiver<RunnableVariant>,
|
||||
pub(crate) main_receiver: PriorityQueueReceiver<RunnableVariant>,
|
||||
pub(crate) platform_window_handle: HWND,
|
||||
}
|
||||
|
||||
@@ -362,7 +362,7 @@ struct WindowCreateContext {
|
||||
windows_version: WindowsVersion,
|
||||
drop_target_helper: IDropTargetHelper,
|
||||
validation_number: usize,
|
||||
main_receiver: flume::Receiver<RunnableVariant>,
|
||||
main_receiver: PriorityQueueReceiver<RunnableVariant>,
|
||||
platform_window_handle: HWND,
|
||||
appearance: WindowAppearance,
|
||||
disable_direct_composition: bool,
|
||||
|
||||
@@ -216,3 +216,19 @@ impl Drop for ThreadTimings {
|
||||
thread_timings.swap_remove(index);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn add_task_timing(timing: TaskTiming) {
|
||||
THREAD_TIMINGS.with(|timings| {
|
||||
let mut timings = timings.lock();
|
||||
let timings = &mut timings.timings;
|
||||
|
||||
if let Some(last_timing) = timings.iter_mut().rev().next() {
|
||||
if last_timing.location == timing.location {
|
||||
last_timing.end = timing.end;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
timings.push_back(timing);
|
||||
});
|
||||
}
|
||||
|
||||
329
crates/gpui/src/queue.rs
Normal file
329
crates/gpui/src/queue.rs
Normal file
@@ -0,0 +1,329 @@
|
||||
use std::{
|
||||
fmt,
|
||||
iter::FusedIterator,
|
||||
sync::{Arc, atomic::AtomicUsize},
|
||||
};
|
||||
|
||||
use rand::{Rng, SeedableRng, rngs::SmallRng};
|
||||
|
||||
use crate::Priority;
|
||||
|
||||
struct PriorityQueues<T> {
|
||||
high_priority: Vec<T>,
|
||||
medium_priority: Vec<T>,
|
||||
low_priority: Vec<T>,
|
||||
}
|
||||
|
||||
impl<T> PriorityQueues<T> {
|
||||
fn is_empty(&self) -> bool {
|
||||
self.high_priority.is_empty()
|
||||
&& self.medium_priority.is_empty()
|
||||
&& self.low_priority.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
struct PriorityQueueState<T> {
|
||||
queues: parking_lot::Mutex<PriorityQueues<T>>,
|
||||
condvar: parking_lot::Condvar,
|
||||
receiver_count: AtomicUsize,
|
||||
sender_count: AtomicUsize,
|
||||
}
|
||||
|
||||
impl<T> PriorityQueueState<T> {
|
||||
fn send(&self, priority: Priority, item: T) -> Result<(), SendError<T>> {
|
||||
if self
|
||||
.receiver_count
|
||||
.load(std::sync::atomic::Ordering::Relaxed)
|
||||
== 0
|
||||
{
|
||||
return Err(SendError(item));
|
||||
}
|
||||
|
||||
let mut queues = self.queues.lock();
|
||||
match priority {
|
||||
Priority::Realtime(_) => unreachable!(),
|
||||
Priority::High => queues.high_priority.push(item),
|
||||
Priority::Medium => queues.medium_priority.push(item),
|
||||
Priority::Low => queues.low_priority.push(item),
|
||||
};
|
||||
self.condvar.notify_one();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn recv<'a>(&'a self) -> Result<parking_lot::MutexGuard<'a, PriorityQueues<T>>, RecvError> {
|
||||
let mut queues = self.queues.lock();
|
||||
|
||||
let sender_count = self.sender_count.load(std::sync::atomic::Ordering::Relaxed);
|
||||
if queues.is_empty() && sender_count == 0 {
|
||||
return Err(crate::queue::RecvError);
|
||||
}
|
||||
|
||||
// parking_lot doesn't do spurious wakeups so an if is fine
|
||||
if queues.is_empty() {
|
||||
self.condvar.wait(&mut queues);
|
||||
}
|
||||
|
||||
Ok(queues)
|
||||
}
|
||||
|
||||
fn try_recv<'a>(
|
||||
&'a self,
|
||||
) -> Result<Option<parking_lot::MutexGuard<'a, PriorityQueues<T>>>, RecvError> {
|
||||
let mut queues = self.queues.lock();
|
||||
|
||||
let sender_count = self.sender_count.load(std::sync::atomic::Ordering::Relaxed);
|
||||
if queues.is_empty() && sender_count == 0 {
|
||||
return Err(crate::queue::RecvError);
|
||||
}
|
||||
|
||||
if queues.is_empty() {
|
||||
Ok(None)
|
||||
} else {
|
||||
Ok(Some(queues))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct PriorityQueueSender<T> {
|
||||
state: Arc<PriorityQueueState<T>>,
|
||||
}
|
||||
|
||||
impl<T> PriorityQueueSender<T> {
|
||||
fn new(state: Arc<PriorityQueueState<T>>) -> Self {
|
||||
Self { state }
|
||||
}
|
||||
|
||||
pub(crate) fn send(&self, priority: Priority, item: T) -> Result<(), SendError<T>> {
|
||||
self.state.send(priority, item)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Drop for PriorityQueueSender<T> {
|
||||
fn drop(&mut self) {
|
||||
self.state
|
||||
.sender_count
|
||||
.fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct PriorityQueueReceiver<T> {
|
||||
state: Arc<PriorityQueueState<T>>,
|
||||
rand: SmallRng,
|
||||
disconnected: bool,
|
||||
}
|
||||
|
||||
impl<T> Clone for PriorityQueueReceiver<T> {
|
||||
fn clone(&self) -> Self {
|
||||
self.state
|
||||
.receiver_count
|
||||
.fetch_add(1, std::sync::atomic::Ordering::AcqRel);
|
||||
Self {
|
||||
state: Arc::clone(&self.state),
|
||||
rand: SmallRng::seed_from_u64(0),
|
||||
disconnected: self.disconnected,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct SendError<T>(T);
|
||||
|
||||
impl<T: fmt::Debug> fmt::Debug for SendError<T> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.debug_tuple("SendError").field(&self.0).finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub(crate) struct RecvError;
|
||||
|
||||
#[allow(dead_code)]
|
||||
impl<T> PriorityQueueReceiver<T> {
|
||||
pub(crate) fn new() -> (PriorityQueueSender<T>, Self) {
|
||||
let state = PriorityQueueState {
|
||||
queues: parking_lot::Mutex::new(PriorityQueues {
|
||||
high_priority: Vec::new(),
|
||||
medium_priority: Vec::new(),
|
||||
low_priority: Vec::new(),
|
||||
}),
|
||||
condvar: parking_lot::Condvar::new(),
|
||||
receiver_count: AtomicUsize::new(1),
|
||||
sender_count: AtomicUsize::new(1),
|
||||
};
|
||||
let state = Arc::new(state);
|
||||
|
||||
let sender = PriorityQueueSender::new(Arc::clone(&state));
|
||||
|
||||
let receiver = PriorityQueueReceiver {
|
||||
state,
|
||||
rand: SmallRng::seed_from_u64(0),
|
||||
disconnected: false,
|
||||
};
|
||||
|
||||
(sender, receiver)
|
||||
}
|
||||
|
||||
/// Tries to pop one element from the priority queue without blocking.
|
||||
///
|
||||
/// This will early return if there are no elements in the queue.
|
||||
///
|
||||
/// This method is best suited if you only intend to pop one element, for better performance
|
||||
/// on large queues see [`Self::try_iter`]
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// If the sender was dropped
|
||||
pub(crate) fn try_pop(&mut self) -> Result<Option<T>, RecvError> {
|
||||
self.pop_inner(false)
|
||||
}
|
||||
|
||||
/// Pops an element from the priority queue blocking if necessary.
|
||||
///
|
||||
/// This method is best suited if you only intend to pop one element, for better performance
|
||||
/// on large queues see [`Self::iter``]
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// If the sender was dropped
|
||||
pub(crate) fn pop(&mut self) -> Result<T, RecvError> {
|
||||
self.pop_inner(true).map(|e| e.unwrap())
|
||||
}
|
||||
|
||||
/// Returns an iterator over the elements of the queue
|
||||
/// this iterator will end when all elements have been consumed and will not wait for new ones.
|
||||
pub(crate) fn try_iter(self) -> TryIter<T> {
|
||||
TryIter {
|
||||
receiver: self,
|
||||
ended: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an iterator over the elements of the queue
|
||||
/// this iterator will wait for new elements if the queue is empty.
|
||||
pub(crate) fn iter(self) -> Iter<T> {
|
||||
Iter(self)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
// algorithm is the loaded die from biased coin from
|
||||
// https://www.keithschwarz.com/darts-dice-coins/
|
||||
fn pop_inner(&mut self, block: bool) -> Result<Option<T>, RecvError> {
|
||||
use Priority as P;
|
||||
|
||||
let mut queues = if !block {
|
||||
let Some(queues) = self.state.try_recv()? else {
|
||||
return Ok(None);
|
||||
};
|
||||
queues
|
||||
} else {
|
||||
self.state.recv()?
|
||||
};
|
||||
|
||||
let high = P::High.probability() * !queues.high_priority.is_empty() as u32;
|
||||
let medium = P::Medium.probability() * !queues.medium_priority.is_empty() as u32;
|
||||
let low = P::Low.probability() * !queues.low_priority.is_empty() as u32;
|
||||
let mut mass = high + medium + low; //%
|
||||
|
||||
if !queues.high_priority.is_empty() {
|
||||
let flip = self.rand.random_ratio(P::High.probability(), mass);
|
||||
if flip {
|
||||
return Ok(queues.high_priority.pop());
|
||||
}
|
||||
mass -= P::High.probability();
|
||||
}
|
||||
|
||||
if !queues.medium_priority.is_empty() {
|
||||
let flip = self.rand.random_ratio(P::Medium.probability(), mass);
|
||||
if flip {
|
||||
return Ok(queues.medium_priority.pop());
|
||||
}
|
||||
mass -= P::Medium.probability();
|
||||
}
|
||||
|
||||
if !queues.low_priority.is_empty() {
|
||||
let flip = self.rand.random_ratio(P::Low.probability(), mass);
|
||||
if flip {
|
||||
return Ok(queues.low_priority.pop());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Drop for PriorityQueueReceiver<T> {
|
||||
fn drop(&mut self) {
|
||||
self.state
|
||||
.receiver_count
|
||||
.fetch_sub(1, std::sync::atomic::Ordering::AcqRel);
|
||||
}
|
||||
}
|
||||
|
||||
/// If None is returned the sender disconnected
|
||||
pub(crate) struct Iter<T>(PriorityQueueReceiver<T>);
|
||||
impl<T> Iterator for Iter<T> {
|
||||
type Item = T;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
self.0.pop_inner(true).ok().flatten()
|
||||
}
|
||||
}
|
||||
impl<T> FusedIterator for Iter<T> {}
|
||||
|
||||
/// If None is returned there are no more elements in the queue
|
||||
pub(crate) struct TryIter<T> {
|
||||
receiver: PriorityQueueReceiver<T>,
|
||||
ended: bool,
|
||||
}
|
||||
impl<T> Iterator for TryIter<T> {
|
||||
type Item = Result<T, RecvError>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.ended {
|
||||
return None;
|
||||
}
|
||||
|
||||
let res = self.receiver.pop_inner(false);
|
||||
self.ended = res.is_err();
|
||||
|
||||
res.transpose()
|
||||
}
|
||||
}
|
||||
impl<T> FusedIterator for TryIter<T> {}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use collections::HashSet;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn all_tasks_get_yielded() {
|
||||
let (tx, mut rx) = PriorityQueueReceiver::new();
|
||||
tx.send(Priority::Medium, 20).unwrap();
|
||||
tx.send(Priority::High, 30).unwrap();
|
||||
tx.send(Priority::Low, 10).unwrap();
|
||||
tx.send(Priority::Medium, 21).unwrap();
|
||||
tx.send(Priority::High, 31).unwrap();
|
||||
|
||||
drop(tx);
|
||||
|
||||
assert_eq!(
|
||||
rx.iter().collect::<HashSet<_>>(),
|
||||
[30, 31, 20, 21, 10].into_iter().collect::<HashSet<_>>()
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_high_prio_task_get_scheduled_quickly() {
|
||||
let (tx, mut rx) = PriorityQueueReceiver::new();
|
||||
for _ in 0..100 {
|
||||
tx.send(Priority::Low, 1).unwrap();
|
||||
}
|
||||
|
||||
assert_eq!(rx.pop().unwrap(), 1);
|
||||
tx.send(Priority::High, 3).unwrap();
|
||||
assert_eq!(rx.pop().unwrap(), 3);
|
||||
assert_eq!(rx.pop().unwrap(), 1);
|
||||
}
|
||||
}
|
||||
@@ -9,14 +9,15 @@ use crate::{
|
||||
KeyBinding, KeyContext, KeyDownEvent, KeyEvent, Keystroke, KeystrokeEvent, LayoutId,
|
||||
LineLayoutIndex, Modifiers, ModifiersChangedEvent, MonochromeSprite, MouseButton, MouseEvent,
|
||||
MouseMoveEvent, MouseUpEvent, Path, Pixels, PlatformAtlas, PlatformDisplay, PlatformInput,
|
||||
PlatformInputHandler, PlatformWindow, Point, PolychromeSprite, PromptButton, PromptLevel, Quad,
|
||||
Render, RenderGlyphParams, RenderImage, RenderImageParams, RenderSvgParams, Replay, ResizeEdge,
|
||||
SMOOTH_SVG_SCALE_FACTOR, SUBPIXEL_VARIANTS_X, SUBPIXEL_VARIANTS_Y, ScaledPixels, Scene, Shadow,
|
||||
SharedString, Size, StrikethroughStyle, Style, SubscriberSet, Subscription, SystemWindowTab,
|
||||
SystemWindowTabController, TabStopMap, TaffyLayoutEngine, Task, TextStyle, TextStyleRefinement,
|
||||
TransformationMatrix, Underline, UnderlineStyle, WindowAppearance, WindowBackgroundAppearance,
|
||||
WindowBounds, WindowControls, WindowDecorations, WindowOptions, WindowParams, WindowTextSystem,
|
||||
point, prelude::*, px, rems, size, transparent_black,
|
||||
PlatformInputHandler, PlatformWindow, Point, PolychromeSprite, Priority, PromptButton,
|
||||
PromptLevel, Quad, Render, RenderGlyphParams, RenderImage, RenderImageParams, RenderSvgParams,
|
||||
Replay, ResizeEdge, SMOOTH_SVG_SCALE_FACTOR, SUBPIXEL_VARIANTS_X, SUBPIXEL_VARIANTS_Y,
|
||||
ScaledPixels, Scene, Shadow, SharedString, Size, StrikethroughStyle, Style, SubscriberSet,
|
||||
Subscription, SystemWindowTab, SystemWindowTabController, TabStopMap, TaffyLayoutEngine, Task,
|
||||
TextStyle, TextStyleRefinement, TransformationMatrix, Underline, UnderlineStyle,
|
||||
WindowAppearance, WindowBackgroundAppearance, WindowBounds, WindowControls, WindowDecorations,
|
||||
WindowOptions, WindowParams, WindowTextSystem, point, prelude::*, px, rems, size,
|
||||
transparent_black,
|
||||
};
|
||||
use anyhow::{Context as _, Result, anyhow};
|
||||
use collections::{FxHashMap, FxHashSet};
|
||||
@@ -1725,6 +1726,27 @@ impl Window {
|
||||
})
|
||||
}
|
||||
|
||||
/// Spawn the future returned by the given closure on the application thread
|
||||
/// pool, with the given priority. The closure is provided a handle to the
|
||||
/// current window and an `AsyncWindowContext` for use within your future.
|
||||
#[track_caller]
|
||||
pub fn spawn_with_priority<AsyncFn, R>(
|
||||
&self,
|
||||
priority: Priority,
|
||||
cx: &App,
|
||||
f: AsyncFn,
|
||||
) -> Task<R>
|
||||
where
|
||||
R: 'static,
|
||||
AsyncFn: AsyncFnOnce(&mut AsyncWindowContext) -> R + 'static,
|
||||
{
|
||||
let handle = self.handle;
|
||||
cx.spawn_with_priority(priority, async move |app| {
|
||||
let mut async_window_cx = AsyncWindowContext::new_context(app.clone(), handle);
|
||||
f(&mut async_window_cx).await
|
||||
})
|
||||
}
|
||||
|
||||
fn bounds_changed(&mut self, cx: &mut App) {
|
||||
self.scale_factor = self.platform_window.scale_factor();
|
||||
self.viewport_size = self.platform_window.content_size();
|
||||
|
||||
@@ -12,7 +12,7 @@ mod session;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use async_dispatcher::{Dispatcher, Runnable, set_dispatcher};
|
||||
use gpui::{App, PlatformDispatcher, RunnableVariant};
|
||||
use gpui::{App, PlatformDispatcher, Priority, RunnableVariant};
|
||||
use project::Fs;
|
||||
pub use runtimelib::ExecutionState;
|
||||
|
||||
@@ -46,7 +46,7 @@ fn zed_dispatcher(cx: &mut App) -> impl Dispatcher {
|
||||
impl Dispatcher for ZedDispatcher {
|
||||
fn dispatch(&self, runnable: Runnable) {
|
||||
self.dispatcher
|
||||
.dispatch(RunnableVariant::Compat(runnable), None);
|
||||
.dispatch(RunnableVariant::Compat(runnable), None, Priority::default());
|
||||
}
|
||||
|
||||
fn dispatch_after(&self, duration: Duration, runnable: Runnable) {
|
||||
|
||||
@@ -2452,6 +2452,12 @@ impl Workspace {
|
||||
.0
|
||||
.split(' ')
|
||||
.flat_map(|k| Keystroke::parse(k).log_err())
|
||||
.map(|k| {
|
||||
cx.keyboard_mapper()
|
||||
.map_key_equivalent(k, true)
|
||||
.inner()
|
||||
.clone()
|
||||
})
|
||||
.collect();
|
||||
let _ = self.send_keystrokes_impl(keystrokes, window, cx);
|
||||
}
|
||||
|
||||
@@ -22,7 +22,8 @@ use git::{
|
||||
COMMIT_MESSAGE, DOT_GIT, FSMONITOR_DAEMON, GITIGNORE, INDEX_LOCK, LFS_DIR, status::GitSummary,
|
||||
};
|
||||
use gpui::{
|
||||
App, AppContext as _, AsyncApp, BackgroundExecutor, Context, Entity, EventEmitter, Task,
|
||||
App, AppContext as _, AsyncApp, BackgroundExecutor, Context, Entity, EventEmitter, Priority,
|
||||
Task,
|
||||
};
|
||||
use ignore::IgnoreStack;
|
||||
use language::DiskState;
|
||||
@@ -4144,7 +4145,7 @@ impl BackgroundScanner {
|
||||
|
||||
let progress_update_count = AtomicUsize::new(0);
|
||||
self.executor
|
||||
.scoped(|scope| {
|
||||
.scoped_priority(Priority::Low, |scope| {
|
||||
for _ in 0..self.executor.num_cpus() {
|
||||
scope.spawn(async {
|
||||
let mut last_progress_update_count = 0;
|
||||
|
||||
@@ -52,6 +52,8 @@ extend-exclude = [
|
||||
"crates/project_panel/benches/linux_repo_snapshot.txt",
|
||||
# Some multibuffer test cases have word fragments that register as typos
|
||||
"crates/multi_buffer/src/multi_buffer_tests.rs",
|
||||
# Macos apis
|
||||
"crates/gpui/src/platform/mac/dispatcher.rs",
|
||||
]
|
||||
|
||||
[default]
|
||||
|
||||
Reference in New Issue
Block a user