1 use std::fmt;
2 use std::marker::PhantomData;
3 use std::mem::ManuallyDrop;
4 use std::ops::Deref;
5 #[cfg(unix)]
6 use std::os::unix::io::{AsFd, AsRawFd, FromRawFd};
7 #[cfg(windows)]
8 use std::os::windows::io::{AsRawSocket, AsSocket, FromRawSocket};
9 
10 use crate::Socket;
11 
12 /// A reference to a [`Socket`] that can be used to configure socket types other
13 /// than the `Socket` type itself.
14 ///
15 /// This allows for example a [`TcpStream`], found in the standard library, to
16 /// be configured using all the additional methods found in the [`Socket`] API.
17 ///
18 /// `SockRef` can be created from any socket type that implements [`AsFd`]
19 /// (Unix) or [`AsSocket`] (Windows) using the [`From`] implementation.
20 ///
21 /// [`TcpStream`]: std::net::TcpStream
22 // Don't use intra-doc links because they won't build on every platform.
23 /// [`AsFd`]: https://doc.rust-lang.org/stable/std/os/unix/io/trait.AsFd.html
24 /// [`AsSocket`]: https://doc.rust-lang.org/stable/std/os/windows/io/trait.AsSocket.html
25 ///
26 /// # Examples
27 ///
28 /// Below is an example of converting a [`TcpStream`] into a [`SockRef`].
29 ///
30 /// ```
31 /// use std::net::{TcpStream, SocketAddr};
32 ///
33 /// use socket2::SockRef;
34 ///
35 /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
36 /// // Create `TcpStream` from the standard library.
37 /// let address: SocketAddr = "127.0.0.1:1234".parse()?;
38 /// # let b1 = std::sync::Arc::new(std::sync::Barrier::new(2));
39 /// # let b2 = b1.clone();
40 /// # let handle = std::thread::spawn(move || {
41 /// #    let listener = std::net::TcpListener::bind(address).unwrap();
42 /// #    b2.wait();
43 /// #    let (stream, _) = listener.accept().unwrap();
44 /// #    std::thread::sleep(std::time::Duration::from_millis(10));
45 /// #    drop(stream);
46 /// # });
47 /// # b1.wait();
48 /// let stream = TcpStream::connect(address)?;
49 ///
50 /// // Create a `SockRef`erence to the stream.
51 /// let socket_ref = SockRef::from(&stream);
52 /// // Use `Socket::set_nodelay` on the stream.
53 /// socket_ref.set_nodelay(true)?;
54 /// drop(socket_ref);
55 ///
56 /// assert_eq!(stream.nodelay()?, true);
57 /// # handle.join().unwrap();
58 /// # Ok(())
59 /// # }
60 /// ```
61 pub struct SockRef<'s> {
62     /// Because this is a reference we don't own the `Socket`, however `Socket`
63     /// closes itself when dropped, so we use `ManuallyDrop` to prevent it from
64     /// closing itself.
65     socket: ManuallyDrop<Socket>,
66     /// Because we don't own the socket we need to ensure the socket remains
67     /// open while we have a "reference" to it, the lifetime `'s` ensures this.
68     _lifetime: PhantomData<&'s Socket>,
69 }
70 
71 impl<'s> Deref for SockRef<'s> {
72     type Target = Socket;
73 
deref(&self) -> &Self::Target74     fn deref(&self) -> &Self::Target {
75         &self.socket
76     }
77 }
78 
79 /// On Windows, a corresponding `From<&impl AsSocket>` implementation exists.
80 #[cfg(unix)]
81 #[cfg_attr(docsrs, doc(cfg(unix)))]
82 impl<'s, S> From<&'s S> for SockRef<'s>
83 where
84     S: AsFd,
85 {
86     /// The caller must ensure `S` is actually a socket.
from(socket: &'s S) -> Self87     fn from(socket: &'s S) -> Self {
88         let fd = socket.as_fd().as_raw_fd();
89         assert!(fd >= 0);
90         SockRef {
91             socket: ManuallyDrop::new(unsafe { Socket::from_raw_fd(fd) }),
92             _lifetime: PhantomData,
93         }
94     }
95 }
96 
97 /// On Unix, a corresponding `From<&impl AsFd>` implementation exists.
98 #[cfg(windows)]
99 #[cfg_attr(docsrs, doc(cfg(windows)))]
100 impl<'s, S> From<&'s S> for SockRef<'s>
101 where
102     S: AsSocket,
103 {
104     /// See the `From<&impl AsFd>` implementation.
from(socket: &'s S) -> Self105     fn from(socket: &'s S) -> Self {
106         let socket = socket.as_socket().as_raw_socket();
107         assert!(socket != windows_sys::Win32::Networking::WinSock::INVALID_SOCKET as _);
108         SockRef {
109             socket: ManuallyDrop::new(unsafe { Socket::from_raw_socket(socket) }),
110             _lifetime: PhantomData,
111         }
112     }
113 }
114 
115 impl fmt::Debug for SockRef<'_> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result116     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
117         f.debug_struct("SockRef")
118             .field("raw", &self.socket.as_raw())
119             .field("local_addr", &self.socket.local_addr().ok())
120             .field("peer_addr", &self.socket.peer_addr().ok())
121             .finish()
122     }
123 }
124