1 /*
2  * Copyright 2019 fsyncd, Berlin, Germany.
3  * Additional material Copyright the Rust project and it's contributors.
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
18 //! Virtio socket support for Rust.
19 
20 use libc::{
21     accept4, ioctl, sa_family_t, sockaddr, sockaddr_vm, socklen_t, suseconds_t, timeval, AF_VSOCK,
22     FIONBIO, SOCK_CLOEXEC,
23 };
24 use nix::{
25     ioctl_read_bad,
26     sys::socket::{
27         self, bind, connect, getpeername, getsockname, listen, recv, send, shutdown, socket,
28         sockopt::{ReceiveTimeout, SendTimeout, SocketError},
29         AddressFamily, Backlog, GetSockOpt, MsgFlags, SetSockOpt, SockFlag, SockType,
30     },
31 };
32 use std::mem::size_of;
33 use std::net::Shutdown;
34 use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
35 use std::time::Duration;
36 use std::{fs::File, os::fd::OwnedFd};
37 use std::{
38     io::{Error, ErrorKind, Read, Result, Write},
39     os::fd::{AsFd, BorrowedFd},
40 };
41 
42 pub use libc::{VMADDR_CID_ANY, VMADDR_CID_HOST, VMADDR_CID_HYPERVISOR, VMADDR_CID_LOCAL};
43 pub use nix::sys::socket::{SockaddrLike, VsockAddr};
44 
new_socket() -> Result<OwnedFd>45 fn new_socket() -> Result<OwnedFd> {
46     Ok(socket(
47         AddressFamily::Vsock,
48         SockType::Stream,
49         SockFlag::SOCK_CLOEXEC,
50         None,
51     )?)
52 }
53 
54 /// An iterator that infinitely accepts connections on a VsockListener.
55 #[derive(Debug)]
56 pub struct Incoming<'a> {
57     listener: &'a VsockListener,
58 }
59 
60 impl<'a> Iterator for Incoming<'a> {
61     type Item = Result<VsockStream>;
62 
next(&mut self) -> Option<Result<VsockStream>>63     fn next(&mut self) -> Option<Result<VsockStream>> {
64         Some(self.listener.accept().map(|p| p.0))
65     }
66 }
67 
68 /// A virtio socket server, listening for connections.
69 #[derive(Debug)]
70 pub struct VsockListener {
71     socket: OwnedFd,
72 }
73 
74 impl VsockListener {
75     /// Create a new VsockListener which is bound and listening on the socket address.
bind(addr: &impl SockaddrLike) -> Result<Self>76     pub fn bind(addr: &impl SockaddrLike) -> Result<Self> {
77         if addr.family() != Some(AddressFamily::Vsock) {
78             return Err(Error::new(
79                 ErrorKind::Other,
80                 "requires a virtio socket address",
81             ));
82         }
83 
84         let socket = new_socket()?;
85 
86         bind(socket.as_raw_fd(), addr)?;
87 
88         // rust stdlib uses a 128 connection backlog
89         listen(&socket, Backlog::new(128).unwrap_or(Backlog::MAXCONN))?;
90 
91         Ok(Self { socket })
92     }
93 
94     /// Create a new VsockListener with specified cid and port.
bind_with_cid_port(cid: u32, port: u32) -> Result<VsockListener>95     pub fn bind_with_cid_port(cid: u32, port: u32) -> Result<VsockListener> {
96         Self::bind(&VsockAddr::new(cid, port))
97     }
98 
99     /// The local socket address of the listener.
local_addr(&self) -> Result<VsockAddr>100     pub fn local_addr(&self) -> Result<VsockAddr> {
101         Ok(getsockname(self.socket.as_raw_fd())?)
102     }
103 
104     /// Create a new independently owned handle to the underlying socket.
try_clone(&self) -> Result<Self>105     pub fn try_clone(&self) -> Result<Self> {
106         Ok(Self {
107             socket: self.socket.try_clone()?,
108         })
109     }
110 
111     /// Accept a new incoming connection from this listener.
accept(&self) -> Result<(VsockStream, VsockAddr)>112     pub fn accept(&self) -> Result<(VsockStream, VsockAddr)> {
113         let mut vsock_addr = sockaddr_vm {
114             svm_family: AF_VSOCK as sa_family_t,
115             svm_reserved1: 0,
116             svm_port: 0,
117             svm_cid: 0,
118             svm_zero: [0u8; 4],
119         };
120         let mut vsock_addr_len = size_of::<sockaddr_vm>() as socklen_t;
121         let socket = unsafe {
122             accept4(
123                 self.socket.as_raw_fd(),
124                 &mut vsock_addr as *mut _ as *mut sockaddr,
125                 &mut vsock_addr_len,
126                 SOCK_CLOEXEC,
127             )
128         };
129         if socket < 0 {
130             Err(Error::last_os_error())
131         } else {
132             Ok((
133                 unsafe { VsockStream::from_raw_fd(socket as RawFd) },
134                 VsockAddr::new(vsock_addr.svm_cid, vsock_addr.svm_port),
135             ))
136         }
137     }
138 
139     /// An iterator over the connections being received on this listener.
incoming(&self) -> Incoming140     pub fn incoming(&self) -> Incoming {
141         Incoming { listener: self }
142     }
143 
144     /// Retrieve the latest error associated with the underlying socket.
take_error(&self) -> Result<Option<Error>>145     pub fn take_error(&self) -> Result<Option<Error>> {
146         let error = SocketError.get(&self.socket)?;
147         Ok(if error == 0 {
148             None
149         } else {
150             Some(Error::from_raw_os_error(error))
151         })
152     }
153 
154     /// Move this stream in and out of nonblocking mode.
set_nonblocking(&self, nonblocking: bool) -> Result<()>155     pub fn set_nonblocking(&self, nonblocking: bool) -> Result<()> {
156         let mut nonblocking: i32 = if nonblocking { 1 } else { 0 };
157         if unsafe { ioctl(self.socket.as_raw_fd(), FIONBIO, &mut nonblocking) } < 0 {
158             Err(Error::last_os_error())
159         } else {
160             Ok(())
161         }
162     }
163 }
164 
165 impl AsRawFd for VsockListener {
as_raw_fd(&self) -> RawFd166     fn as_raw_fd(&self) -> RawFd {
167         self.socket.as_raw_fd()
168     }
169 }
170 
171 impl AsFd for VsockListener {
as_fd(&self) -> BorrowedFd172     fn as_fd(&self) -> BorrowedFd {
173         self.socket.as_fd()
174     }
175 }
176 
177 impl FromRawFd for VsockListener {
from_raw_fd(socket: RawFd) -> Self178     unsafe fn from_raw_fd(socket: RawFd) -> Self {
179         Self {
180             socket: OwnedFd::from_raw_fd(socket),
181         }
182     }
183 }
184 
185 impl IntoRawFd for VsockListener {
into_raw_fd(self) -> RawFd186     fn into_raw_fd(self) -> RawFd {
187         self.socket.into_raw_fd()
188     }
189 }
190 
191 /// A virtio stream between a local and a remote socket.
192 #[derive(Debug)]
193 pub struct VsockStream {
194     socket: OwnedFd,
195 }
196 
197 impl VsockStream {
198     /// Open a connection to a remote host.
connect(addr: &impl SockaddrLike) -> Result<Self>199     pub fn connect(addr: &impl SockaddrLike) -> Result<Self> {
200         if addr.family() != Some(AddressFamily::Vsock) {
201             return Err(Error::new(
202                 ErrorKind::Other,
203                 "requires a virtio socket address",
204             ));
205         }
206 
207         let socket = new_socket()?;
208         connect(socket.as_raw_fd(), addr)?;
209         Ok(Self { socket })
210     }
211 
212     /// Open a connection to a remote host with specified cid and port.
connect_with_cid_port(cid: u32, port: u32) -> Result<Self>213     pub fn connect_with_cid_port(cid: u32, port: u32) -> Result<Self> {
214         Self::connect(&VsockAddr::new(cid, port))
215     }
216 
217     /// Virtio socket address of the remote peer associated with this connection.
peer_addr(&self) -> Result<VsockAddr>218     pub fn peer_addr(&self) -> Result<VsockAddr> {
219         Ok(getpeername(self.socket.as_raw_fd())?)
220     }
221 
222     /// Virtio socket address of the local address associated with this connection.
local_addr(&self) -> Result<VsockAddr>223     pub fn local_addr(&self) -> Result<VsockAddr> {
224         Ok(getsockname(self.socket.as_raw_fd())?)
225     }
226 
227     /// Shutdown the read, write, or both halves of this connection.
shutdown(&self, how: Shutdown) -> Result<()>228     pub fn shutdown(&self, how: Shutdown) -> Result<()> {
229         let how = match how {
230             Shutdown::Write => socket::Shutdown::Write,
231             Shutdown::Read => socket::Shutdown::Read,
232             Shutdown::Both => socket::Shutdown::Both,
233         };
234         Ok(shutdown(self.socket.as_raw_fd(), how)?)
235     }
236 
237     /// Create a new independently owned handle to the underlying socket.
try_clone(&self) -> Result<Self>238     pub fn try_clone(&self) -> Result<Self> {
239         Ok(Self {
240             socket: self.socket.try_clone()?,
241         })
242     }
243 
244     /// Set the timeout on read operations.
set_read_timeout(&self, dur: Option<Duration>) -> Result<()>245     pub fn set_read_timeout(&self, dur: Option<Duration>) -> Result<()> {
246         let timeout = Self::timeval_from_duration(dur)?.into();
247         Ok(ReceiveTimeout.set(&self.socket, &timeout)?)
248     }
249 
250     /// Set the timeout on write operations.
set_write_timeout(&self, dur: Option<Duration>) -> Result<()>251     pub fn set_write_timeout(&self, dur: Option<Duration>) -> Result<()> {
252         let timeout = Self::timeval_from_duration(dur)?.into();
253         Ok(SendTimeout.set(&self.socket, &timeout)?)
254     }
255 
256     /// Retrieve the latest error associated with the underlying socket.
take_error(&self) -> Result<Option<Error>>257     pub fn take_error(&self) -> Result<Option<Error>> {
258         let error = SocketError.get(&self.socket)?;
259         Ok(if error == 0 {
260             None
261         } else {
262             Some(Error::from_raw_os_error(error))
263         })
264     }
265 
266     /// Move this stream in and out of nonblocking mode.
set_nonblocking(&self, nonblocking: bool) -> Result<()>267     pub fn set_nonblocking(&self, nonblocking: bool) -> Result<()> {
268         let mut nonblocking: i32 = if nonblocking { 1 } else { 0 };
269         if unsafe { ioctl(self.socket.as_raw_fd(), FIONBIO, &mut nonblocking) } < 0 {
270             Err(Error::last_os_error())
271         } else {
272             Ok(())
273         }
274     }
275 
timeval_from_duration(dur: Option<Duration>) -> Result<timeval>276     fn timeval_from_duration(dur: Option<Duration>) -> Result<timeval> {
277         match dur {
278             Some(dur) => {
279                 if dur.as_secs() == 0 && dur.subsec_nanos() == 0 {
280                     return Err(Error::new(
281                         ErrorKind::InvalidInput,
282                         "cannot set a zero duration timeout",
283                     ));
284                 }
285 
286                 // https://github.com/rust-lang/libc/issues/1848
287                 #[cfg_attr(target_env = "musl", allow(deprecated))]
288                 let secs = if dur.as_secs() > libc::time_t::max_value() as u64 {
289                     libc::time_t::max_value()
290                 } else {
291                     dur.as_secs() as libc::time_t
292                 };
293                 let mut timeout = timeval {
294                     tv_sec: secs,
295                     tv_usec: i64::from(dur.subsec_micros()) as suseconds_t,
296                 };
297                 if timeout.tv_sec == 0 && timeout.tv_usec == 0 {
298                     timeout.tv_usec = 1;
299                 }
300                 Ok(timeout)
301             }
302             None => Ok(timeval {
303                 tv_sec: 0,
304                 tv_usec: 0,
305             }),
306         }
307     }
308 }
309 
310 impl Read for VsockStream {
read(&mut self, buf: &mut [u8]) -> Result<usize>311     fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
312         <&Self>::read(&mut &*self, buf)
313     }
314 }
315 
316 impl Write for VsockStream {
write(&mut self, buf: &[u8]) -> Result<usize>317     fn write(&mut self, buf: &[u8]) -> Result<usize> {
318         <&Self>::write(&mut &*self, buf)
319     }
320 
flush(&mut self) -> Result<()>321     fn flush(&mut self) -> Result<()> {
322         Ok(())
323     }
324 }
325 
326 impl Read for &VsockStream {
read(&mut self, buf: &mut [u8]) -> Result<usize>327     fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
328         Ok(recv(self.socket.as_raw_fd(), buf, MsgFlags::empty())?)
329     }
330 }
331 
332 impl Write for &VsockStream {
write(&mut self, buf: &[u8]) -> Result<usize>333     fn write(&mut self, buf: &[u8]) -> Result<usize> {
334         Ok(send(self.socket.as_raw_fd(), buf, MsgFlags::MSG_NOSIGNAL)?)
335     }
336 
flush(&mut self) -> Result<()>337     fn flush(&mut self) -> Result<()> {
338         Ok(())
339     }
340 }
341 
342 impl AsRawFd for VsockStream {
as_raw_fd(&self) -> RawFd343     fn as_raw_fd(&self) -> RawFd {
344         self.socket.as_raw_fd()
345     }
346 }
347 
348 impl AsFd for VsockStream {
as_fd(&self) -> BorrowedFd349     fn as_fd(&self) -> BorrowedFd {
350         self.socket.as_fd()
351     }
352 }
353 
354 impl FromRawFd for VsockStream {
from_raw_fd(socket: RawFd) -> Self355     unsafe fn from_raw_fd(socket: RawFd) -> Self {
356         Self {
357             socket: OwnedFd::from_raw_fd(socket),
358         }
359     }
360 }
361 
362 impl IntoRawFd for VsockStream {
into_raw_fd(self) -> RawFd363     fn into_raw_fd(self) -> RawFd {
364         self.socket.into_raw_fd()
365     }
366 }
367 
368 const IOCTL_VM_SOCKETS_GET_LOCAL_CID: usize = 0x7b9;
369 ioctl_read_bad!(
370     vm_sockets_get_local_cid,
371     IOCTL_VM_SOCKETS_GET_LOCAL_CID,
372     u32
373 );
374 
375 /// Gets the CID of the local machine.
376 ///
377 /// Note that when calling [`VsockListener::bind`], you should generally use [`VMADDR_CID_ANY`]
378 /// instead, and for making a loopback connection you should use [`VMADDR_CID_LOCAL`].
get_local_cid() -> Result<u32>379 pub fn get_local_cid() -> Result<u32> {
380     let f = File::open("/dev/vsock")?;
381     let mut cid = 0;
382     // SAFETY: the kernel only modifies the given u32 integer.
383     unsafe { vm_sockets_get_local_cid(f.as_raw_fd(), &mut cid) }?;
384     Ok(cid)
385 }
386