devices: virtio: overhaul DescriptorChain API

Summary: DescriptorChain now reads the full chain when it is created.

Previously, DescriptorChain worked more like an iterator, returning one
descriptor at a time. This meant that DescriptorChain::checked_new()
returning Ok(...) did not guarantee the whole chain was valid, so
callers - in particular, Reader::new() and Writer::new() - had to be
fallible so they could report errors.

With the new API, DescriptorChain::new() reads the whole chain of
descriptors before returning, validating the required properties that
were previously deferred to Reader/Writer creation errors. This means
that Reader::new() and Writer::new() can now be infallible, which
eliminates error cases all over virtio device code.

Since the Reader::new() and Writer::new() function signatures are
changing anyway, take the opportunity to remove the redundant
GuestMemory parameter to those functions; DescriptorChain already holds
a reference to the GuestMemory.

The internal structure of DescriptorChain has been modified in
preparation for supporting packed virtqueues (in addition to the current
split virtqueues). The queue-type-specific code has been factored into a
separate trait implementation (DescriptorChainIter) so it is separate
from the high-level descriptor chain validation logic and IOMMU
translation.

BUG=None
TEST=tools/dev_container tools/presubmit

Change-Id: I48fd44b7f3e8b509dcb3683864a3f9621c744c4c
Reviewed-on: https://chromium-review.googlesource.com/c/crosvm/crosvm/+/4391797
Commit-Queue: Daniel Verkamp <dverkamp@chromium.org>
Reviewed-by: David Stevens <stevensd@chromium.org>
Reviewed-by: Keiichi Watanabe <keiichiw@chromium.org>
Reviewed-by: Maciek Swiech <drmasquatch@google.com>
This commit is contained in:
Daniel Verkamp 2023-03-29 15:11:22 -07:00 committed by crosvm LUCI
parent 74d423fb04
commit 63f50362ec
39 changed files with 929 additions and 1148 deletions

View file

@ -48,8 +48,8 @@ mod fuzzer {
)
.unwrap();
let r = Reader::new(mem.clone(), chain.clone()).unwrap();
let w = Writer::new(mem.clone(), chain).unwrap();
let r = Reader::new(&chain);
let w = Writer::new(&chain);
fuzz_server(r, w);
});
});

View file

@ -9,7 +9,6 @@ use std::mem::size_of;
use cros_fuzz::fuzz_target;
use cros_fuzz::rand::FuzzRng;
use devices::virtio::DescriptorChain;
use devices::virtio::Queue;
use rand::Rng;
use rand::RngCore;
@ -107,16 +106,11 @@ fuzz_target!(|data: &[u8]| {
mem.write_all_at_addr(&buf[..], q.used_ring()).unwrap();
while let Some(avail_desc) = q.pop(mem) {
let idx = avail_desc.index;
let total = avail_desc
.into_iter()
.filter(DescriptorChain::is_write_only)
.try_fold(0u32, |sum, cur| sum.checked_add(cur.len));
if let Some(len) = total {
q.add_used(mem, idx, len);
} else {
q.add_used(mem, idx, 0);
}
.writable_mem_regions()
.iter()
.try_fold(0u32, |sum, cur| sum.checked_add(cur.len as u32));
q.add_used(mem, avail_desc, total.unwrap_or(0));
}
});
});

View file

@ -47,9 +47,7 @@ use zerocopy::FromBytes;
use super::async_utils;
use super::copy_config;
use super::descriptor_utils;
use super::DescriptorChain;
use super::DescriptorError;
use super::DeviceType;
use super::Interrupt;
use super::Queue;
@ -73,9 +71,6 @@ pub enum BalloonError {
/// Failed to create async message receiver.
#[error("failed to create async message receiver: {0}")]
CreatingMessageReceiver(base::TubeError),
/// Virtio descriptor error
#[error("virtio descriptor error: {0}")]
Descriptor(DescriptorError),
/// Failed to receive command message.
#[error("failed to receive command message: {0}")]
ReceivingCommand(base::TubeError),
@ -258,7 +253,7 @@ fn release_ranges<F>(
release_memory_tube: Option<&Tube>,
inflate_ranges: Vec<(u64, u64)>,
desc_handler: &mut F,
) -> descriptor_utils::Result<()>
) -> anyhow::Result<()>
where
F: FnMut(GuestAddress, u64),
{
@ -296,10 +291,9 @@ where
// Processes one message's list of addresses.
fn handle_address_chain<F>(
release_memory_tube: Option<&Tube>,
avail_desc: DescriptorChain,
mem: &GuestMemory,
avail_desc: &DescriptorChain,
desc_handler: &mut F,
) -> descriptor_utils::Result<()>
) -> anyhow::Result<()>
where
F: FnMut(GuestAddress, u64),
{
@ -309,7 +303,7 @@ where
// gains in a newly booted system, so it's worth attempting.
let mut range_start = 0;
let mut range_size = 0;
let mut reader = Reader::new(mem.clone(), avail_desc)?;
let mut reader = Reader::new(avail_desc);
let mut inflate_ranges: Vec<(u64, u64)> = Vec::new();
for res in reader.iter::<Le32>() {
let pfn = match res {
@ -361,13 +355,10 @@ async fn handle_queue<F>(
}
Ok(d) => d,
};
let index = avail_desc.index;
if let Err(e) =
handle_address_chain(release_memory_tube, avail_desc, mem, &mut desc_handler)
{
if let Err(e) = handle_address_chain(release_memory_tube, &avail_desc, &mut desc_handler) {
error!("balloon: failed to process inflate addresses: {}", e);
}
queue.add_used(mem, index, 0);
queue.add_used(mem, avail_desc, 0);
queue.trigger_interrupt(mem, &interrupt);
}
}
@ -375,20 +366,18 @@ async fn handle_queue<F>(
// Processes one page-reporting descriptor.
fn handle_reported_buffer<F>(
release_memory_tube: Option<&Tube>,
avail_desc: DescriptorChain,
avail_desc: &DescriptorChain,
desc_handler: &mut F,
) -> descriptor_utils::Result<()>
) -> anyhow::Result<()>
where
F: FnMut(GuestAddress, u64),
{
let mut reported_ranges: Vec<(u64, u64)> = Vec::new();
let regions = avail_desc.into_iter();
for desc in regions {
let (desc_regions, _exported) = desc.into_mem_regions();
for r in desc_regions {
reported_ranges.push((r.gpa.offset(), r.len));
}
}
let reported_ranges: Vec<(u64, u64)> = avail_desc
.readable_mem_regions()
.iter()
.chain(avail_desc.writable_mem_regions().iter())
.map(|r| (r.offset, r.len as u64))
.collect();
release_ranges(release_memory_tube, reported_ranges, desc_handler)
}
@ -412,11 +401,11 @@ async fn handle_reporting_queue<F>(
}
Ok(d) => d,
};
let index = avail_desc.index;
if let Err(e) = handle_reported_buffer(release_memory_tube, avail_desc, &mut desc_handler) {
if let Err(e) = handle_reported_buffer(release_memory_tube, &avail_desc, &mut desc_handler)
{
error!("balloon: failed to process reported buffer: {}", e);
}
queue.add_used(mem, index, 0);
queue.add_used(mem, avail_desc, 0);
queue.trigger_interrupt(mem, &interrupt);
}
}
@ -451,12 +440,12 @@ async fn handle_stats_queue(
) {
// Consume the first stats buffer sent from the guest at startup. It was not
// requested by anyone, and the stats are stale.
let mut index = match queue.next_async(mem, &mut queue_event).await {
let mut avail_desc = match queue.next_async(mem, &mut queue_event).await {
Err(e) => {
error!("Failed to read descriptor {}", e);
return;
}
Ok(d) => d.index,
Ok(d) => d,
};
loop {
// Wait for a request to read the stats.
@ -469,24 +458,17 @@ async fn handle_stats_queue(
};
// Request a new stats_desc to the guest.
queue.add_used(mem, index, 0);
queue.add_used(mem, avail_desc, 0);
queue.trigger_interrupt(mem, &interrupt);
let stats_desc = match queue.next_async(mem, &mut queue_event).await {
avail_desc = match queue.next_async(mem, &mut queue_event).await {
Err(e) => {
error!("Failed to read descriptor {}", e);
return;
}
Ok(d) => d,
};
index = stats_desc.index;
let mut reader = match Reader::new(mem.clone(), stats_desc) {
Ok(r) => r,
Err(e) => {
error!("balloon: failed to CREATE Reader: {}", e);
continue;
}
};
let mut reader = Reader::new(&avail_desc);
let stats = parse_balloon_stats(&mut reader);
let actual_pages = state.lock().await.actual_pages as u64;
@ -566,15 +548,10 @@ async fn handle_events_queue(
.next_async(mem, &mut queue_event)
.await
.map_err(BalloonError::AsyncAwait)?;
let index = avail_desc.index;
match Reader::new(mem.clone(), avail_desc) {
Ok(mut r) => {
handle_event(state.clone(), interrupt.clone(), &mut r, command_tube).await?
}
Err(e) => error!("balloon: failed to CREATE Reader: {}", e),
};
let mut reader = Reader::new(&avail_desc);
handle_event(state.clone(), interrupt.clone(), &mut reader, command_tube).await?;
queue.add_used(mem, index, 0);
queue.add_used(mem, avail_desc, 0);
queue.trigger_interrupt(mem, &interrupt);
}
}
@ -608,9 +585,7 @@ async fn handle_wss_op_queue(
.next_async(mem, &mut queue_event)
.await
.map_err(BalloonError::AsyncAwait)?;
let index = avail_desc.index;
let mut writer = Writer::new(mem.clone(), avail_desc).map_err(BalloonError::Descriptor)?;
let mut writer = Writer::new(&avail_desc);
match op {
WSSOp::WSSReport { id } => {
@ -636,7 +611,7 @@ async fn handle_wss_op_queue(
}
}
queue.add_used(mem, index, writer.bytes_written() as u32);
queue.add_used(mem, avail_desc, writer.bytes_written() as u32);
queue.trigger_interrupt(mem, &interrupt);
}
@ -687,15 +662,7 @@ async fn handle_wss_data_queue(
.next_async(mem, &mut queue_event)
.await
.map_err(BalloonError::AsyncAwait)?;
let index = avail_desc.index;
let mut reader = match Reader::new(mem.clone(), avail_desc) {
Ok(r) => r,
Err(e) => {
error!("balloon: failed to CREATE Reader: {}", e);
continue;
}
};
let mut reader = Reader::new(&avail_desc);
let wss = parse_balloon_wss(&mut reader);
// Closure to hold the mutex for handling a WSS-R command response
@ -731,7 +698,7 @@ async fn handle_wss_data_queue(
}
}
queue.add_used(mem, index, 0);
queue.add_used(mem, avail_desc, 0);
queue.trigger_interrupt(mem, &interrupt);
}
}
@ -1314,7 +1281,7 @@ mod tests {
.expect("create_descriptor_chain failed");
let mut addrs = Vec::new();
let res = handle_address_chain(None, chain, &memory, &mut |guest_address, len| {
let res = handle_address_chain(None, &chain, &mut |guest_address, len| {
addrs.push((guest_address, len));
});
assert!(res.is_ok());

View file

@ -79,7 +79,6 @@ use crate::virtio::device_constants::block::VIRTIO_BLK_T_OUT;
use crate::virtio::device_constants::block::VIRTIO_BLK_T_WRITE_ZEROES;
use crate::virtio::vhost::user::device::VhostBackendReqConnectionState;
use crate::virtio::DescriptorChain;
use crate::virtio::DescriptorError;
use crate::virtio::DeviceType;
use crate::virtio::Interrupt;
use crate::virtio::Queue;
@ -109,8 +108,6 @@ const DISCARD_SECTOR_ALIGNMENT: u32 = 128;
pub enum ExecuteError {
#[error("failed to copy ID string: {0}")]
CopyId(io::Error),
#[error("virtio descriptor error: {0}")]
Descriptor(DescriptorError),
#[error("failed to perform discard or write zeroes; sector={sector} num_sectors={num_sectors} flags={flags}; {ioerr:?}")]
DiscardWriteZeroes {
ioerr: Option<disk::Error>,
@ -156,7 +153,6 @@ impl ExecuteError {
fn status(&self) -> u8 {
match self {
ExecuteError::CopyId(_) => VIRTIO_BLK_S_IOERR,
ExecuteError::Descriptor(_) => VIRTIO_BLK_S_IOERR,
ExecuteError::DiscardWriteZeroes { .. } => VIRTIO_BLK_S_IOERR,
ExecuteError::Flush(_) => VIRTIO_BLK_S_IOERR,
ExecuteError::MissingStatus => VIRTIO_BLK_S_IOERR,
@ -234,15 +230,13 @@ impl DiskState {
}
async fn process_one_request(
avail_desc: DescriptorChain,
avail_desc: &DescriptorChain,
disk_state: Rc<AsyncMutex<DiskState>>,
flush_timer: Rc<RefCell<TimerAsync>>,
flush_timer_armed: Rc<RefCell<bool>>,
mem: &GuestMemory,
) -> result::Result<usize, ExecuteError> {
let mut reader =
Reader::new(mem.clone(), avail_desc.clone()).map_err(ExecuteError::Descriptor)?;
let mut writer = Writer::new(mem.clone(), avail_desc).map_err(ExecuteError::Descriptor)?;
let mut reader = Reader::new(avail_desc);
let mut writer = Writer::new(avail_desc);
// The last byte of the buffer is virtio_blk_req::status.
// Split it into a separate Writer so that status_writer is the final byte and
@ -287,11 +281,8 @@ pub async fn process_one_chain<I: SignalableInterrupt>(
flush_timer: Rc<RefCell<TimerAsync>>,
flush_timer_armed: Rc<RefCell<bool>>,
) {
let descriptor_index = avail_desc.index;
let len =
match process_one_request(avail_desc, disk_state, flush_timer, flush_timer_armed, &mem)
.await
{
match process_one_request(&avail_desc, disk_state, flush_timer, flush_timer_armed).await {
Ok(len) => len,
Err(e) => {
error!("block: failed to handle request: {}", e);
@ -300,7 +291,7 @@ pub async fn process_one_chain<I: SignalableInterrupt>(
};
let mut queue = queue.borrow_mut();
queue.add_used(&mem, descriptor_index, len as u32);
queue.add_used(&mem, avail_desc, len as u32);
queue.trigger_interrupt(&mem, interrupt);
}
@ -1281,7 +1272,7 @@ mod tests {
})),
}));
let fut = process_one_request(avail_desc, disk_state, flush_timer, flush_timer_armed, &mem);
let fut = process_one_request(&avail_desc, disk_state, flush_timer, flush_timer_armed);
ex.run_until(fut)
.expect("running executor failed")
@ -1344,7 +1335,7 @@ mod tests {
})),
}));
let fut = process_one_request(avail_desc, disk_state, flush_timer, flush_timer_armed, &mem);
let fut = process_one_request(&avail_desc, disk_state, flush_timer, flush_timer_armed);
ex.run_until(fut)
.expect("running executor failed")
@ -1409,7 +1400,7 @@ mod tests {
})),
}));
let fut = process_one_request(avail_desc, disk_state, flush_timer, flush_timer_armed, &mem);
let fut = process_one_request(&avail_desc, disk_state, flush_timer, flush_timer_armed);
ex.run_until(fut)
.expect("running executor failed")

View file

@ -90,16 +90,8 @@ fn handle_input<I: SignalableInterrupt>(
let desc = receive_queue
.peek(mem)
.ok_or(ConsoleError::RxDescriptorsExhausted)?;
let desc_index = desc.index;
// TODO(morg): Handle extra error cases as Err(ConsoleError) instead of just returning.
let mut writer = match Writer::new(mem.clone(), desc) {
Ok(w) => w,
Err(e) => {
error!("console: failed to create Writer: {}", e);
return Ok(());
}
};
let mut writer = Writer::new(&desc);
while writer.available_bytes() > 0 && !buffer.is_empty() {
let (buffer_front, buffer_back) = buffer.as_slices();
let buffer_chunk = if !buffer_front.is_empty() {
@ -115,7 +107,7 @@ fn handle_input<I: SignalableInterrupt>(
if bytes_written > 0 {
receive_queue.pop_peeked(mem);
receive_queue.add_used(mem, desc_index, bytes_written);
receive_queue.add_used(mem, desc, bytes_written);
receive_queue.trigger_interrupt(mem, interrupt);
}
@ -156,17 +148,11 @@ fn process_transmit_queue<I: SignalableInterrupt>(
) {
let mut needs_interrupt = false;
while let Some(avail_desc) = transmit_queue.pop(mem) {
let desc_index = avail_desc.index;
let reader = Reader::new(&avail_desc);
process_transmit_request(reader, output)
.unwrap_or_else(|e| error!("console: process_transmit_request failed: {}", e));
match Reader::new(mem.clone(), avail_desc) {
Ok(reader) => process_transmit_request(reader, output)
.unwrap_or_else(|e| error!("console: process_transmit_request failed: {}", e)),
Err(e) => {
error!("console: failed to create reader: {}", e);
}
};
transmit_queue.add_used(mem, desc_index, 0);
transmit_queue.add_used(mem, avail_desc, 0);
needs_interrupt = true;
}

View file

@ -0,0 +1,388 @@
// Copyright 2023 The ChromiumOS Authors
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
//! Virtqueue descriptor chain abstraction
#![deny(missing_docs)]
use std::sync::Arc;
use anyhow::bail;
use anyhow::Context;
use anyhow::Result;
use base::trace;
use base::Protection;
use cros_async::MemRegion;
use data_model::Le16;
use data_model::Le32;
use data_model::Le64;
use smallvec::SmallVec;
use sync::Mutex;
use vm_memory::GuestAddress;
use vm_memory::GuestMemory;
use zerocopy::AsBytes;
use zerocopy::FromBytes;
use crate::virtio::ipc_memory_mapper::ExportedRegion;
use crate::virtio::ipc_memory_mapper::IpcMemoryMapper;
use crate::virtio::memory_util::read_obj_from_addr_wrapper;
const VIRTQ_DESC_F_NEXT: u16 = 0x1;
const VIRTQ_DESC_F_WRITE: u16 = 0x2;
/// A single virtio split queue descriptor (`struct virtq_desc` in the spec).
#[derive(Copy, Clone, Debug, FromBytes, AsBytes)]
#[repr(C)]
pub struct Desc {
/// Guest address of memory described by this descriptor.
pub addr: Le64,
/// Length of this descriptor's memory region in byutes.
pub len: Le32,
/// `VIRTQ_DESC_F_*` flags for this descriptor.
pub flags: Le16,
/// Index of the next descriptor in the chain (only valid if `flags & VIRTQ_DESC_F_NEXT`).
pub next: Le16,
}
/// Type of access allowed for a single virtio descriptor within a descriptor chain.
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum DescriptorAccess {
/// Descriptor is readable by the device (written by the driver before putting the descriptor
/// chain on the available queue).
DeviceRead,
/// Descriptor is writable by the device (read by the driver after the device puts the
/// descriptor chain on the used queue).
DeviceWrite,
}
/// A virtio descriptor chain.
///
/// This is a low-level representation of the memory regions in a descriptor chain. Most code should
/// use [`virtio::Reader`](crate::virtio::Reader) and [`virtio::Writer`](crate::virtio::Writer)
/// rather than working with `DescriptorChain` memory regions directly.
pub struct DescriptorChain {
mem: GuestMemory,
/// Index into the descriptor table.
index: u16,
/// The memory regions that make up the descriptor chain.
readable_regions: SmallVec<[MemRegion; 2]>,
writable_regions: SmallVec<[MemRegion; 2]>,
/// The exported iommu regions of the descriptors. Non-empty iff iommu is enabled.
exported_regions: Vec<ExportedRegion>,
}
impl DescriptorChain {
/// Read all descriptors from `chain` into a new `DescriptorChain` instance.
///
/// This function validates the following properties of the descriptor chain:
/// * The chain contains at least one descriptor.
/// * Each descriptor has a non-zero length.
/// * Each descriptor's memory falls entirely within a contiguous region of `mem`.
/// * The total length of the descriptor chain data is representable in `u32`.
///
/// If these properties do not hold, `Err` will be returned.
///
/// # Arguments
///
/// * `chain` - Iterator that will be walked to retrieve all of the descriptors in the chain.
/// * `mem` - The [`GuestMemory`] backing the descriptor table and descriptor memory regions.
/// * `index` - The index of the first descriptor in the chain.
/// * `iommu` - An optional IOMMU to use for mapping descriptor memory regions. If `iommu` is
/// specified, the descriptor memory regions will be mapped for the lifetime of this
/// `DescriptorChain` object.
pub fn new(
mut chain: impl DescriptorChainIter,
mem: &GuestMemory,
index: u16,
iommu: Option<Arc<Mutex<IpcMemoryMapper>>>,
) -> Result<DescriptorChain> {
let mut desc_chain = DescriptorChain {
mem: mem.clone(),
index,
readable_regions: SmallVec::new(),
writable_regions: SmallVec::new(),
exported_regions: Vec::new(),
};
while let Some(desc) = chain.next()? {
if desc.len == 0 {
bail!("invalid zero-length descriptor");
}
if let Some(iommu) = iommu.as_ref() {
desc_chain.add_descriptor_iommu(desc, iommu)?;
} else {
desc_chain.add_descriptor(desc)?;
}
}
desc_chain
.validate_mem_regions()
.context("invalid descriptor chain memory regions")?;
trace!(
"DescriptorChain readable={} writable={}",
desc_chain.readable_mem_regions().len(),
desc_chain.writable_mem_regions().len(),
);
Ok(desc_chain)
}
fn add_descriptor(&mut self, desc: Descriptor) -> Result<()> {
let region = MemRegion {
offset: desc.address,
len: desc.len as usize,
};
match desc.access {
DescriptorAccess::DeviceRead => self.readable_regions.push(region),
DescriptorAccess::DeviceWrite => self.writable_regions.push(region),
}
Ok(())
}
fn add_descriptor_iommu(
&mut self,
desc: Descriptor,
iommu: &Arc<Mutex<IpcMemoryMapper>>,
) -> Result<()> {
let exported_region =
ExportedRegion::new(&self.mem, iommu.clone(), desc.address, u64::from(desc.len))
.context("failed to get mem regions")?;
let regions = exported_region.get_mem_regions();
let required_prot = match desc.access {
DescriptorAccess::DeviceRead => Protection::read(),
DescriptorAccess::DeviceWrite => Protection::write(),
};
if !regions.iter().all(|r| r.prot.allows(&required_prot)) {
bail!("missing RW permissions for descriptor");
}
let regions_iter = regions.iter().map(|r| MemRegion {
offset: r.gpa.offset(),
len: r.len as usize,
});
match desc.access {
DescriptorAccess::DeviceRead => self.readable_regions.extend(regions_iter),
DescriptorAccess::DeviceWrite => self.writable_regions.extend(regions_iter),
}
self.exported_regions.push(exported_region);
Ok(())
}
fn validate_mem_regions(&self) -> Result<()> {
let mut total_len: u32 = 0;
for r in self
.readable_mem_regions()
.iter()
.chain(self.writable_mem_regions().iter())
{
// This cast is safe because the virtio descriptor length field is u32.
let len = r.len as u32;
// Check that all the regions are totally contained in GuestMemory.
if !self.mem.is_valid_range(GuestAddress(r.offset), len.into()) {
bail!(
"descriptor address range out of bounds: addr={:#x} len={:#x}",
r.offset,
r.len
);
}
// Verify that summing the descriptor sizes does not overflow.
// This can happen if a driver tricks a device into reading/writing more data than
// fits in a `u32`.
total_len = total_len
.checked_add(len)
.context("descriptor chain length overflow")?;
}
if total_len == 0 {
bail!("invalid zero-length descriptor chain");
}
Ok(())
}
/// Returns the driver-readable memory regions of this descriptor chain.
///
/// Each `MemRegion` is a contiguous range of guest physical memory.
pub fn readable_mem_regions(&self) -> &[MemRegion] {
&self.readable_regions
}
/// Returns the driver-writable memory regions of this descriptor chain.
///
/// Each `MemRegion` is a contiguous range of guest physical memory.
pub fn writable_mem_regions(&self) -> &[MemRegion] {
&self.writable_regions
}
/// Returns the IOMMU memory mapper regions for the memory regions in this chain.
///
/// Empty if IOMMU is not enabled.
pub fn exported_regions(&self) -> &[ExportedRegion] {
&self.exported_regions
}
/// Returns a reference to the [`GuestMemory`] instance.
pub fn mem(&self) -> &GuestMemory {
&self.mem
}
/// Returns the index of the first descriptor in the chain.
pub fn index(&self) -> u16 {
self.index
}
}
/// A single descriptor within a [`DescriptorChain`].
pub struct Descriptor {
/// Guest memory address of this descriptor.
/// If IOMMU is enabled, this is an IOVA, which must be translated via the IOMMU to get a
/// guest physical address.
pub address: u64,
/// Length of the descriptor in bytes.
pub len: u32,
/// Whether this descriptor should be treated as writable or readable by the device.
pub access: DescriptorAccess,
}
/// Iterator over the descriptors of a descriptor chain.
pub trait DescriptorChainIter {
/// Return the next descriptor in the chain, or `None` if there are no more descriptors.
fn next(&mut self) -> Result<Option<Descriptor>>;
}
/// Iterator over the descriptors of a split virtqueue descriptor chain.
pub struct SplitDescriptorChain<'m, 'd> {
/// Current descriptor index within `desc_table`, or `None` if the iterator is exhausted.
index: Option<u16>,
/// Number of descriptors returned by the iterator already.
/// If `count` reaches `queue_size`, the chain has a loop and is therefore invalid.
count: u16,
queue_size: u16,
/// If `writable` is true, a writable descriptor has already been encountered.
/// Valid descriptor chains must consist of readable descriptors followed by writable
/// descriptors.
writable: bool,
mem: &'m GuestMemory,
desc_table: GuestAddress,
exported_desc_table: Option<&'d ExportedRegion>,
}
impl<'m, 'd> SplitDescriptorChain<'m, 'd> {
/// Construct a new iterator over a split virtqueue descriptor chain.
///
/// # Arguments
/// * `mem` - The [`GuestMemory`] containing the descriptor chain.
/// * `desc_table` - Guest physical address of the descriptor table.
/// * `queue_size` - Total number of entries in the descriptor table.
/// * `index` - The index of the first descriptor in the chain.
/// * `exported_desc_table` - If specified, contains the IOMMU mapping of the descriptor table
/// region.
pub fn new(
mem: &'m GuestMemory,
desc_table: GuestAddress,
queue_size: u16,
index: u16,
exported_desc_table: Option<&'d ExportedRegion>,
) -> SplitDescriptorChain<'m, 'd> {
trace!("starting split descriptor chain head={index}");
SplitDescriptorChain {
index: Some(index),
count: 0,
queue_size,
writable: false,
mem,
desc_table,
exported_desc_table,
}
}
}
impl DescriptorChainIter for SplitDescriptorChain<'_, '_> {
fn next(&mut self) -> Result<Option<Descriptor>> {
let index = match self.index {
Some(index) => index,
None => return Ok(None),
};
if index >= self.queue_size {
bail!(
"out of bounds descriptor index {} for queue size {}",
index,
self.queue_size
);
}
if self.count >= self.queue_size {
bail!("descriptor chain loop detected");
}
self.count += 1;
let desc_addr = self
.desc_table
.checked_add((index as u64) * 16)
.context("integer overflow")?;
let desc =
read_obj_from_addr_wrapper::<Desc>(self.mem, self.exported_desc_table, desc_addr)
.with_context(|| format!("failed to read desc {:#x}", desc_addr.offset()))?;
let address: u64 = desc.addr.into();
let len: u32 = desc.len.into();
let flags: u16 = desc.flags.into();
let next: u16 = desc.next.into();
trace!("{index:5}: addr={address:#016x} len={len:#08x} flags={flags:#x}");
let unexpected_flags = flags & !(VIRTQ_DESC_F_WRITE | VIRTQ_DESC_F_NEXT);
if unexpected_flags != 0 {
bail!("unexpected flags in descriptor {index}: {unexpected_flags:#x}")
}
let access = if flags & VIRTQ_DESC_F_WRITE != 0 {
DescriptorAccess::DeviceWrite
} else {
DescriptorAccess::DeviceRead
};
if access == DescriptorAccess::DeviceRead && self.writable {
bail!("invalid device-readable descriptor following writable descriptors");
} else if access == DescriptorAccess::DeviceWrite {
self.writable = true;
}
self.index = if flags & VIRTQ_DESC_F_NEXT != 0 {
Some(next)
} else {
None
};
Ok(Some(Descriptor {
address,
len,
access,
}))
}
}

View file

@ -6,13 +6,11 @@ mod sys;
use std::borrow::Cow;
use std::cmp;
use std::convert::TryInto;
use std::io;
use std::io::Write;
use std::iter::FromIterator;
use std::marker::PhantomData;
use std::ptr::copy_nonoverlapping;
use std::result;
use std::sync::Arc;
use anyhow::Context;
@ -23,12 +21,9 @@ use data_model::zerocopy_from_reader;
use data_model::Le16;
use data_model::Le32;
use data_model::Le64;
use data_model::VolatileMemoryError;
use data_model::VolatileSlice;
use disk::AsyncDisk;
use remain::sorted;
use smallvec::SmallVec;
use thiserror::Error;
use vm_memory::GuestAddress;
use vm_memory::GuestMemory;
use zerocopy::AsBytes;
@ -36,38 +31,36 @@ use zerocopy::FromBytes;
use super::DescriptorChain;
use crate::virtio::ipc_memory_mapper::ExportedRegion;
#[sorted]
#[derive(Error, Debug)]
pub enum Error {
#[error("the combined length of all the buffers in a `DescriptorChain` would overflow")]
DescriptorChainOverflow,
#[error("descriptor guest memory error: {0}")]
GuestMemoryError(vm_memory::GuestMemoryError),
#[error("invalid descriptor chain: {0:#}")]
InvalidChain(anyhow::Error),
#[error("descriptor I/O error: {0}")]
IoError(io::Error),
#[error("`DescriptorChain` split is out of bounds: {0}")]
SplitOutOfBounds(usize),
#[error("volatile memory error: {0}")]
VolatileMemoryError(VolatileMemoryError),
}
pub type Result<T> = result::Result<T, Error>;
use crate::virtio::SplitDescriptorChain;
#[derive(Clone)]
struct DescriptorChainRegions {
regions: DescriptorChainMemRegions,
regions: SmallVec<[cros_async::MemRegion; 16]>,
// For virtio devices that operate on IOVAs rather than guest phyiscal
// addresses, the IOVA regions must be exported from virtio-iommu to get
// the underlying memory regions. It is only valid for the virtio device
// to access those memory regions while they remain exported, so maintain
// references to the exported regions until the descriptor chain is
// dropped.
_exported_regions: Vec<ExportedRegion>,
current: usize,
bytes_consumed: usize,
}
impl DescriptorChainRegions {
fn new(regions: &[cros_async::MemRegion], exported_regions: &[ExportedRegion]) -> Self {
DescriptorChainRegions {
regions: regions.into(),
_exported_regions: exported_regions.into(),
current: 0,
bytes_consumed: 0,
}
}
fn available_bytes(&self) -> usize {
// This is guaranteed not to overflow because the total length of the chain
// is checked during all creations of `DescriptorChainRegions` (see
// `Reader::new()` and `Writer::new()`).
// This is guaranteed not to overflow because the total length of the chain is checked
// during all creations of `DescriptorChain` (see `DescriptorChain::new()`).
self.get_remaining_regions()
.iter()
.fold(0usize, |count, region| count + region.len)
@ -82,7 +75,7 @@ impl DescriptorChainRegions {
/// method to advance the `DescriptorChain`. Multiple calls to `get` with no intervening calls
/// to `consume` will return the same data.
fn get_remaining_regions(&self) -> &[MemRegion] {
&self.regions.regions[self.current..]
&self.regions[self.current..]
}
/// Returns all the remaining buffers in the `DescriptorChain` as `VolatileSlice`s of the given
@ -153,7 +146,7 @@ impl DescriptorChainRegions {
// `get_remaining` here because then the compiler complains that `self.current` is already
// borrowed and doesn't allow us to modify it. We also need to borrow the iovecs mutably.
let current = self.current;
for region in &mut self.regions.regions[current..] {
for region in &mut self.regions[current..] {
if count == 0 {
break;
}
@ -186,7 +179,7 @@ impl DescriptorChainRegions {
let mut rem = offset;
let mut end = self.current;
for region in &mut self.regions.regions[self.current..] {
for region in &mut self.regions[self.current..] {
if rem <= region.len {
region.len = rem;
break;
@ -196,7 +189,7 @@ impl DescriptorChainRegions {
rem -= region.len;
}
self.regions.regions.truncate(end + 1);
self.regions.truncate(end + 1);
other
}
@ -233,72 +226,16 @@ impl<'a, T: FromBytes> Iterator for ReaderIterator<'a, T> {
}
}
#[derive(Clone)]
pub struct DescriptorChainMemRegions {
pub regions: SmallVec<[cros_async::MemRegion; 16]>,
// For virtio devices that operate on IOVAs rather than guest phyiscal
// addresses, the IOVA regions must be exported from virtio-iommu to get
// the underlying memory regions. It is only valid for the virtio device
// to access those memory regions while they remain exported, so maintain
// references to the exported regions until the descriptor chain is
// dropped.
_exported_regions: Option<Vec<ExportedRegion>>,
}
/// Get all the mem regions from a `DescriptorChain` iterator, regardless if the `DescriptorChain`
/// contains GPAs (guest physical address), or IOVAs (io virtual address). IOVAs will
/// be translated to GPAs via IOMMU.
pub fn get_mem_regions<I>(mem: &GuestMemory, vals: I) -> Result<DescriptorChainMemRegions>
where
I: Iterator<Item = DescriptorChain>,
{
let mut total_len: usize = 0;
let mut regions = SmallVec::new();
let mut exported_regions: Option<Vec<ExportedRegion>> = None;
// TODO(jstaron): Update this code to take the indirect descriptors into account.
for desc in vals {
// Verify that summing the descriptor sizes does not overflow.
// This can happen if a driver tricks a device into reading/writing more data than
// fits in a `usize`.
total_len = total_len
.checked_add(desc.len as usize)
.ok_or(Error::DescriptorChainOverflow)?;
let (desc_regions, exported) = desc.into_mem_regions();
for r in desc_regions {
// Check that all the regions are totally contained in GuestMemory.
mem.get_slice_at_addr(r.gpa, r.len.try_into().expect("u32 doesn't fit in usize"))
.map_err(Error::GuestMemoryError)?;
regions.push(cros_async::MemRegion {
offset: r.gpa.offset(),
len: r.len.try_into().expect("u32 doesn't fit in usize"),
});
}
if let Some(exported) = exported {
exported_regions.get_or_insert(vec![]).push(exported);
}
}
Ok(DescriptorChainMemRegions {
regions,
_exported_regions: exported_regions,
})
}
impl Reader {
/// Construct a new Reader wrapper over `desc_chain`.
pub fn new(mem: GuestMemory, desc_chain: DescriptorChain) -> Result<Reader> {
let regions = get_mem_regions(&mem, desc_chain.into_iter().readable())?;
Ok(Reader {
mem,
regions: DescriptorChainRegions {
regions,
current: 0,
bytes_consumed: 0,
},
})
pub fn new(desc_chain: &DescriptorChain) -> Reader {
Reader {
mem: desc_chain.mem().clone(),
regions: DescriptorChainRegions::new(
desc_chain.readable_mem_regions(),
desc_chain.exported_regions(),
),
}
}
/// Reads an object from the descriptor chain buffer.
@ -553,16 +490,14 @@ pub struct Writer {
impl Writer {
/// Construct a new Writer wrapper over `desc_chain`.
pub fn new(mem: GuestMemory, desc_chain: DescriptorChain) -> Result<Writer> {
let regions = get_mem_regions(&mem, desc_chain.into_iter().writable())?;
Ok(Writer {
mem,
regions: DescriptorChainRegions {
regions,
current: 0,
bytes_consumed: 0,
},
})
pub fn new(desc_chain: &DescriptorChain) -> Writer {
Writer {
mem: desc_chain.mem().clone(),
regions: DescriptorChainRegions::new(
desc_chain.writable_mem_regions(),
desc_chain.exported_regions(),
),
}
}
/// Writes an object to the descriptor chain buffer.
@ -807,7 +742,7 @@ pub fn create_descriptor_chain(
mut buffers_start_addr: GuestAddress,
descriptors: Vec<(DescriptorType, u32)>,
spaces_between_regions: u32,
) -> Result<DescriptorChain> {
) -> anyhow::Result<DescriptorChain> {
let descriptors_len = descriptors.len();
for (index, (type_, size)) in descriptors.into_iter().enumerate() {
let mut flags = 0;
@ -829,20 +764,18 @@ pub fn create_descriptor_chain(
let offset = size + spaces_between_regions;
buffers_start_addr = buffers_start_addr
.checked_add(offset as u64)
.context("Invalid buffers_start_addr)")
.map_err(Error::InvalidChain)?;
.context("Invalid buffers_start_addr)")?;
let _ = memory.write_obj_at_addr(
desc,
descriptor_array_addr
.checked_add(index as u64 * std::mem::size_of::<virtq_desc>() as u64)
.context("Invalid descriptor_array_addr")
.map_err(Error::InvalidChain)?,
.context("Invalid descriptor_array_addr")?,
);
}
DescriptorChain::checked_new(memory, descriptor_array_addr, 0x100, 0, 0, None, None)
.map_err(Error::InvalidChain)
let chain = SplitDescriptorChain::new(memory, descriptor_array_addr, 0x100, 0, None);
DescriptorChain::new(chain, memory, 0, None)
}
#[cfg(test)]
@ -874,7 +807,7 @@ mod tests {
0,
)
.expect("create_descriptor_chain failed");
let mut reader = Reader::new(memory, chain).expect("failed to create Reader");
let mut reader = Reader::new(&chain);
assert_eq!(reader.available_bytes(), 106);
assert_eq!(reader.bytes_read(), 0);
@ -915,7 +848,7 @@ mod tests {
0,
)
.expect("create_descriptor_chain failed");
let mut writer = Writer::new(memory, chain).expect("failed to create Writer");
let mut writer = Writer::new(&chain);
assert_eq!(writer.available_bytes(), 106);
assert_eq!(writer.bytes_written(), 0);
@ -951,7 +884,7 @@ mod tests {
0,
)
.expect("create_descriptor_chain failed");
let mut reader = Reader::new(memory, chain).expect("failed to create Reader");
let mut reader = Reader::new(&chain);
assert_eq!(reader.available_bytes(), 0);
assert_eq!(reader.bytes_read(), 0);
@ -976,7 +909,7 @@ mod tests {
0,
)
.expect("create_descriptor_chain failed");
let mut writer = Writer::new(memory, chain).expect("failed to create Writer");
let mut writer = Writer::new(&chain);
assert_eq!(writer.available_bytes(), 0);
assert_eq!(writer.bytes_written(), 0);
@ -1002,7 +935,7 @@ mod tests {
)
.expect("create_descriptor_chain failed");
let mut reader = Reader::new(memory, chain).expect("failed to create Reader");
let mut reader = Reader::new(&chain);
// Open a file in read-only mode so writes to it to trigger an I/O error.
let device_file = if cfg!(windows) { "NUL" } else { "/dev/zero" };
@ -1033,7 +966,7 @@ mod tests {
)
.expect("create_descriptor_chain failed");
let mut writer = Writer::new(memory, chain).expect("failed to create Writer");
let mut writer = Writer::new(&chain);
let mut file = tempfile().unwrap();
@ -1069,9 +1002,8 @@ mod tests {
0,
)
.expect("create_descriptor_chain failed");
let mut reader =
Reader::new(memory.clone(), chain.clone()).expect("failed to create Reader");
let mut writer = Writer::new(memory, chain).expect("failed to create Writer");
let mut reader = Reader::new(&chain);
let mut writer = Writer::new(&chain);
assert_eq!(reader.bytes_read(), 0);
assert_eq!(writer.bytes_written(), 0);
@ -1114,8 +1046,7 @@ mod tests {
123,
)
.expect("create_descriptor_chain failed");
let mut writer =
Writer::new(memory.clone(), chain_writer).expect("failed to create Writer");
let mut writer = Writer::new(&chain_writer);
writer
.write_obj(secret)
.expect("write_obj should not fail here");
@ -1129,7 +1060,7 @@ mod tests {
123,
)
.expect("create_descriptor_chain failed");
let mut reader = Reader::new(memory, chain_reader).expect("failed to create Reader");
let mut reader = Reader::new(&chain_reader);
match reader.read_obj::<Le32>() {
Err(_) => panic!("read_obj should not fail here"),
Ok(read_secret) => assert_eq!(read_secret, secret),
@ -1152,7 +1083,7 @@ mod tests {
)
.expect("create_descriptor_chain failed");
let mut reader = Reader::new(memory, chain).expect("failed to create Reader");
let mut reader = Reader::new(&chain);
let mut buf = vec![0; 1024];
@ -1187,7 +1118,7 @@ mod tests {
0,
)
.expect("create_descriptor_chain failed");
let mut reader = Reader::new(memory, chain).expect("failed to create Reader");
let mut reader = Reader::new(&chain);
let other = reader.split_at(32);
assert_eq!(reader.available_bytes(), 32);
@ -1216,7 +1147,7 @@ mod tests {
0,
)
.expect("create_descriptor_chain failed");
let mut reader = Reader::new(memory, chain).expect("failed to create Reader");
let mut reader = Reader::new(&chain);
let other = reader.split_at(24);
assert_eq!(reader.available_bytes(), 24);
@ -1245,7 +1176,7 @@ mod tests {
0,
)
.expect("create_descriptor_chain failed");
let mut reader = Reader::new(memory, chain).expect("failed to create Reader");
let mut reader = Reader::new(&chain);
let other = reader.split_at(128);
assert_eq!(reader.available_bytes(), 128);
@ -1274,7 +1205,7 @@ mod tests {
0,
)
.expect("create_descriptor_chain failed");
let mut reader = Reader::new(memory, chain).expect("failed to create Reader");
let mut reader = Reader::new(&chain);
let other = reader.split_at(0);
assert_eq!(reader.available_bytes(), 0);
@ -1303,7 +1234,7 @@ mod tests {
0,
)
.expect("create_descriptor_chain failed");
let mut reader = Reader::new(memory, chain).expect("failed to create Reader");
let mut reader = Reader::new(&chain);
let other = reader.split_at(256);
assert_eq!(
@ -1328,7 +1259,7 @@ mod tests {
0,
)
.expect("create_descriptor_chain failed");
let mut reader = Reader::new(memory, chain).expect("failed to create Reader");
let mut reader = Reader::new(&chain);
let mut buf = vec![0u8; 64];
assert_eq!(
@ -1352,7 +1283,7 @@ mod tests {
0,
)
.expect("create_descriptor_chain failed");
let mut writer = Writer::new(memory, chain).expect("failed to create Writer");
let mut writer = Writer::new(&chain);
let buf = vec![0xdeu8; 64];
assert_eq!(
@ -1381,7 +1312,7 @@ mod tests {
0,
)
.expect("create_descriptor_chain failed");
let mut writer = Writer::new(memory.clone(), write_chain).expect("failed to create Writer");
let mut writer = Writer::new(&write_chain);
writer
.consume(vs.clone())
.expect("failed to consume() a vector");
@ -1394,7 +1325,7 @@ mod tests {
0,
)
.expect("create_descriptor_chain failed");
let mut reader = Reader::new(memory, read_chain).expect("failed to create Reader");
let mut reader = Reader::new(&read_chain);
let vs_read = reader
.collect::<io::Result<Vec<Le64>>, _>()
.expect("failed to collect() values");
@ -1427,7 +1358,7 @@ mod tests {
let Reader {
mem: _,
mut regions,
} = Reader::new(memory, chain).expect("failed to create Reader");
} = Reader::new(&chain);
let drain = regions
.get_remaining_regions_with_count(::std::usize::MAX)
@ -1484,7 +1415,7 @@ mod tests {
let Reader {
mem: _,
mut regions,
} = Reader::new(memory.clone(), chain).expect("failed to create Reader");
} = Reader::new(&chain);
let drain = regions
.get_remaining_with_count(&memory, ::std::usize::MAX)

View file

@ -31,7 +31,7 @@ mod tests {
)
.expect("create_descriptor_chain failed");
let mut reader = Reader::new(memory.clone(), chain).expect("failed to create Reader");
let mut reader = Reader::new(&chain);
// TODO(b/235104127): Potentially use tempfile for ro_file so that this
// test can run on Windows.
@ -69,7 +69,7 @@ mod tests {
)
.expect("create_descriptor_chain failed");
let mut writer = Writer::new(memory.clone(), chain).expect("failed to create Writer");
let mut writer = Writer::new(&chain);
let file = tempfile().expect("failed to create temp file");

View file

@ -34,7 +34,6 @@ use crate::pci::PciCapability;
use crate::virtio::copy_config;
use crate::virtio::device_constants::fs::FS_MAX_TAG_LEN;
use crate::virtio::device_constants::fs::QUEUE_SIZE;
use crate::virtio::DescriptorError;
use crate::virtio::DeviceType;
use crate::virtio::Interrupt;
use crate::virtio::PciCapabilityType;
@ -75,9 +74,6 @@ pub enum Error {
/// Failed to get the securebits for the worker thread.
#[error("failed to get securebits for the worker thread: {0}")]
GetSecurebits(SysError),
/// The `len` field of the header is too small.
#[error("DescriptorChain is invalid: {0}")]
InvalidDescriptorChain(DescriptorError),
/// A request is missing readable descriptors.
#[error("request does not have any readable descriptors")]
NoReadableDescriptors,

View file

@ -159,14 +159,12 @@ pub fn process_fs_queue<I: SignalableInterrupt, F: FileSystem + Sync>(
) -> Result<()> {
let mapper = Mapper::new(Arc::clone(tube), slot);
while let Some(avail_desc) = queue.pop(mem) {
let reader =
Reader::new(mem.clone(), avail_desc.clone()).map_err(Error::InvalidDescriptorChain)?;
let writer =
Writer::new(mem.clone(), avail_desc.clone()).map_err(Error::InvalidDescriptorChain)?;
let reader = Reader::new(&avail_desc);
let writer = Writer::new(&avail_desc);
let total = server.handle_message(reader, writer, &mapper)?;
queue.add_used(mem, avail_desc.index, total as u32);
queue.add_used(mem, avail_desc, total as u32);
queue.trigger_interrupt(mem, interrupt);
}

View file

@ -129,7 +129,7 @@ enum VirtioGpuRing {
struct FenceDescriptor {
ring: VirtioGpuRing,
fence_id: u64,
index: u16,
desc_chain: DescriptorChain,
len: u32,
}
@ -141,7 +141,7 @@ pub struct FenceState {
pub trait QueueReader {
fn pop(&self, mem: &GuestMemory) -> Option<DescriptorChain>;
fn add_used(&self, mem: &GuestMemory, desc_index: u16, len: u32);
fn add_used(&self, mem: &GuestMemory, desc_chain: DescriptorChain, len: u32);
fn signal_used(&self, mem: &GuestMemory);
}
@ -164,8 +164,8 @@ impl QueueReader for LocalQueueReader {
self.queue.borrow_mut().pop(mem)
}
fn add_used(&self, mem: &GuestMemory, desc_index: u16, len: u32) {
self.queue.borrow_mut().add_used(mem, desc_index, len)
fn add_used(&self, mem: &GuestMemory, desc_chain: DescriptorChain, len: u32) {
self.queue.borrow_mut().add_used(mem, desc_chain, len)
}
fn signal_used(&self, mem: &GuestMemory) {
@ -195,8 +195,8 @@ impl QueueReader for SharedQueueReader {
self.queue.lock().pop(mem)
}
fn add_used(&self, mem: &GuestMemory, desc_index: u16, len: u32) {
self.queue.lock().add_used(mem, desc_index, len)
fn add_used(&self, mem: &GuestMemory, desc_chain: DescriptorChain, len: u32) {
self.queue.lock().add_used(mem, desc_chain, len)
}
fn signal_used(&self, mem: &GuestMemory) {
@ -276,14 +276,20 @@ where
};
let mut fence_state = fence_state.lock();
fence_state.descs.retain(|f_desc| {
if f_desc.ring == ring && f_desc.fence_id <= completed_fence.fence_id {
ctrl_queue.add_used(&mem, f_desc.index, f_desc.len);
// TODO(dverkamp): use `drain_filter()` when it is stabilized
let mut i = 0;
while i < fence_state.descs.len() {
if fence_state.descs[i].ring == ring
&& fence_state.descs[i].fence_id <= completed_fence.fence_id
{
let completed_desc = fence_state.descs.remove(i);
ctrl_queue.add_used(&mem, completed_desc.desc_chain, completed_desc.len);
signal = true;
return false;
} else {
i += 1;
}
true
});
}
// Update the last completed fence for this context
fence_state
.completed_fences
@ -297,7 +303,7 @@ where
}
pub struct ReturnDescriptor {
pub index: u16,
pub desc_chain: DescriptorChain,
pub len: u32,
}
@ -643,23 +649,11 @@ impl Frontend {
pub fn process_queue(&mut self, mem: &GuestMemory, queue: &dyn QueueReader) -> bool {
let mut signal_used = false;
while let Some(desc) = queue.pop(mem) {
match (
Reader::new(mem.clone(), desc.clone()),
Writer::new(mem.clone(), desc.clone()),
) {
(Ok(mut reader), Ok(mut writer)) => {
if let Some(ret_desc) =
self.process_descriptor(mem, desc.index, &mut reader, &mut writer)
{
queue.add_used(mem, ret_desc.index, ret_desc.len);
signal_used = true;
}
}
(_, Err(e)) | (Err(e), _) => {
debug!("invalid descriptor: {}", e);
queue.add_used(mem, desc.index, 0);
signal_used = true;
}
let mut reader = Reader::new(&desc);
let mut writer = Writer::new(&desc);
if let Some(ret_desc) = self.process_descriptor(mem, desc, &mut reader, &mut writer) {
queue.add_used(mem, ret_desc.desc_chain, ret_desc.len);
signal_used = true;
}
}
@ -669,7 +663,7 @@ impl Frontend {
fn process_descriptor(
&mut self,
mem: &GuestMemory,
desc_index: u16,
desc_chain: DescriptorChain,
reader: &mut Reader,
writer: &mut Writer,
) -> Option<ReturnDescriptor> {
@ -741,7 +735,7 @@ impl Frontend {
fence_state.descs.push(FenceDescriptor {
ring,
fence_id,
index: desc_index,
desc_chain,
len,
});
@ -751,10 +745,7 @@ impl Frontend {
// No fence (or already completed fence), respond now.
}
Some(ReturnDescriptor {
index: desc_index,
len,
})
Some(ReturnDescriptor { desc_chain, len })
}
pub fn return_cursor(&mut self) -> Option<ReturnDescriptor> {
@ -979,7 +970,8 @@ impl Worker {
// All cursor commands go first because they have higher priority.
while let Some(desc) = self.state.return_cursor() {
self.cursor_queue.add_used(&self.mem, desc.index, desc.len);
self.cursor_queue
.add_used(&self.mem, desc.desc_chain, desc.len);
signal_used_cursor = true;
}

View file

@ -36,7 +36,6 @@ pub use super::super::device_constants::gpu::VIRTIO_GPU_F_RESOURCE_BLOB;
pub use super::super::device_constants::gpu::VIRTIO_GPU_F_RESOURCE_SYNC;
pub use super::super::device_constants::gpu::VIRTIO_GPU_F_RESOURCE_UUID;
pub use super::super::device_constants::gpu::VIRTIO_GPU_F_VIRGL;
use super::super::DescriptorError;
use super::edid::EdidBytes;
use super::Reader;
use super::Writer;
@ -598,15 +597,6 @@ pub enum GpuCommandDecodeError {
/// An I/O error occurred.
#[error("an I/O error occurred: {0}")]
IO(io::Error),
/// The command referenced an inaccessible area of memory.
#[error("command referenced an inaccessible area of memory: {0}")]
Memory(DescriptorError),
}
impl From<DescriptorError> for GpuCommandDecodeError {
fn from(e: DescriptorError) -> GpuCommandDecodeError {
GpuCommandDecodeError::Memory(e)
}
}
impl From<io::Error> for GpuCommandDecodeError {
@ -810,9 +800,6 @@ pub enum GpuResponseEncodeError {
/// An I/O error occurred.
#[error("an I/O error occurred: {0}")]
IO(io::Error),
/// The response was encoded to an inaccessible area of memory.
#[error("response was encoded to an inaccessible area of memory: {0}")]
Memory(DescriptorError),
/// More displays than are valid were in a `OkDisplayInfo`.
#[error("{0} is more displays than are valid")]
TooManyDisplays(usize),
@ -821,12 +808,6 @@ pub enum GpuResponseEncodeError {
TooManyPlanes(usize),
}
impl From<DescriptorError> for GpuResponseEncodeError {
fn from(e: DescriptorError) -> GpuResponseEncodeError {
GpuResponseEncodeError::Memory(e)
}
}
impl From<io::Error> for GpuResponseEncodeError {
fn from(e: io::Error) -> GpuResponseEncodeError {
GpuResponseEncodeError::IO(e)

View file

@ -39,7 +39,6 @@ use self::event_source::SocketEventSource;
use super::copy_config;
use super::virtio_device::Error as VirtioError;
use super::DescriptorChain;
use super::DescriptorError;
use super::DeviceType;
use super::Interrupt;
use super::Queue;
@ -57,9 +56,6 @@ const QUEUE_SIZES: &[u16] = &[EVENT_QUEUE_SIZE, STATUS_QUEUE_SIZE];
#[sorted]
#[derive(Error, Debug)]
pub enum InputError {
// Virtio descriptor error
#[error("virtio descriptor error: {0}")]
Descriptor(DescriptorError),
// Failed to get axis information of event device
#[error("failed to get axis information of event device: {0}")]
EvdevAbsInfoError(base::Error),
@ -352,12 +348,8 @@ struct Worker<T: EventSource> {
impl<T: EventSource> Worker<T> {
// Fills a virtqueue with events from the source. Returns the number of bytes written.
fn fill_event_virtqueue(
event_source: &mut T,
avail_desc: DescriptorChain,
mem: &GuestMemory,
) -> Result<usize> {
let mut writer = Writer::new(mem.clone(), avail_desc).map_err(InputError::Descriptor)?;
fn fill_event_virtqueue(event_source: &mut T, avail_desc: &DescriptorChain) -> Result<usize> {
let mut writer = Writer::new(avail_desc);
while writer.available_bytes() >= virtio_input_event::SIZE {
if let Some(evt) = event_source.pop_available_event() {
@ -381,25 +373,17 @@ impl<T: EventSource> Worker<T> {
break;
}
Some(avail_desc) => {
let avail_desc_index = avail_desc.index;
let bytes_written =
match Worker::fill_event_virtqueue(&mut self.event_source, &avail_desc) {
Ok(count) => count,
Err(e) => {
error!("Input: failed to send events to guest: {}", e);
break;
}
};
let bytes_written = match Worker::fill_event_virtqueue(
&mut self.event_source,
avail_desc,
&self.guest_memory,
) {
Ok(count) => count,
Err(e) => {
error!("Input: failed to send events to guest: {}", e);
break;
}
};
self.event_queue.add_used(
&self.guest_memory,
avail_desc_index,
bytes_written as u32,
);
self.event_queue
.add_used(&self.guest_memory, avail_desc, bytes_written as u32);
needs_interrupt = true;
}
}
@ -409,12 +393,8 @@ impl<T: EventSource> Worker<T> {
}
// Sends events from the guest to the source. Returns the number of bytes read.
fn read_event_virtqueue(
avail_desc: DescriptorChain,
event_source: &mut T,
mem: &GuestMemory,
) -> Result<usize> {
let mut reader = Reader::new(mem.clone(), avail_desc).map_err(InputError::Descriptor)?;
fn read_event_virtqueue(avail_desc: &DescriptorChain, event_source: &mut T) -> Result<usize> {
let mut reader = Reader::new(avail_desc);
while reader.available_bytes() >= virtio_input_event::SIZE {
let evt: virtio_input_event = reader.read_obj().map_err(InputError::ReadQueue)?;
event_source.send_event(&evt)?;
@ -426,13 +406,8 @@ impl<T: EventSource> Worker<T> {
fn process_status_queue(&mut self) -> Result<bool> {
let mut needs_interrupt = false;
while let Some(avail_desc) = self.status_queue.pop(&self.guest_memory) {
let avail_desc_index = avail_desc.index;
let bytes_read = match Worker::read_event_virtqueue(
avail_desc,
&mut self.event_source,
&self.guest_memory,
) {
let bytes_read = match Worker::read_event_virtqueue(&avail_desc, &mut self.event_source)
{
Ok(count) => count,
Err(e) => {
error!("Input: failed to read events from virtqueue: {}", e);
@ -441,7 +416,7 @@ impl<T: EventSource> Worker<T> {
};
self.status_queue
.add_used(&self.guest_memory, avail_desc_index, bytes_read as u32);
.add_used(&self.guest_memory, avail_desc, bytes_read as u32);
needs_interrupt = true;
}

View file

@ -64,7 +64,6 @@ use crate::virtio::iommu::ipc_memory_mapper::*;
use crate::virtio::iommu::memory_mapper::*;
use crate::virtio::iommu::protocol::*;
use crate::virtio::DescriptorChain;
use crate::virtio::DescriptorError;
use crate::virtio::DeviceType;
use crate::virtio::Interrupt;
use crate::virtio::Queue;
@ -129,12 +128,8 @@ type Result<T> = result::Result<T, IommuError>;
pub enum IommuError {
#[error("async executor error: {0}")]
AsyncExec(AsyncError),
#[error("failed to create reader: {0}")]
CreateReader(DescriptorError),
#[error("failed to create wait context: {0}")]
CreateWaitContext(SysError),
#[error("failed to create writer: {0}")]
CreateWriter(DescriptorError),
#[error("failed getting host address: {0}")]
GetHostAddress(GuestMemoryError),
#[error("failed to read from guest address: {0}")]
@ -549,10 +544,8 @@ impl State {
&mut self,
avail_desc: &DescriptorChain,
) -> Result<(usize, Option<EventAsync>)> {
let mut reader =
Reader::new(self.mem.clone(), avail_desc.clone()).map_err(IommuError::CreateReader)?;
let mut writer =
Writer::new(self.mem.clone(), avail_desc.clone()).map_err(IommuError::CreateWriter)?;
let mut reader = Reader::new(avail_desc);
let mut writer = Writer::new(avail_desc);
// at least we need space to write VirtioIommuReqTail
if writer.available_bytes() < size_of::<virtio_iommu_req_tail>() {
@ -602,7 +595,6 @@ async fn request_queue<I: SignalableInterrupt>(
.next_async(&mem, &mut queue_event)
.await
.map_err(IommuError::ReadAsyncDesc)?;
let desc_index = avail_desc.index;
let (len, fault_resolved_event) = match state.borrow_mut().execute_request(&avail_desc) {
Ok(res) => res,
@ -624,7 +616,7 @@ async fn request_queue<I: SignalableInterrupt>(
debug!("iommu fault resolved");
}
queue.add_used(&mem, desc_index, len as u32);
queue.add_used(&mem, avail_desc, len as u32);
queue.trigger_interrupt(&mem, &interrupt);
}
}

View file

@ -8,6 +8,7 @@ mod async_device;
mod async_utils;
#[cfg(feature = "balloon")]
mod balloon;
mod descriptor_chain;
mod descriptor_utils;
pub mod device_constants;
mod input;
@ -41,7 +42,10 @@ pub mod vsock;
pub use self::balloon::*;
pub use self::block::*;
pub use self::console::*;
pub use self::descriptor_utils::Error as DescriptorError;
pub use self::descriptor_chain::Desc;
pub use self::descriptor_chain::DescriptorChain;
pub use self::descriptor_chain::DescriptorChainIter;
pub use self::descriptor_chain::SplitDescriptorChain;
pub use self::descriptor_utils::*;
#[cfg(feature = "gpu")]
pub use self::gpu::*;

View file

@ -43,7 +43,6 @@ use zerocopy::AsBytes;
use zerocopy::FromBytes;
use super::copy_config;
use super::DescriptorError;
use super::DeviceType;
use super::Interrupt;
use super::Queue;
@ -75,15 +74,15 @@ pub enum NetError {
/// Creating WaitContext failed.
#[error("failed to create wait context: {0}")]
CreateWaitContext(SysError),
/// Descriptor chain was invalid.
#[error("failed to valildate descriptor chain: {0}")]
DescriptorChain(DescriptorError),
/// Adding the tap descriptor back to the event context failed.
#[error("failed to add tap trigger to event context: {0}")]
EventAddTap(SysError),
/// Removing the tap descriptor from the event context failed.
#[error("failed to remove tap trigger from event context: {0}")]
EventRemoveTap(SysError),
/// Invalid control command
#[error("invalid control command")]
InvalidCmd,
/// Error reading data from control queue.
#[error("failed to read control message data: {0}")]
ReadCtrlData(io::Error),
@ -116,6 +115,9 @@ pub enum NetError {
/// Setting tap netmask failed.
#[error("failed to set tap netmask: {0}")]
TapSetNetmask(TapError),
/// Setting tap offload failed.
#[error("failed to set tap offload: {0}")]
TapSetOffload(TapError),
/// Setting vnet header size failed.
#[error("failed to set vnet header size: {0}")]
TapSetVnetHdrSize(TapError),
@ -216,6 +218,56 @@ pub struct VirtioNetConfig {
mtu: Le16,
}
fn process_ctrl_request<T: TapT>(
reader: &mut Reader,
tap: &mut T,
acked_features: u64,
vq_pairs: u16,
) -> Result<(), NetError> {
let ctrl_hdr: virtio_net_ctrl_hdr = reader.read_obj().map_err(NetError::ReadCtrlHeader)?;
match ctrl_hdr.class as c_uint {
VIRTIO_NET_CTRL_GUEST_OFFLOADS => {
if ctrl_hdr.cmd != VIRTIO_NET_CTRL_GUEST_OFFLOADS_SET as u8 {
error!(
"invalid cmd for VIRTIO_NET_CTRL_GUEST_OFFLOADS: {}",
ctrl_hdr.cmd
);
return Err(NetError::InvalidCmd);
}
let offloads: Le64 = reader.read_obj().map_err(NetError::ReadCtrlData)?;
let tap_offloads = virtio_features_to_tap_offload(offloads.into());
tap.set_offload(tap_offloads)
.map_err(NetError::TapSetOffload)?;
}
VIRTIO_NET_CTRL_MQ => {
if ctrl_hdr.cmd == VIRTIO_NET_CTRL_MQ_VQ_PAIRS_SET as u8 {
let pairs: Le16 = reader.read_obj().map_err(NetError::ReadCtrlData)?;
// Simple handle it now
if acked_features & 1 << virtio_net::VIRTIO_NET_F_MQ == 0
|| pairs.to_native() != vq_pairs
{
error!(
"Invalid VQ_PAIRS_SET cmd, driver request pairs: {}, device vq pairs: {}",
pairs.to_native(),
vq_pairs
);
return Err(NetError::InvalidCmd);
}
}
}
_ => {
warn!(
"unimplemented class for VIRTIO_NET_CTRL_GUEST_OFFLOADS: {}",
ctrl_hdr.class
);
return Err(NetError::InvalidCmd);
}
}
Ok(())
}
pub fn process_ctrl<I: SignalableInterrupt, T: TapT>(
interrupt: &I,
ctrl_queue: &mut Queue,
@ -225,65 +277,19 @@ pub fn process_ctrl<I: SignalableInterrupt, T: TapT>(
vq_pairs: u16,
) -> Result<(), NetError> {
while let Some(desc_chain) = ctrl_queue.pop(mem) {
let index = desc_chain.index;
let mut reader =
Reader::new(mem.clone(), desc_chain.clone()).map_err(NetError::DescriptorChain)?;
let mut writer = Writer::new(mem.clone(), desc_chain).map_err(NetError::DescriptorChain)?;
let ctrl_hdr: virtio_net_ctrl_hdr = reader.read_obj().map_err(NetError::ReadCtrlHeader)?;
let mut write_error = || {
let mut reader = Reader::new(&desc_chain);
let mut writer = Writer::new(&desc_chain);
if let Err(e) = process_ctrl_request(&mut reader, tap, acked_features, vq_pairs) {
error!("process_ctrl_request failed: {}", e);
writer
.write_all(&[VIRTIO_NET_ERR as u8])
.map_err(NetError::WriteAck)?;
ctrl_queue.add_used(mem, index, writer.bytes_written() as u32);
Ok(())
};
match ctrl_hdr.class as c_uint {
VIRTIO_NET_CTRL_GUEST_OFFLOADS => {
if ctrl_hdr.cmd != VIRTIO_NET_CTRL_GUEST_OFFLOADS_SET as u8 {
error!(
"invalid cmd for VIRTIO_NET_CTRL_GUEST_OFFLOADS: {}",
ctrl_hdr.cmd
);
write_error()?;
continue;
}
let offloads: Le64 = reader.read_obj().map_err(NetError::ReadCtrlData)?;
let tap_offloads = virtio_features_to_tap_offload(offloads.into());
if let Err(e) = tap.set_offload(tap_offloads) {
error!("Failed to set tap itnerface offload flags: {}", e);
write_error()?;
continue;
}
let ack = VIRTIO_NET_OK as u8;
writer.write_all(&[ack]).map_err(NetError::WriteAck)?;
}
VIRTIO_NET_CTRL_MQ => {
if ctrl_hdr.cmd == VIRTIO_NET_CTRL_MQ_VQ_PAIRS_SET as u8 {
let pairs: Le16 = reader.read_obj().map_err(NetError::ReadCtrlData)?;
// Simple handle it now
if acked_features & 1 << virtio_net::VIRTIO_NET_F_MQ == 0
|| pairs.to_native() != vq_pairs
{
error!("Invalid VQ_PAIRS_SET cmd, driver request pairs: {}, device vq pairs: {}",
pairs.to_native(), vq_pairs);
write_error()?;
continue;
}
let ack = VIRTIO_NET_OK as u8;
writer.write_all(&[ack]).map_err(NetError::WriteAck)?;
}
}
_ => warn!(
"unimplemented class for VIRTIO_NET_CTRL_GUEST_OFFLOADS: {}",
ctrl_hdr.class
),
} else {
writer
.write_all(&[VIRTIO_NET_OK as u8])
.map_err(NetError::WriteAck)?;
}
ctrl_queue.add_used(mem, index, writer.bytes_written() as u32);
ctrl_queue.add_used(mem, desc_chain, writer.bytes_written() as u32);
}
ctrl_queue.trigger_interrupt(mem, interrupt);

View file

@ -22,7 +22,6 @@ use thiserror::Error;
use vm_memory::GuestMemory;
use super::copy_config;
use super::DescriptorError;
use super::DeviceType;
use super::Interrupt;
use super::Queue;
@ -51,9 +50,6 @@ pub enum P9Error {
/// An internal I/O error occurred.
#[error("P9 internal server error: {0}")]
Internal(io::Error),
/// A DescriptorChain contains invalid data.
#[error("DescriptorChain contains invalid data: {0}")]
InvalidDescriptorChain(DescriptorError),
/// A request is missing readable descriptors.
#[error("request does not have any readable descriptors")]
NoReadableDescriptors,
@ -87,17 +83,15 @@ struct Worker {
impl Worker {
fn process_queue(&mut self) -> P9Result<()> {
while let Some(avail_desc) = self.queue.pop(&self.mem) {
let mut reader = Reader::new(self.mem.clone(), avail_desc.clone())
.map_err(P9Error::InvalidDescriptorChain)?;
let mut writer = Writer::new(self.mem.clone(), avail_desc.clone())
.map_err(P9Error::InvalidDescriptorChain)?;
let mut reader = Reader::new(&avail_desc);
let mut writer = Writer::new(&avail_desc);
self.server
.handle_message(&mut reader, &mut writer)
.map_err(P9Error::Internal)?;
self.queue
.add_used(&self.mem, avail_desc.index, writer.bytes_written() as u32);
.add_used(&self.mem, avail_desc, writer.bytes_written() as u32);
}
self.queue.trigger_interrupt(&self.mem, &self.interrupt);

View file

@ -34,7 +34,6 @@ use zerocopy::FromBytes;
use super::async_utils;
use super::copy_config;
use super::DescriptorChain;
use super::DescriptorError;
use super::DeviceType;
use super::Interrupt;
use super::Queue;
@ -72,9 +71,6 @@ struct virtio_pmem_req {
#[sorted]
#[derive(Error, Debug)]
enum Error {
/// Invalid virtio descriptor chain.
#[error("virtio descriptor error: {0}")]
Descriptor(DescriptorError),
/// Failed to read from virtqueue.
#[error("failed to read from virtqueue: {0}")]
ReadQueue(io::Error),
@ -126,14 +122,13 @@ fn execute_request(
}
fn handle_request(
mem: &GuestMemory,
avail_desc: DescriptorChain,
avail_desc: &DescriptorChain,
pmem_device_tube: &Tube,
mapping_arena_slot: u32,
mapping_size: usize,
) -> Result<usize> {
let mut reader = Reader::new(mem.clone(), avail_desc.clone()).map_err(Error::Descriptor)?;
let mut writer = Writer::new(mem.clone(), avail_desc).map_err(Error::Descriptor)?;
let mut reader = Reader::new(avail_desc);
let mut writer = Writer::new(avail_desc);
let status_code = reader
.read_obj()
@ -166,10 +161,9 @@ async fn handle_queue(
}
Ok(d) => d,
};
let index = avail_desc.index;
let written = match handle_request(
mem,
avail_desc,
&avail_desc,
&pmem_device_tube,
mapping_arena_slot,
mapping_size,
@ -180,7 +174,7 @@ async fn handle_queue(
0
}
};
queue.add_used(mem, index, written as u32);
queue.add_used(mem, avail_desc, written as u32);
queue.trigger_interrupt(mem, &interrupt);
}
}

View file

@ -10,6 +10,7 @@
//! For more information about this device, please visit <go/virtio-pvclock>.
use std::arch::x86_64::_rdtsc;
use std::mem::size_of;
use std::sync::atomic::AtomicU64;
use std::sync::atomic::Ordering;
use std::sync::Arc;
@ -455,16 +456,34 @@ fn run_worker(
match event.token {
Token::SetPvClockPageQueue => {
let _ = set_pvclock_page_queue_evt.wait();
let desc = match set_pvclock_page_queue.pop(&worker.mem) {
Some(desc) => desc,
let desc_chain = match set_pvclock_page_queue.pop(&worker.mem) {
Some(desc_chain) => desc_chain,
None => {
error!("set_pvclock_page queue was empty");
continue;
}
};
let mut req: virtio_pvclock_set_pvclock_page_req =
match worker.mem.read_obj_from_addr(desc.addr) {
// This device does not follow the virtio spec requirements for device-readable
// vs. device-writable descriptors, so we can't use `Reader`/`Writer`. Pick the
// first descriptor from the chain and assume the whole req structure is
// contained within it.
let desc = desc_chain
.readable_mem_regions()
.iter()
.chain(desc_chain.writable_mem_regions().iter())
.next()
.unwrap();
let len = if desc.len < size_of::<virtio_pvclock_set_pvclock_page_req>() {
error!("pvclock descriptor too short");
0
} else {
let addr = GuestAddress(desc.offset);
let mut req: virtio_pvclock_set_pvclock_page_req = match worker
.mem
.read_obj_from_addr(addr)
{
Ok(req) => req,
Err(e) => {
error!("failed to read request from set_pvclock_page queue: {}", e);
@ -472,20 +491,23 @@ fn run_worker(
}
};
req.status = match worker.set_pvclock_page(req.pvclock_page_pa.into()) {
Err(e) => {
error!("failed to set pvclock page: {:#}", e);
VIRTIO_PVCLOCK_S_IOERR
req.status = match worker.set_pvclock_page(req.pvclock_page_pa.into()) {
Err(e) => {
error!("failed to set pvclock page: {:#}", e);
VIRTIO_PVCLOCK_S_IOERR
}
Ok(_) => VIRTIO_PVCLOCK_S_OK,
};
if let Err(e) = worker.mem.write_obj_at_addr(req, addr) {
error!("failed to write set_pvclock_page status: {}", e);
continue;
}
Ok(_) => VIRTIO_PVCLOCK_S_OK,
desc.len as u32
};
if let Err(e) = worker.mem.write_obj_at_addr(req, desc.addr) {
error!("failed to write set_pvclock_page status: {}", e);
continue;
}
set_pvclock_page_queue.add_used(&worker.mem, desc.index, desc.len);
set_pvclock_page_queue.add_used(&worker.mem, desc_chain, len);
set_pvclock_page_queue.trigger_interrupt(&worker.mem, &interrupt);
}
Token::SuspendResume => {

View file

@ -12,276 +12,27 @@ use anyhow::Context;
use anyhow::Result;
use base::error;
use base::warn;
use base::Protection;
use cros_async::AsyncError;
use cros_async::EventAsync;
use data_model::Le16;
use data_model::Le32;
use data_model::Le64;
use smallvec::smallvec;
use smallvec::SmallVec;
use sync::Mutex;
use virtio_sys::virtio_ring::VIRTIO_RING_F_EVENT_IDX;
use vm_memory::GuestAddress;
use vm_memory::GuestMemory;
use zerocopy::AsBytes;
use zerocopy::FromBytes;
use super::SignalableInterrupt;
use super::VIRTIO_MSI_NO_VECTOR;
use crate::virtio::ipc_memory_mapper::ExportedRegion;
use crate::virtio::ipc_memory_mapper::IpcMemoryMapper;
use crate::virtio::memory_mapper::MemRegion;
use crate::virtio::memory_util::read_obj_from_addr_wrapper;
use crate::virtio::memory_util::write_obj_at_addr_wrapper;
const VIRTQ_DESC_F_NEXT: u16 = 0x1;
const VIRTQ_DESC_F_WRITE: u16 = 0x2;
#[allow(dead_code)]
const VIRTQ_DESC_F_INDIRECT: u16 = 0x4;
use crate::virtio::DescriptorChain;
use crate::virtio::SplitDescriptorChain;
#[allow(dead_code)]
const VIRTQ_USED_F_NO_NOTIFY: u16 = 0x1;
#[allow(dead_code)]
const VIRTQ_AVAIL_F_NO_INTERRUPT: u16 = 0x1;
/// An iterator over a single descriptor chain. Not to be confused with AvailIter,
/// which iterates over the descriptor chain heads in a queue.
pub struct DescIter {
next: Option<DescriptorChain>,
}
impl DescIter {
/// Returns an iterator that only yields the readable descriptors in the chain.
pub fn readable(self) -> impl Iterator<Item = DescriptorChain> {
self.take_while(DescriptorChain::is_read_only)
}
/// Returns an iterator that only yields the writable descriptors in the chain.
pub fn writable(self) -> impl Iterator<Item = DescriptorChain> {
self.skip_while(DescriptorChain::is_read_only)
}
}
impl Iterator for DescIter {
type Item = DescriptorChain;
fn next(&mut self) -> Option<Self::Item> {
if let Some(current) = self.next.take() {
self.next = current.next_descriptor();
Some(current)
} else {
None
}
}
}
/// A virtio descriptor chain.
#[derive(Clone)]
pub struct DescriptorChain {
mem: GuestMemory,
desc_table: GuestAddress,
queue_size: u16,
ttl: u16, // used to prevent infinite chain cycles
/// Index into the descriptor table
pub index: u16,
/// Guest physical address of device specific data, or IO virtual address
/// if iommu is used
pub addr: GuestAddress,
/// Length of device specific data
pub len: u32,
/// Includes next, write, and indirect bits
pub flags: u16,
/// Index into the descriptor table of the next descriptor if flags has
/// the next bit set
pub next: u16,
/// The memory regions associated with the current descriptor.
regions: SmallVec<[MemRegion; 1]>,
/// Translates `addr` to guest physical address
iommu: Option<Arc<Mutex<IpcMemoryMapper>>>,
/// The exported descriptor table of this chain's queue. Present
/// iff iommu is present.
exported_desc_table: Option<ExportedRegion>,
/// The exported iommu region of the current descriptor. Present iff
/// iommu is present.
exported_region: Option<ExportedRegion>,
}
#[derive(Copy, Clone, Debug, FromBytes, AsBytes)]
#[repr(C)]
pub struct Desc {
pub addr: Le64,
pub len: Le32,
pub flags: Le16,
pub next: Le16,
}
impl DescriptorChain {
pub(crate) fn checked_new(
mem: &GuestMemory,
desc_table: GuestAddress,
queue_size: u16,
index: u16,
required_flags: u16,
iommu: Option<Arc<Mutex<IpcMemoryMapper>>>,
exported_desc_table: Option<ExportedRegion>,
) -> Result<DescriptorChain> {
if index >= queue_size {
bail!("index ({}) >= queue_size ({})", index, queue_size);
}
let desc_head = desc_table
.checked_add((index as u64) * 16)
.context("integer overflow")?;
let desc: Desc =
read_obj_from_addr_wrapper(mem, exported_desc_table.as_ref(), desc_head)
.with_context(|| format!("failed to read desc {:x}", desc_head.offset()))?;
let addr = GuestAddress(desc.addr.into());
let len = desc.len.to_native();
let (regions, exported_region) = if let Some(iommu) = &iommu {
if exported_desc_table.is_none() {
bail!("missing exported descriptor table");
}
let exported_region =
ExportedRegion::new(mem, iommu.clone(), addr.offset(), len.into())
.context("failed to get mem regions")?;
let regions = exported_region.get_mem_regions();
let required_prot = if required_flags & VIRTQ_DESC_F_WRITE == 0 {
Protection::read()
} else {
Protection::write()
};
for r in &regions {
if !r.prot.allows(&required_prot) {
bail!("missing RW permissions for descriptor");
}
}
(regions, Some(exported_region))
} else {
(
smallvec![MemRegion {
gpa: addr,
len: len.into(),
prot: Protection::read_write(),
}],
None,
)
};
let chain = DescriptorChain {
mem: mem.clone(),
desc_table,
queue_size,
ttl: queue_size,
index,
addr,
len,
flags: desc.flags.into(),
next: desc.next.into(),
iommu,
regions,
exported_region,
exported_desc_table,
};
if chain.is_valid() && chain.flags & required_flags == required_flags {
Ok(chain)
} else {
bail!("chain is invalid")
}
}
pub fn into_mem_regions(self) -> (SmallVec<[MemRegion; 1]>, Option<ExportedRegion>) {
(self.regions, self.exported_region)
}
fn is_valid(&self) -> bool {
if self.len > 0 {
// Each region in `self.regions` must be a contiguous range in `self.mem`.
if !self
.regions
.iter()
.all(|r| self.mem.is_valid_range(r.gpa, r.len))
{
return false;
}
}
!self.has_next() || self.next < self.queue_size
}
/// Gets if this descriptor chain has another descriptor chain linked after it.
pub fn has_next(&self) -> bool {
self.flags & VIRTQ_DESC_F_NEXT != 0 && self.ttl > 1
}
/// If the driver designated this as a write only descriptor.
///
/// If this is false, this descriptor is read only.
/// Write only means the the emulated device can write and the driver can read.
pub fn is_write_only(&self) -> bool {
self.flags & VIRTQ_DESC_F_WRITE != 0
}
/// If the driver designated this as a read only descriptor.
///
/// If this is false, this descriptor is write only.
/// Read only means the emulated device can read and the driver can write.
pub fn is_read_only(&self) -> bool {
self.flags & VIRTQ_DESC_F_WRITE == 0
}
/// Gets the next descriptor in this descriptor chain, if there is one.
///
/// Note that this is distinct from the next descriptor chain returned by `AvailIter`, which is
/// the head of the next _available_ descriptor chain.
pub fn next_descriptor(&self) -> Option<DescriptorChain> {
if self.has_next() {
// Once we see a write-only descriptor, all subsequent descriptors must be write-only.
let required_flags = self.flags & VIRTQ_DESC_F_WRITE;
let iommu = self.iommu.as_ref().map(Arc::clone);
match DescriptorChain::checked_new(
&self.mem,
self.desc_table,
self.queue_size,
self.next,
required_flags,
iommu,
self.exported_desc_table.clone(),
) {
Ok(mut c) => {
c.ttl = self.ttl - 1;
Some(c)
}
Err(e) => {
error!("{:#}", e);
None
}
}
} else {
None
}
}
/// Produces an iterator over all the descriptors in this chain.
pub fn into_iter(self) -> DescIter {
DescIter { next: Some(self) }
}
}
/// Consuming iterator over all available descriptor chain heads in the queue.
pub struct AvailIter<'a, 'b> {
mem: &'a GuestMemory,
@ -674,20 +425,19 @@ impl Queue {
.unwrap();
let iommu = self.iommu.as_ref().map(Arc::clone);
DescriptorChain::checked_new(
let chain = SplitDescriptorChain::new(
mem,
self.desc_table,
self.size,
descriptor_index,
0,
iommu,
self.exported_desc_table.clone(),
)
.map_err(|e| {
error!("{:#}", e);
e
})
.ok()
self.exported_desc_table.as_ref(),
);
DescriptorChain::new(chain, mem, descriptor_index, iommu)
.map_err(|e| {
error!("{:#}", e);
e
})
.ok()
}
/// Remove the first available descriptor chain from the queue.
@ -730,7 +480,8 @@ impl Queue {
}
/// Puts an available descriptor head into the used ring for use by the guest.
pub fn add_used(&mut self, mem: &GuestMemory, desc_index: u16, len: u32) {
pub fn add_used(&mut self, mem: &GuestMemory, desc_chain: DescriptorChain, len: u32) {
let desc_index = desc_chain.index();
if desc_index >= self.size {
error!(
"attempted to add out of bounds descriptor to used ring: {}",
@ -924,10 +675,17 @@ impl Queue {
mod tests {
use std::convert::TryInto;
use data_model::Le16;
use data_model::Le32;
use data_model::Le64;
use memoffset::offset_of;
use zerocopy::AsBytes;
use zerocopy::FromBytes;
use super::super::Interrupt;
use super::*;
use crate::virtio::create_descriptor_chain;
use crate::virtio::Desc;
use crate::IrqLevelEvent;
const GUEST_MEMORY_SIZE: u64 = 0x10000;
@ -1015,6 +773,11 @@ mod tests {
queue.ack_features((1u64) << VIRTIO_RING_F_EVENT_IDX);
}
fn fake_desc_chain(mem: &GuestMemory) -> DescriptorChain {
create_descriptor_chain(mem, GuestAddress(0), GuestAddress(0), Vec::new(), 0)
.expect("failed to create descriptor chain")
}
#[test]
fn queue_event_id_guest_fast() {
let mut queue = Queue::new(QUEUE_SIZE.try_into().unwrap());
@ -1032,7 +795,7 @@ mod tests {
// device has handled them, so increase self.next_used to 0x100
let mut device_generate: Wrapping<u16> = Wrapping(0x100);
for _ in 0..device_generate.0 {
queue.add_used(&mem, 0x0, BUFFER_LEN);
queue.add_used(&mem, fake_desc_chain(&mem), BUFFER_LEN);
}
// At this moment driver hasn't handled any interrupts yet, so it
@ -1050,7 +813,7 @@ mod tests {
// Assume driver submit another u16::MAX - 0x100 req to device,
// Device has handled all of them, so increase self.next_used to u16::MAX
for _ in device_generate.0..u16::max_value() {
queue.add_used(&mem, 0x0, BUFFER_LEN);
queue.add_used(&mem, fake_desc_chain(&mem), BUFFER_LEN);
}
device_generate = Wrapping(u16::max_value());
@ -1068,7 +831,7 @@ mod tests {
// Assume driver submit another 1 request,
// device has handled it, so wrap self.next_used to 0
queue.add_used(&mem, 0x0, BUFFER_LEN);
queue.add_used(&mem, fake_desc_chain(&mem), BUFFER_LEN);
device_generate += Wrapping(1);
// At this moment driver has handled all the previous interrupts, so it
@ -1101,7 +864,7 @@ mod tests {
// device have handled 0x100 req, so increase self.next_used to 0x100
let mut device_generate: Wrapping<u16> = Wrapping(0x100);
for _ in 0..device_generate.0 {
queue.add_used(&mem, 0x0, BUFFER_LEN);
queue.add_used(&mem, fake_desc_chain(&mem), BUFFER_LEN);
}
// At this moment driver hasn't handled any interrupts yet, so it
@ -1118,7 +881,7 @@ mod tests {
// Assume driver submit another 1 request,
// device has handled it, so increment self.next_used.
queue.add_used(&mem, 0x0, BUFFER_LEN);
queue.add_used(&mem, fake_desc_chain(&mem), BUFFER_LEN);
device_generate += Wrapping(1);
// At this moment driver hasn't finished last interrupt yet,
@ -1128,7 +891,7 @@ mod tests {
// Assume driver submit another u16::MAX - 0x101 req to device,
// Device has handled all of them, so increase self.next_used to u16::MAX
for _ in device_generate.0..u16::max_value() {
queue.add_used(&mem, 0x0, BUFFER_LEN);
queue.add_used(&mem, fake_desc_chain(&mem), BUFFER_LEN);
}
device_generate = Wrapping(u16::max_value());
@ -1142,7 +905,7 @@ mod tests {
// Assume driver submit another 1 request,
// device has handled it, so wrap self.next_used to 0
queue.add_used(&mem, 0x0, BUFFER_LEN);
queue.add_used(&mem, fake_desc_chain(&mem), BUFFER_LEN);
device_generate += Wrapping(1);
// At this moment driver has already finished the last interrupt(0x100),
@ -1151,7 +914,7 @@ mod tests {
// Assume driver submit another 1 request,
// device has handled it, so increment self.next_used to 1
queue.add_used(&mem, 0x0, BUFFER_LEN);
queue.add_used(&mem, fake_desc_chain(&mem), BUFFER_LEN);
device_generate += Wrapping(1);
// At this moment driver hasn't finished last interrupt((Wrapping(0)) yet,
@ -1168,7 +931,7 @@ mod tests {
// Assume driver submit another 1 request,
// device has handled it, so increase self.next_used.
queue.add_used(&mem, 0x0, BUFFER_LEN);
queue.add_used(&mem, fake_desc_chain(&mem), BUFFER_LEN);
device_generate += Wrapping(1);
// At this moment driver has finished all the previous interrupts, so it

View file

@ -2,7 +2,6 @@
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.
use std::io;
use std::io::Write;
use anyhow::anyhow;
@ -50,31 +49,21 @@ impl Worker {
let mut needs_interrupt = false;
while let Some(avail_desc) = queue.pop(&self.mem) {
let index = avail_desc.index;
let mut writer = Writer::new(&avail_desc);
let avail_bytes = writer.available_bytes();
let writer_or_err = Writer::new(self.mem.clone(), avail_desc)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidInput, e));
let written_size = match writer_or_err {
Ok(mut writer) => {
let avail_bytes = writer.available_bytes();
let mut rand_bytes = vec![0u8; avail_bytes];
OsRng.fill_bytes(&mut rand_bytes);
let mut rand_bytes = vec![0u8; avail_bytes];
OsRng.fill_bytes(&mut rand_bytes);
match writer.write_all(&rand_bytes) {
Ok(_) => rand_bytes.len(),
Err(e) => {
warn!("Failed to write random data to the guest: {}", e);
0usize
}
}
}
let written_size = match writer.write_all(&rand_bytes) {
Ok(_) => rand_bytes.len(),
Err(e) => {
warn!("Failed to write random data to the guest: {}", e);
0usize
}
};
queue.add_used(&self.mem, index, written_size as u32);
queue.add_used(&self.mem, avail_desc, written_size as u32);
needs_interrupt = true;
}

View file

@ -80,7 +80,6 @@ pub trait PlaybackBufferWriter {
#[cfg(windows)]
async fn check_and_prefill(
&mut self,
mem: &GuestMemory,
desc_receiver: &mut mpsc::UnboundedReceiver<DescriptorChain>,
sender: &mut mpsc::UnboundedSender<PcmResponse>,
) -> Result<(), Error>;
@ -155,7 +154,6 @@ impl VirtioSndPcmCmd {
// a runtime/internal error
async fn process_pcm_ctrl(
ex: &Executor,
mem: &GuestMemory,
tx_send: &mpsc::UnboundedSender<PcmResponse>,
rx_send: &mpsc::UnboundedSender<PcmResponse>,
streams: &Rc<AsyncMutex<Vec<AsyncMutex<StreamInfo>>>>,
@ -190,7 +188,7 @@ async fn process_pcm_ctrl(
}
result
}
VirtioSndPcmCmd::Prepare => stream.prepare(ex, mem.clone(), tx_send, rx_send).await,
VirtioSndPcmCmd::Prepare => stream.prepare(ex, tx_send, rx_send).await,
VirtioSndPcmCmd::Start => stream.start().await,
VirtioSndPcmCmd::Stop => stream.stop().await,
VirtioSndPcmCmd::Release => stream.release().await,
@ -285,15 +283,13 @@ impl From<Result<u32, Error>> for virtio_snd_pcm_status {
// Drain all DescriptorChain in desc_receiver during WorkerStatus::Quit process.
async fn drain_desc_receiver(
desc_receiver: &mut mpsc::UnboundedReceiver<DescriptorChain>,
mem: &GuestMemory,
sender: &mut mpsc::UnboundedSender<PcmResponse>,
) -> Result<(), Error> {
let mut o_desc_chain = desc_receiver.next().await;
while let Some(desc_chain) = o_desc_chain {
// From the virtio-snd spec:
// The device MUST complete all pending I/O messages for the specified stream ID.
let desc_index = desc_chain.index;
let writer = Writer::new(mem.clone(), desc_chain).map_err(Error::DescriptorChain)?;
let writer = Writer::new(&desc_chain);
let status = virtio_snd_pcm_status::new(StatusCode::OK, 0);
// Fetch next DescriptorChain to see if the current one is the last one.
o_desc_chain = desc_receiver.next().await;
@ -305,7 +301,7 @@ async fn drain_desc_receiver(
};
sender
.send(PcmResponse {
desc_index,
desc_chain,
status,
writer,
done,
@ -323,17 +319,12 @@ async fn drain_desc_receiver(
Ok(())
}
pub(crate) fn get_index_with_reader_and_writer(
mem: &GuestMemory,
desc_chain: DescriptorChain,
) -> Result<(u16, Reader, Writer), Error> {
let desc_index = desc_chain.index;
let mut reader =
Reader::new(mem.clone(), desc_chain.clone()).map_err(Error::DescriptorChain)?;
pub(crate) fn get_reader_and_writer(desc_chain: &DescriptorChain) -> (Reader, Writer) {
let mut reader = Reader::new(desc_chain);
// stream_id was already read in handle_pcm_queue
reader.consume(std::mem::size_of::<virtio_snd_pcm_xfer>());
let writer = Writer::new(mem.clone(), desc_chain).map_err(Error::DescriptorChain)?;
Ok((desc_index, reader, writer))
let writer = Writer::new(desc_chain);
(reader, writer)
}
/// Start a pcm worker that receives descriptors containing PCM frames (audio data) from the tx/rx
@ -346,18 +337,9 @@ pub async fn start_pcm_worker(
dstream: DirectionalStream,
mut desc_receiver: mpsc::UnboundedReceiver<DescriptorChain>,
status_mutex: Rc<AsyncMutex<WorkerStatus>>,
mem: GuestMemory,
mut sender: mpsc::UnboundedSender<PcmResponse>,
) -> Result<(), Error> {
let res = pcm_worker_loop(
ex,
dstream,
&mut desc_receiver,
&status_mutex,
&mem,
&mut sender,
)
.await;
let res = pcm_worker_loop(ex, dstream, &mut desc_receiver, &status_mutex, &mut sender).await;
*status_mutex.lock().await = WorkerStatus::Quit;
if res.is_err() {
error!(
@ -366,7 +348,7 @@ pub async fn start_pcm_worker(
);
// On error, guaranteed that desc_receiver has not been drained, so drain it here.
// Note that drain blocks until the stream is release.
drain_desc_receiver(&mut desc_receiver, &mem, &mut sender).await?;
drain_desc_receiver(&mut desc_receiver, &mut sender).await?;
}
res
}
@ -376,7 +358,6 @@ async fn pcm_worker_loop(
dstream: DirectionalStream,
desc_receiver: &mut mpsc::UnboundedReceiver<DescriptorChain>,
status_mutex: &Rc<AsyncMutex<WorkerStatus>>,
mem: &GuestMemory,
sender: &mut mpsc::UnboundedSender<PcmResponse>,
) -> Result<(), Error> {
match dstream {
@ -389,7 +370,7 @@ async fn pcm_worker_loop(
let worker_status = status_mutex.lock().await;
match *worker_status {
WorkerStatus::Quit => {
drain_desc_receiver(desc_receiver, mem, sender).await?;
drain_desc_receiver(desc_receiver, sender).await?;
if let Err(e) = write_data(dst_buf, None, &mut buffer_writer).await {
error!("Error on write_data after worker quit: {}", e)
}
@ -403,7 +384,7 @@ async fn pcm_worker_loop(
// accpet arbitrarily size buffers
#[cfg(windows)]
buffer_writer
.check_and_prefill(mem, desc_receiver, sender)
.check_and_prefill(desc_receiver, sender)
.await?;
match desc_receiver.try_next() {
@ -417,11 +398,10 @@ async fn pcm_worker_loop(
return Err(Error::InvalidPCMWorkerState);
}
Ok(Some(desc_chain)) => {
let (desc_index, reader, writer) =
get_index_with_reader_and_writer(mem, desc_chain)?;
let (reader, writer) = get_reader_and_writer(&desc_chain);
sender
.send(PcmResponse {
desc_index,
desc_chain,
status: write_data(dst_buf, Some(reader), &mut buffer_writer)
.await
.into(),
@ -444,7 +424,7 @@ async fn pcm_worker_loop(
let worker_status = status_mutex.lock().await;
match *worker_status {
WorkerStatus::Quit => {
drain_desc_receiver(desc_receiver, mem, sender).await?;
drain_desc_receiver(desc_receiver, sender).await?;
if let Err(e) = read_data(src_buf, None, period_bytes).await {
error!("Error on read_data after worker quit: {}", e)
}
@ -464,12 +444,11 @@ async fn pcm_worker_loop(
return Err(Error::InvalidPCMWorkerState);
}
Ok(Some(desc_chain)) => {
let (desc_index, _reader, mut writer) =
get_index_with_reader_and_writer(mem, desc_chain)?;
let (_reader, mut writer) = get_reader_and_writer(&desc_chain);
sender
.send(PcmResponse {
desc_index,
desc_chain,
status: read_data(src_buf, Some(&mut writer), period_bytes)
.await
.into(),
@ -488,15 +467,13 @@ async fn pcm_worker_loop(
// Defer pcm message response to the pcm response worker
async fn defer_pcm_response_to_worker(
desc_chain: DescriptorChain,
mem: &GuestMemory,
status: virtio_snd_pcm_status,
response_sender: &mut mpsc::UnboundedSender<PcmResponse>,
) -> Result<(), Error> {
let desc_index = desc_chain.index;
let writer = Writer::new(mem.clone(), desc_chain).map_err(Error::DescriptorChain)?;
let writer = Writer::new(&desc_chain);
response_sender
.send(PcmResponse {
desc_index,
desc_chain,
status,
writer,
done: None,
@ -507,7 +484,7 @@ async fn defer_pcm_response_to_worker(
fn send_pcm_response_with_writer<I: SignalableInterrupt>(
mut writer: Writer,
desc_index: u16,
desc_chain: DescriptorChain,
mem: &GuestMemory,
queue: &mut Queue,
interrupt: &I,
@ -519,7 +496,7 @@ fn send_pcm_response_with_writer<I: SignalableInterrupt>(
.consume_bytes(writer.available_bytes() - std::mem::size_of::<virtio_snd_pcm_status>());
}
writer.write_obj(status).map_err(Error::WriteResponse)?;
queue.add_used(mem, desc_index, writer.bytes_written() as u32);
queue.add_used(mem, desc_chain, writer.bytes_written() as u32);
queue.trigger_interrupt(mem, interrupt);
Ok(())
}
@ -559,7 +536,7 @@ pub async fn send_pcm_response_worker<I: SignalableInterrupt>(
if let Some(r) = res {
send_pcm_response_with_writer(
r.writer,
r.desc_index,
r.desc_chain,
mem,
&mut *queue.lock().await,
&interrupt,
@ -610,8 +587,7 @@ pub async fn handle_pcm_queue(
res = next_async => res.map_err(Error::Async)?,
};
let mut reader =
Reader::new(mem.clone(), desc_chain.clone()).map_err(Error::DescriptorChain)?;
let mut reader = Reader::new(&desc_chain);
let pcm_xfer: virtio_snd_pcm_xfer = reader.read_obj().map_err(Error::ReadMessage)?;
let stream_id: usize = u32::from(pcm_xfer.stream_id) as usize;
@ -627,7 +603,6 @@ pub async fn handle_pcm_queue(
);
defer_pcm_response_to_worker(
desc_chain,
mem,
virtio_snd_pcm_status::new(StatusCode::IoErr, 0),
&mut response_sender,
)
@ -655,7 +630,6 @@ pub async fn handle_pcm_queue(
}
defer_pcm_response_to_worker(
desc_chain,
mem,
virtio_snd_pcm_status::new(StatusCode::IoErr, 0),
&mut response_sender,
)
@ -693,11 +667,8 @@ pub async fn handle_ctrl_queue<I: SignalableInterrupt>(
}
};
let index = desc_chain.index;
let mut reader =
Reader::new(mem.clone(), desc_chain.clone()).map_err(Error::DescriptorChain)?;
let mut writer = Writer::new(mem.clone(), desc_chain).map_err(Error::DescriptorChain)?;
let mut reader = Reader::new(&desc_chain);
let mut writer = Writer::new(&desc_chain);
// Don't advance the reader
let code = reader
.clone()
@ -870,7 +841,6 @@ pub async fn handle_ctrl_queue<I: SignalableInterrupt>(
process_pcm_ctrl(
ex,
&mem.clone(),
&tx_send,
&rx_send,
streams,
@ -895,18 +865,9 @@ pub async fn handle_ctrl_queue<I: SignalableInterrupt>(
.map_err(Error::WriteResponse);
}
};
process_pcm_ctrl(
ex,
&mem.clone(),
&tx_send,
&rx_send,
streams,
cmd,
&mut writer,
stream_id,
)
.await
.and(Ok(()))?;
process_pcm_ctrl(ex, &tx_send, &rx_send, streams, cmd, &mut writer, stream_id)
.await
.and(Ok(()))?;
Ok(())
}
c => {
@ -919,7 +880,7 @@ pub async fn handle_ctrl_queue<I: SignalableInterrupt>(
};
handle_ctrl_msg.await?;
queue.add_used(mem, index, writer.bytes_written() as u32);
queue.add_used(mem, desc_chain, writer.bytes_written() as u32);
queue.trigger_interrupt(mem, &interrupt);
}
Ok(())
@ -939,8 +900,7 @@ pub async fn handle_event_queue<I: SignalableInterrupt>(
.map_err(Error::Async)?;
// TODO(woodychow): Poll and forward events from cras asynchronously (API to be added)
let index = desc_chain.index;
queue.add_used(mem, index, 0);
queue.add_used(mem, desc_chain, 0);
queue.trigger_interrupt(mem, &interrupt);
}
}

View file

@ -57,7 +57,7 @@ use crate::virtio::snd::sys::set_audio_thread_priority;
use crate::virtio::snd::sys::SysAsyncStreamObjects;
use crate::virtio::snd::sys::SysAudioStreamSourceGenerator;
use crate::virtio::snd::sys::SysBufferWriter;
use crate::virtio::DescriptorError;
use crate::virtio::DescriptorChain;
use crate::virtio::DeviceType;
use crate::virtio::Interrupt;
use crate::virtio::Queue;
@ -94,9 +94,6 @@ pub enum Error {
/// Cloning kill event failed.
#[error("Failed to clone kill event: {0}")]
CloneKillEvent(SysError),
/// Descriptor chain was invalid.
#[error("Failed to valildate descriptor chain: {0}")]
DescriptorChain(DescriptorError),
// Future error.
#[error("Unexpected error. Done was not triggered before dropped: {0}")]
DoneNotTriggered(Canceled),
@ -194,7 +191,7 @@ const SUPPORTED_FRAME_RATES: u64 = 1 << VIRTIO_SND_PCM_RATE_8000
// Response from pcm_worker to pcm_queue
pub struct PcmResponse {
pub(crate) desc_index: u16,
pub(crate) desc_chain: DescriptorChain,
pub(crate) status: virtio_snd_pcm_status, // response to the pcm message
pub(crate) writer: Writer,
pub(crate) done: Option<oneshot::Sender<()>>, // when pcm response is written to the queue

View file

@ -14,7 +14,6 @@ use cros_async::Executor;
use futures::channel::mpsc;
use futures::Future;
use futures::TryFutureExt;
use vm_memory::GuestMemory;
use super::Error;
use super::PcmResponse;
@ -202,13 +201,11 @@ impl StreamInfo {
/// Prepares the stream, putting it into [`VIRTIO_SND_R_PCM_PREPARE`] state.
///
/// * `ex`: [`Executor`] to run the pcm worker.
/// * `mem`: [`GuestMemory`] to read or write stream data in descriptor chain.
/// * `tx_send`: Sender for sending `PcmResponse` for tx queue. (playback stream)
/// * `rx_send`: Sender for sending `PcmResponse` for rx queue. (capture stream)
pub async fn prepare(
&mut self,
ex: &Executor,
mem: GuestMemory,
tx_send: &mpsc::UnboundedSender<PcmResponse>,
rx_send: &mpsc::UnboundedSender<PcmResponse>,
) -> Result<(), Error> {
@ -299,7 +296,6 @@ impl StreamInfo {
stream,
receiver,
self.status_mutex.clone(),
mem,
pcm_sender,
);
self.worker_future = Some(Box::new(ex.spawn_local(f).into_future()));
@ -385,8 +381,6 @@ impl StreamInfo {
#[cfg(test)]
mod tests {
use audio_streams::NoopStreamSourceGenerator;
#[cfg(windows)]
use vm_memory::GuestAddress;
use super::*;
@ -419,14 +413,10 @@ mod tests {
expected_ok: bool,
expected_state: u32,
) -> StreamInfo {
#[cfg(windows)]
let mem = GuestMemory::new(&[(GuestAddress(0), 0x10000)]).unwrap();
#[cfg(unix)]
let mem = GuestMemory::new(&[]).unwrap();
let (tx_send, _) = mpsc::unbounded();
let (rx_send, _) = mpsc::unbounded();
let result = ex.run_until(stream.prepare(ex, mem, &tx_send, &rx_send));
let result = ex.run_until(stream.prepare(ex, &tx_send, &rx_send));
assert_eq!(result.unwrap().is_ok(), expected_ok);
assert_eq!(stream.state, expected_state);
stream

View file

@ -28,7 +28,7 @@ use win_audio::AudioSharedFormat;
use win_audio::WinAudioServer;
use win_audio::WinStreamSourceGenerator;
use crate::virtio::snd::common_backend::async_funcs::get_index_with_reader_and_writer;
use crate::virtio::snd::common_backend::async_funcs::get_reader_and_writer;
use crate::virtio::snd::common_backend::async_funcs::PlaybackBufferWriter;
use crate::virtio::snd::common_backend::stream_info::StreamInfo;
use crate::virtio::snd::common_backend::DirectionalStream;
@ -224,7 +224,6 @@ impl PlaybackBufferWriter for WinBufferWriter {
async fn check_and_prefill(
&mut self,
mem: &GuestMemory,
desc_receiver: &mut UnboundedReceiver<DescriptorChain>,
sender: &mut UnboundedSender<PcmResponse>,
) -> Result<(), Error> {
@ -244,13 +243,12 @@ impl PlaybackBufferWriter for WinBufferWriter {
return Err(Error::InvalidPCMWorkerState);
}
Ok(Some(desc_chain)) => {
let (desc_index, mut reader, writer) =
get_index_with_reader_and_writer(mem, desc_chain)?;
let (mut reader, writer) = get_reader_and_writer(&desc_chain);
self.write_to_resampler_buffer(&mut reader)?;
sender
.send(PcmResponse {
desc_index,
desc_chain,
status: Ok(0).into(),
writer,
done: None,

View file

@ -36,7 +36,6 @@ use zerocopy::AsBytes;
use crate::virtio::copy_config;
use crate::virtio::device_constants::snd::virtio_snd_config;
use crate::virtio::DescriptorError;
use crate::virtio::DeviceType;
use crate::virtio::Interrupt;
use crate::virtio::Queue;
@ -56,14 +55,8 @@ pub enum SoundError {
ClientNew(Error),
#[error("Failed to create event pair: {0}")]
CreateEvent(BaseError),
#[error("Failed to create Reader from descriptor chain: {0}")]
CreateReader(DescriptorError),
#[error("Failed to create thread: {0}")]
CreateThread(IoError),
#[error("Failed to create Writer from descriptor chain: {0}")]
CreateWriter(DescriptorError),
#[error("Error with queue descriptor: {0}")]
Descriptor(DescriptorError),
#[error("Attempted a {0} operation while on the wrong state: {1}, this is a bug")]
ImpossibleState(&'static str, &'static str),
#[error("Error consuming queue event: {0}")]

View file

@ -231,14 +231,11 @@ impl Stream {
match self.current_state {
StreamState::Started => {
while let Some(desc) = self.buffer_queue.pop_front() {
let mut reader = Reader::new(self.guest_memory.clone(), desc.clone())
.map_err(SoundError::CreateReader)?;
let mut reader = Reader::new(&desc);
// Ignore the first buffer, it was already read by the time this thread
// receives the descriptor
reader.consume(std::mem::size_of::<virtio_snd_pcm_xfer>());
let desc_index = desc.index;
let mut writer = Writer::new(self.guest_memory.clone(), desc)
.map_err(SoundError::CreateWriter)?;
let mut writer = Writer::new(&desc);
let io_res = if self.capture {
let buffer_size =
writer.available_bytes() - std::mem::size_of::<virtio_snd_pcm_status>();
@ -285,7 +282,7 @@ impl Stream {
let mut io_queue_lock = self.io_queue.lock();
io_queue_lock.add_used(
&self.guest_memory,
desc_index,
desc,
writer.bytes_written() as u32,
);
io_queue_lock.trigger_interrupt(&self.guest_memory, &self.interrupt);
@ -419,8 +416,7 @@ pub fn reply_control_op_status(
queue: &Arc<Mutex<Queue>>,
interrupt: &Interrupt,
) -> Result<()> {
let desc_index = desc.index;
let mut writer = Writer::new(guest_memory.clone(), desc).map_err(SoundError::Descriptor)?;
let mut writer = Writer::new(&desc);
writer
.write_obj(virtio_snd_hdr {
code: Le32::from(code),
@ -428,7 +424,7 @@ pub fn reply_control_op_status(
.map_err(SoundError::QueueIO)?;
{
let mut queue_lock = queue.lock();
queue_lock.add_used(guest_memory, desc_index, writer.bytes_written() as u32);
queue_lock.add_used(guest_memory, desc, writer.bytes_written() as u32);
queue_lock.trigger_interrupt(guest_memory, interrupt);
}
Ok(())
@ -443,8 +439,7 @@ pub fn reply_pcm_buffer_status(
queue: &Arc<Mutex<Queue>>,
interrupt: &Interrupt,
) -> Result<()> {
let desc_index = desc.index;
let mut writer = Writer::new(guest_memory.clone(), desc).map_err(SoundError::Descriptor)?;
let mut writer = Writer::new(&desc);
if writer.available_bytes() > std::mem::size_of::<virtio_snd_pcm_status>() {
writer
.consume_bytes(writer.available_bytes() - std::mem::size_of::<virtio_snd_pcm_status>());
@ -457,7 +452,7 @@ pub fn reply_pcm_buffer_status(
.map_err(SoundError::QueueIO)?;
{
let mut queue_lock = queue.lock();
queue_lock.add_used(guest_memory, desc_index, writer.bytes_written() as u32);
queue_lock.add_used(guest_memory, desc, writer.bytes_written() as u32);
queue_lock.trigger_interrupt(guest_memory, interrupt);
}
Ok(())

View file

@ -204,8 +204,7 @@ impl Worker {
// Err if it encounters an unrecoverable error.
fn process_controlq_buffers(&mut self) -> Result<()> {
while let Some(avail_desc) = lock_pop_unlock(&self.control_queue, &self.guest_memory) {
let mut reader = Reader::new(self.guest_memory.clone(), avail_desc.clone())
.map_err(SoundError::Descriptor)?;
let mut reader = Reader::new(&avail_desc);
let available_bytes = reader.available_bytes();
if available_bytes < std::mem::size_of::<virtio_snd_hdr>() {
error!(
@ -278,9 +277,7 @@ impl Worker {
VIRTIO_SND_S_OK
}
};
let desc_index = avail_desc.index;
let mut writer = Writer::new(self.guest_memory.clone(), avail_desc)
.map_err(SoundError::Descriptor)?;
let mut writer = Writer::new(&avail_desc);
writer
.write_obj(virtio_snd_hdr {
code: Le32::from(code),
@ -290,7 +287,7 @@ impl Worker {
let mut queue_lock = self.control_queue.lock();
queue_lock.add_used(
&self.guest_memory,
desc_index,
avail_desc,
writer.bytes_written() as u32,
);
queue_lock.trigger_interrupt(&self.guest_memory, &self.interrupt);
@ -384,15 +381,10 @@ impl Worker {
fn process_event_triggered(&mut self) -> Result<()> {
while let Some(evt) = self.vios_client.pop_event() {
if let Some(desc) = self.event_queue.pop(&self.guest_memory) {
let desc_index = desc.index;
let mut writer =
Writer::new(self.guest_memory.clone(), desc).map_err(SoundError::Descriptor)?;
let mut writer = Writer::new(&desc);
writer.write_obj(evt).map_err(SoundError::QueueIO)?;
self.event_queue.add_used(
&self.guest_memory,
desc_index,
writer.bytes_written() as u32,
);
self.event_queue
.add_used(&self.guest_memory, desc, writer.bytes_written() as u32);
{
self.event_queue
.trigger_interrupt(&self.guest_memory, &self.interrupt);
@ -507,9 +499,7 @@ impl Worker {
code: u32,
info_vec: Vec<T>,
) -> Result<()> {
let desc_index = desc.index;
let mut writer =
Writer::new(self.guest_memory.clone(), desc).map_err(SoundError::Descriptor)?;
let mut writer = Writer::new(&desc);
writer
.write_obj(virtio_snd_hdr {
code: Le32::from(code),
@ -520,11 +510,7 @@ impl Worker {
}
{
let mut queue_lock = self.control_queue.lock();
queue_lock.add_used(
&self.guest_memory,
desc_index,
writer.bytes_written() as u32,
);
queue_lock.add_used(&self.guest_memory, desc, writer.bytes_written() as u32);
queue_lock.trigger_interrupt(&self.guest_memory, &self.interrupt);
}
Ok(())
@ -578,8 +564,7 @@ fn io_loop(
}
};
while let Some(avail_desc) = lock_pop_unlock(queue, &guest_memory) {
let mut reader = Reader::new(guest_memory.clone(), avail_desc.clone())
.map_err(SoundError::Descriptor)?;
let mut reader = Reader::new(&avail_desc);
let xfer: virtio_snd_pcm_xfer = reader.read_obj().map_err(SoundError::QueueIO)?;
let stream_id = xfer.stream_id.to_native();
if stream_id as usize >= senders.len() {

View file

@ -40,36 +40,29 @@ pub fn process_rx<I: SignalableInterrupt, T: TapT>(
}
};
let index = desc_chain.index;
let bytes_written = match Writer::new(mem.clone(), desc_chain) {
Ok(mut writer) => {
match writer.write_from(&mut tap, writer.available_bytes()) {
Ok(_) => {}
Err(ref e) if e.kind() == io::ErrorKind::WriteZero => {
warn!("net: rx: buffer is too small to hold frame");
break;
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
// No more to read from the tap.
break;
}
Err(e) => {
warn!("net: rx: failed to write slice: {}", e);
return Err(NetError::WriteBuffer(e));
}
};
let mut writer = Writer::new(&desc_chain);
writer.bytes_written() as u32
match writer.write_from(&mut tap, writer.available_bytes()) {
Ok(_) => {}
Err(ref e) if e.kind() == io::ErrorKind::WriteZero => {
warn!("net: rx: buffer is too small to hold frame");
break;
}
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
// No more to read from the tap.
break;
}
Err(e) => {
error!("net: failed to create Writer: {}", e);
0
warn!("net: rx: failed to write slice: {}", e);
return Err(NetError::WriteBuffer(e));
}
};
let bytes_written = writer.bytes_written() as u32;
if bytes_written > 0 {
rx_queue.pop_peeked(mem);
rx_queue.add_used(mem, index, bytes_written);
rx_queue.add_used(mem, desc_chain, bytes_written);
needs_interrupt = true;
}
}
@ -92,29 +85,23 @@ pub fn process_tx<I: SignalableInterrupt, T: TapT>(
mut tap: &mut T,
) {
while let Some(desc_chain) = tx_queue.pop(mem) {
let index = desc_chain.index;
match Reader::new(mem.clone(), desc_chain) {
Ok(mut reader) => {
let expected_count = reader.available_bytes();
match reader.read_to(&mut tap, expected_count) {
Ok(count) => {
// Tap writes must be done in one call. If the entire frame was not
// written, it's an error.
if count != expected_count {
error!(
"net: tx: wrote only {} bytes of {} byte frame",
count, expected_count
);
}
}
Err(e) => error!("net: tx: failed to write frame to tap: {}", e),
let mut reader = Reader::new(&desc_chain);
let expected_count = reader.available_bytes();
match reader.read_to(&mut tap, expected_count) {
Ok(count) => {
// Tap writes must be done in one call. If the entire frame was not
// written, it's an error.
if count != expected_count {
error!(
"net: tx: wrote only {} bytes of {} byte frame",
count, expected_count
);
}
}
Err(e) => error!("net: failed to create Reader: {}", e),
Err(e) => error!("net: tx: failed to write frame to tap: {}", e),
}
tx_queue.add_used(mem, index, 0);
tx_queue.add_used(mem, desc_chain, 0);
}
tx_queue.trigger_interrupt(mem, interrupt);

View file

@ -44,31 +44,23 @@ fn rx_single_frame(
None => return false,
};
let index = desc_chain.index;
let bytes_written = match Writer::new(mem.clone(), desc_chain) {
Ok(mut writer) => {
match writer.write_all(&rx_buf[0..rx_count]) {
Ok(()) => (),
Err(ref e) if e.kind() == io::ErrorKind::WriteZero => {
warn!(
"net: rx: buffer is too small to hold frame of size {}",
rx_count
);
}
Err(e) => {
warn!("net: rx: failed to write slice: {}", e);
}
};
writer.bytes_written() as u32
let mut writer = Writer::new(&desc_chain);
match writer.write_all(&rx_buf[0..rx_count]) {
Ok(()) => (),
Err(ref e) if e.kind() == io::ErrorKind::WriteZero => {
warn!(
"net: rx: buffer is too small to hold frame of size {}",
rx_count
);
}
Err(e) => {
error!("net: failed to create Writer: {}", e);
0
warn!("net: rx: failed to write slice: {}", e);
}
};
rx_queue.add_used(mem, index, bytes_written);
let bytes_written = writer.bytes_written() as u32;
rx_queue.add_used(mem, desc_chain, bytes_written);
true
}
@ -175,26 +167,20 @@ pub fn process_tx<I: SignalableInterrupt, T: TapT>(
}
while let Some(desc_chain) = tx_queue.pop(mem) {
let index = desc_chain.index;
match Reader::new(mem.clone(), desc_chain) {
Ok(reader) => {
let mut frame = [0u8; MAX_BUFFER_SIZE];
match read_to_end(reader, &mut frame[..]) {
Ok(len) => {
// We need to copy frame into continuous buffer before writing it to
// slirp because tap requires frame to complete in a single write.
if let Err(err) = tap.write_all(&frame[..len]) {
error!("net: tx: failed to write to tap: {}", err);
}
}
Err(e) => error!("net: tx: failed to read frame into buffer: {}", e),
let mut reader = Reader::new(&desc_chain);
let mut frame = [0u8; MAX_BUFFER_SIZE];
match read_to_end(reader, &mut frame[..]) {
Ok(len) => {
// We need to copy frame into continuous buffer before writing it to
// slirp because tap requires frame to complete in a single write.
if let Err(err) = tap.write_all(&frame[..len]) {
error!("net: tx: failed to write to tap: {}", err);
}
}
Err(e) => error!("net: failed to create Reader: {}", e),
Err(e) => error!("net: tx: failed to read frame into buffer: {}", e),
}
tx_queue.add_used(mem, index, 0);
tx_queue.add_used(mem, desc_chain, 0);
}
tx_queue.trigger_interrupt(mem, interrupt);

View file

@ -20,7 +20,6 @@ use thiserror::Error;
use vm_memory::GuestMemory;
use super::DescriptorChain;
use super::DescriptorError;
use super::DeviceType;
use super::Interrupt;
use super::Queue;
@ -54,9 +53,9 @@ pub trait TpmBackend: Send {
}
impl Worker {
fn perform_work(&mut self, desc: DescriptorChain) -> Result<u32> {
let mut reader = Reader::new(self.mem.clone(), desc.clone()).map_err(Error::Descriptor)?;
let mut writer = Writer::new(self.mem.clone(), desc).map_err(Error::Descriptor)?;
fn perform_work(&mut self, desc: &DescriptorChain) -> Result<u32> {
let mut reader = Reader::new(desc);
let mut writer = Writer::new(desc);
let available_bytes = reader.available_bytes();
if available_bytes > TPM_BUFSIZE {
@ -92,9 +91,7 @@ impl Worker {
fn process_queue(&mut self) -> NeedsInterrupt {
let mut needs_interrupt = NeedsInterrupt::No;
while let Some(avail_desc) = self.queue.pop(&self.mem) {
let index = avail_desc.index;
let len = match self.perform_work(avail_desc) {
let len = match self.perform_work(&avail_desc) {
Ok(len) => len,
Err(err) => {
error!("{}", err);
@ -102,7 +99,7 @@ impl Worker {
}
};
self.queue.add_used(&self.mem, index, len);
self.queue.add_used(&self.mem, avail_desc, len);
needs_interrupt = NeedsInterrupt::Yes;
}
@ -257,8 +254,6 @@ enum Error {
BufferTooSmall { size: usize, required: usize },
#[error("vtpm command is too long: {size} > {} bytes", TPM_BUFSIZE)]
CommandTooLong { size: usize },
#[error("virtio descriptor error: {0}")]
Descriptor(DescriptorError),
#[error("vtpm failed to read from guest memory: {0}")]
Read(io::Error),
#[error(

View file

@ -51,8 +51,8 @@ impl gpu::QueueReader for SharedReader {
self.queue.lock().pop(mem)
}
fn add_used(&self, mem: &GuestMemory, desc_index: u16, len: u32) {
self.queue.lock().add_used(mem, desc_index, len)
fn add_used(&self, mem: &GuestMemory, desc_chain: DescriptorChain, len: u32) {
self.queue.lock().add_used(mem, desc_chain, len)
}
fn signal_used(&self, mem: &GuestMemory) {

View file

@ -468,21 +468,18 @@ mod test {
fn device_write(mem: &QueueMemory, q: &mut DeviceQueue, data: &[u8]) -> usize {
let desc_chain = q.pop(mem).unwrap();
let index = desc_chain.index;
let mut writer = Writer::new(mem.clone(), desc_chain).unwrap();
let mut writer = Writer::new(&desc_chain);
let written = writer.write(data).unwrap();
q.add_used(mem, index, written as u32);
q.add_used(mem, desc_chain, written as u32);
written
}
fn device_read(mem: &QueueMemory, q: &mut DeviceQueue, len: usize) -> Vec<u8> {
let desc_chain = q.pop(mem).unwrap();
let desc_index = desc_chain.index;
let mut reader = Reader::new(mem.clone(), desc_chain).unwrap();
let mut reader = Reader::new(&desc_chain);
let mut buf = vec![0; len];
reader.read_exact(&mut buf).unwrap();
q.add_used(mem, desc_index, len as u32);
q.add_used(mem, desc_chain, len as u32);
buf
}

View file

@ -597,9 +597,8 @@ impl Worker {
// If a sibling is disconnected, send 0-length data to the guest and return an error.
if !is_connected {
// Send 0-length data
let index = desc.index;
self.rx_queue.pop_peeked(&self.mem);
self.rx_queue.add_used(&self.mem, index, 0 /* len */);
self.rx_queue.add_used(&self.mem, desc, 0 /* len */);
if !self.rx_queue.trigger_interrupt(&self.mem, &self.interrupt) {
// This interrupt should always be injected. We'd rather fail
// fast if there is an error.
@ -620,7 +619,6 @@ impl Worker {
.context("failed to read Vhost-user sibling message header")?;
let buf = self.get_sibling_msg_data::<R>(&hdr)?;
let index = desc.index;
let bytes_written = {
let res = if !R::is_header_valid(&hdr) {
Err(anyhow!("invalid header for {:?}", hdr.get_code()))
@ -633,7 +631,7 @@ impl Worker {
// message to the virt queue and return how many bytes
// were written.
match res {
Ok(()) => self.forward_msg_to_device(desc, &hdr, &buf),
Ok(()) => self.forward_msg_to_device(&desc, &hdr, &buf),
Err(e) => Err(e),
}
};
@ -646,7 +644,7 @@ impl Worker {
Ok(bytes_written) => {
// The driver is able to deal with a descriptor with 0 bytes written.
self.rx_queue.pop_peeked(&self.mem);
self.rx_queue.add_used(&self.mem, index, bytes_written);
self.rx_queue.add_used(&self.mem, desc, bytes_written);
if !self.rx_queue.trigger_interrupt(&self.mem, &self.interrupt) {
// This interrupt should always be injected. We'd rather fail
// fast if there is an error.
@ -719,28 +717,22 @@ impl Worker {
// queue. Returns the number of bytes written to the virt queue.
fn forward_msg_to_device<R: Req>(
&mut self,
desc_chain: DescriptorChain,
desc_chain: &DescriptorChain,
hdr: &VhostUserMsgHeader<R>,
buf: &[u8],
) -> Result<u32> {
let bytes_written = match Writer::new(self.mem.clone(), desc_chain) {
Ok(mut writer) => {
if writer.available_bytes()
< buf.len() + std::mem::size_of::<VhostUserMsgHeader<R>>()
{
bail!("rx buffer too small to accomodate server data");
}
// Write header first then any data. Do these separately to prevent any reorders.
let mut written = writer
.write(hdr.as_slice())
.context("failed to write header")?;
written += writer.write(buf).context("failed to write message body")?;
written as u32
}
Err(e) => {
bail!("failed to create Writer: {}", e);
}
};
let mut writer = Writer::new(desc_chain);
if writer.available_bytes() < buf.len() + std::mem::size_of::<VhostUserMsgHeader<R>>() {
bail!("rx buffer too small to accomodate server data");
}
// Write header first then any data. Do these separately to prevent any reorders.
let mut written = writer
.write(hdr.as_slice())
.context("failed to write header")?;
written += writer.write(buf).context("failed to write message body")?;
let bytes_written = written as u32;
Ok(bytes_written)
}
@ -1083,45 +1075,40 @@ impl Worker {
// the Vhost-user sibling over its socket connection.
fn process_tx(&mut self) -> Result<()> {
while let Some(desc_chain) = self.tx_queue.pop(&self.mem) {
let index = desc_chain.index;
match Reader::new(self.mem.clone(), desc_chain) {
Ok(mut reader) => {
let expected_count = reader.available_bytes();
let mut msg = vec![0; expected_count];
reader
.read_exact(&mut msg)
.context("virtqueue read failed")?;
let mut reader = Reader::new(&desc_chain);
let expected_count = reader.available_bytes();
let mut msg = vec![0; expected_count];
reader
.read_exact(&mut msg)
.context("virtqueue read failed")?;
// This may be a SlaveReq, but the bytes of any valid SlaveReq
// are also a valid MasterReq.
let hdr =
vhost_header_from_bytes::<MasterReq>(&msg).context("message too short")?;
let (dest, (msg, fd)) = if hdr.is_reply() {
(self.slave_req_helper.as_mut().as_mut(), (msg, None))
} else {
let processed_msg = self.process_message_from_backend(msg)?;
(
self.slave_req_fd
.as_mut()
.context("missing slave_req_fd")?
.as_mut(),
processed_msg,
)
};
// This may be a SlaveReq, but the bytes of any valid SlaveReq
// are also a valid MasterReq.
let hdr = vhost_header_from_bytes::<MasterReq>(&msg).context("message too short")?;
let (dest, (msg, fd)) = if hdr.is_reply() {
(self.slave_req_helper.as_mut().as_mut(), (msg, None))
} else {
let processed_msg = self.process_message_from_backend(msg)?;
(
self.slave_req_fd
.as_mut()
.context("missing slave_req_fd")?
.as_mut(),
processed_msg,
)
};
if let Some(fd) = fd {
let written = dest
.send_with_fd(&[IoSlice::new(msg.as_slice())], fd.as_raw_descriptor())
.context("failed to foward message")?;
dest.write_all(&msg[written..])
} else {
dest.write_all(msg.as_slice())
}
if let Some(fd) = fd {
let written = dest
.send_with_fd(&[IoSlice::new(msg.as_slice())], fd.as_raw_descriptor())
.context("failed to foward message")?;
}
Err(e) => error!("failed to create Reader: {}", e),
dest.write_all(&msg[written..])
} else {
dest.write_all(msg.as_slice())
}
self.tx_queue.add_used(&self.mem, index, 0);
.context("failed to foward message")?;
self.tx_queue.add_used(&self.mem, desc_chain, 0);
if !self.tx_queue.trigger_interrupt(&self.mem, &self.interrupt) {
panic!("failed inject tx queue interrupt");
}

View file

@ -4,13 +4,13 @@
use std::collections::BTreeMap;
use crate::virtio::queue::DescriptorChain;
use crate::virtio::video::command::QueueType;
use crate::virtio::video::device::AsyncCmdResponse;
use crate::virtio::video::device::AsyncCmdTag;
use crate::virtio::video::error::VideoError;
use crate::virtio::video::protocol;
use crate::virtio::video::response::CmdResponse;
use crate::virtio::DescriptorChain;
/// AsyncCmdDescMap is a BTreeMap which stores descriptor chains in which asynchronous
/// responses will be written.

View file

@ -28,7 +28,6 @@ use zerocopy::AsBytes;
use crate::virtio;
use crate::virtio::copy_config;
use crate::virtio::virtio_device::VirtioDevice;
use crate::virtio::DescriptorError;
use crate::virtio::DeviceType;
use crate::virtio::Interrupt;
use crate::Suspendable;
@ -97,9 +96,6 @@ pub enum Error {
/// Making an EventAsync failed.
#[error("failed to create an EventAsync: {0}")]
EventAsyncCreationFailed(cros_async::AsyncError),
/// A DescriptorChain contains invalid data.
#[error("DescriptorChain contains invalid data: {0}")]
InvalidDescriptorChain(DescriptorError),
/// Failed to read a virtio-video command.
#[error("failed to read a command from the guest: {0}")]
ReadFailure(ReadCmdError),

View file

@ -22,7 +22,6 @@ use cros_async::SelectResult;
use futures::FutureExt;
use vm_memory::GuestMemory;
use crate::virtio::queue::DescriptorChain;
use crate::virtio::queue::Queue;
use crate::virtio::video::async_cmd_desc_map::AsyncCmdDescMap;
use crate::virtio::video::command::QueueType;
@ -40,6 +39,7 @@ use crate::virtio::video::response;
use crate::virtio::video::response::Response;
use crate::virtio::video::Error;
use crate::virtio::video::Result;
use crate::virtio::DescriptorChain;
use crate::virtio::Reader;
use crate::virtio::SignalableInterrupt;
use crate::virtio::Writer;
@ -87,9 +87,7 @@ impl<I: SignalableInterrupt> Worker<I> {
return Ok(());
}
while let Some((desc, response)) = responses.pop_front() {
let desc_index = desc.index;
let mut writer =
Writer::new(self.mem.clone(), desc).map_err(Error::InvalidDescriptorChain)?;
let mut writer = Writer::new(&desc);
if let Err(e) = response.write(&mut writer) {
error!(
"failed to write a command response for {:?}: {}",
@ -97,7 +95,7 @@ impl<I: SignalableInterrupt> Worker<I> {
);
}
self.cmd_queue
.add_used(&self.mem, desc_index, writer.bytes_written() as u32);
.add_used(&self.mem, desc, writer.bytes_written() as u32);
}
self.cmd_queue
.trigger_interrupt(&self.mem, &self.cmd_queue_interrupt);
@ -111,14 +109,12 @@ impl<I: SignalableInterrupt> Worker<I> {
.pop(&self.mem)
.ok_or(Error::DescriptorNotAvailable)?;
let desc_index = desc.index;
let mut writer =
Writer::new(self.mem.clone(), desc).map_err(Error::InvalidDescriptorChain)?;
let mut writer = Writer::new(&desc);
event
.write(&mut writer)
.map_err(|error| Error::WriteEventFailure { event, error })?;
self.event_queue
.add_used(&self.mem, desc_index, writer.bytes_written() as u32);
.add_used(&self.mem, desc, writer.bytes_written() as u32);
self.event_queue
.trigger_interrupt(&self.mem, &self.event_queue_interrupt);
Ok(())
@ -205,8 +201,7 @@ impl<I: SignalableInterrupt> Worker<I> {
desc: DescriptorChain,
) -> Result<VecDeque<WritableResp>> {
let mut responses: VecDeque<WritableResp> = Default::default();
let mut reader =
Reader::new(self.mem.clone(), desc.clone()).map_err(Error::InvalidDescriptorChain)?;
let mut reader = Reader::new(&desc);
let cmd = VideoCmd::from_reader(&mut reader).map_err(Error::ReadFailure)?;

View file

@ -58,7 +58,7 @@ use crate::virtio::vsock::sys::windows::protocol::virtio_vsock_event;
use crate::virtio::vsock::sys::windows::protocol::virtio_vsock_hdr;
use crate::virtio::vsock::sys::windows::protocol::vsock_op;
use crate::virtio::vsock::sys::windows::protocol::TYPE_STREAM_SOCKET;
use crate::virtio::DescriptorError;
use crate::virtio::DescriptorChain;
use crate::virtio::DeviceType;
use crate::virtio::Interrupt;
use crate::virtio::Queue;
@ -79,12 +79,8 @@ pub enum VsockError {
CloneDescriptor(SysError),
#[error("Failed to create EventAsync: {0}")]
CreateEventAsync(AsyncError),
#[error("Failed to create queue reader: {0}")]
CreateReader(DescriptorError),
#[error("Failed to create wait context: {0}")]
CreateWaitContext(SysError),
#[error("Failed to create queue writer: {0}")]
CreateWriter(DescriptorError),
#[error("Failed to read queue: {0}")]
ReadQueue(io::Error),
#[error("Failed to reset event object: {0}")]
@ -518,7 +514,15 @@ impl Worker {
) -> Result<()> {
loop {
// Run continuously until exit evt
let (mut reader, index) = self.get_next(&mut queue, &mut queue_evt).await?;
let avail_desc = match queue.next_async(&self.mem, &mut queue_evt).await {
Ok(d) => d,
Err(e) => {
error!("vsock: Failed to read descriptor {}", e);
return Err(VsockError::AwaitQueue(e));
}
};
let mut reader = Reader::new(&avail_desc);
while reader.available_bytes() >= std::mem::size_of::<virtio_vsock_hdr>() {
let header = match reader.read_obj::<virtio_vsock_hdr>() {
Ok(hdr) => hdr,
@ -549,7 +553,7 @@ impl Worker {
};
}
queue.add_used(&self.mem, index, 0);
queue.add_used(&self.mem, avail_desc, 0);
queue.trigger_interrupt(&self.mem, &self.interrupt);
}
}
@ -1111,13 +1115,12 @@ impl Worker {
}
}
// Get a `Writer`, or wait if there's no descriptors currently available on
// the queue.
async fn get_next_writer(
async fn write_bytes_to_queue(
&self,
queue: &mut Queue,
queue_evt: &mut EventAsync,
) -> Result<(Writer, u16)> {
bytes: &[u8],
) -> Result<()> {
let avail_desc = match queue.next_async(&self.mem, queue_evt).await {
Ok(d) => d,
Err(e) => {
@ -1125,22 +1128,8 @@ impl Worker {
return Err(VsockError::AwaitQueue(e));
}
};
let index = avail_desc.index;
Writer::new(self.mem.clone(), avail_desc)
.map_err(|e| {
error!("vsock: failed to create Writer: {}", e);
VsockError::CreateWriter(e)
})
.map(|r| (r, index))
}
async fn write_bytes_to_queue(
&self,
queue: &mut Queue,
queue_evt: &mut EventAsync,
bytes: &[u8],
) -> Result<()> {
let (mut writer, desc_index) = self.get_next_writer(queue, queue_evt).await?;
let mut writer = Writer::new(&avail_desc);
let res = writer.write_all(bytes);
if let Err(e) = res {
@ -1154,7 +1143,7 @@ impl Worker {
let bytes_written = writer.bytes_written() as u32;
if bytes_written > 0 {
queue.add_used(&self.mem, desc_index, bytes_written);
queue.add_used(&self.mem, avail_desc, bytes_written);
queue.trigger_interrupt(&self.mem, &self.interrupt);
Ok(())
} else {
@ -1170,7 +1159,15 @@ impl Worker {
loop {
// Log but don't act on events. They are reserved exclusively for guest migration events
// resulting in CID resets, which we don't support.
let (mut reader, _index) = self.get_next(&mut queue, &mut queue_evt).await?;
let avail_desc = match queue.next_async(&self.mem, &mut queue_evt).await {
Ok(d) => d,
Err(e) => {
error!("vsock: Failed to read descriptor {}", e);
return Err(VsockError::AwaitQueue(e));
}
};
let mut reader = Reader::new(&avail_desc);
for event in reader.iter::<virtio_vsock_event>() {
if event.is_ok() {
error!(
@ -1182,27 +1179,6 @@ impl Worker {
}
}
async fn get_next(
&self,
queue: &mut Queue,
queue_evt: &mut EventAsync,
) -> Result<(Reader, u16)> {
let avail_desc = match queue.next_async(&self.mem, queue_evt).await {
Ok(d) => d,
Err(e) => {
error!("vsock: Failed to read descriptor {}", e);
return Err(VsockError::AwaitQueue(e));
}
};
let index = avail_desc.index;
Reader::new(self.mem.clone(), avail_desc)
.map_err(|e| {
error!("vsock: failed to create Reader: {}", e);
VsockError::CreateReader(e)
})
.map(|r| (r, index))
}
fn run(
&mut self,
rx_queue: Queue,

View file

@ -1700,30 +1700,21 @@ pub fn process_in_queue<I: SignalableInterrupt>(
break;
};
let index = desc.index;
let mut should_pop = false;
if let Some(in_resp) = state.next_recv() {
let bytes_written = match Writer::new(mem.clone(), desc) {
Ok(mut writer) => {
match encode_resp(&mut writer, in_resp) {
Ok(()) => {
should_pop = true;
}
Err(e) => {
error!("failed to encode response to descriptor chain: {}", e);
}
};
writer.bytes_written() as u32
let mut writer = Writer::new(&desc);
match encode_resp(&mut writer, in_resp) {
Ok(()) => {
should_pop = true;
}
Err(e) => {
error!("invalid descriptor: {}", e);
0
error!("failed to encode response to descriptor chain: {}", e);
}
};
}
let bytes_written = writer.bytes_written() as u32;
needs_interrupt = true;
in_queue.pop_peeked(mem);
in_queue.add_used(mem, index, bytes_written);
in_queue.add_used(mem, desc, bytes_written);
} else {
break;
}
@ -1752,33 +1743,23 @@ pub fn process_out_queue<I: SignalableInterrupt>(
) {
let mut needs_interrupt = false;
while let Some(desc) = out_queue.pop(mem) {
let desc_index = desc.index;
match (
Reader::new(mem.clone(), desc.clone()),
Writer::new(mem.clone(), desc),
) {
(Ok(mut reader), Ok(mut writer)) => {
let resp = match state.execute(&mut reader) {
Ok(r) => r,
Err(e) => WlResp::Err(Box::new(e)),
};
let mut reader = Reader::new(&desc);
let mut writer = Writer::new(&desc);
match encode_resp(&mut writer, resp) {
Ok(()) => {}
Err(e) => {
error!("failed to encode response to descriptor chain: {}", e);
}
}
let resp = match state.execute(&mut reader) {
Ok(r) => r,
Err(e) => WlResp::Err(Box::new(e)),
};
out_queue.add_used(mem, desc_index, writer.bytes_written() as u32);
needs_interrupt = true;
}
(_, Err(e)) | (Err(e), _) => {
error!("invalid descriptor: {}", e);
out_queue.add_used(mem, desc_index, 0);
needs_interrupt = true;
match encode_resp(&mut writer, resp) {
Ok(()) => {}
Err(e) => {
error!("failed to encode response to descriptor chain: {}", e);
}
}
out_queue.add_used(mem, desc, writer.bytes_written() as u32);
needs_interrupt = true;
}
if needs_interrupt {