1#![doc = include_str!("../README.md")]
2#![allow(non_upper_case_globals)]
3#![allow(non_camel_case_types)]
4#![allow(non_snake_case)]
5
6use std::{
7 ffi::c_void,
8 future::Future,
9 pin::Pin,
10 ptr::NonNull,
11 sync::{
12 atomic::{AtomicBool, Ordering},
13 Arc, Mutex,
14 },
15 task::{Context, Poll, Waker},
16 time::Duration,
17};
18
19mod sys {
20 #![allow(dead_code)]
21 include!(concat!(env!("OUT_DIR"), "/dispatch_sys.rs"));
22}
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum JoinError {
27 Aborted,
29 PollAfterReady,
31}
32
33impl std::fmt::Display for JoinError {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 match self {
36 JoinError::Aborted => write!(f, "task was aborted"),
37 JoinError::PollAfterReady => write!(f, "task polled after completion"),
38 }
39 }
40}
41
42impl std::error::Error for JoinError {}
43
44enum TaskState<T> {
45 Running(async_task::Task<T>),
46 Completed,
47 Aborted,
48}
49
50pub struct Task<T>(TaskState<T>);
56
57impl<T> Task<T> {
58 pub fn abort(&mut self) {
62 self.0 = TaskState::Aborted;
63 }
64}
65
66impl<T> Drop for Task<T> {
67 fn drop(&mut self) {
68 if let TaskState::Running(task) = std::mem::replace(&mut self.0, TaskState::Completed) {
69 task.detach();
70 }
71 }
72}
73
74impl<T> Future for Task<T> {
75 type Output = Result<T, JoinError>;
76
77 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
78 match &mut self.0 {
79 TaskState::Running(task) => match Pin::new(task).poll(cx) {
80 Poll::Ready(val) => {
81 self.0 = TaskState::Completed;
82 Poll::Ready(Ok(val))
83 }
84 Poll::Pending => Poll::Pending,
85 },
86 TaskState::Completed => Poll::Ready(Err(JoinError::PollAfterReady)),
87 TaskState::Aborted => Poll::Ready(Err(JoinError::Aborted)),
88 }
89 }
90}
91
92pub fn spawn<F, T>(future: F) -> Task<T>
96where
97 F: Future<Output = T> + Send + 'static,
98 T: Send + 'static,
99{
100 let (runnable, task) = async_task::spawn(future, schedule_background);
101 runnable.schedule();
102 Task(TaskState::Running(task))
103}
104
105pub fn spawn_on_main<F, T>(future: F) -> Task<T>
109where
110 F: Future<Output = T> + 'static,
111 T: 'static,
112{
113 let (runnable, task) = async_task::spawn_local(future, schedule_main);
114 runnable.schedule();
115 Task(TaskState::Running(task))
116}
117
118pub fn spawn_after<F, T>(duration: Duration, future: F) -> Task<T>
123where
124 F: Future<Output = T> + Send + 'static,
125 T: Send + 'static,
126{
127 let first_schedule = Arc::new(AtomicBool::new(true));
128 let (runnable, task) = async_task::spawn(future, move |runnable: async_task::Runnable<()>| {
129 let ptr = runnable.into_raw().as_ptr() as *mut c_void;
130
131 if first_schedule.swap(false, Ordering::SeqCst) {
132 unsafe {
138 let when =
139 sys::dispatch_time(sys::DISPATCH_TIME_NOW as u64, duration.as_nanos() as i64);
140 sys::dispatch_after_f(
141 when,
142 sys::dispatch_get_global_queue(
143 sys::DISPATCH_QUEUE_PRIORITY_DEFAULT as isize,
144 0,
145 ),
146 ptr,
147 Some(trampoline),
148 );
149 }
150 } else {
151 unsafe {
156 sys::dispatch_async_f(
157 sys::dispatch_get_global_queue(
158 sys::DISPATCH_QUEUE_PRIORITY_DEFAULT as isize,
159 0,
160 ),
161 ptr,
162 Some(trampoline),
163 );
164 }
165 }
166 });
167 runnable.schedule();
168 Task(TaskState::Running(task))
169}
170
171pub fn sleep(duration: Duration) -> Sleep {
179 Sleep {
180 duration,
181 state: None,
182 }
183}
184
185pub struct Sleep {
189 duration: Duration,
190 state: Option<Arc<SleepState>>,
191}
192
193struct SleepState {
194 waker: Mutex<Option<Waker>>,
195 completed: AtomicBool,
196}
197
198impl Future for Sleep {
199 type Output = ();
200
201 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
202 if self.state.is_none() {
204 let state = Arc::new(SleepState {
205 waker: Mutex::new(Some(cx.waker().clone())),
206 completed: AtomicBool::new(false),
207 });
208
209 let gcd_state = Arc::clone(&state);
211 let ptr = Arc::into_raw(gcd_state) as *mut c_void;
212
213 unsafe {
219 let when = sys::dispatch_time(
220 sys::DISPATCH_TIME_NOW as u64,
221 self.duration.as_nanos() as i64,
222 );
223 sys::dispatch_after_f(
224 when,
225 sys::dispatch_get_global_queue(
226 sys::DISPATCH_QUEUE_PRIORITY_DEFAULT as isize,
227 0,
228 ),
229 ptr,
230 Some(sleep_trampoline),
231 );
232 }
233
234 self.state = Some(state);
235 return Poll::Pending;
236 }
237
238 let state = self.state.as_ref().unwrap();
240 if state.completed.load(Ordering::SeqCst) {
241 Poll::Ready(())
242 } else {
243 *state.waker.lock().unwrap() = Some(cx.waker().clone());
245 Poll::Pending
246 }
247 }
248}
249
250extern "C" fn sleep_trampoline(context: *mut c_void) {
251 let state = unsafe { Arc::from_raw(context as *const SleepState) };
254 state.completed.store(true, Ordering::SeqCst);
255 let waker = state.waker.lock().unwrap().take();
256 drop(state);
257 if let Some(waker) = waker {
258 waker.wake();
259 }
260}
261
262fn dispatch_get_main_queue() -> sys::dispatch_queue_t {
263 std::ptr::addr_of!(sys::_dispatch_main_q) as *const _ as sys::dispatch_queue_t
264}
265
266fn schedule_background(runnable: async_task::Runnable<()>) {
267 let ptr = runnable.into_raw().as_ptr() as *mut c_void;
268 unsafe {
273 sys::dispatch_async_f(
274 sys::dispatch_get_global_queue(sys::DISPATCH_QUEUE_PRIORITY_DEFAULT as isize, 0),
275 ptr,
276 Some(trampoline),
277 );
278 }
279}
280
281fn schedule_main(runnable: async_task::Runnable<()>) {
282 let ptr = runnable.into_raw().as_ptr() as *mut c_void;
283 unsafe {
288 sys::dispatch_async_f(dispatch_get_main_queue(), ptr, Some(trampoline));
289 }
290}
291
292extern "C" fn trampoline(context: *mut c_void) {
293 let runnable =
299 unsafe { async_task::Runnable::<()>::from_raw(NonNull::new_unchecked(context as *mut ())) };
300 runnable.run();
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306 use std::sync::mpsc;
307 use std::time::Duration;
308
309 #[test]
310 fn test_spawn_await() {
311 let task = spawn(async { 42 });
312 let result = pollster::block_on(task);
313 assert_eq!(result, Ok(42));
314 }
315
316 #[test]
317 fn test_spawn_drop_continues() {
318 let (tx, rx) = mpsc::channel();
320
321 drop(spawn(async move {
322 tx.send(42).unwrap();
323 }));
324
325 let result = rx.recv_timeout(Duration::from_secs(1)).unwrap();
326 assert_eq!(result, 42);
327 }
328
329 #[test]
330 fn test_spawn_after_delays() {
331 let start = std::time::Instant::now();
332
333 let task = spawn_after(Duration::from_millis(100), async { 123 });
334 let result = pollster::block_on(task);
335
336 assert_eq!(result, Ok(123));
337 assert!(
338 start.elapsed() >= Duration::from_millis(100),
339 "expected at least 100ms delay, got {:?}",
340 start.elapsed()
341 );
342 }
343
344 #[test]
345 fn test_sleep() {
346 let start = std::time::Instant::now();
347
348 let task = spawn(async {
349 sleep(Duration::from_millis(100)).await;
350 "done"
351 });
352 let result = pollster::block_on(task);
353
354 assert_eq!(result, Ok("done"));
355 assert!(
356 start.elapsed() >= Duration::from_millis(100),
357 "expected at least 100ms delay, got {:?}",
358 start.elapsed()
359 );
360 }
361
362 #[test]
363 fn test_sleep_zero_duration() {
364 let task = spawn(async {
365 sleep(Duration::ZERO).await;
366 "done"
367 });
368 let result = pollster::block_on(task);
369 assert_eq!(result, Ok("done"));
370 }
371
372 #[test]
373 fn test_abort() {
374 let (tx, rx) = mpsc::channel();
375
376 let mut task = spawn(async move {
377 sleep(Duration::from_millis(200)).await;
378 tx.send(()).unwrap();
379 42
380 });
381
382 std::thread::sleep(Duration::from_millis(10));
384
385 task.abort();
386
387 assert!(rx.recv_timeout(Duration::from_millis(300)).is_err());
389
390 let result = pollster::block_on(task);
392 assert_eq!(result, Err(JoinError::Aborted));
393 }
394
395 #[test]
396 fn test_nested_spawn_await() {
397 let task = spawn(async {
398 let inner = spawn(async { 42 });
399 inner.await.unwrap() + 1
400 });
401
402 let result = pollster::block_on(task);
403 assert_eq!(result, Ok(43));
404 }
405
406 #[test]
407 fn test_poll_after_ready() {
408 use std::future::Future;
409 use std::pin::Pin;
410 use std::task::{Context, Poll, Waker};
411
412 let mut task = spawn(async { 1 });
413
414 let waker = Waker::noop();
416 let mut cx = Context::from_waker(&waker);
417 loop {
418 match Pin::new(&mut task).poll(&mut cx) {
419 Poll::Ready(Ok(1)) => break,
420 Poll::Ready(other) => panic!("unexpected result: {:?}", other),
421 Poll::Pending => std::thread::yield_now(),
422 }
423 }
424
425 let result = Pin::new(&mut task).poll(&mut cx);
427 assert!(matches!(
428 result,
429 Poll::Ready(Err(JoinError::PollAfterReady))
430 ));
431 }
432}