From 3e8d52b802d37db89b9aac86f0fbbac31cd66668 Mon Sep 17 00:00:00 2001 From: Chirantan Ekbote Date: Fri, 10 Sep 2021 18:27:16 +0900 Subject: [PATCH] vhost: Don't require GuestMemory in ::new() GuestMemory is only needed for the set_mem_table and set_vring_addr methods so take it in as a parameter there rather than storing it in the struct. Vhost-user devices don't have access to GuestMemory when the vhost device is first constructed. BUG=b:179756331 TEST=unit tests Change-Id: Id446db43777c26b0dfbe8b37366f2da93de53b23 Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/crosvm/+/3153211 Tested-by: kokoro Commit-Queue: Chirantan Ekbote Reviewed-by: Daniel Verkamp Reviewed-by: Keiichi Watanabe --- devices/src/virtio/vhost/net.rs | 16 ++++--- devices/src/virtio/vhost/vsock.rs | 14 ++---- devices/src/virtio/vhost/worker.rs | 5 ++- src/linux.rs | 16 ++----- vhost/src/lib.rs | 72 +++++++++++++++--------------- vhost/src/net.rs | 23 +++------- vhost/src/vsock.rs | 11 +---- 7 files changed, 65 insertions(+), 92 deletions(-) diff --git a/devices/src/virtio/vhost/net.rs b/devices/src/virtio/vhost/net.rs index 5802f64a7e..97eb5460c4 100644 --- a/devices/src/virtio/vhost/net.rs +++ b/devices/src/virtio/vhost/net.rs @@ -50,7 +50,6 @@ where ip_addr: Ipv4Addr, netmask: Ipv4Addr, mac_addr: MacAddress, - mem: &GuestMemory, ) -> Result> { let kill_evt = Event::new().map_err(Error::CreateKillEvent)?; @@ -72,7 +71,7 @@ where .map_err(Error::TapSetVnetHdrSize)?; tap.enable().map_err(Error::TapEnable)?; - let vhost_net_handle = U::new(vhost_net_device_path, mem).map_err(Error::VhostOpen)?; + let vhost_net_handle = U::new(vhost_net_device_path).map_err(Error::VhostOpen)?; let avail_features = base_features | 1 << virtio_net::VIRTIO_NET_F_GUEST_CSUM @@ -190,7 +189,7 @@ where fn activate( &mut self, - _: GuestMemory, + mem: GuestMemory, interrupt: Interrupt, queues: Vec, queue_evts: Vec, @@ -238,8 +237,13 @@ where } Ok(()) }; - let result = - worker.run(queue_evts, QUEUE_SIZES, activate_vqs, cleanup_vqs); + let result = worker.run( + mem, + queue_evts, + QUEUE_SIZES, + activate_vqs, + cleanup_vqs, + ); if let Err(e) = result { error!("net worker thread exited with error: {}", e); } @@ -369,7 +373,6 @@ pub mod tests { } fn create_net_common() -> Net> { - let guest_memory = create_guest_memory().unwrap(); let features = base_features(ProtectionType::Unprotected); Net::>::new( &PathBuf::from(""), @@ -377,7 +380,6 @@ pub mod tests { Ipv4Addr::new(127, 0, 0, 1), Ipv4Addr::new(255, 255, 255, 0), "de:21:e8:47:6b:6a".parse().unwrap(), - &guest_memory, ) .unwrap() } diff --git a/devices/src/virtio/vhost/vsock.rs b/devices/src/virtio/vhost/vsock.rs index c825529022..21ce139436 100644 --- a/devices/src/virtio/vhost/vsock.rs +++ b/devices/src/virtio/vhost/vsock.rs @@ -31,15 +31,9 @@ pub struct Vsock { impl Vsock { /// Create a new virtio-vsock device with the given VM cid. - pub fn new( - vhost_vsock_device_path: &Path, - base_features: u64, - cid: u64, - mem: &GuestMemory, - ) -> Result { + pub fn new(vhost_vsock_device_path: &Path, base_features: u64, cid: u64) -> Result { let kill_evt = Event::new().map_err(Error::CreateKillEvent)?; - let handle = - VhostVsockHandle::new(vhost_vsock_device_path, mem).map_err(Error::VhostOpen)?; + let handle = VhostVsockHandle::new(vhost_vsock_device_path).map_err(Error::VhostOpen)?; let avail_features = base_features | 1 << virtio_sys::vhost::VIRTIO_F_NOTIFY_ON_EMPTY @@ -147,7 +141,7 @@ impl VirtioDevice for Vsock { fn activate( &mut self, - _: GuestMemory, + mem: GuestMemory, interrupt: Interrupt, queues: Vec, queue_evts: Vec, @@ -184,7 +178,7 @@ impl VirtioDevice for Vsock { }; let cleanup_vqs = |_handle: &VhostVsockHandle| -> Result<()> { Ok(()) }; let result = - worker.run(queue_evts, QUEUE_SIZES, activate_vqs, cleanup_vqs); + worker.run(mem, queue_evts, QUEUE_SIZES, activate_vqs, cleanup_vqs); if let Err(e) = result { error!("vsock worker thread exited with error: {:?}", e); } diff --git a/devices/src/virtio/vhost/worker.rs b/devices/src/virtio/vhost/worker.rs index 9d7822ebfa..222ca4e40c 100644 --- a/devices/src/virtio/vhost/worker.rs +++ b/devices/src/virtio/vhost/worker.rs @@ -6,6 +6,7 @@ use std::os::raw::c_ulonglong; use base::{error, Error as SysError, Event, PollToken, Tube, WaitContext}; use vhost::Vhost; +use vm_memory::GuestMemory; use super::control_socket::{VhostDevRequest, VhostDevResponse}; use super::{Error, Result}; @@ -46,6 +47,7 @@ impl Worker { pub fn run( &mut self, + mem: GuestMemory, queue_evts: Vec, queue_sizes: &[u16], activate_vqs: F1, @@ -66,7 +68,7 @@ impl Worker { .map_err(Error::VhostSetFeatures)?; self.vhost_handle - .set_mem_table() + .set_mem_table(&mem) .map_err(Error::VhostSetMemTable)?; for (queue_index, queue) in self.queues.iter().enumerate() { @@ -76,6 +78,7 @@ impl Worker { self.vhost_handle .set_vring_addr( + &mem, queue_sizes[queue_index], queue.actual_size(), queue_index, diff --git a/src/linux.rs b/src/linux.rs index eff84f26c5..4c6e457a14 100644 --- a/src/linux.rs +++ b/src/linux.rs @@ -594,7 +594,6 @@ fn create_net_device( host_ip: Ipv4Addr, netmask: Ipv4Addr, mac_address: MacAddress, - mem: &GuestMemory, ) -> DeviceResult { let mut vq_pairs = cfg.net_vq_pairs.unwrap_or(1); let vcpu_count = cfg.vcpu_count.unwrap_or(1); @@ -611,7 +610,6 @@ fn create_net_device( host_ip, netmask, mac_address, - mem, ) .map_err(Error::VhostNetDeviceNew)?; Box::new(dev) as Box @@ -941,9 +939,9 @@ fn register_video_device( Ok(()) } -fn create_vhost_vsock_device(cfg: &Config, cid: u64, mem: &GuestMemory) -> DeviceResult { +fn create_vhost_vsock_device(cfg: &Config, cid: u64) -> DeviceResult { let features = virtio::base_features(cfg.protected_vm); - let dev = virtio::vhost::Vsock::new(&cfg.vhost_vsock_device_path, features, cid, mem) + let dev = virtio::vhost::Vsock::new(&cfg.vhost_vsock_device_path, features, cid) .map_err(Error::VhostVsockDeviceNew)?; Ok(VirtioDeviceStub { @@ -1319,13 +1317,7 @@ fn create_virtio_devices( if !cfg.vhost_user_net.is_empty() { return Err(Error::VhostUserNetWithNetArgs); } - devs.push(create_net_device( - cfg, - host_ip, - netmask, - mac_address, - vm.get_memory(), - )?); + devs.push(create_net_device(cfg, host_ip, netmask, mac_address)?); } for net in &cfg.vhost_user_net { @@ -1469,7 +1461,7 @@ fn create_virtio_devices( } if let Some(cid) = cfg.cid { - devs.push(create_vhost_vsock_device(cfg, cid, vm.get_memory())?); + devs.push(create_vhost_vsock_device(cfg, cid)?); } for vhost_user_fs in &cfg.vhost_user_fs { diff --git a/vhost/src/lib.rs b/vhost/src/lib.rs index 618c5ac470..4a70599f31 100644 --- a/vhost/src/lib.rs +++ b/vhost/src/lib.rs @@ -12,7 +12,6 @@ pub use crate::vsock::Vsock; use std::alloc::Layout; use std::fmt::{self, Display}; use std::io::Error as IoError; -use std::mem; use std::ptr::null; use assertions::const_assert; @@ -64,9 +63,6 @@ fn ioctl_result() -> Result { /// transfer. The device itself only needs to deal with setting up the kernel driver and /// managing the control channel. pub trait Vhost: AsRawDescriptor + std::marker::Sized { - /// Get the guest memory mapping. - fn mem(&self) -> &GuestMemory; - /// Set the current process as the owner of this file descriptor. /// This must be run before any other vhost ioctls. fn set_owner(&self) -> Result<()> { @@ -109,14 +105,14 @@ pub trait Vhost: AsRawDescriptor + std::marker::Sized { } /// Set the guest memory mappings for vhost to use. - fn set_mem_table(&self) -> Result<()> { - const SIZE_OF_MEMORY: usize = mem::size_of::(); - const SIZE_OF_REGION: usize = mem::size_of::(); - const ALIGN_OF_MEMORY: usize = mem::align_of::(); - const ALIGN_OF_REGION: usize = mem::align_of::(); + fn set_mem_table(&self, mem: &GuestMemory) -> Result<()> { + const SIZE_OF_MEMORY: usize = std::mem::size_of::(); + const SIZE_OF_REGION: usize = std::mem::size_of::(); + const ALIGN_OF_MEMORY: usize = std::mem::align_of::(); + const ALIGN_OF_REGION: usize = std::mem::align_of::(); const_assert!(ALIGN_OF_MEMORY >= ALIGN_OF_REGION); - let num_regions = self.mem().num_regions() as usize; + let num_regions = mem.num_regions() as usize; let size = SIZE_OF_MEMORY + num_regions * SIZE_OF_REGION; let layout = Layout::from_size_align(size, ALIGN_OF_MEMORY).expect("impossible layout"); let mut allocation = LayoutAllocation::zeroed(layout); @@ -130,17 +126,15 @@ pub trait Vhost: AsRawDescriptor + std::marker::Sized { // we correctly specify the size to match the amount of backing memory. let vhost_regions = unsafe { vhost_memory.regions.as_mut_slice(num_regions as usize) }; - let _ = self - .mem() - .with_regions::<_, ()>(|index, guest_addr, size, host_addr, _, _| { - vhost_regions[index] = virtio_sys::vhost_memory_region { - guest_phys_addr: guest_addr.offset() as u64, - memory_size: size as u64, - userspace_addr: host_addr as u64, - flags_padding: 0u64, - }; - Ok(()) - }); + let _ = mem.with_regions::<_, ()>(|index, guest_addr, size, host_addr, _, _| { + vhost_regions[index] = virtio_sys::vhost_memory_region { + guest_phys_addr: guest_addr.offset() as u64, + memory_size: size as u64, + userspace_addr: host_addr as u64, + flags_padding: 0u64, + }; + Ok(()) + }); // This ioctl is called with a pointer that is valid for the lifetime // of this function. The kernel will make its own copy of the memory @@ -179,6 +173,7 @@ pub trait Vhost: AsRawDescriptor + std::marker::Sized { #[allow(clippy::if_same_then_else)] fn is_valid( &self, + mem: &GuestMemory, queue_max_size: u16, queue_size: u16, desc_addr: GuestAddress, @@ -192,17 +187,17 @@ pub trait Vhost: AsRawDescriptor + std::marker::Sized { false } else if desc_addr .checked_add(desc_table_size as u64) - .map_or(true, |v| !self.mem().address_in_range(v)) + .map_or(true, |v| !mem.address_in_range(v)) { false } else if avail_addr .checked_add(avail_ring_size as u64) - .map_or(true, |v| !self.mem().address_in_range(v)) + .map_or(true, |v| !mem.address_in_range(v)) { false } else if used_addr .checked_add(used_ring_size as u64) - .map_or(true, |v| !self.mem().address_in_range(v)) + .map_or(true, |v| !mem.address_in_range(v)) { false } else { @@ -223,6 +218,7 @@ pub trait Vhost: AsRawDescriptor + std::marker::Sized { /// * `log_addr` - Optional address for logging. fn set_vring_addr( &self, + mem: &GuestMemory, queue_max_size: u16, queue_size: u16, queue_index: usize, @@ -234,25 +230,29 @@ pub trait Vhost: AsRawDescriptor + std::marker::Sized { ) -> Result<()> { // TODO(smbarber): Refactor out virtio from crosvm so we can // validate a Queue struct directly. - if !self.is_valid(queue_max_size, queue_size, desc_addr, used_addr, avail_addr) { + if !self.is_valid( + mem, + queue_max_size, + queue_size, + desc_addr, + used_addr, + avail_addr, + ) { return Err(Error::InvalidQueue); } - let desc_addr = self - .mem() + let desc_addr = mem .get_host_address(desc_addr) .map_err(Error::DescriptorTableAddress)?; - let used_addr = self - .mem() + let used_addr = mem .get_host_address(used_addr) .map_err(Error::UsedAddress)?; - let avail_addr = self - .mem() + let avail_addr = mem .get_host_address(avail_addr) .map_err(Error::AvailAddress)?; let log_addr = match log_addr { None => null(), - Some(a) => self.mem().get_host_address(a).map_err(Error::LogAddress)?, + Some(a) => mem.get_host_address(a).map_err(Error::LogAddress)?, }; let vring_addr = virtio_sys::vhost_vring_addr { @@ -360,8 +360,7 @@ mod tests { } fn create_fake_vhost_net() -> FakeNet { - let gm = create_guest_memory().unwrap(); - FakeNet::::new(&PathBuf::from(""), &gm).unwrap() + FakeNet::::new(&PathBuf::from("")).unwrap() } #[test] @@ -393,7 +392,8 @@ mod tests { #[test] fn set_mem_table() { let vhost_net = create_fake_vhost_net(); - let res = vhost_net.set_mem_table(); + let gm = create_guest_memory().unwrap(); + let res = vhost_net.set_mem_table(&gm); assert_ok_or_known_failure(res); } @@ -407,7 +407,9 @@ mod tests { #[test] fn set_vring_addr() { let vhost_net = create_fake_vhost_net(); + let gm = create_guest_memory().unwrap(); let res = vhost_net.set_vring_addr( + &gm, 1, 1, 0, diff --git a/vhost/src/net.rs b/vhost/src/net.rs index 59dcbbcc10..9d03dd7ed5 100644 --- a/vhost/src/net.rs +++ b/vhost/src/net.rs @@ -11,7 +11,6 @@ use std::{ }; use base::{ioctl_with_ref, AsRawDescriptor, RawDescriptor}; -use vm_memory::GuestMemory; use super::{ioctl_result, Error, Result, Vhost}; @@ -23,13 +22,12 @@ pub struct Net { // descriptor must be dropped first, which will stop and tear down the // vhost-net worker before GuestMemory can potentially be unmapped. descriptor: File, - mem: GuestMemory, phantom: PhantomData, } pub trait NetT: Vhost + AsRawDescriptor + Send + Sized { /// Create a new NetT instance - fn new(vhost_net_device_path: &Path, mem: &GuestMemory) -> Result; + fn new(vhost_net_device_path: &Path) -> Result; /// Set the tap file descriptor that will serve as the VHOST_NET backend. /// This will start the vhost worker for the given queue. @@ -48,7 +46,7 @@ where /// /// # Arguments /// * `mem` - Guest memory mapping. - fn new(vhost_net_device_path: &Path, mem: &GuestMemory) -> Result> { + fn new(vhost_net_device_path: &Path) -> Result> { Ok(Net:: { descriptor: OpenOptions::new() .read(true) @@ -56,7 +54,6 @@ where .custom_flags(libc::O_CLOEXEC | libc::O_NONBLOCK) .open(vhost_net_device_path) .map_err(Error::VhostOpen)?, - mem: mem.clone(), phantom: PhantomData, }) } @@ -83,11 +80,7 @@ where } } -impl Vhost for Net { - fn mem(&self) -> &GuestMemory { - &self.mem - } -} +impl Vhost for Net {} impl AsRawDescriptor for Net { fn as_raw_descriptor(&self) -> RawDescriptor { @@ -104,7 +97,6 @@ pub mod fakes { pub struct FakeNet { descriptor: File, - mem: GuestMemory, phantom: PhantomData, } @@ -118,7 +110,7 @@ pub mod fakes { where T: TapT, { - fn new(_vhost_net_device_path: &Path, mem: &GuestMemory) -> Result> { + fn new(_vhost_net_device_path: &Path) -> Result> { Ok(FakeNet:: { descriptor: OpenOptions::new() .read(true) @@ -126,7 +118,6 @@ pub mod fakes { .create(true) .open(TMP_FILE) .unwrap(), - mem: mem.clone(), phantom: PhantomData, }) } @@ -136,11 +127,7 @@ pub mod fakes { } } - impl Vhost for FakeNet { - fn mem(&self) -> &GuestMemory { - &self.mem - } - } + impl Vhost for FakeNet {} impl AsRawDescriptor for FakeNet { fn as_raw_descriptor(&self) -> RawDescriptor { diff --git a/vhost/src/vsock.rs b/vhost/src/vsock.rs index c5afa89384..32ed1f8a35 100644 --- a/vhost/src/vsock.rs +++ b/vhost/src/vsock.rs @@ -10,19 +10,17 @@ use std::{ use base::{ioctl_with_ref, AsRawDescriptor, RawDescriptor}; use virtio_sys::{VHOST_VSOCK_SET_GUEST_CID, VHOST_VSOCK_SET_RUNNING}; -use vm_memory::GuestMemory; use super::{ioctl_result, Error, Result, Vhost}; /// Handle for running VHOST_VSOCK ioctls. pub struct Vsock { descriptor: File, - mem: GuestMemory, } impl Vsock { /// Open a handle to a new VHOST_VSOCK instance. - pub fn new(vhost_vsock_device_path: &Path, mem: &GuestMemory) -> Result { + pub fn new>(vhost_vsock_device_path: P) -> Result { Ok(Vsock { descriptor: OpenOptions::new() .read(true) @@ -30,7 +28,6 @@ impl Vsock { .custom_flags(libc::O_CLOEXEC | libc::O_NONBLOCK) .open(vhost_vsock_device_path) .map_err(Error::VhostOpen)?, - mem: mem.clone(), }) } @@ -69,11 +66,7 @@ impl Vsock { } } -impl Vhost for Vsock { - fn mem(&self) -> &GuestMemory { - &self.mem - } -} +impl Vhost for Vsock {} impl AsRawDescriptor for Vsock { fn as_raw_descriptor(&self) -> RawDescriptor {