diff --git a/Cargo.lock b/Cargo.lock index 8ac7e4f067..086a9f3a70 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -367,7 +367,6 @@ dependencies = [ name = "sys_util" version = "0.1.0" dependencies = [ - "cc 1.0.15 (registry+https://github.com/rust-lang/crates.io-index)", "data_model 0.1.0", "libc 0.2.40 (registry+https://github.com/rust-lang/crates.io-index)", "poll_token_derive 0.1.0", diff --git a/crosvm_plugin/src/lib.rs b/crosvm_plugin/src/lib.rs index 06cb367ed3..754f74c03d 100644 --- a/crosvm_plugin/src/lib.rs +++ b/crosvm_plugin/src/lib.rs @@ -40,7 +40,7 @@ use libc::{E2BIG, ENOTCONN, EINVAL, EPROTO, ENOENT}; use protobuf::{Message, ProtobufEnum, RepeatedField, parse_from_bytes}; -use sys_util::Scm; +use sys_util::ScmSocket; use kvm::dirty_log_bitmap_size; @@ -250,7 +250,6 @@ impl Drop for StatUpdater { pub struct crosvm { id_allocator: Arc, socket: UnixDatagram, - fd_messager: Scm, request_buffer: Vec, response_buffer: Vec, vcpus: Arc>, @@ -261,7 +260,6 @@ impl crosvm { let mut crosvm = crosvm { id_allocator: Default::default(), socket, - fd_messager: Scm::new(MAX_DATAGRAM_FD), request_buffer: Vec::new(), response_buffer: vec![0; MAX_DATAGRAM_SIZE], vcpus: Default::default(), @@ -277,7 +275,6 @@ impl crosvm { crosvm { id_allocator, socket, - fd_messager: Scm::new(MAX_DATAGRAM_FD), request_buffer: Vec::new(), response_buffer: vec![0; MAX_DATAGRAM_SIZE], vcpus, @@ -296,16 +293,19 @@ impl crosvm { request .write_to_vec(&mut self.request_buffer) .map_err(proto_error_to_int)?; - self.fd_messager - .send(&self.socket, &[self.request_buffer.as_slice()], fds) + self.socket + .send_with_fds(self.request_buffer.as_slice(), fds) .map_err(|e| -e.errno())?; - let mut datagram_files = Vec::new(); - let msg_size = self.fd_messager - .recv(&self.socket, - &mut [&mut self.response_buffer], - &mut datagram_files) + let mut datagram_fds = [0; MAX_DATAGRAM_FD]; + let (msg_size, fd_count) = self.socket + .recv_with_fds(&mut self.response_buffer, &mut datagram_fds) .map_err(|e| -e.errno())?; + // Safe because the first fd_count fds from recv_with_fds are owned by us and valid. + let datagram_files = datagram_fds[..fd_count] + .iter() + .map(|&fd| unsafe { File::from_raw_fd(fd) }) + .collect(); let response: MainResponse = parse_from_bytes(&self.response_buffer[..msg_size]) .map_err(proto_error_to_int)?; diff --git a/devices/src/virtio/wl.rs b/devices/src/virtio/wl.rs index b946c51cca..7c6ae58c66 100644 --- a/devices/src/virtio/wl.rs +++ b/devices/src/virtio/wl.rs @@ -40,9 +40,7 @@ use std::io::{self, Seek, SeekFrom, Read}; use std::mem::{size_of, size_of_val}; #[cfg(feature = "wl-dmabuf")] use std::os::raw::{c_uint, c_ulonglong}; -use std::os::unix::io::{AsRawFd, RawFd}; -#[cfg(feature = "wl-dmabuf")] -use std::os::unix::io::FromRawFd; +use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; use std::os::unix::net::{UnixDatagram, UnixStream}; use std::path::{PathBuf, Path}; use std::rc::Rc; @@ -59,7 +57,7 @@ use data_model::*; use data_model::VolatileMemoryError; use resources::GpuMemoryDesc; -use sys_util::{Error, Result, EventFd, Scm, SharedMemory, GuestAddress, GuestMemory, +use sys_util::{Error, Result, EventFd, ScmSocket, SharedMemory, GuestAddress, GuestMemory, GuestMemoryError, PollContext, PollToken, FileFlags, pipe, round_up_to_page_size}; #[cfg(feature = "wl-dmabuf")] @@ -432,21 +430,19 @@ impl From for WlError { #[derive(Clone)] struct VmRequester { - inner: Rc>, + inner: Rc>, } impl VmRequester { fn new(vm_socket: UnixDatagram) -> VmRequester { - VmRequester { inner: Rc::new(RefCell::new((Scm::new(1), vm_socket))) } + VmRequester { inner: Rc::new(RefCell::new(vm_socket)) } } fn request(&self, request: VmRequest) -> WlResult { let mut inner = self.inner.borrow_mut(); - let (ref mut scm, ref mut vm_socket) = *inner; - request - .send(scm, vm_socket) - .map_err(WlError::VmControl)?; - VmResponse::recv(scm, vm_socket).map_err(WlError::VmControl) + let ref mut vm_socket = *inner; + request.send(vm_socket).map_err(WlError::VmControl)?; + VmResponse::recv(vm_socket).map_err(WlError::VmControl) } } @@ -803,9 +799,10 @@ impl WlVfd { } // Sends data/files from the guest to the host over this VFD. - fn send(&mut self, scm: &mut Scm, fds: &[RawFd], data: VolatileSlice) -> WlResult { + fn send(&mut self, fds: &[RawFd], data: VolatileSlice) -> WlResult { if let Some(ref socket) = self.socket { - scm.send(socket, &[data], fds) + socket + .send_with_fds(data, fds) .map_err(WlError::SendVfd)?; Ok(WlResp::Ok) } else if let Some((_, ref mut local_pipe)) = self.local_pipe { @@ -822,19 +819,24 @@ impl WlVfd { } // Receives data/files from the host for this VFD and queues it for the guest. - fn recv(&mut self, scm: &mut Scm, in_file_queue: &mut Vec) -> WlResult> { + fn recv(&mut self, in_file_queue: &mut Vec) -> WlResult> { if let Some(socket) = self.socket.take() { - let mut buf = Vec::new(); - buf.resize(IN_BUFFER_LEN, 0); - let old_len = in_file_queue.len(); + let mut buf = vec![0; IN_BUFFER_LEN]; + let mut fd_buf = [0; VIRTWL_SEND_MAX_ALLOCS]; // If any errors happen, the socket will get dropped, preventing more reading. - let len = scm.recv(&socket, &mut [&mut buf[..]], in_file_queue) + let (len, file_count) = socket + .recv_with_fds(&mut buf[..], &mut fd_buf) .map_err(WlError::RecvVfd)?; // If any data gets read, the put the socket back for future recv operations. - if len != 0 || in_file_queue.len() != old_len { + if len != 0 || file_count != 0 { buf.truncate(len); buf.shrink_to_fit(); self.socket = Some(socket); + // Safe because the first file_counts fds from recv_with_fds are owned by us and + // valid. + in_file_queue.extend(fd_buf[..file_count] + .iter() + .map(|&fd| unsafe { File::from_raw_fd(fd) })); return Ok(buf); } Ok(Vec::new()) @@ -893,7 +895,6 @@ struct WlState { poll_ctx: PollContext, vfds: Map, next_vfd_id: u32, - scm: Scm, in_file_queue: Vec, in_queue: VecDeque<(u32 /* vfd_id */, WlRecv)>, current_recv_vfd: Option, @@ -907,7 +908,6 @@ impl WlState { vm: VmRequester::new(vm_socket), poll_ctx: PollContext::new().expect("failed to create PollContext"), use_transition_flags, - scm: Scm::new(VIRTWL_SEND_MAX_ALLOCS), vfds: Map::new(), next_vfd_id: NEXT_VFD_ID_BASE, in_file_queue: Vec::new(), @@ -1126,7 +1126,7 @@ impl WlState { match self.vfds.get_mut(&vfd_id) { Some(vfd) => { - match vfd.send(&mut self.scm, &fds[..vfd_count], data)? { + match vfd.send(&fds[..vfd_count], data)? { WlResp::Ok => {} _ => return Ok(WlResp::InvalidType), } @@ -1145,7 +1145,7 @@ impl WlState { fn recv(&mut self, vfd_id: u32) -> WlResult<()> { let buf = match self.vfds.get_mut(&vfd_id) { - Some(vfd) => vfd.recv(&mut self.scm, &mut self.in_file_queue)?, + Some(vfd) => vfd.recv(&mut self.in_file_queue)?, None => return Ok(()), }; if self.in_file_queue.is_empty() && buf.is_empty() { diff --git a/src/linux.rs b/src/linux.rs index a142d1dda1..a5c87de643 100644 --- a/src/linux.rs +++ b/src/linux.rs @@ -729,8 +729,6 @@ fn run_control(mut linux: RunnableLinuxVm, control_sockets: Vec, balloon_host_socket: UnixDatagram, sigchld_fd: SignalFd) -> Result<()> { - const MAX_VM_FD_RECV: usize = 1; - // Paths to get the currently available memory and the low memory threshold. const LOWMEM_MARGIN: &'static str = "/sys/kernel/mm/chromeos-low_mem/margin"; const LOWMEM_AVAILABLE: &'static str = "/sys/kernel/mm/chromeos-low_mem/available"; @@ -817,8 +815,6 @@ fn run_control(mut linux: RunnableLinuxVm, } vcpu_thread_barrier.wait(); - let mut scm = Scm::new(MAX_VM_FD_RECV); - 'poll: loop { let events = { match poll_ctx.wait() { @@ -945,7 +941,7 @@ fn run_control(mut linux: RunnableLinuxVm, } Token::VmControl { index } => { if let Some(socket) = control_sockets.get(index as usize) { - match VmRequest::recv(&mut scm, socket.as_ref()) { + match VmRequest::recv(socket.as_ref()) { Ok(request) => { let mut running = true; let response = @@ -953,7 +949,7 @@ fn run_control(mut linux: RunnableLinuxVm, &mut linux.resources, &mut running, &balloon_host_socket); - if let Err(e) = response.send(&mut scm, socket.as_ref()) { + if let Err(e) = response.send(socket.as_ref()) { error!("failed to send VmResponse: {:?}", e); } if !running { diff --git a/src/main.rs b/src/main.rs index bd2b632726..de93f1bc9c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -47,7 +47,7 @@ use std::string::String; use std::thread::sleep; use std::time::Duration; -use sys_util::{Scm, getpid, kill_process_group, reap_child, syslog}; +use sys_util::{getpid, kill_process_group, reap_child, syslog}; use qcow::QcowFile; use argument::{Argument, set_arguments, print_help}; @@ -519,7 +519,6 @@ fn run_vm(args: std::env::Args) -> std::result::Result<(), ()> { } fn stop_vms(args: std::env::Args) -> std::result::Result<(), ()> { - let mut scm = Scm::new(1); if args.len() == 0 { print_help("crosvm stop", "VM_SOCKET...", &[]); println!("Stops the crosvm instance listening on each `VM_SOCKET` given."); @@ -532,7 +531,7 @@ fn stop_vms(args: std::env::Args) -> std::result::Result<(), ()> { Ok(s) }) { Ok(s) => { - if let Err(e) = VmRequest::Exit.send(&mut scm, &s) { + if let Err(e) = VmRequest::Exit.send(&s) { error!("failed to send stop request to socket at '{}': {:?}", socket_path, e); @@ -549,7 +548,6 @@ fn stop_vms(args: std::env::Args) -> std::result::Result<(), ()> { } fn balloon_vms(mut args: std::env::Args) -> std::result::Result<(), ()> { - let mut scm = Scm::new(1); if args.len() < 2 { print_help("crosvm balloon", "PAGE_ADJUST VM_SOCKET...", &[]); println!("Adjust the ballon size of the crosvm instance by `PAGE_ADJUST` pages, `PAGE_ADJUST` can be negative to shrink the balloon."); @@ -569,7 +567,7 @@ fn balloon_vms(mut args: std::env::Args) -> std::result::Result<(), ()> { Ok(s) }) { Ok(s) => { - if let Err(e) = VmRequest::BalloonAdjust(num_pages).send(&mut scm, &s) { + if let Err(e) = VmRequest::BalloonAdjust(num_pages).send(&s) { error!("failed to send balloon request to socket at '{}': {:?}", socket_path, e); diff --git a/src/plugin/process.rs b/src/plugin/process.rs index 0e8a74f6eb..71c16c0b2e 100644 --- a/src/plugin/process.rs +++ b/src/plugin/process.rs @@ -25,7 +25,7 @@ use protobuf::Message; use io_jail::Minijail; use kvm::{Vm, IoeventAddress, NoDatamatch, IrqSource, IrqRoute, PicId, dirty_log_bitmap_size}; use kvm_sys::{kvm_pic_state, kvm_ioapic_state, kvm_pit_state2}; -use sys_util::{EventFd, MemoryMapping, Killable, Scm, SharedMemory, GuestAddress, +use sys_util::{EventFd, MemoryMapping, Killable, ScmSocket, SharedMemory, GuestAddress, Result as SysResult, Error as SysError, SIGRTMIN}; use plugin_proto::*; @@ -119,9 +119,7 @@ pub struct Process { vcpu_sockets: Vec<(UnixDatagram, UnixDatagram)>, // Socket Transmission - scm: Scm, request_buffer: Vec, - datagram_files: Vec, response_buffer: Vec, } @@ -181,9 +179,7 @@ impl Process { per_vcpu_states, kill_evt: EventFd::new().map_err(Error::CreateEventFd)?, vcpu_sockets, - scm: Scm::new(1), request_buffer: vec![0; MAX_DATAGRAM_SIZE], - datagram_files: Vec::new(), response_buffer: Vec::new(), }) } @@ -444,10 +440,8 @@ impl Process { vcpu_handles: &[JoinHandle<()>], tap: Option<&Tap>) -> Result<()> { - let msg_size = self.scm - .recv(&self.request_sockets[index], - &mut [&mut self.request_buffer], - &mut self.datagram_files) + let (msg_size, request_file) = self.request_sockets[index] + .recv_with_fd(&mut self.request_buffer) .map_err(Error::PluginSocketRecv)?; if msg_size == 0 { @@ -475,7 +469,7 @@ impl Process { } } else if create.has_memory() { let memory = create.get_memory(); - match self.datagram_files.pop() { + match request_file { Some(memfd) => { Self::handle_memory(entry, vm, @@ -653,16 +647,13 @@ impl Process { response.errno = e.errno(); } - self.datagram_files.clear(); self.response_buffer.clear(); response .write_to_vec(&mut self.response_buffer) .map_err(Error::EncodeResponse)?; assert_ne!(self.response_buffer.len(), 0); - self.scm - .send(&self.request_sockets[index], - &[&self.response_buffer[..]], - &response_fds) + self.request_sockets[index] + .send_with_fds(&self.response_buffer[..], &response_fds) .map_err(Error::PluginSocketSend)?; Ok(()) diff --git a/sys_util/Cargo.toml b/sys_util/Cargo.toml index 21884f8c16..b3ba0a2bd5 100644 --- a/sys_util/Cargo.toml +++ b/sys_util/Cargo.toml @@ -2,13 +2,9 @@ name = "sys_util" version = "0.1.0" authors = ["The Chromium OS Authors"] -build = "build.rs" [dependencies] data_model = { path = "../data_model" } libc = "*" syscall_defines = { path = "../syscall_defines" } poll_token_derive = { path = "poll_token_derive" } - -[build-dependencies] -cc = "=1.0.15" diff --git a/sys_util/build.rs b/sys_util/build.rs deleted file mode 100644 index f1fc1df98e..0000000000 --- a/sys_util/build.rs +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright 2017 The Chromium OS Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -extern crate cc; - -fn main() { - cc::Build::new() - .file("sock_ctrl_msg.c") - .compile("sock_ctrl_msg"); -} diff --git a/sys_util/sock_ctrl_msg.c b/sys_util/sock_ctrl_msg.c deleted file mode 100644 index d569a47a0a..0000000000 --- a/sys_util/sock_ctrl_msg.c +++ /dev/null @@ -1,126 +0,0 @@ -// Copyright 2017 The Chromium OS Authors. All rights reserved. -// Use of this source code is governed by a BSD-style license that can be -// found in the LICENSE file. - -#include -#include // memcpy -#include // close -#include -#include // CMSG_* - -/* - * Returns the number of bytes the `cmsg_buffer` must be for the functions that take a cmsg_buffer - * in this module. - * Arguments: - * * `fd_count` - Maximum number of file descriptors to be sent or received via the cmsg. - */ -size_t scm_cmsg_buffer_len(size_t fd_count) -{ - return CMSG_SPACE(sizeof(int) * fd_count); -} - -/* - * Convenience wrapper around `sendmsg` that builds up the `msghdr` structure for you given the - * array of fds. - * Arguments: - * * `fd` - Unix domain socket to `sendmsg` on. - * * `outv` - Array of `outv_count` length `iovec`s that contain the data to send. - * * `outv_count` - Number of elements in `outv` array. - * * `cmsg_buffer` - A buffer that must be at least `scm_cmsg_buffer_len(fd_count)` bytes long. - * * `fds` - Array of `fd_count` file descriptors to send along with data. - * * `fd_count` - Number of elements in `fds` array. - * Returns: - * A non-negative number indicating how many bytes were sent on success or a negative errno on - * failure. - */ -ssize_t scm_sendmsg(int fd, const struct iovec *outv, size_t outv_count, uint8_t *cmsg_buffer, - const int *fds, size_t fd_count) -{ - if (fd < 0 || ((!cmsg_buffer || !fds) && fd_count > 0)) - return -EINVAL; - - struct msghdr msg; - memset(&msg, 0, sizeof(msg)); - msg.msg_iov = (struct iovec *)outv; // discard const, sendmsg won't mutate it - msg.msg_iovlen = outv_count; - - if (fd_count) { - msg.msg_control = cmsg_buffer; - msg.msg_controllen = scm_cmsg_buffer_len(fd_count); - - struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg); - cmsg->cmsg_level = SOL_SOCKET; - cmsg->cmsg_type = SCM_RIGHTS; - cmsg->cmsg_len = CMSG_LEN(fd_count * sizeof(int)); - memcpy(CMSG_DATA(cmsg), fds, fd_count * sizeof(int)); - - msg.msg_controllen = cmsg->cmsg_len; - } - - ssize_t bytes_sent = sendmsg(fd, &msg, MSG_NOSIGNAL); - if (bytes_sent == -1) - return -errno; - - return bytes_sent; -} - -/* - * Convenience wrapper around `recvmsg` that builds up the `msghdr` structure and returns up to - * `*fd_count` file descriptors in the given `fds` array. - * Arguments: - * * `fd` - Unix domain socket to `recvmsg` on. - * * `outv` - Array of `outv_count` length `iovec`s that will contain the received data. - * * `outv_count` - Number of elements in `outv` array. - * * `cmsg_buffer` - A buffer that must be at least `scm_cmsg_buffer_len(*fd_count)` bytes long. - * * `fds` - Array of `fd_count` file descriptors to receive along with data. - * * `fd_count` - Number of elements in `fds` array. - * Returns: - * A non-negative number indicating how many bytes were received on success or a negative errno on - * failure. - */ -ssize_t scm_recvmsg(int fd, struct iovec *outv, size_t outv_count, uint8_t *cmsg_buffer, int *fds, - size_t *fd_count) -{ - if (fd < 0 || !cmsg_buffer || !fds || !fd_count) - return -EINVAL; - - struct msghdr msg; - memset(&msg, 0, sizeof(msg)); - msg.msg_iov = outv; - msg.msg_iovlen = outv_count; - msg.msg_control = cmsg_buffer; - msg.msg_controllen = scm_cmsg_buffer_len(*fd_count); - - ssize_t total_read = recvmsg(fd, &msg, 0); - if (total_read == -1) - return -errno; - - if (total_read == 0 && CMSG_FIRSTHDR(&msg) == NULL) { - *fd_count = 0; - return 0; - } - - size_t fd_idx = 0; - struct cmsghdr *cmsg; - for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL; cmsg = CMSG_NXTHDR(&msg, cmsg)) { - if (cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_RIGHTS) - continue; - - size_t cmsg_fd_count = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int); - - int *cmsg_fds = (int *)CMSG_DATA(cmsg); - size_t cmsg_fd_idx; - for (cmsg_fd_idx = 0; cmsg_fd_idx < cmsg_fd_count; cmsg_fd_idx++) { - if (fd_idx < *fd_count) { - fds[fd_idx] = cmsg_fds[cmsg_fd_idx]; - fd_idx++; - } else { - close(cmsg_fds[cmsg_fd_idx]); - } - } - } - - *fd_count = fd_idx; - - return total_read; -} diff --git a/sys_util/src/sock_ctrl_msg.rs b/sys_util/src/sock_ctrl_msg.rs index 8a3ce39e7c..ebbf9db211 100644 --- a/sys_util/src/sock_ctrl_msg.rs +++ b/sys_util/src/sock_ctrl_msg.rs @@ -2,41 +2,200 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +//! Used to send and receive messages with file descriptors on sockets that accept control messages +//! (e.g. Unix domain sockets). + use std::fs::File; +use std::mem::size_of; use std::os::unix::io::{AsRawFd, RawFd, FromRawFd}; use std::os::unix::net::{UnixDatagram, UnixStream}; +use std::ptr::{copy_nonoverlapping, null_mut, write_unaligned}; -use libc::{c_void, iovec}; +use libc::{c_void, c_long, iovec, msghdr, cmsghdr, sendmsg, recvmsg, SOL_SOCKET, SCM_RIGHTS, + MSG_NOSIGNAL}; use data_model::VolatileSlice; use {Result, Error}; -// These functions are implemented in C because each of them requires complicated setup with CMSG -// macros. These macros are part of the system headers and are required to be used for portability -// reasons. In practice, the control message ABI can't change but using them is much easier and less -// error prone than trying to port these macros to rust. -extern "C" { - fn scm_cmsg_buffer_len(fd_count: usize) -> usize; - fn scm_sendmsg(fd: RawFd, - outv: *const iovec, - outv_count: usize, - cmsg_buffer: *mut u8, - fds: *const RawFd, - fd_count: usize) - -> isize; - fn scm_recvmsg(fd: RawFd, - outv: *mut iovec, - outv_count: usize, - cmsg_buffer: *mut u8, - fds: *mut RawFd, - fd_count: *mut usize) - -> isize; +// Each of the following macros performs the same function as their C counterparts. They are each +// macros because they are used to size statically allocated arrays. + +macro_rules! CMSG_ALIGN { + ($len:expr) => ( + (($len) + size_of::() - 1) & !(size_of::() - 1) + ) } -fn cmsg_buffer_len(fd_count: usize) -> usize { - // Safe because this function has no side effects, touches no pointers, and never fails. - unsafe { scm_cmsg_buffer_len(fd_count) } +macro_rules! CMSG_SPACE { + ($len:expr) => ( + size_of::() + CMSG_ALIGN!($len) + ) +} + +macro_rules! CMSG_LEN { + ($len:expr) => ( + size_of::() + ($len) + ) +} + + +// This function (macro in the C version) is not used in any compile time constant slots, so is just +// an ordinary function. The returned pointer is hard coded to be RawFd because that's all that this +// module supports. +#[allow(non_snake_case)] +#[inline(always)] +fn CMSG_DATA(cmsg_buffer: *mut u8) -> *mut RawFd { + // Essentially returns a pointer to just past the header. + (cmsg_buffer as *mut cmsghdr).wrapping_offset(1) as *mut RawFd +} + +// This function is like CMSG_NEXT, but safer because it reads only from references, although it +// does some pointer arithmetic on cmsg_ptr. +fn get_next_cmsg(msghdr: &msghdr, cmsg: &cmsghdr, cmsg_ptr: *mut u8) -> *mut u8 { + let next_cmsg = cmsg_ptr.wrapping_offset(CMSG_ALIGN!(cmsg.cmsg_len) as isize); + if next_cmsg + .wrapping_offset(1) + .wrapping_sub(msghdr.msg_control as usize) as usize > msghdr.msg_controllen { + null_mut() + } else { + next_cmsg + } +} + + +const CMSG_BUFFER_INLINE_CAPACITY: usize = CMSG_SPACE!(size_of::() * 32); + +enum CmsgBuffer { + Inline([u8; CMSG_BUFFER_INLINE_CAPACITY]), + Heap(Box<[u8]>), +} + +impl CmsgBuffer { + fn with_capacity(capacity: usize) -> CmsgBuffer { + if capacity <= CMSG_BUFFER_INLINE_CAPACITY { + CmsgBuffer::Inline([0u8; CMSG_BUFFER_INLINE_CAPACITY]) + } else { + CmsgBuffer::Heap(vec![0; capacity].into_boxed_slice()) + } + } + + fn as_mut_ptr(&mut self) -> *mut u8 { + match self { + CmsgBuffer::Inline(a) => a.as_mut_ptr(), + CmsgBuffer::Heap(a) => a.as_mut_ptr(), + } + } +} + +fn raw_sendmsg(fd: RawFd, out_data: D, out_fds: &[RawFd]) -> Result { + let cmsg_capacity = CMSG_SPACE!(size_of::() * out_fds.len()); + let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity); + + let mut iovec = iovec { + iov_base: out_data.as_ptr() as *mut c_void, + iov_len: out_data.size(), + }; + + let mut msg = msghdr { + msg_name: null_mut(), + msg_namelen: 0, + msg_iov: &mut iovec as *mut iovec, + msg_iovlen: 1, + msg_control: null_mut(), + msg_controllen: 0, + msg_flags: 0, + }; + + if !out_fds.is_empty() { + let cmsg = cmsghdr { + cmsg_len: CMSG_LEN!(size_of::() * out_fds.len()), + cmsg_level: SOL_SOCKET, + cmsg_type: SCM_RIGHTS, + }; + unsafe { + // Safe because cmsg_buffer was allocated to be large enough to contain cmsghdr. + write_unaligned(cmsg_buffer.as_mut_ptr() as *mut cmsghdr, cmsg); + // Safe because the cmsg_buffer was allocated to be large enough to hold out_fds.len() + // file descriptors. + copy_nonoverlapping(out_fds.as_ptr(), + CMSG_DATA(cmsg_buffer.as_mut_ptr()), + out_fds.len()); + } + + msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void; + msg.msg_controllen = cmsg_capacity; + } + + // Safe because the msghdr was properly constructed from valid (or null) pointers of the + // indicated length and we check the return value. + let write_count = unsafe { sendmsg(fd, &msg, MSG_NOSIGNAL) }; + + if write_count == -1 { + Err(Error::last()) + } else { + Ok(write_count as usize) + } +} + +fn raw_recvmsg(fd: RawFd, in_data: &mut [u8], in_fds: &mut [RawFd]) -> Result<(usize, usize)> { + let cmsg_capacity = CMSG_SPACE!(size_of::() * in_fds.len()); + let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity); + + let mut iovec = iovec { + iov_base: in_data.as_mut_ptr() as *mut c_void, + iov_len: in_data.len(), + }; + + let mut msg = msghdr { + msg_name: null_mut(), + msg_namelen: 0, + msg_iov: &mut iovec as *mut iovec, + msg_iovlen: 1, + msg_control: null_mut(), + msg_controllen: 0, + msg_flags: 0, + }; + + if !in_fds.is_empty() { + msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void; + msg.msg_controllen = cmsg_capacity; + } + + // Safe because the msghdr was properly constructed from valid (or null) pointers of the + // indicated length and we check the return value. + let total_read = unsafe { recvmsg(fd, &mut msg, 0) }; + + if total_read == -1 { + return Err(Error::last()); + } + + if total_read == 0 && msg.msg_controllen < size_of::() { + return Ok((0, 0)); + } + + let mut cmsg_ptr = msg.msg_control as *mut u8; + let mut in_fds_count = 0; + while !cmsg_ptr.is_null() { + // Safe because we checked that cmsg_ptr was non-null, and the loop is constructed such that + // that only happens when there is at least sizeof(cmsghdr) space after the pointer to read. + let cmsg = unsafe { (cmsg_ptr as *mut cmsghdr).read_unaligned() }; + + if cmsg.cmsg_level == SOL_SOCKET && cmsg.cmsg_type == SCM_RIGHTS { + let fd_count = (cmsg.cmsg_len - CMSG_LEN!(0)) / size_of::(); + unsafe { + copy_nonoverlapping(CMSG_DATA(cmsg_ptr), + in_fds[in_fds_count..(in_fds_count + fd_count)].as_mut_ptr(), + fd_count); + } + in_fds_count += fd_count; + } + + cmsg_ptr = get_next_cmsg(&msg, &cmsg, cmsg_ptr); + } + + + Ok((total_read as usize, in_fds_count)) } /// Trait for file descriptors can send and receive socket control messages via `sendmsg` and @@ -44,6 +203,69 @@ fn cmsg_buffer_len(fd_count: usize) -> usize { pub trait ScmSocket { /// Gets the file descriptor of this socket. fn socket_fd(&self) -> RawFd; + + /// Sends the given data and file descriptor over the socket. + /// + /// On success, returns the number of bytes sent. + /// + /// # Arguments + /// + /// * `buf` - A buffer of data to send on the `socket`. + /// * `fd` - A file descriptors to be sent. + fn send_with_fd(&self, buf: D, fd: RawFd) -> Result { + self.send_with_fds(buf, &[fd]) + } + + /// Sends the given data and file descriptors over the socket. + /// + /// On success, returns the number of bytes sent. + /// + /// # Arguments + /// + /// * `buf` - A buffer of data to send on the `socket`. + /// * `fds` - A list of file descriptors to be sent. + fn send_with_fds(&self, buf: D, fd: &[RawFd]) -> Result { + raw_sendmsg(self.socket_fd(), buf, fd) + } + + /// Receives data and potentially a file descriptor from the socket. + /// + /// On success, returns the number of bytes and an optional file descriptor. + /// + /// # Arguments + /// + /// * `buf` - A buffer to receive data from the socket.vm + fn recv_with_fd(&self, buf: &mut [u8]) -> Result<(usize, Option)> { + let mut fd = [0]; + let (read_count, fd_count) = self.recv_with_fds(buf, &mut fd)?; + let file = if fd_count == 0 { + None + } else { + // Safe because the first fd from recv_with_fds is owned by us and valid because this + // branch was taken. + Some(unsafe { File::from_raw_fd(fd[0]) }) + }; + Ok((read_count, file)) + + } + + + /// Receives data and file descriptors from the socket. + /// + /// On success, returns the number of bytes and file descriptors received as a tuple + /// `(bytes count, files count)`. + /// + /// # Arguments + /// + /// * `buf` - A buffer to receive data from the socket. + /// * `fds` - A slice of `RawFd`s to put the received file descriptors into. On success, the + /// number of valid file descriptors is indicated by the second element of the + /// returned tuple. The caller owns these file descriptors, but they will not be + /// closed on drop like a `File`-like type would be. It is recommended that each valid + /// file descriptor gets wrapped in a drop type that closes it after this returns. + fn recv_with_fds(&self, buf: &mut [u8], fds: &mut [RawFd]) -> Result<(usize, usize)> { + raw_recvmsg(self.socket_fd(), buf, fds) + } } impl ScmSocket for UnixDatagram { @@ -95,123 +317,6 @@ unsafe impl<'a> IntoIovec for VolatileSlice<'a> { } } -/// Used to send and receive messages with file descriptors on sockets that accept control messages -/// (e.g. Unix domain sockets). -pub struct Scm { - cmsg_buffer: Vec, - vecs: Vec, - fds: Vec, -} - -impl Scm { - /// Constructs a new Scm object with pre-allocated structures. - /// - /// # Arguments - /// - /// * `fd_count` - The maximum number of files that can be received per `recv` call. - pub fn new(fd_count: usize) -> Scm { - Scm { - cmsg_buffer: Vec::with_capacity(cmsg_buffer_len(fd_count)), - vecs: Vec::new(), - fds: vec![-1; fd_count], - } - } - - /// Sends the given data and file descriptors over the given `socket`. - /// - /// On success, returns the number of bytes sent. - /// - /// # Arguments - /// - /// * `socket` - A socket that supports socket control messages. - /// * `bufs` - A list of buffers to send on the `socket`. - /// * `fds` - A list of file descriptors to be sent. - pub fn send(&mut self, - socket: &T, - bufs: &[D], - fds: &[RawFd]) - -> Result { - let cmsg_buf_len = cmsg_buffer_len(fds.len()); - self.cmsg_buffer.reserve(cmsg_buf_len); - self.vecs.clear(); - for ref buf in bufs { - self.vecs - .push(iovec { - iov_base: buf.as_ptr() as *mut c_void, - iov_len: buf.size(), - }); - } - let write_count = unsafe { - // Safe because we are giving scm_sendmsg only valid pointers and lengths and we check - // the return value. - self.cmsg_buffer.set_len(cmsg_buf_len); - scm_sendmsg(socket.socket_fd(), - self.vecs.as_ptr(), - self.vecs.len(), - self.cmsg_buffer.as_mut_ptr(), - fds.as_ptr(), - fds.len()) - }; - - if write_count < 0 { - Err(Error::new(-write_count as i32)) - } else { - Ok(write_count as usize) - } - } - - /// Receives data and file descriptors from the given `socket` into the list of buffers. - /// - /// On success, returns the number of bytes received. - /// - /// # Arguments - /// - /// * `socket` - A socket that supports socket control messages. - /// * `bufs` - A list of buffers to receive data from the `socket`. The `recvmsg` call fills - /// these directly. - /// * `files` - A vector of `File`s to put the received file descriptors into. This vector is - /// not cleared and will have at most `fd_count` (specified in `Scm::new`) `File`s - /// added to it. - pub fn recv(&mut self, - socket: &T, - bufs: &mut [&mut [u8]], - files: &mut Vec) - -> Result { - let cmsg_buf_len = cmsg_buffer_len(files.len()); - self.cmsg_buffer.reserve(cmsg_buf_len); - self.vecs.clear(); - for buf in bufs { - self.vecs - .push(iovec { - iov_base: buf.as_mut_ptr() as *mut c_void, - iov_len: buf.len(), - }); - } - let mut fd_count = self.fds.len(); - let read_count = unsafe { - // Safe because we are giving scm_recvmsg only valid pointers and lengths and we check - // the return value. - self.cmsg_buffer.set_len(cmsg_buf_len); - scm_recvmsg(socket.socket_fd(), - self.vecs.as_mut_ptr(), - self.vecs.len(), - self.cmsg_buffer.as_mut_ptr(), - self.fds.as_mut_ptr(), - &mut fd_count as *mut usize) - }; - - if read_count < 0 { - Err(Error::new(-read_count as i32)) - } else { - // Safe because we have unqiue ownership of each fd we wrap with File. - for &fd in &self.fds[0..fd_count] { - files.push(unsafe { File::from_raw_fd(fd) }); - } - Ok(read_count as usize) - } - } -} - #[cfg(test)] mod tests { use super::*; @@ -228,22 +333,22 @@ mod tests { #[test] fn buffer_len() { - assert_eq!(cmsg_buffer_len(0), size_of::()); - assert_eq!(cmsg_buffer_len(1), + assert_eq!(CMSG_SPACE!(0 * size_of::()), size_of::()); + assert_eq!(CMSG_SPACE!(1 * size_of::()), size_of::() + size_of::()); if size_of::() == 4 { - assert_eq!(cmsg_buffer_len(2), + assert_eq!(CMSG_SPACE!(2 * size_of::()), size_of::() + size_of::()); - assert_eq!(cmsg_buffer_len(3), + assert_eq!(CMSG_SPACE!(3 * size_of::()), size_of::() + size_of::() * 2); - assert_eq!(cmsg_buffer_len(4), + assert_eq!(CMSG_SPACE!(4 * size_of::()), size_of::() + size_of::() * 2); } else if size_of::() == 8 { - assert_eq!(cmsg_buffer_len(2), + assert_eq!(CMSG_SPACE!(2 * size_of::()), size_of::() + size_of::() * 2); - assert_eq!(cmsg_buffer_len(3), + assert_eq!(CMSG_SPACE!(3 * size_of::()), size_of::() + size_of::() * 3); - assert_eq!(cmsg_buffer_len(4), + assert_eq!(CMSG_SPACE!(4 * size_of::()), size_of::() + size_of::() * 4); } } @@ -252,51 +357,42 @@ mod tests { fn send_recv_no_fd() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); - let write_count = scm.send(&s1, - [[1u8, 1, 2].as_ref(), [21, 34, 55].as_ref()].as_ref(), - &[]) + let write_count = s1.send_with_fds([1u8, 1, 2, 21, 34, 55].as_ref(), &[]) .expect("failed to send data"); assert_eq!(write_count, 6); - let mut buf1 = [0; 3]; - let mut buf2 = [0; 3]; - let mut bufs = [buf1.as_mut(), buf2.as_mut()]; - let mut files = Vec::new(); - let read_count = scm.recv(&s2, &mut bufs[..], &mut files) + let mut buf = [0; 6]; + let mut files = [0; 1]; + let (read_count, file_count) = s2.recv_with_fds(&mut buf[..], &mut files) .expect("failed to recv data"); assert_eq!(read_count, 6); - assert!(files.is_empty()); - assert_eq!(bufs[0], [1, 1, 2]); - assert_eq!(bufs[1], [21, 34, 55]); + assert_eq!(file_count, 0); + assert_eq!(buf, [1, 1, 2, 21, 34, 55]); } #[test] fn send_recv_only_fd() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); let evt = EventFd::new().expect("failed to create eventfd"); - let write_count = scm.send(&s1, &[[].as_ref()], &[evt.as_raw_fd()]) + let write_count = s1.send_with_fd([].as_ref(), evt.as_raw_fd()) .expect("failed to send fd"); assert_eq!(write_count, 0); - let mut files = Vec::new(); - let read_count = scm.recv(&s2, &mut [&mut []], &mut files) - .expect("failed to recv fd"); + let (read_count, file_opt) = s2.recv_with_fd(&mut []).expect("failed to recv fd"); + + let mut file = file_opt.unwrap(); assert_eq!(read_count, 0); - assert_eq!(files.len(), 1); - assert!(files[0].as_raw_fd() >= 0); - assert_ne!(files[0].as_raw_fd(), s1.as_raw_fd()); - assert_ne!(files[0].as_raw_fd(), s2.as_raw_fd()); - assert_ne!(files[0].as_raw_fd(), evt.as_raw_fd()); + assert!(file.as_raw_fd() >= 0); + assert_ne!(file.as_raw_fd(), s1.as_raw_fd()); + assert_ne!(file.as_raw_fd(), s2.as_raw_fd()); + assert_ne!(file.as_raw_fd(), evt.as_raw_fd()); - files[0] - .write(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) }) + file.write(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) }) .expect("failed to write to sent fd"); assert_eq!(evt.read().expect("failed to read from eventfd"), 1203); @@ -306,28 +402,28 @@ mod tests { fn send_recv_with_fd() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); let evt = EventFd::new().expect("failed to create eventfd"); - let write_count = scm.send(&s1, &[[237].as_ref()], &[evt.as_raw_fd()]) + let write_count = s1.send_with_fds([237].as_ref(), &[evt.as_raw_fd()]) .expect("failed to send fd"); assert_eq!(write_count, 1); - let mut files = Vec::new(); + let mut files = [0; 2]; let mut buf = [0u8]; - let read_count = scm.recv(&s2, &mut [&mut buf], &mut files) + let (read_count, file_count) = s2.recv_with_fds(&mut buf, &mut files) .expect("failed to recv fd"); assert_eq!(read_count, 1); assert_eq!(buf[0], 237); - assert_eq!(files.len(), 1); - assert!(files[0].as_raw_fd() >= 0); - assert_ne!(files[0].as_raw_fd(), s1.as_raw_fd()); - assert_ne!(files[0].as_raw_fd(), s2.as_raw_fd()); - assert_ne!(files[0].as_raw_fd(), evt.as_raw_fd()); + assert_eq!(file_count, 1); + assert!(files[0] >= 0); + assert_ne!(files[0], s1.as_raw_fd()); + assert_ne!(files[0], s2.as_raw_fd()); + assert_ne!(files[0], evt.as_raw_fd()); - files[0] - .write(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) }) + let mut file = unsafe { File::from_raw_fd(files[0]) }; + + file.write(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) }) .expect("failed to write to sent fd"); assert_eq!(evt.read().expect("failed to read from eventfd"), 1203); diff --git a/vm_control/src/lib.rs b/vm_control/src/lib.rs index ad123a1890..fae2e1d585 100644 --- a/vm_control/src/lib.rs +++ b/vm_control/src/lib.rs @@ -27,7 +27,8 @@ use libc::{ERANGE, EINVAL, ENODEV}; use byteorder::{LittleEndian, WriteBytesExt}; use data_model::{DataInit, Le32, Le64, VolatileMemory}; -use sys_util::{EventFd, Result, Error as SysError, MmapError, MemoryMapping, Scm, GuestAddress}; +use sys_util::{EventFd, Result, Error as SysError, MmapError, MemoryMapping, ScmSocket, + GuestAddress}; use resources::{GpuMemoryDesc, GpuMemoryPlaneDesc, SystemAllocator}; use kvm::{IoeventAddress, Vm}; @@ -134,11 +135,10 @@ impl VmRequest { /// Receive a `VmRequest` from the given socket. /// /// A `VmResponse` should be sent out over the given socket before another request is received. - pub fn recv(scm: &mut Scm, s: &UnixDatagram) -> VmControlResult { + pub fn recv(s: &UnixDatagram) -> VmControlResult { assert_eq!(VM_REQUEST_SIZE, std::mem::size_of::()); let mut buf = [0; VM_REQUEST_SIZE]; - let mut fds = Vec::new(); - let read = scm.recv(s, &mut [&mut buf], &mut fds) + let (read, file) = s.recv_with_fd(&mut buf) .map_err(|e| VmControlError::Recv(e))?; if read != VM_REQUEST_SIZE { return Err(VmControlError::BadSize(read)); @@ -150,7 +150,7 @@ impl VmRequest { match req.type_.into() { VM_REQUEST_TYPE_EXIT => Ok(VmRequest::Exit), VM_REQUEST_TYPE_REGISTER_MEMORY => { - let fd = fds.pop().ok_or(VmControlError::ExpectFd)?; + let fd = file.ok_or(VmControlError::ExpectFd)?; Ok(VmRequest::RegisterMemory(MaybeOwnedFd::Owned(fd), req.size.to_native() as usize)) } @@ -172,7 +172,7 @@ impl VmRequest { /// /// After this request is a sent, a `VmResponse` should be received before sending another /// request. - pub fn send(&self, scm: &mut Scm, s: &UnixDatagram) -> VmControlResult<()> { + pub fn send(&self, s: &UnixDatagram) -> VmControlResult<()> { assert_eq!(VM_REQUEST_SIZE, std::mem::size_of::()); let mut req = VmRequestStruct::default(); let mut fd_buf = [0; 1]; @@ -203,7 +203,7 @@ impl VmRequest { } let mut buf = [0; VM_REQUEST_SIZE]; buf.as_mut().get_ref(0).unwrap().store(req); - scm.send(s, &[buf.as_ref()], &fd_buf[..fd_len]) + s.send_with_fds(buf.as_ref(), &fd_buf[..fd_len]) .map_err(|e| VmControlError::Send(e))?; Ok(()) } @@ -332,10 +332,9 @@ impl VmResponse { /// Receive a `VmResponse` from the given socket. /// /// This should be called after the sending a `VmRequest` before sending another request. - pub fn recv(scm: &mut Scm, s: &UnixDatagram) -> VmControlResult { + pub fn recv(s: &UnixDatagram) -> VmControlResult { let mut buf = [0; VM_RESPONSE_SIZE]; - let mut fds = Vec::new(); - let read = scm.recv(s, &mut [&mut buf], &mut fds) + let (read, file) = s.recv_with_fd(&mut buf) .map_err(|e| VmControlError::Recv(e))?; if read != VM_RESPONSE_SIZE { return Err(VmControlError::BadSize(read)); @@ -354,7 +353,7 @@ impl VmResponse { }) } VM_RESPONSE_TYPE_ALLOCATE_AND_REGISTER_GPU_MEMORY => { - let fd = fds.pop().ok_or(VmControlError::ExpectFd)?; + let fd = file.ok_or(VmControlError::ExpectFd)?; Ok(VmResponse::AllocateAndRegisterGpuMemory { fd: MaybeOwnedFd::Owned(fd), pfn: resp.pfn.into(), @@ -377,7 +376,7 @@ impl VmResponse { /// /// This must be called after receiving a `VmRequest` to indicate the outcome of that request's /// execution. - pub fn send(&self, scm: &mut Scm, s: &UnixDatagram) -> VmControlResult<()> { + pub fn send(&self, s: &UnixDatagram) -> VmControlResult<()> { let mut resp = VmResponseStruct::default(); let mut fd_buf = [0; 1]; let mut fd_len = 0; @@ -408,7 +407,7 @@ impl VmResponse { } let mut buf = [0; VM_RESPONSE_SIZE]; buf.as_mut().get_ref(0).unwrap().store(resp); - scm.send(s, &[buf.as_ref()], &fd_buf[..fd_len]) + s.send_with_fds(buf.as_ref(), &fd_buf[..fd_len]) .map_err(|e| VmControlError::Send(e))?; Ok(()) } @@ -427,9 +426,8 @@ mod tests { #[test] fn request_exit() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); - VmRequest::Exit.send(&mut scm, &s1).unwrap(); - match VmRequest::recv(&mut scm, &s2).unwrap() { + VmRequest::Exit.send(&s1).unwrap(); + match VmRequest::recv(&s2).unwrap() { VmRequest::Exit => {} _ => panic!("recv wrong request variant"), } @@ -439,14 +437,13 @@ mod tests { fn request_register_memory() { if !kernel_has_memfd() { return; } let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); let shm_size: usize = 4096; let mut shm = SharedMemory::new(None).unwrap(); shm.set_size(shm_size as u64).unwrap(); VmRequest::RegisterMemory(MaybeOwnedFd::Borrowed(shm.as_raw_fd()), shm_size) - .send(&mut scm, &s1) + .send(&s1) .unwrap(); - match VmRequest::recv(&mut scm, &s2).unwrap() { + match VmRequest::recv(&s2).unwrap() { VmRequest::RegisterMemory(MaybeOwnedFd::Owned(fd), size) => { assert!(fd.as_raw_fd() >= 0); assert_eq!(size, shm_size); @@ -458,11 +455,8 @@ mod tests { #[test] fn request_unregister_memory() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); - VmRequest::UnregisterMemory(77) - .send(&mut scm, &s1) - .unwrap(); - match VmRequest::recv(&mut scm, &s2).unwrap() { + VmRequest::UnregisterMemory(77).send(&s1).unwrap(); + match VmRequest::recv(&s2).unwrap() { VmRequest::UnregisterMemory(slot) => assert_eq!(slot, 77), _ => panic!("recv wrong request variant"), } @@ -471,11 +465,10 @@ mod tests { #[test] fn request_expect_fd() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); let mut bad_request = [0; VM_REQUEST_SIZE]; bad_request[0] = VM_REQUEST_TYPE_REGISTER_MEMORY as u8; - scm.send(&s2, &[bad_request.as_ref()], &[]).unwrap(); - match VmRequest::recv(&mut scm, &s1) { + s2.send_with_fds(bad_request.as_ref(), &[]).unwrap(); + match VmRequest::recv(&s1) { Err(VmControlError::ExpectFd) => {} _ => panic!("recv wrong error variant"), } @@ -484,9 +477,8 @@ mod tests { #[test] fn request_no_data() { let (s1, _) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); s1.shutdown(Shutdown::Both).unwrap(); - match VmRequest::recv(&mut scm, &s1) { + match VmRequest::recv(&s1) { Err(VmControlError::BadSize(s)) => assert_eq!(s, 0), _ => panic!("recv wrong error variant"), } @@ -495,9 +487,8 @@ mod tests { #[test] fn request_bad_size() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); - scm.send(&s2, &[[12; 7].as_ref()], &[]).unwrap(); - match VmRequest::recv(&mut scm, &s1) { + s2.send_with_fds([12; 7].as_ref(), &[]).unwrap(); + match VmRequest::recv(&s1) { Err(VmControlError::BadSize(_)) => {} _ => panic!("recv wrong error variant"), } @@ -506,10 +497,9 @@ mod tests { #[test] fn request_invalid_type() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); - scm.send(&s2, &[[12; VM_REQUEST_SIZE].as_ref()], &[]) + s2.send_with_fds([12; VM_REQUEST_SIZE].as_ref(), &[]) .unwrap(); - match VmRequest::recv(&mut scm, &s1) { + match VmRequest::recv(&s1) { Err(VmControlError::InvalidType) => {} _ => panic!("recv wrong error variant"), } @@ -518,14 +508,21 @@ mod tests { #[test] fn request_allocate_and_register_gpu_memory() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); let gpu_width: u32 = 32; let gpu_height: u32 = 32; let gpu_format: u32 = 0x34325258; - let r = VmRequest::AllocateAndRegisterGpuMemory { width: gpu_width, height: gpu_height, format: gpu_format }; - r.send(&mut scm, &s1).unwrap(); - match VmRequest::recv(&mut scm, &s2).unwrap() { - VmRequest::AllocateAndRegisterGpuMemory {width, height, format} => { + let r = VmRequest::AllocateAndRegisterGpuMemory { + width: gpu_width, + height: gpu_height, + format: gpu_format, + }; + r.send(&s1).unwrap(); + match VmRequest::recv(&s2).unwrap() { + VmRequest::AllocateAndRegisterGpuMemory { + width, + height, + format, + } => { assert_eq!(width, gpu_width); assert_eq!(height, gpu_width); assert_eq!(format, gpu_format); @@ -537,9 +534,8 @@ mod tests { #[test] fn resp_ok() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); - VmResponse::Ok.send(&mut scm, &s1).unwrap(); - match VmResponse::recv(&mut scm, &s2).unwrap() { + VmResponse::Ok.send(&s1).unwrap(); + match VmResponse::recv(&s2).unwrap() { VmResponse::Ok => {} _ => panic!("recv wrong response variant"), } @@ -548,10 +544,9 @@ mod tests { #[test] fn resp_err() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); let r1 = VmResponse::Err(SysError::new(libc::EDESTADDRREQ)); - r1.send(&mut scm, &s1).unwrap(); - match VmResponse::recv(&mut scm, &s2).unwrap() { + r1.send(&s1).unwrap(); + match VmResponse::recv(&s2).unwrap() { VmResponse::Err(e) => { assert_eq!(e, SysError::new(libc::EDESTADDRREQ)); } @@ -562,12 +557,14 @@ mod tests { #[test] fn resp_memory() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); let memory_pfn = 55; let memory_slot = 66; - let r1 = VmResponse::RegisterMemory { pfn: memory_pfn, slot: memory_slot }; - r1.send(&mut scm, &s1).unwrap(); - match VmResponse::recv(&mut scm, &s2).unwrap() { + let r1 = VmResponse::RegisterMemory { + pfn: memory_pfn, + slot: memory_slot, + }; + r1.send(&s1).unwrap(); + match VmResponse::recv(&s2).unwrap() { VmResponse::RegisterMemory { pfn, slot } => { assert_eq!(pfn, memory_pfn); assert_eq!(slot, memory_slot); @@ -579,9 +576,8 @@ mod tests { #[test] fn resp_no_data() { let (s1, _) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); s1.shutdown(Shutdown::Both).unwrap(); - match VmResponse::recv(&mut scm, &s1) { + match VmResponse::recv(&s1) { Err(e) => { assert_eq!(e, VmControlError::BadSize(0)); } @@ -592,9 +588,8 @@ mod tests { #[test] fn resp_bad_size() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); - scm.send(&s2, &[[12; 7].as_ref()], &[]).unwrap(); - match VmResponse::recv(&mut scm, &s1) { + s2.send_with_fds([12; 7].as_ref(), &[]).unwrap(); + match VmResponse::recv(&s1) { Err(e) => { assert_eq!(e, VmControlError::BadSize(7)); } @@ -605,10 +600,9 @@ mod tests { #[test] fn resp_invalid_type() { let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); - scm.send(&s2, &[[12; VM_RESPONSE_SIZE].as_ref()], &[]) + s2.send_with_fds([12; VM_RESPONSE_SIZE].as_ref(), &[]) .unwrap(); - match VmResponse::recv(&mut scm, &s1) { + match VmResponse::recv(&s1) { Err(e) => { assert_eq!(e, VmControlError::InvalidType); } @@ -620,7 +614,6 @@ mod tests { fn resp_allocate_and_register_gpu_memory() { if !kernel_has_memfd() { return; } let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair"); - let mut scm = Scm::new(1); let shm_size: usize = 4096; let mut shm = SharedMemory::new(None).unwrap(); shm.set_size(shm_size as u64).unwrap(); @@ -637,9 +630,14 @@ mod tests { slot: memory_slot, desc: GpuMemoryDesc { planes: memory_planes }, }; - r1.send(&mut scm, &s1).unwrap(); - match VmResponse::recv(&mut scm, &s2).unwrap() { - VmResponse::AllocateAndRegisterGpuMemory { fd, pfn, slot, desc } => { + r1.send(&s1).unwrap(); + match VmResponse::recv(&s2).unwrap() { + VmResponse::AllocateAndRegisterGpuMemory { + fd, + pfn, + slot, + desc, + } => { assert!(fd.as_raw_fd() >= 0); assert_eq!(pfn, memory_pfn); assert_eq!(slot, memory_slot);