diff --git a/sys_util/src/lib.rs b/sys_util/src/lib.rs index ab4e66efd6..d1bcf164b0 100644 --- a/sys_util/src/lib.rs +++ b/sys_util/src/lib.rs @@ -34,8 +34,10 @@ pub mod net; mod passwd; mod poll; mod priority; +pub mod rand; mod raw_fd; pub mod sched; +pub mod scoped_path; pub mod scoped_signal_handler; mod seek_hole; mod shm; @@ -45,6 +47,7 @@ mod sock_ctrl_msg; mod struct_util; mod terminal; mod timerfd; +pub mod vsock; mod write_zeroes; pub use crate::alloc::LayoutAllocation; diff --git a/sys_util/src/net.rs b/sys_util/src/net.rs index a4f2d2ec68..b59fb0f212 100644 --- a/sys_util/src/net.rs +++ b/sys_util/src/net.rs @@ -5,23 +5,282 @@ use std::ffi::OsString; use std::fs::remove_file; use std::io; -use std::mem; +use std::mem::{self, size_of}; +use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6, TcpListener, TcpStream, ToSocketAddrs}; use std::ops::Deref; use std::os::unix::{ ffi::{OsStrExt, OsStringExt}, - io::{AsRawFd, FromRawFd, RawFd}, + io::{AsRawFd, FromRawFd, IntoRawFd, RawFd}, }; use std::path::Path; use std::path::PathBuf; use std::ptr::null_mut; use std::time::Duration; -use libc::{recvfrom, MSG_PEEK, MSG_TRUNC}; +use libc::{ + c_int, in6_addr, in_addr, recvfrom, sa_family_t, sockaddr, sockaddr_in, sockaddr_in6, + socklen_t, AF_INET, AF_INET6, MSG_PEEK, MSG_TRUNC, SOCK_CLOEXEC, SOCK_STREAM, +}; use serde::{Deserialize, Serialize}; use crate::sock_ctrl_msg::{ScmSocket, SCM_SOCKET_MAX_FD_COUNT}; use crate::{AsRawDescriptor, RawDescriptor}; +/// Assist in handling both IP version 4 and IP version 6. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum InetVersion { + V4, + V6, +} + +impl InetVersion { + pub fn from_sockaddr(s: &SocketAddr) -> Self { + match s { + SocketAddr::V4(_) => InetVersion::V4, + SocketAddr::V6(_) => InetVersion::V6, + } + } +} + +impl From for sa_family_t { + fn from(v: InetVersion) -> sa_family_t { + match v { + InetVersion::V4 => AF_INET as sa_family_t, + InetVersion::V6 => AF_INET6 as sa_family_t, + } + } +} + +fn sockaddrv4_to_lib_c(s: &SocketAddrV4) -> sockaddr_in { + sockaddr_in { + sin_family: AF_INET as sa_family_t, + sin_port: s.port().to_be(), + sin_addr: in_addr { + s_addr: u32::from_ne_bytes(s.ip().octets()), + }, + sin_zero: [0; 8], + } +} + +fn sockaddrv6_to_lib_c(s: &SocketAddrV6) -> sockaddr_in6 { + sockaddr_in6 { + sin6_family: AF_INET6 as sa_family_t, + sin6_port: s.port().to_be(), + sin6_flowinfo: 0, + sin6_addr: in6_addr { + s6_addr: s.ip().octets(), + }, + sin6_scope_id: 0, + } +} + +/// A TCP socket. +/// +/// Do not use this class unless you need to change socket options or query the +/// state of the socket prior to calling listen or connect. Instead use either TcpStream or +/// TcpListener. +#[derive(Debug)] +pub struct TcpSocket { + inet_version: InetVersion, + fd: RawFd, +} + +impl TcpSocket { + pub fn new(inet_version: InetVersion) -> io::Result { + let fd = unsafe { + libc::socket( + Into::::into(inet_version) as c_int, + SOCK_STREAM | SOCK_CLOEXEC, + 0, + ) + }; + if fd < 0 { + Err(io::Error::last_os_error()) + } else { + Ok(TcpSocket { inet_version, fd }) + } + } + + pub fn bind(&mut self, addr: A) -> io::Result<()> { + let sockaddr = addr + .to_socket_addrs() + .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))? + .next() + .unwrap(); + + let ret = match sockaddr { + SocketAddr::V4(a) => { + let sin = sockaddrv4_to_lib_c(&a); + // Safe because this doesn't modify any memory and we check the return value. + unsafe { + libc::bind( + self.fd, + &sin as *const sockaddr_in as *const sockaddr, + size_of::() as socklen_t, + ) + } + } + SocketAddr::V6(a) => { + let sin6 = sockaddrv6_to_lib_c(&a); + // Safe because this doesn't modify any memory and we check the return value. + unsafe { + libc::bind( + self.fd, + &sin6 as *const sockaddr_in6 as *const sockaddr, + size_of::() as socklen_t, + ) + } + } + }; + if ret < 0 { + let bind_err = io::Error::last_os_error(); + Err(bind_err) + } else { + Ok(()) + } + } + + pub fn connect(self, addr: A) -> io::Result { + let sockaddr = addr + .to_socket_addrs() + .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))? + .next() + .unwrap(); + + let ret = match sockaddr { + SocketAddr::V4(a) => { + let sin = sockaddrv4_to_lib_c(&a); + // Safe because this doesn't modify any memory and we check the return value. + unsafe { + libc::connect( + self.fd, + &sin as *const sockaddr_in as *const sockaddr, + size_of::() as socklen_t, + ) + } + } + SocketAddr::V6(a) => { + let sin6 = sockaddrv6_to_lib_c(&a); + // Safe because this doesn't modify any memory and we check the return value. + unsafe { + libc::connect( + self.fd, + &sin6 as *const sockaddr_in6 as *const sockaddr, + size_of::() as socklen_t, + ) + } + } + }; + + if ret < 0 { + let connect_err = io::Error::last_os_error(); + Err(connect_err) + } else { + // Safe because the ownership of the raw fd is released from self and taken over by the + // new TcpStream. + Ok(unsafe { TcpStream::from_raw_fd(self.into_raw_fd()) }) + } + } + + pub fn listen(self) -> io::Result { + // Safe because this doesn't modify any memory and we check the return value. + let ret = unsafe { libc::listen(self.fd, 1) }; + if ret < 0 { + let listen_err = io::Error::last_os_error(); + Err(listen_err) + } else { + // Safe because the ownership of the raw fd is released from self and taken over by the + // new TcpListener. + Ok(unsafe { TcpListener::from_raw_fd(self.into_raw_fd()) }) + } + } + + /// Returns the port that this socket is bound to. This can only succeed after bind is called. + pub fn local_port(&self) -> io::Result { + match self.inet_version { + InetVersion::V4 => { + let mut sin = sockaddr_in { + sin_family: 0, + sin_port: 0, + sin_addr: in_addr { s_addr: 0 }, + sin_zero: [0; 8], + }; + + // Safe because we give a valid pointer for addrlen and check the length. + let mut addrlen = size_of::() as socklen_t; + let ret = unsafe { + // Get the socket address that was actually bound. + libc::getsockname( + self.fd, + &mut sin as *mut sockaddr_in as *mut sockaddr, + &mut addrlen as *mut socklen_t, + ) + }; + if ret < 0 { + let getsockname_err = io::Error::last_os_error(); + Err(getsockname_err) + } else { + // If this doesn't match, it's not safe to get the port out of the sockaddr. + assert_eq!(addrlen as usize, size_of::()); + + Ok(u16::from_be(sin.sin_port)) + } + } + InetVersion::V6 => { + let mut sin6 = sockaddr_in6 { + sin6_family: 0, + sin6_port: 0, + sin6_flowinfo: 0, + sin6_addr: in6_addr { s6_addr: [0; 16] }, + sin6_scope_id: 0, + }; + + // Safe because we give a valid pointer for addrlen and check the length. + let mut addrlen = size_of::() as socklen_t; + let ret = unsafe { + // Get the socket address that was actually bound. + libc::getsockname( + self.fd, + &mut sin6 as *mut sockaddr_in6 as *mut sockaddr, + &mut addrlen as *mut socklen_t, + ) + }; + if ret < 0 { + let getsockname_err = io::Error::last_os_error(); + Err(getsockname_err) + } else { + // If this doesn't match, it's not safe to get the port out of the sockaddr. + assert_eq!(addrlen as usize, size_of::()); + + Ok(u16::from_be(sin6.sin6_port)) + } + } + } + } +} + +impl IntoRawFd for TcpSocket { + fn into_raw_fd(self) -> RawFd { + let fd = self.fd; + mem::forget(self); + fd + } +} + +impl AsRawFd for TcpSocket { + fn as_raw_fd(&self) -> RawFd { + self.fd + } +} + +impl Drop for TcpSocket { + fn drop(&mut self) { + // Safe because this doesn't modify any memory and we are the only + // owner of the file descriptor. + unsafe { libc::close(self.fd) }; + } +} + // Offset of sun_path in structure sockaddr_un. fn sun_path_offset() -> usize { // Prefer 0 to null() so that we do not need to subtract from the `sub_path` pointer. diff --git a/sys_util/src/rand.rs b/sys_util/src/rand.rs new file mode 100644 index 0000000000..7c77993971 --- /dev/null +++ b/sys_util/src/rand.rs @@ -0,0 +1,114 @@ +// Copyright 2021 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +//! Rust implementation of functionality parallel to libchrome's base/rand_util.h. + +use std::thread::sleep; +use std::time::Duration; + +use libc::{c_uint, c_void}; + +use crate::{ + errno::{errno_result, Result}, + handle_eintr_errno, +}; + +/// How long to wait before calling getrandom again if it does not return +/// enough bytes. +const POLL_INTERVAL: Duration = Duration::from_millis(50); + +/// Represents whether or not the random bytes are pulled from the source of +/// /dev/random or /dev/urandom. +#[derive(Debug, Clone, Eq, PartialEq)] +pub enum Source { + // This is the default and uses the same source as /dev/urandom. + Pseudorandom, + // This uses the same source as /dev/random and may be. + Random, +} + +impl Default for Source { + fn default() -> Self { + Source::Pseudorandom + } +} + +impl Source { + fn to_getrandom_flags(&self) -> c_uint { + match self { + Source::Random => libc::GRND_RANDOM, + Source::Pseudorandom => 0, + } + } +} + +/// Fills `output` completely with random bytes from the specified `source`. +pub fn rand_bytes(mut output: &mut [u8], source: Source) -> Result<()> { + if output.is_empty() { + return Ok(()); + } + + loop { + // Safe because output is mutable and the writes are limited by output.len(). + let bytes = handle_eintr_errno!(unsafe { + libc::getrandom( + output.as_mut_ptr() as *mut c_void, + output.len(), + source.to_getrandom_flags(), + ) + }); + + if bytes < 0 { + return errno_result(); + } + if bytes as usize == output.len() { + return Ok(()); + } + + // Wait for more entropy and try again for the remaining bytes. + sleep(POLL_INTERVAL); + output = &mut output[bytes as usize..]; + } +} + +/// Allocates a vector of length `len` filled with random bytes from the +/// specified `source`. +pub fn rand_vec(len: usize, source: Source) -> Result> { + let mut rand = Vec::with_capacity(len); + if len == 0 { + return Ok(rand); + } + + // Safe because rand will either be initialized by getrandom or dropped. + unsafe { rand.set_len(len) }; + rand_bytes(rand.as_mut_slice(), source)?; + Ok(rand) +} + +#[cfg(test)] +mod tests { + use super::*; + + const TEST_SIZE: usize = 64; + + #[test] + fn randbytes_success() { + let mut rand = vec![0u8; TEST_SIZE]; + rand_bytes(&mut rand, Source::Pseudorandom).unwrap(); + assert_ne!(&rand, &[0u8; TEST_SIZE]); + } + + #[test] + fn randvec_success() { + let rand = rand_vec(TEST_SIZE, Source::Pseudorandom).unwrap(); + assert_eq!(rand.len(), TEST_SIZE); + assert_ne!(&rand, &[0u8; TEST_SIZE]); + } + + #[test] + fn sourcerandom_success() { + let rand = rand_vec(TEST_SIZE, Source::Random).unwrap(); + assert_ne!(&rand, &[0u8; TEST_SIZE]); + } +} diff --git a/sys_util/src/scoped_path.rs b/sys_util/src/scoped_path.rs new file mode 100644 index 0000000000..745137eaec --- /dev/null +++ b/sys_util/src/scoped_path.rs @@ -0,0 +1,138 @@ +// Copyright 2021 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +use std::env::{current_exe, temp_dir}; +use std::fs::{create_dir_all, remove_dir_all}; +use std::io::Result; +use std::ops::Deref; +use std::path::{Path, PathBuf}; +use std::thread::panicking; + +use crate::{getpid, gettid}; + +/// Returns a stable path based on the label, pid, and tid. If the label isn't provided the +/// current_exe is used instead. +pub fn get_temp_path(label: Option<&str>) -> PathBuf { + if let Some(label) = label { + temp_dir().join(format!("{}-{}-{}", label, getpid(), gettid())) + } else { + get_temp_path(Some( + current_exe() + .unwrap() + .file_name() + .unwrap() + .to_str() + .unwrap(), + )) + } +} + +/// Automatically deletes the path it contains when it goes out of scope unless it is a test and +/// drop is called after a panic!. +/// +/// This is particularly useful for creating temporary directories for use with tests. +pub struct ScopedPath>(P); + +impl> ScopedPath

{ + pub fn create(p: P) -> Result { + create_dir_all(p.as_ref())?; + Ok(ScopedPath(p)) + } +} + +impl> AsRef for ScopedPath

{ + fn as_ref(&self) -> &Path { + self.0.as_ref() + } +} + +impl> Deref for ScopedPath

{ + type Target = Path; + + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } +} + +impl> Drop for ScopedPath

{ + fn drop(&mut self) { + // Leave the files on a failed test run for debugging. + if panicking() && cfg!(test) { + eprintln!("NOTE: Not removing {}", self.display()); + return; + } + if let Err(e) = remove_dir_all(&**self) { + eprintln!("Failed to remove {}: {}", self.display(), e); + } + } +} + +#[cfg(test)] +pub(crate) mod tests { + use super::*; + + use std::panic::catch_unwind; + + #[test] + fn gettemppath() { + assert_ne!("", get_temp_path(None).to_string_lossy()); + assert!(get_temp_path(None).starts_with(temp_dir())); + assert_eq!( + get_temp_path(None), + get_temp_path(Some( + current_exe() + .unwrap() + .file_name() + .unwrap() + .to_str() + .unwrap() + )) + ); + assert_ne!( + get_temp_path(Some("label")), + get_temp_path(Some( + current_exe() + .unwrap() + .file_name() + .unwrap() + .to_str() + .unwrap() + )) + ); + } + + #[test] + fn scopedpath_exists() { + let tmp_path = get_temp_path(None); + { + let scoped_path = ScopedPath::create(&tmp_path).unwrap(); + assert!(scoped_path.exists()); + } + assert!(!tmp_path.exists()); + } + + #[test] + fn scopedpath_notexists() { + let tmp_path = get_temp_path(None); + { + let _scoped_path = ScopedPath(&tmp_path); + } + assert!(!tmp_path.exists()); + } + + #[test] + fn scopedpath_panic() { + let tmp_path = get_temp_path(None); + assert!(catch_unwind(|| { + { + let scoped_path = ScopedPath::create(&tmp_path).unwrap(); + assert!(scoped_path.exists()); + panic!() + } + }) + .is_err()); + assert!(tmp_path.exists()); + remove_dir_all(&tmp_path).unwrap(); + } +} diff --git a/sys_util/src/vsock.rs b/sys_util/src/vsock.rs new file mode 100644 index 0000000000..a51dfdbfcb --- /dev/null +++ b/sys_util/src/vsock.rs @@ -0,0 +1,495 @@ +// Copyright 2021 The Chromium OS Authors. All rights reserved. +// Use of this source code is governed by a BSD-style license that can be +// found in the LICENSE file. + +/// Support for virtual sockets. +use std::fmt; +use std::io; +use std::mem::{self, size_of}; +use std::num::ParseIntError; +use std::os::raw::{c_uchar, c_uint, c_ushort}; +use std::os::unix::io::{AsRawFd, IntoRawFd, RawFd}; +use std::result; +use std::str::FromStr; + +use libc::{ + self, c_void, sa_family_t, size_t, sockaddr, socklen_t, F_GETFL, F_SETFL, O_NONBLOCK, + VMADDR_CID_ANY, VMADDR_CID_HOST, VMADDR_CID_HYPERVISOR, +}; + +// The domain for vsock sockets. +const AF_VSOCK: sa_family_t = 40; + +// Vsock loopback address. +const VMADDR_CID_LOCAL: c_uint = 1; + +/// Vsock equivalent of binding on port 0. Binds to a random port. +pub const VMADDR_PORT_ANY: c_uint = c_uint::max_value(); + +// The number of bytes of padding to be added to the sockaddr_vm struct. Taken directly +// from linux/vm_sockets.h. +const PADDING: usize = size_of::() + - size_of::() + - size_of::() + - (2 * size_of::()); + +#[repr(C)] +#[derive(Default)] +struct sockaddr_vm { + svm_family: sa_family_t, + svm_reserved1: c_ushort, + svm_port: c_uint, + svm_cid: c_uint, + svm_zero: [c_uchar; PADDING], +} + +#[derive(Debug)] +pub struct AddrParseError; + +impl fmt::Display for AddrParseError { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + write!(fmt, "failed to parse vsock address") + } +} + +/// The vsock equivalent of an IP address. +#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq)] +pub enum VsockCid { + /// Vsock equivalent of INADDR_ANY. Indicates the context id of the current endpoint. + Any, + /// An address that refers to the bare-metal machine that serves as the hypervisor. + Hypervisor, + /// The loopback address. + Local, + /// The parent machine. It may not be the hypervisor for nested VMs. + Host, + /// An assigned CID that serves as the address for VSOCK. + Cid(c_uint), +} + +impl fmt::Display for VsockCid { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + match &self { + VsockCid::Any => write!(fmt, "Any"), + VsockCid::Hypervisor => write!(fmt, "Hypervisor"), + VsockCid::Local => write!(fmt, "Local"), + VsockCid::Host => write!(fmt, "Host"), + VsockCid::Cid(c) => write!(fmt, "'{}'", c), + } + } +} + +impl From for VsockCid { + fn from(c: c_uint) -> Self { + match c { + VMADDR_CID_ANY => VsockCid::Any, + VMADDR_CID_HYPERVISOR => VsockCid::Hypervisor, + VMADDR_CID_LOCAL => VsockCid::Local, + VMADDR_CID_HOST => VsockCid::Host, + _ => VsockCid::Cid(c), + } + } +} + +impl FromStr for VsockCid { + type Err = ParseIntError; + + fn from_str(s: &str) -> Result { + let c: c_uint = s.parse()?; + Ok(c.into()) + } +} + +impl From for c_uint { + fn from(cid: VsockCid) -> c_uint { + match cid { + VsockCid::Any => VMADDR_CID_ANY, + VsockCid::Hypervisor => VMADDR_CID_HYPERVISOR, + VsockCid::Local => VMADDR_CID_LOCAL, + VsockCid::Host => VMADDR_CID_HOST, + VsockCid::Cid(c) => c, + } + } +} + +/// An address associated with a virtual socket. +#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq)] +pub struct SocketAddr { + pub cid: VsockCid, + pub port: c_uint, +} + +pub trait ToSocketAddr { + fn to_socket_addr(&self) -> result::Result; +} + +impl ToSocketAddr for SocketAddr { + fn to_socket_addr(&self) -> result::Result { + Ok(*self) + } +} + +impl ToSocketAddr for str { + fn to_socket_addr(&self) -> result::Result { + self.parse() + } +} + +impl ToSocketAddr for (VsockCid, c_uint) { + fn to_socket_addr(&self) -> result::Result { + let (cid, port) = *self; + Ok(SocketAddr { cid, port }) + } +} + +impl<'a, T: ToSocketAddr + ?Sized> ToSocketAddr for &'a T { + fn to_socket_addr(&self) -> result::Result { + (**self).to_socket_addr() + } +} + +impl FromStr for SocketAddr { + type Err = AddrParseError; + + /// Parse a vsock SocketAddr from a string. vsock socket addresses are of the form + /// "vsock:cid:port". + fn from_str(s: &str) -> Result { + let components: Vec<&str> = s.split(':').collect(); + if components.len() != 3 || components[0] != "vsock" { + return Err(AddrParseError); + } + + Ok(SocketAddr { + cid: components[1].parse().map_err(|_| AddrParseError)?, + port: components[2].parse().map_err(|_| AddrParseError)?, + }) + } +} + +impl fmt::Display for SocketAddr { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + write!(fmt, "{}:{}", self.cid, self.port) + } +} + +/// Sets `fd` to be blocking or nonblocking. `fd` must be a valid fd of a type that accepts the +/// `O_NONBLOCK` flag. This includes regular files, pipes, and sockets. +unsafe fn set_nonblocking(fd: RawFd, nonblocking: bool) -> io::Result<()> { + let flags = libc::fcntl(fd, F_GETFL, 0); + if flags < 0 { + return Err(io::Error::last_os_error()); + } + + let flags = if nonblocking { + flags | O_NONBLOCK + } else { + flags & !O_NONBLOCK + }; + + let ret = libc::fcntl(fd, F_SETFL, flags); + if ret < 0 { + return Err(io::Error::last_os_error()); + } + + Ok(()) +} + +/// A virtual socket. +/// +/// Do not use this class unless you need to change socket options or query the +/// state of the socket prior to calling listen or connect. Instead use either VsockStream or +/// VsockListener. +#[derive(Debug)] +pub struct VsockSocket { + fd: RawFd, +} + +impl VsockSocket { + pub fn new() -> io::Result { + let fd = unsafe { libc::socket(libc::AF_VSOCK, libc::SOCK_STREAM | libc::SOCK_CLOEXEC, 0) }; + if fd < 0 { + Err(io::Error::last_os_error()) + } else { + Ok(VsockSocket { fd }) + } + } + + pub fn bind(&mut self, addr: A) -> io::Result<()> { + let sockaddr = addr + .to_socket_addr() + .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?; + + // The compiler should optimize this out since these are both compile-time constants. + assert_eq!(size_of::(), size_of::()); + + let svm = sockaddr_vm { + svm_family: AF_VSOCK, + svm_cid: sockaddr.cid.into(), + svm_port: sockaddr.port, + ..Default::default() + }; + + // Safe because this doesn't modify any memory and we check the return value. + let ret = unsafe { + libc::bind( + self.fd, + &svm as *const sockaddr_vm as *const sockaddr, + size_of::() as socklen_t, + ) + }; + if ret < 0 { + let bind_err = io::Error::last_os_error(); + Err(bind_err) + } else { + Ok(()) + } + } + + pub fn connect(self, addr: A) -> io::Result { + let sockaddr = addr + .to_socket_addr() + .map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?; + + let svm = sockaddr_vm { + svm_family: AF_VSOCK, + svm_cid: sockaddr.cid.into(), + svm_port: sockaddr.port, + ..Default::default() + }; + + // Safe because this just connects a vsock socket, and the return value is checked. + let ret = unsafe { + libc::connect( + self.fd, + &svm as *const sockaddr_vm as *const sockaddr, + size_of::() as socklen_t, + ) + }; + if ret < 0 { + let connect_err = io::Error::last_os_error(); + Err(connect_err) + } else { + Ok(VsockStream { sock: self }) + } + } + + pub fn listen(self) -> io::Result { + // Safe because this doesn't modify any memory and we check the return value. + let ret = unsafe { libc::listen(self.fd, 1) }; + if ret < 0 { + let listen_err = io::Error::last_os_error(); + return Err(listen_err); + } + Ok(VsockListener { sock: self }) + } + + /// Returns the port that this socket is bound to. This can only succeed after bind is called. + pub fn local_port(&self) -> io::Result { + let mut svm: sockaddr_vm = Default::default(); + + // Safe because we give a valid pointer for addrlen and check the length. + let mut addrlen = size_of::() as socklen_t; + let ret = unsafe { + // Get the socket address that was actually bound. + libc::getsockname( + self.fd, + &mut svm as *mut sockaddr_vm as *mut sockaddr, + &mut addrlen as *mut socklen_t, + ) + }; + if ret < 0 { + let getsockname_err = io::Error::last_os_error(); + Err(getsockname_err) + } else { + // If this doesn't match, it's not safe to get the port out of the sockaddr. + assert_eq!(addrlen as usize, size_of::()); + + Ok(svm.svm_port) + } + } + + pub fn try_clone(&self) -> io::Result { + // Safe because this doesn't modify any memory and we check the return value. + let dup_fd = unsafe { libc::fcntl(self.fd, libc::F_DUPFD_CLOEXEC, 0) }; + if dup_fd < 0 { + Err(io::Error::last_os_error()) + } else { + Ok(Self { fd: dup_fd }) + } + } + + pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> { + // Safe because the fd is valid and owned by this stream. + unsafe { set_nonblocking(self.fd, nonblocking) } + } +} + +impl IntoRawFd for VsockSocket { + fn into_raw_fd(self) -> RawFd { + let fd = self.fd; + mem::forget(self); + fd + } +} + +impl AsRawFd for VsockSocket { + fn as_raw_fd(&self) -> RawFd { + self.fd + } +} + +impl Drop for VsockSocket { + fn drop(&mut self) { + // Safe because this doesn't modify any memory and we are the only + // owner of the file descriptor. + unsafe { libc::close(self.fd) }; + } +} + +/// A virtual stream socket. +#[derive(Debug)] +pub struct VsockStream { + sock: VsockSocket, +} + +impl VsockStream { + pub fn connect(addr: A) -> io::Result { + let sock = VsockSocket::new()?; + sock.connect(addr) + } + + /// Returns the port that this stream is bound to. + pub fn local_port(&self) -> io::Result { + self.sock.local_port() + } + + pub fn try_clone(&self) -> io::Result { + self.sock.try_clone().map(|f| VsockStream { sock: f }) + } + + pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> { + self.sock.set_nonblocking(nonblocking) + } +} + +impl io::Read for VsockStream { + fn read(&mut self, buf: &mut [u8]) -> io::Result { + // Safe because this will only modify the contents of |buf| and we check the return value. + let ret = unsafe { + libc::read( + self.sock.as_raw_fd(), + buf as *mut [u8] as *mut c_void, + buf.len() as size_t, + ) + }; + if ret < 0 { + return Err(io::Error::last_os_error()); + } + + Ok(ret as usize) + } +} + +impl io::Write for VsockStream { + fn write(&mut self, buf: &[u8]) -> io::Result { + // Safe because this doesn't modify any memory and we check the return value. + let ret = unsafe { + libc::write( + self.sock.as_raw_fd(), + buf as *const [u8] as *const c_void, + buf.len() as size_t, + ) + }; + if ret < 0 { + return Err(io::Error::last_os_error()); + } + + Ok(ret as usize) + } + + fn flush(&mut self) -> io::Result<()> { + // No buffered data so nothing to do. + Ok(()) + } +} + +impl AsRawFd for VsockStream { + fn as_raw_fd(&self) -> RawFd { + self.sock.as_raw_fd() + } +} + +impl IntoRawFd for VsockStream { + fn into_raw_fd(self) -> RawFd { + self.sock.into_raw_fd() + } +} + +/// Represents a virtual socket server. +#[derive(Debug)] +pub struct VsockListener { + sock: VsockSocket, +} + +impl VsockListener { + /// Creates a new `VsockListener` bound to the specified port on the current virtual socket + /// endpoint. + pub fn bind(addr: A) -> io::Result { + let mut sock = VsockSocket::new()?; + sock.bind(addr)?; + sock.listen() + } + + /// Returns the port that this listener is bound to. + pub fn local_port(&self) -> io::Result { + self.sock.local_port() + } + + /// Accepts a new incoming connection on this listener. Blocks the calling thread until a + /// new connection is established. When established, returns the corresponding `VsockStream` + /// and the remote peer's address. + pub fn accept(&self) -> io::Result<(VsockStream, SocketAddr)> { + let mut svm: sockaddr_vm = Default::default(); + + // Safe because this will only modify |svm| and we check the return value. + let mut socklen: socklen_t = size_of::() as socklen_t; + let fd = unsafe { + libc::accept4( + self.sock.as_raw_fd(), + &mut svm as *mut sockaddr_vm as *mut sockaddr, + &mut socklen as *mut socklen_t, + libc::SOCK_CLOEXEC, + ) + }; + if fd < 0 { + return Err(io::Error::last_os_error()); + } + + if svm.svm_family != AF_VSOCK { + return Err(io::Error::new( + io::ErrorKind::InvalidData, + format!("unexpected address family: {}", svm.svm_family), + )); + } + + Ok(( + VsockStream { + sock: VsockSocket { fd }, + }, + SocketAddr { + cid: svm.svm_cid.into(), + port: svm.svm_port, + }, + )) + } + + pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> { + self.sock.set_nonblocking(nonblocking) + } +} + +impl AsRawFd for VsockListener { + fn as_raw_fd(&self) -> RawFd { + self.sock.as_raw_fd() + } +}