1 use byteorder::{NetworkEndian, ReadBytesExt};
2 use log::*;
3 use std::{
4     borrow::Cow,
5     default::Default,
6     fmt,
7     io::{Cursor, ErrorKind, Read, Write},
8     result::Result as StdResult,
9     str::Utf8Error,
10     string::{FromUtf8Error, String},
11 };
12 
13 use super::{
14     coding::{CloseCode, Control, Data, OpCode},
15     mask::{apply_mask, generate_mask},
16 };
17 use crate::error::{Error, ProtocolError, Result};
18 
19 /// A struct representing the close command.
20 #[derive(Debug, Clone, Eq, PartialEq)]
21 pub struct CloseFrame<'t> {
22     /// The reason as a code.
23     pub code: CloseCode,
24     /// The reason as text string.
25     pub reason: Cow<'t, str>,
26 }
27 
28 impl<'t> CloseFrame<'t> {
29     /// Convert into a owned string.
into_owned(self) -> CloseFrame<'static>30     pub fn into_owned(self) -> CloseFrame<'static> {
31         CloseFrame { code: self.code, reason: self.reason.into_owned().into() }
32     }
33 }
34 
35 impl<'t> fmt::Display for CloseFrame<'t> {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result36     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
37         write!(f, "{} ({})", self.reason, self.code)
38     }
39 }
40 
41 /// A struct representing a WebSocket frame header.
42 #[allow(missing_copy_implementations)]
43 #[derive(Debug, Clone, Eq, PartialEq)]
44 pub struct FrameHeader {
45     /// Indicates that the frame is the last one of a possibly fragmented message.
46     pub is_final: bool,
47     /// Reserved for protocol extensions.
48     pub rsv1: bool,
49     /// Reserved for protocol extensions.
50     pub rsv2: bool,
51     /// Reserved for protocol extensions.
52     pub rsv3: bool,
53     /// WebSocket protocol opcode.
54     pub opcode: OpCode,
55     /// A frame mask, if any.
56     pub mask: Option<[u8; 4]>,
57 }
58 
59 impl Default for FrameHeader {
default() -> Self60     fn default() -> Self {
61         FrameHeader {
62             is_final: true,
63             rsv1: false,
64             rsv2: false,
65             rsv3: false,
66             opcode: OpCode::Control(Control::Close),
67             mask: None,
68         }
69     }
70 }
71 
72 impl FrameHeader {
73     /// Parse a header from an input stream.
74     /// Returns `None` if insufficient data and does not consume anything in this case.
75     /// Payload size is returned along with the header.
parse(cursor: &mut Cursor<impl AsRef<[u8]>>) -> Result<Option<(Self, u64)>>76     pub fn parse(cursor: &mut Cursor<impl AsRef<[u8]>>) -> Result<Option<(Self, u64)>> {
77         let initial = cursor.position();
78         match Self::parse_internal(cursor) {
79             ret @ Ok(None) => {
80                 cursor.set_position(initial);
81                 ret
82             }
83             ret => ret,
84         }
85     }
86 
87     /// Get the size of the header formatted with given payload length.
88     #[allow(clippy::len_without_is_empty)]
len(&self, length: u64) -> usize89     pub fn len(&self, length: u64) -> usize {
90         2 + LengthFormat::for_length(length).extra_bytes() + if self.mask.is_some() { 4 } else { 0 }
91     }
92 
93     /// Format a header for given payload size.
format(&self, length: u64, output: &mut impl Write) -> Result<()>94     pub fn format(&self, length: u64, output: &mut impl Write) -> Result<()> {
95         let code: u8 = self.opcode.into();
96 
97         let one = {
98             code | if self.is_final { 0x80 } else { 0 }
99                 | if self.rsv1 { 0x40 } else { 0 }
100                 | if self.rsv2 { 0x20 } else { 0 }
101                 | if self.rsv3 { 0x10 } else { 0 }
102         };
103 
104         let lenfmt = LengthFormat::for_length(length);
105 
106         let two = { lenfmt.length_byte() | if self.mask.is_some() { 0x80 } else { 0 } };
107 
108         output.write_all(&[one, two])?;
109         match lenfmt {
110             LengthFormat::U8(_) => (),
111             LengthFormat::U16 => {
112                 output.write_all(&(length as u16).to_be_bytes())?;
113             }
114             LengthFormat::U64 => {
115                 output.write_all(&length.to_be_bytes())?;
116             }
117         }
118 
119         if let Some(ref mask) = self.mask {
120             output.write_all(mask)?
121         }
122 
123         Ok(())
124     }
125 
126     /// Generate a random frame mask and store this in the header.
127     ///
128     /// Of course this does not change frame contents. It just generates a mask.
set_random_mask(&mut self)129     pub(crate) fn set_random_mask(&mut self) {
130         self.mask = Some(generate_mask())
131     }
132 }
133 
134 impl FrameHeader {
135     /// Internal parse engine.
136     /// Returns `None` if insufficient data.
137     /// Payload size is returned along with the header.
parse_internal(cursor: &mut impl Read) -> Result<Option<(Self, u64)>>138     fn parse_internal(cursor: &mut impl Read) -> Result<Option<(Self, u64)>> {
139         let (first, second) = {
140             let mut head = [0u8; 2];
141             if cursor.read(&mut head)? != 2 {
142                 return Ok(None);
143             }
144             trace!("Parsed headers {:?}", head);
145             (head[0], head[1])
146         };
147 
148         trace!("First: {:b}", first);
149         trace!("Second: {:b}", second);
150 
151         let is_final = first & 0x80 != 0;
152 
153         let rsv1 = first & 0x40 != 0;
154         let rsv2 = first & 0x20 != 0;
155         let rsv3 = first & 0x10 != 0;
156 
157         let opcode = OpCode::from(first & 0x0F);
158         trace!("Opcode: {:?}", opcode);
159 
160         let masked = second & 0x80 != 0;
161         trace!("Masked: {:?}", masked);
162 
163         let length = {
164             let length_byte = second & 0x7F;
165             let length_length = LengthFormat::for_byte(length_byte).extra_bytes();
166             if length_length > 0 {
167                 match cursor.read_uint::<NetworkEndian>(length_length) {
168                     Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => {
169                         return Ok(None);
170                     }
171                     Err(err) => {
172                         return Err(err.into());
173                     }
174                     Ok(read) => read,
175                 }
176             } else {
177                 u64::from(length_byte)
178             }
179         };
180 
181         let mask = if masked {
182             let mut mask_bytes = [0u8; 4];
183             if cursor.read(&mut mask_bytes)? != 4 {
184                 return Ok(None);
185             } else {
186                 Some(mask_bytes)
187             }
188         } else {
189             None
190         };
191 
192         // Disallow bad opcode
193         match opcode {
194             OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => {
195                 return Err(Error::Protocol(ProtocolError::InvalidOpcode(first & 0x0F)))
196             }
197             _ => (),
198         }
199 
200         let hdr = FrameHeader { is_final, rsv1, rsv2, rsv3, opcode, mask };
201 
202         Ok(Some((hdr, length)))
203     }
204 }
205 
206 /// A struct representing a WebSocket frame.
207 #[derive(Debug, Clone, Eq, PartialEq)]
208 pub struct Frame {
209     header: FrameHeader,
210     payload: Vec<u8>,
211 }
212 
213 impl Frame {
214     /// Get the length of the frame.
215     /// This is the length of the header + the length of the payload.
216     #[inline]
len(&self) -> usize217     pub fn len(&self) -> usize {
218         let length = self.payload.len();
219         self.header.len(length as u64) + length
220     }
221 
222     /// Check if the frame is empty.
223     #[inline]
is_empty(&self) -> bool224     pub fn is_empty(&self) -> bool {
225         self.len() == 0
226     }
227 
228     /// Get a reference to the frame's header.
229     #[inline]
header(&self) -> &FrameHeader230     pub fn header(&self) -> &FrameHeader {
231         &self.header
232     }
233 
234     /// Get a mutable reference to the frame's header.
235     #[inline]
header_mut(&mut self) -> &mut FrameHeader236     pub fn header_mut(&mut self) -> &mut FrameHeader {
237         &mut self.header
238     }
239 
240     /// Get a reference to the frame's payload.
241     #[inline]
payload(&self) -> &Vec<u8>242     pub fn payload(&self) -> &Vec<u8> {
243         &self.payload
244     }
245 
246     /// Get a mutable reference to the frame's payload.
247     #[inline]
payload_mut(&mut self) -> &mut Vec<u8>248     pub fn payload_mut(&mut self) -> &mut Vec<u8> {
249         &mut self.payload
250     }
251 
252     /// Test whether the frame is masked.
253     #[inline]
is_masked(&self) -> bool254     pub(crate) fn is_masked(&self) -> bool {
255         self.header.mask.is_some()
256     }
257 
258     /// Generate a random mask for the frame.
259     ///
260     /// This just generates a mask, payload is not changed. The actual masking is performed
261     /// either on `format()` or on `apply_mask()` call.
262     #[inline]
set_random_mask(&mut self)263     pub(crate) fn set_random_mask(&mut self) {
264         self.header.set_random_mask()
265     }
266 
267     /// This method unmasks the payload and should only be called on frames that are actually
268     /// masked. In other words, those frames that have just been received from a client endpoint.
269     #[inline]
apply_mask(&mut self)270     pub(crate) fn apply_mask(&mut self) {
271         if let Some(mask) = self.header.mask.take() {
272             apply_mask(&mut self.payload, mask)
273         }
274     }
275 
276     /// Consume the frame into its payload as binary.
277     #[inline]
into_data(self) -> Vec<u8>278     pub fn into_data(self) -> Vec<u8> {
279         self.payload
280     }
281 
282     /// Consume the frame into its payload as string.
283     #[inline]
into_string(self) -> StdResult<String, FromUtf8Error>284     pub fn into_string(self) -> StdResult<String, FromUtf8Error> {
285         String::from_utf8(self.payload)
286     }
287 
288     /// Get frame payload as `&str`.
289     #[inline]
to_text(&self) -> Result<&str, Utf8Error>290     pub fn to_text(&self) -> Result<&str, Utf8Error> {
291         std::str::from_utf8(&self.payload)
292     }
293 
294     /// Consume the frame into a closing frame.
295     #[inline]
into_close(self) -> Result<Option<CloseFrame<'static>>>296     pub(crate) fn into_close(self) -> Result<Option<CloseFrame<'static>>> {
297         match self.payload.len() {
298             0 => Ok(None),
299             1 => Err(Error::Protocol(ProtocolError::InvalidCloseSequence)),
300             _ => {
301                 let mut data = self.payload;
302                 let code = u16::from_be_bytes([data[0], data[1]]).into();
303                 data.drain(0..2);
304                 let text = String::from_utf8(data)?;
305                 Ok(Some(CloseFrame { code, reason: text.into() }))
306             }
307         }
308     }
309 
310     /// Create a new data frame.
311     #[inline]
message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame312     pub fn message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame {
313         debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame.");
314 
315         Frame { header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, payload: data }
316     }
317 
318     /// Create a new Pong control frame.
319     #[inline]
pong(data: Vec<u8>) -> Frame320     pub fn pong(data: Vec<u8>) -> Frame {
321         Frame {
322             header: FrameHeader {
323                 opcode: OpCode::Control(Control::Pong),
324                 ..FrameHeader::default()
325             },
326             payload: data,
327         }
328     }
329 
330     /// Create a new Ping control frame.
331     #[inline]
ping(data: Vec<u8>) -> Frame332     pub fn ping(data: Vec<u8>) -> Frame {
333         Frame {
334             header: FrameHeader {
335                 opcode: OpCode::Control(Control::Ping),
336                 ..FrameHeader::default()
337             },
338             payload: data,
339         }
340     }
341 
342     /// Create a new Close control frame.
343     #[inline]
close(msg: Option<CloseFrame>) -> Frame344     pub fn close(msg: Option<CloseFrame>) -> Frame {
345         let payload = if let Some(CloseFrame { code, reason }) = msg {
346             let mut p = Vec::with_capacity(reason.as_bytes().len() + 2);
347             p.extend(u16::from(code).to_be_bytes());
348             p.extend_from_slice(reason.as_bytes());
349             p
350         } else {
351             Vec::new()
352         };
353 
354         Frame { header: FrameHeader::default(), payload }
355     }
356 
357     /// Create a frame from given header and data.
from_payload(header: FrameHeader, payload: Vec<u8>) -> Self358     pub fn from_payload(header: FrameHeader, payload: Vec<u8>) -> Self {
359         Frame { header, payload }
360     }
361 
362     /// Write a frame out to a buffer
format(mut self, output: &mut impl Write) -> Result<()>363     pub fn format(mut self, output: &mut impl Write) -> Result<()> {
364         self.header.format(self.payload.len() as u64, output)?;
365         self.apply_mask();
366         output.write_all(self.payload())?;
367         Ok(())
368     }
369 }
370 
371 impl fmt::Display for Frame {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result372     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
373         use std::fmt::Write;
374 
375         write!(
376             f,
377             "
378 <FRAME>
379 final: {}
380 reserved: {} {} {}
381 opcode: {}
382 length: {}
383 payload length: {}
384 payload: 0x{}
385             ",
386             self.header.is_final,
387             self.header.rsv1,
388             self.header.rsv2,
389             self.header.rsv3,
390             self.header.opcode,
391             // self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()),
392             self.len(),
393             self.payload.len(),
394             self.payload.iter().fold(String::new(), |mut output, byte| {
395                 _ = write!(output, "{byte:02x}");
396                 output
397             })
398         )
399     }
400 }
401 
402 /// Handling of the length format.
403 enum LengthFormat {
404     U8(u8),
405     U16,
406     U64,
407 }
408 
409 impl LengthFormat {
410     /// Get the length format for a given data size.
411     #[inline]
for_length(length: u64) -> Self412     fn for_length(length: u64) -> Self {
413         if length < 126 {
414             LengthFormat::U8(length as u8)
415         } else if length < 65536 {
416             LengthFormat::U16
417         } else {
418             LengthFormat::U64
419         }
420     }
421 
422     /// Get the size of the length encoding.
423     #[inline]
extra_bytes(&self) -> usize424     fn extra_bytes(&self) -> usize {
425         match *self {
426             LengthFormat::U8(_) => 0,
427             LengthFormat::U16 => 2,
428             LengthFormat::U64 => 8,
429         }
430     }
431 
432     /// Encode the given length.
433     #[inline]
length_byte(&self) -> u8434     fn length_byte(&self) -> u8 {
435         match *self {
436             LengthFormat::U8(b) => b,
437             LengthFormat::U16 => 126,
438             LengthFormat::U64 => 127,
439         }
440     }
441 
442     /// Get the length format for a given length byte.
443     #[inline]
for_byte(byte: u8) -> Self444     fn for_byte(byte: u8) -> Self {
445         match byte & 0x7F {
446             126 => LengthFormat::U16,
447             127 => LengthFormat::U64,
448             b => LengthFormat::U8(b),
449         }
450     }
451 }
452 
453 #[cfg(test)]
454 mod tests {
455     use super::*;
456 
457     use super::super::coding::{Data, OpCode};
458     use std::io::Cursor;
459 
460     #[test]
parse()461     fn parse() {
462         let mut raw: Cursor<Vec<u8>> =
463             Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
464         let (header, length) = FrameHeader::parse(&mut raw).unwrap().unwrap();
465         assert_eq!(length, 7);
466         let mut payload = Vec::new();
467         raw.read_to_end(&mut payload).unwrap();
468         let frame = Frame::from_payload(header, payload);
469         assert_eq!(frame.into_data(), vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]);
470     }
471 
472     #[test]
format()473     fn format() {
474         let frame = Frame::ping(vec![0x01, 0x02]);
475         let mut buf = Vec::with_capacity(frame.len());
476         frame.format(&mut buf).unwrap();
477         assert_eq!(buf, vec![0x89, 0x02, 0x01, 0x02]);
478     }
479 
480     #[test]
display()481     fn display() {
482         let f = Frame::message("hi there".into(), OpCode::Data(Data::Text), true);
483         let view = format!("{}", f);
484         assert!(view.contains("payload:"));
485     }
486 }
487