1 use core::cell::UnsafeCell;
2 use core::fmt;
3 use core::task::Waker;
4 
5 #[cfg(not(feature = "portable-atomic"))]
6 use core::sync::atomic::AtomicUsize;
7 use core::sync::atomic::Ordering;
8 #[cfg(feature = "portable-atomic")]
9 use portable_atomic::AtomicUsize;
10 
11 use crate::raw::TaskVTable;
12 use crate::state::*;
13 use crate::utils::abort_on_panic;
14 
15 /// The header of a task.
16 ///
17 /// This header is stored in memory at the beginning of the heap-allocated task.
18 pub(crate) struct Header<M> {
19     /// Current state of the task.
20     ///
21     /// Contains flags representing the current state and the reference count.
22     pub(crate) state: AtomicUsize,
23 
24     /// The task that is blocked on the `Task` handle.
25     ///
26     /// This waker needs to be woken up once the task completes or is closed.
27     pub(crate) awaiter: UnsafeCell<Option<Waker>>,
28 
29     /// The virtual table.
30     ///
31     /// In addition to the actual waker virtual table, it also contains pointers to several other
32     /// methods necessary for bookkeeping the heap-allocated task.
33     pub(crate) vtable: &'static TaskVTable,
34 
35     /// Metadata associated with the task.
36     ///
37     /// This metadata may be provided to the user.
38     pub(crate) metadata: M,
39 
40     /// Whether or not a panic that occurs in the task should be propagated.
41     #[cfg(feature = "std")]
42     pub(crate) propagate_panic: bool,
43 }
44 
45 impl<M> Header<M> {
46     /// Notifies the awaiter blocked on this task.
47     ///
48     /// If the awaiter is the same as the current waker, it will not be notified.
49     #[inline]
notify(&self, current: Option<&Waker>)50     pub(crate) fn notify(&self, current: Option<&Waker>) {
51         if let Some(w) = self.take(current) {
52             abort_on_panic(|| w.wake());
53         }
54     }
55 
56     /// Takes the awaiter blocked on this task.
57     ///
58     /// If there is no awaiter or if it is the same as the current waker, returns `None`.
59     #[inline]
take(&self, current: Option<&Waker>) -> Option<Waker>60     pub(crate) fn take(&self, current: Option<&Waker>) -> Option<Waker> {
61         // Set the bit indicating that the task is notifying its awaiter.
62         let state = self.state.fetch_or(NOTIFYING, Ordering::AcqRel);
63 
64         // If the task was not notifying or registering an awaiter...
65         if state & (NOTIFYING | REGISTERING) == 0 {
66             // Take the waker out.
67             let waker = unsafe { (*self.awaiter.get()).take() };
68 
69             // Unset the bit indicating that the task is notifying its awaiter.
70             self.state
71                 .fetch_and(!NOTIFYING & !AWAITER, Ordering::Release);
72 
73             // Finally, notify the waker if it's different from the current waker.
74             if let Some(w) = waker {
75                 match current {
76                     None => return Some(w),
77                     Some(c) if !w.will_wake(c) => return Some(w),
78                     Some(_) => abort_on_panic(|| drop(w)),
79                 }
80             }
81         }
82 
83         None
84     }
85 
86     /// Registers a new awaiter blocked on this task.
87     ///
88     /// This method is called when `Task` is polled and it has not yet completed.
89     #[inline]
register(&self, waker: &Waker)90     pub(crate) fn register(&self, waker: &Waker) {
91         // Load the state and synchronize with it.
92         let mut state = self.state.fetch_or(0, Ordering::Acquire);
93 
94         loop {
95             // There can't be two concurrent registrations because `Task` can only be polled
96             // by a unique pinned reference.
97             debug_assert!(state & REGISTERING == 0);
98 
99             // If we're in the notifying state at this moment, just wake and return without
100             // registering.
101             if state & NOTIFYING != 0 {
102                 abort_on_panic(|| waker.wake_by_ref());
103                 return;
104             }
105 
106             // Mark the state to let other threads know we're registering a new awaiter.
107             match self.state.compare_exchange_weak(
108                 state,
109                 state | REGISTERING,
110                 Ordering::AcqRel,
111                 Ordering::Acquire,
112             ) {
113                 Ok(_) => {
114                     state |= REGISTERING;
115                     break;
116                 }
117                 Err(s) => state = s,
118             }
119         }
120 
121         // Put the waker into the awaiter field.
122         unsafe {
123             abort_on_panic(|| (*self.awaiter.get()) = Some(waker.clone()));
124         }
125 
126         // This variable will contain the newly registered waker if a notification comes in before
127         // we complete registration.
128         let mut waker = None;
129 
130         loop {
131             // If there was a notification, take the waker out of the awaiter field.
132             if state & NOTIFYING != 0 {
133                 if let Some(w) = unsafe { (*self.awaiter.get()).take() } {
134                     abort_on_panic(|| waker = Some(w));
135                 }
136             }
137 
138             // The new state is not being notified nor registered, but there might or might not be
139             // an awaiter depending on whether there was a concurrent notification.
140             let new = if waker.is_none() {
141                 (state & !NOTIFYING & !REGISTERING) | AWAITER
142             } else {
143                 state & !NOTIFYING & !REGISTERING & !AWAITER
144             };
145 
146             match self
147                 .state
148                 .compare_exchange_weak(state, new, Ordering::AcqRel, Ordering::Acquire)
149             {
150                 Ok(_) => break,
151                 Err(s) => state = s,
152             }
153         }
154 
155         // If there was a notification during registration, wake the awaiter now.
156         if let Some(w) = waker {
157             abort_on_panic(|| w.wake());
158         }
159     }
160 }
161 
162 impl<M: fmt::Debug> fmt::Debug for Header<M> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result163     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164         let state = self.state.load(Ordering::SeqCst);
165 
166         f.debug_struct("Header")
167             .field("scheduled", &(state & SCHEDULED != 0))
168             .field("running", &(state & RUNNING != 0))
169             .field("completed", &(state & COMPLETED != 0))
170             .field("closed", &(state & CLOSED != 0))
171             .field("awaiter", &(state & AWAITER != 0))
172             .field("task", &(state & TASK != 0))
173             .field("ref_count", &(state / REFERENCE))
174             .field("metadata", &self.metadata)
175             .finish()
176     }
177 }
178