vhost: Don't require GuestMemory in ::new()

GuestMemory is only needed for the set_mem_table and set_vring_addr
methods so take it in as a parameter there rather than storing it in the
struct.  Vhost-user devices don't have access to GuestMemory when the
vhost device is first constructed.

BUG=b:179756331
TEST=unit tests

Change-Id: Id446db43777c26b0dfbe8b37366f2da93de53b23
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/crosvm/+/3153211
Tested-by: kokoro <noreply+kokoro@google.com>
Commit-Queue: Chirantan Ekbote <chirantan@chromium.org>
Reviewed-by: Daniel Verkamp <dverkamp@chromium.org>
Reviewed-by: Keiichi Watanabe <keiichiw@chromium.org>
This commit is contained in:
Chirantan Ekbote 2021-09-10 18:27:16 +09:00 committed by Commit Bot
parent eb1640e301
commit 3e8d52b802
7 changed files with 65 additions and 92 deletions

View file

@ -50,7 +50,6 @@ where
ip_addr: Ipv4Addr,
netmask: Ipv4Addr,
mac_addr: MacAddress,
mem: &GuestMemory,
) -> Result<Net<T, U>> {
let kill_evt = Event::new().map_err(Error::CreateKillEvent)?;
@ -72,7 +71,7 @@ where
.map_err(Error::TapSetVnetHdrSize)?;
tap.enable().map_err(Error::TapEnable)?;
let vhost_net_handle = U::new(vhost_net_device_path, mem).map_err(Error::VhostOpen)?;
let vhost_net_handle = U::new(vhost_net_device_path).map_err(Error::VhostOpen)?;
let avail_features = base_features
| 1 << virtio_net::VIRTIO_NET_F_GUEST_CSUM
@ -190,7 +189,7 @@ where
fn activate(
&mut self,
_: GuestMemory,
mem: GuestMemory,
interrupt: Interrupt,
queues: Vec<Queue>,
queue_evts: Vec<Event>,
@ -238,8 +237,13 @@ where
}
Ok(())
};
let result =
worker.run(queue_evts, QUEUE_SIZES, activate_vqs, cleanup_vqs);
let result = worker.run(
mem,
queue_evts,
QUEUE_SIZES,
activate_vqs,
cleanup_vqs,
);
if let Err(e) = result {
error!("net worker thread exited with error: {}", e);
}
@ -369,7 +373,6 @@ pub mod tests {
}
fn create_net_common() -> Net<FakeTap, FakeNet<FakeTap>> {
let guest_memory = create_guest_memory().unwrap();
let features = base_features(ProtectionType::Unprotected);
Net::<FakeTap, FakeNet<FakeTap>>::new(
&PathBuf::from(""),
@ -377,7 +380,6 @@ pub mod tests {
Ipv4Addr::new(127, 0, 0, 1),
Ipv4Addr::new(255, 255, 255, 0),
"de:21:e8:47:6b:6a".parse().unwrap(),
&guest_memory,
)
.unwrap()
}

View file

@ -31,15 +31,9 @@ pub struct Vsock {
impl Vsock {
/// Create a new virtio-vsock device with the given VM cid.
pub fn new(
vhost_vsock_device_path: &Path,
base_features: u64,
cid: u64,
mem: &GuestMemory,
) -> Result<Vsock> {
pub fn new(vhost_vsock_device_path: &Path, base_features: u64, cid: u64) -> Result<Vsock> {
let kill_evt = Event::new().map_err(Error::CreateKillEvent)?;
let handle =
VhostVsockHandle::new(vhost_vsock_device_path, mem).map_err(Error::VhostOpen)?;
let handle = VhostVsockHandle::new(vhost_vsock_device_path).map_err(Error::VhostOpen)?;
let avail_features = base_features
| 1 << virtio_sys::vhost::VIRTIO_F_NOTIFY_ON_EMPTY
@ -147,7 +141,7 @@ impl VirtioDevice for Vsock {
fn activate(
&mut self,
_: GuestMemory,
mem: GuestMemory,
interrupt: Interrupt,
queues: Vec<Queue>,
queue_evts: Vec<Event>,
@ -184,7 +178,7 @@ impl VirtioDevice for Vsock {
};
let cleanup_vqs = |_handle: &VhostVsockHandle| -> Result<()> { Ok(()) };
let result =
worker.run(queue_evts, QUEUE_SIZES, activate_vqs, cleanup_vqs);
worker.run(mem, queue_evts, QUEUE_SIZES, activate_vqs, cleanup_vqs);
if let Err(e) = result {
error!("vsock worker thread exited with error: {:?}", e);
}

View file

@ -6,6 +6,7 @@ use std::os::raw::c_ulonglong;
use base::{error, Error as SysError, Event, PollToken, Tube, WaitContext};
use vhost::Vhost;
use vm_memory::GuestMemory;
use super::control_socket::{VhostDevRequest, VhostDevResponse};
use super::{Error, Result};
@ -46,6 +47,7 @@ impl<T: Vhost> Worker<T> {
pub fn run<F1, F2>(
&mut self,
mem: GuestMemory,
queue_evts: Vec<Event>,
queue_sizes: &[u16],
activate_vqs: F1,
@ -66,7 +68,7 @@ impl<T: Vhost> Worker<T> {
.map_err(Error::VhostSetFeatures)?;
self.vhost_handle
.set_mem_table()
.set_mem_table(&mem)
.map_err(Error::VhostSetMemTable)?;
for (queue_index, queue) in self.queues.iter().enumerate() {
@ -76,6 +78,7 @@ impl<T: Vhost> Worker<T> {
self.vhost_handle
.set_vring_addr(
&mem,
queue_sizes[queue_index],
queue.actual_size(),
queue_index,

View file

@ -594,7 +594,6 @@ fn create_net_device(
host_ip: Ipv4Addr,
netmask: Ipv4Addr,
mac_address: MacAddress,
mem: &GuestMemory,
) -> DeviceResult {
let mut vq_pairs = cfg.net_vq_pairs.unwrap_or(1);
let vcpu_count = cfg.vcpu_count.unwrap_or(1);
@ -611,7 +610,6 @@ fn create_net_device(
host_ip,
netmask,
mac_address,
mem,
)
.map_err(Error::VhostNetDeviceNew)?;
Box::new(dev) as Box<dyn VirtioDevice>
@ -941,9 +939,9 @@ fn register_video_device(
Ok(())
}
fn create_vhost_vsock_device(cfg: &Config, cid: u64, mem: &GuestMemory) -> DeviceResult {
fn create_vhost_vsock_device(cfg: &Config, cid: u64) -> DeviceResult {
let features = virtio::base_features(cfg.protected_vm);
let dev = virtio::vhost::Vsock::new(&cfg.vhost_vsock_device_path, features, cid, mem)
let dev = virtio::vhost::Vsock::new(&cfg.vhost_vsock_device_path, features, cid)
.map_err(Error::VhostVsockDeviceNew)?;
Ok(VirtioDeviceStub {
@ -1319,13 +1317,7 @@ fn create_virtio_devices(
if !cfg.vhost_user_net.is_empty() {
return Err(Error::VhostUserNetWithNetArgs);
}
devs.push(create_net_device(
cfg,
host_ip,
netmask,
mac_address,
vm.get_memory(),
)?);
devs.push(create_net_device(cfg, host_ip, netmask, mac_address)?);
}
for net in &cfg.vhost_user_net {
@ -1469,7 +1461,7 @@ fn create_virtio_devices(
}
if let Some(cid) = cfg.cid {
devs.push(create_vhost_vsock_device(cfg, cid, vm.get_memory())?);
devs.push(create_vhost_vsock_device(cfg, cid)?);
}
for vhost_user_fs in &cfg.vhost_user_fs {

View file

@ -12,7 +12,6 @@ pub use crate::vsock::Vsock;
use std::alloc::Layout;
use std::fmt::{self, Display};
use std::io::Error as IoError;
use std::mem;
use std::ptr::null;
use assertions::const_assert;
@ -64,9 +63,6 @@ fn ioctl_result<T>() -> Result<T> {
/// transfer. The device itself only needs to deal with setting up the kernel driver and
/// managing the control channel.
pub trait Vhost: AsRawDescriptor + std::marker::Sized {
/// Get the guest memory mapping.
fn mem(&self) -> &GuestMemory;
/// Set the current process as the owner of this file descriptor.
/// This must be run before any other vhost ioctls.
fn set_owner(&self) -> Result<()> {
@ -109,14 +105,14 @@ pub trait Vhost: AsRawDescriptor + std::marker::Sized {
}
/// Set the guest memory mappings for vhost to use.
fn set_mem_table(&self) -> Result<()> {
const SIZE_OF_MEMORY: usize = mem::size_of::<virtio_sys::vhost_memory>();
const SIZE_OF_REGION: usize = mem::size_of::<virtio_sys::vhost_memory_region>();
const ALIGN_OF_MEMORY: usize = mem::align_of::<virtio_sys::vhost_memory>();
const ALIGN_OF_REGION: usize = mem::align_of::<virtio_sys::vhost_memory_region>();
fn set_mem_table(&self, mem: &GuestMemory) -> Result<()> {
const SIZE_OF_MEMORY: usize = std::mem::size_of::<virtio_sys::vhost_memory>();
const SIZE_OF_REGION: usize = std::mem::size_of::<virtio_sys::vhost_memory_region>();
const ALIGN_OF_MEMORY: usize = std::mem::align_of::<virtio_sys::vhost_memory>();
const ALIGN_OF_REGION: usize = std::mem::align_of::<virtio_sys::vhost_memory_region>();
const_assert!(ALIGN_OF_MEMORY >= ALIGN_OF_REGION);
let num_regions = self.mem().num_regions() as usize;
let num_regions = mem.num_regions() as usize;
let size = SIZE_OF_MEMORY + num_regions * SIZE_OF_REGION;
let layout = Layout::from_size_align(size, ALIGN_OF_MEMORY).expect("impossible layout");
let mut allocation = LayoutAllocation::zeroed(layout);
@ -130,17 +126,15 @@ pub trait Vhost: AsRawDescriptor + std::marker::Sized {
// we correctly specify the size to match the amount of backing memory.
let vhost_regions = unsafe { vhost_memory.regions.as_mut_slice(num_regions as usize) };
let _ = self
.mem()
.with_regions::<_, ()>(|index, guest_addr, size, host_addr, _, _| {
vhost_regions[index] = virtio_sys::vhost_memory_region {
guest_phys_addr: guest_addr.offset() as u64,
memory_size: size as u64,
userspace_addr: host_addr as u64,
flags_padding: 0u64,
};
Ok(())
});
let _ = mem.with_regions::<_, ()>(|index, guest_addr, size, host_addr, _, _| {
vhost_regions[index] = virtio_sys::vhost_memory_region {
guest_phys_addr: guest_addr.offset() as u64,
memory_size: size as u64,
userspace_addr: host_addr as u64,
flags_padding: 0u64,
};
Ok(())
});
// This ioctl is called with a pointer that is valid for the lifetime
// of this function. The kernel will make its own copy of the memory
@ -179,6 +173,7 @@ pub trait Vhost: AsRawDescriptor + std::marker::Sized {
#[allow(clippy::if_same_then_else)]
fn is_valid(
&self,
mem: &GuestMemory,
queue_max_size: u16,
queue_size: u16,
desc_addr: GuestAddress,
@ -192,17 +187,17 @@ pub trait Vhost: AsRawDescriptor + std::marker::Sized {
false
} else if desc_addr
.checked_add(desc_table_size as u64)
.map_or(true, |v| !self.mem().address_in_range(v))
.map_or(true, |v| !mem.address_in_range(v))
{
false
} else if avail_addr
.checked_add(avail_ring_size as u64)
.map_or(true, |v| !self.mem().address_in_range(v))
.map_or(true, |v| !mem.address_in_range(v))
{
false
} else if used_addr
.checked_add(used_ring_size as u64)
.map_or(true, |v| !self.mem().address_in_range(v))
.map_or(true, |v| !mem.address_in_range(v))
{
false
} else {
@ -223,6 +218,7 @@ pub trait Vhost: AsRawDescriptor + std::marker::Sized {
/// * `log_addr` - Optional address for logging.
fn set_vring_addr(
&self,
mem: &GuestMemory,
queue_max_size: u16,
queue_size: u16,
queue_index: usize,
@ -234,25 +230,29 @@ pub trait Vhost: AsRawDescriptor + std::marker::Sized {
) -> Result<()> {
// TODO(smbarber): Refactor out virtio from crosvm so we can
// validate a Queue struct directly.
if !self.is_valid(queue_max_size, queue_size, desc_addr, used_addr, avail_addr) {
if !self.is_valid(
mem,
queue_max_size,
queue_size,
desc_addr,
used_addr,
avail_addr,
) {
return Err(Error::InvalidQueue);
}
let desc_addr = self
.mem()
let desc_addr = mem
.get_host_address(desc_addr)
.map_err(Error::DescriptorTableAddress)?;
let used_addr = self
.mem()
let used_addr = mem
.get_host_address(used_addr)
.map_err(Error::UsedAddress)?;
let avail_addr = self
.mem()
let avail_addr = mem
.get_host_address(avail_addr)
.map_err(Error::AvailAddress)?;
let log_addr = match log_addr {
None => null(),
Some(a) => self.mem().get_host_address(a).map_err(Error::LogAddress)?,
Some(a) => mem.get_host_address(a).map_err(Error::LogAddress)?,
};
let vring_addr = virtio_sys::vhost_vring_addr {
@ -360,8 +360,7 @@ mod tests {
}
fn create_fake_vhost_net() -> FakeNet<FakeTap> {
let gm = create_guest_memory().unwrap();
FakeNet::<FakeTap>::new(&PathBuf::from(""), &gm).unwrap()
FakeNet::<FakeTap>::new(&PathBuf::from("")).unwrap()
}
#[test]
@ -393,7 +392,8 @@ mod tests {
#[test]
fn set_mem_table() {
let vhost_net = create_fake_vhost_net();
let res = vhost_net.set_mem_table();
let gm = create_guest_memory().unwrap();
let res = vhost_net.set_mem_table(&gm);
assert_ok_or_known_failure(res);
}
@ -407,7 +407,9 @@ mod tests {
#[test]
fn set_vring_addr() {
let vhost_net = create_fake_vhost_net();
let gm = create_guest_memory().unwrap();
let res = vhost_net.set_vring_addr(
&gm,
1,
1,
0,

View file

@ -11,7 +11,6 @@ use std::{
};
use base::{ioctl_with_ref, AsRawDescriptor, RawDescriptor};
use vm_memory::GuestMemory;
use super::{ioctl_result, Error, Result, Vhost};
@ -23,13 +22,12 @@ pub struct Net<T> {
// descriptor must be dropped first, which will stop and tear down the
// vhost-net worker before GuestMemory can potentially be unmapped.
descriptor: File,
mem: GuestMemory,
phantom: PhantomData<T>,
}
pub trait NetT<T: TapT>: Vhost + AsRawDescriptor + Send + Sized {
/// Create a new NetT instance
fn new(vhost_net_device_path: &Path, mem: &GuestMemory) -> Result<Self>;
fn new(vhost_net_device_path: &Path) -> Result<Self>;
/// Set the tap file descriptor that will serve as the VHOST_NET backend.
/// This will start the vhost worker for the given queue.
@ -48,7 +46,7 @@ where
///
/// # Arguments
/// * `mem` - Guest memory mapping.
fn new(vhost_net_device_path: &Path, mem: &GuestMemory) -> Result<Net<T>> {
fn new(vhost_net_device_path: &Path) -> Result<Net<T>> {
Ok(Net::<T> {
descriptor: OpenOptions::new()
.read(true)
@ -56,7 +54,6 @@ where
.custom_flags(libc::O_CLOEXEC | libc::O_NONBLOCK)
.open(vhost_net_device_path)
.map_err(Error::VhostOpen)?,
mem: mem.clone(),
phantom: PhantomData,
})
}
@ -83,11 +80,7 @@ where
}
}
impl<T> Vhost for Net<T> {
fn mem(&self) -> &GuestMemory {
&self.mem
}
}
impl<T> Vhost for Net<T> {}
impl<T> AsRawDescriptor for Net<T> {
fn as_raw_descriptor(&self) -> RawDescriptor {
@ -104,7 +97,6 @@ pub mod fakes {
pub struct FakeNet<T> {
descriptor: File,
mem: GuestMemory,
phantom: PhantomData<T>,
}
@ -118,7 +110,7 @@ pub mod fakes {
where
T: TapT,
{
fn new(_vhost_net_device_path: &Path, mem: &GuestMemory) -> Result<FakeNet<T>> {
fn new(_vhost_net_device_path: &Path) -> Result<FakeNet<T>> {
Ok(FakeNet::<T> {
descriptor: OpenOptions::new()
.read(true)
@ -126,7 +118,6 @@ pub mod fakes {
.create(true)
.open(TMP_FILE)
.unwrap(),
mem: mem.clone(),
phantom: PhantomData,
})
}
@ -136,11 +127,7 @@ pub mod fakes {
}
}
impl<T> Vhost for FakeNet<T> {
fn mem(&self) -> &GuestMemory {
&self.mem
}
}
impl<T> Vhost for FakeNet<T> {}
impl<T> AsRawDescriptor for FakeNet<T> {
fn as_raw_descriptor(&self) -> RawDescriptor {

View file

@ -10,19 +10,17 @@ use std::{
use base::{ioctl_with_ref, AsRawDescriptor, RawDescriptor};
use virtio_sys::{VHOST_VSOCK_SET_GUEST_CID, VHOST_VSOCK_SET_RUNNING};
use vm_memory::GuestMemory;
use super::{ioctl_result, Error, Result, Vhost};
/// Handle for running VHOST_VSOCK ioctls.
pub struct Vsock {
descriptor: File,
mem: GuestMemory,
}
impl Vsock {
/// Open a handle to a new VHOST_VSOCK instance.
pub fn new(vhost_vsock_device_path: &Path, mem: &GuestMemory) -> Result<Vsock> {
pub fn new<P: AsRef<Path>>(vhost_vsock_device_path: P) -> Result<Vsock> {
Ok(Vsock {
descriptor: OpenOptions::new()
.read(true)
@ -30,7 +28,6 @@ impl Vsock {
.custom_flags(libc::O_CLOEXEC | libc::O_NONBLOCK)
.open(vhost_vsock_device_path)
.map_err(Error::VhostOpen)?,
mem: mem.clone(),
})
}
@ -69,11 +66,7 @@ impl Vsock {
}
}
impl Vhost for Vsock {
fn mem(&self) -> &GuestMemory {
&self.mem
}
}
impl Vhost for Vsock {}
impl AsRawDescriptor for Vsock {
fn as_raw_descriptor(&self) -> RawDescriptor {