1 use std::cmp;
2 use std::fmt;
3 use std::mem;
4 use std::num::NonZeroUsize;
5
6 use crate::errors::InvalidThreadAccess;
7 use crate::thread_id;
8 use std::mem::ManuallyDrop;
9
10 /// A [`Fragile<T>`] wraps a non sendable `T` to be safely send to other threads.
11 ///
12 /// Once the value has been wrapped it can be sent to other threads but access
13 /// to the value on those threads will fail.
14 ///
15 /// If the value needs destruction and the fragile wrapper is on another thread
16 /// the destructor will panic. Alternatively you can use
17 /// [`Sticky`](crate::Sticky) which is not going to panic but might temporarily
18 /// leak the value.
19 pub struct Fragile<T> {
20 // ManuallyDrop is necessary because we need to move out of here without running the
21 // Drop code in functions like `into_inner`.
22 value: ManuallyDrop<T>,
23 thread_id: NonZeroUsize,
24 }
25
26 impl<T> Fragile<T> {
27 /// Creates a new [`Fragile`] wrapping a `value`.
28 ///
29 /// The value that is moved into the [`Fragile`] can be non `Send` and
30 /// will be anchored to the thread that created the object. If the
31 /// fragile wrapper type ends up being send from thread to thread
32 /// only the original thread can interact with the value.
new(value: T) -> Self33 pub fn new(value: T) -> Self {
34 Fragile {
35 value: ManuallyDrop::new(value),
36 thread_id: thread_id::get(),
37 }
38 }
39
40 /// Returns `true` if the access is valid.
41 ///
42 /// This will be `false` if the value was sent to another thread.
is_valid(&self) -> bool43 pub fn is_valid(&self) -> bool {
44 thread_id::get() == self.thread_id
45 }
46
47 #[inline(always)]
assert_thread(&self)48 fn assert_thread(&self) {
49 if !self.is_valid() {
50 panic!("trying to access wrapped value in fragile container from incorrect thread.");
51 }
52 }
53
54 /// Consumes the `Fragile`, returning the wrapped value.
55 ///
56 /// # Panics
57 ///
58 /// Panics if called from a different thread than the one where the
59 /// original value was created.
into_inner(self) -> T60 pub fn into_inner(self) -> T {
61 self.assert_thread();
62
63 let mut this = ManuallyDrop::new(self);
64
65 // SAFETY: `this` is not accessed beyond this point, and because it's in a ManuallyDrop its
66 // destructor is not run.
67 unsafe { ManuallyDrop::take(&mut this.value) }
68 }
69
70 /// Consumes the `Fragile`, returning the wrapped value if successful.
71 ///
72 /// The wrapped value is returned if this is called from the same thread
73 /// as the one where the original value was created, otherwise the
74 /// [`Fragile`] is returned as `Err(self)`.
try_into_inner(self) -> Result<T, Self>75 pub fn try_into_inner(self) -> Result<T, Self> {
76 if thread_id::get() == self.thread_id {
77 Ok(self.into_inner())
78 } else {
79 Err(self)
80 }
81 }
82
83 /// Immutably borrows the wrapped value.
84 ///
85 /// # Panics
86 ///
87 /// Panics if the calling thread is not the one that wrapped the value.
88 /// For a non-panicking variant, use [`try_get`](Self::try_get).
get(&self) -> &T89 pub fn get(&self) -> &T {
90 self.assert_thread();
91 &*self.value
92 }
93
94 /// Mutably borrows the wrapped value.
95 ///
96 /// # Panics
97 ///
98 /// Panics if the calling thread is not the one that wrapped the value.
99 /// For a non-panicking variant, use [`try_get_mut`](Self::try_get_mut).
get_mut(&mut self) -> &mut T100 pub fn get_mut(&mut self) -> &mut T {
101 self.assert_thread();
102 &mut *self.value
103 }
104
105 /// Tries to immutably borrow the wrapped value.
106 ///
107 /// Returns `None` if the calling thread is not the one that wrapped the value.
try_get(&self) -> Result<&T, InvalidThreadAccess>108 pub fn try_get(&self) -> Result<&T, InvalidThreadAccess> {
109 if thread_id::get() == self.thread_id {
110 Ok(&*self.value)
111 } else {
112 Err(InvalidThreadAccess)
113 }
114 }
115
116 /// Tries to mutably borrow the wrapped value.
117 ///
118 /// Returns `None` if the calling thread is not the one that wrapped the value.
try_get_mut(&mut self) -> Result<&mut T, InvalidThreadAccess>119 pub fn try_get_mut(&mut self) -> Result<&mut T, InvalidThreadAccess> {
120 if thread_id::get() == self.thread_id {
121 Ok(&mut *self.value)
122 } else {
123 Err(InvalidThreadAccess)
124 }
125 }
126 }
127
128 impl<T> Drop for Fragile<T> {
drop(&mut self)129 fn drop(&mut self) {
130 if mem::needs_drop::<T>() {
131 if thread_id::get() == self.thread_id {
132 // SAFETY: `ManuallyDrop::drop` cannot be called after this point.
133 unsafe { ManuallyDrop::drop(&mut self.value) };
134 } else {
135 panic!("destructor of fragile object ran on wrong thread");
136 }
137 }
138 }
139 }
140
141 impl<T> From<T> for Fragile<T> {
142 #[inline]
from(t: T) -> Fragile<T>143 fn from(t: T) -> Fragile<T> {
144 Fragile::new(t)
145 }
146 }
147
148 impl<T: Clone> Clone for Fragile<T> {
149 #[inline]
clone(&self) -> Fragile<T>150 fn clone(&self) -> Fragile<T> {
151 Fragile::new(self.get().clone())
152 }
153 }
154
155 impl<T: Default> Default for Fragile<T> {
156 #[inline]
default() -> Fragile<T>157 fn default() -> Fragile<T> {
158 Fragile::new(T::default())
159 }
160 }
161
162 impl<T: PartialEq> PartialEq for Fragile<T> {
163 #[inline]
eq(&self, other: &Fragile<T>) -> bool164 fn eq(&self, other: &Fragile<T>) -> bool {
165 *self.get() == *other.get()
166 }
167 }
168
169 impl<T: Eq> Eq for Fragile<T> {}
170
171 impl<T: PartialOrd> PartialOrd for Fragile<T> {
172 #[inline]
partial_cmp(&self, other: &Fragile<T>) -> Option<cmp::Ordering>173 fn partial_cmp(&self, other: &Fragile<T>) -> Option<cmp::Ordering> {
174 self.get().partial_cmp(other.get())
175 }
176
177 #[inline]
lt(&self, other: &Fragile<T>) -> bool178 fn lt(&self, other: &Fragile<T>) -> bool {
179 *self.get() < *other.get()
180 }
181
182 #[inline]
le(&self, other: &Fragile<T>) -> bool183 fn le(&self, other: &Fragile<T>) -> bool {
184 *self.get() <= *other.get()
185 }
186
187 #[inline]
gt(&self, other: &Fragile<T>) -> bool188 fn gt(&self, other: &Fragile<T>) -> bool {
189 *self.get() > *other.get()
190 }
191
192 #[inline]
ge(&self, other: &Fragile<T>) -> bool193 fn ge(&self, other: &Fragile<T>) -> bool {
194 *self.get() >= *other.get()
195 }
196 }
197
198 impl<T: Ord> Ord for Fragile<T> {
199 #[inline]
cmp(&self, other: &Fragile<T>) -> cmp::Ordering200 fn cmp(&self, other: &Fragile<T>) -> cmp::Ordering {
201 self.get().cmp(other.get())
202 }
203 }
204
205 impl<T: fmt::Display> fmt::Display for Fragile<T> {
fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error>206 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
207 fmt::Display::fmt(self.get(), f)
208 }
209 }
210
211 impl<T: fmt::Debug> fmt::Debug for Fragile<T> {
fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error>212 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
213 match self.try_get() {
214 Ok(value) => f.debug_struct("Fragile").field("value", value).finish(),
215 Err(..) => {
216 struct InvalidPlaceholder;
217 impl fmt::Debug for InvalidPlaceholder {
218 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
219 f.write_str("<invalid thread>")
220 }
221 }
222
223 f.debug_struct("Fragile")
224 .field("value", &InvalidPlaceholder)
225 .finish()
226 }
227 }
228 }
229 }
230
231 // this type is sync because access can only ever happy from the same thread
232 // that created it originally. All other threads will be able to safely
233 // call some basic operations on the reference and they will fail.
234 unsafe impl<T> Sync for Fragile<T> {}
235
236 // The entire point of this type is to be Send
237 #[allow(clippy::non_send_fields_in_send_ty)]
238 unsafe impl<T> Send for Fragile<T> {}
239
240 #[test]
test_basic()241 fn test_basic() {
242 use std::thread;
243 let val = Fragile::new(true);
244 assert_eq!(val.to_string(), "true");
245 assert_eq!(val.get(), &true);
246 assert!(val.try_get().is_ok());
247 thread::spawn(move || {
248 assert!(val.try_get().is_err());
249 })
250 .join()
251 .unwrap();
252 }
253
254 #[test]
test_mut()255 fn test_mut() {
256 let mut val = Fragile::new(true);
257 *val.get_mut() = false;
258 assert_eq!(val.to_string(), "false");
259 assert_eq!(val.get(), &false);
260 }
261
262 #[test]
263 #[should_panic]
test_access_other_thread()264 fn test_access_other_thread() {
265 use std::thread;
266 let val = Fragile::new(true);
267 thread::spawn(move || {
268 val.get();
269 })
270 .join()
271 .unwrap();
272 }
273
274 #[test]
test_noop_drop_elsewhere()275 fn test_noop_drop_elsewhere() {
276 use std::thread;
277 let val = Fragile::new(true);
278 thread::spawn(move || {
279 // force the move
280 val.try_get().ok();
281 })
282 .join()
283 .unwrap();
284 }
285
286 #[test]
test_panic_on_drop_elsewhere()287 fn test_panic_on_drop_elsewhere() {
288 use std::sync::atomic::{AtomicBool, Ordering};
289 use std::sync::Arc;
290 use std::thread;
291 let was_called = Arc::new(AtomicBool::new(false));
292 struct X(Arc<AtomicBool>);
293 impl Drop for X {
294 fn drop(&mut self) {
295 self.0.store(true, Ordering::SeqCst);
296 }
297 }
298 let val = Fragile::new(X(was_called.clone()));
299 assert!(thread::spawn(move || {
300 val.try_get().ok();
301 })
302 .join()
303 .is_err());
304 assert!(!was_called.load(Ordering::SeqCst));
305 }
306
307 #[test]
test_rc_sending()308 fn test_rc_sending() {
309 use std::rc::Rc;
310 use std::sync::mpsc::channel;
311 use std::thread;
312
313 let val = Fragile::new(Rc::new(true));
314 let (tx, rx) = channel();
315
316 let thread = thread::spawn(move || {
317 assert!(val.try_get().is_err());
318 let here = val;
319 tx.send(here).unwrap();
320 });
321
322 let rv = rx.recv().unwrap();
323 assert!(**rv.get());
324
325 thread.join().unwrap();
326 }
327