Files
zed/crates/scheduler/src/executor.rs
Antonio Scandurra ebe835dd91 Checkpoint
2025-09-05 10:37:36 +02:00

233 lines
6.1 KiB
Rust

use crate::{Scheduler, SessionId, Timer};
use chrono::{DateTime, Utc};
use futures::FutureExt as _;
use std::{
future::Future,
marker::PhantomData,
mem::ManuallyDrop,
panic::Location,
pin::Pin,
rc::Rc,
sync::Arc,
task::{Context, Poll},
thread::{self, ThreadId},
time::Duration,
};
#[derive(Clone)]
pub struct ForegroundExecutor {
session_id: SessionId,
scheduler: Arc<dyn Scheduler>,
not_send: PhantomData<Rc<()>>,
}
impl ForegroundExecutor {
#[track_caller]
pub fn spawn<F>(&self, future: F) -> Task<F::Output>
where
F: Future + 'static,
F::Output: 'static,
{
let session_id = self.session_id;
let scheduler = Arc::clone(&self.scheduler);
let (runnable, task) = spawn_local_with_source_location(future, move |runnable| {
scheduler.schedule_foreground(session_id, runnable);
});
runnable.schedule();
Task(TaskState::Spawned(task))
}
pub fn block_on<Fut: Future>(&self, future: Fut) -> Fut::Output {
let mut output = None;
self.scheduler.block(
self.session_id,
async { output = Some(future.await) }.boxed_local(),
None,
);
output.unwrap()
}
pub fn block_with_timeout<Fut: Unpin + Future>(
&self,
future: &mut Fut,
timeout: Duration,
) -> Option<Fut::Output> {
let mut output = None;
self.scheduler.block(
self.session_id,
async { output = Some(future.await) }.boxed_local(),
Some(timeout),
);
output
}
pub fn timer(&self, duration: Duration) -> Timer {
self.scheduler.timer(duration)
}
}
impl ForegroundExecutor {
pub fn new(session_id: SessionId, scheduler: Arc<dyn Scheduler>) -> Self {
Self {
session_id,
scheduler,
not_send: PhantomData,
}
}
}
impl BackgroundExecutor {
pub fn new(scheduler: Arc<dyn Scheduler>) -> Self {
Self { scheduler }
}
}
#[derive(Clone)]
pub struct BackgroundExecutor {
scheduler: Arc<dyn Scheduler>,
}
impl BackgroundExecutor {
pub fn now(&self) -> DateTime<Utc> {
self.scheduler.now()
}
pub fn spawn<F>(&self, future: F) -> Task<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let scheduler = Arc::clone(&self.scheduler);
let (runnable, task) = async_task::spawn(future, move |runnable| {
scheduler.schedule_background(runnable);
});
runnable.schedule();
Task(TaskState::Spawned(task))
}
pub fn timer(&self, duration: Duration) -> Timer {
self.scheduler.timer(duration)
}
pub fn scheduler(&self) -> &Arc<dyn Scheduler> {
&self.scheduler
}
}
/// Task is a primitive that allows work to happen in the background.
///
/// It implements [`Future`] so you can `.await` on it.
///
/// If you drop a task it will be cancelled immediately. Calling [`Task::detach`] allows
/// the task to continue running, but with no way to return a value.
#[must_use]
#[derive(Debug)]
pub struct Task<T>(TaskState<T>);
#[derive(Debug)]
enum TaskState<T> {
/// A task that is ready to return a value
Ready(Option<T>),
/// A task that is currently running.
Spawned(async_task::Task<T>),
}
impl<T> Task<T> {
/// Creates a new task that will resolve with the value
pub fn ready(val: T) -> Self {
Task(TaskState::Ready(Some(val)))
}
pub fn is_ready(&self) -> bool {
match &self.0 {
TaskState::Ready(_) => true,
TaskState::Spawned(task) => task.is_finished(),
}
}
/// Detaching a task runs it to completion in the background
pub fn detach(self) {
match self {
Task(TaskState::Ready(_)) => {}
Task(TaskState::Spawned(task)) => task.detach(),
}
}
}
impl<T> Future for Task<T> {
type Output = T;
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
match unsafe { self.get_unchecked_mut() } {
Task(TaskState::Ready(val)) => Poll::Ready(val.take().unwrap()),
Task(TaskState::Spawned(task)) => Pin::new(task).poll(cx),
}
}
}
/// Variant of `async_task::spawn_local` that includes the source location of the spawn in panics.
///
/// Copy-modified from:
/// <https://github.com/smol-rs/async-task/blob/ca9dbe1db9c422fd765847fa91306e30a6bb58a9/src/runnable.rs#L405>
#[track_caller]
fn spawn_local_with_source_location<Fut, S>(
future: Fut,
schedule: S,
) -> (async_task::Runnable, async_task::Task<Fut::Output, ()>)
where
Fut: Future + 'static,
Fut::Output: 'static,
S: async_task::Schedule + Send + Sync + 'static,
{
#[inline]
fn thread_id() -> ThreadId {
std::thread_local! {
static ID: ThreadId = thread::current().id();
}
ID.try_with(|id| *id)
.unwrap_or_else(|_| thread::current().id())
}
struct Checked<F> {
id: ThreadId,
inner: ManuallyDrop<F>,
location: &'static Location<'static>,
}
impl<F> Drop for Checked<F> {
fn drop(&mut self) {
assert!(
self.id == thread_id(),
"local task dropped by a thread that didn't spawn it. Task spawned at {}",
self.location
);
unsafe {
ManuallyDrop::drop(&mut self.inner);
}
}
}
impl<F: Future> Future for Checked<F> {
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
assert!(
self.id == thread_id(),
"local task polled by a thread that didn't spawn it. Task spawned at {}",
self.location
);
unsafe { self.map_unchecked_mut(|c| &mut *c.inner).poll(cx) }
}
}
// Wrap the future into one that checks which thread it's on.
let future = Checked {
id: thread_id(),
inner: ManuallyDrop::new(future),
location: Location::caller(),
};
unsafe { async_task::spawn_unchecked(future, schedule) }
}