1 use futures::channel::oneshot;
2 use futures::executor::{block_on, LocalPool};
3 use futures::future::{self, FutureExt, LocalFutureObj, TryFutureExt};
4 use futures::task::LocalSpawn;
5 use std::cell::{Cell, RefCell};
6 use std::panic::AssertUnwindSafe;
7 use std::rc::Rc;
8 use std::task::Poll;
9 use std::thread;
10 
11 struct CountClone(Rc<Cell<i32>>);
12 
13 impl Clone for CountClone {
clone(&self) -> Self14     fn clone(&self) -> Self {
15         self.0.set(self.0.get() + 1);
16         Self(self.0.clone())
17     }
18 }
19 
send_shared_oneshot_and_wait_on_multiple_threads(threads_number: u32)20 fn send_shared_oneshot_and_wait_on_multiple_threads(threads_number: u32) {
21     let (tx, rx) = oneshot::channel::<i32>();
22     let f = rx.shared();
23     let join_handles = (0..threads_number)
24         .map(|_| {
25             let cloned_future = f.clone();
26             thread::spawn(move || {
27                 assert_eq!(block_on(cloned_future).unwrap(), 6);
28             })
29         })
30         .collect::<Vec<_>>();
31 
32     tx.send(6).unwrap();
33 
34     assert_eq!(block_on(f).unwrap(), 6);
35     for join_handle in join_handles {
36         join_handle.join().unwrap();
37     }
38 }
39 
40 #[test]
one_thread()41 fn one_thread() {
42     send_shared_oneshot_and_wait_on_multiple_threads(1);
43 }
44 
45 #[test]
two_threads()46 fn two_threads() {
47     send_shared_oneshot_and_wait_on_multiple_threads(2);
48 }
49 
50 #[test]
many_threads()51 fn many_threads() {
52     send_shared_oneshot_and_wait_on_multiple_threads(1000);
53 }
54 
55 #[test]
drop_on_one_task_ok()56 fn drop_on_one_task_ok() {
57     let (tx, rx) = oneshot::channel::<u32>();
58     let f1 = rx.shared();
59     let f2 = f1.clone();
60 
61     let (tx2, rx2) = oneshot::channel::<u32>();
62 
63     let t1 = thread::spawn(|| {
64         let f = future::try_select(f1.map_err(|_| ()), rx2.map_err(|_| ()));
65         drop(block_on(f));
66     });
67 
68     let (tx3, rx3) = oneshot::channel::<u32>();
69 
70     let t2 = thread::spawn(|| {
71         let _ = block_on(f2.map_ok(|x| tx3.send(x).unwrap()).map_err(|_| ()));
72     });
73 
74     tx2.send(11).unwrap(); // cancel `f1`
75     t1.join().unwrap();
76 
77     tx.send(42).unwrap(); // Should cause `f2` and then `rx3` to get resolved.
78     let result = block_on(rx3).unwrap();
79     assert_eq!(result, 42);
80     t2.join().unwrap();
81 }
82 
83 #[test]
drop_in_poll()84 fn drop_in_poll() {
85     let slot1 = Rc::new(RefCell::new(None));
86     let slot2 = slot1.clone();
87 
88     let future1 = future::lazy(move |_| {
89         slot2.replace(None); // Drop future
90         1
91     })
92     .shared();
93 
94     let future2 = LocalFutureObj::new(Box::new(future1.clone()));
95     slot1.replace(Some(future2));
96 
97     assert_eq!(block_on(future1), 1);
98 }
99 
100 #[test]
peek()101 fn peek() {
102     let mut local_pool = LocalPool::new();
103     let spawn = &mut local_pool.spawner();
104 
105     let (tx0, rx0) = oneshot::channel::<i32>();
106     let f1 = rx0.shared();
107     let f2 = f1.clone();
108 
109     // Repeated calls on the original or clone do not change the outcome.
110     for _ in 0..2 {
111         assert!(f1.peek().is_none());
112         assert!(f2.peek().is_none());
113     }
114 
115     // Completing the underlying future has no effect, because the value has not been `poll`ed in.
116     tx0.send(42).unwrap();
117     for _ in 0..2 {
118         assert!(f1.peek().is_none());
119         assert!(f2.peek().is_none());
120     }
121 
122     // Once the Shared has been polled, the value is peekable on the clone.
123     spawn.spawn_local_obj(LocalFutureObj::new(Box::new(f1.map(|_| ())))).unwrap();
124     local_pool.run();
125     for _ in 0..2 {
126         assert_eq!(*f2.peek().unwrap(), Ok(42));
127     }
128 }
129 
130 #[test]
downgrade()131 fn downgrade() {
132     let (tx, rx) = oneshot::channel::<i32>();
133     let shared = rx.shared();
134     // Since there are outstanding `Shared`s, we can get a `WeakShared`.
135     let weak = shared.downgrade().unwrap();
136     // It should upgrade fine right now.
137     let mut shared2 = weak.upgrade().unwrap();
138 
139     tx.send(42).unwrap();
140     assert_eq!(block_on(shared).unwrap(), 42);
141 
142     // We should still be able to get a new `WeakShared` and upgrade it
143     // because `shared2` is outstanding.
144     assert!(shared2.downgrade().is_some());
145     assert!(weak.upgrade().is_some());
146 
147     assert_eq!(block_on(&mut shared2).unwrap(), 42);
148     // Now that all `Shared`s have been exhausted, we should not be able
149     // to get a new `WeakShared` or upgrade an existing one.
150     assert!(weak.upgrade().is_none());
151     assert!(shared2.downgrade().is_none());
152 }
153 
154 #[test]
ptr_eq()155 fn ptr_eq() {
156     use future::FusedFuture;
157     use std::collections::hash_map::DefaultHasher;
158     use std::hash::Hasher;
159 
160     let (tx, rx) = oneshot::channel::<i32>();
161     let shared = rx.shared();
162     let mut shared2 = shared.clone();
163     let mut hasher = DefaultHasher::new();
164     let mut hasher2 = DefaultHasher::new();
165 
166     // Because these two futures share the same underlying future,
167     // `ptr_eq` should return true.
168     assert!(shared.ptr_eq(&shared2));
169     // Equivalence relations are symmetric
170     assert!(shared2.ptr_eq(&shared));
171 
172     // If `ptr_eq` returns true, they should hash to the same value.
173     shared.ptr_hash(&mut hasher);
174     shared2.ptr_hash(&mut hasher2);
175     assert_eq!(hasher.finish(), hasher2.finish());
176 
177     tx.send(42).unwrap();
178     assert_eq!(block_on(&mut shared2).unwrap(), 42);
179 
180     // Now that `shared2` has completed, `ptr_eq` should return false.
181     assert!(shared2.is_terminated());
182     assert!(!shared.ptr_eq(&shared2));
183 
184     // `ptr_eq` should continue to work for the other `Shared`.
185     let shared3 = shared.clone();
186     let mut hasher3 = DefaultHasher::new();
187     assert!(shared.ptr_eq(&shared3));
188 
189     shared3.ptr_hash(&mut hasher3);
190     assert_eq!(hasher.finish(), hasher3.finish());
191 
192     let (_tx, rx) = oneshot::channel::<i32>();
193     let shared4 = rx.shared();
194 
195     // And `ptr_eq` should return false for two futures that don't share
196     // the underlying future.
197     assert!(!shared.ptr_eq(&shared4));
198 }
199 
200 #[test]
dont_clone_in_single_owner_shared_future()201 fn dont_clone_in_single_owner_shared_future() {
202     let counter = CountClone(Rc::new(Cell::new(0)));
203     let (tx, rx) = oneshot::channel();
204 
205     let rx = rx.shared();
206 
207     tx.send(counter).ok().unwrap();
208 
209     assert_eq!(block_on(rx).unwrap().0.get(), 0);
210 }
211 
212 #[test]
dont_do_unnecessary_clones_on_output()213 fn dont_do_unnecessary_clones_on_output() {
214     let counter = CountClone(Rc::new(Cell::new(0)));
215     let (tx, rx) = oneshot::channel();
216 
217     let rx = rx.shared();
218 
219     tx.send(counter).ok().unwrap();
220 
221     assert_eq!(block_on(rx.clone()).unwrap().0.get(), 1);
222     assert_eq!(block_on(rx.clone()).unwrap().0.get(), 2);
223     assert_eq!(block_on(rx).unwrap().0.get(), 2);
224 }
225 
226 #[test]
shared_future_that_wakes_itself_until_pending_is_returned()227 fn shared_future_that_wakes_itself_until_pending_is_returned() {
228     let proceed = Cell::new(false);
229     let fut = futures::future::poll_fn(|cx| {
230         if proceed.get() {
231             Poll::Ready(())
232         } else {
233             cx.waker().wake_by_ref();
234             Poll::Pending
235         }
236     })
237     .shared();
238 
239     // The join future can only complete if the second future gets a chance to run after the first
240     // has returned pending
241     assert_eq!(block_on(futures::future::join(fut, async { proceed.set(true) })), ((), ()));
242 }
243 
244 #[test]
245 #[should_panic(expected = "inner future panicked during poll")]
panic_while_poll()246 fn panic_while_poll() {
247     let fut = futures::future::poll_fn::<i8, _>(|_cx| panic!("test")).shared();
248 
249     let fut_captured = fut.clone();
250     std::panic::catch_unwind(AssertUnwindSafe(|| {
251         block_on(fut_captured);
252     }))
253     .unwrap_err();
254 
255     block_on(fut);
256 }
257 
258 #[test]
259 #[should_panic(expected = "test_marker")]
poll_while_panic()260 fn poll_while_panic() {
261     struct S;
262 
263     impl Drop for S {
264         fn drop(&mut self) {
265             let fut = futures::future::ready(1).shared();
266             assert_eq!(block_on(fut.clone()), 1);
267             assert_eq!(block_on(fut), 1);
268         }
269     }
270 
271     let _s = S {};
272     panic!("test_marker");
273 }
274