async_dispatch/
lib.rs

1#![doc = include_str!("../README.md")]
2#![allow(non_upper_case_globals)]
3#![allow(non_camel_case_types)]
4#![allow(non_snake_case)]
5
6use std::{
7    ffi::c_void,
8    future::Future,
9    pin::Pin,
10    ptr::NonNull,
11    sync::{
12        atomic::{AtomicBool, Ordering},
13        Arc, Mutex,
14    },
15    task::{Context, Poll, Waker},
16    time::Duration,
17};
18
19mod sys {
20    #![allow(dead_code)]
21    include!(concat!(env!("OUT_DIR"), "/dispatch_sys.rs"));
22}
23
24/// Error returned when awaiting a task that cannot produce a result.
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum JoinError {
27    /// The task was aborted before completion.
28    Aborted,
29    /// The task was polled after already returning a result.
30    PollAfterReady,
31}
32
33impl std::fmt::Display for JoinError {
34    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35        match self {
36            JoinError::Aborted => write!(f, "task was aborted"),
37            JoinError::PollAfterReady => write!(f, "task polled after completion"),
38        }
39    }
40}
41
42impl std::error::Error for JoinError {}
43
44enum TaskState<T> {
45    Running(async_task::Task<T>),
46    Completed,
47    Aborted,
48}
49
50/// A handle to a spawned task that can be awaited for its result.
51///
52/// Dropping a Task without awaiting it allows the task to continue running
53/// in the background; the result is simply discarded. This matches tokio's
54/// `JoinHandle` semantics.
55pub struct Task<T>(TaskState<T>);
56
57impl<T> Task<T> {
58    /// Abort the task, cancelling it at the next yield point.
59    ///
60    /// After calling abort, awaiting this task will return `Err(JoinError::Aborted)`.
61    pub fn abort(&mut self) {
62        self.0 = TaskState::Aborted;
63    }
64}
65
66impl<T> Drop for Task<T> {
67    fn drop(&mut self) {
68        if let TaskState::Running(task) = std::mem::replace(&mut self.0, TaskState::Completed) {
69            task.detach();
70        }
71    }
72}
73
74impl<T> Future for Task<T> {
75    type Output = Result<T, JoinError>;
76
77    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
78        match &mut self.0 {
79            TaskState::Running(task) => match Pin::new(task).poll(cx) {
80                Poll::Ready(val) => {
81                    self.0 = TaskState::Completed;
82                    Poll::Ready(Ok(val))
83                }
84                Poll::Pending => Poll::Pending,
85            },
86            TaskState::Completed => Poll::Ready(Err(JoinError::PollAfterReady)),
87            TaskState::Aborted => Poll::Ready(Err(JoinError::Aborted)),
88        }
89    }
90}
91
92/// Spawn a future on a background GCD queue.
93///
94/// The future will be polled on one of the system's global concurrent queues.
95pub fn spawn<F, T>(future: F) -> Task<T>
96where
97    F: Future<Output = T> + Send + 'static,
98    T: Send + 'static,
99{
100    let (runnable, task) = async_task::spawn(future, schedule_background);
101    runnable.schedule();
102    Task(TaskState::Running(task))
103}
104
105/// Spawn a future on the main thread's dispatch queue.
106///
107/// Use this for work that must run on the main thread, such as UI updates.
108pub fn spawn_on_main<F, T>(future: F) -> Task<T>
109where
110    F: Future<Output = T> + 'static,
111    T: 'static,
112{
113    let (runnable, task) = async_task::spawn_local(future, schedule_main);
114    runnable.schedule();
115    Task(TaskState::Running(task))
116}
117
118/// Spawn a future on a background queue after a delay.
119///
120/// The delay only applies to the initial spawn. If the future yields and is
121/// woken again, subsequent polls happen immediately.
122pub fn spawn_after<F, T>(duration: Duration, future: F) -> Task<T>
123where
124    F: Future<Output = T> + Send + 'static,
125    T: Send + 'static,
126{
127    let first_schedule = Arc::new(AtomicBool::new(true));
128    let (runnable, task) = async_task::spawn(future, move |runnable: async_task::Runnable<()>| {
129        let ptr = runnable.into_raw().as_ptr() as *mut c_void;
130
131        if first_schedule.swap(false, Ordering::SeqCst) {
132            // SAFETY: We call GCD's dispatch_after_f with:
133            // - A valid dispatch_time computed from DISPATCH_TIME_NOW
134            // - A valid global queue handle from dispatch_get_global_queue
135            // - A pointer from Runnable::into_raw() which transfers ownership to GCD
136            // - trampoline, which will reconstruct the Runnable exactly once
137            unsafe {
138                let when =
139                    sys::dispatch_time(sys::DISPATCH_TIME_NOW as u64, duration.as_nanos() as i64);
140                sys::dispatch_after_f(
141                    when,
142                    sys::dispatch_get_global_queue(
143                        sys::DISPATCH_QUEUE_PRIORITY_DEFAULT as isize,
144                        0,
145                    ),
146                    ptr,
147                    Some(trampoline),
148                );
149            }
150        } else {
151            // SAFETY: We call GCD's dispatch_async_f with:
152            // - A valid global queue handle from dispatch_get_global_queue
153            // - A pointer from Runnable::into_raw() which transfers ownership to GCD
154            // - trampoline, which will reconstruct the Runnable exactly once
155            unsafe {
156                sys::dispatch_async_f(
157                    sys::dispatch_get_global_queue(
158                        sys::DISPATCH_QUEUE_PRIORITY_DEFAULT as isize,
159                        0,
160                    ),
161                    ptr,
162                    Some(trampoline),
163                );
164            }
165        }
166    });
167    runnable.schedule();
168    Task(TaskState::Running(task))
169}
170
171/// Returns a future that completes after the specified duration.
172///
173/// This is the async equivalent of `std::thread::sleep`. The timer is
174/// managed by GCD and does not block any threads while waiting.
175///
176/// Note: The timer cannot be cancelled. If the `Sleep` future is dropped
177/// before completion, the underlying GCD timer still fires but does nothing.
178pub fn sleep(duration: Duration) -> Sleep {
179    Sleep {
180        duration,
181        state: None,
182    }
183}
184
185/// A future that completes after a duration.
186///
187/// Created by the [`sleep`] function.
188pub struct Sleep {
189    duration: Duration,
190    state: Option<Arc<SleepState>>,
191}
192
193struct SleepState {
194    waker: Mutex<Option<Waker>>,
195    completed: AtomicBool,
196}
197
198impl Future for Sleep {
199    type Output = ();
200
201    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
202        // If we haven't started the timer yet, do so now
203        if self.state.is_none() {
204            let state = Arc::new(SleepState {
205                waker: Mutex::new(Some(cx.waker().clone())),
206                completed: AtomicBool::new(false),
207            });
208
209            // Clone for GCD to own
210            let gcd_state = Arc::clone(&state);
211            let ptr = Arc::into_raw(gcd_state) as *mut c_void;
212
213            // SAFETY: We call GCD's dispatch_after_f with:
214            // - A valid dispatch_time computed from DISPATCH_TIME_NOW
215            // - A valid global queue handle from dispatch_get_global_queue
216            // - A pointer from Arc::into_raw() which transfers one ref count to GCD
217            // - sleep_trampoline, which will call Arc::from_raw() exactly once
218            unsafe {
219                let when = sys::dispatch_time(
220                    sys::DISPATCH_TIME_NOW as u64,
221                    self.duration.as_nanos() as i64,
222                );
223                sys::dispatch_after_f(
224                    when,
225                    sys::dispatch_get_global_queue(
226                        sys::DISPATCH_QUEUE_PRIORITY_DEFAULT as isize,
227                        0,
228                    ),
229                    ptr,
230                    Some(sleep_trampoline),
231                );
232            }
233
234            self.state = Some(state);
235            return Poll::Pending;
236        }
237
238        // Timer already started - check if it's completed
239        let state = self.state.as_ref().unwrap();
240        if state.completed.load(Ordering::SeqCst) {
241            Poll::Ready(())
242        } else {
243            // Update the waker in case it changed
244            *state.waker.lock().unwrap() = Some(cx.waker().clone());
245            Poll::Pending
246        }
247    }
248}
249
250extern "C" fn sleep_trampoline(context: *mut c_void) {
251    // SAFETY: This pointer was created by Arc::into_raw() in Sleep::poll.
252    // GCD calls this exactly once, so we reclaim the Arc reference here.
253    let state = unsafe { Arc::from_raw(context as *const SleepState) };
254    state.completed.store(true, Ordering::SeqCst);
255    let waker = state.waker.lock().unwrap().take();
256    drop(state);
257    if let Some(waker) = waker {
258        waker.wake();
259    }
260}
261
262fn dispatch_get_main_queue() -> sys::dispatch_queue_t {
263    std::ptr::addr_of!(sys::_dispatch_main_q) as *const _ as sys::dispatch_queue_t
264}
265
266fn schedule_background(runnable: async_task::Runnable<()>) {
267    let ptr = runnable.into_raw().as_ptr() as *mut c_void;
268    // SAFETY: We call GCD's dispatch_async_f with:
269    // - A valid global queue handle from dispatch_get_global_queue
270    // - A pointer from Runnable::into_raw() which transfers ownership to GCD
271    // - trampoline, which will reconstruct the Runnable exactly once
272    unsafe {
273        sys::dispatch_async_f(
274            sys::dispatch_get_global_queue(sys::DISPATCH_QUEUE_PRIORITY_DEFAULT as isize, 0),
275            ptr,
276            Some(trampoline),
277        );
278    }
279}
280
281fn schedule_main(runnable: async_task::Runnable<()>) {
282    let ptr = runnable.into_raw().as_ptr() as *mut c_void;
283    // SAFETY: We call GCD's dispatch_async_f with:
284    // - The main queue handle (a valid static queue)
285    // - A pointer from Runnable::into_raw() which transfers ownership to GCD
286    // - trampoline, which will reconstruct the Runnable exactly once
287    unsafe {
288        sys::dispatch_async_f(dispatch_get_main_queue(), ptr, Some(trampoline));
289    }
290}
291
292extern "C" fn trampoline(context: *mut c_void) {
293    // SAFETY: This function is only called by GCD with a pointer that was created
294    // by Runnable::into_raw() in one of the schedule functions. GCD guarantees:
295    // - The pointer is passed exactly once per dispatch
296    // - The pointer value is unchanged from what we provided
297    // We reconstruct the Runnable, taking back ownership, and run it.
298    let runnable =
299        unsafe { async_task::Runnable::<()>::from_raw(NonNull::new_unchecked(context as *mut ())) };
300    runnable.run();
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306    use std::sync::mpsc;
307    use std::time::Duration;
308
309    #[test]
310    fn test_spawn_await() {
311        let task = spawn(async { 42 });
312        let result = pollster::block_on(task);
313        assert_eq!(result, Ok(42));
314    }
315
316    #[test]
317    fn test_spawn_drop_continues() {
318        // Dropping a task lets it continue running (detach-on-drop)
319        let (tx, rx) = mpsc::channel();
320
321        drop(spawn(async move {
322            tx.send(42).unwrap();
323        }));
324
325        let result = rx.recv_timeout(Duration::from_secs(1)).unwrap();
326        assert_eq!(result, 42);
327    }
328
329    #[test]
330    fn test_spawn_after_delays() {
331        let start = std::time::Instant::now();
332
333        let task = spawn_after(Duration::from_millis(100), async { 123 });
334        let result = pollster::block_on(task);
335
336        assert_eq!(result, Ok(123));
337        assert!(
338            start.elapsed() >= Duration::from_millis(100),
339            "expected at least 100ms delay, got {:?}",
340            start.elapsed()
341        );
342    }
343
344    #[test]
345    fn test_sleep() {
346        let start = std::time::Instant::now();
347
348        let task = spawn(async {
349            sleep(Duration::from_millis(100)).await;
350            "done"
351        });
352        let result = pollster::block_on(task);
353
354        assert_eq!(result, Ok("done"));
355        assert!(
356            start.elapsed() >= Duration::from_millis(100),
357            "expected at least 100ms delay, got {:?}",
358            start.elapsed()
359        );
360    }
361
362    #[test]
363    fn test_sleep_zero_duration() {
364        let task = spawn(async {
365            sleep(Duration::ZERO).await;
366            "done"
367        });
368        let result = pollster::block_on(task);
369        assert_eq!(result, Ok("done"));
370    }
371
372    #[test]
373    fn test_abort() {
374        let (tx, rx) = mpsc::channel();
375
376        let mut task = spawn(async move {
377            sleep(Duration::from_millis(200)).await;
378            tx.send(()).unwrap();
379            42
380        });
381
382        // Give it time to start
383        std::thread::sleep(Duration::from_millis(10));
384
385        task.abort();
386
387        // The channel should never receive (task was cancelled)
388        assert!(rx.recv_timeout(Duration::from_millis(300)).is_err());
389
390        // Awaiting returns Aborted
391        let result = pollster::block_on(task);
392        assert_eq!(result, Err(JoinError::Aborted));
393    }
394
395    #[test]
396    fn test_nested_spawn_await() {
397        let task = spawn(async {
398            let inner = spawn(async { 42 });
399            inner.await.unwrap() + 1
400        });
401
402        let result = pollster::block_on(task);
403        assert_eq!(result, Ok(43));
404    }
405
406    #[test]
407    fn test_poll_after_ready() {
408        use std::future::Future;
409        use std::pin::Pin;
410        use std::task::{Context, Poll, Waker};
411
412        let mut task = spawn(async { 1 });
413
414        // Poll to completion
415        let waker = Waker::noop();
416        let mut cx = Context::from_waker(&waker);
417        loop {
418            match Pin::new(&mut task).poll(&mut cx) {
419                Poll::Ready(Ok(1)) => break,
420                Poll::Ready(other) => panic!("unexpected result: {:?}", other),
421                Poll::Pending => std::thread::yield_now(),
422            }
423        }
424
425        // Poll again after ready
426        let result = Pin::new(&mut task).poll(&mut cx);
427        assert!(matches!(
428            result,
429            Poll::Ready(Err(JoinError::PollAfterReady))
430        ));
431    }
432}