233 lines
6.1 KiB
Rust
233 lines
6.1 KiB
Rust
use crate::{Scheduler, SessionId, Timer};
|
|
use chrono::{DateTime, Utc};
|
|
use futures::FutureExt as _;
|
|
use std::{
|
|
future::Future,
|
|
marker::PhantomData,
|
|
mem::ManuallyDrop,
|
|
panic::Location,
|
|
pin::Pin,
|
|
rc::Rc,
|
|
sync::Arc,
|
|
task::{Context, Poll},
|
|
thread::{self, ThreadId},
|
|
time::Duration,
|
|
};
|
|
|
|
#[derive(Clone)]
|
|
pub struct ForegroundExecutor {
|
|
session_id: SessionId,
|
|
scheduler: Arc<dyn Scheduler>,
|
|
not_send: PhantomData<Rc<()>>,
|
|
}
|
|
|
|
impl ForegroundExecutor {
|
|
#[track_caller]
|
|
pub fn spawn<F>(&self, future: F) -> Task<F::Output>
|
|
where
|
|
F: Future + 'static,
|
|
F::Output: 'static,
|
|
{
|
|
let session_id = self.session_id;
|
|
let scheduler = Arc::clone(&self.scheduler);
|
|
let (runnable, task) = spawn_local_with_source_location(future, move |runnable| {
|
|
scheduler.schedule_foreground(session_id, runnable);
|
|
});
|
|
runnable.schedule();
|
|
Task(TaskState::Spawned(task))
|
|
}
|
|
|
|
pub fn block_on<Fut: Future>(&self, future: Fut) -> Fut::Output {
|
|
let mut output = None;
|
|
self.scheduler.block(
|
|
self.session_id,
|
|
async { output = Some(future.await) }.boxed_local(),
|
|
None,
|
|
);
|
|
output.unwrap()
|
|
}
|
|
|
|
pub fn block_with_timeout<Fut: Unpin + Future>(
|
|
&self,
|
|
future: &mut Fut,
|
|
timeout: Duration,
|
|
) -> Option<Fut::Output> {
|
|
let mut output = None;
|
|
self.scheduler.block(
|
|
self.session_id,
|
|
async { output = Some(future.await) }.boxed_local(),
|
|
Some(timeout),
|
|
);
|
|
output
|
|
}
|
|
|
|
pub fn timer(&self, duration: Duration) -> Timer {
|
|
self.scheduler.timer(duration)
|
|
}
|
|
}
|
|
|
|
impl ForegroundExecutor {
|
|
pub fn new(session_id: SessionId, scheduler: Arc<dyn Scheduler>) -> Self {
|
|
Self {
|
|
session_id,
|
|
scheduler,
|
|
not_send: PhantomData,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl BackgroundExecutor {
|
|
pub fn new(scheduler: Arc<dyn Scheduler>) -> Self {
|
|
Self { scheduler }
|
|
}
|
|
}
|
|
|
|
#[derive(Clone)]
|
|
pub struct BackgroundExecutor {
|
|
scheduler: Arc<dyn Scheduler>,
|
|
}
|
|
|
|
impl BackgroundExecutor {
|
|
pub fn now(&self) -> DateTime<Utc> {
|
|
self.scheduler.now()
|
|
}
|
|
|
|
pub fn spawn<F>(&self, future: F) -> Task<F::Output>
|
|
where
|
|
F: Future + Send + 'static,
|
|
F::Output: Send + 'static,
|
|
{
|
|
let scheduler = Arc::clone(&self.scheduler);
|
|
let (runnable, task) = async_task::spawn(future, move |runnable| {
|
|
scheduler.schedule_background(runnable);
|
|
});
|
|
runnable.schedule();
|
|
Task(TaskState::Spawned(task))
|
|
}
|
|
|
|
pub fn timer(&self, duration: Duration) -> Timer {
|
|
self.scheduler.timer(duration)
|
|
}
|
|
|
|
pub fn scheduler(&self) -> &Arc<dyn Scheduler> {
|
|
&self.scheduler
|
|
}
|
|
}
|
|
|
|
/// Task is a primitive that allows work to happen in the background.
|
|
///
|
|
/// It implements [`Future`] so you can `.await` on it.
|
|
///
|
|
/// If you drop a task it will be cancelled immediately. Calling [`Task::detach`] allows
|
|
/// the task to continue running, but with no way to return a value.
|
|
#[must_use]
|
|
#[derive(Debug)]
|
|
pub struct Task<T>(TaskState<T>);
|
|
|
|
#[derive(Debug)]
|
|
enum TaskState<T> {
|
|
/// A task that is ready to return a value
|
|
Ready(Option<T>),
|
|
|
|
/// A task that is currently running.
|
|
Spawned(async_task::Task<T>),
|
|
}
|
|
|
|
impl<T> Task<T> {
|
|
/// Creates a new task that will resolve with the value
|
|
pub fn ready(val: T) -> Self {
|
|
Task(TaskState::Ready(Some(val)))
|
|
}
|
|
|
|
pub fn is_ready(&self) -> bool {
|
|
match &self.0 {
|
|
TaskState::Ready(_) => true,
|
|
TaskState::Spawned(task) => task.is_finished(),
|
|
}
|
|
}
|
|
|
|
/// Detaching a task runs it to completion in the background
|
|
pub fn detach(self) {
|
|
match self {
|
|
Task(TaskState::Ready(_)) => {}
|
|
Task(TaskState::Spawned(task)) => task.detach(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<T> Future for Task<T> {
|
|
type Output = T;
|
|
|
|
fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
|
|
match unsafe { self.get_unchecked_mut() } {
|
|
Task(TaskState::Ready(val)) => Poll::Ready(val.take().unwrap()),
|
|
Task(TaskState::Spawned(task)) => Pin::new(task).poll(cx),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Variant of `async_task::spawn_local` that includes the source location of the spawn in panics.
|
|
///
|
|
/// Copy-modified from:
|
|
/// <https://github.com/smol-rs/async-task/blob/ca9dbe1db9c422fd765847fa91306e30a6bb58a9/src/runnable.rs#L405>
|
|
#[track_caller]
|
|
fn spawn_local_with_source_location<Fut, S>(
|
|
future: Fut,
|
|
schedule: S,
|
|
) -> (async_task::Runnable, async_task::Task<Fut::Output, ()>)
|
|
where
|
|
Fut: Future + 'static,
|
|
Fut::Output: 'static,
|
|
S: async_task::Schedule + Send + Sync + 'static,
|
|
{
|
|
#[inline]
|
|
fn thread_id() -> ThreadId {
|
|
std::thread_local! {
|
|
static ID: ThreadId = thread::current().id();
|
|
}
|
|
ID.try_with(|id| *id)
|
|
.unwrap_or_else(|_| thread::current().id())
|
|
}
|
|
|
|
struct Checked<F> {
|
|
id: ThreadId,
|
|
inner: ManuallyDrop<F>,
|
|
location: &'static Location<'static>,
|
|
}
|
|
|
|
impl<F> Drop for Checked<F> {
|
|
fn drop(&mut self) {
|
|
assert!(
|
|
self.id == thread_id(),
|
|
"local task dropped by a thread that didn't spawn it. Task spawned at {}",
|
|
self.location
|
|
);
|
|
unsafe {
|
|
ManuallyDrop::drop(&mut self.inner);
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<F: Future> Future for Checked<F> {
|
|
type Output = F::Output;
|
|
|
|
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
|
|
assert!(
|
|
self.id == thread_id(),
|
|
"local task polled by a thread that didn't spawn it. Task spawned at {}",
|
|
self.location
|
|
);
|
|
unsafe { self.map_unchecked_mut(|c| &mut *c.inner).poll(cx) }
|
|
}
|
|
}
|
|
|
|
// Wrap the future into one that checks which thread it's on.
|
|
let future = Checked {
|
|
id: thread_id(),
|
|
inner: ManuallyDrop::new(future),
|
|
location: Location::caller(),
|
|
};
|
|
|
|
unsafe { async_task::spawn_unchecked(future, schedule) }
|
|
}
|