1 use rand::{Rng, SeedableRng};
2 
3 use base64::engine::{general_purpose::STANDARD, Engine};
4 use base64::*;
5 
6 use base64::engine::general_purpose::{GeneralPurpose, NO_PAD};
7 
8 // generate random contents of the specified length and test encode/decode roundtrip
roundtrip_random<E: Engine>( byte_buf: &mut Vec<u8>, str_buf: &mut String, engine: &E, byte_len: usize, approx_values_per_byte: u8, max_rounds: u64, )9 fn roundtrip_random<E: Engine>(
10     byte_buf: &mut Vec<u8>,
11     str_buf: &mut String,
12     engine: &E,
13     byte_len: usize,
14     approx_values_per_byte: u8,
15     max_rounds: u64,
16 ) {
17     // let the short ones be short but don't let it get too crazy large
18     let num_rounds = calculate_number_of_rounds(byte_len, approx_values_per_byte, max_rounds);
19     let mut r = rand::rngs::SmallRng::from_entropy();
20     let mut decode_buf = Vec::new();
21 
22     for _ in 0..num_rounds {
23         byte_buf.clear();
24         str_buf.clear();
25         decode_buf.clear();
26         while byte_buf.len() < byte_len {
27             byte_buf.push(r.gen::<u8>());
28         }
29 
30         engine.encode_string(&byte_buf, str_buf);
31         engine.decode_vec(&str_buf, &mut decode_buf).unwrap();
32 
33         assert_eq!(byte_buf, &decode_buf);
34     }
35 }
36 
calculate_number_of_rounds(byte_len: usize, approx_values_per_byte: u8, max: u64) -> u6437 fn calculate_number_of_rounds(byte_len: usize, approx_values_per_byte: u8, max: u64) -> u64 {
38     // don't overflow
39     let mut prod = approx_values_per_byte as u64;
40 
41     for _ in 0..byte_len {
42         if prod > max {
43             return max;
44         }
45 
46         prod = prod.saturating_mul(prod);
47     }
48 
49     prod
50 }
51 
52 #[test]
roundtrip_random_short_standard()53 fn roundtrip_random_short_standard() {
54     let mut byte_buf: Vec<u8> = Vec::new();
55     let mut str_buf = String::new();
56 
57     for input_len in 0..40 {
58         roundtrip_random(&mut byte_buf, &mut str_buf, &STANDARD, input_len, 4, 10000);
59     }
60 }
61 
62 #[test]
roundtrip_random_with_fast_loop_standard()63 fn roundtrip_random_with_fast_loop_standard() {
64     let mut byte_buf: Vec<u8> = Vec::new();
65     let mut str_buf = String::new();
66 
67     for input_len in 40..100 {
68         roundtrip_random(&mut byte_buf, &mut str_buf, &STANDARD, input_len, 4, 1000);
69     }
70 }
71 
72 #[test]
roundtrip_random_short_no_padding()73 fn roundtrip_random_short_no_padding() {
74     let mut byte_buf: Vec<u8> = Vec::new();
75     let mut str_buf = String::new();
76 
77     let engine = GeneralPurpose::new(&alphabet::STANDARD, NO_PAD);
78     for input_len in 0..40 {
79         roundtrip_random(&mut byte_buf, &mut str_buf, &engine, input_len, 4, 10000);
80     }
81 }
82 
83 #[test]
roundtrip_random_no_padding()84 fn roundtrip_random_no_padding() {
85     let mut byte_buf: Vec<u8> = Vec::new();
86     let mut str_buf = String::new();
87 
88     let engine = GeneralPurpose::new(&alphabet::STANDARD, NO_PAD);
89 
90     for input_len in 40..100 {
91         roundtrip_random(&mut byte_buf, &mut str_buf, &engine, input_len, 4, 1000);
92     }
93 }
94 
95 #[test]
roundtrip_decode_trailing_10_bytes()96 fn roundtrip_decode_trailing_10_bytes() {
97     // This is a special case because we decode 8 byte blocks of input at a time as much as we can,
98     // ideally unrolled to 32 bytes at a time, in stages 1 and 2. Since we also write a u64's worth
99     // of bytes (8) to the output, we always write 2 garbage bytes that then will be overwritten by
100     // the NEXT block. However, if the next block only contains 2 bytes, it will decode to 1 byte,
101     // and therefore be too short to cover up the trailing 2 garbage bytes. Thus, we have stage 3
102     // to handle that case.
103 
104     for num_quads in 0..25 {
105         let mut s: String = "ABCD".repeat(num_quads);
106         s.push_str("EFGHIJKLZg");
107 
108         let engine = GeneralPurpose::new(&alphabet::STANDARD, NO_PAD);
109         let decoded = engine.decode(&s).unwrap();
110         assert_eq!(num_quads * 3 + 7, decoded.len());
111 
112         assert_eq!(s, engine.encode(&decoded));
113     }
114 }
115 
116 #[test]
display_wrapper_matches_normal_encode()117 fn display_wrapper_matches_normal_encode() {
118     let mut bytes = Vec::<u8>::with_capacity(256);
119 
120     for i in 0..255 {
121         bytes.push(i);
122     }
123     bytes.push(255);
124 
125     assert_eq!(
126         STANDARD.encode(&bytes),
127         format!("{}", display::Base64Display::new(&bytes, &STANDARD))
128     );
129 }
130 
131 #[test]
encode_engine_slice_error_when_buffer_too_small()132 fn encode_engine_slice_error_when_buffer_too_small() {
133     for num_triples in 1..100 {
134         let input = "AAA".repeat(num_triples);
135         let mut vec = vec![0; (num_triples - 1) * 4];
136         assert_eq!(
137             EncodeSliceError::OutputSliceTooSmall,
138             STANDARD.encode_slice(&input, &mut vec).unwrap_err()
139         );
140         vec.push(0);
141         assert_eq!(
142             EncodeSliceError::OutputSliceTooSmall,
143             STANDARD.encode_slice(&input, &mut vec).unwrap_err()
144         );
145         vec.push(0);
146         assert_eq!(
147             EncodeSliceError::OutputSliceTooSmall,
148             STANDARD.encode_slice(&input, &mut vec).unwrap_err()
149         );
150         vec.push(0);
151         assert_eq!(
152             EncodeSliceError::OutputSliceTooSmall,
153             STANDARD.encode_slice(&input, &mut vec).unwrap_err()
154         );
155         vec.push(0);
156         assert_eq!(
157             num_triples * 4,
158             STANDARD.encode_slice(&input, &mut vec).unwrap()
159         );
160     }
161 }
162