1 //! Methods to connect to a WebSocket as a client.
2 
3 use std::{
4     io::{Read, Write},
5     net::{SocketAddr, TcpStream, ToSocketAddrs},
6     result::Result as StdResult,
7 };
8 
9 use http::{request::Parts, Uri};
10 use log::*;
11 
12 use url::Url;
13 
14 use crate::{
15     handshake::client::{generate_key, Request, Response},
16     protocol::WebSocketConfig,
17     stream::MaybeTlsStream,
18 };
19 
20 use crate::{
21     error::{Error, Result, UrlError},
22     handshake::{client::ClientHandshake, HandshakeError},
23     protocol::WebSocket,
24     stream::{Mode, NoDelay},
25 };
26 
27 /// Connect to the given WebSocket in blocking mode.
28 ///
29 /// Uses a websocket configuration passed as an argument to the function. Calling it with `None` is
30 /// equal to calling `connect()` function.
31 ///
32 /// The URL may be either ws:// or wss://.
33 /// To support wss:// URLs, you must activate the TLS feature on the crate level. Please refer to the
34 /// project's [README][readme] for more information on available features.
35 ///
36 /// This function "just works" for those who wants a simple blocking solution
37 /// similar to `std::net::TcpStream`. If you want a non-blocking or other
38 /// custom stream, call `client` instead.
39 ///
40 /// This function uses `native_tls` or `rustls` to do TLS depending on the feature flags enabled. If
41 /// you want to use other TLS libraries, use `client` instead. There is no need to enable any of
42 /// the `*-tls` features if you don't call `connect` since it's the only function that uses them.
43 ///
44 /// [readme]: https://github.com/snapview/tungstenite-rs/#features
connect_with_config<Req: IntoClientRequest>( request: Req, config: Option<WebSocketConfig>, max_redirects: u8, ) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)>45 pub fn connect_with_config<Req: IntoClientRequest>(
46     request: Req,
47     config: Option<WebSocketConfig>,
48     max_redirects: u8,
49 ) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
50     fn try_client_handshake(
51         request: Request,
52         config: Option<WebSocketConfig>,
53     ) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
54         let uri = request.uri();
55         let mode = uri_mode(uri)?;
56         let host = request.uri().host().ok_or(Error::Url(UrlError::NoHostName))?;
57         let host = if host.starts_with('[') { &host[1..host.len() - 1] } else { host };
58         let port = uri.port_u16().unwrap_or(match mode {
59             Mode::Plain => 80,
60             Mode::Tls => 443,
61         });
62         let addrs = (host, port).to_socket_addrs()?;
63         let mut stream = connect_to_some(addrs.as_slice(), request.uri())?;
64         NoDelay::set_nodelay(&mut stream, true)?;
65 
66         #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
67         let client = client_with_config(request, MaybeTlsStream::Plain(stream), config);
68         #[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
69         let client = crate::tls::client_tls_with_config(request, stream, config, None);
70 
71         client.map_err(|e| match e {
72             HandshakeError::Failure(f) => f,
73             HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
74         })
75     }
76 
77     fn create_request(parts: &Parts, uri: &Uri) -> Request {
78         let mut builder =
79             Request::builder().uri(uri.clone()).method(parts.method.clone()).version(parts.version);
80         *builder.headers_mut().expect("Failed to create `Request`") = parts.headers.clone();
81         builder.body(()).expect("Failed to create `Request`")
82     }
83 
84     let (parts, _) = request.into_client_request()?.into_parts();
85     let mut uri = parts.uri.clone();
86 
87     for attempt in 0..(max_redirects + 1) {
88         let request = create_request(&parts, &uri);
89 
90         match try_client_handshake(request, config) {
91             Err(Error::Http(res)) if res.status().is_redirection() && attempt < max_redirects => {
92                 if let Some(location) = res.headers().get("Location") {
93                     uri = location.to_str()?.parse::<Uri>()?;
94                     debug!("Redirecting to {:?}", uri);
95                     continue;
96                 } else {
97                     warn!("No `Location` found in redirect");
98                     return Err(Error::Http(res));
99                 }
100             }
101             other => return other,
102         }
103     }
104 
105     unreachable!("Bug in a redirect handling logic")
106 }
107 
108 /// Connect to the given WebSocket in blocking mode.
109 ///
110 /// The URL may be either ws:// or wss://.
111 /// To support wss:// URLs, feature `native-tls` or `rustls-tls` must be turned on.
112 ///
113 /// This function "just works" for those who wants a simple blocking solution
114 /// similar to `std::net::TcpStream`. If you want a non-blocking or other
115 /// custom stream, call `client` instead.
116 ///
117 /// This function uses `native_tls` or `rustls` to do TLS depending on the feature flags enabled. If
118 /// you want to use other TLS libraries, use `client` instead. There is no need to enable any of
119 /// the `*-tls` features if you don't call `connect` since it's the only function that uses them.
connect<Req: IntoClientRequest>( request: Req, ) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)>120 pub fn connect<Req: IntoClientRequest>(
121     request: Req,
122 ) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
123     connect_with_config(request, None, 3)
124 }
125 
connect_to_some(addrs: &[SocketAddr], uri: &Uri) -> Result<TcpStream>126 fn connect_to_some(addrs: &[SocketAddr], uri: &Uri) -> Result<TcpStream> {
127     for addr in addrs {
128         debug!("Trying to contact {} at {}...", uri, addr);
129         if let Ok(stream) = TcpStream::connect(addr) {
130             return Ok(stream);
131         }
132     }
133     Err(Error::Url(UrlError::UnableToConnect(uri.to_string())))
134 }
135 
136 /// Get the mode of the given URL.
137 ///
138 /// This function may be used to ease the creation of custom TLS streams
139 /// in non-blocking algorithms or for use with TLS libraries other than `native_tls` or `rustls`.
uri_mode(uri: &Uri) -> Result<Mode>140 pub fn uri_mode(uri: &Uri) -> Result<Mode> {
141     match uri.scheme_str() {
142         Some("ws") => Ok(Mode::Plain),
143         Some("wss") => Ok(Mode::Tls),
144         _ => Err(Error::Url(UrlError::UnsupportedUrlScheme)),
145     }
146 }
147 
148 /// Do the client handshake over the given stream given a web socket configuration. Passing `None`
149 /// as configuration is equal to calling `client()` function.
150 ///
151 /// Use this function if you need a nonblocking handshake support or if you
152 /// want to use a custom stream like `mio::net::TcpStream` or `openssl::ssl::SslStream`.
153 /// Any stream supporting `Read + Write` will do.
client_with_config<Stream, Req>( request: Req, stream: Stream, config: Option<WebSocketConfig>, ) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>> where Stream: Read + Write, Req: IntoClientRequest,154 pub fn client_with_config<Stream, Req>(
155     request: Req,
156     stream: Stream,
157     config: Option<WebSocketConfig>,
158 ) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
159 where
160     Stream: Read + Write,
161     Req: IntoClientRequest,
162 {
163     ClientHandshake::start(stream, request.into_client_request()?, config)?.handshake()
164 }
165 
166 /// Do the client handshake over the given stream.
167 ///
168 /// Use this function if you need a nonblocking handshake support or if you
169 /// want to use a custom stream like `mio::net::TcpStream` or `openssl::ssl::SslStream`.
170 /// Any stream supporting `Read + Write` will do.
client<Stream, Req>( request: Req, stream: Stream, ) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>> where Stream: Read + Write, Req: IntoClientRequest,171 pub fn client<Stream, Req>(
172     request: Req,
173     stream: Stream,
174 ) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
175 where
176     Stream: Read + Write,
177     Req: IntoClientRequest,
178 {
179     client_with_config(request, stream, None)
180 }
181 
182 /// Trait for converting various types into HTTP requests used for a client connection.
183 ///
184 /// This trait is implemented by default for string slices, strings, `url::Url`, `http::Uri` and
185 /// `http::Request<()>`. Note that the implementation for `http::Request<()>` is trivial and will
186 /// simply take your request and pass it as is further without altering any headers or URLs, so
187 /// be aware of this. If you just want to connect to the endpoint with a certain URL, better pass
188 /// a regular string containing the URL in which case `tungstenite-rs` will take care for generating
189 /// the proper `http::Request<()>` for you.
190 pub trait IntoClientRequest {
191     /// Convert into a `Request` that can be used for a client connection.
into_client_request(self) -> Result<Request>192     fn into_client_request(self) -> Result<Request>;
193 }
194 
195 impl<'a> IntoClientRequest for &'a str {
into_client_request(self) -> Result<Request>196     fn into_client_request(self) -> Result<Request> {
197         self.parse::<Uri>()?.into_client_request()
198     }
199 }
200 
201 impl<'a> IntoClientRequest for &'a String {
into_client_request(self) -> Result<Request>202     fn into_client_request(self) -> Result<Request> {
203         <&str as IntoClientRequest>::into_client_request(self)
204     }
205 }
206 
207 impl IntoClientRequest for String {
into_client_request(self) -> Result<Request>208     fn into_client_request(self) -> Result<Request> {
209         <&str as IntoClientRequest>::into_client_request(&self)
210     }
211 }
212 
213 impl<'a> IntoClientRequest for &'a Uri {
into_client_request(self) -> Result<Request>214     fn into_client_request(self) -> Result<Request> {
215         self.clone().into_client_request()
216     }
217 }
218 
219 impl IntoClientRequest for Uri {
into_client_request(self) -> Result<Request>220     fn into_client_request(self) -> Result<Request> {
221         let authority = self.authority().ok_or(Error::Url(UrlError::NoHostName))?.as_str();
222         let host = authority
223             .find('@')
224             .map(|idx| authority.split_at(idx + 1).1)
225             .unwrap_or_else(|| authority);
226 
227         if host.is_empty() {
228             return Err(Error::Url(UrlError::EmptyHostName));
229         }
230 
231         let req = Request::builder()
232             .method("GET")
233             .header("Host", host)
234             .header("Connection", "Upgrade")
235             .header("Upgrade", "websocket")
236             .header("Sec-WebSocket-Version", "13")
237             .header("Sec-WebSocket-Key", generate_key())
238             .uri(self)
239             .body(())?;
240         Ok(req)
241     }
242 }
243 
244 impl<'a> IntoClientRequest for &'a Url {
into_client_request(self) -> Result<Request>245     fn into_client_request(self) -> Result<Request> {
246         self.as_str().into_client_request()
247     }
248 }
249 
250 impl IntoClientRequest for Url {
into_client_request(self) -> Result<Request>251     fn into_client_request(self) -> Result<Request> {
252         self.as_str().into_client_request()
253     }
254 }
255 
256 impl IntoClientRequest for Request {
into_client_request(self) -> Result<Request>257     fn into_client_request(self) -> Result<Request> {
258         Ok(self)
259     }
260 }
261 
262 impl<'h, 'b> IntoClientRequest for httparse::Request<'h, 'b> {
into_client_request(self) -> Result<Request>263     fn into_client_request(self) -> Result<Request> {
264         use crate::handshake::headers::FromHttparse;
265         Request::from_httparse(self)
266     }
267 }
268