sys_util: remove Scm struct and sock_ctrl_msg C library

The Scm object was made to reduce the number of heap allocations in
the hot paths of poll loops, at the cost of some code complexity. As it
turns out, the number of file descriptors being sent or received is
usually just one or limited to a fixed amount that can easily be covered
with a fixed size stack allocated buffer.

This change implements that solution, with heap allocation as a backup
in the rare case that many file descriptors must be sent or received.

This change also moves the msg and cmsg manipulation code out of C and
into pure Rust. The move was necessary to allocate the correct amount
of buffer space at compile time. It also improves safety by reducing the
scope of unsafe code. Deleting the code for building the C library is
also a nice bonus.

Finally, the removal of the commonly used Scm struct required
transitioning existing usage to the ScmSocket trait based methods. This
includes all those changes.

TEST=cargo test
BUG=None

Change-Id: If27ba297f5416dd9b8bc686ce740866912fa0aa0
Reviewed-on: https://chromium-review.googlesource.com/1186146
Commit-Ready: ChromeOS CL Exonerator Bot <chromiumos-cl-exonerator@appspot.gserviceaccount.com>
Tested-by: Zach Reizner <zachr@chromium.org>
Reviewed-by: Zach Reizner <zachr@chromium.org>
This commit is contained in:
Zach Reizner 2018-08-23 13:34:56 -07:00 committed by chrome-bot
parent 4a55609f50
commit a99954cb7c
11 changed files with 385 additions and 448 deletions

1
Cargo.lock generated
View file

@ -367,7 +367,6 @@ dependencies = [
name = "sys_util"
version = "0.1.0"
dependencies = [
"cc 1.0.15 (registry+https://github.com/rust-lang/crates.io-index)",
"data_model 0.1.0",
"libc 0.2.40 (registry+https://github.com/rust-lang/crates.io-index)",
"poll_token_derive 0.1.0",

View file

@ -40,7 +40,7 @@ use libc::{E2BIG, ENOTCONN, EINVAL, EPROTO, ENOENT};
use protobuf::{Message, ProtobufEnum, RepeatedField, parse_from_bytes};
use sys_util::Scm;
use sys_util::ScmSocket;
use kvm::dirty_log_bitmap_size;
@ -250,7 +250,6 @@ impl Drop for StatUpdater {
pub struct crosvm {
id_allocator: Arc<IdAllocator>,
socket: UnixDatagram,
fd_messager: Scm,
request_buffer: Vec<u8>,
response_buffer: Vec<u8>,
vcpus: Arc<Vec<crosvm_vcpu>>,
@ -261,7 +260,6 @@ impl crosvm {
let mut crosvm = crosvm {
id_allocator: Default::default(),
socket,
fd_messager: Scm::new(MAX_DATAGRAM_FD),
request_buffer: Vec::new(),
response_buffer: vec![0; MAX_DATAGRAM_SIZE],
vcpus: Default::default(),
@ -277,7 +275,6 @@ impl crosvm {
crosvm {
id_allocator,
socket,
fd_messager: Scm::new(MAX_DATAGRAM_FD),
request_buffer: Vec::new(),
response_buffer: vec![0; MAX_DATAGRAM_SIZE],
vcpus,
@ -296,16 +293,19 @@ impl crosvm {
request
.write_to_vec(&mut self.request_buffer)
.map_err(proto_error_to_int)?;
self.fd_messager
.send(&self.socket, &[self.request_buffer.as_slice()], fds)
self.socket
.send_with_fds(self.request_buffer.as_slice(), fds)
.map_err(|e| -e.errno())?;
let mut datagram_files = Vec::new();
let msg_size = self.fd_messager
.recv(&self.socket,
&mut [&mut self.response_buffer],
&mut datagram_files)
let mut datagram_fds = [0; MAX_DATAGRAM_FD];
let (msg_size, fd_count) = self.socket
.recv_with_fds(&mut self.response_buffer, &mut datagram_fds)
.map_err(|e| -e.errno())?;
// Safe because the first fd_count fds from recv_with_fds are owned by us and valid.
let datagram_files = datagram_fds[..fd_count]
.iter()
.map(|&fd| unsafe { File::from_raw_fd(fd) })
.collect();
let response: MainResponse = parse_from_bytes(&self.response_buffer[..msg_size])
.map_err(proto_error_to_int)?;

View file

@ -40,9 +40,7 @@ use std::io::{self, Seek, SeekFrom, Read};
use std::mem::{size_of, size_of_val};
#[cfg(feature = "wl-dmabuf")]
use std::os::raw::{c_uint, c_ulonglong};
use std::os::unix::io::{AsRawFd, RawFd};
#[cfg(feature = "wl-dmabuf")]
use std::os::unix::io::FromRawFd;
use std::os::unix::io::{AsRawFd, FromRawFd, RawFd};
use std::os::unix::net::{UnixDatagram, UnixStream};
use std::path::{PathBuf, Path};
use std::rc::Rc;
@ -59,7 +57,7 @@ use data_model::*;
use data_model::VolatileMemoryError;
use resources::GpuMemoryDesc;
use sys_util::{Error, Result, EventFd, Scm, SharedMemory, GuestAddress, GuestMemory,
use sys_util::{Error, Result, EventFd, ScmSocket, SharedMemory, GuestAddress, GuestMemory,
GuestMemoryError, PollContext, PollToken, FileFlags, pipe, round_up_to_page_size};
#[cfg(feature = "wl-dmabuf")]
@ -432,21 +430,19 @@ impl From<VolatileMemoryError> for WlError {
#[derive(Clone)]
struct VmRequester {
inner: Rc<RefCell<(Scm, UnixDatagram)>>,
inner: Rc<RefCell<UnixDatagram>>,
}
impl VmRequester {
fn new(vm_socket: UnixDatagram) -> VmRequester {
VmRequester { inner: Rc::new(RefCell::new((Scm::new(1), vm_socket))) }
VmRequester { inner: Rc::new(RefCell::new(vm_socket)) }
}
fn request(&self, request: VmRequest) -> WlResult<VmResponse> {
let mut inner = self.inner.borrow_mut();
let (ref mut scm, ref mut vm_socket) = *inner;
request
.send(scm, vm_socket)
.map_err(WlError::VmControl)?;
VmResponse::recv(scm, vm_socket).map_err(WlError::VmControl)
let ref mut vm_socket = *inner;
request.send(vm_socket).map_err(WlError::VmControl)?;
VmResponse::recv(vm_socket).map_err(WlError::VmControl)
}
}
@ -803,9 +799,10 @@ impl WlVfd {
}
// Sends data/files from the guest to the host over this VFD.
fn send(&mut self, scm: &mut Scm, fds: &[RawFd], data: VolatileSlice) -> WlResult<WlResp> {
fn send(&mut self, fds: &[RawFd], data: VolatileSlice) -> WlResult<WlResp> {
if let Some(ref socket) = self.socket {
scm.send(socket, &[data], fds)
socket
.send_with_fds(data, fds)
.map_err(WlError::SendVfd)?;
Ok(WlResp::Ok)
} else if let Some((_, ref mut local_pipe)) = self.local_pipe {
@ -822,19 +819,24 @@ impl WlVfd {
}
// Receives data/files from the host for this VFD and queues it for the guest.
fn recv(&mut self, scm: &mut Scm, in_file_queue: &mut Vec<File>) -> WlResult<Vec<u8>> {
fn recv(&mut self, in_file_queue: &mut Vec<File>) -> WlResult<Vec<u8>> {
if let Some(socket) = self.socket.take() {
let mut buf = Vec::new();
buf.resize(IN_BUFFER_LEN, 0);
let old_len = in_file_queue.len();
let mut buf = vec![0; IN_BUFFER_LEN];
let mut fd_buf = [0; VIRTWL_SEND_MAX_ALLOCS];
// If any errors happen, the socket will get dropped, preventing more reading.
let len = scm.recv(&socket, &mut [&mut buf[..]], in_file_queue)
let (len, file_count) = socket
.recv_with_fds(&mut buf[..], &mut fd_buf)
.map_err(WlError::RecvVfd)?;
// If any data gets read, the put the socket back for future recv operations.
if len != 0 || in_file_queue.len() != old_len {
if len != 0 || file_count != 0 {
buf.truncate(len);
buf.shrink_to_fit();
self.socket = Some(socket);
// Safe because the first file_counts fds from recv_with_fds are owned by us and
// valid.
in_file_queue.extend(fd_buf[..file_count]
.iter()
.map(|&fd| unsafe { File::from_raw_fd(fd) }));
return Ok(buf);
}
Ok(Vec::new())
@ -893,7 +895,6 @@ struct WlState {
poll_ctx: PollContext<u32>,
vfds: Map<u32, WlVfd>,
next_vfd_id: u32,
scm: Scm,
in_file_queue: Vec<File>,
in_queue: VecDeque<(u32 /* vfd_id */, WlRecv)>,
current_recv_vfd: Option<u32>,
@ -907,7 +908,6 @@ impl WlState {
vm: VmRequester::new(vm_socket),
poll_ctx: PollContext::new().expect("failed to create PollContext"),
use_transition_flags,
scm: Scm::new(VIRTWL_SEND_MAX_ALLOCS),
vfds: Map::new(),
next_vfd_id: NEXT_VFD_ID_BASE,
in_file_queue: Vec::new(),
@ -1126,7 +1126,7 @@ impl WlState {
match self.vfds.get_mut(&vfd_id) {
Some(vfd) => {
match vfd.send(&mut self.scm, &fds[..vfd_count], data)? {
match vfd.send(&fds[..vfd_count], data)? {
WlResp::Ok => {}
_ => return Ok(WlResp::InvalidType),
}
@ -1145,7 +1145,7 @@ impl WlState {
fn recv(&mut self, vfd_id: u32) -> WlResult<()> {
let buf = match self.vfds.get_mut(&vfd_id) {
Some(vfd) => vfd.recv(&mut self.scm, &mut self.in_file_queue)?,
Some(vfd) => vfd.recv(&mut self.in_file_queue)?,
None => return Ok(()),
};
if self.in_file_queue.is_empty() && buf.is_empty() {

View file

@ -729,8 +729,6 @@ fn run_control(mut linux: RunnableLinuxVm,
control_sockets: Vec<UnlinkUnixDatagram>,
balloon_host_socket: UnixDatagram,
sigchld_fd: SignalFd) -> Result<()> {
const MAX_VM_FD_RECV: usize = 1;
// Paths to get the currently available memory and the low memory threshold.
const LOWMEM_MARGIN: &'static str = "/sys/kernel/mm/chromeos-low_mem/margin";
const LOWMEM_AVAILABLE: &'static str = "/sys/kernel/mm/chromeos-low_mem/available";
@ -817,8 +815,6 @@ fn run_control(mut linux: RunnableLinuxVm,
}
vcpu_thread_barrier.wait();
let mut scm = Scm::new(MAX_VM_FD_RECV);
'poll: loop {
let events = {
match poll_ctx.wait() {
@ -945,7 +941,7 @@ fn run_control(mut linux: RunnableLinuxVm,
}
Token::VmControl { index } => {
if let Some(socket) = control_sockets.get(index as usize) {
match VmRequest::recv(&mut scm, socket.as_ref()) {
match VmRequest::recv(socket.as_ref()) {
Ok(request) => {
let mut running = true;
let response =
@ -953,7 +949,7 @@ fn run_control(mut linux: RunnableLinuxVm,
&mut linux.resources,
&mut running,
&balloon_host_socket);
if let Err(e) = response.send(&mut scm, socket.as_ref()) {
if let Err(e) = response.send(socket.as_ref()) {
error!("failed to send VmResponse: {:?}", e);
}
if !running {

View file

@ -47,7 +47,7 @@ use std::string::String;
use std::thread::sleep;
use std::time::Duration;
use sys_util::{Scm, getpid, kill_process_group, reap_child, syslog};
use sys_util::{getpid, kill_process_group, reap_child, syslog};
use qcow::QcowFile;
use argument::{Argument, set_arguments, print_help};
@ -519,7 +519,6 @@ fn run_vm(args: std::env::Args) -> std::result::Result<(), ()> {
}
fn stop_vms(args: std::env::Args) -> std::result::Result<(), ()> {
let mut scm = Scm::new(1);
if args.len() == 0 {
print_help("crosvm stop", "VM_SOCKET...", &[]);
println!("Stops the crosvm instance listening on each `VM_SOCKET` given.");
@ -532,7 +531,7 @@ fn stop_vms(args: std::env::Args) -> std::result::Result<(), ()> {
Ok(s)
}) {
Ok(s) => {
if let Err(e) = VmRequest::Exit.send(&mut scm, &s) {
if let Err(e) = VmRequest::Exit.send(&s) {
error!("failed to send stop request to socket at '{}': {:?}",
socket_path,
e);
@ -549,7 +548,6 @@ fn stop_vms(args: std::env::Args) -> std::result::Result<(), ()> {
}
fn balloon_vms(mut args: std::env::Args) -> std::result::Result<(), ()> {
let mut scm = Scm::new(1);
if args.len() < 2 {
print_help("crosvm balloon", "PAGE_ADJUST VM_SOCKET...", &[]);
println!("Adjust the ballon size of the crosvm instance by `PAGE_ADJUST` pages, `PAGE_ADJUST` can be negative to shrink the balloon.");
@ -569,7 +567,7 @@ fn balloon_vms(mut args: std::env::Args) -> std::result::Result<(), ()> {
Ok(s)
}) {
Ok(s) => {
if let Err(e) = VmRequest::BalloonAdjust(num_pages).send(&mut scm, &s) {
if let Err(e) = VmRequest::BalloonAdjust(num_pages).send(&s) {
error!("failed to send balloon request to socket at '{}': {:?}",
socket_path,
e);

View file

@ -25,7 +25,7 @@ use protobuf::Message;
use io_jail::Minijail;
use kvm::{Vm, IoeventAddress, NoDatamatch, IrqSource, IrqRoute, PicId, dirty_log_bitmap_size};
use kvm_sys::{kvm_pic_state, kvm_ioapic_state, kvm_pit_state2};
use sys_util::{EventFd, MemoryMapping, Killable, Scm, SharedMemory, GuestAddress,
use sys_util::{EventFd, MemoryMapping, Killable, ScmSocket, SharedMemory, GuestAddress,
Result as SysResult, Error as SysError, SIGRTMIN};
use plugin_proto::*;
@ -119,9 +119,7 @@ pub struct Process {
vcpu_sockets: Vec<(UnixDatagram, UnixDatagram)>,
// Socket Transmission
scm: Scm,
request_buffer: Vec<u8>,
datagram_files: Vec<File>,
response_buffer: Vec<u8>,
}
@ -181,9 +179,7 @@ impl Process {
per_vcpu_states,
kill_evt: EventFd::new().map_err(Error::CreateEventFd)?,
vcpu_sockets,
scm: Scm::new(1),
request_buffer: vec![0; MAX_DATAGRAM_SIZE],
datagram_files: Vec::new(),
response_buffer: Vec::new(),
})
}
@ -444,10 +440,8 @@ impl Process {
vcpu_handles: &[JoinHandle<()>],
tap: Option<&Tap>)
-> Result<()> {
let msg_size = self.scm
.recv(&self.request_sockets[index],
&mut [&mut self.request_buffer],
&mut self.datagram_files)
let (msg_size, request_file) = self.request_sockets[index]
.recv_with_fd(&mut self.request_buffer)
.map_err(Error::PluginSocketRecv)?;
if msg_size == 0 {
@ -475,7 +469,7 @@ impl Process {
}
} else if create.has_memory() {
let memory = create.get_memory();
match self.datagram_files.pop() {
match request_file {
Some(memfd) => {
Self::handle_memory(entry,
vm,
@ -653,16 +647,13 @@ impl Process {
response.errno = e.errno();
}
self.datagram_files.clear();
self.response_buffer.clear();
response
.write_to_vec(&mut self.response_buffer)
.map_err(Error::EncodeResponse)?;
assert_ne!(self.response_buffer.len(), 0);
self.scm
.send(&self.request_sockets[index],
&[&self.response_buffer[..]],
&response_fds)
self.request_sockets[index]
.send_with_fds(&self.response_buffer[..], &response_fds)
.map_err(Error::PluginSocketSend)?;
Ok(())

View file

@ -2,13 +2,9 @@
name = "sys_util"
version = "0.1.0"
authors = ["The Chromium OS Authors"]
build = "build.rs"
[dependencies]
data_model = { path = "../data_model" }
libc = "*"
syscall_defines = { path = "../syscall_defines" }
poll_token_derive = { path = "poll_token_derive" }
[build-dependencies]
cc = "=1.0.15"

View file

@ -1,11 +0,0 @@
// Copyright 2017 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
extern crate cc;
fn main() {
cc::Build::new()
.file("sock_ctrl_msg.c")
.compile("sock_ctrl_msg");
}

View file

@ -1,126 +0,0 @@
// Copyright 2017 The Chromium OS Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
#include <stdint.h>
#include <string.h> // memcpy
#include <unistd.h> // close
#include <sys/errno.h>
#include <sys/socket.h> // CMSG_*
/*
* Returns the number of bytes the `cmsg_buffer` must be for the functions that take a cmsg_buffer
* in this module.
* Arguments:
* * `fd_count` - Maximum number of file descriptors to be sent or received via the cmsg.
*/
size_t scm_cmsg_buffer_len(size_t fd_count)
{
return CMSG_SPACE(sizeof(int) * fd_count);
}
/*
* Convenience wrapper around `sendmsg` that builds up the `msghdr` structure for you given the
* array of fds.
* Arguments:
* * `fd` - Unix domain socket to `sendmsg` on.
* * `outv` - Array of `outv_count` length `iovec`s that contain the data to send.
* * `outv_count` - Number of elements in `outv` array.
* * `cmsg_buffer` - A buffer that must be at least `scm_cmsg_buffer_len(fd_count)` bytes long.
* * `fds` - Array of `fd_count` file descriptors to send along with data.
* * `fd_count` - Number of elements in `fds` array.
* Returns:
* A non-negative number indicating how many bytes were sent on success or a negative errno on
* failure.
*/
ssize_t scm_sendmsg(int fd, const struct iovec *outv, size_t outv_count, uint8_t *cmsg_buffer,
const int *fds, size_t fd_count)
{
if (fd < 0 || ((!cmsg_buffer || !fds) && fd_count > 0))
return -EINVAL;
struct msghdr msg;
memset(&msg, 0, sizeof(msg));
msg.msg_iov = (struct iovec *)outv; // discard const, sendmsg won't mutate it
msg.msg_iovlen = outv_count;
if (fd_count) {
msg.msg_control = cmsg_buffer;
msg.msg_controllen = scm_cmsg_buffer_len(fd_count);
struct cmsghdr *cmsg = CMSG_FIRSTHDR(&msg);
cmsg->cmsg_level = SOL_SOCKET;
cmsg->cmsg_type = SCM_RIGHTS;
cmsg->cmsg_len = CMSG_LEN(fd_count * sizeof(int));
memcpy(CMSG_DATA(cmsg), fds, fd_count * sizeof(int));
msg.msg_controllen = cmsg->cmsg_len;
}
ssize_t bytes_sent = sendmsg(fd, &msg, MSG_NOSIGNAL);
if (bytes_sent == -1)
return -errno;
return bytes_sent;
}
/*
* Convenience wrapper around `recvmsg` that builds up the `msghdr` structure and returns up to
* `*fd_count` file descriptors in the given `fds` array.
* Arguments:
* * `fd` - Unix domain socket to `recvmsg` on.
* * `outv` - Array of `outv_count` length `iovec`s that will contain the received data.
* * `outv_count` - Number of elements in `outv` array.
* * `cmsg_buffer` - A buffer that must be at least `scm_cmsg_buffer_len(*fd_count)` bytes long.
* * `fds` - Array of `fd_count` file descriptors to receive along with data.
* * `fd_count` - Number of elements in `fds` array.
* Returns:
* A non-negative number indicating how many bytes were received on success or a negative errno on
* failure.
*/
ssize_t scm_recvmsg(int fd, struct iovec *outv, size_t outv_count, uint8_t *cmsg_buffer, int *fds,
size_t *fd_count)
{
if (fd < 0 || !cmsg_buffer || !fds || !fd_count)
return -EINVAL;
struct msghdr msg;
memset(&msg, 0, sizeof(msg));
msg.msg_iov = outv;
msg.msg_iovlen = outv_count;
msg.msg_control = cmsg_buffer;
msg.msg_controllen = scm_cmsg_buffer_len(*fd_count);
ssize_t total_read = recvmsg(fd, &msg, 0);
if (total_read == -1)
return -errno;
if (total_read == 0 && CMSG_FIRSTHDR(&msg) == NULL) {
*fd_count = 0;
return 0;
}
size_t fd_idx = 0;
struct cmsghdr *cmsg;
for (cmsg = CMSG_FIRSTHDR(&msg); cmsg != NULL; cmsg = CMSG_NXTHDR(&msg, cmsg)) {
if (cmsg->cmsg_level != SOL_SOCKET || cmsg->cmsg_type != SCM_RIGHTS)
continue;
size_t cmsg_fd_count = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
int *cmsg_fds = (int *)CMSG_DATA(cmsg);
size_t cmsg_fd_idx;
for (cmsg_fd_idx = 0; cmsg_fd_idx < cmsg_fd_count; cmsg_fd_idx++) {
if (fd_idx < *fd_count) {
fds[fd_idx] = cmsg_fds[cmsg_fd_idx];
fd_idx++;
} else {
close(cmsg_fds[cmsg_fd_idx]);
}
}
}
*fd_count = fd_idx;
return total_read;
}

View file

@ -2,41 +2,200 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//! Used to send and receive messages with file descriptors on sockets that accept control messages
//! (e.g. Unix domain sockets).
use std::fs::File;
use std::mem::size_of;
use std::os::unix::io::{AsRawFd, RawFd, FromRawFd};
use std::os::unix::net::{UnixDatagram, UnixStream};
use std::ptr::{copy_nonoverlapping, null_mut, write_unaligned};
use libc::{c_void, iovec};
use libc::{c_void, c_long, iovec, msghdr, cmsghdr, sendmsg, recvmsg, SOL_SOCKET, SCM_RIGHTS,
MSG_NOSIGNAL};
use data_model::VolatileSlice;
use {Result, Error};
// These functions are implemented in C because each of them requires complicated setup with CMSG
// macros. These macros are part of the system headers and are required to be used for portability
// reasons. In practice, the control message ABI can't change but using them is much easier and less
// error prone than trying to port these macros to rust.
extern "C" {
fn scm_cmsg_buffer_len(fd_count: usize) -> usize;
fn scm_sendmsg(fd: RawFd,
outv: *const iovec,
outv_count: usize,
cmsg_buffer: *mut u8,
fds: *const RawFd,
fd_count: usize)
-> isize;
fn scm_recvmsg(fd: RawFd,
outv: *mut iovec,
outv_count: usize,
cmsg_buffer: *mut u8,
fds: *mut RawFd,
fd_count: *mut usize)
-> isize;
// Each of the following macros performs the same function as their C counterparts. They are each
// macros because they are used to size statically allocated arrays.
macro_rules! CMSG_ALIGN {
($len:expr) => (
(($len) + size_of::<c_long>() - 1) & !(size_of::<c_long>() - 1)
)
}
fn cmsg_buffer_len(fd_count: usize) -> usize {
// Safe because this function has no side effects, touches no pointers, and never fails.
unsafe { scm_cmsg_buffer_len(fd_count) }
macro_rules! CMSG_SPACE {
($len:expr) => (
size_of::<cmsghdr>() + CMSG_ALIGN!($len)
)
}
macro_rules! CMSG_LEN {
($len:expr) => (
size_of::<cmsghdr>() + ($len)
)
}
// This function (macro in the C version) is not used in any compile time constant slots, so is just
// an ordinary function. The returned pointer is hard coded to be RawFd because that's all that this
// module supports.
#[allow(non_snake_case)]
#[inline(always)]
fn CMSG_DATA(cmsg_buffer: *mut u8) -> *mut RawFd {
// Essentially returns a pointer to just past the header.
(cmsg_buffer as *mut cmsghdr).wrapping_offset(1) as *mut RawFd
}
// This function is like CMSG_NEXT, but safer because it reads only from references, although it
// does some pointer arithmetic on cmsg_ptr.
fn get_next_cmsg(msghdr: &msghdr, cmsg: &cmsghdr, cmsg_ptr: *mut u8) -> *mut u8 {
let next_cmsg = cmsg_ptr.wrapping_offset(CMSG_ALIGN!(cmsg.cmsg_len) as isize);
if next_cmsg
.wrapping_offset(1)
.wrapping_sub(msghdr.msg_control as usize) as usize > msghdr.msg_controllen {
null_mut()
} else {
next_cmsg
}
}
const CMSG_BUFFER_INLINE_CAPACITY: usize = CMSG_SPACE!(size_of::<RawFd>() * 32);
enum CmsgBuffer {
Inline([u8; CMSG_BUFFER_INLINE_CAPACITY]),
Heap(Box<[u8]>),
}
impl CmsgBuffer {
fn with_capacity(capacity: usize) -> CmsgBuffer {
if capacity <= CMSG_BUFFER_INLINE_CAPACITY {
CmsgBuffer::Inline([0u8; CMSG_BUFFER_INLINE_CAPACITY])
} else {
CmsgBuffer::Heap(vec![0; capacity].into_boxed_slice())
}
}
fn as_mut_ptr(&mut self) -> *mut u8 {
match self {
CmsgBuffer::Inline(a) => a.as_mut_ptr(),
CmsgBuffer::Heap(a) => a.as_mut_ptr(),
}
}
}
fn raw_sendmsg<D: IntoIovec>(fd: RawFd, out_data: D, out_fds: &[RawFd]) -> Result<usize> {
let cmsg_capacity = CMSG_SPACE!(size_of::<RawFd>() * out_fds.len());
let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity);
let mut iovec = iovec {
iov_base: out_data.as_ptr() as *mut c_void,
iov_len: out_data.size(),
};
let mut msg = msghdr {
msg_name: null_mut(),
msg_namelen: 0,
msg_iov: &mut iovec as *mut iovec,
msg_iovlen: 1,
msg_control: null_mut(),
msg_controllen: 0,
msg_flags: 0,
};
if !out_fds.is_empty() {
let cmsg = cmsghdr {
cmsg_len: CMSG_LEN!(size_of::<RawFd>() * out_fds.len()),
cmsg_level: SOL_SOCKET,
cmsg_type: SCM_RIGHTS,
};
unsafe {
// Safe because cmsg_buffer was allocated to be large enough to contain cmsghdr.
write_unaligned(cmsg_buffer.as_mut_ptr() as *mut cmsghdr, cmsg);
// Safe because the cmsg_buffer was allocated to be large enough to hold out_fds.len()
// file descriptors.
copy_nonoverlapping(out_fds.as_ptr(),
CMSG_DATA(cmsg_buffer.as_mut_ptr()),
out_fds.len());
}
msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void;
msg.msg_controllen = cmsg_capacity;
}
// Safe because the msghdr was properly constructed from valid (or null) pointers of the
// indicated length and we check the return value.
let write_count = unsafe { sendmsg(fd, &msg, MSG_NOSIGNAL) };
if write_count == -1 {
Err(Error::last())
} else {
Ok(write_count as usize)
}
}
fn raw_recvmsg(fd: RawFd, in_data: &mut [u8], in_fds: &mut [RawFd]) -> Result<(usize, usize)> {
let cmsg_capacity = CMSG_SPACE!(size_of::<RawFd>() * in_fds.len());
let mut cmsg_buffer = CmsgBuffer::with_capacity(cmsg_capacity);
let mut iovec = iovec {
iov_base: in_data.as_mut_ptr() as *mut c_void,
iov_len: in_data.len(),
};
let mut msg = msghdr {
msg_name: null_mut(),
msg_namelen: 0,
msg_iov: &mut iovec as *mut iovec,
msg_iovlen: 1,
msg_control: null_mut(),
msg_controllen: 0,
msg_flags: 0,
};
if !in_fds.is_empty() {
msg.msg_control = cmsg_buffer.as_mut_ptr() as *mut c_void;
msg.msg_controllen = cmsg_capacity;
}
// Safe because the msghdr was properly constructed from valid (or null) pointers of the
// indicated length and we check the return value.
let total_read = unsafe { recvmsg(fd, &mut msg, 0) };
if total_read == -1 {
return Err(Error::last());
}
if total_read == 0 && msg.msg_controllen < size_of::<cmsghdr>() {
return Ok((0, 0));
}
let mut cmsg_ptr = msg.msg_control as *mut u8;
let mut in_fds_count = 0;
while !cmsg_ptr.is_null() {
// Safe because we checked that cmsg_ptr was non-null, and the loop is constructed such that
// that only happens when there is at least sizeof(cmsghdr) space after the pointer to read.
let cmsg = unsafe { (cmsg_ptr as *mut cmsghdr).read_unaligned() };
if cmsg.cmsg_level == SOL_SOCKET && cmsg.cmsg_type == SCM_RIGHTS {
let fd_count = (cmsg.cmsg_len - CMSG_LEN!(0)) / size_of::<RawFd>();
unsafe {
copy_nonoverlapping(CMSG_DATA(cmsg_ptr),
in_fds[in_fds_count..(in_fds_count + fd_count)].as_mut_ptr(),
fd_count);
}
in_fds_count += fd_count;
}
cmsg_ptr = get_next_cmsg(&msg, &cmsg, cmsg_ptr);
}
Ok((total_read as usize, in_fds_count))
}
/// Trait for file descriptors can send and receive socket control messages via `sendmsg` and
@ -44,6 +203,69 @@ fn cmsg_buffer_len(fd_count: usize) -> usize {
pub trait ScmSocket {
/// Gets the file descriptor of this socket.
fn socket_fd(&self) -> RawFd;
/// Sends the given data and file descriptor over the socket.
///
/// On success, returns the number of bytes sent.
///
/// # Arguments
///
/// * `buf` - A buffer of data to send on the `socket`.
/// * `fd` - A file descriptors to be sent.
fn send_with_fd<D: IntoIovec>(&self, buf: D, fd: RawFd) -> Result<usize> {
self.send_with_fds(buf, &[fd])
}
/// Sends the given data and file descriptors over the socket.
///
/// On success, returns the number of bytes sent.
///
/// # Arguments
///
/// * `buf` - A buffer of data to send on the `socket`.
/// * `fds` - A list of file descriptors to be sent.
fn send_with_fds<D: IntoIovec>(&self, buf: D, fd: &[RawFd]) -> Result<usize> {
raw_sendmsg(self.socket_fd(), buf, fd)
}
/// Receives data and potentially a file descriptor from the socket.
///
/// On success, returns the number of bytes and an optional file descriptor.
///
/// # Arguments
///
/// * `buf` - A buffer to receive data from the socket.vm
fn recv_with_fd(&self, buf: &mut [u8]) -> Result<(usize, Option<File>)> {
let mut fd = [0];
let (read_count, fd_count) = self.recv_with_fds(buf, &mut fd)?;
let file = if fd_count == 0 {
None
} else {
// Safe because the first fd from recv_with_fds is owned by us and valid because this
// branch was taken.
Some(unsafe { File::from_raw_fd(fd[0]) })
};
Ok((read_count, file))
}
/// Receives data and file descriptors from the socket.
///
/// On success, returns the number of bytes and file descriptors received as a tuple
/// `(bytes count, files count)`.
///
/// # Arguments
///
/// * `buf` - A buffer to receive data from the socket.
/// * `fds` - A slice of `RawFd`s to put the received file descriptors into. On success, the
/// number of valid file descriptors is indicated by the second element of the
/// returned tuple. The caller owns these file descriptors, but they will not be
/// closed on drop like a `File`-like type would be. It is recommended that each valid
/// file descriptor gets wrapped in a drop type that closes it after this returns.
fn recv_with_fds(&self, buf: &mut [u8], fds: &mut [RawFd]) -> Result<(usize, usize)> {
raw_recvmsg(self.socket_fd(), buf, fds)
}
}
impl ScmSocket for UnixDatagram {
@ -95,123 +317,6 @@ unsafe impl<'a> IntoIovec for VolatileSlice<'a> {
}
}
/// Used to send and receive messages with file descriptors on sockets that accept control messages
/// (e.g. Unix domain sockets).
pub struct Scm {
cmsg_buffer: Vec<u8>,
vecs: Vec<iovec>,
fds: Vec<RawFd>,
}
impl Scm {
/// Constructs a new Scm object with pre-allocated structures.
///
/// # Arguments
///
/// * `fd_count` - The maximum number of files that can be received per `recv` call.
pub fn new(fd_count: usize) -> Scm {
Scm {
cmsg_buffer: Vec::with_capacity(cmsg_buffer_len(fd_count)),
vecs: Vec::new(),
fds: vec![-1; fd_count],
}
}
/// Sends the given data and file descriptors over the given `socket`.
///
/// On success, returns the number of bytes sent.
///
/// # Arguments
///
/// * `socket` - A socket that supports socket control messages.
/// * `bufs` - A list of buffers to send on the `socket`.
/// * `fds` - A list of file descriptors to be sent.
pub fn send<T: ScmSocket, D: IntoIovec>(&mut self,
socket: &T,
bufs: &[D],
fds: &[RawFd])
-> Result<usize> {
let cmsg_buf_len = cmsg_buffer_len(fds.len());
self.cmsg_buffer.reserve(cmsg_buf_len);
self.vecs.clear();
for ref buf in bufs {
self.vecs
.push(iovec {
iov_base: buf.as_ptr() as *mut c_void,
iov_len: buf.size(),
});
}
let write_count = unsafe {
// Safe because we are giving scm_sendmsg only valid pointers and lengths and we check
// the return value.
self.cmsg_buffer.set_len(cmsg_buf_len);
scm_sendmsg(socket.socket_fd(),
self.vecs.as_ptr(),
self.vecs.len(),
self.cmsg_buffer.as_mut_ptr(),
fds.as_ptr(),
fds.len())
};
if write_count < 0 {
Err(Error::new(-write_count as i32))
} else {
Ok(write_count as usize)
}
}
/// Receives data and file descriptors from the given `socket` into the list of buffers.
///
/// On success, returns the number of bytes received.
///
/// # Arguments
///
/// * `socket` - A socket that supports socket control messages.
/// * `bufs` - A list of buffers to receive data from the `socket`. The `recvmsg` call fills
/// these directly.
/// * `files` - A vector of `File`s to put the received file descriptors into. This vector is
/// not cleared and will have at most `fd_count` (specified in `Scm::new`) `File`s
/// added to it.
pub fn recv<T: ScmSocket>(&mut self,
socket: &T,
bufs: &mut [&mut [u8]],
files: &mut Vec<File>)
-> Result<usize> {
let cmsg_buf_len = cmsg_buffer_len(files.len());
self.cmsg_buffer.reserve(cmsg_buf_len);
self.vecs.clear();
for buf in bufs {
self.vecs
.push(iovec {
iov_base: buf.as_mut_ptr() as *mut c_void,
iov_len: buf.len(),
});
}
let mut fd_count = self.fds.len();
let read_count = unsafe {
// Safe because we are giving scm_recvmsg only valid pointers and lengths and we check
// the return value.
self.cmsg_buffer.set_len(cmsg_buf_len);
scm_recvmsg(socket.socket_fd(),
self.vecs.as_mut_ptr(),
self.vecs.len(),
self.cmsg_buffer.as_mut_ptr(),
self.fds.as_mut_ptr(),
&mut fd_count as *mut usize)
};
if read_count < 0 {
Err(Error::new(-read_count as i32))
} else {
// Safe because we have unqiue ownership of each fd we wrap with File.
for &fd in &self.fds[0..fd_count] {
files.push(unsafe { File::from_raw_fd(fd) });
}
Ok(read_count as usize)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
@ -228,22 +333,22 @@ mod tests {
#[test]
fn buffer_len() {
assert_eq!(cmsg_buffer_len(0), size_of::<cmsghdr>());
assert_eq!(cmsg_buffer_len(1),
assert_eq!(CMSG_SPACE!(0 * size_of::<RawFd>()), size_of::<cmsghdr>());
assert_eq!(CMSG_SPACE!(1 * size_of::<RawFd>()),
size_of::<cmsghdr>() + size_of::<c_long>());
if size_of::<RawFd>() == 4 {
assert_eq!(cmsg_buffer_len(2),
assert_eq!(CMSG_SPACE!(2 * size_of::<RawFd>()),
size_of::<cmsghdr>() + size_of::<c_long>());
assert_eq!(cmsg_buffer_len(3),
assert_eq!(CMSG_SPACE!(3 * size_of::<RawFd>()),
size_of::<cmsghdr>() + size_of::<c_long>() * 2);
assert_eq!(cmsg_buffer_len(4),
assert_eq!(CMSG_SPACE!(4 * size_of::<RawFd>()),
size_of::<cmsghdr>() + size_of::<c_long>() * 2);
} else if size_of::<RawFd>() == 8 {
assert_eq!(cmsg_buffer_len(2),
assert_eq!(CMSG_SPACE!(2 * size_of::<RawFd>()),
size_of::<cmsghdr>() + size_of::<c_long>() * 2);
assert_eq!(cmsg_buffer_len(3),
assert_eq!(CMSG_SPACE!(3 * size_of::<RawFd>()),
size_of::<cmsghdr>() + size_of::<c_long>() * 3);
assert_eq!(cmsg_buffer_len(4),
assert_eq!(CMSG_SPACE!(4 * size_of::<RawFd>()),
size_of::<cmsghdr>() + size_of::<c_long>() * 4);
}
}
@ -252,51 +357,42 @@ mod tests {
fn send_recv_no_fd() {
let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
let mut scm = Scm::new(1);
let write_count = scm.send(&s1,
[[1u8, 1, 2].as_ref(), [21, 34, 55].as_ref()].as_ref(),
&[])
let write_count = s1.send_with_fds([1u8, 1, 2, 21, 34, 55].as_ref(), &[])
.expect("failed to send data");
assert_eq!(write_count, 6);
let mut buf1 = [0; 3];
let mut buf2 = [0; 3];
let mut bufs = [buf1.as_mut(), buf2.as_mut()];
let mut files = Vec::new();
let read_count = scm.recv(&s2, &mut bufs[..], &mut files)
let mut buf = [0; 6];
let mut files = [0; 1];
let (read_count, file_count) = s2.recv_with_fds(&mut buf[..], &mut files)
.expect("failed to recv data");
assert_eq!(read_count, 6);
assert!(files.is_empty());
assert_eq!(bufs[0], [1, 1, 2]);
assert_eq!(bufs[1], [21, 34, 55]);
assert_eq!(file_count, 0);
assert_eq!(buf, [1, 1, 2, 21, 34, 55]);
}
#[test]
fn send_recv_only_fd() {
let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
let mut scm = Scm::new(1);
let evt = EventFd::new().expect("failed to create eventfd");
let write_count = scm.send(&s1, &[[].as_ref()], &[evt.as_raw_fd()])
let write_count = s1.send_with_fd([].as_ref(), evt.as_raw_fd())
.expect("failed to send fd");
assert_eq!(write_count, 0);
let mut files = Vec::new();
let read_count = scm.recv(&s2, &mut [&mut []], &mut files)
.expect("failed to recv fd");
let (read_count, file_opt) = s2.recv_with_fd(&mut []).expect("failed to recv fd");
let mut file = file_opt.unwrap();
assert_eq!(read_count, 0);
assert_eq!(files.len(), 1);
assert!(files[0].as_raw_fd() >= 0);
assert_ne!(files[0].as_raw_fd(), s1.as_raw_fd());
assert_ne!(files[0].as_raw_fd(), s2.as_raw_fd());
assert_ne!(files[0].as_raw_fd(), evt.as_raw_fd());
assert!(file.as_raw_fd() >= 0);
assert_ne!(file.as_raw_fd(), s1.as_raw_fd());
assert_ne!(file.as_raw_fd(), s2.as_raw_fd());
assert_ne!(file.as_raw_fd(), evt.as_raw_fd());
files[0]
.write(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) })
file.write(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) })
.expect("failed to write to sent fd");
assert_eq!(evt.read().expect("failed to read from eventfd"), 1203);
@ -306,28 +402,28 @@ mod tests {
fn send_recv_with_fd() {
let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
let mut scm = Scm::new(1);
let evt = EventFd::new().expect("failed to create eventfd");
let write_count = scm.send(&s1, &[[237].as_ref()], &[evt.as_raw_fd()])
let write_count = s1.send_with_fds([237].as_ref(), &[evt.as_raw_fd()])
.expect("failed to send fd");
assert_eq!(write_count, 1);
let mut files = Vec::new();
let mut files = [0; 2];
let mut buf = [0u8];
let read_count = scm.recv(&s2, &mut [&mut buf], &mut files)
let (read_count, file_count) = s2.recv_with_fds(&mut buf, &mut files)
.expect("failed to recv fd");
assert_eq!(read_count, 1);
assert_eq!(buf[0], 237);
assert_eq!(files.len(), 1);
assert!(files[0].as_raw_fd() >= 0);
assert_ne!(files[0].as_raw_fd(), s1.as_raw_fd());
assert_ne!(files[0].as_raw_fd(), s2.as_raw_fd());
assert_ne!(files[0].as_raw_fd(), evt.as_raw_fd());
assert_eq!(file_count, 1);
assert!(files[0] >= 0);
assert_ne!(files[0], s1.as_raw_fd());
assert_ne!(files[0], s2.as_raw_fd());
assert_ne!(files[0], evt.as_raw_fd());
files[0]
.write(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) })
let mut file = unsafe { File::from_raw_fd(files[0]) };
file.write(unsafe { from_raw_parts(&1203u64 as *const u64 as *const u8, 8) })
.expect("failed to write to sent fd");
assert_eq!(evt.read().expect("failed to read from eventfd"), 1203);

View file

@ -27,7 +27,8 @@ use libc::{ERANGE, EINVAL, ENODEV};
use byteorder::{LittleEndian, WriteBytesExt};
use data_model::{DataInit, Le32, Le64, VolatileMemory};
use sys_util::{EventFd, Result, Error as SysError, MmapError, MemoryMapping, Scm, GuestAddress};
use sys_util::{EventFd, Result, Error as SysError, MmapError, MemoryMapping, ScmSocket,
GuestAddress};
use resources::{GpuMemoryDesc, GpuMemoryPlaneDesc, SystemAllocator};
use kvm::{IoeventAddress, Vm};
@ -134,11 +135,10 @@ impl VmRequest {
/// Receive a `VmRequest` from the given socket.
///
/// A `VmResponse` should be sent out over the given socket before another request is received.
pub fn recv(scm: &mut Scm, s: &UnixDatagram) -> VmControlResult<VmRequest> {
pub fn recv(s: &UnixDatagram) -> VmControlResult<VmRequest> {
assert_eq!(VM_REQUEST_SIZE, std::mem::size_of::<VmRequestStruct>());
let mut buf = [0; VM_REQUEST_SIZE];
let mut fds = Vec::new();
let read = scm.recv(s, &mut [&mut buf], &mut fds)
let (read, file) = s.recv_with_fd(&mut buf)
.map_err(|e| VmControlError::Recv(e))?;
if read != VM_REQUEST_SIZE {
return Err(VmControlError::BadSize(read));
@ -150,7 +150,7 @@ impl VmRequest {
match req.type_.into() {
VM_REQUEST_TYPE_EXIT => Ok(VmRequest::Exit),
VM_REQUEST_TYPE_REGISTER_MEMORY => {
let fd = fds.pop().ok_or(VmControlError::ExpectFd)?;
let fd = file.ok_or(VmControlError::ExpectFd)?;
Ok(VmRequest::RegisterMemory(MaybeOwnedFd::Owned(fd),
req.size.to_native() as usize))
}
@ -172,7 +172,7 @@ impl VmRequest {
///
/// After this request is a sent, a `VmResponse` should be received before sending another
/// request.
pub fn send(&self, scm: &mut Scm, s: &UnixDatagram) -> VmControlResult<()> {
pub fn send(&self, s: &UnixDatagram) -> VmControlResult<()> {
assert_eq!(VM_REQUEST_SIZE, std::mem::size_of::<VmRequestStruct>());
let mut req = VmRequestStruct::default();
let mut fd_buf = [0; 1];
@ -203,7 +203,7 @@ impl VmRequest {
}
let mut buf = [0; VM_REQUEST_SIZE];
buf.as_mut().get_ref(0).unwrap().store(req);
scm.send(s, &[buf.as_ref()], &fd_buf[..fd_len])
s.send_with_fds(buf.as_ref(), &fd_buf[..fd_len])
.map_err(|e| VmControlError::Send(e))?;
Ok(())
}
@ -332,10 +332,9 @@ impl VmResponse {
/// Receive a `VmResponse` from the given socket.
///
/// This should be called after the sending a `VmRequest` before sending another request.
pub fn recv(scm: &mut Scm, s: &UnixDatagram) -> VmControlResult<VmResponse> {
pub fn recv(s: &UnixDatagram) -> VmControlResult<VmResponse> {
let mut buf = [0; VM_RESPONSE_SIZE];
let mut fds = Vec::new();
let read = scm.recv(s, &mut [&mut buf], &mut fds)
let (read, file) = s.recv_with_fd(&mut buf)
.map_err(|e| VmControlError::Recv(e))?;
if read != VM_RESPONSE_SIZE {
return Err(VmControlError::BadSize(read));
@ -354,7 +353,7 @@ impl VmResponse {
})
}
VM_RESPONSE_TYPE_ALLOCATE_AND_REGISTER_GPU_MEMORY => {
let fd = fds.pop().ok_or(VmControlError::ExpectFd)?;
let fd = file.ok_or(VmControlError::ExpectFd)?;
Ok(VmResponse::AllocateAndRegisterGpuMemory {
fd: MaybeOwnedFd::Owned(fd),
pfn: resp.pfn.into(),
@ -377,7 +376,7 @@ impl VmResponse {
///
/// This must be called after receiving a `VmRequest` to indicate the outcome of that request's
/// execution.
pub fn send(&self, scm: &mut Scm, s: &UnixDatagram) -> VmControlResult<()> {
pub fn send(&self, s: &UnixDatagram) -> VmControlResult<()> {
let mut resp = VmResponseStruct::default();
let mut fd_buf = [0; 1];
let mut fd_len = 0;
@ -408,7 +407,7 @@ impl VmResponse {
}
let mut buf = [0; VM_RESPONSE_SIZE];
buf.as_mut().get_ref(0).unwrap().store(resp);
scm.send(s, &[buf.as_ref()], &fd_buf[..fd_len])
s.send_with_fds(buf.as_ref(), &fd_buf[..fd_len])
.map_err(|e| VmControlError::Send(e))?;
Ok(())
}
@ -427,9 +426,8 @@ mod tests {
#[test]
fn request_exit() {
let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
let mut scm = Scm::new(1);
VmRequest::Exit.send(&mut scm, &s1).unwrap();
match VmRequest::recv(&mut scm, &s2).unwrap() {
VmRequest::Exit.send(&s1).unwrap();
match VmRequest::recv(&s2).unwrap() {
VmRequest::Exit => {}
_ => panic!("recv wrong request variant"),
}
@ -439,14 +437,13 @@ mod tests {
fn request_register_memory() {
if !kernel_has_memfd() { return; }
let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
let mut scm = Scm::new(1);
let shm_size: usize = 4096;
let mut shm = SharedMemory::new(None).unwrap();
shm.set_size(shm_size as u64).unwrap();
VmRequest::RegisterMemory(MaybeOwnedFd::Borrowed(shm.as_raw_fd()), shm_size)
.send(&mut scm, &s1)
.send(&s1)
.unwrap();
match VmRequest::recv(&mut scm, &s2).unwrap() {
match VmRequest::recv(&s2).unwrap() {
VmRequest::RegisterMemory(MaybeOwnedFd::Owned(fd), size) => {
assert!(fd.as_raw_fd() >= 0);
assert_eq!(size, shm_size);
@ -458,11 +455,8 @@ mod tests {
#[test]
fn request_unregister_memory() {
let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
let mut scm = Scm::new(1);
VmRequest::UnregisterMemory(77)
.send(&mut scm, &s1)
.unwrap();
match VmRequest::recv(&mut scm, &s2).unwrap() {
VmRequest::UnregisterMemory(77).send(&s1).unwrap();
match VmRequest::recv(&s2).unwrap() {
VmRequest::UnregisterMemory(slot) => assert_eq!(slot, 77),
_ => panic!("recv wrong request variant"),
}
@ -471,11 +465,10 @@ mod tests {
#[test]
fn request_expect_fd() {
let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
let mut scm = Scm::new(1);
let mut bad_request = [0; VM_REQUEST_SIZE];
bad_request[0] = VM_REQUEST_TYPE_REGISTER_MEMORY as u8;
scm.send(&s2, &[bad_request.as_ref()], &[]).unwrap();
match VmRequest::recv(&mut scm, &s1) {
s2.send_with_fds(bad_request.as_ref(), &[]).unwrap();
match VmRequest::recv(&s1) {
Err(VmControlError::ExpectFd) => {}
_ => panic!("recv wrong error variant"),
}
@ -484,9 +477,8 @@ mod tests {
#[test]
fn request_no_data() {
let (s1, _) = UnixDatagram::pair().expect("failed to create socket pair");
let mut scm = Scm::new(1);
s1.shutdown(Shutdown::Both).unwrap();
match VmRequest::recv(&mut scm, &s1) {
match VmRequest::recv(&s1) {
Err(VmControlError::BadSize(s)) => assert_eq!(s, 0),
_ => panic!("recv wrong error variant"),
}
@ -495,9 +487,8 @@ mod tests {
#[test]
fn request_bad_size() {
let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
let mut scm = Scm::new(1);
scm.send(&s2, &[[12; 7].as_ref()], &[]).unwrap();
match VmRequest::recv(&mut scm, &s1) {
s2.send_with_fds([12; 7].as_ref(), &[]).unwrap();
match VmRequest::recv(&s1) {
Err(VmControlError::BadSize(_)) => {}
_ => panic!("recv wrong error variant"),
}
@ -506,10 +497,9 @@ mod tests {
#[test]
fn request_invalid_type() {
let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
let mut scm = Scm::new(1);
scm.send(&s2, &[[12; VM_REQUEST_SIZE].as_ref()], &[])
s2.send_with_fds([12; VM_REQUEST_SIZE].as_ref(), &[])
.unwrap();
match VmRequest::recv(&mut scm, &s1) {
match VmRequest::recv(&s1) {
Err(VmControlError::InvalidType) => {}
_ => panic!("recv wrong error variant"),
}
@ -518,14 +508,21 @@ mod tests {
#[test]
fn request_allocate_and_register_gpu_memory() {
let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
let mut scm = Scm::new(1);
let gpu_width: u32 = 32;
let gpu_height: u32 = 32;
let gpu_format: u32 = 0x34325258;
let r = VmRequest::AllocateAndRegisterGpuMemory { width: gpu_width, height: gpu_height, format: gpu_format };
r.send(&mut scm, &s1).unwrap();
match VmRequest::recv(&mut scm, &s2).unwrap() {
VmRequest::AllocateAndRegisterGpuMemory {width, height, format} => {
let r = VmRequest::AllocateAndRegisterGpuMemory {
width: gpu_width,
height: gpu_height,
format: gpu_format,
};
r.send(&s1).unwrap();
match VmRequest::recv(&s2).unwrap() {
VmRequest::AllocateAndRegisterGpuMemory {
width,
height,
format,
} => {
assert_eq!(width, gpu_width);
assert_eq!(height, gpu_width);
assert_eq!(format, gpu_format);
@ -537,9 +534,8 @@ mod tests {
#[test]
fn resp_ok() {
let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
let mut scm = Scm::new(1);
VmResponse::Ok.send(&mut scm, &s1).unwrap();
match VmResponse::recv(&mut scm, &s2).unwrap() {
VmResponse::Ok.send(&s1).unwrap();
match VmResponse::recv(&s2).unwrap() {
VmResponse::Ok => {}
_ => panic!("recv wrong response variant"),
}
@ -548,10 +544,9 @@ mod tests {
#[test]
fn resp_err() {
let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
let mut scm = Scm::new(1);
let r1 = VmResponse::Err(SysError::new(libc::EDESTADDRREQ));
r1.send(&mut scm, &s1).unwrap();
match VmResponse::recv(&mut scm, &s2).unwrap() {
r1.send(&s1).unwrap();
match VmResponse::recv(&s2).unwrap() {
VmResponse::Err(e) => {
assert_eq!(e, SysError::new(libc::EDESTADDRREQ));
}
@ -562,12 +557,14 @@ mod tests {
#[test]
fn resp_memory() {
let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
let mut scm = Scm::new(1);
let memory_pfn = 55;
let memory_slot = 66;
let r1 = VmResponse::RegisterMemory { pfn: memory_pfn, slot: memory_slot };
r1.send(&mut scm, &s1).unwrap();
match VmResponse::recv(&mut scm, &s2).unwrap() {
let r1 = VmResponse::RegisterMemory {
pfn: memory_pfn,
slot: memory_slot,
};
r1.send(&s1).unwrap();
match VmResponse::recv(&s2).unwrap() {
VmResponse::RegisterMemory { pfn, slot } => {
assert_eq!(pfn, memory_pfn);
assert_eq!(slot, memory_slot);
@ -579,9 +576,8 @@ mod tests {
#[test]
fn resp_no_data() {
let (s1, _) = UnixDatagram::pair().expect("failed to create socket pair");
let mut scm = Scm::new(1);
s1.shutdown(Shutdown::Both).unwrap();
match VmResponse::recv(&mut scm, &s1) {
match VmResponse::recv(&s1) {
Err(e) => {
assert_eq!(e, VmControlError::BadSize(0));
}
@ -592,9 +588,8 @@ mod tests {
#[test]
fn resp_bad_size() {
let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
let mut scm = Scm::new(1);
scm.send(&s2, &[[12; 7].as_ref()], &[]).unwrap();
match VmResponse::recv(&mut scm, &s1) {
s2.send_with_fds([12; 7].as_ref(), &[]).unwrap();
match VmResponse::recv(&s1) {
Err(e) => {
assert_eq!(e, VmControlError::BadSize(7));
}
@ -605,10 +600,9 @@ mod tests {
#[test]
fn resp_invalid_type() {
let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
let mut scm = Scm::new(1);
scm.send(&s2, &[[12; VM_RESPONSE_SIZE].as_ref()], &[])
s2.send_with_fds([12; VM_RESPONSE_SIZE].as_ref(), &[])
.unwrap();
match VmResponse::recv(&mut scm, &s1) {
match VmResponse::recv(&s1) {
Err(e) => {
assert_eq!(e, VmControlError::InvalidType);
}
@ -620,7 +614,6 @@ mod tests {
fn resp_allocate_and_register_gpu_memory() {
if !kernel_has_memfd() { return; }
let (s1, s2) = UnixDatagram::pair().expect("failed to create socket pair");
let mut scm = Scm::new(1);
let shm_size: usize = 4096;
let mut shm = SharedMemory::new(None).unwrap();
shm.set_size(shm_size as u64).unwrap();
@ -637,9 +630,14 @@ mod tests {
slot: memory_slot,
desc: GpuMemoryDesc { planes: memory_planes },
};
r1.send(&mut scm, &s1).unwrap();
match VmResponse::recv(&mut scm, &s2).unwrap() {
VmResponse::AllocateAndRegisterGpuMemory { fd, pfn, slot, desc } => {
r1.send(&s1).unwrap();
match VmResponse::recv(&s2).unwrap() {
VmResponse::AllocateAndRegisterGpuMemory {
fd,
pfn,
slot,
desc,
} => {
assert!(fd.as_raw_fd() >= 0);
assert_eq!(pfn, memory_pfn);
assert_eq!(slot, memory_slot);