1 //! Server handshake machine.
2
3 use std::{
4 io::{self, Read, Write},
5 marker::PhantomData,
6 result::Result as StdResult,
7 };
8
9 use http::{
10 response::Builder, HeaderMap, Request as HttpRequest, Response as HttpResponse, StatusCode,
11 };
12 use httparse::Status;
13 use log::*;
14
15 use super::{
16 derive_accept_key,
17 headers::{FromHttparse, MAX_HEADERS},
18 machine::{HandshakeMachine, StageResult, TryParse},
19 HandshakeRole, MidHandshake, ProcessingResult,
20 };
21 use crate::{
22 error::{Error, ProtocolError, Result},
23 protocol::{Role, WebSocket, WebSocketConfig},
24 };
25
26 /// Server request type.
27 pub type Request = HttpRequest<()>;
28
29 /// Server response type.
30 pub type Response = HttpResponse<()>;
31
32 /// Server error response type.
33 pub type ErrorResponse = HttpResponse<Option<String>>;
34
create_parts<T>(request: &HttpRequest<T>) -> Result<Builder>35 fn create_parts<T>(request: &HttpRequest<T>) -> Result<Builder> {
36 if request.method() != http::Method::GET {
37 return Err(Error::Protocol(ProtocolError::WrongHttpMethod));
38 }
39
40 if request.version() < http::Version::HTTP_11 {
41 return Err(Error::Protocol(ProtocolError::WrongHttpVersion));
42 }
43
44 if !request
45 .headers()
46 .get("Connection")
47 .and_then(|h| h.to_str().ok())
48 .map(|h| h.split(|c| c == ' ' || c == ',').any(|p| p.eq_ignore_ascii_case("Upgrade")))
49 .unwrap_or(false)
50 {
51 return Err(Error::Protocol(ProtocolError::MissingConnectionUpgradeHeader));
52 }
53
54 if !request
55 .headers()
56 .get("Upgrade")
57 .and_then(|h| h.to_str().ok())
58 .map(|h| h.eq_ignore_ascii_case("websocket"))
59 .unwrap_or(false)
60 {
61 return Err(Error::Protocol(ProtocolError::MissingUpgradeWebSocketHeader));
62 }
63
64 if !request.headers().get("Sec-WebSocket-Version").map(|h| h == "13").unwrap_or(false) {
65 return Err(Error::Protocol(ProtocolError::MissingSecWebSocketVersionHeader));
66 }
67
68 let key = request
69 .headers()
70 .get("Sec-WebSocket-Key")
71 .ok_or(Error::Protocol(ProtocolError::MissingSecWebSocketKey))?;
72
73 let builder = Response::builder()
74 .status(StatusCode::SWITCHING_PROTOCOLS)
75 .version(request.version())
76 .header("Connection", "Upgrade")
77 .header("Upgrade", "websocket")
78 .header("Sec-WebSocket-Accept", derive_accept_key(key.as_bytes()));
79
80 Ok(builder)
81 }
82
83 /// Create a response for the request.
create_response(request: &Request) -> Result<Response>84 pub fn create_response(request: &Request) -> Result<Response> {
85 Ok(create_parts(request)?.body(())?)
86 }
87
88 /// Create a response for the request with a custom body.
create_response_with_body<T>( request: &HttpRequest<T>, generate_body: impl FnOnce() -> T, ) -> Result<HttpResponse<T>>89 pub fn create_response_with_body<T>(
90 request: &HttpRequest<T>,
91 generate_body: impl FnOnce() -> T,
92 ) -> Result<HttpResponse<T>> {
93 Ok(create_parts(request)?.body(generate_body())?)
94 }
95
96 /// Write `response` to the stream `w`.
write_response<T>(mut w: impl io::Write, response: &HttpResponse<T>) -> Result<()>97 pub fn write_response<T>(mut w: impl io::Write, response: &HttpResponse<T>) -> Result<()> {
98 writeln!(
99 w,
100 "{version:?} {status}\r",
101 version = response.version(),
102 status = response.status()
103 )?;
104
105 for (k, v) in response.headers() {
106 writeln!(w, "{}: {}\r", k, v.to_str()?)?;
107 }
108
109 writeln!(w, "\r")?;
110
111 Ok(())
112 }
113
114 impl TryParse for Request {
try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>>115 fn try_parse(buf: &[u8]) -> Result<Option<(usize, Self)>> {
116 let mut hbuffer = [httparse::EMPTY_HEADER; MAX_HEADERS];
117 let mut req = httparse::Request::new(&mut hbuffer);
118 Ok(match req.parse(buf)? {
119 Status::Partial => None,
120 Status::Complete(size) => Some((size, Request::from_httparse(req)?)),
121 })
122 }
123 }
124
125 impl<'h, 'b: 'h> FromHttparse<httparse::Request<'h, 'b>> for Request {
from_httparse(raw: httparse::Request<'h, 'b>) -> Result<Self>126 fn from_httparse(raw: httparse::Request<'h, 'b>) -> Result<Self> {
127 if raw.method.expect("Bug: no method in header") != "GET" {
128 return Err(Error::Protocol(ProtocolError::WrongHttpMethod));
129 }
130
131 if raw.version.expect("Bug: no HTTP version") < /*1.*/1 {
132 return Err(Error::Protocol(ProtocolError::WrongHttpVersion));
133 }
134
135 let headers = HeaderMap::from_httparse(raw.headers)?;
136
137 let mut request = Request::new(());
138 *request.method_mut() = http::Method::GET;
139 *request.headers_mut() = headers;
140 *request.uri_mut() = raw.path.expect("Bug: no path in header").parse()?;
141 // TODO: httparse only supports HTTP 0.9/1.0/1.1 but not HTTP 2.0
142 // so the only valid value we could get in the response would be 1.1.
143 *request.version_mut() = http::Version::HTTP_11;
144
145 Ok(request)
146 }
147 }
148
149 /// The callback trait.
150 ///
151 /// The callback is called when the server receives an incoming WebSocket
152 /// handshake request from the client. Specifying a callback allows you to analyze incoming headers
153 /// and add additional headers to the response that server sends to the client and/or reject the
154 /// connection based on the incoming headers.
155 pub trait Callback: Sized {
156 /// Called whenever the server read the request from the client and is ready to reply to it.
157 /// May return additional reply headers.
158 /// Returning an error resulting in rejecting the incoming connection.
on_request( self, request: &Request, response: Response, ) -> StdResult<Response, ErrorResponse>159 fn on_request(
160 self,
161 request: &Request,
162 response: Response,
163 ) -> StdResult<Response, ErrorResponse>;
164 }
165
166 impl<F> Callback for F
167 where
168 F: FnOnce(&Request, Response) -> StdResult<Response, ErrorResponse>,
169 {
on_request( self, request: &Request, response: Response, ) -> StdResult<Response, ErrorResponse>170 fn on_request(
171 self,
172 request: &Request,
173 response: Response,
174 ) -> StdResult<Response, ErrorResponse> {
175 self(request, response)
176 }
177 }
178
179 /// Stub for callback that does nothing.
180 #[derive(Clone, Copy, Debug)]
181 pub struct NoCallback;
182
183 impl Callback for NoCallback {
on_request( self, _request: &Request, response: Response, ) -> StdResult<Response, ErrorResponse>184 fn on_request(
185 self,
186 _request: &Request,
187 response: Response,
188 ) -> StdResult<Response, ErrorResponse> {
189 Ok(response)
190 }
191 }
192
193 /// Server handshake role.
194 #[allow(missing_copy_implementations)]
195 #[derive(Debug)]
196 pub struct ServerHandshake<S, C> {
197 /// Callback which is called whenever the server read the request from the client and is ready
198 /// to reply to it. The callback returns an optional headers which will be added to the reply
199 /// which the server sends to the user.
200 callback: Option<C>,
201 /// WebSocket configuration.
202 config: Option<WebSocketConfig>,
203 /// Error code/flag. If set, an error will be returned after sending response to the client.
204 error_response: Option<ErrorResponse>,
205 /// Internal stream type.
206 _marker: PhantomData<S>,
207 }
208
209 impl<S: Read + Write, C: Callback> ServerHandshake<S, C> {
210 /// Start server handshake. `callback` specifies a custom callback which the user can pass to
211 /// the handshake, this callback will be called when the a websocket client connects to the
212 /// server, you can specify the callback if you want to add additional header to the client
213 /// upon join based on the incoming headers.
start(stream: S, callback: C, config: Option<WebSocketConfig>) -> MidHandshake<Self>214 pub fn start(stream: S, callback: C, config: Option<WebSocketConfig>) -> MidHandshake<Self> {
215 trace!("Server handshake initiated.");
216 MidHandshake {
217 machine: HandshakeMachine::start_read(stream),
218 role: ServerHandshake {
219 callback: Some(callback),
220 config,
221 error_response: None,
222 _marker: PhantomData,
223 },
224 }
225 }
226 }
227
228 impl<S: Read + Write, C: Callback> HandshakeRole for ServerHandshake<S, C> {
229 type IncomingData = Request;
230 type InternalStream = S;
231 type FinalResult = WebSocket<S>;
232
stage_finished( &mut self, finish: StageResult<Self::IncomingData, Self::InternalStream>, ) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>>233 fn stage_finished(
234 &mut self,
235 finish: StageResult<Self::IncomingData, Self::InternalStream>,
236 ) -> Result<ProcessingResult<Self::InternalStream, Self::FinalResult>> {
237 Ok(match finish {
238 StageResult::DoneReading { stream, result, tail } => {
239 if !tail.is_empty() {
240 return Err(Error::Protocol(ProtocolError::JunkAfterRequest));
241 }
242
243 let response = create_response(&result)?;
244 let callback_result = if let Some(callback) = self.callback.take() {
245 callback.on_request(&result, response)
246 } else {
247 Ok(response)
248 };
249
250 match callback_result {
251 Ok(response) => {
252 let mut output = vec![];
253 write_response(&mut output, &response)?;
254 ProcessingResult::Continue(HandshakeMachine::start_write(stream, output))
255 }
256
257 Err(resp) => {
258 if resp.status().is_success() {
259 return Err(Error::Protocol(ProtocolError::CustomResponseSuccessful));
260 }
261
262 self.error_response = Some(resp);
263 let resp = self.error_response.as_ref().unwrap();
264
265 let mut output = vec![];
266 write_response(&mut output, resp)?;
267
268 if let Some(body) = resp.body() {
269 output.extend_from_slice(body.as_bytes());
270 }
271
272 ProcessingResult::Continue(HandshakeMachine::start_write(stream, output))
273 }
274 }
275 }
276
277 StageResult::DoneWriting(stream) => {
278 if let Some(err) = self.error_response.take() {
279 debug!("Server handshake failed.");
280
281 let (parts, body) = err.into_parts();
282 let body = body.map(|b| b.as_bytes().to_vec());
283 return Err(Error::Http(http::Response::from_parts(parts, body)));
284 } else {
285 debug!("Server handshake done.");
286 let websocket = WebSocket::from_raw_socket(stream, Role::Server, self.config);
287 ProcessingResult::Done(websocket)
288 }
289 }
290 })
291 }
292 }
293
294 #[cfg(test)]
295 mod tests {
296 use super::{super::machine::TryParse, create_response, Request};
297
298 #[test]
request_parsing()299 fn request_parsing() {
300 const DATA: &[u8] = b"GET /script.ws HTTP/1.1\r\nHost: foo.com\r\n\r\n";
301 let (_, req) = Request::try_parse(DATA).unwrap().unwrap();
302 assert_eq!(req.uri().path(), "/script.ws");
303 assert_eq!(req.headers().get("Host").unwrap(), &b"foo.com"[..]);
304 }
305
306 #[test]
request_replying()307 fn request_replying() {
308 const DATA: &[u8] = b"\
309 GET /script.ws HTTP/1.1\r\n\
310 Host: foo.com\r\n\
311 Connection: upgrade\r\n\
312 Upgrade: websocket\r\n\
313 Sec-WebSocket-Version: 13\r\n\
314 Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
315 \r\n";
316 let (_, req) = Request::try_parse(DATA).unwrap().unwrap();
317 let response = create_response(&req).unwrap();
318
319 assert_eq!(
320 response.headers().get("Sec-WebSocket-Accept").unwrap(),
321 b"s3pPLMBiTxaQ9kYGzzhZRbK+xOo=".as_ref()
322 );
323 }
324 }
325