1 use std::fmt; 2 use std::marker::PhantomData; 3 use std::mem::ManuallyDrop; 4 use std::ops::Deref; 5 #[cfg(unix)] 6 use std::os::unix::io::{AsFd, AsRawFd, FromRawFd}; 7 #[cfg(windows)] 8 use std::os::windows::io::{AsRawSocket, AsSocket, FromRawSocket}; 9 10 use crate::Socket; 11 12 /// A reference to a [`Socket`] that can be used to configure socket types other 13 /// than the `Socket` type itself. 14 /// 15 /// This allows for example a [`TcpStream`], found in the standard library, to 16 /// be configured using all the additional methods found in the [`Socket`] API. 17 /// 18 /// `SockRef` can be created from any socket type that implements [`AsFd`] 19 /// (Unix) or [`AsSocket`] (Windows) using the [`From`] implementation. 20 /// 21 /// [`TcpStream`]: std::net::TcpStream 22 // Don't use intra-doc links because they won't build on every platform. 23 /// [`AsFd`]: https://doc.rust-lang.org/stable/std/os/unix/io/trait.AsFd.html 24 /// [`AsSocket`]: https://doc.rust-lang.org/stable/std/os/windows/io/trait.AsSocket.html 25 /// 26 /// # Examples 27 /// 28 /// Below is an example of converting a [`TcpStream`] into a [`SockRef`]. 29 /// 30 /// ``` 31 /// use std::net::{TcpStream, SocketAddr}; 32 /// 33 /// use socket2::SockRef; 34 /// 35 /// # fn main() -> Result<(), Box<dyn std::error::Error>> { 36 /// // Create `TcpStream` from the standard library. 37 /// let address: SocketAddr = "127.0.0.1:1234".parse()?; 38 /// # let b1 = std::sync::Arc::new(std::sync::Barrier::new(2)); 39 /// # let b2 = b1.clone(); 40 /// # let handle = std::thread::spawn(move || { 41 /// # let listener = std::net::TcpListener::bind(address).unwrap(); 42 /// # b2.wait(); 43 /// # let (stream, _) = listener.accept().unwrap(); 44 /// # std::thread::sleep(std::time::Duration::from_millis(10)); 45 /// # drop(stream); 46 /// # }); 47 /// # b1.wait(); 48 /// let stream = TcpStream::connect(address)?; 49 /// 50 /// // Create a `SockRef`erence to the stream. 51 /// let socket_ref = SockRef::from(&stream); 52 /// // Use `Socket::set_nodelay` on the stream. 53 /// socket_ref.set_nodelay(true)?; 54 /// drop(socket_ref); 55 /// 56 /// assert_eq!(stream.nodelay()?, true); 57 /// # handle.join().unwrap(); 58 /// # Ok(()) 59 /// # } 60 /// ``` 61 pub struct SockRef<'s> { 62 /// Because this is a reference we don't own the `Socket`, however `Socket` 63 /// closes itself when dropped, so we use `ManuallyDrop` to prevent it from 64 /// closing itself. 65 socket: ManuallyDrop<Socket>, 66 /// Because we don't own the socket we need to ensure the socket remains 67 /// open while we have a "reference" to it, the lifetime `'s` ensures this. 68 _lifetime: PhantomData<&'s Socket>, 69 } 70 71 impl<'s> Deref for SockRef<'s> { 72 type Target = Socket; 73 deref(&self) -> &Self::Target74 fn deref(&self) -> &Self::Target { 75 &self.socket 76 } 77 } 78 79 /// On Windows, a corresponding `From<&impl AsSocket>` implementation exists. 80 #[cfg(unix)] 81 #[cfg_attr(docsrs, doc(cfg(unix)))] 82 impl<'s, S> From<&'s S> for SockRef<'s> 83 where 84 S: AsFd, 85 { 86 /// The caller must ensure `S` is actually a socket. from(socket: &'s S) -> Self87 fn from(socket: &'s S) -> Self { 88 let fd = socket.as_fd().as_raw_fd(); 89 assert!(fd >= 0); 90 SockRef { 91 socket: ManuallyDrop::new(unsafe { Socket::from_raw_fd(fd) }), 92 _lifetime: PhantomData, 93 } 94 } 95 } 96 97 /// On Unix, a corresponding `From<&impl AsFd>` implementation exists. 98 #[cfg(windows)] 99 #[cfg_attr(docsrs, doc(cfg(windows)))] 100 impl<'s, S> From<&'s S> for SockRef<'s> 101 where 102 S: AsSocket, 103 { 104 /// See the `From<&impl AsFd>` implementation. from(socket: &'s S) -> Self105 fn from(socket: &'s S) -> Self { 106 let socket = socket.as_socket().as_raw_socket(); 107 assert!(socket != windows_sys::Win32::Networking::WinSock::INVALID_SOCKET as _); 108 SockRef { 109 socket: ManuallyDrop::new(unsafe { Socket::from_raw_socket(socket) }), 110 _lifetime: PhantomData, 111 } 112 } 113 } 114 115 impl fmt::Debug for SockRef<'_> { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result116 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 117 f.debug_struct("SockRef") 118 .field("raw", &self.socket.as_raw()) 119 .field("local_addr", &self.socket.local_addr().ok()) 120 .field("peer_addr", &self.socket.peer_addr().ok()) 121 .finish() 122 } 123 } 124