sys_util: Migrate additional code from libchromeos-rs.

This migrates:
* libchromeos::base::rand_util
* libchromeos::net
* libchromeos::scoped_path
* libchromeos::vsock

BUG=chromium:1193155
TEST=cargo test -- --test-threads=1

Change-Id: I801c7cbf8001dcd386792bb120781dacb9a699c7
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/crosvm/+/2832311
Tested-by: Allen Webb <allenwebb@google.com>
Tested-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Stephen Barber <smbarber@chromium.org>
Commit-Queue: Allen Webb <allenwebb@google.com>
This commit is contained in:
Allen Webb 2021-04-16 13:12:10 -05:00 committed by Commit Bot
parent e35d827ec4
commit c9de52bea5
5 changed files with 1012 additions and 3 deletions

View file

@ -34,8 +34,10 @@ pub mod net;
mod passwd;
mod poll;
mod priority;
pub mod rand;
mod raw_fd;
pub mod sched;
pub mod scoped_path;
pub mod scoped_signal_handler;
mod seek_hole;
mod shm;
@ -45,6 +47,7 @@ mod sock_ctrl_msg;
mod struct_util;
mod terminal;
mod timerfd;
pub mod vsock;
mod write_zeroes;
pub use crate::alloc::LayoutAllocation;

View file

@ -5,23 +5,282 @@
use std::ffi::OsString;
use std::fs::remove_file;
use std::io;
use std::mem;
use std::mem::{self, size_of};
use std::net::{SocketAddr, SocketAddrV4, SocketAddrV6, TcpListener, TcpStream, ToSocketAddrs};
use std::ops::Deref;
use std::os::unix::{
ffi::{OsStrExt, OsStringExt},
io::{AsRawFd, FromRawFd, RawFd},
io::{AsRawFd, FromRawFd, IntoRawFd, RawFd},
};
use std::path::Path;
use std::path::PathBuf;
use std::ptr::null_mut;
use std::time::Duration;
use libc::{recvfrom, MSG_PEEK, MSG_TRUNC};
use libc::{
c_int, in6_addr, in_addr, recvfrom, sa_family_t, sockaddr, sockaddr_in, sockaddr_in6,
socklen_t, AF_INET, AF_INET6, MSG_PEEK, MSG_TRUNC, SOCK_CLOEXEC, SOCK_STREAM,
};
use serde::{Deserialize, Serialize};
use crate::sock_ctrl_msg::{ScmSocket, SCM_SOCKET_MAX_FD_COUNT};
use crate::{AsRawDescriptor, RawDescriptor};
/// Assist in handling both IP version 4 and IP version 6.
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum InetVersion {
V4,
V6,
}
impl InetVersion {
pub fn from_sockaddr(s: &SocketAddr) -> Self {
match s {
SocketAddr::V4(_) => InetVersion::V4,
SocketAddr::V6(_) => InetVersion::V6,
}
}
}
impl From<InetVersion> for sa_family_t {
fn from(v: InetVersion) -> sa_family_t {
match v {
InetVersion::V4 => AF_INET as sa_family_t,
InetVersion::V6 => AF_INET6 as sa_family_t,
}
}
}
fn sockaddrv4_to_lib_c(s: &SocketAddrV4) -> sockaddr_in {
sockaddr_in {
sin_family: AF_INET as sa_family_t,
sin_port: s.port().to_be(),
sin_addr: in_addr {
s_addr: u32::from_ne_bytes(s.ip().octets()),
},
sin_zero: [0; 8],
}
}
fn sockaddrv6_to_lib_c(s: &SocketAddrV6) -> sockaddr_in6 {
sockaddr_in6 {
sin6_family: AF_INET6 as sa_family_t,
sin6_port: s.port().to_be(),
sin6_flowinfo: 0,
sin6_addr: in6_addr {
s6_addr: s.ip().octets(),
},
sin6_scope_id: 0,
}
}
/// A TCP socket.
///
/// Do not use this class unless you need to change socket options or query the
/// state of the socket prior to calling listen or connect. Instead use either TcpStream or
/// TcpListener.
#[derive(Debug)]
pub struct TcpSocket {
inet_version: InetVersion,
fd: RawFd,
}
impl TcpSocket {
pub fn new(inet_version: InetVersion) -> io::Result<Self> {
let fd = unsafe {
libc::socket(
Into::<sa_family_t>::into(inet_version) as c_int,
SOCK_STREAM | SOCK_CLOEXEC,
0,
)
};
if fd < 0 {
Err(io::Error::last_os_error())
} else {
Ok(TcpSocket { inet_version, fd })
}
}
pub fn bind<A: ToSocketAddrs>(&mut self, addr: A) -> io::Result<()> {
let sockaddr = addr
.to_socket_addrs()
.map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
.next()
.unwrap();
let ret = match sockaddr {
SocketAddr::V4(a) => {
let sin = sockaddrv4_to_lib_c(&a);
// Safe because this doesn't modify any memory and we check the return value.
unsafe {
libc::bind(
self.fd,
&sin as *const sockaddr_in as *const sockaddr,
size_of::<sockaddr_in>() as socklen_t,
)
}
}
SocketAddr::V6(a) => {
let sin6 = sockaddrv6_to_lib_c(&a);
// Safe because this doesn't modify any memory and we check the return value.
unsafe {
libc::bind(
self.fd,
&sin6 as *const sockaddr_in6 as *const sockaddr,
size_of::<sockaddr_in6>() as socklen_t,
)
}
}
};
if ret < 0 {
let bind_err = io::Error::last_os_error();
Err(bind_err)
} else {
Ok(())
}
}
pub fn connect<A: ToSocketAddrs>(self, addr: A) -> io::Result<TcpStream> {
let sockaddr = addr
.to_socket_addrs()
.map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?
.next()
.unwrap();
let ret = match sockaddr {
SocketAddr::V4(a) => {
let sin = sockaddrv4_to_lib_c(&a);
// Safe because this doesn't modify any memory and we check the return value.
unsafe {
libc::connect(
self.fd,
&sin as *const sockaddr_in as *const sockaddr,
size_of::<sockaddr_in>() as socklen_t,
)
}
}
SocketAddr::V6(a) => {
let sin6 = sockaddrv6_to_lib_c(&a);
// Safe because this doesn't modify any memory and we check the return value.
unsafe {
libc::connect(
self.fd,
&sin6 as *const sockaddr_in6 as *const sockaddr,
size_of::<sockaddr_in>() as socklen_t,
)
}
}
};
if ret < 0 {
let connect_err = io::Error::last_os_error();
Err(connect_err)
} else {
// Safe because the ownership of the raw fd is released from self and taken over by the
// new TcpStream.
Ok(unsafe { TcpStream::from_raw_fd(self.into_raw_fd()) })
}
}
pub fn listen(self) -> io::Result<TcpListener> {
// Safe because this doesn't modify any memory and we check the return value.
let ret = unsafe { libc::listen(self.fd, 1) };
if ret < 0 {
let listen_err = io::Error::last_os_error();
Err(listen_err)
} else {
// Safe because the ownership of the raw fd is released from self and taken over by the
// new TcpListener.
Ok(unsafe { TcpListener::from_raw_fd(self.into_raw_fd()) })
}
}
/// Returns the port that this socket is bound to. This can only succeed after bind is called.
pub fn local_port(&self) -> io::Result<u16> {
match self.inet_version {
InetVersion::V4 => {
let mut sin = sockaddr_in {
sin_family: 0,
sin_port: 0,
sin_addr: in_addr { s_addr: 0 },
sin_zero: [0; 8],
};
// Safe because we give a valid pointer for addrlen and check the length.
let mut addrlen = size_of::<sockaddr_in>() as socklen_t;
let ret = unsafe {
// Get the socket address that was actually bound.
libc::getsockname(
self.fd,
&mut sin as *mut sockaddr_in as *mut sockaddr,
&mut addrlen as *mut socklen_t,
)
};
if ret < 0 {
let getsockname_err = io::Error::last_os_error();
Err(getsockname_err)
} else {
// If this doesn't match, it's not safe to get the port out of the sockaddr.
assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
Ok(u16::from_be(sin.sin_port))
}
}
InetVersion::V6 => {
let mut sin6 = sockaddr_in6 {
sin6_family: 0,
sin6_port: 0,
sin6_flowinfo: 0,
sin6_addr: in6_addr { s6_addr: [0; 16] },
sin6_scope_id: 0,
};
// Safe because we give a valid pointer for addrlen and check the length.
let mut addrlen = size_of::<sockaddr_in6>() as socklen_t;
let ret = unsafe {
// Get the socket address that was actually bound.
libc::getsockname(
self.fd,
&mut sin6 as *mut sockaddr_in6 as *mut sockaddr,
&mut addrlen as *mut socklen_t,
)
};
if ret < 0 {
let getsockname_err = io::Error::last_os_error();
Err(getsockname_err)
} else {
// If this doesn't match, it's not safe to get the port out of the sockaddr.
assert_eq!(addrlen as usize, size_of::<sockaddr_in>());
Ok(u16::from_be(sin6.sin6_port))
}
}
}
}
}
impl IntoRawFd for TcpSocket {
fn into_raw_fd(self) -> RawFd {
let fd = self.fd;
mem::forget(self);
fd
}
}
impl AsRawFd for TcpSocket {
fn as_raw_fd(&self) -> RawFd {
self.fd
}
}
impl Drop for TcpSocket {
fn drop(&mut self) {
// Safe because this doesn't modify any memory and we are the only
// owner of the file descriptor.
unsafe { libc::close(self.fd) };
}
}
// Offset of sun_path in structure sockaddr_un.
fn sun_path_offset() -> usize {
// Prefer 0 to null() so that we do not need to subtract from the `sub_path` pointer.

114
sys_util/src/rand.rs Normal file
View file

@ -0,0 +1,114 @@
// Copyright 2021 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//! Rust implementation of functionality parallel to libchrome's base/rand_util.h.
use std::thread::sleep;
use std::time::Duration;
use libc::{c_uint, c_void};
use crate::{
errno::{errno_result, Result},
handle_eintr_errno,
};
/// How long to wait before calling getrandom again if it does not return
/// enough bytes.
const POLL_INTERVAL: Duration = Duration::from_millis(50);
/// Represents whether or not the random bytes are pulled from the source of
/// /dev/random or /dev/urandom.
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum Source {
// This is the default and uses the same source as /dev/urandom.
Pseudorandom,
// This uses the same source as /dev/random and may be.
Random,
}
impl Default for Source {
fn default() -> Self {
Source::Pseudorandom
}
}
impl Source {
fn to_getrandom_flags(&self) -> c_uint {
match self {
Source::Random => libc::GRND_RANDOM,
Source::Pseudorandom => 0,
}
}
}
/// Fills `output` completely with random bytes from the specified `source`.
pub fn rand_bytes(mut output: &mut [u8], source: Source) -> Result<()> {
if output.is_empty() {
return Ok(());
}
loop {
// Safe because output is mutable and the writes are limited by output.len().
let bytes = handle_eintr_errno!(unsafe {
libc::getrandom(
output.as_mut_ptr() as *mut c_void,
output.len(),
source.to_getrandom_flags(),
)
});
if bytes < 0 {
return errno_result();
}
if bytes as usize == output.len() {
return Ok(());
}
// Wait for more entropy and try again for the remaining bytes.
sleep(POLL_INTERVAL);
output = &mut output[bytes as usize..];
}
}
/// Allocates a vector of length `len` filled with random bytes from the
/// specified `source`.
pub fn rand_vec(len: usize, source: Source) -> Result<Vec<u8>> {
let mut rand = Vec::with_capacity(len);
if len == 0 {
return Ok(rand);
}
// Safe because rand will either be initialized by getrandom or dropped.
unsafe { rand.set_len(len) };
rand_bytes(rand.as_mut_slice(), source)?;
Ok(rand)
}
#[cfg(test)]
mod tests {
use super::*;
const TEST_SIZE: usize = 64;
#[test]
fn randbytes_success() {
let mut rand = vec![0u8; TEST_SIZE];
rand_bytes(&mut rand, Source::Pseudorandom).unwrap();
assert_ne!(&rand, &[0u8; TEST_SIZE]);
}
#[test]
fn randvec_success() {
let rand = rand_vec(TEST_SIZE, Source::Pseudorandom).unwrap();
assert_eq!(rand.len(), TEST_SIZE);
assert_ne!(&rand, &[0u8; TEST_SIZE]);
}
#[test]
fn sourcerandom_success() {
let rand = rand_vec(TEST_SIZE, Source::Random).unwrap();
assert_ne!(&rand, &[0u8; TEST_SIZE]);
}
}

138
sys_util/src/scoped_path.rs Normal file
View file

@ -0,0 +1,138 @@
// Copyright 2021 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
use std::env::{current_exe, temp_dir};
use std::fs::{create_dir_all, remove_dir_all};
use std::io::Result;
use std::ops::Deref;
use std::path::{Path, PathBuf};
use std::thread::panicking;
use crate::{getpid, gettid};
/// Returns a stable path based on the label, pid, and tid. If the label isn't provided the
/// current_exe is used instead.
pub fn get_temp_path(label: Option<&str>) -> PathBuf {
if let Some(label) = label {
temp_dir().join(format!("{}-{}-{}", label, getpid(), gettid()))
} else {
get_temp_path(Some(
current_exe()
.unwrap()
.file_name()
.unwrap()
.to_str()
.unwrap(),
))
}
}
/// Automatically deletes the path it contains when it goes out of scope unless it is a test and
/// drop is called after a panic!.
///
/// This is particularly useful for creating temporary directories for use with tests.
pub struct ScopedPath<P: AsRef<Path>>(P);
impl<P: AsRef<Path>> ScopedPath<P> {
pub fn create(p: P) -> Result<Self> {
create_dir_all(p.as_ref())?;
Ok(ScopedPath(p))
}
}
impl<P: AsRef<Path>> AsRef<Path> for ScopedPath<P> {
fn as_ref(&self) -> &Path {
self.0.as_ref()
}
}
impl<P: AsRef<Path>> Deref for ScopedPath<P> {
type Target = Path;
fn deref(&self) -> &Self::Target {
self.0.as_ref()
}
}
impl<P: AsRef<Path>> Drop for ScopedPath<P> {
fn drop(&mut self) {
// Leave the files on a failed test run for debugging.
if panicking() && cfg!(test) {
eprintln!("NOTE: Not removing {}", self.display());
return;
}
if let Err(e) = remove_dir_all(&**self) {
eprintln!("Failed to remove {}: {}", self.display(), e);
}
}
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use std::panic::catch_unwind;
#[test]
fn gettemppath() {
assert_ne!("", get_temp_path(None).to_string_lossy());
assert!(get_temp_path(None).starts_with(temp_dir()));
assert_eq!(
get_temp_path(None),
get_temp_path(Some(
current_exe()
.unwrap()
.file_name()
.unwrap()
.to_str()
.unwrap()
))
);
assert_ne!(
get_temp_path(Some("label")),
get_temp_path(Some(
current_exe()
.unwrap()
.file_name()
.unwrap()
.to_str()
.unwrap()
))
);
}
#[test]
fn scopedpath_exists() {
let tmp_path = get_temp_path(None);
{
let scoped_path = ScopedPath::create(&tmp_path).unwrap();
assert!(scoped_path.exists());
}
assert!(!tmp_path.exists());
}
#[test]
fn scopedpath_notexists() {
let tmp_path = get_temp_path(None);
{
let _scoped_path = ScopedPath(&tmp_path);
}
assert!(!tmp_path.exists());
}
#[test]
fn scopedpath_panic() {
let tmp_path = get_temp_path(None);
assert!(catch_unwind(|| {
{
let scoped_path = ScopedPath::create(&tmp_path).unwrap();
assert!(scoped_path.exists());
panic!()
}
})
.is_err());
assert!(tmp_path.exists());
remove_dir_all(&tmp_path).unwrap();
}
}

495
sys_util/src/vsock.rs Normal file
View file

@ -0,0 +1,495 @@
// Copyright 2021 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
/// Support for virtual sockets.
use std::fmt;
use std::io;
use std::mem::{self, size_of};
use std::num::ParseIntError;
use std::os::raw::{c_uchar, c_uint, c_ushort};
use std::os::unix::io::{AsRawFd, IntoRawFd, RawFd};
use std::result;
use std::str::FromStr;
use libc::{
self, c_void, sa_family_t, size_t, sockaddr, socklen_t, F_GETFL, F_SETFL, O_NONBLOCK,
VMADDR_CID_ANY, VMADDR_CID_HOST, VMADDR_CID_HYPERVISOR,
};
// The domain for vsock sockets.
const AF_VSOCK: sa_family_t = 40;
// Vsock loopback address.
const VMADDR_CID_LOCAL: c_uint = 1;
/// Vsock equivalent of binding on port 0. Binds to a random port.
pub const VMADDR_PORT_ANY: c_uint = c_uint::max_value();
// The number of bytes of padding to be added to the sockaddr_vm struct. Taken directly
// from linux/vm_sockets.h.
const PADDING: usize = size_of::<sockaddr>()
- size_of::<sa_family_t>()
- size_of::<c_ushort>()
- (2 * size_of::<c_uint>());
#[repr(C)]
#[derive(Default)]
struct sockaddr_vm {
svm_family: sa_family_t,
svm_reserved1: c_ushort,
svm_port: c_uint,
svm_cid: c_uint,
svm_zero: [c_uchar; PADDING],
}
#[derive(Debug)]
pub struct AddrParseError;
impl fmt::Display for AddrParseError {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "failed to parse vsock address")
}
}
/// The vsock equivalent of an IP address.
#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq)]
pub enum VsockCid {
/// Vsock equivalent of INADDR_ANY. Indicates the context id of the current endpoint.
Any,
/// An address that refers to the bare-metal machine that serves as the hypervisor.
Hypervisor,
/// The loopback address.
Local,
/// The parent machine. It may not be the hypervisor for nested VMs.
Host,
/// An assigned CID that serves as the address for VSOCK.
Cid(c_uint),
}
impl fmt::Display for VsockCid {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
match &self {
VsockCid::Any => write!(fmt, "Any"),
VsockCid::Hypervisor => write!(fmt, "Hypervisor"),
VsockCid::Local => write!(fmt, "Local"),
VsockCid::Host => write!(fmt, "Host"),
VsockCid::Cid(c) => write!(fmt, "'{}'", c),
}
}
}
impl From<c_uint> for VsockCid {
fn from(c: c_uint) -> Self {
match c {
VMADDR_CID_ANY => VsockCid::Any,
VMADDR_CID_HYPERVISOR => VsockCid::Hypervisor,
VMADDR_CID_LOCAL => VsockCid::Local,
VMADDR_CID_HOST => VsockCid::Host,
_ => VsockCid::Cid(c),
}
}
}
impl FromStr for VsockCid {
type Err = ParseIntError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let c: c_uint = s.parse()?;
Ok(c.into())
}
}
impl From<VsockCid> for c_uint {
fn from(cid: VsockCid) -> c_uint {
match cid {
VsockCid::Any => VMADDR_CID_ANY,
VsockCid::Hypervisor => VMADDR_CID_HYPERVISOR,
VsockCid::Local => VMADDR_CID_LOCAL,
VsockCid::Host => VMADDR_CID_HOST,
VsockCid::Cid(c) => c,
}
}
}
/// An address associated with a virtual socket.
#[derive(Debug, Copy, Clone, Hash, Eq, PartialEq)]
pub struct SocketAddr {
pub cid: VsockCid,
pub port: c_uint,
}
pub trait ToSocketAddr {
fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError>;
}
impl ToSocketAddr for SocketAddr {
fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
Ok(*self)
}
}
impl ToSocketAddr for str {
fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
self.parse()
}
}
impl ToSocketAddr for (VsockCid, c_uint) {
fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
let (cid, port) = *self;
Ok(SocketAddr { cid, port })
}
}
impl<'a, T: ToSocketAddr + ?Sized> ToSocketAddr for &'a T {
fn to_socket_addr(&self) -> result::Result<SocketAddr, AddrParseError> {
(**self).to_socket_addr()
}
}
impl FromStr for SocketAddr {
type Err = AddrParseError;
/// Parse a vsock SocketAddr from a string. vsock socket addresses are of the form
/// "vsock:cid:port".
fn from_str(s: &str) -> Result<SocketAddr, AddrParseError> {
let components: Vec<&str> = s.split(':').collect();
if components.len() != 3 || components[0] != "vsock" {
return Err(AddrParseError);
}
Ok(SocketAddr {
cid: components[1].parse().map_err(|_| AddrParseError)?,
port: components[2].parse().map_err(|_| AddrParseError)?,
})
}
}
impl fmt::Display for SocketAddr {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "{}:{}", self.cid, self.port)
}
}
/// Sets `fd` to be blocking or nonblocking. `fd` must be a valid fd of a type that accepts the
/// `O_NONBLOCK` flag. This includes regular files, pipes, and sockets.
unsafe fn set_nonblocking(fd: RawFd, nonblocking: bool) -> io::Result<()> {
let flags = libc::fcntl(fd, F_GETFL, 0);
if flags < 0 {
return Err(io::Error::last_os_error());
}
let flags = if nonblocking {
flags | O_NONBLOCK
} else {
flags & !O_NONBLOCK
};
let ret = libc::fcntl(fd, F_SETFL, flags);
if ret < 0 {
return Err(io::Error::last_os_error());
}
Ok(())
}
/// A virtual socket.
///
/// Do not use this class unless you need to change socket options or query the
/// state of the socket prior to calling listen or connect. Instead use either VsockStream or
/// VsockListener.
#[derive(Debug)]
pub struct VsockSocket {
fd: RawFd,
}
impl VsockSocket {
pub fn new() -> io::Result<Self> {
let fd = unsafe { libc::socket(libc::AF_VSOCK, libc::SOCK_STREAM | libc::SOCK_CLOEXEC, 0) };
if fd < 0 {
Err(io::Error::last_os_error())
} else {
Ok(VsockSocket { fd })
}
}
pub fn bind<A: ToSocketAddr>(&mut self, addr: A) -> io::Result<()> {
let sockaddr = addr
.to_socket_addr()
.map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?;
// The compiler should optimize this out since these are both compile-time constants.
assert_eq!(size_of::<sockaddr_vm>(), size_of::<sockaddr>());
let svm = sockaddr_vm {
svm_family: AF_VSOCK,
svm_cid: sockaddr.cid.into(),
svm_port: sockaddr.port,
..Default::default()
};
// Safe because this doesn't modify any memory and we check the return value.
let ret = unsafe {
libc::bind(
self.fd,
&svm as *const sockaddr_vm as *const sockaddr,
size_of::<sockaddr_vm>() as socklen_t,
)
};
if ret < 0 {
let bind_err = io::Error::last_os_error();
Err(bind_err)
} else {
Ok(())
}
}
pub fn connect<A: ToSocketAddr>(self, addr: A) -> io::Result<VsockStream> {
let sockaddr = addr
.to_socket_addr()
.map_err(|_| io::Error::from_raw_os_error(libc::EINVAL))?;
let svm = sockaddr_vm {
svm_family: AF_VSOCK,
svm_cid: sockaddr.cid.into(),
svm_port: sockaddr.port,
..Default::default()
};
// Safe because this just connects a vsock socket, and the return value is checked.
let ret = unsafe {
libc::connect(
self.fd,
&svm as *const sockaddr_vm as *const sockaddr,
size_of::<sockaddr_vm>() as socklen_t,
)
};
if ret < 0 {
let connect_err = io::Error::last_os_error();
Err(connect_err)
} else {
Ok(VsockStream { sock: self })
}
}
pub fn listen(self) -> io::Result<VsockListener> {
// Safe because this doesn't modify any memory and we check the return value.
let ret = unsafe { libc::listen(self.fd, 1) };
if ret < 0 {
let listen_err = io::Error::last_os_error();
return Err(listen_err);
}
Ok(VsockListener { sock: self })
}
/// Returns the port that this socket is bound to. This can only succeed after bind is called.
pub fn local_port(&self) -> io::Result<u32> {
let mut svm: sockaddr_vm = Default::default();
// Safe because we give a valid pointer for addrlen and check the length.
let mut addrlen = size_of::<sockaddr_vm>() as socklen_t;
let ret = unsafe {
// Get the socket address that was actually bound.
libc::getsockname(
self.fd,
&mut svm as *mut sockaddr_vm as *mut sockaddr,
&mut addrlen as *mut socklen_t,
)
};
if ret < 0 {
let getsockname_err = io::Error::last_os_error();
Err(getsockname_err)
} else {
// If this doesn't match, it's not safe to get the port out of the sockaddr.
assert_eq!(addrlen as usize, size_of::<sockaddr_vm>());
Ok(svm.svm_port)
}
}
pub fn try_clone(&self) -> io::Result<Self> {
// Safe because this doesn't modify any memory and we check the return value.
let dup_fd = unsafe { libc::fcntl(self.fd, libc::F_DUPFD_CLOEXEC, 0) };
if dup_fd < 0 {
Err(io::Error::last_os_error())
} else {
Ok(Self { fd: dup_fd })
}
}
pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
// Safe because the fd is valid and owned by this stream.
unsafe { set_nonblocking(self.fd, nonblocking) }
}
}
impl IntoRawFd for VsockSocket {
fn into_raw_fd(self) -> RawFd {
let fd = self.fd;
mem::forget(self);
fd
}
}
impl AsRawFd for VsockSocket {
fn as_raw_fd(&self) -> RawFd {
self.fd
}
}
impl Drop for VsockSocket {
fn drop(&mut self) {
// Safe because this doesn't modify any memory and we are the only
// owner of the file descriptor.
unsafe { libc::close(self.fd) };
}
}
/// A virtual stream socket.
#[derive(Debug)]
pub struct VsockStream {
sock: VsockSocket,
}
impl VsockStream {
pub fn connect<A: ToSocketAddr>(addr: A) -> io::Result<VsockStream> {
let sock = VsockSocket::new()?;
sock.connect(addr)
}
/// Returns the port that this stream is bound to.
pub fn local_port(&self) -> io::Result<u32> {
self.sock.local_port()
}
pub fn try_clone(&self) -> io::Result<VsockStream> {
self.sock.try_clone().map(|f| VsockStream { sock: f })
}
pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
self.sock.set_nonblocking(nonblocking)
}
}
impl io::Read for VsockStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
// Safe because this will only modify the contents of |buf| and we check the return value.
let ret = unsafe {
libc::read(
self.sock.as_raw_fd(),
buf as *mut [u8] as *mut c_void,
buf.len() as size_t,
)
};
if ret < 0 {
return Err(io::Error::last_os_error());
}
Ok(ret as usize)
}
}
impl io::Write for VsockStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
// Safe because this doesn't modify any memory and we check the return value.
let ret = unsafe {
libc::write(
self.sock.as_raw_fd(),
buf as *const [u8] as *const c_void,
buf.len() as size_t,
)
};
if ret < 0 {
return Err(io::Error::last_os_error());
}
Ok(ret as usize)
}
fn flush(&mut self) -> io::Result<()> {
// No buffered data so nothing to do.
Ok(())
}
}
impl AsRawFd for VsockStream {
fn as_raw_fd(&self) -> RawFd {
self.sock.as_raw_fd()
}
}
impl IntoRawFd for VsockStream {
fn into_raw_fd(self) -> RawFd {
self.sock.into_raw_fd()
}
}
/// Represents a virtual socket server.
#[derive(Debug)]
pub struct VsockListener {
sock: VsockSocket,
}
impl VsockListener {
/// Creates a new `VsockListener` bound to the specified port on the current virtual socket
/// endpoint.
pub fn bind<A: ToSocketAddr>(addr: A) -> io::Result<VsockListener> {
let mut sock = VsockSocket::new()?;
sock.bind(addr)?;
sock.listen()
}
/// Returns the port that this listener is bound to.
pub fn local_port(&self) -> io::Result<u32> {
self.sock.local_port()
}
/// Accepts a new incoming connection on this listener. Blocks the calling thread until a
/// new connection is established. When established, returns the corresponding `VsockStream`
/// and the remote peer's address.
pub fn accept(&self) -> io::Result<(VsockStream, SocketAddr)> {
let mut svm: sockaddr_vm = Default::default();
// Safe because this will only modify |svm| and we check the return value.
let mut socklen: socklen_t = size_of::<sockaddr_vm>() as socklen_t;
let fd = unsafe {
libc::accept4(
self.sock.as_raw_fd(),
&mut svm as *mut sockaddr_vm as *mut sockaddr,
&mut socklen as *mut socklen_t,
libc::SOCK_CLOEXEC,
)
};
if fd < 0 {
return Err(io::Error::last_os_error());
}
if svm.svm_family != AF_VSOCK {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("unexpected address family: {}", svm.svm_family),
));
}
Ok((
VsockStream {
sock: VsockSocket { fd },
},
SocketAddr {
cid: svm.svm_cid.into(),
port: svm.svm_port,
},
))
}
pub fn set_nonblocking(&mut self, nonblocking: bool) -> io::Result<()> {
self.sock.set_nonblocking(nonblocking)
}
}
impl AsRawFd for VsockListener {
fn as_raw_fd(&self) -> RawFd {
self.sock.as_raw_fd()
}
}