1 use byteorder::{NetworkEndian, ReadBytesExt}; 2 use log::*; 3 use std::{ 4 borrow::Cow, 5 default::Default, 6 fmt, 7 io::{Cursor, ErrorKind, Read, Write}, 8 result::Result as StdResult, 9 str::Utf8Error, 10 string::{FromUtf8Error, String}, 11 }; 12 13 use super::{ 14 coding::{CloseCode, Control, Data, OpCode}, 15 mask::{apply_mask, generate_mask}, 16 }; 17 use crate::error::{Error, ProtocolError, Result}; 18 19 /// A struct representing the close command. 20 #[derive(Debug, Clone, Eq, PartialEq)] 21 pub struct CloseFrame<'t> { 22 /// The reason as a code. 23 pub code: CloseCode, 24 /// The reason as text string. 25 pub reason: Cow<'t, str>, 26 } 27 28 impl<'t> CloseFrame<'t> { 29 /// Convert into a owned string. into_owned(self) -> CloseFrame<'static>30 pub fn into_owned(self) -> CloseFrame<'static> { 31 CloseFrame { code: self.code, reason: self.reason.into_owned().into() } 32 } 33 } 34 35 impl<'t> fmt::Display for CloseFrame<'t> { fmt(&self, f: &mut fmt::Formatter) -> fmt::Result36 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 37 write!(f, "{} ({})", self.reason, self.code) 38 } 39 } 40 41 /// A struct representing a WebSocket frame header. 42 #[allow(missing_copy_implementations)] 43 #[derive(Debug, Clone, Eq, PartialEq)] 44 pub struct FrameHeader { 45 /// Indicates that the frame is the last one of a possibly fragmented message. 46 pub is_final: bool, 47 /// Reserved for protocol extensions. 48 pub rsv1: bool, 49 /// Reserved for protocol extensions. 50 pub rsv2: bool, 51 /// Reserved for protocol extensions. 52 pub rsv3: bool, 53 /// WebSocket protocol opcode. 54 pub opcode: OpCode, 55 /// A frame mask, if any. 56 pub mask: Option<[u8; 4]>, 57 } 58 59 impl Default for FrameHeader { default() -> Self60 fn default() -> Self { 61 FrameHeader { 62 is_final: true, 63 rsv1: false, 64 rsv2: false, 65 rsv3: false, 66 opcode: OpCode::Control(Control::Close), 67 mask: None, 68 } 69 } 70 } 71 72 impl FrameHeader { 73 /// Parse a header from an input stream. 74 /// Returns `None` if insufficient data and does not consume anything in this case. 75 /// Payload size is returned along with the header. parse(cursor: &mut Cursor<impl AsRef<[u8]>>) -> Result<Option<(Self, u64)>>76 pub fn parse(cursor: &mut Cursor<impl AsRef<[u8]>>) -> Result<Option<(Self, u64)>> { 77 let initial = cursor.position(); 78 match Self::parse_internal(cursor) { 79 ret @ Ok(None) => { 80 cursor.set_position(initial); 81 ret 82 } 83 ret => ret, 84 } 85 } 86 87 /// Get the size of the header formatted with given payload length. 88 #[allow(clippy::len_without_is_empty)] len(&self, length: u64) -> usize89 pub fn len(&self, length: u64) -> usize { 90 2 + LengthFormat::for_length(length).extra_bytes() + if self.mask.is_some() { 4 } else { 0 } 91 } 92 93 /// Format a header for given payload size. format(&self, length: u64, output: &mut impl Write) -> Result<()>94 pub fn format(&self, length: u64, output: &mut impl Write) -> Result<()> { 95 let code: u8 = self.opcode.into(); 96 97 let one = { 98 code | if self.is_final { 0x80 } else { 0 } 99 | if self.rsv1 { 0x40 } else { 0 } 100 | if self.rsv2 { 0x20 } else { 0 } 101 | if self.rsv3 { 0x10 } else { 0 } 102 }; 103 104 let lenfmt = LengthFormat::for_length(length); 105 106 let two = { lenfmt.length_byte() | if self.mask.is_some() { 0x80 } else { 0 } }; 107 108 output.write_all(&[one, two])?; 109 match lenfmt { 110 LengthFormat::U8(_) => (), 111 LengthFormat::U16 => { 112 output.write_all(&(length as u16).to_be_bytes())?; 113 } 114 LengthFormat::U64 => { 115 output.write_all(&length.to_be_bytes())?; 116 } 117 } 118 119 if let Some(ref mask) = self.mask { 120 output.write_all(mask)? 121 } 122 123 Ok(()) 124 } 125 126 /// Generate a random frame mask and store this in the header. 127 /// 128 /// Of course this does not change frame contents. It just generates a mask. set_random_mask(&mut self)129 pub(crate) fn set_random_mask(&mut self) { 130 self.mask = Some(generate_mask()) 131 } 132 } 133 134 impl FrameHeader { 135 /// Internal parse engine. 136 /// Returns `None` if insufficient data. 137 /// Payload size is returned along with the header. parse_internal(cursor: &mut impl Read) -> Result<Option<(Self, u64)>>138 fn parse_internal(cursor: &mut impl Read) -> Result<Option<(Self, u64)>> { 139 let (first, second) = { 140 let mut head = [0u8; 2]; 141 if cursor.read(&mut head)? != 2 { 142 return Ok(None); 143 } 144 trace!("Parsed headers {:?}", head); 145 (head[0], head[1]) 146 }; 147 148 trace!("First: {:b}", first); 149 trace!("Second: {:b}", second); 150 151 let is_final = first & 0x80 != 0; 152 153 let rsv1 = first & 0x40 != 0; 154 let rsv2 = first & 0x20 != 0; 155 let rsv3 = first & 0x10 != 0; 156 157 let opcode = OpCode::from(first & 0x0F); 158 trace!("Opcode: {:?}", opcode); 159 160 let masked = second & 0x80 != 0; 161 trace!("Masked: {:?}", masked); 162 163 let length = { 164 let length_byte = second & 0x7F; 165 let length_length = LengthFormat::for_byte(length_byte).extra_bytes(); 166 if length_length > 0 { 167 match cursor.read_uint::<NetworkEndian>(length_length) { 168 Err(ref err) if err.kind() == ErrorKind::UnexpectedEof => { 169 return Ok(None); 170 } 171 Err(err) => { 172 return Err(err.into()); 173 } 174 Ok(read) => read, 175 } 176 } else { 177 u64::from(length_byte) 178 } 179 }; 180 181 let mask = if masked { 182 let mut mask_bytes = [0u8; 4]; 183 if cursor.read(&mut mask_bytes)? != 4 { 184 return Ok(None); 185 } else { 186 Some(mask_bytes) 187 } 188 } else { 189 None 190 }; 191 192 // Disallow bad opcode 193 match opcode { 194 OpCode::Control(Control::Reserved(_)) | OpCode::Data(Data::Reserved(_)) => { 195 return Err(Error::Protocol(ProtocolError::InvalidOpcode(first & 0x0F))) 196 } 197 _ => (), 198 } 199 200 let hdr = FrameHeader { is_final, rsv1, rsv2, rsv3, opcode, mask }; 201 202 Ok(Some((hdr, length))) 203 } 204 } 205 206 /// A struct representing a WebSocket frame. 207 #[derive(Debug, Clone, Eq, PartialEq)] 208 pub struct Frame { 209 header: FrameHeader, 210 payload: Vec<u8>, 211 } 212 213 impl Frame { 214 /// Get the length of the frame. 215 /// This is the length of the header + the length of the payload. 216 #[inline] len(&self) -> usize217 pub fn len(&self) -> usize { 218 let length = self.payload.len(); 219 self.header.len(length as u64) + length 220 } 221 222 /// Check if the frame is empty. 223 #[inline] is_empty(&self) -> bool224 pub fn is_empty(&self) -> bool { 225 self.len() == 0 226 } 227 228 /// Get a reference to the frame's header. 229 #[inline] header(&self) -> &FrameHeader230 pub fn header(&self) -> &FrameHeader { 231 &self.header 232 } 233 234 /// Get a mutable reference to the frame's header. 235 #[inline] header_mut(&mut self) -> &mut FrameHeader236 pub fn header_mut(&mut self) -> &mut FrameHeader { 237 &mut self.header 238 } 239 240 /// Get a reference to the frame's payload. 241 #[inline] payload(&self) -> &Vec<u8>242 pub fn payload(&self) -> &Vec<u8> { 243 &self.payload 244 } 245 246 /// Get a mutable reference to the frame's payload. 247 #[inline] payload_mut(&mut self) -> &mut Vec<u8>248 pub fn payload_mut(&mut self) -> &mut Vec<u8> { 249 &mut self.payload 250 } 251 252 /// Test whether the frame is masked. 253 #[inline] is_masked(&self) -> bool254 pub(crate) fn is_masked(&self) -> bool { 255 self.header.mask.is_some() 256 } 257 258 /// Generate a random mask for the frame. 259 /// 260 /// This just generates a mask, payload is not changed. The actual masking is performed 261 /// either on `format()` or on `apply_mask()` call. 262 #[inline] set_random_mask(&mut self)263 pub(crate) fn set_random_mask(&mut self) { 264 self.header.set_random_mask() 265 } 266 267 /// This method unmasks the payload and should only be called on frames that are actually 268 /// masked. In other words, those frames that have just been received from a client endpoint. 269 #[inline] apply_mask(&mut self)270 pub(crate) fn apply_mask(&mut self) { 271 if let Some(mask) = self.header.mask.take() { 272 apply_mask(&mut self.payload, mask) 273 } 274 } 275 276 /// Consume the frame into its payload as binary. 277 #[inline] into_data(self) -> Vec<u8>278 pub fn into_data(self) -> Vec<u8> { 279 self.payload 280 } 281 282 /// Consume the frame into its payload as string. 283 #[inline] into_string(self) -> StdResult<String, FromUtf8Error>284 pub fn into_string(self) -> StdResult<String, FromUtf8Error> { 285 String::from_utf8(self.payload) 286 } 287 288 /// Get frame payload as `&str`. 289 #[inline] to_text(&self) -> Result<&str, Utf8Error>290 pub fn to_text(&self) -> Result<&str, Utf8Error> { 291 std::str::from_utf8(&self.payload) 292 } 293 294 /// Consume the frame into a closing frame. 295 #[inline] into_close(self) -> Result<Option<CloseFrame<'static>>>296 pub(crate) fn into_close(self) -> Result<Option<CloseFrame<'static>>> { 297 match self.payload.len() { 298 0 => Ok(None), 299 1 => Err(Error::Protocol(ProtocolError::InvalidCloseSequence)), 300 _ => { 301 let mut data = self.payload; 302 let code = u16::from_be_bytes([data[0], data[1]]).into(); 303 data.drain(0..2); 304 let text = String::from_utf8(data)?; 305 Ok(Some(CloseFrame { code, reason: text.into() })) 306 } 307 } 308 } 309 310 /// Create a new data frame. 311 #[inline] message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame312 pub fn message(data: Vec<u8>, opcode: OpCode, is_final: bool) -> Frame { 313 debug_assert!(matches!(opcode, OpCode::Data(_)), "Invalid opcode for data frame."); 314 315 Frame { header: FrameHeader { is_final, opcode, ..FrameHeader::default() }, payload: data } 316 } 317 318 /// Create a new Pong control frame. 319 #[inline] pong(data: Vec<u8>) -> Frame320 pub fn pong(data: Vec<u8>) -> Frame { 321 Frame { 322 header: FrameHeader { 323 opcode: OpCode::Control(Control::Pong), 324 ..FrameHeader::default() 325 }, 326 payload: data, 327 } 328 } 329 330 /// Create a new Ping control frame. 331 #[inline] ping(data: Vec<u8>) -> Frame332 pub fn ping(data: Vec<u8>) -> Frame { 333 Frame { 334 header: FrameHeader { 335 opcode: OpCode::Control(Control::Ping), 336 ..FrameHeader::default() 337 }, 338 payload: data, 339 } 340 } 341 342 /// Create a new Close control frame. 343 #[inline] close(msg: Option<CloseFrame>) -> Frame344 pub fn close(msg: Option<CloseFrame>) -> Frame { 345 let payload = if let Some(CloseFrame { code, reason }) = msg { 346 let mut p = Vec::with_capacity(reason.as_bytes().len() + 2); 347 p.extend(u16::from(code).to_be_bytes()); 348 p.extend_from_slice(reason.as_bytes()); 349 p 350 } else { 351 Vec::new() 352 }; 353 354 Frame { header: FrameHeader::default(), payload } 355 } 356 357 /// Create a frame from given header and data. from_payload(header: FrameHeader, payload: Vec<u8>) -> Self358 pub fn from_payload(header: FrameHeader, payload: Vec<u8>) -> Self { 359 Frame { header, payload } 360 } 361 362 /// Write a frame out to a buffer format(mut self, output: &mut impl Write) -> Result<()>363 pub fn format(mut self, output: &mut impl Write) -> Result<()> { 364 self.header.format(self.payload.len() as u64, output)?; 365 self.apply_mask(); 366 output.write_all(self.payload())?; 367 Ok(()) 368 } 369 } 370 371 impl fmt::Display for Frame { fmt(&self, f: &mut fmt::Formatter) -> fmt::Result372 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { 373 use std::fmt::Write; 374 375 write!( 376 f, 377 " 378 <FRAME> 379 final: {} 380 reserved: {} {} {} 381 opcode: {} 382 length: {} 383 payload length: {} 384 payload: 0x{} 385 ", 386 self.header.is_final, 387 self.header.rsv1, 388 self.header.rsv2, 389 self.header.rsv3, 390 self.header.opcode, 391 // self.mask.map(|mask| format!("{:?}", mask)).unwrap_or("NONE".into()), 392 self.len(), 393 self.payload.len(), 394 self.payload.iter().fold(String::new(), |mut output, byte| { 395 _ = write!(output, "{byte:02x}"); 396 output 397 }) 398 ) 399 } 400 } 401 402 /// Handling of the length format. 403 enum LengthFormat { 404 U8(u8), 405 U16, 406 U64, 407 } 408 409 impl LengthFormat { 410 /// Get the length format for a given data size. 411 #[inline] for_length(length: u64) -> Self412 fn for_length(length: u64) -> Self { 413 if length < 126 { 414 LengthFormat::U8(length as u8) 415 } else if length < 65536 { 416 LengthFormat::U16 417 } else { 418 LengthFormat::U64 419 } 420 } 421 422 /// Get the size of the length encoding. 423 #[inline] extra_bytes(&self) -> usize424 fn extra_bytes(&self) -> usize { 425 match *self { 426 LengthFormat::U8(_) => 0, 427 LengthFormat::U16 => 2, 428 LengthFormat::U64 => 8, 429 } 430 } 431 432 /// Encode the given length. 433 #[inline] length_byte(&self) -> u8434 fn length_byte(&self) -> u8 { 435 match *self { 436 LengthFormat::U8(b) => b, 437 LengthFormat::U16 => 126, 438 LengthFormat::U64 => 127, 439 } 440 } 441 442 /// Get the length format for a given length byte. 443 #[inline] for_byte(byte: u8) -> Self444 fn for_byte(byte: u8) -> Self { 445 match byte & 0x7F { 446 126 => LengthFormat::U16, 447 127 => LengthFormat::U64, 448 b => LengthFormat::U8(b), 449 } 450 } 451 } 452 453 #[cfg(test)] 454 mod tests { 455 use super::*; 456 457 use super::super::coding::{Data, OpCode}; 458 use std::io::Cursor; 459 460 #[test] parse()461 fn parse() { 462 let mut raw: Cursor<Vec<u8>> = 463 Cursor::new(vec![0x82, 0x07, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); 464 let (header, length) = FrameHeader::parse(&mut raw).unwrap().unwrap(); 465 assert_eq!(length, 7); 466 let mut payload = Vec::new(); 467 raw.read_to_end(&mut payload).unwrap(); 468 let frame = Frame::from_payload(header, payload); 469 assert_eq!(frame.into_data(), vec![0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]); 470 } 471 472 #[test] format()473 fn format() { 474 let frame = Frame::ping(vec![0x01, 0x02]); 475 let mut buf = Vec::with_capacity(frame.len()); 476 frame.format(&mut buf).unwrap(); 477 assert_eq!(buf, vec![0x89, 0x02, 0x01, 0x02]); 478 } 479 480 #[test] display()481 fn display() { 482 let f = Frame::message("hi there".into(), OpCode::Data(Data::Text), true); 483 let view = format!("{}", f); 484 assert!(view.contains("payload:")); 485 } 486 } 487