1 //! Server handshake machine.
2 
3 use std::{
4     io::{self, Read, Write},
5     marker::PhantomData,
6     result::Result as StdResult,
7 };
8 
9 use http::{
10     response::Builder, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
11 };
12 use httparse::Status;
13 use log::*;
14 
15 use super::{
16     derive_accept_key,
17     headers::{FromHttparse, MAX_HEADERS},
18     machine::{HandshakeMachine, StageResult, TryParse},
19     HandshakeRole, MidHandshake, ProcessingResult,
20 };
21 use crate::{
22     error::{Error, ProtocolError, Result},
23     protocol::{Role, WebSocket, WebSocketConfig},
24 };
25 
26 /// Server request type.
27 pub type Request = HttpRequest<()>;
28 
29 /// Server response type.
30 pub type Response = HttpResponse<()>;
31 
32 /// Server error response type.
33 pub type ErrorResponse = HttpResponse<Option<String>>;
34 
create_parts<T>(request: &HttpRequest<T>) -> Result<Builder>35 fn create_parts<T>(request: &HttpRequest<T>) -> Result<Builder> {
36     if request.method() != http::Method::GET {
37         return Err(Error::Protocol(ProtocolError::WrongHttpMethod));
38     }
39 
40     if request.version() < http::Version::HTTP_11 {
41         return Err(Error::Protocol(ProtocolError::WrongHttpVersion));
42     }
43 
44     if !request
45         .headers()
46         .get("Connection")
47         .and_then(|h| h.to_str().ok())
48         .map(|h| h.split(|c| c == ' ' || c == ',').any(|p| p.eq_ignore_ascii_case("Upgrade")))
49         .unwrap_or(false)
50     {
51         return Err(Error::Protocol(ProtocolError::MissingConnectionUpgradeHeader));
52     }
53 
54     if !request
55         .headers()
56         .get("Upgrade")
57         .and_then(|h| h.to_str().ok())
58         .map(|h| h.eq_ignore_ascii_case("websocket"))
59         .unwrap_or(false)
60     {
61         return Err(Error::Protocol(ProtocolError::MissingUpgradeWebSocketHeader));
62     }
63 
64     if !request.headers().get("Sec-WebSocket-Version").map(|h| h == "13").unwrap_or(false) {
65         return Err(Error::Protocol(ProtocolError::MissingSecWebSocketVersionHeader));
66     }
67 
68     let key = request
69         .headers()
70         .get("Sec-WebSocket-Key")
71         .ok_or(Error::Protocol(ProtocolError::MissingSecWebSocketKey))?;
72 
73     let builder = Response::builder()
74         .status(StatusCode::SWITCHING_PROTOCOLS)
75         .version(request.version())
76         .header("Connection", "Upgrade")
77         .header("Upgrade", "websocket")
78         .header("Sec-WebSocket-Accept", derive_accept_key(key.as_bytes()));
79 
80     Ok(builder)
81 }
82 
83 /// Create a response for the request.
create_response(request: &Request) -> Result<Response>84 pub fn create_response(request: &Request) -> Result<Response> {
85     Ok(create_parts(request)?.body(())?)
86 }
87 
88 /// Create a response for the request with a custom body.
create_response_with_body<T>( request: &HttpRequest<T>, generate_body: impl FnOnce() -> T, ) -> Result<HttpResponse<T>>89 pub fn create_response_with_body<T>(
90     request: &HttpRequest<T>,
91     generate_body: impl FnOnce() -> T,
92 ) -> Result<HttpResponse<T>> {
93     Ok(create_parts(request)?.body(generate_body())?)
94 }
95 
96 /// Write `response` to the stream `w`.
write_response<T>(mut w: impl io::Write, response: &HttpResponse<T>) -> Result<()>97 pub fn write_response<T>(mut w: impl io::Write, response: &HttpResponse<T>) -> Result<()> {
98     writeln!(
99         w,
100         "{version:?} {status}\r",
101         version = response.version(),
102         status = response.status()
103     )?;
104 
105     for (k, v) in response.headers() {
106         writeln!(w, "{}: {}\r", k, v.to_str()?)?;
107     }
108 
109     writeln!(w, "\r")?;
110 
111     Ok(())
112 }
113 
114 impl TryParse for Request {
try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>>115     fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
116         let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
117         let mut req = httparse::Request::new(&mut hbuffer);
118         Ok(match req.parse(buf)? {
119             Status::Partial => None,
120             Status::Complete(size) => Some((size, Request::from_httparse(req)?)),
121         })
122     }
123 }
124 
125 impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
from_httparse(raw: httparse::Request<'h, 'b>) -> Result<Self>126     fn from_httparse(raw: httparse::Request<'h, 'b>) -> Result<Self> {
127         if raw.method.expect("Bug: no method in header") != "GET" {
128             return Err(Error::Protocol(ProtocolError::WrongHttpMethod));
129         }
130 
131         if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
132             return Err(Error::Protocol(ProtocolError::WrongHttpVersion));
133         }
134 
135         let headers = HeaderMap::from_httparse(raw.headers)?;
136 
137         let mut request = Request::new(());
138         *request.method_mut() = http::Method::GET;
139         *request.headers_mut() = headers;
140         *request.uri_mut() = raw.path.expect("Bug: no path in header").parse()?;
141         // TODO: httparse only supports HTTP 0.9/1.0/1.1 but not HTTP 2.0
142         // so the only valid value we could get in the response would be 1.1.
143         *request.version_mut() = http::Version::HTTP_11;
144 
145         Ok(request)
146     }
147 }
148 
149 /// The callback trait.
150 ///
151 /// The callback is called when the server receives an incoming WebSocket
152 /// handshake request from the client. Specifying a callback allows you to analyze incoming headers
153 /// and add additional headers to the response that server sends to the client and/or reject the
154 /// connection based on the incoming headers.
155 pub trait Callback: Sized {
156     /// Called whenever the server read the request from the client and is ready to reply to it.
157     /// May return additional reply headers.
158     /// Returning an error resulting in rejecting the incoming connection.
on_request( self, request: &Request, response: Response, ) -> StdResult<Response, ErrorResponse>159     fn on_request(
160         self,
161         request: &Request,
162         response: Response,
163     ) -> StdResult<Response, ErrorResponse>;
164 }
165 
166 impl<F> Callback for F
167 where
168     F: FnOnce(&Request, Response) -> StdResult<Response, ErrorResponse>,
169 {
on_request( self, request: &Request, response: Response, ) -> StdResult<Response, ErrorResponse>170     fn on_request(
171         self,
172         request: &Request,
173         response: Response,
174     ) -> StdResult<Response, ErrorResponse> {
175         self(request, response)
176     }
177 }
178 
179 /// Stub for callback that does nothing.
180 #[derive(Clone, Copy, Debug)]
181 pub struct NoCallback;
182 
183 impl Callback for NoCallback {
on_request( self, _request: &Request, response: Response, ) -> StdResult<Response, ErrorResponse>184     fn on_request(
185         self,
186         _request: &Request,
187         response: Response,
188     ) -> StdResult<Response, ErrorResponse> {
189         Ok(response)
190     }
191 }
192 
193 /// Server handshake role.
194 #[allow(missing_copy_implementations)]
195 #[derive(Debug)]
196 pub struct ServerHandshake<S, C> {
197     /// Callback which is called whenever the server read the request from the client and is ready
198     /// to reply to it. The callback returns an optional headers which will be added to the reply
199     /// which the server sends to the user.
200     callback: Option<C>,
201     /// WebSocket configuration.
202     config: Option<WebSocketConfig>,
203     /// Error code/flag. If set, an error will be returned after sending response to the client.
204     error_response: Option<ErrorResponse>,
205     /// Internal stream type.
206     _marker: PhantomData<S>,
207 }
208 
209 impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
210     /// Start server handshake. `callback` specifies a custom callback which the user can pass to
211     /// the handshake, this callback will be called when the a websocket client connects to the
212     /// server, you can specify the callback if you want to add additional header to the client
213     /// upon join based on the incoming headers.
start(stream: S, callback: C, config: Option<WebSocketConfig>) -> MidHandshake<Self>214     pub fn start(stream: S, callback: C, config: Option<WebSocketConfig>) -> MidHandshake<Self> {
215         trace!("Server handshake initiated.");
216         MidHandshake {
217             machine: HandshakeMachine::start_read(stream),
218             role: ServerHandshake {
219                 callback: Some(callback),
220                 config,
221                 error_response: None,
222                 _marker: PhantomData,
223             },
224         }
225     }
226 }
227 
228 impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
229     type IncomingData = Request;
230     type InternalStream = S;
231     type FinalResult = WebSocket<S>;
232 
stage_finished( &mut self, finish: StageResult<Self::IncomingData, Self::InternalStream>, ) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>>233     fn stage_finished(
234         &mut self,
235         finish: StageResult<Self::IncomingData, Self::InternalStream>,
236     ) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>> {
237         Ok(match finish {
238             StageResult::DoneReading { stream, result, tail } => {
239                 if !tail.is_empty() {
240                     return Err(Error::Protocol(ProtocolError::JunkAfterRequest));
241                 }
242 
243                 let response = create_response(&result)?;
244                 let callback_result = if let Some(callback) = self.callback.take() {
245                     callback.on_request(&result, response)
246                 } else {
247                     Ok(response)
248                 };
249 
250                 match callback_result {
251                     Ok(response) => {
252                         let mut output = vec![];
253                         write_response(&mut output, &response)?;
254                         ProcessingResult::Continue(HandshakeMachine::start_write(stream, output))
255                     }
256 
257                     Err(resp) => {
258                         if resp.status().is_success() {
259                             return Err(Error::Protocol(ProtocolError::CustomResponseSuccessful));
260                         }
261 
262                         self.error_response = Some(resp);
263                         let resp = self.error_response.as_ref().unwrap();
264 
265                         let mut output = vec![];
266                         write_response(&mut output, resp)?;
267 
268                         if let Some(body) = resp.body() {
269                             output.extend_from_slice(body.as_bytes());
270                         }
271 
272                         ProcessingResult::Continue(HandshakeMachine::start_write(stream, output))
273                     }
274                 }
275             }
276 
277             StageResult::DoneWriting(stream) => {
278                 if let Some(err) = self.error_response.take() {
279                     debug!("Server handshake failed.");
280 
281                     let (parts, body) = err.into_parts();
282                     let body = body.map(|b| b.as_bytes().to_vec());
283                     return Err(Error::Http(http::Response::from_parts(parts, body)));
284                 } else {
285                     debug!("Server handshake done.");
286                     let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);
287                     ProcessingResult::Done(websocket)
288                 }
289             }
290         })
291     }
292 }
293 
294 #[cfg(test)]
295 mod tests {
296     use super::{super::machine::TryParse, create_response, Request};
297 
298     #[test]
request_parsing()299     fn request_parsing() {
300         const DATA: &[u8] = b"GET /script.ws HTTP/1.1\r\nHost: foo.com\r\n\r\n";
301         let (_, req) = Request::try_parse(DATA).unwrap().unwrap();
302         assert_eq!(req.uri().path(), "/script.ws");
303         assert_eq!(req.headers().get("Host").unwrap(), &b"foo.com"[..]);
304     }
305 
306     #[test]
request_replying()307     fn request_replying() {
308         const DATA: &[u8] = b"\
309             GET /script.ws HTTP/1.1\r\n\
310             Host: foo.com\r\n\
311             Connection: upgrade\r\n\
312             Upgrade: websocket\r\n\
313             Sec-WebSocket-Version: 13\r\n\
314             Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
315             \r\n";
316         let (_, req) = Request::try_parse(DATA).unwrap().unwrap();
317         let response = create_response(&req).unwrap();
318 
319         assert_eq!(
320             response.headers().get("Sec-WebSocket-Accept").unwrap(),
321             b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=".as_ref()
322         );
323     }
324 }
325