diff --git a/vhost/src/lib.rs b/vhost/src/lib.rs index b474bf2f7c..cb57e662eb 100644 --- a/vhost/src/lib.rs +++ b/vhost/src/lib.rs @@ -202,42 +202,6 @@ pub trait Vhost: AsRawDescriptor + std::marker::Sized { Ok(()) } - // TODO(smbarber): This is copypasta. Eliminate the copypasta. - #[allow(clippy::if_same_then_else)] - fn is_valid( - &self, - mem: &GuestMemory, - queue_max_size: u16, - queue_size: u16, - desc_addr: GuestAddress, - avail_addr: GuestAddress, - used_addr: GuestAddress, - ) -> bool { - let desc_table_size = 16 * queue_size as usize; - let avail_ring_size = 6 + 2 * queue_size as usize; - let used_ring_size = 6 + 8 * queue_size as usize; - if queue_size > queue_max_size || queue_size == 0 || (queue_size & (queue_size - 1)) != 0 { - false - } else if desc_addr - .checked_add(desc_table_size as u64) - .map_or(true, |v| !mem.address_in_range(v)) - { - false - } else if avail_addr - .checked_add(avail_ring_size as u64) - .map_or(true, |v| !mem.address_in_range(v)) - { - false - } else if used_addr - .checked_add(used_ring_size as u64) - .map_or(true, |v| !mem.address_in_range(v)) - { - false - } else { - true - } - } - /// Set the addresses for a given vring. /// /// # Arguments @@ -261,28 +225,27 @@ pub trait Vhost: AsRawDescriptor + std::marker::Sized { avail_addr: GuestAddress, log_addr: Option, ) -> Result<()> { - // TODO(smbarber): Refactor out virtio from crosvm so we can - // validate a Queue struct directly. - if !self.is_valid( - mem, - queue_max_size, - queue_size, - desc_addr, - used_addr, - avail_addr, - ) { + if queue_size > queue_max_size || queue_size == 0 || !queue_size.is_power_of_two() { return Err(Error::InvalidQueue); } - let desc_addr = mem - .get_host_address(desc_addr) + let queue_size = usize::from(queue_size); + + let desc_table_size = 16 * queue_size; + let desc_table = mem + .get_slice_at_addr(desc_addr, desc_table_size) .map_err(Error::DescriptorTableAddress)?; - let used_addr = mem - .get_host_address(used_addr) + + let used_ring_size = 6 + 8 * queue_size; + let used_ring = mem + .get_slice_at_addr(used_addr, used_ring_size) .map_err(Error::UsedAddress)?; - let avail_addr = mem - .get_host_address(avail_addr) + + let avail_ring_size = 6 + 2 * queue_size; + let avail_ring = mem + .get_slice_at_addr(avail_addr, avail_ring_size) .map_err(Error::AvailAddress)?; + let log_addr = match log_addr { None => null(), Some(a) => mem.get_host_address(a).map_err(Error::LogAddress)?, @@ -291,9 +254,9 @@ pub trait Vhost: AsRawDescriptor + std::marker::Sized { let vring_addr = virtio_sys::vhost::vhost_vring_addr { index: queue_index as u32, flags, - desc_user_addr: desc_addr as u64, - used_user_addr: used_addr as u64, - avail_user_addr: avail_addr as u64, + desc_user_addr: desc_table.as_ptr() as u64, + used_user_addr: used_ring.as_ptr() as u64, + avail_user_addr: avail_ring.as_ptr() as u64, log_guest_addr: log_addr as u64, };