From a99954cb7c076c9585aba416afdcb86f67e3676f Mon Sep 17 00:00:00 2001 From: Zach Reizner Date: Thu, 23 Aug 2018 13:34:56 -0700 Subject: [PATCH] sys_util: remove Scm struct and sock_ctrl_msg C library The Scm object was made to reduce the number of heap allocations in the hot paths of poll loops, at the cost of some code complexity. As it turns out, the number of file descriptors being sent or received is usually just one or limited to a fixed amount that can easily be covered with a fixed size stack allocated buffer. This change implements that solution, with heap allocation as a backup in the rare case that many file descriptors must be sent or received. This change also moves the msg and cmsg manipulation code out of C and into pure Rust. The move was necessary to allocate the correct amount of buffer space at compile time. It also improves safety by reducing the scope of unsafe code. Deleting the code for building the C library is also a nice bonus. Finally, the removal of the commonly used Scm struct required transitioning existing usage to the ScmSocket trait based methods. This includes all those changes. TEST=cargo test BUG=None Change-Id: If27ba297f5416dd9b8bc686ce740866912fa0aa0 Reviewed-on: https://chromium-review.googlesource.com/1186146 Commit-Ready: ChromeOS CL Exonerator Bot Tested-by: Zach Reizner Reviewed-by: Zach Reizner --- Cargo.lock | 1 - crosvm_plugin/src/lib.rs | 22 +- devices/src/virtio/wl.rs | 46 ++-- src/linux.rs | 8 +- src/main.rs | 8 +- src/plugin/process.rs | 21 +- sys_util/Cargo.toml | 4 - sys_util/build.rs | 11 - sys_util/sock_ctrl_msg.c | 126 --------- sys_util/src/sock_ctrl_msg.rs | 464 ++++++++++++++++++++-------------- vm_control/src/lib.rs | 122 +++++---- 11 files changed, 385 insertions(+), 448 deletions(-) delete mode 100644 sys_util/build.rs delete mode 100644 sys_util/sock_ctrl_msg.c diff --git a/Cargo.lock b/Cargo.lock index 8ac7e4f067..086a9f3a70 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -367,7 +367,6 @@ dependencies = [ name = "sys_util" version = "0.1.0" dependencies = [ - "cc 1.0.15 (registry+https://github.com/rust-lang/crates.io-index)", "data_model 0.1.0", "libc 0.2.40 (registry+https://github.com/rust-lang/crates.io-index)", "poll_token_derive 0.1.0", diff --git a/crosvm_plugin/src/lib.rs b/crosvm_plugin/src/lib.rs index 06cb367ed3..754f74c03d 100644 --- a/crosvm_plugin/src/lib.rs +++ b/crosvm_plugin/src/lib.rs @@ -40,7 +40,7 @@ use libc::{E2BIG, ENOTCONN, EINVAL, EPROTO, ENOENT}; use protobuf::{Message, ProtobufEnum, RepeatedField, parse_from_bytes}; -use sys_util::Scm; +use sys_util::ScmSocket; use kvm::dirty_log_bitmap_size; @@ -250,7 +250,6 @@ impl Drop for StatUpdater { pub struct crosvm { id_allocator: Arc, socket: UnixDatagram, - fd_messager: Scm, request_buffer: Vec, response_buffer: Vec, vcpus: Arc>, @@ -261,7 +260,6 @@ impl crosvm { let mut crosvm = crosvm { id_allocator: Default::default(), socket, - fd_messager: Scm::new(MAX_DATAGRAM_FD), request_buffer: Vec::new(), response_buffer: vec![0; MAX_DATAGRAM_SIZE], vcpus: Default::default(), @@ -277,7 +275,6 @@ impl crosvm { crosvm { id_allocator, socket, - fd_messager: Scm::new(MAX_DATAGRAM_FD), request_buffer: Vec::new(), response_buffer: vec![0; MAX_DATAGRAM_SIZE], vcpus, @@ -296,16 +293,19 @@ impl crosvm { request .write_to_vec(&mut self.request_buffer) .map_err(proto_error_to_int)?; - self.fd_messager - .send(&self.socket, &[self.request_buffer.as_slice()], fds) + self.socket + .send_with_fds(self.request_buffer.as_slice(), fds) .map_err(|e| -e.errno())?; - let mut datagram_files = Vec::new(); - let msg_size = self.fd_messager - .recv(&self.socket, - &mut [&mut self.response_buffer], - &mut datagram_files) + let mut datagram_fds = [0; MAX_DATAGRAM_FD]; + let (msg_size, fd_count) = self.socket + .recv_with_fds(&mut self.response_buffer, &mut datagram_fds) .map_err(|e| -e.errno())?; + // Safe because the first fd_count fds from recv_with_fds are owned by us and valid. + let datagram_files = datagram_fds[..fd_count] + .iter() + .map(|&fd| unsafe { File::from_raw_fd(fd) }) + .collect(); let response: MainResponse = parse_from_bytes(&self.response_buffer[..msg_size]) .map_err(proto_error_to_int)?; diff --git a/devices/src/virtio/wl.rs b/devices/src/virtio/wl.rs index b946c51cca..7c6ae58c66 100644 --- a/devices/src/virtio/wl.rs +++ b/devices/src/virtio/wl.rs @@ -40,9 +40,7 @@ use std::io::{self, Seek, SeekFrom, Read}; use std::mem::{size_of, size_of_val}; #[cfg(feature = "wl-dmabuf")] use std::os::raw::{c_uint, c_ulonglong}; -use std::os::unix::io::{AsRawFd, RawFd}; -#[cfg(feature = "wl-dmabuf")] -use std::os::unix::io::FromRawFd; +use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; use std::os::unix::net::{UnixDatagram, UnixStream}; use std::path::{PathBuf, Path}; use std::rc::Rc; @@ -59,7 +57,7 @@ use data_model::*; use data_model::VolatileMemoryError; use resources::GpuMemoryDesc; -use sys_util::{Error, Result, EventFd, Scm, SharedMemory, GuestAddress, GuestMemory, +use sys_util::{Error, Result, EventFd, ScmSocket, SharedMemory, GuestAddress, GuestMemory, GuestMemoryError, PollContext, PollToken, FileFlags, pipe, round_up_to_page_size}; #[cfg(feature = "wl-dmabuf")] @@ -432,21 +430,19 @@ impl From for WlError { #[derive(Clone)] struct VmRequester { - inner: Rc>, + inner: Rc>, } impl VmRequester { fn new(vm_socket: UnixDatagram) -> VmRequester { - VmRequester { inner: Rc::new(RefCell::new((Scm::new(1), vm_socket))) } + VmRequester { inner: Rc::new(RefCell::new(vm_socket)) } } fn request(&self, request: VmRequest) -> WlResult { let mut inner = self.inner.borrow_mut(); - let (ref mut scm, ref mut vm_socket) = *inner; - request - .send(scm, vm_socket) - .map_err(WlError::VmControl)?; - VmResponse::recv(scm, vm_socket).map_err(WlError::VmControl) + let ref mut vm_socket = *inner; + request.send(vm_socket).map_err(WlError::VmControl)?; + VmResponse::recv(vm_socket).map_err(WlError::VmControl) } } @@ -803,9 +799,10 @@ impl WlVfd { } // Sends data/files from the guest to the host over this VFD. - fn send(&mut self, scm: &mut Scm, fds: &[RawFd], data: VolatileSlice) -> WlResult { + fn send(&mut self, fds: &[RawFd], data: VolatileSlice) -> WlResult { if let Some(ref socket) = self.socket { - scm.send(socket, &[data], fds) + socket + .send_with_fds(data, fds) .map_err(WlError::SendVfd)?; Ok(WlResp::Ok) } else if let Some((_, ref mut local_pipe)) = self.local_pipe { @@ -822,19 +819,24 @@ impl WlVfd { } // Receives data/files from the host for this VFD and queues it for the guest. - fn recv(&mut self, scm: &mut Scm, in_file_queue: &mut Vec) -> WlResult> { + fn recv(&mut self, in_file_queue: &mut Vec) -> WlResult> { if let Some(socket) = self.socket.take() { - let mut buf = Vec::new(); - buf.resize(IN_BUFFER_LEN, 0); - let old_len = in_file_queue.len(); + let mut buf = vec![0; IN_BUFFER_LEN]; + let mut fd_buf = [0; VIRTWL_SEND_MAX_ALLOCS]; // If any errors happen, the socket will get dropped, preventing more reading. - let len = scm.recv(&socket, &mut [&mut buf[..]], in_file_queue) + let (len, file_count) = socket + .recv_with_fds(&mut buf[..], &mut fd_buf) .map_err(WlError::RecvVfd)?; // If any data gets read, the put the socket back for future recv operations. - if len != 0 || in_file_queue.len() != old_len { + if len != 0 || file_count != 0 { buf.truncate(len); buf.shrink_to_fit(); self.socket = Some(socket); + // Safe because the first file_counts fds from recv_with_fds are owned by us and + // valid. + in_file_queue.extend(fd_buf[..file_count] + .iter() + .map(|&fd| unsafe { File::from_raw_fd(fd) })); return Ok(buf); } Ok(Vec::new()) @@ -893,7 +895,6 @@ struct WlState { poll_ctx: PollContext, vfds: Map, next_vfd_id: u32, - scm: Scm, in_file_queue: Vec, in_queue: VecDeque<(u32 /* vfd_id */, WlRecv)>, current_recv_vfd: Option, @@ -907,7 +908,6 @@ impl WlState { vm: VmRequester::new(vm_socket), poll_ctx: PollContext::new().expect("failed to create PollContext"), use_transition_flags, - scm: Scm::new(VIRTWL_SEND_MAX_ALLOCS), vfds: Map::new(), next_vfd_id: NEXT_VFD_ID_BASE, in_file_queue: Vec::new(), @@ -1126,7 +1126,7 @@ impl WlState { match self.vfds.get_mut(&vfd_id) { Some(vfd) => { - match vfd.send(&mut self.scm, &fds[..vfd_count], data)? { + match vfd.send(&fds[..vfd_count], data)? { WlResp::Ok => {} _ => return Ok(WlResp::InvalidType), } @@ -1145,7 +1145,7 @@ impl WlState { fn recv(&mut self, vfd_id: u32) -> WlResult<()> { let buf = match self.vfds.get_mut(&vfd_id) { - Some(vfd) => vfd.recv(&mut self.scm, &mut self.in_file_queue)?, + Some(vfd) => vfd.recv(&mut self.in_file_queue)?, None => return Ok(()), }; if self.in_file_queue.is_empty() && buf.is_empty() { diff --git a/src/linux.rs b/src/linux.rs index a142d1dda1..a5c87de643 100644 --- a/src/linux.rs +++ b/src/linux.rs @@ -729,8 +729,6 @@ fn run_control(mut linux: RunnableLinuxVm, control_sockets: Vec, balloon_host_socket: UnixDatagram, sigchld_fd: SignalFd) -> Result<()> { - const MAX_VM_FD_RECV: usize = 1; - // Paths to get the currently available memory and the low memory threshold. const LOWMEM_MARGIN: &'static str = "/sys/kernel/mm/chromeos-low_mem/margin"; const LOWMEM_AVAILABLE: &'static str = "/sys/kernel/mm/chromeos-low_mem/available"; @@ -817,8 +815,6 @@ fn run_control(mut linux: RunnableLinuxVm, } vcpu_thread_barrier.wait(); - let mut scm = Scm::new(MAX_VM_FD_RECV); - 'poll: loop { let events = { match poll_ctx.wait() { @@ -945,7 +941,7 @@ fn run_control(mut linux: RunnableLinuxVm, } Token::VmControl { index } => { if let Some(socket) = control_sockets.get(index as usize) { - match VmRequest::recv(&mut scm, socket.as_ref()) { + match VmRequest::recv(socket.as_ref()) { Ok(request) => { let mut running = true; let response = @@ -953,7 +949,7 @@ fn run_control(mut linux: RunnableLinuxVm, &mut linux.resources, &mut running, &balloon_host_socket); - if let Err(e) = response.send(&mut scm, socket.as_ref()) { + if let Err(e) = response.send(socket.as_ref()) { error!("failed to send VmResponse: {:?}", e); } if !running { diff --git a/src/main.rs b/src/main.rs index bd2b632726..de93f1bc9c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -47,7 +47,7 @@ use std::string::String; use std::thread::sleep; use std::time::Duration; -use sys_util::{Scm, getpid, kill_process_group, reap_child, syslog}; +use sys_util::{getpid, kill_process_group, reap_child, syslog}; use qcow::QcowFile; use argument::{Argument, set_arguments, print_help}; @@ -519,7 +519,6 @@ fn run_vm(args: std::env::Args) -> std::result::Result<(), ()> { } fn stop_vms(args: std::env::Args) -> std::result::Result<(), ()> { - let mut scm = Scm::new(1); if args.len() == 0 { print_help("crosvm stop", "VM_SOCKET...", &[]); println!("Stops the crosvm instance listening on each `VM_SOCKET` given."); @@ -532,7 +531,7 @@ fn stop_vms(args: std::env::Args) -> std::result::Result<(), ()> { Ok(s) }) { Ok(s) => { - if let Err(e) = VmRequest::Exit.send(&mut scm, &s) { + if let Err(e) = VmRequest::Exit.send(&s) { error!("failed to send stop request to socket at '{}': {:?}", socket_path, e); @@ -549,7 +548,6 @@ fn stop_vms(args: std::env::Args) -> std::result::Result<(), ()> { } fn balloon_vms(mut args: std::env::Args) -> std::result::Result<(), ()> { - let mut scm = Scm::new(1); if args.len() < 2 { print_help("crosvm balloon", "PAGE_ADJUST VM_SOCKET...", &[]); println!("Adjust the ballon size of the crosvm instance by `PAGE_ADJUST` pages, `PAGE_ADJUST` can be negative to shrink the balloon."); @@ -569,7 +567,7 @@ fn balloon_vms(mut args: std::env::Args) -> std::result::Result<(), ()> { Ok(s) }) { Ok(s) => { - if let Err(e) = VmRequest::BalloonAdjust(num_pages).send(&mut scm, &s) { + if let Err(e) = VmRequest::BalloonAdjust(num_pages).send(&s) { error!("failed to send balloon request to socket at '{}': {:?}", socket_path, e); diff --git a/src/plugin/process.rs b/src/plugin/process.rs index 0e8a74f6eb..71c16c0b2e 100644 --- a/src/plugin/process.rs +++ b/src/plugin/process.rs @@ -25,7 +25,7 @@ use protobuf::Message; use io_jail::Minijail; use kvm::{Vm, IoeventAddress, NoDatamatch, IrqSource, IrqRoute, PicId, dirty_log_bitmap_size}; use kvm_sys::{kvm_pic_state, kvm_ioapic_state, kvm_pit_state2}; -use sys_util::{EventFd, MemoryMapping, Killable, Scm, SharedMemory, GuestAddress, +use sys_util::{EventFd, MemoryMapping, Killable, ScmSocket, SharedMemory, GuestAddress, Result as SysResult, Error as SysError, SIGRTMIN}; use plugin_proto::*; @@ -119,9 +119,7 @@ pub struct Process { vcpu_sockets: Vec<(UnixDatagram, UnixDatagram)>, // Socket Transmission - scm: Scm, request_buffer: Vec, - datagram_files: Vec, response_buffer: Vec, } @@ -181,9 +179,7 @@ impl Process { per_vcpu_states, kill_evt: EventFd::new().map_err(Error::CreateEventFd)?, vcpu_sockets, - scm: Scm::new(1), request_buffer: vec![0; MAX_DATAGRAM_SIZE], - datagram_files: Vec::new(), response_buffer: Vec::new(), }) } @@ -444,10 +440,8 @@ impl Process { vcpu_handles: &[JoinHandle<()>], tap: Option<&Tap>) -> Result<()> { - let msg_size = self.scm - .recv(&self.request_sockets[index], - &mut [&mut self.request_buffer], - &mut self.datagram_files) + let (msg_size, request_file) = self.request_sockets[index] + .recv_with_fd(&mut self.request_buffer) .map_err(Error::PluginSocketRecv)?; if msg_size == 0 { @@ -475,7 +469,7 @@ impl Process { } } else if create.has_memory() { let memory = create.get_memory(); - match self.datagram_files.pop() { + match request_file { Some(memfd) => { Self::handle_memory(entry, vm, @@ -653,16 +647,13 @@ impl Process { response.errno = e.errno(); } - self.datagram_files.clear(); self.response_buffer.clear(); response .write_to_vec(&mut self.response_buffer) .map_err(Error::EncodeResponse)?; assert_ne!(self.response_buffer.len(), 0); - self.scm - .send(&self.request_sockets[index], - &[&self.response_buffer[..]], - &response_fds) + self.request_sockets[index] + .send_with_fds(&self.response_buffer[..], &response_fds) .map_err(Error::PluginSocketSend)?; Ok(()) diff --git a/sys_util/Cargo.toml b/sys_util/Cargo.toml index 21884f8c16..b3ba0a2bd5 100644 --- a/sys_util/Cargo.toml +++ b/sys_util/Cargo.toml @@ -2,13 +2,9 @@ name = "sys_util" version = "0.1.0" authors = ["The Chromium OS Authors"] -build = "build.rs" [dependencies] data_model = { path = "../data_model" } libc = "*" syscall_defines = { path = "../syscall_defines" } poll_token_derive = { path = "poll_token_derive" } - -[build-dependencies] -cc = "=1.0.15" diff --git a/sys_util/build.rs b/sys_util/build.rs deleted file mode 100644 index f1fc1df98e..0000000000 --- a/sys_util/build.rs +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright 2017 The Chromium OS Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -extern crate cc; - -fn main() { - cc::Build::new() - .file("sock_ctrl_msg.c") - .compile("sock_ctrl_msg"); -} diff --git a/sys_util/sock_ctrl_msg.c b/sys_util/sock_ctrl_msg.c deleted file mode 100644 index d569a47a0a..0000000000 --- a/sys_util/sock_ctrl_msg.c +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright 2017 The Chromium OS Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include -#include // memcpy -#include // close -#include -#include // CMSG_* - -/* - * Returns the number of bytes the `cmsg_buffer` must be for the functions that take a cmsg_buffer - * in this module. - * Arguments: - * * `fd_count` - Maximum number of file descriptors to be sent or received via the cmsg. - */ -size_t scm_cmsg_buffer_len(size_t fd_count) -{ - return CMSG_SPACE(sizeof(int) * fd_count); -} - -/* - * Convenience wrapper around `sendmsg` that builds up the `msghdr` structure for you given the - * array of fds. - * Arguments: - * * `fd` - Unix domain socket to `sendmsg` on. - * * `outv` - Array of `outv_count` length `iovec`s that contain the data to send. - * * `outv_count` - Number of elements in `outv` array. - * * `cmsg_buffer` - A buffer that must be at least `scm_cmsg_buffer_len(fd_count)` bytes long. - * * `fds` - Array of `fd_count` file descriptors to send along with data. - * * `fd_count` - Number of elements in `fds` array. - * Returns: - * A non-negative number indicating how many bytes were sent on success or a negative errno on - * failure. - */ -ssize_t scm_sendmsg(int fd, const struct iovec *outv, size_t outv_count, uint8_t *cmsg_buffer, - const int *fds, size_t fd_count) -{ - if (fd < 0 || ((!cmsg_buffer || !fds) && fd_count > 0)) - return -EINVAL; - - struct msghdr msg; - memset(&msg, 0, sizeof(msg)); - msg.msg_iov = (struct iovec *)outv; // discard const, sendmsg won't mutate it - msg.msg_iovlen = outv_count; - - if (fd_count) { - msg.msg_control = cmsg_buffer; - msg.msg_controllen = scm_cmsg_buffer_len(fd_count); - - struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); - cmsg->cmsg_level = SOL_SOCKET; - cmsg->cmsg_type = SCM_RIGHTS; - cmsg->cmsg_len = CMSG_LEN(fd_count * sizeof(int)); - memcpy(CMSG_DATA(cmsg), fds, fd_count * sizeof(int)); - - msg.msg_controllen = cmsg->cmsg_len; - } - - ssize_t bytes_sent = sendmsg(fd, &msg, MSG_NOSIGNAL); - if (bytes_sent == -1) - return -errno; - - return bytes_sent; -} - -/* - * Convenience wrapper around `recvmsg` that builds up the `msghdr` structure and returns up to - * `*fd_count` file descriptors in the given `fds` array. - * Arguments: - * * `fd` - Unix domain socket to `recvmsg` on. - * * `outv` - Array of `outv_count` length `iovec`s that will contain the received data. - * * `outv_count` - Number of elements in `outv` array. - * * `cmsg_buffer` - A buffer that must be at least `scm_cmsg_buffer_len(*fd_count)` bytes long. - * * `fds` - Array of `fd_count` file descriptors to receive along with data. - * * `fd_count` - Number of elements in `fds` array. - * Returns: - * A non-negative number indicating how many bytes were received on success or a negative errno on - * failure. - */ -ssize_t scm_recvmsg(int fd, struct iovec *outv, size_t outv_count, uint8_t *cmsg_buffer, int *fds, - size_t *fd_count) -{ - if (fd < 0 || !cmsg_buffer || !fds || !fd_count) - return -EINVAL; - - struct msghdr msg; - memset(&msg, 0, sizeof(msg)); - msg.msg_iov = outv; - msg.msg_iovlen = outv_count; - msg.msg_control = cmsg_buffer; - msg.msg_controllen = scm_cmsg_buffer_len(*fd_count); - - ssize_t total_read = recvmsg(fd, &msg, 0); - if (total_read == -1) - return -errno; - - if (total_read == 0 && CMSG_FIRSTHDR(&msg) == NULL) { - *fd_count = 0; - return 0; - } - - size_t fd_idx = 0; - struct cmsghdr *cmsg; - for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL; cmsg = CMSG_NXTHDR(&msg, cmsg)) { - if (cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_RIGHTS) - continue; - - size_t cmsg_fd_count = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int); - - int *cmsg_fds = (int *)CMSG_DATA(cmsg); - size_t cmsg_fd_idx; - for (cmsg_fd_idx = 0; cmsg_fd_idx < cmsg_fd_count; cmsg_fd_idx++) { - if (fd_idx < *fd_count) { - fds[fd_idx] = cmsg_fds[cmsg_fd_idx]; - fd_idx++; - } else { - close(cmsg_fds[cmsg_fd_idx]); - } - } - } - - *fd_count = fd_idx; - - return total_read; -} diff --git a/sys_util/src/sock_ctrl_msg.rs b/sys_util/src/sock_ctrl_msg.rs index 8a3ce39e7c..ebbf9db211 100644 --- a/sys_util/src/sock_ctrl_msg.rs +++ b/sys_util/src/sock_ctrl_msg.rs @@ -2,41 +2,200 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +//! Used to send and receive messages with file descriptors on sockets that accept control messages +//! (e.g. Unix domain sockets). + use std::fs::File; +use std::mem::size_of; use std::os::unix::io::{AsRawFd, RawFd, FromRawFd}; use std::os::unix::net::{UnixDatagram, UnixStream}; +use std::ptr::{copy_nonoverlapping, null_mut, write_unaligned}; -use libc::{c_void, iovec}; +use libc::{c_void, c_long, iovec, msghdr, cmsghdr, sendmsg, recvmsg, SOL_SOCKET, SCM_RIGHTS, + MSG_NOSIGNAL}; use data_model::VolatileSlice; use {Result, Error}; -// These functions are implemented in C because each of them requires complicated setup with CMSG -// macros. These macros are part of the system headers and are required to be used for portability -// reasons. In practice, the control message ABI can't change but using them is much easier and less -// error prone than trying to port these macros to rust. -extern "C" { - fn scm_cmsg_buffer_len(fd_count: usize) -> usize; - fn scm_sendmsg(fd: RawFd, - outv: *const iovec, - outv_count: usize, - cmsg_buffer: *mut u8, - fds: *const RawFd, - fd_count: usize) - -> isize; - fn scm_recvmsg(fd: RawFd, - outv: *mut iovec, - outv_count: usize, - cmsg_buffer: *mut u8, - fds: *mut RawFd, - fd_count: *mut usize) - -> isize; +// Each of the following macros performs the same function as their C counterparts. They are each +// macros because they are used to size statically allocated arrays. + +macro_rules! CMSG_ALIGN { + ($len:expr) => ( + (($len) + size_of::() - 1) & !(size_of::() - 1) + ) } -fn cmsg_buffer_len(fd_count: usize) -> usize { - // Safe because this function has no side effects, touches no pointers, and never fails. - unsafe { scm_cmsg_buffer_len(fd_count) } +macro_rules! CMSG_SPACE { + ($len:expr) => ( + size_of::() + CMSG_ALIGN!($len) + ) +} + +macro_rules! CMSG_LEN { + ($len:expr) => ( + size_of::() + ($len) + ) +} + + +// This function (macro in the C version) is not used in any compile time constant slots, so is just +// an ordinary function. The returned pointer is hard coded to be RawFd because that's all that this +// module supports. +#[allow(non_snake_case)] +#[inline(always)] +fn CMSG_DATA(cmsg_buffer: *mut u8) -> *mut RawFd { + // Essentially returns a pointer to just past the header. + (cmsg_buffer as *mut cmsghdr).wrapping_offset(1) as *mut RawFd +} + +// This function is like CMSG_NEXT, but safer because it reads only from references, although it +// does some pointer arithmetic on cmsg_ptr. +fn get_next_cmsg(msghdr: &msghdr, cmsg: &cmsghdr, cmsg_ptr: *mut u8) -> *mut u8 { + let next_cmsg = cmsg_ptr.wrapping_offset(CMSG_ALIGN!(cmsg.cmsg_len) as isize); + if next_cmsg + .wrapping_offset(1) + .wrapping_sub(msghdr.msg_control as usize) as usize > msghdr.msg_controllen { + null_mut() + } else { + next_cmsg + } +} + + +const CMSG_BUFFER_INLINE_CAPACITY: usize = CMSG_SPACE!(size_of::() * 32); + +enum CmsgBuffer { + Inline([u8; CMSG_BUFFER_INLINE_CAPACITY]), + Heap(Box<[u8]>), +} + +impl CmsgBuffer { + fn with_capacity(capacity: usize) -> CmsgBuffer { + if capacity <= CMSG_BUFFER_INLINE_CAPACITY { + CmsgBuffer::Inline([0u8; CMSG_BUFFER_INLINE_CAPACITY]) + } else { + CmsgBuffer::Heap(vec![0; capacity].into_boxed_slice()) + } + } + + fn as_mut_ptr(&mut self) -> *mut u8 { + match self { + CmsgBuffer::Inline(a) => a.as_mut_ptr(), + CmsgBuffer::Heap(a) => a.as_mut_ptr(), + } + } +} + +fn raw_sendmsg(fd: RawFd, out_data: D, out_fds: &[RawFd]) -> Result { + let cmsg_capacity = CMSG_SPACE!(size_of::() * out_fds.len()); + let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity); + + let mut iovec = iovec { + iov_base: out_data.as_ptr() as *mut c_void, + iov_len: out_data.size(), + }; + + let mut msg = msghdr { + msg_name: null_mut(), + msg_namelen: 0, + msg_iov: &mut iovec as *mut iovec, + msg_iovlen: 1, + msg_control: null_mut(), + msg_controllen: 0, + msg_flags: 0, + }; + + if !out_fds.is_empty() { + let cmsg = cmsghdr { + cmsg_len: CMSG_LEN!(size_of::() * out_fds.len()), + cmsg_level: SOL_SOCKET, + cmsg_type: SCM_RIGHTS, + }; + unsafe { + // Safe because cmsg_buffer was allocated to be large enough to contain cmsghdr. + write_unaligned(cmsg_buffer.as_mut_ptr() as *mut cmsghdr, cmsg); + // Safe because the cmsg_buffer was allocated to be large enough to hold out_fds.len() + // file descriptors. + copy_nonoverlapping(out_fds.as_ptr(), + CMSG_DATA(cmsg_buffer.as_mut_ptr()), + out_fds.len()); + } + + msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void; + msg.msg_controllen = cmsg_capacity; + } + + // Safe because the msghdr was properly constructed from valid (or null) pointers of the + // indicated length and we check the return value. + let write_count = unsafe { sendmsg(fd, &msg, MSG_NOSIGNAL) }; + + if write_count == -1 { + Err(Error::last()) + } else { + Ok(write_count as usize) + } +} + +fn raw_recvmsg(fd: RawFd, in_data: &mut [u8], in_fds: &mut [RawFd]) -> Result<(usize, usize)> { + let cmsg_capacity = CMSG_SPACE!(size_of::() * in_fds.len()); + let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity); + + let mut iovec = iovec { + iov_base: in_data.as_mut_ptr() as *mut c_void, + iov_len: in_data.len(), + }; + + let mut msg = msghdr { + msg_name: null_mut(), + msg_namelen: 0, + msg_iov: &mut iovec as *mut iovec, + msg_iovlen: 1, + msg_control: null_mut(), + msg_controllen: 0, + msg_flags: 0, + }; + + if !in_fds.is_empty() { + msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void; + msg.msg_controllen = cmsg_capacity; + } + + // Safe because the msghdr was properly constructed from valid (or null) pointers of the + // indicated length and we check the return value. + let total_read = unsafe { recvmsg(fd, &mut msg, 0) }; + + if total_read == -1 { + return Err(Error::last()); + } + + if total_read == 0 && msg.msg_controllen < size_of::() { + return Ok((0, 0)); + } + + let mut cmsg_ptr = msg.msg_control as *mut u8; + let mut in_fds_count = 0; + while !cmsg_ptr.is_null() { + // Safe because we checked that cmsg_ptr was non-null, and the loop is constructed such that + // that only happens when there is at least sizeof(cmsghdr) space after the pointer to read. + let cmsg = unsafe { (cmsg_ptr as *mut cmsghdr).read_unaligned() }; + + if cmsg.cmsg_level == SOL_SOCKET && cmsg.cmsg_type == SCM_RIGHTS { + let fd_count = (cmsg.cmsg_len - CMSG_LEN!(0)) / size_of::(); + unsafe { + copy_nonoverlapping(CMSG_DATA(cmsg_ptr), + in_fds[in_fds_count..(in_fds_count + fd_count)].as_mut_ptr(), + fd_count); + } + in_fds_count += fd_count; + } + + cmsg_ptr = get_next_cmsg(&msg, &cmsg, cmsg_ptr); + } + + + Ok((total_read as usize, in_fds_count)) } /// Trait for file descriptors can send and receive socket control messages via `sendmsg` and @@ -44,6 +203,69 @@ fn cmsg_buffer_len(fd_count: usize) -> usize { pub trait ScmSocket { /// Gets the file descriptor of this socket. fn socket_fd(&self) -> RawFd; + + /// Sends the given data and file descriptor over the socket. + /// + /// On success, returns the number of bytes sent. + /// + /// # Arguments + /// + /// * `buf` - A buffer of data to send on the `socket`. + /// * `fd` - A file descriptors to be sent. + fn send_with_fd(&self, buf: D, fd: RawFd) -> Result { + self.send_with_fds(buf, &[fd]) + } + + /// Sends the given data and file descriptors over the socket. + /// + /// On success, returns the number of bytes sent. + /// + /// # Arguments + /// + /// * `buf` - A buffer of data to send on the `socket`. + /// * `fds` - A list of file descriptors to be sent. + fn send_with_fds(&self, buf: D, fd: &[RawFd]) -> Result { + raw_sendmsg(self.socket_fd(), buf, fd) + } + + /// Receives data and potentially a file descriptor from the socket. + /// + /// On success, returns the number of bytes and an optional file descriptor. + /// + /// # Arguments + /// + /// * `buf` - A buffer to receive data from the socket.vm + fn recv_with_fd(&self, buf: &mut [u8]) -> Result<(usize, Option)> { + let mut fd = [0]; + let (read_count, fd_count) = self.recv_with_fds(buf, &mut fd)?; + let file = if fd_count == 0 { + None + } else { + // Safe because the first fd from recv_with_fds is owned by us and valid because this + // branch was taken. + Some(unsafe { File::from_raw_fd(fd[0]) }) + }; + Ok((read_count, file)) + + } + + + /// Receives data and file descriptors from the socket. + /// + /// On success, returns the number of bytes and file descriptors received as a tuple + /// `(bytes count, files count)`. + /// + /// # Arguments + /// + /// * `buf` - A buffer to receive data from the socket. + /// * `fds` - A slice of `RawFd`s to put the received file descriptors into. On success, the + /// number of valid file descriptors is indicated by the second element of the + /// returned tuple. The caller owns these file descriptors, but they will not be + /// closed on drop like a `File`-like type would be. It is recommended that each valid + /// file descriptor gets wrapped in a drop type that closes it after this returns. + fn recv_with_fds(&self, buf: &mut [u8], fds: &mut [RawFd]) -> Result<(usize, usize)> { + raw_recvmsg(self.socket_fd(), buf, fds) + } } impl ScmSocket for UnixDatagram { @@ -95,123 +317,6 @@ unsafe impl<'a> IntoIovec for VolatileSlice<'a> { } } -/// Used to send and receive messages with file descriptors on sockets that accept control messages -/// (e.g. Unix domain sockets). -pub struct Scm { - cmsg_buffer: Vec, - vecs: Vec, - fds: Vec, -} - -impl Scm { - /// Constructs a new Scm object with pre-allocated structures. - /// - /// # Arguments - /// - /// * `fd_count` - The maximum number of files that can be received per `recv` call. - pub fn new(fd_count: usize) -> Scm { - Scm { - cmsg_buffer: Vec::with_capacity(cmsg_buffer_len(fd_count)), - vecs: Vec::new(), - fds: vec![-1; fd_count], - } - } - - /// Sends the given data and file descriptors over the given `socket`. - /// - /// On success, returns the number of bytes sent. - /// - /// # Arguments - /// - /// * `socket` - A socket that supports socket control messages. - /// * `bufs` - A list of buffers to send on the `socket`. - /// * `fds` - A list of file descriptors to be sent. - pub fn send(&mut self, - socket: &T, - bufs: &[D], - fds: &[RawFd]) - -> Result { - let cmsg_buf_len = cmsg_buffer_len(fds.len()); - self.cmsg_buffer.reserve(cmsg_buf_len); - self.vecs.clear(); - for ref buf in bufs { - self.vecs - .push(iovec { - iov_base: buf.as_ptr() as *mut c_void, - iov_len: buf.size(), - }); - } - let write_count = unsafe { - // Safe because we are giving scm_sendmsg only valid pointers and lengths and we check - // the return value. - self.cmsg_buffer.set_len(cmsg_buf_len); - scm_sendmsg(socket.socket_fd(), - self.vecs.as_ptr(), - self.vecs.len(), - self.cmsg_buffer.as_mut_ptr(), - fds.as_ptr(), - fds.len()) - }; - - if write_count < 0 { - Err(Error::new(-write_count as i32)) - } else { - Ok(write_count as usize) - } - } - - /// Receives data and file descriptors from the given `socket` into the list of buffers. - /// - /// On success, returns the number of bytes received. - /// - /// # Arguments - /// - /// * `socket` - A socket that supports socket control messages. - /// * `bufs` - A list of buffers to receive data from the `socket`. The `recvmsg` call fills - /// these directly. - /// * `files` - A vector of `File`s to put the received file descriptors into. This vector is - /// not cleared and will have at most `fd_count` (specified in `Scm::new`) `File`s - /// added to it. - pub fn recv(&mut self, - socket: &T, - bufs: &mut [&mut [u8]], - files: &mut Vec) - -> Result { - let cmsg_buf_len = cmsg_buffer_len(files.len()); - self.cmsg_buffer.reserve(cmsg_buf_len); - self.vecs.clear(); - for buf in bufs { - self.vecs - .push(iovec { - iov_base: buf.as_mut_ptr() as *mut c_void, - iov_len: buf.len(), - }); - } - let mut fd_count = self.fds.len(); - let read_count = unsafe { - // Safe because we are giving scm_recvmsg only valid pointers and lengths and we check - // the return value. - self.cmsg_buffer.set_len(cmsg_buf_len); - scm_recvmsg(socket.socket_fd(), - self.vecs.as_mut_ptr(), - self.vecs.len(), - self.cmsg_buffer.as_mut_ptr(), - self.fds.as_mut_ptr(), - &mut fd_count as *mut usize) - }; - - if read_count < 0 { - Err(Error::new(-read_count as i32)) - } else { - // Safe because we have unqiue ownership of each fd we wrap with File. - for &fd in &self.fds[0..fd_count] { - files.push(unsafe { File::from_raw_fd(fd) }); - } - Ok(read_count as usize) - } - } -} - #[cfg(test)] mod tests { use super::*; @@ -228,22 +333,22 @@ mod tests { #[test] fn buffer_len() { - assert_eq!(cmsg_buffer_len(0), size_of::()); - assert_eq!(cmsg_buffer_len(1), + assert_eq!(CMSG_SPACE!(0 * size_of::()), size_of::()); + assert_eq!(CMSG_SPACE!(1 * size_of::()), size_of::() + size_of::()); if size_of::() == 4 { - assert_eq!(cmsg_buffer_len(2), + assert_eq!(CMSG_SPACE!(2 * size_of::()), size_of::() + size_of::()); - assert_eq!(cmsg_buffer_len(3), + assert_eq!(CMSG_SPACE!(3 * size_of::()), size_of::() + size_of::() * 2); - assert_eq!(cmsg_buffer_len(4), + assert_eq!(CMSG_SPACE!(4 * size_of::()), size_of::() + size_of::() * 2); } else if size_of::() == 8 { - assert_eq!(cmsg_buffer_len(2), + assert_eq!(CMSG_SPACE!(2 * size_of::()), size_of::() + size_of::() * 2); - assert_eq!(cmsg_buffer_len(3), + assert_eq!(CMSG_SPACE!(3 * size_of::()), size_of::() + size_of::() * 3); - assert_eq!(cmsg_buffer_len(4), + assert_eq!(CMSG_SPACE!(4 * size_of::()), size_of::() + size_of::() * 4); } } @@ -252,51 +357,42 @@ mod tests { fn send_recv_no_fd() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); - let write_count = scm.send(&s1, - [[1u8, 1, 2].as_ref(), [21, 34, 55].as_ref()].as_ref(), - &[]) + let write_count = s1.send_with_fds([1u8, 1, 2, 21, 34, 55].as_ref(), &[]) .expect("failed to send data"); assert_eq!(write_count, 6); - let mut buf1 = [0; 3]; - let mut buf2 = [0; 3]; - let mut bufs = [buf1.as_mut(), buf2.as_mut()]; - let mut files = Vec::new(); - let read_count = scm.recv(&s2, &mut bufs[..], &mut files) + let mut buf = [0; 6]; + let mut files = [0; 1]; + let (read_count, file_count) = s2.recv_with_fds(&mut buf[..], &mut files) .expect("failed to recv data"); assert_eq!(read_count, 6); - assert!(files.is_empty()); - assert_eq!(bufs[0], [1, 1, 2]); - assert_eq!(bufs[1], [21, 34, 55]); + assert_eq!(file_count, 0); + assert_eq!(buf, [1, 1, 2, 21, 34, 55]); } #[test] fn send_recv_only_fd() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); let evt = EventFd::new().expect("failed to create eventfd"); - let write_count = scm.send(&s1, &[[].as_ref()], &[evt.as_raw_fd()]) + let write_count = s1.send_with_fd([].as_ref(), evt.as_raw_fd()) .expect("failed to send fd"); assert_eq!(write_count, 0); - let mut files = Vec::new(); - let read_count = scm.recv(&s2, &mut [&mut []], &mut files) - .expect("failed to recv fd"); + let (read_count, file_opt) = s2.recv_with_fd(&mut []).expect("failed to recv fd"); + + let mut file = file_opt.unwrap(); assert_eq!(read_count, 0); - assert_eq!(files.len(), 1); - assert!(files[0].as_raw_fd() >= 0); - assert_ne!(files[0].as_raw_fd(), s1.as_raw_fd()); - assert_ne!(files[0].as_raw_fd(), s2.as_raw_fd()); - assert_ne!(files[0].as_raw_fd(), evt.as_raw_fd()); + assert!(file.as_raw_fd() >= 0); + assert_ne!(file.as_raw_fd(), s1.as_raw_fd()); + assert_ne!(file.as_raw_fd(), s2.as_raw_fd()); + assert_ne!(file.as_raw_fd(), evt.as_raw_fd()); - files[0] - .write(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) }) + file.write(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) }) .expect("failed to write to sent fd"); assert_eq!(evt.read().expect("failed to read from eventfd"), 1203); @@ -306,28 +402,28 @@ mod tests { fn send_recv_with_fd() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); let evt = EventFd::new().expect("failed to create eventfd"); - let write_count = scm.send(&s1, &[[237].as_ref()], &[evt.as_raw_fd()]) + let write_count = s1.send_with_fds([237].as_ref(), &[evt.as_raw_fd()]) .expect("failed to send fd"); assert_eq!(write_count, 1); - let mut files = Vec::new(); + let mut files = [0; 2]; let mut buf = [0u8]; - let read_count = scm.recv(&s2, &mut [&mut buf], &mut files) + let (read_count, file_count) = s2.recv_with_fds(&mut buf, &mut files) .expect("failed to recv fd"); assert_eq!(read_count, 1); assert_eq!(buf[0], 237); - assert_eq!(files.len(), 1); - assert!(files[0].as_raw_fd() >= 0); - assert_ne!(files[0].as_raw_fd(), s1.as_raw_fd()); - assert_ne!(files[0].as_raw_fd(), s2.as_raw_fd()); - assert_ne!(files[0].as_raw_fd(), evt.as_raw_fd()); + assert_eq!(file_count, 1); + assert!(files[0] >= 0); + assert_ne!(files[0], s1.as_raw_fd()); + assert_ne!(files[0], s2.as_raw_fd()); + assert_ne!(files[0], evt.as_raw_fd()); - files[0] - .write(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) }) + let mut file = unsafe { File::from_raw_fd(files[0]) }; + + file.write(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) }) .expect("failed to write to sent fd"); assert_eq!(evt.read().expect("failed to read from eventfd"), 1203); diff --git a/vm_control/src/lib.rs b/vm_control/src/lib.rs index ad123a1890..fae2e1d585 100644 --- a/vm_control/src/lib.rs +++ b/vm_control/src/lib.rs @@ -27,7 +27,8 @@ use libc::{ERANGE, EINVAL, ENODEV}; use byteorder::{LittleEndian, WriteBytesExt}; use data_model::{DataInit, Le32, Le64, VolatileMemory}; -use sys_util::{EventFd, Result, Error as SysError, MmapError, MemoryMapping, Scm, GuestAddress}; +use sys_util::{EventFd, Result, Error as SysError, MmapError, MemoryMapping, ScmSocket, + GuestAddress}; use resources::{GpuMemoryDesc, GpuMemoryPlaneDesc, SystemAllocator}; use kvm::{IoeventAddress, Vm}; @@ -134,11 +135,10 @@ impl VmRequest { /// Receive a `VmRequest` from the given socket. /// /// A `VmResponse` should be sent out over the given socket before another request is received. - pub fn recv(scm: &mut Scm, s: &UnixDatagram) -> VmControlResult { + pub fn recv(s: &UnixDatagram) -> VmControlResult { assert_eq!(VM_REQUEST_SIZE, std::mem::size_of::()); let mut buf = [0; VM_REQUEST_SIZE]; - let mut fds = Vec::new(); - let read = scm.recv(s, &mut [&mut buf], &mut fds) + let (read, file) = s.recv_with_fd(&mut buf) .map_err(|e| VmControlError::Recv(e))?; if read != VM_REQUEST_SIZE { return Err(VmControlError::BadSize(read)); @@ -150,7 +150,7 @@ impl VmRequest { match req.type_.into() { VM_REQUEST_TYPE_EXIT => Ok(VmRequest::Exit), VM_REQUEST_TYPE_REGISTER_MEMORY => { - let fd = fds.pop().ok_or(VmControlError::ExpectFd)?; + let fd = file.ok_or(VmControlError::ExpectFd)?; Ok(VmRequest::RegisterMemory(MaybeOwnedFd::Owned(fd), req.size.to_native() as usize)) } @@ -172,7 +172,7 @@ impl VmRequest { /// /// After this request is a sent, a `VmResponse` should be received before sending another /// request. - pub fn send(&self, scm: &mut Scm, s: &UnixDatagram) -> VmControlResult<()> { + pub fn send(&self, s: &UnixDatagram) -> VmControlResult<()> { assert_eq!(VM_REQUEST_SIZE, std::mem::size_of::()); let mut req = VmRequestStruct::default(); let mut fd_buf = [0; 1]; @@ -203,7 +203,7 @@ impl VmRequest { } let mut buf = [0; VM_REQUEST_SIZE]; buf.as_mut().get_ref(0).unwrap().store(req); - scm.send(s, &[buf.as_ref()], &fd_buf[..fd_len]) + s.send_with_fds(buf.as_ref(), &fd_buf[..fd_len]) .map_err(|e| VmControlError::Send(e))?; Ok(()) } @@ -332,10 +332,9 @@ impl VmResponse { /// Receive a `VmResponse` from the given socket. /// /// This should be called after the sending a `VmRequest` before sending another request. - pub fn recv(scm: &mut Scm, s: &UnixDatagram) -> VmControlResult { + pub fn recv(s: &UnixDatagram) -> VmControlResult { let mut buf = [0; VM_RESPONSE_SIZE]; - let mut fds = Vec::new(); - let read = scm.recv(s, &mut [&mut buf], &mut fds) + let (read, file) = s.recv_with_fd(&mut buf) .map_err(|e| VmControlError::Recv(e))?; if read != VM_RESPONSE_SIZE { return Err(VmControlError::BadSize(read)); @@ -354,7 +353,7 @@ impl VmResponse { }) } VM_RESPONSE_TYPE_ALLOCATE_AND_REGISTER_GPU_MEMORY => { - let fd = fds.pop().ok_or(VmControlError::ExpectFd)?; + let fd = file.ok_or(VmControlError::ExpectFd)?; Ok(VmResponse::AllocateAndRegisterGpuMemory { fd: MaybeOwnedFd::Owned(fd), pfn: resp.pfn.into(), @@ -377,7 +376,7 @@ impl VmResponse { /// /// This must be called after receiving a `VmRequest` to indicate the outcome of that request's /// execution. - pub fn send(&self, scm: &mut Scm, s: &UnixDatagram) -> VmControlResult<()> { + pub fn send(&self, s: &UnixDatagram) -> VmControlResult<()> { let mut resp = VmResponseStruct::default(); let mut fd_buf = [0; 1]; let mut fd_len = 0; @@ -408,7 +407,7 @@ impl VmResponse { } let mut buf = [0; VM_RESPONSE_SIZE]; buf.as_mut().get_ref(0).unwrap().store(resp); - scm.send(s, &[buf.as_ref()], &fd_buf[..fd_len]) + s.send_with_fds(buf.as_ref(), &fd_buf[..fd_len]) .map_err(|e| VmControlError::Send(e))?; Ok(()) } @@ -427,9 +426,8 @@ mod tests { #[test] fn request_exit() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); - VmRequest::Exit.send(&mut scm, &s1).unwrap(); - match VmRequest::recv(&mut scm, &s2).unwrap() { + VmRequest::Exit.send(&s1).unwrap(); + match VmRequest::recv(&s2).unwrap() { VmRequest::Exit => {} _ => panic!("recv wrong request variant"), } @@ -439,14 +437,13 @@ mod tests { fn request_register_memory() { if !kernel_has_memfd() { return; } let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); let shm_size: usize = 4096; let mut shm = SharedMemory::new(None).unwrap(); shm.set_size(shm_size as u64).unwrap(); VmRequest::RegisterMemory(MaybeOwnedFd::Borrowed(shm.as_raw_fd()), shm_size) - .send(&mut scm, &s1) + .send(&s1) .unwrap(); - match VmRequest::recv(&mut scm, &s2).unwrap() { + match VmRequest::recv(&s2).unwrap() { VmRequest::RegisterMemory(MaybeOwnedFd::Owned(fd), size) => { assert!(fd.as_raw_fd() >= 0); assert_eq!(size, shm_size); @@ -458,11 +455,8 @@ mod tests { #[test] fn request_unregister_memory() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); - VmRequest::UnregisterMemory(77) - .send(&mut scm, &s1) - .unwrap(); - match VmRequest::recv(&mut scm, &s2).unwrap() { + VmRequest::UnregisterMemory(77).send(&s1).unwrap(); + match VmRequest::recv(&s2).unwrap() { VmRequest::UnregisterMemory(slot) => assert_eq!(slot, 77), _ => panic!("recv wrong request variant"), } @@ -471,11 +465,10 @@ mod tests { #[test] fn request_expect_fd() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); let mut bad_request = [0; VM_REQUEST_SIZE]; bad_request[0] = VM_REQUEST_TYPE_REGISTER_MEMORY as u8; - scm.send(&s2, &[bad_request.as_ref()], &[]).unwrap(); - match VmRequest::recv(&mut scm, &s1) { + s2.send_with_fds(bad_request.as_ref(), &[]).unwrap(); + match VmRequest::recv(&s1) { Err(VmControlError::ExpectFd) => {} _ => panic!("recv wrong error variant"), } @@ -484,9 +477,8 @@ mod tests { #[test] fn request_no_data() { let (s1, _) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); s1.shutdown(Shutdown::Both).unwrap(); - match VmRequest::recv(&mut scm, &s1) { + match VmRequest::recv(&s1) { Err(VmControlError::BadSize(s)) => assert_eq!(s, 0), _ => panic!("recv wrong error variant"), } @@ -495,9 +487,8 @@ mod tests { #[test] fn request_bad_size() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); - scm.send(&s2, &[[12; 7].as_ref()], &[]).unwrap(); - match VmRequest::recv(&mut scm, &s1) { + s2.send_with_fds([12; 7].as_ref(), &[]).unwrap(); + match VmRequest::recv(&s1) { Err(VmControlError::BadSize(_)) => {} _ => panic!("recv wrong error variant"), } @@ -506,10 +497,9 @@ mod tests { #[test] fn request_invalid_type() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); - scm.send(&s2, &[[12; VM_REQUEST_SIZE].as_ref()], &[]) + s2.send_with_fds([12; VM_REQUEST_SIZE].as_ref(), &[]) .unwrap(); - match VmRequest::recv(&mut scm, &s1) { + match VmRequest::recv(&s1) { Err(VmControlError::InvalidType) => {} _ => panic!("recv wrong error variant"), } @@ -518,14 +508,21 @@ mod tests { #[test] fn request_allocate_and_register_gpu_memory() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); let gpu_width: u32 = 32; let gpu_height: u32 = 32; let gpu_format: u32 = 0x34325258; - let r = VmRequest::AllocateAndRegisterGpuMemory { width: gpu_width, height: gpu_height, format: gpu_format }; - r.send(&mut scm, &s1).unwrap(); - match VmRequest::recv(&mut scm, &s2).unwrap() { - VmRequest::AllocateAndRegisterGpuMemory {width, height, format} => { + let r = VmRequest::AllocateAndRegisterGpuMemory { + width: gpu_width, + height: gpu_height, + format: gpu_format, + }; + r.send(&s1).unwrap(); + match VmRequest::recv(&s2).unwrap() { + VmRequest::AllocateAndRegisterGpuMemory { + width, + height, + format, + } => { assert_eq!(width, gpu_width); assert_eq!(height, gpu_width); assert_eq!(format, gpu_format); @@ -537,9 +534,8 @@ mod tests { #[test] fn resp_ok() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); - VmResponse::Ok.send(&mut scm, &s1).unwrap(); - match VmResponse::recv(&mut scm, &s2).unwrap() { + VmResponse::Ok.send(&s1).unwrap(); + match VmResponse::recv(&s2).unwrap() { VmResponse::Ok => {} _ => panic!("recv wrong response variant"), } @@ -548,10 +544,9 @@ mod tests { #[test] fn resp_err() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); let r1 = VmResponse::Err(SysError::new(libc::EDESTADDRREQ)); - r1.send(&mut scm, &s1).unwrap(); - match VmResponse::recv(&mut scm, &s2).unwrap() { + r1.send(&s1).unwrap(); + match VmResponse::recv(&s2).unwrap() { VmResponse::Err(e) => { assert_eq!(e, SysError::new(libc::EDESTADDRREQ)); } @@ -562,12 +557,14 @@ mod tests { #[test] fn resp_memory() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); let memory_pfn = 55; let memory_slot = 66; - let r1 = VmResponse::RegisterMemory { pfn: memory_pfn, slot: memory_slot }; - r1.send(&mut scm, &s1).unwrap(); - match VmResponse::recv(&mut scm, &s2).unwrap() { + let r1 = VmResponse::RegisterMemory { + pfn: memory_pfn, + slot: memory_slot, + }; + r1.send(&s1).unwrap(); + match VmResponse::recv(&s2).unwrap() { VmResponse::RegisterMemory { pfn, slot } => { assert_eq!(pfn, memory_pfn); assert_eq!(slot, memory_slot); @@ -579,9 +576,8 @@ mod tests { #[test] fn resp_no_data() { let (s1, _) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); s1.shutdown(Shutdown::Both).unwrap(); - match VmResponse::recv(&mut scm, &s1) { + match VmResponse::recv(&s1) { Err(e) => { assert_eq!(e, VmControlError::BadSize(0)); } @@ -592,9 +588,8 @@ mod tests { #[test] fn resp_bad_size() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); - scm.send(&s2, &[[12; 7].as_ref()], &[]).unwrap(); - match VmResponse::recv(&mut scm, &s1) { + s2.send_with_fds([12; 7].as_ref(), &[]).unwrap(); + match VmResponse::recv(&s1) { Err(e) => { assert_eq!(e, VmControlError::BadSize(7)); } @@ -605,10 +600,9 @@ mod tests { #[test] fn resp_invalid_type() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); - scm.send(&s2, &[[12; VM_RESPONSE_SIZE].as_ref()], &[]) + s2.send_with_fds([12; VM_RESPONSE_SIZE].as_ref(), &[]) .unwrap(); - match VmResponse::recv(&mut scm, &s1) { + match VmResponse::recv(&s1) { Err(e) => { assert_eq!(e, VmControlError::InvalidType); } @@ -620,7 +614,6 @@ mod tests { fn resp_allocate_and_register_gpu_memory() { if !kernel_has_memfd() { return; } let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); let shm_size: usize = 4096; let mut shm = SharedMemory::new(None).unwrap(); shm.set_size(shm_size as u64).unwrap(); @@ -637,9 +630,14 @@ mod tests { slot: memory_slot, desc: GpuMemoryDesc { planes: memory_planes }, }; - r1.send(&mut scm, &s1).unwrap(); - match VmResponse::recv(&mut scm, &s2).unwrap() { - VmResponse::AllocateAndRegisterGpuMemory { fd, pfn, slot, desc } => { + r1.send(&s1).unwrap(); + match VmResponse::recv(&s2).unwrap() { + VmResponse::AllocateAndRegisterGpuMemory { + fd, + pfn, + slot, + desc, + } => { assert!(fd.as_raw_fd() >= 0); assert_eq!(pfn, memory_pfn); assert_eq!(slot, memory_slot);