1 /*
2 * Copyright 2019 fsyncd, Berlin, Germany.
3 * Additional material Copyright the Rust project and it's contributors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 //! Virtio socket support for Rust.
19
20 use libc::{
21 accept4, ioctl, sa_family_t, sockaddr, sockaddr_vm, socklen_t, suseconds_t, timeval, AF_VSOCK,
22 FIONBIO, SOCK_CLOEXEC,
23 };
24 use nix::{
25 ioctl_read_bad,
26 sys::socket::{
27 self, bind, connect, getpeername, getsockname, listen, recv, send, shutdown, socket,
28 sockopt::{ReceiveTimeout, SendTimeout, SocketError},
29 AddressFamily, Backlog, GetSockOpt, MsgFlags, SetSockOpt, SockFlag, SockType,
30 },
31 };
32 use std::mem::size_of;
33 use std::net::Shutdown;
34 use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
35 use std::time::Duration;
36 use std::{fs::File, os::fd::OwnedFd};
37 use std::{
38 io::{Error, ErrorKind, Read, Result, Write},
39 os::fd::{AsFd, BorrowedFd},
40 };
41
42 pub use libc::{VMADDR_CID_ANY, VMADDR_CID_HOST, VMADDR_CID_HYPERVISOR, VMADDR_CID_LOCAL};
43 pub use nix::sys::socket::{SockaddrLike, VsockAddr};
44
new_socket() -> Result<OwnedFd>45 fn new_socket() -> Result<OwnedFd> {
46 Ok(socket(
47 AddressFamily::Vsock,
48 SockType::Stream,
49 SockFlag::SOCK_CLOEXEC,
50 None,
51 )?)
52 }
53
54 /// An iterator that infinitely accepts connections on a VsockListener.
55 #[derive(Debug)]
56 pub struct Incoming<'a> {
57 listener: &'a VsockListener,
58 }
59
60 impl<'a> Iterator for Incoming<'a> {
61 type Item = Result<VsockStream>;
62
next(&mut self) -> Option<Result<VsockStream>>63 fn next(&mut self) -> Option<Result<VsockStream>> {
64 Some(self.listener.accept().map(|p| p.0))
65 }
66 }
67
68 /// A virtio socket server, listening for connections.
69 #[derive(Debug)]
70 pub struct VsockListener {
71 socket: OwnedFd,
72 }
73
74 impl VsockListener {
75 /// Create a new VsockListener which is bound and listening on the socket address.
bind(addr: &impl SockaddrLike) -> Result<Self>76 pub fn bind(addr: &impl SockaddrLike) -> Result<Self> {
77 if addr.family() != Some(AddressFamily::Vsock) {
78 return Err(Error::new(
79 ErrorKind::Other,
80 "requires a virtio socket address",
81 ));
82 }
83
84 let socket = new_socket()?;
85
86 bind(socket.as_raw_fd(), addr)?;
87
88 // rust stdlib uses a 128 connection backlog
89 listen(&socket, Backlog::new(128).unwrap_or(Backlog::MAXCONN))?;
90
91 Ok(Self { socket })
92 }
93
94 /// Create a new VsockListener with specified cid and port.
bind_with_cid_port(cid: u32, port: u32) -> Result<VsockListener>95 pub fn bind_with_cid_port(cid: u32, port: u32) -> Result<VsockListener> {
96 Self::bind(&VsockAddr::new(cid, port))
97 }
98
99 /// The local socket address of the listener.
local_addr(&self) -> Result<VsockAddr>100 pub fn local_addr(&self) -> Result<VsockAddr> {
101 Ok(getsockname(self.socket.as_raw_fd())?)
102 }
103
104 /// Create a new independently owned handle to the underlying socket.
try_clone(&self) -> Result<Self>105 pub fn try_clone(&self) -> Result<Self> {
106 Ok(Self {
107 socket: self.socket.try_clone()?,
108 })
109 }
110
111 /// Accept a new incoming connection from this listener.
accept(&self) -> Result<(VsockStream, VsockAddr)>112 pub fn accept(&self) -> Result<(VsockStream, VsockAddr)> {
113 let mut vsock_addr = sockaddr_vm {
114 svm_family: AF_VSOCK as sa_family_t,
115 svm_reserved1: 0,
116 svm_port: 0,
117 svm_cid: 0,
118 svm_zero: [0u8; 4],
119 };
120 let mut vsock_addr_len = size_of::<sockaddr_vm>() as socklen_t;
121 let socket = unsafe {
122 accept4(
123 self.socket.as_raw_fd(),
124 &mut vsock_addr as *mut _ as *mut sockaddr,
125 &mut vsock_addr_len,
126 SOCK_CLOEXEC,
127 )
128 };
129 if socket < 0 {
130 Err(Error::last_os_error())
131 } else {
132 Ok((
133 unsafe { VsockStream::from_raw_fd(socket as RawFd) },
134 VsockAddr::new(vsock_addr.svm_cid, vsock_addr.svm_port),
135 ))
136 }
137 }
138
139 /// An iterator over the connections being received on this listener.
incoming(&self) -> Incoming140 pub fn incoming(&self) -> Incoming {
141 Incoming { listener: self }
142 }
143
144 /// Retrieve the latest error associated with the underlying socket.
take_error(&self) -> Result<Option<Error>>145 pub fn take_error(&self) -> Result<Option<Error>> {
146 let error = SocketError.get(&self.socket)?;
147 Ok(if error == 0 {
148 None
149 } else {
150 Some(Error::from_raw_os_error(error))
151 })
152 }
153
154 /// Move this stream in and out of nonblocking mode.
set_nonblocking(&self, nonblocking: bool) -> Result<()>155 pub fn set_nonblocking(&self, nonblocking: bool) -> Result<()> {
156 let mut nonblocking: i32 = if nonblocking { 1 } else { 0 };
157 if unsafe { ioctl(self.socket.as_raw_fd(), FIONBIO, &mut nonblocking) } < 0 {
158 Err(Error::last_os_error())
159 } else {
160 Ok(())
161 }
162 }
163 }
164
165 impl AsRawFd for VsockListener {
as_raw_fd(&self) -> RawFd166 fn as_raw_fd(&self) -> RawFd {
167 self.socket.as_raw_fd()
168 }
169 }
170
171 impl AsFd for VsockListener {
as_fd(&self) -> BorrowedFd172 fn as_fd(&self) -> BorrowedFd {
173 self.socket.as_fd()
174 }
175 }
176
177 impl FromRawFd for VsockListener {
from_raw_fd(socket: RawFd) -> Self178 unsafe fn from_raw_fd(socket: RawFd) -> Self {
179 Self {
180 socket: OwnedFd::from_raw_fd(socket),
181 }
182 }
183 }
184
185 impl IntoRawFd for VsockListener {
into_raw_fd(self) -> RawFd186 fn into_raw_fd(self) -> RawFd {
187 self.socket.into_raw_fd()
188 }
189 }
190
191 /// A virtio stream between a local and a remote socket.
192 #[derive(Debug)]
193 pub struct VsockStream {
194 socket: OwnedFd,
195 }
196
197 impl VsockStream {
198 /// Open a connection to a remote host.
connect(addr: &impl SockaddrLike) -> Result<Self>199 pub fn connect(addr: &impl SockaddrLike) -> Result<Self> {
200 if addr.family() != Some(AddressFamily::Vsock) {
201 return Err(Error::new(
202 ErrorKind::Other,
203 "requires a virtio socket address",
204 ));
205 }
206
207 let socket = new_socket()?;
208 connect(socket.as_raw_fd(), addr)?;
209 Ok(Self { socket })
210 }
211
212 /// Open a connection to a remote host with specified cid and port.
connect_with_cid_port(cid: u32, port: u32) -> Result<Self>213 pub fn connect_with_cid_port(cid: u32, port: u32) -> Result<Self> {
214 Self::connect(&VsockAddr::new(cid, port))
215 }
216
217 /// Virtio socket address of the remote peer associated with this connection.
peer_addr(&self) -> Result<VsockAddr>218 pub fn peer_addr(&self) -> Result<VsockAddr> {
219 Ok(getpeername(self.socket.as_raw_fd())?)
220 }
221
222 /// Virtio socket address of the local address associated with this connection.
local_addr(&self) -> Result<VsockAddr>223 pub fn local_addr(&self) -> Result<VsockAddr> {
224 Ok(getsockname(self.socket.as_raw_fd())?)
225 }
226
227 /// Shutdown the read, write, or both halves of this connection.
shutdown(&self, how: Shutdown) -> Result<()>228 pub fn shutdown(&self, how: Shutdown) -> Result<()> {
229 let how = match how {
230 Shutdown::Write => socket::Shutdown::Write,
231 Shutdown::Read => socket::Shutdown::Read,
232 Shutdown::Both => socket::Shutdown::Both,
233 };
234 Ok(shutdown(self.socket.as_raw_fd(), how)?)
235 }
236
237 /// Create a new independently owned handle to the underlying socket.
try_clone(&self) -> Result<Self>238 pub fn try_clone(&self) -> Result<Self> {
239 Ok(Self {
240 socket: self.socket.try_clone()?,
241 })
242 }
243
244 /// Set the timeout on read operations.
set_read_timeout(&self, dur: Option<Duration>) -> Result<()>245 pub fn set_read_timeout(&self, dur: Option<Duration>) -> Result<()> {
246 let timeout = Self::timeval_from_duration(dur)?.into();
247 Ok(ReceiveTimeout.set(&self.socket, &timeout)?)
248 }
249
250 /// Set the timeout on write operations.
set_write_timeout(&self, dur: Option<Duration>) -> Result<()>251 pub fn set_write_timeout(&self, dur: Option<Duration>) -> Result<()> {
252 let timeout = Self::timeval_from_duration(dur)?.into();
253 Ok(SendTimeout.set(&self.socket, &timeout)?)
254 }
255
256 /// Retrieve the latest error associated with the underlying socket.
take_error(&self) -> Result<Option<Error>>257 pub fn take_error(&self) -> Result<Option<Error>> {
258 let error = SocketError.get(&self.socket)?;
259 Ok(if error == 0 {
260 None
261 } else {
262 Some(Error::from_raw_os_error(error))
263 })
264 }
265
266 /// Move this stream in and out of nonblocking mode.
set_nonblocking(&self, nonblocking: bool) -> Result<()>267 pub fn set_nonblocking(&self, nonblocking: bool) -> Result<()> {
268 let mut nonblocking: i32 = if nonblocking { 1 } else { 0 };
269 if unsafe { ioctl(self.socket.as_raw_fd(), FIONBIO, &mut nonblocking) } < 0 {
270 Err(Error::last_os_error())
271 } else {
272 Ok(())
273 }
274 }
275
timeval_from_duration(dur: Option<Duration>) -> Result<timeval>276 fn timeval_from_duration(dur: Option<Duration>) -> Result<timeval> {
277 match dur {
278 Some(dur) => {
279 if dur.as_secs() == 0 && dur.subsec_nanos() == 0 {
280 return Err(Error::new(
281 ErrorKind::InvalidInput,
282 "cannot set a zero duration timeout",
283 ));
284 }
285
286 // https://github.com/rust-lang/libc/issues/1848
287 #[cfg_attr(target_env = "musl", allow(deprecated))]
288 let secs = if dur.as_secs() > libc::time_t::max_value() as u64 {
289 libc::time_t::max_value()
290 } else {
291 dur.as_secs() as libc::time_t
292 };
293 let mut timeout = timeval {
294 tv_sec: secs,
295 tv_usec: i64::from(dur.subsec_micros()) as suseconds_t,
296 };
297 if timeout.tv_sec == 0 && timeout.tv_usec == 0 {
298 timeout.tv_usec = 1;
299 }
300 Ok(timeout)
301 }
302 None => Ok(timeval {
303 tv_sec: 0,
304 tv_usec: 0,
305 }),
306 }
307 }
308 }
309
310 impl Read for VsockStream {
read(&mut self, buf: &mut [u8]) -> Result<usize>311 fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
312 <&Self>::read(&mut &*self, buf)
313 }
314 }
315
316 impl Write for VsockStream {
write(&mut self, buf: &[u8]) -> Result<usize>317 fn write(&mut self, buf: &[u8]) -> Result<usize> {
318 <&Self>::write(&mut &*self, buf)
319 }
320
flush(&mut self) -> Result<()>321 fn flush(&mut self) -> Result<()> {
322 Ok(())
323 }
324 }
325
326 impl Read for &VsockStream {
read(&mut self, buf: &mut [u8]) -> Result<usize>327 fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
328 Ok(recv(self.socket.as_raw_fd(), buf, MsgFlags::empty())?)
329 }
330 }
331
332 impl Write for &VsockStream {
write(&mut self, buf: &[u8]) -> Result<usize>333 fn write(&mut self, buf: &[u8]) -> Result<usize> {
334 Ok(send(self.socket.as_raw_fd(), buf, MsgFlags::MSG_NOSIGNAL)?)
335 }
336
flush(&mut self) -> Result<()>337 fn flush(&mut self) -> Result<()> {
338 Ok(())
339 }
340 }
341
342 impl AsRawFd for VsockStream {
as_raw_fd(&self) -> RawFd343 fn as_raw_fd(&self) -> RawFd {
344 self.socket.as_raw_fd()
345 }
346 }
347
348 impl AsFd for VsockStream {
as_fd(&self) -> BorrowedFd349 fn as_fd(&self) -> BorrowedFd {
350 self.socket.as_fd()
351 }
352 }
353
354 impl FromRawFd for VsockStream {
from_raw_fd(socket: RawFd) -> Self355 unsafe fn from_raw_fd(socket: RawFd) -> Self {
356 Self {
357 socket: OwnedFd::from_raw_fd(socket),
358 }
359 }
360 }
361
362 impl IntoRawFd for VsockStream {
into_raw_fd(self) -> RawFd363 fn into_raw_fd(self) -> RawFd {
364 self.socket.into_raw_fd()
365 }
366 }
367
368 const IOCTL_VM_SOCKETS_GET_LOCAL_CID: usize = 0x7b9;
369 ioctl_read_bad!(
370 vm_sockets_get_local_cid,
371 IOCTL_VM_SOCKETS_GET_LOCAL_CID,
372 u32
373 );
374
375 /// Gets the CID of the local machine.
376 ///
377 /// Note that when calling [`VsockListener::bind`], you should generally use [`VMADDR_CID_ANY`]
378 /// instead, and for making a loopback connection you should use [`VMADDR_CID_LOCAL`].
get_local_cid() -> Result<u32>379 pub fn get_local_cid() -> Result<u32> {
380 let f = File::open("/dev/vsock")?;
381 let mut cid = 0;
382 // SAFETY: the kernel only modifies the given u32 integer.
383 unsafe { vm_sockets_get_local_cid(f.as_raw_fd(), &mut cid) }?;
384 Ok(cid)
385 }
386