1 //! Methods to connect to a WebSocket as a client.
2
3 use std::{
4 io::{Read, Write},
5 net::{SocketAddr, TcpStream, ToSocketAddrs},
6 result::Result as StdResult,
7 };
8
9 use http::{request::Parts, Uri};
10 use log::*;
11
12 use url::Url;
13
14 use crate::{
15 handshake::client::{generate_key, Request, Response},
16 protocol::WebSocketConfig,
17 stream::MaybeTlsStream,
18 };
19
20 use crate::{
21 error::{Error, Result, UrlError},
22 handshake::{client::ClientHandshake, HandshakeError},
23 protocol::WebSocket,
24 stream::{Mode, NoDelay},
25 };
26
27 /// Connect to the given WebSocket in blocking mode.
28 ///
29 /// Uses a websocket configuration passed as an argument to the function. Calling it with `None` is
30 /// equal to calling `connect()` function.
31 ///
32 /// The URL may be either ws:// or wss://.
33 /// To support wss:// URLs, you must activate the TLS feature on the crate level. Please refer to the
34 /// project's [README][readme] for more information on available features.
35 ///
36 /// This function "just works" for those who wants a simple blocking solution
37 /// similar to `std::net::TcpStream`. If you want a non-blocking or other
38 /// custom stream, call `client` instead.
39 ///
40 /// This function uses `native_tls` or `rustls` to do TLS depending on the feature flags enabled. If
41 /// you want to use other TLS libraries, use `client` instead. There is no need to enable any of
42 /// the `*-tls` features if you don't call `connect` since it's the only function that uses them.
43 ///
44 /// [readme]: https://github.com/snapview/tungstenite-rs/#features
connect_with_config<Req: IntoClientRequest>( request: Req, config: Option<WebSocketConfig>, max_redirects: u8, ) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)>45 pub fn connect_with_config<Req: IntoClientRequest>(
46 request: Req,
47 config: Option<WebSocketConfig>,
48 max_redirects: u8,
49 ) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
50 fn try_client_handshake(
51 request: Request,
52 config: Option<WebSocketConfig>,
53 ) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
54 let uri = request.uri();
55 let mode = uri_mode(uri)?;
56 let host = request.uri().host().ok_or(Error::Url(UrlError::NoHostName))?;
57 let host = if host.starts_with('[') { &host[1..host.len() - 1] } else { host };
58 let port = uri.port_u16().unwrap_or(match mode {
59 Mode::Plain => 80,
60 Mode::Tls => 443,
61 });
62 let addrs = (host, port).to_socket_addrs()?;
63 let mut stream = connect_to_some(addrs.as_slice(), request.uri())?;
64 NoDelay::set_nodelay(&mut stream, true)?;
65
66 #[cfg(not(any(feature = "native-tls", feature = "__rustls-tls")))]
67 let client = client_with_config(request, MaybeTlsStream::Plain(stream), config);
68 #[cfg(any(feature = "native-tls", feature = "__rustls-tls"))]
69 let client = crate::tls::client_tls_with_config(request, stream, config, None);
70
71 client.map_err(|e| match e {
72 HandshakeError::Failure(f) => f,
73 HandshakeError::Interrupted(_) => panic!("Bug: blocking handshake not blocked"),
74 })
75 }
76
77 fn create_request(parts: &Parts, uri: &Uri) -> Request {
78 let mut builder =
79 Request::builder().uri(uri.clone()).method(parts.method.clone()).version(parts.version);
80 *builder.headers_mut().expect("Failed to create `Request`") = parts.headers.clone();
81 builder.body(()).expect("Failed to create `Request`")
82 }
83
84 let (parts, _) = request.into_client_request()?.into_parts();
85 let mut uri = parts.uri.clone();
86
87 for attempt in 0..(max_redirects + 1) {
88 let request = create_request(&parts, &uri);
89
90 match try_client_handshake(request, config) {
91 Err(Error::Http(res)) if res.status().is_redirection() && attempt < max_redirects => {
92 if let Some(location) = res.headers().get("Location") {
93 uri = location.to_str()?.parse::<Uri>()?;
94 debug!("Redirecting to {:?}", uri);
95 continue;
96 } else {
97 warn!("No `Location` found in redirect");
98 return Err(Error::Http(res));
99 }
100 }
101 other => return other,
102 }
103 }
104
105 unreachable!("Bug in a redirect handling logic")
106 }
107
108 /// Connect to the given WebSocket in blocking mode.
109 ///
110 /// The URL may be either ws:// or wss://.
111 /// To support wss:// URLs, feature `native-tls` or `rustls-tls` must be turned on.
112 ///
113 /// This function "just works" for those who wants a simple blocking solution
114 /// similar to `std::net::TcpStream`. If you want a non-blocking or other
115 /// custom stream, call `client` instead.
116 ///
117 /// This function uses `native_tls` or `rustls` to do TLS depending on the feature flags enabled. If
118 /// you want to use other TLS libraries, use `client` instead. There is no need to enable any of
119 /// the `*-tls` features if you don't call `connect` since it's the only function that uses them.
connect<Req: IntoClientRequest>( request: Req, ) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)>120 pub fn connect<Req: IntoClientRequest>(
121 request: Req,
122 ) -> Result<(WebSocket<MaybeTlsStream<TcpStream>>, Response)> {
123 connect_with_config(request, None, 3)
124 }
125
connect_to_some(addrs: &[SocketAddr], uri: &Uri) -> Result<TcpStream>126 fn connect_to_some(addrs: &[SocketAddr], uri: &Uri) -> Result<TcpStream> {
127 for addr in addrs {
128 debug!("Trying to contact {} at {}...", uri, addr);
129 if let Ok(stream) = TcpStream::connect(addr) {
130 return Ok(stream);
131 }
132 }
133 Err(Error::Url(UrlError::UnableToConnect(uri.to_string())))
134 }
135
136 /// Get the mode of the given URL.
137 ///
138 /// This function may be used to ease the creation of custom TLS streams
139 /// in non-blocking algorithms or for use with TLS libraries other than `native_tls` or `rustls`.
uri_mode(uri: &Uri) -> Result<Mode>140 pub fn uri_mode(uri: &Uri) -> Result<Mode> {
141 match uri.scheme_str() {
142 Some("ws") => Ok(Mode::Plain),
143 Some("wss") => Ok(Mode::Tls),
144 _ => Err(Error::Url(UrlError::UnsupportedUrlScheme)),
145 }
146 }
147
148 /// Do the client handshake over the given stream given a web socket configuration. Passing `None`
149 /// as configuration is equal to calling `client()` function.
150 ///
151 /// Use this function if you need a nonblocking handshake support or if you
152 /// want to use a custom stream like `mio::net::TcpStream` or `openssl::ssl::SslStream`.
153 /// Any stream supporting `Read + Write` will do.
client_with_config<Stream, Req>( request: Req, stream: Stream, config: Option<WebSocketConfig>, ) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>> where Stream: Read + Write, Req: IntoClientRequest,154 pub fn client_with_config<Stream, Req>(
155 request: Req,
156 stream: Stream,
157 config: Option<WebSocketConfig>,
158 ) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
159 where
160 Stream: Read + Write,
161 Req: IntoClientRequest,
162 {
163 ClientHandshake::start(stream, request.into_client_request()?, config)?.handshake()
164 }
165
166 /// Do the client handshake over the given stream.
167 ///
168 /// Use this function if you need a nonblocking handshake support or if you
169 /// want to use a custom stream like `mio::net::TcpStream` or `openssl::ssl::SslStream`.
170 /// Any stream supporting `Read + Write` will do.
client<Stream, Req>( request: Req, stream: Stream, ) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>> where Stream: Read + Write, Req: IntoClientRequest,171 pub fn client<Stream, Req>(
172 request: Req,
173 stream: Stream,
174 ) -> StdResult<(WebSocket<Stream>, Response), HandshakeError<ClientHandshake<Stream>>>
175 where
176 Stream: Read + Write,
177 Req: IntoClientRequest,
178 {
179 client_with_config(request, stream, None)
180 }
181
182 /// Trait for converting various types into HTTP requests used for a client connection.
183 ///
184 /// This trait is implemented by default for string slices, strings, `url::Url`, `http::Uri` and
185 /// `http::Request<()>`. Note that the implementation for `http::Request<()>` is trivial and will
186 /// simply take your request and pass it as is further without altering any headers or URLs, so
187 /// be aware of this. If you just want to connect to the endpoint with a certain URL, better pass
188 /// a regular string containing the URL in which case `tungstenite-rs` will take care for generating
189 /// the proper `http::Request<()>` for you.
190 pub trait IntoClientRequest {
191 /// Convert into a `Request` that can be used for a client connection.
into_client_request(self) -> Result<Request>192 fn into_client_request(self) -> Result<Request>;
193 }
194
195 impl<'a> IntoClientRequest for &'a str {
into_client_request(self) -> Result<Request>196 fn into_client_request(self) -> Result<Request> {
197 self.parse::<Uri>()?.into_client_request()
198 }
199 }
200
201 impl<'a> IntoClientRequest for &'a String {
into_client_request(self) -> Result<Request>202 fn into_client_request(self) -> Result<Request> {
203 <&str as IntoClientRequest>::into_client_request(self)
204 }
205 }
206
207 impl IntoClientRequest for String {
into_client_request(self) -> Result<Request>208 fn into_client_request(self) -> Result<Request> {
209 <&str as IntoClientRequest>::into_client_request(&self)
210 }
211 }
212
213 impl<'a> IntoClientRequest for &'a Uri {
into_client_request(self) -> Result<Request>214 fn into_client_request(self) -> Result<Request> {
215 self.clone().into_client_request()
216 }
217 }
218
219 impl IntoClientRequest for Uri {
into_client_request(self) -> Result<Request>220 fn into_client_request(self) -> Result<Request> {
221 let authority = self.authority().ok_or(Error::Url(UrlError::NoHostName))?.as_str();
222 let host = authority
223 .find('@')
224 .map(|idx| authority.split_at(idx + 1).1)
225 .unwrap_or_else(|| authority);
226
227 if host.is_empty() {
228 return Err(Error::Url(UrlError::EmptyHostName));
229 }
230
231 let req = Request::builder()
232 .method("GET")
233 .header("Host", host)
234 .header("Connection", "Upgrade")
235 .header("Upgrade", "websocket")
236 .header("Sec-WebSocket-Version", "13")
237 .header("Sec-WebSocket-Key", generate_key())
238 .uri(self)
239 .body(())?;
240 Ok(req)
241 }
242 }
243
244 impl<'a> IntoClientRequest for &'a Url {
into_client_request(self) -> Result<Request>245 fn into_client_request(self) -> Result<Request> {
246 self.as_str().into_client_request()
247 }
248 }
249
250 impl IntoClientRequest for Url {
into_client_request(self) -> Result<Request>251 fn into_client_request(self) -> Result<Request> {
252 self.as_str().into_client_request()
253 }
254 }
255
256 impl IntoClientRequest for Request {
into_client_request(self) -> Result<Request>257 fn into_client_request(self) -> Result<Request> {
258 Ok(self)
259 }
260 }
261
262 impl<'h, 'b> IntoClientRequest for httparse::Request<'h, 'b> {
into_client_request(self) -> Result<Request>263 fn into_client_request(self) -> Result<Request> {
264 use crate::handshake::headers::FromHttparse;
265 Request::from_httparse(self)
266 }
267 }
268