diff --git a/Cargo.toml b/Cargo.toml index a9e2d07e4b..780c439a26 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -132,7 +132,7 @@ direct = ["balloon", "devices/direct", "arch/direct", "x86_64/direct"] ffmpeg = ["devices/ffmpeg"] gdb = ["gdbstub", "gdbstub_arch", "arch/gdb", "vm_control/gdb", "x86_64/gdb"] gfxstream = ["devices/gfxstream"] -gpu = ["devices/gpu", "vm_memory/udmabuf"] +gpu = ["devices/gpu"] haxm = ["hypervisor/haxm"] whpx = ["devices/whpx", "hypervisor/whpx"] vaapi = ["devices/vaapi"] diff --git a/devices/src/vfio.rs b/devices/src/vfio.rs index f831e21f96..c780219555 100644 --- a/devices/src/vfio.rs +++ b/devices/src/vfio.rs @@ -1406,6 +1406,17 @@ impl VfioDevice { .map_err(VfioError::Resources) } + pub fn get_iova(&self, alloc: &Alloc) -> Option { + self.iova_alloc.lock().get(alloc).map(|res| res.0) + } + + pub fn release_iova(&self, alloc: Alloc) -> Result { + self.iova_alloc + .lock() + .release(alloc) + .map_err(VfioError::Resources) + } + pub fn get_max_addr(&self) -> u64 { self.iova_alloc.lock().get_max_addr() } diff --git a/devices/src/virtio/vhost/mod.rs b/devices/src/virtio/vhost/mod.rs index 1ef2b250bc..31f8173a9b 100644 --- a/devices/src/virtio/vhost/mod.rs +++ b/devices/src/virtio/vhost/mod.rs @@ -4,6 +4,8 @@ //! Implements vhost-based virtio devices. +use anyhow::anyhow; +use anyhow::Context; use base::Error as SysError; use base::TubeError; use data_model::DataInit; @@ -119,7 +121,7 @@ pub enum Error { pub type Result = std::result::Result; -const HEADER_LEN: usize = std::mem::size_of::>(); +pub const HEADER_LEN: usize = std::mem::size_of::>(); pub fn vhost_header_from_bytes(bytes: &[u8]) -> Option<&VhostUserMsgHeader> { if bytes.len() < HEADER_LEN { @@ -129,3 +131,21 @@ pub fn vhost_header_from_bytes(bytes: &[u8]) -> Option<&VhostUserMsgHead // This can't fail because we already checked the size and because packed alignment is 1. Some(VhostUserMsgHeader::::from_slice(&bytes[0..HEADER_LEN]).unwrap()) } + +pub fn vhost_body_from_message_bytes(bytes: &mut [u8]) -> anyhow::Result<&mut T> { + let body_len = std::mem::size_of::(); + let hdr = vhost_header_from_bytes::(bytes).context("failed to parse header")?; + + if body_len != hdr.get_size() as usize || bytes.len() != body_len + HEADER_LEN { + return Err(anyhow!( + "parse error: body_len={} hdr_size={} msg_size={}", + body_len, + hdr.get_size(), + bytes.len() + )); + } + + // We already checked the size. This can only fail due to alignment, but all valid + // message types are packed (i.e. alignment=1). + Ok(T::from_mut_slice(&mut bytes[HEADER_LEN..]).expect("bad alignment")) +} diff --git a/devices/src/virtio/vhost/user/device/vvu/device.rs b/devices/src/virtio/vhost/user/device/vvu/device.rs index fa7c7718e7..efe5f784e4 100644 --- a/devices/src/virtio/vhost/user/device/vvu/device.rs +++ b/devices/src/virtio/vhost/user/device/vvu/device.rs @@ -5,6 +5,7 @@ //! Implement a struct that works as a `vmm_vhost`'s backend. use std::cmp::Ordering; +use std::io::Error as IoError; use std::io::IoSlice; use std::io::IoSliceMut; use std::mem; @@ -21,13 +22,21 @@ use anyhow::Context; use anyhow::Result; use base::error; use base::info; +use base::AsRawDescriptor; +use base::Descriptor; use base::Event; +use base::MappedRegion; +use base::MemoryMappingBuilder; +use base::MemoryMappingBuilderUnix; +use base::Protection; use base::RawDescriptor; +use base::SafeDescriptor; use cros_async::EventAsync; use cros_async::Executor; use futures::pin_mut; use futures::select; use futures::FutureExt; +use resources::Alloc; use sync::Mutex; use vmm_vhost::connection::vfio::Device as VfioDeviceTrait; use vmm_vhost::connection::vfio::Endpoint as VfioEndpoint; @@ -35,9 +44,11 @@ use vmm_vhost::connection::vfio::RecvIntoBufsError; use vmm_vhost::connection::Endpoint; use vmm_vhost::message::*; +use crate::vfio::VfioDevice; use crate::virtio::vhost::user::device::vvu::pci::QueueNotifier; use crate::virtio::vhost::user::device::vvu::pci::VvuPciDevice; use crate::virtio::vhost::user::device::vvu::queue::UserQueue; +use crate::virtio::vhost::vhost_body_from_message_bytes; use crate::virtio::vhost::vhost_header_from_bytes; use crate::virtio::vhost::HEADER_LEN; @@ -133,11 +144,21 @@ impl VfioReceiver { Ok((len, ret_vec)) } - fn recv_into_bufs(&mut self, bufs: &mut [IoSliceMut]) -> Result { + fn recv_into_bufs( + &mut self, + bufs: &mut [IoSliceMut], + mut processor: Option<&mut BackendChannelInner>, + ) -> Result { let mut size = 0; for buf in bufs { - let (len, _) = self.recv_into_buf(buf)?; + let (len, msg) = self.recv_into_buf(buf)?; size += len; + + if let (Some(processor), Some(msg)) = (processor.as_mut(), msg) { + processor + .postprocess_rx(msg) + .map_err(RecvIntoBufsError::Fatal)?; + } } Ok(size) @@ -148,6 +169,7 @@ impl VfioReceiver { #[derive(Default)] struct EndpointTxBuffer { bytes: Vec, + files: Vec, } // Utility class for writing an input vhost-user byte stream to the vvu @@ -163,9 +185,18 @@ impl Queue { iovs: &[IoSlice], fds: Option<&[RawDescriptor]>, tx_state: &mut EndpointTxBuffer, + processor: Option<&mut BackendChannelInner>, ) -> Result { - if fds.is_some() { - bail!("cannot send FDs"); + if let Some(fds) = fds { + if processor.is_none() { + bail!("cannot send FDs"); + } + + let fds: std::result::Result, IoError> = fds + .iter() + .map(|fd| SafeDescriptor::try_from(&Descriptor(*fd) as &dyn AsRawDescriptor)) + .collect(); + tx_state.files = fds?; } let mut size = 0; @@ -181,6 +212,16 @@ impl Queue { Ordering::Greater => (), Ordering::Equal => { let msg = mem::take(&mut tx_state.bytes); + let files = std::mem::take(&mut tx_state.files); + + let msg = if let Some(processor) = processor { + processor + .preprocess_tx(msg, files) + .context("failed to preprocess message")? + } else { + msg + }; + self.txq.write(&msg).context("Failed to send data")?; } Ordering::Less => bail!("sent bytes larger than message size"), @@ -357,6 +398,10 @@ impl VfioDeviceTrait for VvuDevice { self.backend_channel = Some(VfioEndpoint::from(BackendChannel { receiver: VfioReceiver::new(backend_rxq_receiver, backend_rxq_evt), queue: txq.clone(), + inner: BackendChannelInner { + pending_unmap: None, + vfio: device.vfio_dev.clone(), + }, tx_state: EndpointTxBuffer::default(), })); @@ -409,7 +454,7 @@ impl VfioDeviceTrait for VvuDevice { } DeviceState::Running { txq, tx_state, .. } => { let mut queue = txq.lock(); - queue.send_bufs(iovs, fds, tx_state) + queue.send_bufs(iovs, fds, tx_state, None) } } } @@ -419,7 +464,7 @@ impl VfioDeviceTrait for VvuDevice { DeviceState::Initialized { .. } => Err(RecvIntoBufsError::Fatal(anyhow!( "VvuDevice hasn't started yet" ))), - DeviceState::Running { rxq_receiver, .. } => rxq_receiver.recv_into_bufs(bufs), + DeviceState::Running { rxq_receiver, .. } => rxq_receiver.recv_into_bufs(bufs, None), } } @@ -432,10 +477,20 @@ impl VfioDeviceTrait for VvuDevice { } } +// State of the backend channel not directly related to sending/receiving data. +struct BackendChannelInner { + vfio: Arc, + + // Offset of the pending unmap operation. Set when an unmap message is sent, + // and cleared when the reply is recieved. + pending_unmap: Option, +} + // Struct which implements the Endpoint for backend messages. struct BackendChannel { receiver: VfioReceiver, queue: Arc>, + inner: BackendChannelInner, tx_state: EndpointTxBuffer, } @@ -449,11 +504,16 @@ impl VfioDeviceTrait for BackendChannel { } fn send_bufs(&mut self, iovs: &[IoSlice], fds: Option<&[RawFd]>) -> Result { - self.queue.lock().send_bufs(iovs, fds, &mut self.tx_state) + self.queue.lock().send_bufs( + iovs, + fds, + &mut self.tx_state, + Some(&mut self.inner as &mut BackendChannelInner), + ) } fn recv_into_bufs(&mut self, bufs: &mut [IoSliceMut]) -> Result { - self.receiver.recv_into_bufs(bufs) + self.receiver.recv_into_bufs(bufs, Some(&mut self.inner)) } fn create_slave_request_endpoint(&mut self) -> Result>> { @@ -462,3 +522,101 @@ impl VfioDeviceTrait for BackendChannel { )) } } + +impl BackendChannelInner { + // Preprocess messages before forwarding them to the virtqueue. Returns the bytes to + // send to the host. + fn preprocess_tx( + &mut self, + mut msg: Vec, + mut files: Vec, + ) -> Result> { + // msg came from a ProtocolReader, so this can't fail. + let hdr = vhost_header_from_bytes::(&msg).expect("framing error"); + let msg_type = hdr.get_code(); + + match msg_type { + SlaveReq::SHMEM_MAP => { + let file = files.pop().context("missing file to mmap")?; + + // msg came from a ProtoclReader, so this can't fail. + let mut msg = vhost_body_from_message_bytes::(&mut msg) + .expect("framing error"); + + let mapping = MemoryMappingBuilder::new(msg.len as usize) + .from_descriptor(&file) + .offset(msg.fd_offset) + .protection(Protection::from(msg.flags.bits() as libc::c_int)) + .build() + .context("failed to map file")?; + + let iova = self + .vfio + .alloc_iova(msg.len, 4096, Alloc::Anon(msg.shm_offset as usize)) + .context("failed to allocate iova")?; + // Safe because we're mapping an external file. + unsafe { + self.vfio + .vfio_dma_map(iova, msg.len, mapping.as_ptr() as u64, true) + .context("failed to map into IO address space")?; + } + + // The udmabuf constructed in the hypervisor corresponds to the region + // we mmap'ed, so fd_offset is no longer necessary. Reuse it for the + // iova. + msg.fd_offset = iova; + } + SlaveReq::SHMEM_UNMAP => { + if self.pending_unmap.is_some() { + bail!("overlapping unmap requests"); + } + + let msg = vhost_body_from_message_bytes::(&mut msg) + .expect("framing error"); + match self.vfio.get_iova(&Alloc::Anon(msg.shm_offset as usize)) { + None => bail!("unmap doesn't match mapped allocation"), + Some(range) => { + if !range.len().map_or(false, |l| l == msg.len) { + bail!("unmap size mismatch"); + } + } + } + + self.pending_unmap = Some(msg.shm_offset) + } + _ => (), + } + + if !files.is_empty() { + bail!("{} unhandled files for {:?}", files.len(), msg_type); + } + + Ok(msg) + } + + // Postprocess replies recieved from the virtqueue. This occurs after the + // replies have been forwarded to the endpoint. + fn postprocess_rx(&mut self, msg: Vec) -> Result<()> { + // msg are provided by ProtocolReader, so this can't fail. + let hdr = vhost_header_from_bytes::(&msg).unwrap(); + + if hdr.get_code() == SlaveReq::SHMEM_UNMAP { + let offset = self + .pending_unmap + .take() + .ok_or(RecvIntoBufsError::Fatal(anyhow!( + "unexpected unmap response" + )))?; + + let r = self + .vfio + .release_iova(Alloc::Anon(offset as usize)) + .expect("corrupted IOVA address space"); + self.vfio + .vfio_dma_unmap(r.start, r.len().unwrap()) + .context("failed to unmap memory")?; + } + + Ok(()) + } +} diff --git a/devices/src/virtio/vhost/user/proxy.rs b/devices/src/virtio/vhost/user/proxy.rs index fbc3e02daa..5f244c81d9 100644 --- a/devices/src/virtio/vhost/user/proxy.rs +++ b/devices/src/virtio/vhost/user/proxy.rs @@ -50,6 +50,8 @@ use vm_control::VmMemoryDestination; use vm_control::VmMemoryRequest; use vm_control::VmMemoryResponse; use vm_control::VmMemorySource; +use vm_memory::udmabuf::UdmabufDriver; +use vm_memory::GuestAddress; use vm_memory::GuestMemory; use vmm_vhost::connection::socket::Endpoint as SocketEndpoint; use vmm_vhost::connection::EndpointExt; @@ -60,6 +62,8 @@ use vmm_vhost::message::VhostUserMemory; use vmm_vhost::message::VhostUserMemoryRegion; use vmm_vhost::message::VhostUserMsgHeader; use vmm_vhost::message::VhostUserMsgValidator; +use vmm_vhost::message::VhostUserShmemMapMsg; +use vmm_vhost::message::VhostUserShmemMapMsgFlags; use vmm_vhost::message::VhostUserU64; use vmm_vhost::Protocol; use vmm_vhost::SlaveReqHelper; @@ -71,6 +75,8 @@ use crate::pci::PciBarRegionType; use crate::pci::PciCapability; use crate::pci::PciCapabilityID; use crate::virtio::copy_config; +use crate::virtio::ipc_memory_mapper::IpcMemoryMapper; +use crate::virtio::vhost::vhost_body_from_message_bytes; use crate::virtio::vhost::vhost_header_from_bytes; use crate::virtio::DescriptorChain; use crate::virtio::DeviceType; @@ -261,6 +267,12 @@ struct Worker { // Channel for backend mesages. slave_req_fd: Option>, + + // Driver for exporting memory as udmabufs for shared memory regions. + udmabuf_driver: Option, + + // Iommu to translate IOVAs into GPAs for shared memory regions. + iommu: Arc>, } #[derive(EventToken, Debug, Clone, PartialEq)] @@ -891,6 +903,7 @@ impl Worker { bail!("invalid file count for SET_SLAVE_REQ_FD {}", files.len()); } + self.udmabuf_driver = Some(UdmabufDriver::new().context("failed to get udmabuf driver")?); // Safe because we own the file. let socket = unsafe { UnixStream::from_raw_descriptor(file.into_raw_descriptor()) }; @@ -902,11 +915,65 @@ impl Worker { Ok(()) } + // Exports the udmabuf necessary to fulfil the |msg| mapping request. + fn export_udmabuf(&mut self, msg: &VhostUserShmemMapMsg) -> Result> { + let regions = self + .iommu + .lock() + .translate(msg.fd_offset, msg.len) + .context("failed to translate")?; + + let prot = match ( + msg.flags.contains(VhostUserShmemMapMsgFlags::MAP_R), + msg.flags.contains(VhostUserShmemMapMsgFlags::MAP_W), + ) { + (true, true) => Protection::read_write(), + (true, false) => Protection::read(), + (false, true) => Protection::write(), + (false, false) => bail!("unsupported protection"), + }; + let regions = regions + .iter() + .map(|r| { + if !r.prot.allows(&prot) { + Err(anyhow!("invalid permissions")) + } else { + Ok((r.gpa, r.len as usize)) + } + }) + .collect::>>()?; + + // udmabuf_driver is set at the same time as slave_req_fd, so if we've + // received a message on slave_req_fd, udmabuf_driver must be present. + let udmabuf = self + .udmabuf_driver + .as_ref() + .expect("missing udmabuf driver") + .create_udmabuf(&self.mem, ®ions) + .context("failed to create udmabuf")?; + + Ok(Box::new(udmabuf)) + } + fn process_message_from_backend( &mut self, - msg: Vec, + mut msg: Vec, ) -> Result<(Vec, Option>)> { - Ok((msg, None)) + // The message was already parsed as a MasterReq, so this can't fail + let hdr = vhost_header_from_bytes::(&msg).unwrap(); + + let fd = if hdr.get_code() == SlaveReq::SHMEM_MAP { + let mut msg = vhost_body_from_message_bytes(&mut msg).context("incomplete message")?; + let fd = self.export_udmabuf(msg).context("failed to export fd")?; + // VVU reuses the fd_offset field for the IOVA of the buffer. The + // udmabuf corresponds to exactly what should be mapped, so set + // fd_offset to 0 for regular vhost-user. + msg.fd_offset = 0; + Some(fd) + } else { + None + }; + Ok((msg, fd)) } // Processes data from the device backend (via virtio Tx queue) and forward it to @@ -1071,6 +1138,8 @@ enum State { tx_queue: Queue, rx_queue_evt: Event, tx_queue_evt: Event, + + iommu: Arc>, }, /// The worker thread is running. Running { @@ -1133,6 +1202,8 @@ pub struct VirtioVhostUser { // The value is wrapped with `Arc>` because it can be modified from the worker thread // as well as the main device thread. state: Arc>, + + iommu: Option>>, } impl VirtioVhostUser { @@ -1164,6 +1235,7 @@ impl VirtioVhostUser { listener, })), pci_address, + iommu: None, }) } @@ -1309,6 +1381,7 @@ impl VirtioVhostUser { tx_queue, rx_queue_evt, tx_queue_evt, + iommu, ) = match old_state { State::Activated { main_process_tube, @@ -1319,6 +1392,7 @@ impl VirtioVhostUser { tx_queue, rx_queue_evt, tx_queue_evt, + iommu, } => ( main_process_tube, listener, @@ -1328,6 +1402,7 @@ impl VirtioVhostUser { tx_queue, rx_queue_evt, tx_queue_evt, + iommu, ), s => { // Unreachable because we've checked the state at the beginning of this function. @@ -1386,6 +1461,8 @@ impl VirtioVhostUser { slave_req_helper, registered_memory: Vec::new(), slave_req_fd: None, + udmabuf_driver: None, + iommu: iommu.clone(), }; match worker.run( rx_queue_evt.try_clone().unwrap(), @@ -1421,6 +1498,7 @@ impl VirtioVhostUser { tx_queue, rx_queue_evt, tx_queue_evt, + iommu, }; Ok(()) @@ -1590,6 +1668,10 @@ impl VirtioDevice for VirtioVhostUser { )] } + fn set_iommu(&mut self, iommu: &Arc>) { + self.iommu = Some(iommu.clone()); + } + fn activate( &mut self, mem: GuestMemory, @@ -1620,6 +1702,7 @@ impl VirtioDevice for VirtioVhostUser { tx_queue: queues.remove(0), rx_queue_evt: queue_evts.remove(0), tx_queue_evt: queue_evts.remove(0), + iommu: self.iommu.take().unwrap(), }; } s => { diff --git a/devices/src/virtio/virtio_device.rs b/devices/src/virtio/virtio_device.rs index bec166239a..2014802e2e 100644 --- a/devices/src/virtio/virtio_device.rs +++ b/devices/src/virtio/virtio_device.rs @@ -2,12 +2,15 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +use std::sync::Arc; + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] use acpi_tables::sdt::SDT; use anyhow::Result; use base::Event; use base::Protection; use base::RawDescriptor; +use sync::Mutex; use vm_control::VmMemorySource; use vm_memory::GuestAddress; use vm_memory::GuestMemory; @@ -18,6 +21,7 @@ use crate::pci::PciAddress; use crate::pci::PciBarConfiguration; use crate::pci::PciBarIndex; use crate::pci::PciCapability; +use crate::virtio::ipc_memory_mapper::IpcMemoryMapper; #[derive(Clone)] pub struct SharedMemoryRegion { @@ -95,6 +99,12 @@ pub trait VirtioDevice: Send { let _ = data; } + /// If the device is translated by an IOMMU, called before + /// |activate| with the IOMMU's mapper. + fn set_iommu(&mut self, iommu: &Arc>) { + let _ = iommu; + } + /// Activates this device for real usage. fn activate( &mut self, diff --git a/devices/src/virtio/virtio_pci_device.rs b/devices/src/virtio/virtio_pci_device.rs index b890458441..4ee528a9ad 100644 --- a/devices/src/virtio/virtio_pci_device.rs +++ b/devices/src/virtio/virtio_pci_device.rs @@ -510,6 +510,10 @@ impl VirtioPciDevice { .filter(|(q, _)| q.ready()) .unzip(); + if let Some(iommu) = &self.iommu { + self.device.set_iommu(iommu); + } + self.device.activate(mem, interrupt, queues, queue_evts); self.device_activated = true; } diff --git a/resources/src/address_allocator.rs b/resources/src/address_allocator.rs index 851a8eae79..8568698384 100644 --- a/resources/src/address_allocator.rs +++ b/resources/src/address_allocator.rs @@ -253,17 +253,18 @@ impl AddressAllocator { } } - /// Releases exising allocation back to free pool. - pub fn release(&mut self, alloc: Alloc) -> Result<()> { + /// Releases exising allocation back to free pool and returns the range that was released. + pub fn release(&mut self, alloc: Alloc) -> Result { if let Some((range, _tag)) = self.allocs.remove(&alloc) { - self.insert_at(range) + self.insert_at(range)?; + Ok(range) } else { Err(Error::BadAlloc(alloc)) } } /// Release a allocation contains the value. - pub fn release_containing(&mut self, value: u64) -> Result<()> { + pub fn release_containing(&mut self, value: u64) -> Result { if let Some(alloc) = self.find_overlapping(AddressRange { start: value, end: value, @@ -422,7 +423,7 @@ impl<'a> AddressAllocatorSet<'a> { last_res } - pub fn release(&mut self, alloc: Alloc) -> Result<()> { + pub fn release(&mut self, alloc: Alloc) -> Result { let mut last_res = Err(Error::OutOfSpace); for allocator in self.allocators.iter_mut() { last_res = allocator.release(alloc); diff --git a/resources/src/system_allocator.rs b/resources/src/system_allocator.rs index c8c91d5f33..3fc480a387 100644 --- a/resources/src/system_allocator.rs +++ b/resources/src/system_allocator.rs @@ -417,7 +417,13 @@ mod tests { ), Ok(()) ); - assert_eq!(a.mmio_allocator(MmioType::Low).release(id), Ok(())); + assert_eq!( + a.mmio_allocator(MmioType::Low).release(id), + Ok(AddressRange { + start: 0x3000_5000, + end: 0x30009fff + }) + ); assert_eq!( a.reserve_mmio(AddressRange { start: 0x3000_2000, diff --git a/vm_memory/Cargo.toml b/vm_memory/Cargo.toml index 6dc0c700e1..6a4c233179 100644 --- a/vm_memory/Cargo.toml +++ b/vm_memory/Cargo.toml @@ -15,6 +15,3 @@ bitflags = "1" remain = "*" serde = { version = "1", features = [ "derive" ] } thiserror = "*" - -[features] -udmabuf = [] diff --git a/vm_memory/src/lib.rs b/vm_memory/src/lib.rs index 35c96ed39d..edaaf90cb9 100644 --- a/vm_memory/src/lib.rs +++ b/vm_memory/src/lib.rs @@ -8,7 +8,7 @@ mod guest_address; pub mod guest_memory; cfg_if::cfg_if! { - if #[cfg(all(unix, feature = "udmabuf"))] { + if #[cfg(unix)] { pub mod udmabuf; mod udmabuf_bindings; }