1 use std::{
2     cmp,
3     io::{self, Read as _},
4     iter,
5 };
6 
7 use rand::{Rng as _, RngCore as _};
8 
9 use super::decoder::{DecoderReader, BUF_SIZE};
10 use crate::{
11     alphabet,
12     engine::{general_purpose::STANDARD, Engine, GeneralPurpose},
13     tests::{random_alphabet, random_config, random_engine},
14     DecodeError, PAD_BYTE,
15 };
16 
17 #[test]
simple()18 fn simple() {
19     let tests: &[(&[u8], &[u8])] = &[
20         (&b"0"[..], &b"MA=="[..]),
21         (b"01", b"MDE="),
22         (b"012", b"MDEy"),
23         (b"0123", b"MDEyMw=="),
24         (b"01234", b"MDEyMzQ="),
25         (b"012345", b"MDEyMzQ1"),
26         (b"0123456", b"MDEyMzQ1Ng=="),
27         (b"01234567", b"MDEyMzQ1Njc="),
28         (b"012345678", b"MDEyMzQ1Njc4"),
29         (b"0123456789", b"MDEyMzQ1Njc4OQ=="),
30     ][..];
31 
32     for (text_expected, base64data) in tests.iter() {
33         // Read n bytes at a time.
34         for n in 1..base64data.len() + 1 {
35             let mut wrapped_reader = io::Cursor::new(base64data);
36             let mut decoder = DecoderReader::new(&mut wrapped_reader, &STANDARD);
37 
38             // handle errors as you normally would
39             let mut text_got = Vec::new();
40             let mut buffer = vec![0u8; n];
41             while let Ok(read) = decoder.read(&mut buffer[..]) {
42                 if read == 0 {
43                     break;
44                 }
45                 text_got.extend_from_slice(&buffer[..read]);
46             }
47 
48             assert_eq!(
49                 text_got,
50                 *text_expected,
51                 "\nGot: {}\nExpected: {}",
52                 String::from_utf8_lossy(&text_got[..]),
53                 String::from_utf8_lossy(text_expected)
54             );
55         }
56     }
57 }
58 
59 // Make sure we error out on trailing junk.
60 #[test]
trailing_junk()61 fn trailing_junk() {
62     let tests: &[&[u8]] = &[&b"MDEyMzQ1Njc4*!@#$%^&"[..], b"MDEyMzQ1Njc4OQ== "][..];
63 
64     for base64data in tests.iter() {
65         // Read n bytes at a time.
66         for n in 1..base64data.len() + 1 {
67             let mut wrapped_reader = io::Cursor::new(base64data);
68             let mut decoder = DecoderReader::new(&mut wrapped_reader, &STANDARD);
69 
70             // handle errors as you normally would
71             let mut buffer = vec![0u8; n];
72             let mut saw_error = false;
73             loop {
74                 match decoder.read(&mut buffer[..]) {
75                     Err(_) => {
76                         saw_error = true;
77                         break;
78                     }
79                     Ok(0) => break,
80                     Ok(_len) => (),
81                 }
82             }
83 
84             assert!(saw_error);
85         }
86     }
87 }
88 
89 #[test]
handles_short_read_from_delegate()90 fn handles_short_read_from_delegate() {
91     let mut rng = rand::thread_rng();
92     let mut bytes = Vec::new();
93     let mut b64 = String::new();
94     let mut decoded = Vec::new();
95 
96     for _ in 0..10_000 {
97         bytes.clear();
98         b64.clear();
99         decoded.clear();
100 
101         let size = rng.gen_range(0..(10 * BUF_SIZE));
102         bytes.extend(iter::repeat(0).take(size));
103         bytes.truncate(size);
104         rng.fill_bytes(&mut bytes[..size]);
105         assert_eq!(size, bytes.len());
106 
107         let engine = random_engine(&mut rng);
108         engine.encode_string(&bytes[..], &mut b64);
109 
110         let mut wrapped_reader = io::Cursor::new(b64.as_bytes());
111         let mut short_reader = RandomShortRead {
112             delegate: &mut wrapped_reader,
113             rng: &mut rng,
114         };
115 
116         let mut decoder = DecoderReader::new(&mut short_reader, &engine);
117 
118         let decoded_len = decoder.read_to_end(&mut decoded).unwrap();
119         assert_eq!(size, decoded_len);
120         assert_eq!(&bytes[..], &decoded[..]);
121     }
122 }
123 
124 #[test]
read_in_short_increments()125 fn read_in_short_increments() {
126     let mut rng = rand::thread_rng();
127     let mut bytes = Vec::new();
128     let mut b64 = String::new();
129     let mut decoded = Vec::new();
130 
131     for _ in 0..10_000 {
132         bytes.clear();
133         b64.clear();
134         decoded.clear();
135 
136         let size = rng.gen_range(0..(10 * BUF_SIZE));
137         bytes.extend(iter::repeat(0).take(size));
138         // leave room to play around with larger buffers
139         decoded.extend(iter::repeat(0).take(size * 3));
140 
141         rng.fill_bytes(&mut bytes[..]);
142         assert_eq!(size, bytes.len());
143 
144         let engine = random_engine(&mut rng);
145 
146         engine.encode_string(&bytes[..], &mut b64);
147 
148         let mut wrapped_reader = io::Cursor::new(&b64[..]);
149         let mut decoder = DecoderReader::new(&mut wrapped_reader, &engine);
150 
151         consume_with_short_reads_and_validate(&mut rng, &bytes[..], &mut decoded, &mut decoder);
152     }
153 }
154 
155 #[test]
read_in_short_increments_with_short_delegate_reads()156 fn read_in_short_increments_with_short_delegate_reads() {
157     let mut rng = rand::thread_rng();
158     let mut bytes = Vec::new();
159     let mut b64 = String::new();
160     let mut decoded = Vec::new();
161 
162     for _ in 0..10_000 {
163         bytes.clear();
164         b64.clear();
165         decoded.clear();
166 
167         let size = rng.gen_range(0..(10 * BUF_SIZE));
168         bytes.extend(iter::repeat(0).take(size));
169         // leave room to play around with larger buffers
170         decoded.extend(iter::repeat(0).take(size * 3));
171 
172         rng.fill_bytes(&mut bytes[..]);
173         assert_eq!(size, bytes.len());
174 
175         let engine = random_engine(&mut rng);
176 
177         engine.encode_string(&bytes[..], &mut b64);
178 
179         let mut base_reader = io::Cursor::new(&b64[..]);
180         let mut decoder = DecoderReader::new(&mut base_reader, &engine);
181         let mut short_reader = RandomShortRead {
182             delegate: &mut decoder,
183             rng: &mut rand::thread_rng(),
184         };
185 
186         consume_with_short_reads_and_validate(
187             &mut rng,
188             &bytes[..],
189             &mut decoded,
190             &mut short_reader,
191         );
192     }
193 }
194 
195 #[test]
reports_invalid_last_symbol_correctly()196 fn reports_invalid_last_symbol_correctly() {
197     let mut rng = rand::thread_rng();
198     let mut bytes = Vec::new();
199     let mut b64 = String::new();
200     let mut b64_bytes = Vec::new();
201     let mut decoded = Vec::new();
202     let mut bulk_decoded = Vec::new();
203 
204     for _ in 0..1_000 {
205         bytes.clear();
206         b64.clear();
207         b64_bytes.clear();
208 
209         let size = rng.gen_range(1..(10 * BUF_SIZE));
210         bytes.extend(iter::repeat(0).take(size));
211         decoded.extend(iter::repeat(0).take(size));
212         rng.fill_bytes(&mut bytes[..]);
213         assert_eq!(size, bytes.len());
214 
215         let config = random_config(&mut rng);
216         let alphabet = random_alphabet(&mut rng);
217         // changing padding will cause invalid padding errors when we twiddle the last byte
218         let engine = GeneralPurpose::new(alphabet, config.with_encode_padding(false));
219         engine.encode_string(&bytes[..], &mut b64);
220         b64_bytes.extend(b64.bytes());
221         assert_eq!(b64_bytes.len(), b64.len());
222 
223         // change the last character to every possible symbol. Should behave the same as bulk
224         // decoding whether invalid or valid.
225         for &s1 in alphabet.symbols.iter() {
226             decoded.clear();
227             bulk_decoded.clear();
228 
229             // replace the last
230             *b64_bytes.last_mut().unwrap() = s1;
231             let bulk_res = engine.decode_vec(&b64_bytes[..], &mut bulk_decoded);
232 
233             let mut wrapped_reader = io::Cursor::new(&b64_bytes[..]);
234             let mut decoder = DecoderReader::new(&mut wrapped_reader, &engine);
235 
236             let stream_res = decoder.read_to_end(&mut decoded).map(|_| ()).map_err(|e| {
237                 e.into_inner()
238                     .and_then(|e| e.downcast::<DecodeError>().ok())
239             });
240 
241             assert_eq!(bulk_res.map_err(|e| Some(Box::new(e))), stream_res);
242         }
243     }
244 }
245 
246 #[test]
reports_invalid_byte_correctly()247 fn reports_invalid_byte_correctly() {
248     let mut rng = rand::thread_rng();
249     let mut bytes = Vec::new();
250     let mut b64 = String::new();
251     let mut stream_decoded = Vec::new();
252     let mut bulk_decoded = Vec::new();
253 
254     for _ in 0..10_000 {
255         bytes.clear();
256         b64.clear();
257         stream_decoded.clear();
258         bulk_decoded.clear();
259 
260         let size = rng.gen_range(1..(10 * BUF_SIZE));
261         bytes.extend(iter::repeat(0).take(size));
262         rng.fill_bytes(&mut bytes[..size]);
263         assert_eq!(size, bytes.len());
264 
265         let engine = GeneralPurpose::new(&alphabet::STANDARD, random_config(&mut rng));
266 
267         engine.encode_string(&bytes[..], &mut b64);
268         // replace one byte, somewhere, with '*', which is invalid
269         let bad_byte_pos = rng.gen_range(0..b64.len());
270         let mut b64_bytes = b64.bytes().collect::<Vec<u8>>();
271         b64_bytes[bad_byte_pos] = b'*';
272 
273         let mut wrapped_reader = io::Cursor::new(b64_bytes.clone());
274         let mut decoder = DecoderReader::new(&mut wrapped_reader, &engine);
275 
276         let read_decode_err = decoder
277             .read_to_end(&mut stream_decoded)
278             .map_err(|e| {
279                 let kind = e.kind();
280                 let inner = e
281                     .into_inner()
282                     .and_then(|e| e.downcast::<DecodeError>().ok());
283                 inner.map(|i| (*i, kind))
284             })
285             .err()
286             .and_then(|o| o);
287 
288         let bulk_decode_err = engine.decode_vec(&b64_bytes[..], &mut bulk_decoded).err();
289 
290         // it's tricky to predict where the invalid data's offset will be since if it's in the last
291         // chunk it will be reported at the first padding location because it's treated as invalid
292         // padding. So, we just check that it's the same as it is for decoding all at once.
293         assert_eq!(
294             bulk_decode_err.map(|e| (e, io::ErrorKind::InvalidData)),
295             read_decode_err
296         );
297     }
298 }
299 
300 #[test]
internal_padding_error_with_short_read_concatenated_texts_invalid_byte_error()301 fn internal_padding_error_with_short_read_concatenated_texts_invalid_byte_error() {
302     let mut rng = rand::thread_rng();
303     let mut bytes = Vec::new();
304     let mut b64 = String::new();
305     let mut reader_decoded = Vec::new();
306     let mut bulk_decoded = Vec::new();
307 
308     // encodes with padding, requires that padding be present so we don't get InvalidPadding
309     // just because padding is there at all
310     let engine = STANDARD;
311 
312     for _ in 0..10_000 {
313         bytes.clear();
314         b64.clear();
315         reader_decoded.clear();
316         bulk_decoded.clear();
317 
318         // at least 2 bytes so there can be a split point between bytes
319         let size = rng.gen_range(2..(10 * BUF_SIZE));
320         bytes.resize(size, 0);
321         rng.fill_bytes(&mut bytes[..size]);
322 
323         // Concatenate two valid b64s, yielding padding in the middle.
324         // This avoids scenarios that are challenging to assert on, like random padding location
325         // that might be InvalidLastSymbol when decoded at certain buffer sizes but InvalidByte
326         // when done all at once.
327         let split = loop {
328             // find a split point that will produce padding on the first part
329             let s = rng.gen_range(1..size);
330             if s % 3 != 0 {
331                 // short enough to need padding
332                 break s;
333             };
334         };
335 
336         engine.encode_string(&bytes[..split], &mut b64);
337         assert!(b64.contains('='), "split: {}, b64: {}", split, b64);
338         let bad_byte_pos = b64.find('=').unwrap();
339         engine.encode_string(&bytes[split..], &mut b64);
340         let b64_bytes = b64.as_bytes();
341 
342         // short read to make it plausible for padding to happen on a read boundary
343         let read_len = rng.gen_range(1..10);
344         let mut wrapped_reader = ShortRead {
345             max_read_len: read_len,
346             delegate: io::Cursor::new(&b64_bytes),
347         };
348 
349         let mut decoder = DecoderReader::new(&mut wrapped_reader, &engine);
350 
351         let read_decode_err = decoder
352             .read_to_end(&mut reader_decoded)
353             .map_err(|e| {
354                 *e.into_inner()
355                     .and_then(|e| e.downcast::<DecodeError>().ok())
356                     .unwrap()
357             })
358             .unwrap_err();
359 
360         let bulk_decode_err = engine.decode_vec(b64_bytes, &mut bulk_decoded).unwrap_err();
361 
362         assert_eq!(
363             bulk_decode_err,
364             read_decode_err,
365             "read len: {}, bad byte pos: {}, b64: {}",
366             read_len,
367             bad_byte_pos,
368             std::str::from_utf8(b64_bytes).unwrap()
369         );
370         assert_eq!(
371             DecodeError::InvalidByte(
372                 split / 3 * 4
373                     + match split % 3 {
374                         1 => 2,
375                         2 => 3,
376                         _ => unreachable!(),
377                     },
378                 PAD_BYTE
379             ),
380             read_decode_err
381         );
382     }
383 }
384 
385 #[test]
internal_padding_anywhere_error()386 fn internal_padding_anywhere_error() {
387     let mut rng = rand::thread_rng();
388     let mut bytes = Vec::new();
389     let mut b64 = String::new();
390     let mut reader_decoded = Vec::new();
391 
392     // encodes with padding, requires that padding be present so we don't get InvalidPadding
393     // just because padding is there at all
394     let engine = STANDARD;
395 
396     for _ in 0..10_000 {
397         bytes.clear();
398         b64.clear();
399         reader_decoded.clear();
400 
401         bytes.resize(10 * BUF_SIZE, 0);
402         rng.fill_bytes(&mut bytes[..]);
403 
404         // Just shove a padding byte in there somewhere.
405         // The specific error to expect is challenging to predict precisely because it
406         // will vary based on the position of the padding in the quad and the read buffer
407         // length, but SOMETHING should go wrong.
408 
409         engine.encode_string(&bytes[..], &mut b64);
410         let mut b64_bytes = b64.as_bytes().to_vec();
411         // put padding somewhere other than the last quad
412         b64_bytes[rng.gen_range(0..bytes.len() - 4)] = PAD_BYTE;
413 
414         // short read to make it plausible for padding to happen on a read boundary
415         let read_len = rng.gen_range(1..10);
416         let mut wrapped_reader = ShortRead {
417             max_read_len: read_len,
418             delegate: io::Cursor::new(&b64_bytes),
419         };
420 
421         let mut decoder = DecoderReader::new(&mut wrapped_reader, &engine);
422 
423         let result = decoder.read_to_end(&mut reader_decoded);
424         assert!(result.is_err());
425     }
426 }
427 
consume_with_short_reads_and_validate<R: io::Read>( rng: &mut rand::rngs::ThreadRng, expected_bytes: &[u8], decoded: &mut [u8], short_reader: &mut R, )428 fn consume_with_short_reads_and_validate<R: io::Read>(
429     rng: &mut rand::rngs::ThreadRng,
430     expected_bytes: &[u8],
431     decoded: &mut [u8],
432     short_reader: &mut R,
433 ) {
434     let mut total_read = 0_usize;
435     loop {
436         assert!(
437             total_read <= expected_bytes.len(),
438             "tr {} size {}",
439             total_read,
440             expected_bytes.len()
441         );
442         if total_read == expected_bytes.len() {
443             assert_eq!(expected_bytes, &decoded[..total_read]);
444             // should be done
445             assert_eq!(0, short_reader.read(&mut *decoded).unwrap());
446             // didn't write anything
447             assert_eq!(expected_bytes, &decoded[..total_read]);
448 
449             break;
450         }
451         let decode_len = rng.gen_range(1..cmp::max(2, expected_bytes.len() * 2));
452 
453         let read = short_reader
454             .read(&mut decoded[total_read..total_read + decode_len])
455             .unwrap();
456         total_read += read;
457     }
458 }
459 
460 /// Limits how many bytes a reader will provide in each read call.
461 /// Useful for shaking out code that may work fine only with typical input sources that always fill
462 /// the buffer.
463 struct RandomShortRead<'a, 'b, R: io::Read, N: rand::Rng> {
464     delegate: &'b mut R,
465     rng: &'a mut N,
466 }
467 
468 impl<'a, 'b, R: io::Read, N: rand::Rng> io::Read for RandomShortRead<'a, 'b, R, N> {
read(&mut self, buf: &mut [u8]) -> Result<usize, io::Error>469     fn read(&mut self, buf: &mut [u8]) -> Result<usize, io::Error> {
470         // avoid 0 since it means EOF for non-empty buffers
471         let effective_len = cmp::min(self.rng.gen_range(1..20), buf.len());
472 
473         self.delegate.read(&mut buf[..effective_len])
474     }
475 }
476 
477 struct ShortRead<R: io::Read> {
478     delegate: R,
479     max_read_len: usize,
480 }
481 
482 impl<R: io::Read> io::Read for ShortRead<R> {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>483     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
484         let len = self.max_read_len.max(buf.len());
485         self.delegate.read(&mut buf[..len])
486     }
487 }
488