1 #![allow(clippy::unit_arg)]
2 
3 use std::cmp;
4 use std::fmt;
5 use std::marker::PhantomData;
6 use std::mem;
7 use std::num::NonZeroUsize;
8 
9 use crate::errors::InvalidThreadAccess;
10 use crate::registry;
11 use crate::thread_id;
12 use crate::StackToken;
13 
14 /// A [`Sticky<T>`] keeps a value T stored in a thread.
15 ///
16 /// This type works similar in nature to [`Fragile`](crate::Fragile) and exposes a
17 /// similar interface.  The difference is that whereas [`Fragile`](crate::Fragile) has
18 /// its destructor called in the thread where the value was sent, a
19 /// [`Sticky`] that is moved to another thread will have the internal
20 /// destructor called when the originating thread tears down.
21 ///
22 /// Because [`Sticky`] allows values to be kept alive for longer than the
23 /// [`Sticky`] itself, it requires all its contents to be `'static` for
24 /// soundness.  More importantly it also requires the use of [`StackToken`]s.
25 /// For information about how to use stack tokens and why they are neded,
26 /// refer to [`stack_token!`](crate::stack_token).
27 ///
28 /// As this uses TLS internally the general rules about the platform limitations
29 /// of destructors for TLS apply.
30 pub struct Sticky<T: 'static> {
31     item_id: registry::ItemId,
32     thread_id: NonZeroUsize,
33     _marker: PhantomData<*mut T>,
34 }
35 
36 impl<T> Drop for Sticky<T> {
drop(&mut self)37     fn drop(&mut self) {
38         // if the type needs dropping we can only do so on the
39         // right thread.  worst case we leak the value until the
40         // thread dies.
41         if mem::needs_drop::<T>() {
42             unsafe {
43                 if self.is_valid() {
44                     self.unsafe_take_value();
45                 }
46             }
47 
48         // otherwise we take the liberty to drop the value
49         // right here and now.  We can however only do that if
50         // we are on the right thread.  If we are not, we again
51         // need to wait for the thread to shut down.
52         } else if let Some(entry) = registry::try_remove(self.item_id, self.thread_id) {
53             unsafe {
54                 (entry.drop)(entry.ptr);
55             }
56         }
57     }
58 }
59 
60 impl<T> Sticky<T> {
61     /// Creates a new [`Sticky`] wrapping a `value`.
62     ///
63     /// The value that is moved into the [`Sticky`] can be non `Send` and
64     /// will be anchored to the thread that created the object.  If the
65     /// sticky wrapper type ends up being send from thread to thread
66     /// only the original thread can interact with the value.
new(value: T) -> Self67     pub fn new(value: T) -> Self {
68         let entry = registry::Entry {
69             ptr: Box::into_raw(Box::new(value)).cast(),
70             drop: |ptr| {
71                 let ptr = ptr.cast::<T>();
72                 // SAFETY: This callback will only be called once, with the
73                 // above pointer.
74                 drop(unsafe { Box::from_raw(ptr) });
75             },
76         };
77 
78         let thread_id = thread_id::get();
79         let item_id = registry::insert(thread_id, entry);
80 
81         Sticky {
82             item_id,
83             thread_id,
84             _marker: PhantomData,
85         }
86     }
87 
88     #[inline(always)]
with_value<F: FnOnce(*mut T) -> R, R>(&self, f: F) -> R89     fn with_value<F: FnOnce(*mut T) -> R, R>(&self, f: F) -> R {
90         self.assert_thread();
91 
92         registry::with(self.item_id, self.thread_id, |entry| {
93             f(entry.ptr.cast::<T>())
94         })
95     }
96 
97     /// Returns `true` if the access is valid.
98     ///
99     /// This will be `false` if the value was sent to another thread.
100     #[inline(always)]
is_valid(&self) -> bool101     pub fn is_valid(&self) -> bool {
102         thread_id::get() == self.thread_id
103     }
104 
105     #[inline(always)]
assert_thread(&self)106     fn assert_thread(&self) {
107         if !self.is_valid() {
108             panic!("trying to access wrapped value in sticky container from incorrect thread.");
109         }
110     }
111 
112     /// Consumes the `Sticky`, returning the wrapped value.
113     ///
114     /// # Panics
115     ///
116     /// Panics if called from a different thread than the one where the
117     /// original value was created.
into_inner(mut self) -> T118     pub fn into_inner(mut self) -> T {
119         self.assert_thread();
120         unsafe {
121             let rv = self.unsafe_take_value();
122             mem::forget(self);
123             rv
124         }
125     }
126 
unsafe_take_value(&mut self) -> T127     unsafe fn unsafe_take_value(&mut self) -> T {
128         let ptr = registry::remove(self.item_id, self.thread_id)
129             .ptr
130             .cast::<T>();
131         *Box::from_raw(ptr)
132     }
133 
134     /// Consumes the `Sticky`, returning the wrapped value if successful.
135     ///
136     /// The wrapped value is returned if this is called from the same thread
137     /// as the one where the original value was created, otherwise the
138     /// `Sticky` is returned as `Err(self)`.
try_into_inner(self) -> Result<T, Self>139     pub fn try_into_inner(self) -> Result<T, Self> {
140         if self.is_valid() {
141             Ok(self.into_inner())
142         } else {
143             Err(self)
144         }
145     }
146 
147     /// Immutably borrows the wrapped value.
148     ///
149     /// # Panics
150     ///
151     /// Panics if the calling thread is not the one that wrapped the value.
152     /// For a non-panicking variant, use [`try_get`](#method.try_get`).
get<'stack>(&'stack self, _proof: &'stack StackToken) -> &'stack T153     pub fn get<'stack>(&'stack self, _proof: &'stack StackToken) -> &'stack T {
154         self.with_value(|value| unsafe { &*value })
155     }
156 
157     /// Mutably borrows the wrapped value.
158     ///
159     /// # Panics
160     ///
161     /// Panics if the calling thread is not the one that wrapped the value.
162     /// For a non-panicking variant, use [`try_get_mut`](#method.try_get_mut`).
get_mut<'stack>(&'stack mut self, _proof: &'stack StackToken) -> &'stack mut T163     pub fn get_mut<'stack>(&'stack mut self, _proof: &'stack StackToken) -> &'stack mut T {
164         self.with_value(|value| unsafe { &mut *value })
165     }
166 
167     /// Tries to immutably borrow the wrapped value.
168     ///
169     /// Returns `None` if the calling thread is not the one that wrapped the value.
try_get<'stack>( &'stack self, _proof: &'stack StackToken, ) -> Result<&'stack T, InvalidThreadAccess>170     pub fn try_get<'stack>(
171         &'stack self,
172         _proof: &'stack StackToken,
173     ) -> Result<&'stack T, InvalidThreadAccess> {
174         if self.is_valid() {
175             Ok(self.with_value(|value| unsafe { &*value }))
176         } else {
177             Err(InvalidThreadAccess)
178         }
179     }
180 
181     /// Tries to mutably borrow the wrapped value.
182     ///
183     /// Returns `None` if the calling thread is not the one that wrapped the value.
try_get_mut<'stack>( &'stack mut self, _proof: &'stack StackToken, ) -> Result<&'stack mut T, InvalidThreadAccess>184     pub fn try_get_mut<'stack>(
185         &'stack mut self,
186         _proof: &'stack StackToken,
187     ) -> Result<&'stack mut T, InvalidThreadAccess> {
188         if self.is_valid() {
189             Ok(self.with_value(|value| unsafe { &mut *value }))
190         } else {
191             Err(InvalidThreadAccess)
192         }
193     }
194 }
195 
196 impl<T> From<T> for Sticky<T> {
197     #[inline]
from(t: T) -> Sticky<T>198     fn from(t: T) -> Sticky<T> {
199         Sticky::new(t)
200     }
201 }
202 
203 impl<T: Clone> Clone for Sticky<T> {
204     #[inline]
clone(&self) -> Sticky<T>205     fn clone(&self) -> Sticky<T> {
206         crate::stack_token!(tok);
207         Sticky::new(self.get(tok).clone())
208     }
209 }
210 
211 impl<T: Default> Default for Sticky<T> {
212     #[inline]
default() -> Sticky<T>213     fn default() -> Sticky<T> {
214         Sticky::new(T::default())
215     }
216 }
217 
218 impl<T: PartialEq> PartialEq for Sticky<T> {
219     #[inline]
eq(&self, other: &Sticky<T>) -> bool220     fn eq(&self, other: &Sticky<T>) -> bool {
221         crate::stack_token!(tok);
222         *self.get(tok) == *other.get(tok)
223     }
224 }
225 
226 impl<T: Eq> Eq for Sticky<T> {}
227 
228 impl<T: PartialOrd> PartialOrd for Sticky<T> {
229     #[inline]
partial_cmp(&self, other: &Sticky<T>) -> Option<cmp::Ordering>230     fn partial_cmp(&self, other: &Sticky<T>) -> Option<cmp::Ordering> {
231         crate::stack_token!(tok);
232         self.get(tok).partial_cmp(other.get(tok))
233     }
234 
235     #[inline]
lt(&self, other: &Sticky<T>) -> bool236     fn lt(&self, other: &Sticky<T>) -> bool {
237         crate::stack_token!(tok);
238         *self.get(tok) < *other.get(tok)
239     }
240 
241     #[inline]
le(&self, other: &Sticky<T>) -> bool242     fn le(&self, other: &Sticky<T>) -> bool {
243         crate::stack_token!(tok);
244         *self.get(tok) <= *other.get(tok)
245     }
246 
247     #[inline]
gt(&self, other: &Sticky<T>) -> bool248     fn gt(&self, other: &Sticky<T>) -> bool {
249         crate::stack_token!(tok);
250         *self.get(tok) > *other.get(tok)
251     }
252 
253     #[inline]
ge(&self, other: &Sticky<T>) -> bool254     fn ge(&self, other: &Sticky<T>) -> bool {
255         crate::stack_token!(tok);
256         *self.get(tok) >= *other.get(tok)
257     }
258 }
259 
260 impl<T: Ord> Ord for Sticky<T> {
261     #[inline]
cmp(&self, other: &Sticky<T>) -> cmp::Ordering262     fn cmp(&self, other: &Sticky<T>) -> cmp::Ordering {
263         crate::stack_token!(tok);
264         self.get(tok).cmp(other.get(tok))
265     }
266 }
267 
268 impl<T: fmt::Display> fmt::Display for Sticky<T> {
fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error>269     fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
270         crate::stack_token!(tok);
271         fmt::Display::fmt(self.get(tok), f)
272     }
273 }
274 
275 impl<T: fmt::Debug> fmt::Debug for Sticky<T> {
fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error>276     fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
277         crate::stack_token!(tok);
278         match self.try_get(tok) {
279             Ok(value) => f.debug_struct("Sticky").field("value", value).finish(),
280             Err(..) => {
281                 struct InvalidPlaceholder;
282                 impl fmt::Debug for InvalidPlaceholder {
283                     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
284                         f.write_str("<invalid thread>")
285                     }
286                 }
287 
288                 f.debug_struct("Sticky")
289                     .field("value", &InvalidPlaceholder)
290                     .finish()
291             }
292         }
293     }
294 }
295 
296 // similar as for fragile ths type is sync because it only accesses TLS data
297 // which is thread local.  There is nothing that needs to be synchronized.
298 unsafe impl<T> Sync for Sticky<T> {}
299 
300 // The entire point of this type is to be Send
301 unsafe impl<T> Send for Sticky<T> {}
302 
303 #[test]
test_basic()304 fn test_basic() {
305     use std::thread;
306     let val = Sticky::new(true);
307     crate::stack_token!(tok);
308     assert_eq!(val.to_string(), "true");
309     assert_eq!(val.get(tok), &true);
310     assert!(val.try_get(tok).is_ok());
311     thread::spawn(move || {
312         crate::stack_token!(tok);
313         assert!(val.try_get(tok).is_err());
314     })
315     .join()
316     .unwrap();
317 }
318 
319 #[test]
test_mut()320 fn test_mut() {
321     let mut val = Sticky::new(true);
322     crate::stack_token!(tok);
323     *val.get_mut(tok) = false;
324     assert_eq!(val.to_string(), "false");
325     assert_eq!(val.get(tok), &false);
326 }
327 
328 #[test]
329 #[should_panic]
test_access_other_thread()330 fn test_access_other_thread() {
331     use std::thread;
332     let val = Sticky::new(true);
333     thread::spawn(move || {
334         crate::stack_token!(tok);
335         val.get(tok);
336     })
337     .join()
338     .unwrap();
339 }
340 
341 #[test]
test_drop_same_thread()342 fn test_drop_same_thread() {
343     use std::sync::atomic::{AtomicBool, Ordering};
344     use std::sync::Arc;
345     let was_called = Arc::new(AtomicBool::new(false));
346     struct X(Arc<AtomicBool>);
347     impl Drop for X {
348         fn drop(&mut self) {
349             self.0.store(true, Ordering::SeqCst);
350         }
351     }
352     let val = Sticky::new(X(was_called.clone()));
353     mem::drop(val);
354     assert!(was_called.load(Ordering::SeqCst));
355 }
356 
357 #[test]
test_noop_drop_elsewhere()358 fn test_noop_drop_elsewhere() {
359     use std::sync::atomic::{AtomicBool, Ordering};
360     use std::sync::Arc;
361     use std::thread;
362 
363     let was_called = Arc::new(AtomicBool::new(false));
364 
365     {
366         let was_called = was_called.clone();
367         thread::spawn(move || {
368             struct X(Arc<AtomicBool>);
369             impl Drop for X {
370                 fn drop(&mut self) {
371                     self.0.store(true, Ordering::SeqCst);
372                 }
373             }
374 
375             let val = Sticky::new(X(was_called.clone()));
376             assert!(thread::spawn(move || {
377                 // moves it here but do not deallocate
378                 crate::stack_token!(tok);
379                 val.try_get(tok).ok();
380             })
381             .join()
382             .is_ok());
383 
384             assert!(!was_called.load(Ordering::SeqCst));
385         })
386         .join()
387         .unwrap();
388     }
389 
390     assert!(was_called.load(Ordering::SeqCst));
391 }
392 
393 #[test]
test_rc_sending()394 fn test_rc_sending() {
395     use std::rc::Rc;
396     use std::thread;
397     let val = Sticky::new(Rc::new(true));
398     thread::spawn(move || {
399         crate::stack_token!(tok);
400         assert!(val.try_get(tok).is_err());
401     })
402     .join()
403     .unwrap();
404 }
405 
406 #[test]
test_two_stickies()407 fn test_two_stickies() {
408     struct Wat;
409 
410     impl Drop for Wat {
411         fn drop(&mut self) {
412             // do nothing
413         }
414     }
415 
416     let s1 = Sticky::new(Wat);
417     let s2 = Sticky::new(Wat);
418 
419     // make sure all is well
420 
421     drop(s1);
422     drop(s2);
423 }
424