1 use std::cmp;
2 use std::fmt;
3 use std::mem;
4 use std::num::NonZeroUsize;
5 
6 use crate::errors::InvalidThreadAccess;
7 use crate::thread_id;
8 use std::mem::ManuallyDrop;
9 
10 /// A [`Fragile<T>`] wraps a non sendable `T` to be safely send to other threads.
11 ///
12 /// Once the value has been wrapped it can be sent to other threads but access
13 /// to the value on those threads will fail.
14 ///
15 /// If the value needs destruction and the fragile wrapper is on another thread
16 /// the destructor will panic.  Alternatively you can use
17 /// [`Sticky`](crate::Sticky) which is not going to panic but might temporarily
18 /// leak the value.
19 pub struct Fragile<T> {
20     // ManuallyDrop is necessary because we need to move out of here without running the
21     // Drop code in functions like `into_inner`.
22     value: ManuallyDrop<T>,
23     thread_id: NonZeroUsize,
24 }
25 
26 impl<T> Fragile<T> {
27     /// Creates a new [`Fragile`] wrapping a `value`.
28     ///
29     /// The value that is moved into the [`Fragile`] can be non `Send` and
30     /// will be anchored to the thread that created the object.  If the
31     /// fragile wrapper type ends up being send from thread to thread
32     /// only the original thread can interact with the value.
new(value: T) -> Self33     pub fn new(value: T) -> Self {
34         Fragile {
35             value: ManuallyDrop::new(value),
36             thread_id: thread_id::get(),
37         }
38     }
39 
40     /// Returns `true` if the access is valid.
41     ///
42     /// This will be `false` if the value was sent to another thread.
is_valid(&self) -> bool43     pub fn is_valid(&self) -> bool {
44         thread_id::get() == self.thread_id
45     }
46 
47     #[inline(always)]
assert_thread(&self)48     fn assert_thread(&self) {
49         if !self.is_valid() {
50             panic!("trying to access wrapped value in fragile container from incorrect thread.");
51         }
52     }
53 
54     /// Consumes the `Fragile`, returning the wrapped value.
55     ///
56     /// # Panics
57     ///
58     /// Panics if called from a different thread than the one where the
59     /// original value was created.
into_inner(self) -> T60     pub fn into_inner(self) -> T {
61         self.assert_thread();
62 
63         let mut this = ManuallyDrop::new(self);
64 
65         // SAFETY: `this` is not accessed beyond this point, and because it's in a ManuallyDrop its
66         // destructor is not run.
67         unsafe { ManuallyDrop::take(&mut this.value) }
68     }
69 
70     /// Consumes the `Fragile`, returning the wrapped value if successful.
71     ///
72     /// The wrapped value is returned if this is called from the same thread
73     /// as the one where the original value was created, otherwise the
74     /// [`Fragile`] is returned as `Err(self)`.
try_into_inner(self) -> Result<T, Self>75     pub fn try_into_inner(self) -> Result<T, Self> {
76         if thread_id::get() == self.thread_id {
77             Ok(self.into_inner())
78         } else {
79             Err(self)
80         }
81     }
82 
83     /// Immutably borrows the wrapped value.
84     ///
85     /// # Panics
86     ///
87     /// Panics if the calling thread is not the one that wrapped the value.
88     /// For a non-panicking variant, use [`try_get`](Self::try_get).
get(&self) -> &T89     pub fn get(&self) -> &T {
90         self.assert_thread();
91         &*self.value
92     }
93 
94     /// Mutably borrows the wrapped value.
95     ///
96     /// # Panics
97     ///
98     /// Panics if the calling thread is not the one that wrapped the value.
99     /// For a non-panicking variant, use [`try_get_mut`](Self::try_get_mut).
get_mut(&mut self) -> &mut T100     pub fn get_mut(&mut self) -> &mut T {
101         self.assert_thread();
102         &mut *self.value
103     }
104 
105     /// Tries to immutably borrow the wrapped value.
106     ///
107     /// Returns `None` if the calling thread is not the one that wrapped the value.
try_get(&self) -> Result<&T, InvalidThreadAccess>108     pub fn try_get(&self) -> Result<&T, InvalidThreadAccess> {
109         if thread_id::get() == self.thread_id {
110             Ok(&*self.value)
111         } else {
112             Err(InvalidThreadAccess)
113         }
114     }
115 
116     /// Tries to mutably borrow the wrapped value.
117     ///
118     /// Returns `None` if the calling thread is not the one that wrapped the value.
try_get_mut(&mut self) -> Result<&mut T, InvalidThreadAccess>119     pub fn try_get_mut(&mut self) -> Result<&mut T, InvalidThreadAccess> {
120         if thread_id::get() == self.thread_id {
121             Ok(&mut *self.value)
122         } else {
123             Err(InvalidThreadAccess)
124         }
125     }
126 }
127 
128 impl<T> Drop for Fragile<T> {
drop(&mut self)129     fn drop(&mut self) {
130         if mem::needs_drop::<T>() {
131             if thread_id::get() == self.thread_id {
132                 // SAFETY: `ManuallyDrop::drop` cannot be called after this point.
133                 unsafe { ManuallyDrop::drop(&mut self.value) };
134             } else {
135                 panic!("destructor of fragile object ran on wrong thread");
136             }
137         }
138     }
139 }
140 
141 impl<T> From<T> for Fragile<T> {
142     #[inline]
from(t: T) -> Fragile<T>143     fn from(t: T) -> Fragile<T> {
144         Fragile::new(t)
145     }
146 }
147 
148 impl<T: Clone> Clone for Fragile<T> {
149     #[inline]
clone(&self) -> Fragile<T>150     fn clone(&self) -> Fragile<T> {
151         Fragile::new(self.get().clone())
152     }
153 }
154 
155 impl<T: Default> Default for Fragile<T> {
156     #[inline]
default() -> Fragile<T>157     fn default() -> Fragile<T> {
158         Fragile::new(T::default())
159     }
160 }
161 
162 impl<T: PartialEq> PartialEq for Fragile<T> {
163     #[inline]
eq(&self, other: &Fragile<T>) -> bool164     fn eq(&self, other: &Fragile<T>) -> bool {
165         *self.get() == *other.get()
166     }
167 }
168 
169 impl<T: Eq> Eq for Fragile<T> {}
170 
171 impl<T: PartialOrd> PartialOrd for Fragile<T> {
172     #[inline]
partial_cmp(&self, other: &Fragile<T>) -> Option<cmp::Ordering>173     fn partial_cmp(&self, other: &Fragile<T>) -> Option<cmp::Ordering> {
174         self.get().partial_cmp(other.get())
175     }
176 
177     #[inline]
lt(&self, other: &Fragile<T>) -> bool178     fn lt(&self, other: &Fragile<T>) -> bool {
179         *self.get() < *other.get()
180     }
181 
182     #[inline]
le(&self, other: &Fragile<T>) -> bool183     fn le(&self, other: &Fragile<T>) -> bool {
184         *self.get() <= *other.get()
185     }
186 
187     #[inline]
gt(&self, other: &Fragile<T>) -> bool188     fn gt(&self, other: &Fragile<T>) -> bool {
189         *self.get() > *other.get()
190     }
191 
192     #[inline]
ge(&self, other: &Fragile<T>) -> bool193     fn ge(&self, other: &Fragile<T>) -> bool {
194         *self.get() >= *other.get()
195     }
196 }
197 
198 impl<T: Ord> Ord for Fragile<T> {
199     #[inline]
cmp(&self, other: &Fragile<T>) -> cmp::Ordering200     fn cmp(&self, other: &Fragile<T>) -> cmp::Ordering {
201         self.get().cmp(other.get())
202     }
203 }
204 
205 impl<T: fmt::Display> fmt::Display for Fragile<T> {
fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error>206     fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
207         fmt::Display::fmt(self.get(), f)
208     }
209 }
210 
211 impl<T: fmt::Debug> fmt::Debug for Fragile<T> {
fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error>212     fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
213         match self.try_get() {
214             Ok(value) => f.debug_struct("Fragile").field("value", value).finish(),
215             Err(..) => {
216                 struct InvalidPlaceholder;
217                 impl fmt::Debug for InvalidPlaceholder {
218                     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
219                         f.write_str("<invalid thread>")
220                     }
221                 }
222 
223                 f.debug_struct("Fragile")
224                     .field("value", &InvalidPlaceholder)
225                     .finish()
226             }
227         }
228     }
229 }
230 
231 // this type is sync because access can only ever happy from the same thread
232 // that created it originally.  All other threads will be able to safely
233 // call some basic operations on the reference and they will fail.
234 unsafe impl<T> Sync for Fragile<T> {}
235 
236 // The entire point of this type is to be Send
237 #[allow(clippy::non_send_fields_in_send_ty)]
238 unsafe impl<T> Send for Fragile<T> {}
239 
240 #[test]
test_basic()241 fn test_basic() {
242     use std::thread;
243     let val = Fragile::new(true);
244     assert_eq!(val.to_string(), "true");
245     assert_eq!(val.get(), &true);
246     assert!(val.try_get().is_ok());
247     thread::spawn(move || {
248         assert!(val.try_get().is_err());
249     })
250     .join()
251     .unwrap();
252 }
253 
254 #[test]
test_mut()255 fn test_mut() {
256     let mut val = Fragile::new(true);
257     *val.get_mut() = false;
258     assert_eq!(val.to_string(), "false");
259     assert_eq!(val.get(), &false);
260 }
261 
262 #[test]
263 #[should_panic]
test_access_other_thread()264 fn test_access_other_thread() {
265     use std::thread;
266     let val = Fragile::new(true);
267     thread::spawn(move || {
268         val.get();
269     })
270     .join()
271     .unwrap();
272 }
273 
274 #[test]
test_noop_drop_elsewhere()275 fn test_noop_drop_elsewhere() {
276     use std::thread;
277     let val = Fragile::new(true);
278     thread::spawn(move || {
279         // force the move
280         val.try_get().ok();
281     })
282     .join()
283     .unwrap();
284 }
285 
286 #[test]
test_panic_on_drop_elsewhere()287 fn test_panic_on_drop_elsewhere() {
288     use std::sync::atomic::{AtomicBool, Ordering};
289     use std::sync::Arc;
290     use std::thread;
291     let was_called = Arc::new(AtomicBool::new(false));
292     struct X(Arc<AtomicBool>);
293     impl Drop for X {
294         fn drop(&mut self) {
295             self.0.store(true, Ordering::SeqCst);
296         }
297     }
298     let val = Fragile::new(X(was_called.clone()));
299     assert!(thread::spawn(move || {
300         val.try_get().ok();
301     })
302     .join()
303     .is_err());
304     assert!(!was_called.load(Ordering::SeqCst));
305 }
306 
307 #[test]
test_rc_sending()308 fn test_rc_sending() {
309     use std::rc::Rc;
310     use std::sync::mpsc::channel;
311     use std::thread;
312 
313     let val = Fragile::new(Rc::new(true));
314     let (tx, rx) = channel();
315 
316     let thread = thread::spawn(move || {
317         assert!(val.try_get().is_err());
318         let here = val;
319         tx.send(here).unwrap();
320     });
321 
322     let rv = rx.recv().unwrap();
323     assert!(**rv.get());
324 
325     thread.join().unwrap();
326 }
327