1 //! OID encoder with `const` support.
2 
3 use crate::{
4     arcs::{ARC_MAX_FIRST, ARC_MAX_SECOND},
5     Arc, Error, ObjectIdentifier, Result,
6 };
7 
8 /// BER/DER encoder
9 #[derive(Debug)]
10 pub(crate) struct Encoder {
11     /// Current state
12     state: State,
13 
14     /// Bytes of the OID being encoded in-progress
15     bytes: [u8; ObjectIdentifier::MAX_SIZE],
16 
17     /// Current position within the byte buffer
18     cursor: usize,
19 }
20 
21 /// Current state of the encoder
22 #[derive(Debug)]
23 enum State {
24     /// Initial state - no arcs yet encoded
25     Initial,
26 
27     /// First arc parsed
28     FirstArc(Arc),
29 
30     /// Encoding base 128 body of the OID
31     Body,
32 }
33 
34 impl Encoder {
35     /// Create a new encoder initialized to an empty default state.
new() -> Self36     pub(crate) const fn new() -> Self {
37         Self {
38             state: State::Initial,
39             bytes: [0u8; ObjectIdentifier::MAX_SIZE],
40             cursor: 0,
41         }
42     }
43 
44     /// Extend an existing OID.
extend(oid: ObjectIdentifier) -> Self45     pub(crate) const fn extend(oid: ObjectIdentifier) -> Self {
46         Self {
47             state: State::Body,
48             bytes: oid.bytes,
49             cursor: oid.length as usize,
50         }
51     }
52 
53     /// Encode an [`Arc`] as base 128 into the internal buffer.
arc(mut self, arc: Arc) -> Result<Self>54     pub(crate) const fn arc(mut self, arc: Arc) -> Result<Self> {
55         match self.state {
56             State::Initial => {
57                 if arc > ARC_MAX_FIRST {
58                     return Err(Error::ArcInvalid { arc });
59                 }
60 
61                 self.state = State::FirstArc(arc);
62                 Ok(self)
63             }
64             // Ensured not to overflow by `ARC_MAX_SECOND` check
65             #[allow(clippy::integer_arithmetic)]
66             State::FirstArc(first_arc) => {
67                 if arc > ARC_MAX_SECOND {
68                     return Err(Error::ArcInvalid { arc });
69                 }
70 
71                 self.state = State::Body;
72                 self.bytes[0] = (first_arc * (ARC_MAX_SECOND + 1)) as u8 + arc as u8;
73                 self.cursor = 1;
74                 Ok(self)
75             }
76             // TODO(tarcieri): finer-grained overflow safety / checked arithmetic
77             #[allow(clippy::integer_arithmetic)]
78             State::Body => {
79                 // Total number of bytes in encoded arc - 1
80                 let nbytes = base128_len(arc);
81 
82                 // Shouldn't overflow on any 16-bit+ architectures
83                 if self.cursor + nbytes + 1 >= ObjectIdentifier::MAX_SIZE {
84                     return Err(Error::Length);
85                 }
86 
87                 let new_cursor = self.cursor + nbytes + 1;
88 
89                 // TODO(tarcieri): use `?` when stable in `const fn`
90                 match self.encode_base128_byte(arc, nbytes, false) {
91                     Ok(mut encoder) => {
92                         encoder.cursor = new_cursor;
93                         Ok(encoder)
94                     }
95                     Err(err) => Err(err),
96                 }
97             }
98         }
99     }
100 
101     /// Finish encoding an OID.
finish(self) -> Result<ObjectIdentifier>102     pub(crate) const fn finish(self) -> Result<ObjectIdentifier> {
103         if self.cursor >= 2 {
104             Ok(ObjectIdentifier {
105                 bytes: self.bytes,
106                 length: self.cursor as u8,
107             })
108         } else {
109             Err(Error::NotEnoughArcs)
110         }
111     }
112 
113     /// Encode a single byte of a Base 128 value.
encode_base128_byte(mut self, mut n: u32, i: usize, continued: bool) -> Result<Self>114     const fn encode_base128_byte(mut self, mut n: u32, i: usize, continued: bool) -> Result<Self> {
115         let mask = if continued { 0b10000000 } else { 0 };
116 
117         // Underflow checked by branch
118         #[allow(clippy::integer_arithmetic)]
119         if n > 0x80 {
120             self.bytes[checked_add!(self.cursor, i)] = (n & 0b1111111) as u8 | mask;
121             n >>= 7;
122 
123             if i > 0 {
124                 self.encode_base128_byte(n, i.saturating_sub(1), true)
125             } else {
126                 Err(Error::Base128)
127             }
128         } else {
129             self.bytes[self.cursor] = n as u8 | mask;
130             Ok(self)
131         }
132     }
133 }
134 
135 /// Compute the length - 1 of an arc when encoded in base 128.
base128_len(arc: Arc) -> usize136 const fn base128_len(arc: Arc) -> usize {
137     match arc {
138         0..=0x7f => 0,
139         0x80..=0x3fff => 1,
140         0x4000..=0x1fffff => 2,
141         0x200000..=0x1fffffff => 3,
142         _ => 4,
143     }
144 }
145 
146 #[cfg(test)]
147 mod tests {
148     use super::Encoder;
149     use hex_literal::hex;
150 
151     /// OID `1.2.840.10045.2.1` encoded as ASN.1 BER/DER
152     const EXAMPLE_OID_BER: &[u8] = &hex!("2A8648CE3D0201");
153 
154     #[test]
encode()155     fn encode() {
156         let encoder = Encoder::new();
157         let encoder = encoder.arc(1).unwrap();
158         let encoder = encoder.arc(2).unwrap();
159         let encoder = encoder.arc(840).unwrap();
160         let encoder = encoder.arc(10045).unwrap();
161         let encoder = encoder.arc(2).unwrap();
162         let encoder = encoder.arc(1).unwrap();
163         assert_eq!(&encoder.bytes[..encoder.cursor], EXAMPLE_OID_BER);
164     }
165 }
166