1 use std::borrow::Borrow;
2 use std::collections::HashMap;
3 use std::error;
4 use std::fmt;
5 use std::io;
6 use std::result;
7 
8 use super::{TrieSetSlice, CHUNK_SIZE};
9 
10 // This implementation was pretty much cribbed from raphlinus' contribution
11 // to the standard library: https://github.com/rust-lang/rust/pull/33098/files
12 //
13 // The fundamental principle guiding this implementation is to take advantage
14 // of the fact that similar Unicode codepoints are often grouped together, and
15 // that most boolean Unicode properties are quite sparse over the entire space
16 // of Unicode codepoints.
17 //
18 // To do this, we represent sets using something like a trie (which gives us
19 // prefix compression). The "final" states of the trie are embedded in leaves
20 // or "chunks," where each chunk is a 64 bit integer. Each bit position of the
21 // integer corresponds to whether a particular codepoint is in the set or not.
22 // These chunks are not just a compact representation of the final states of
23 // the trie, but are also a form of suffix compression. In particular, if
24 // multiple ranges of 64 contiguous codepoints map have the same set membership
25 // ordering, then they all map to the exact same chunk in the trie.
26 //
27 // We organize this structure by partitioning the space of Unicode codepoints
28 // into three disjoint sets. The first set corresponds to codepoints
29 // [0, 0x800), the second [0x800, 0x1000) and the third [0x10000, 0x110000).
30 // These partitions conveniently correspond to the space of 1 or 2 byte UTF-8
31 // encoded codepoints, 3 byte UTF-8 encoded codepoints and 4 byte UTF-8 encoded
32 // codepoints, respectively.
33 //
34 // Each partition has its own tree with its own root. The first partition is
35 // the simplest, since the tree is completely flat. In particular, to determine
36 // the set membership of a Unicode codepoint (that is less than `0x800`), we
37 // do the following (where `cp` is the codepoint we're testing):
38 //
39 //     let chunk_address = cp >> 6;
40 //     let chunk_bit = cp & 0b111111;
41 //     let chunk = tree1[cp >> 6];
42 //     let is_member = 1 == ((chunk >> chunk_bit) & 1);
43 //
44 // We do something similar for the second partition:
45 //
46 //     // we subtract 0x20 since (0x800 >> 6) == 0x20.
47 //     let child_address = (cp >> 6) - 0x20;
48 //     let chunk_address = tree2_level1[child_address];
49 //     let chunk_bit = cp & 0b111111;
50 //     let chunk = tree2_level2[chunk_address];
51 //     let is_member = 1 == ((chunk >> chunk_bit) & 1);
52 //
53 // And so on for the third partition.
54 //
55 // Note that as a special case, if the second or third partitions are empty,
56 // then the trie will store empty slices for those levels. The `contains`
57 // check knows to return `false` in those cases.
58 
59 const CHUNKS: usize = 0x110000 / CHUNK_SIZE;
60 
61 /// A type alias that maps to `std::result::Result<T, ucd_trie::Error>`.
62 pub type Result<T> = result::Result<T, Error>;
63 
64 /// An error that can occur during construction of a trie.
65 #[derive(Clone, Debug)]
66 pub enum Error {
67     /// This error is returned when an invalid codepoint is given to
68     /// `TrieSetOwned::from_codepoints`. An invalid codepoint is a `u32` that
69     /// is greater than `0x10FFFF`.
70     InvalidCodepoint(u32),
71     /// This error is returned when a set of Unicode codepoints could not be
72     /// sufficiently compressed into the trie provided by this crate. There is
73     /// no work-around for this error at this time.
74     GaveUp,
75 }
76 
77 impl error::Error for Error {}
78 
79 impl fmt::Display for Error {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result80     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81         match *self {
82             Error::InvalidCodepoint(cp) => write!(
83                 f,
84                 "could not construct trie set containing an \
85                  invalid Unicode codepoint: 0x{:X}",
86                 cp
87             ),
88             Error::GaveUp => {
89                 write!(f, "could not compress codepoint set into a trie")
90             }
91         }
92     }
93 }
94 
95 impl From<Error> for io::Error {
from(err: Error) -> io::Error96     fn from(err: Error) -> io::Error {
97         io::Error::new(io::ErrorKind::Other, err)
98     }
99 }
100 
101 /// An owned trie set.
102 #[derive(Clone)]
103 pub struct TrieSetOwned {
104     tree1_level1: Vec<u64>,
105     tree2_level1: Vec<u8>,
106     tree2_level2: Vec<u64>,
107     tree3_level1: Vec<u8>,
108     tree3_level2: Vec<u8>,
109     tree3_level3: Vec<u64>,
110 }
111 
112 impl fmt::Debug for TrieSetOwned {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result113     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
114         write!(f, "TrieSetOwned(...)")
115     }
116 }
117 
118 impl TrieSetOwned {
new(all: &[bool]) -> Result<TrieSetOwned>119     fn new(all: &[bool]) -> Result<TrieSetOwned> {
120         let mut bitvectors = Vec::with_capacity(CHUNKS);
121         for i in 0..CHUNKS {
122             let mut bitvector = 0u64;
123             for j in 0..CHUNK_SIZE {
124                 if all[i * CHUNK_SIZE + j] {
125                     bitvector |= 1 << j;
126                 }
127             }
128             bitvectors.push(bitvector);
129         }
130 
131         let tree1_level1 =
132             bitvectors.iter().cloned().take(0x800 / CHUNK_SIZE).collect();
133 
134         let (mut tree2_level1, mut tree2_level2) = compress_postfix_leaves(
135             &bitvectors[0x800 / CHUNK_SIZE..0x10000 / CHUNK_SIZE],
136         )?;
137         if tree2_level2.len() == 1 && tree2_level2[0] == 0 {
138             tree2_level1.clear();
139             tree2_level2.clear();
140         }
141 
142         let (mid, mut tree3_level3) = compress_postfix_leaves(
143             &bitvectors[0x10000 / CHUNK_SIZE..0x110000 / CHUNK_SIZE],
144         )?;
145         let (mut tree3_level1, mut tree3_level2) =
146             compress_postfix_mid(&mid, 64)?;
147         if tree3_level3.len() == 1 && tree3_level3[0] == 0 {
148             tree3_level1.clear();
149             tree3_level2.clear();
150             tree3_level3.clear();
151         }
152 
153         Ok(TrieSetOwned {
154             tree1_level1,
155             tree2_level1,
156             tree2_level2,
157             tree3_level1,
158             tree3_level2,
159             tree3_level3,
160         })
161     }
162 
163     /// Create a new trie set from a set of Unicode scalar values.
164     ///
165     /// This returns an error if a set could not be sufficiently compressed to
166     /// fit into a trie.
from_scalars<I, C>(scalars: I) -> Result<TrieSetOwned> where I: IntoIterator<Item = C>, C: Borrow<char>,167     pub fn from_scalars<I, C>(scalars: I) -> Result<TrieSetOwned>
168     where
169         I: IntoIterator<Item = C>,
170         C: Borrow<char>,
171     {
172         let mut all = vec![false; 0x110000];
173         for s in scalars {
174             all[*s.borrow() as usize] = true;
175         }
176         TrieSetOwned::new(&all)
177     }
178 
179     /// Create a new trie set from a set of Unicode scalar values.
180     ///
181     /// This returns an error if a set could not be sufficiently compressed to
182     /// fit into a trie. This also returns an error if any of the given
183     /// codepoints are greater than `0x10FFFF`.
from_codepoints<I, C>(codepoints: I) -> Result<TrieSetOwned> where I: IntoIterator<Item = C>, C: Borrow<u32>,184     pub fn from_codepoints<I, C>(codepoints: I) -> Result<TrieSetOwned>
185     where
186         I: IntoIterator<Item = C>,
187         C: Borrow<u32>,
188     {
189         let mut all = vec![false; 0x110000];
190         for cp in codepoints {
191             let cp = *cp.borrow();
192             if cp > 0x10FFFF {
193                 return Err(Error::InvalidCodepoint(cp));
194             }
195             all[cp as usize] = true;
196         }
197         TrieSetOwned::new(&all)
198     }
199 
200     /// Return this set as a slice.
201     #[inline(always)]
as_slice(&self) -> TrieSetSlice<'_>202     pub fn as_slice(&self) -> TrieSetSlice<'_> {
203         TrieSetSlice {
204             tree1_level1: &self.tree1_level1,
205             tree2_level1: &self.tree2_level1,
206             tree2_level2: &self.tree2_level2,
207             tree3_level1: &self.tree3_level1,
208             tree3_level2: &self.tree3_level2,
209             tree3_level3: &self.tree3_level3,
210         }
211     }
212 
213     /// Returns true if and only if the given Unicode scalar value is in this
214     /// set.
contains_char(&self, c: char) -> bool215     pub fn contains_char(&self, c: char) -> bool {
216         self.as_slice().contains_char(c)
217     }
218 
219     /// Returns true if and only if the given codepoint is in this set.
220     ///
221     /// If the given value exceeds the codepoint range (i.e., it's greater
222     /// than `0x10FFFF`), then this returns false.
contains_u32(&self, cp: u32) -> bool223     pub fn contains_u32(&self, cp: u32) -> bool {
224         self.as_slice().contains_u32(cp)
225     }
226 }
227 
compress_postfix_leaves(chunks: &[u64]) -> Result<(Vec<u8>, Vec<u64>)>228 fn compress_postfix_leaves(chunks: &[u64]) -> Result<(Vec<u8>, Vec<u64>)> {
229     let mut root = vec![];
230     let mut children = vec![];
231     let mut bychild = HashMap::new();
232     for &chunk in chunks {
233         if !bychild.contains_key(&chunk) {
234             let start = bychild.len();
235             if start > ::std::u8::MAX as usize {
236                 return Err(Error::GaveUp);
237             }
238             bychild.insert(chunk, start as u8);
239             children.push(chunk);
240         }
241         root.push(bychild[&chunk]);
242     }
243     Ok((root, children))
244 }
245 
compress_postfix_mid( chunks: &[u8], chunk_size: usize, ) -> Result<(Vec<u8>, Vec<u8>)>246 fn compress_postfix_mid(
247     chunks: &[u8],
248     chunk_size: usize,
249 ) -> Result<(Vec<u8>, Vec<u8>)> {
250     let mut root = vec![];
251     let mut children = vec![];
252     let mut bychild = HashMap::new();
253     for i in 0..(chunks.len() / chunk_size) {
254         let chunk = &chunks[i * chunk_size..(i + 1) * chunk_size];
255         if !bychild.contains_key(chunk) {
256             let start = bychild.len();
257             if start > ::std::u8::MAX as usize {
258                 return Err(Error::GaveUp);
259             }
260             bychild.insert(chunk, start as u8);
261             children.extend(chunk);
262         }
263         root.push(bychild[chunk]);
264     }
265     Ok((root, children))
266 }
267 
268 #[cfg(test)]
269 mod tests {
270     use super::TrieSetOwned;
271     use crate::general_category;
272     use std::collections::HashSet;
273 
mk(scalars: &[char]) -> TrieSetOwned274     fn mk(scalars: &[char]) -> TrieSetOwned {
275         TrieSetOwned::from_scalars(scalars).unwrap()
276     }
277 
ranges_to_set(ranges: &[(u32, u32)]) -> Vec<u32>278     fn ranges_to_set(ranges: &[(u32, u32)]) -> Vec<u32> {
279         let mut set = vec![];
280         for &(start, end) in ranges {
281             for cp in start..end + 1 {
282                 set.push(cp);
283             }
284         }
285         set
286     }
287 
288     #[test]
set1()289     fn set1() {
290         let set = mk(&['a']);
291         assert!(set.contains_char('a'));
292         assert!(!set.contains_char('b'));
293         assert!(!set.contains_char('β'));
294         assert!(!set.contains_char('☃'));
295         assert!(!set.contains_char('��'));
296     }
297 
298     #[test]
set_combined()299     fn set_combined() {
300         let set = mk(&['a', 'b', 'β', '☃', '��']);
301         assert!(set.contains_char('a'));
302         assert!(set.contains_char('b'));
303         assert!(set.contains_char('β'));
304         assert!(set.contains_char('☃'));
305         assert!(set.contains_char('��'));
306 
307         assert!(!set.contains_char('c'));
308         assert!(!set.contains_char('θ'));
309         assert!(!set.contains_char('⛇'));
310         assert!(!set.contains_char('��'));
311     }
312 
313     // Basic tests on all of the general category sets. We check that
314     // membership is correct on every Unicode codepoint... because we can.
315 
316     macro_rules! category_test {
317         ($name:ident, $ranges:ident) => {
318             #[test]
319             fn $name() {
320                 let set = ranges_to_set(general_category::$ranges);
321                 let hashset: HashSet<u32> = set.iter().cloned().collect();
322                 let trie = TrieSetOwned::from_codepoints(&set).unwrap();
323                 for cp in 0..0x110000 {
324                     assert!(trie.contains_u32(cp) == hashset.contains(&cp));
325                 }
326                 // Test that an invalid codepoint is treated correctly.
327                 assert!(!trie.contains_u32(0x110000));
328                 assert!(!hashset.contains(&0x110000));
329             }
330         };
331     }
332 
333     category_test!(gencat_cased_letter, CASED_LETTER);
334     category_test!(gencat_close_punctuation, CLOSE_PUNCTUATION);
335     category_test!(gencat_connector_punctuation, CONNECTOR_PUNCTUATION);
336     category_test!(gencat_control, CONTROL);
337     category_test!(gencat_currency_symbol, CURRENCY_SYMBOL);
338     category_test!(gencat_dash_punctuation, DASH_PUNCTUATION);
339     category_test!(gencat_decimal_number, DECIMAL_NUMBER);
340     category_test!(gencat_enclosing_mark, ENCLOSING_MARK);
341     category_test!(gencat_final_punctuation, FINAL_PUNCTUATION);
342     category_test!(gencat_format, FORMAT);
343     category_test!(gencat_initial_punctuation, INITIAL_PUNCTUATION);
344     category_test!(gencat_letter, LETTER);
345     category_test!(gencat_letter_number, LETTER_NUMBER);
346     category_test!(gencat_line_separator, LINE_SEPARATOR);
347     category_test!(gencat_lowercase_letter, LOWERCASE_LETTER);
348     category_test!(gencat_math_symbol, MATH_SYMBOL);
349     category_test!(gencat_mark, MARK);
350     category_test!(gencat_modifier_letter, MODIFIER_LETTER);
351     category_test!(gencat_modifier_symbol, MODIFIER_SYMBOL);
352     category_test!(gencat_nonspacing_mark, NONSPACING_MARK);
353     category_test!(gencat_number, NUMBER);
354     category_test!(gencat_open_punctuation, OPEN_PUNCTUATION);
355     category_test!(gencat_other, OTHER);
356     category_test!(gencat_other_letter, OTHER_LETTER);
357     category_test!(gencat_other_number, OTHER_NUMBER);
358     category_test!(gencat_other_punctuation, OTHER_PUNCTUATION);
359     category_test!(gencat_other_symbol, OTHER_SYMBOL);
360     category_test!(gencat_paragraph_separator, PARAGRAPH_SEPARATOR);
361     category_test!(gencat_private_use, PRIVATE_USE);
362     category_test!(gencat_punctuation, PUNCTUATION);
363     category_test!(gencat_separator, SEPARATOR);
364     category_test!(gencat_space_separator, SPACE_SEPARATOR);
365     category_test!(gencat_spacing_mark, SPACING_MARK);
366     category_test!(gencat_surrogate, SURROGATE);
367     category_test!(gencat_symbol, SYMBOL);
368     category_test!(gencat_titlecase_letter, TITLECASE_LETTER);
369     category_test!(gencat_unassigned, UNASSIGNED);
370     category_test!(gencat_uppercase_letter, UPPERCASE_LETTER);
371 }
372