diff --git a/Cargo.lock b/Cargo.lock index 961475dee1..a4cc15d787 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -694,6 +694,7 @@ dependencies = [ "anyhow", "argh", "async-task", + "async-trait", "audio_streams", "balloon_control", "base", diff --git a/crosvm_cli/src/sys/windows/exit.rs b/crosvm_cli/src/sys/windows/exit.rs index bb568b1492..3d1786d77f 100644 --- a/crosvm_cli/src/sys/windows/exit.rs +++ b/crosvm_cli/src/sys/windows/exit.rs @@ -338,6 +338,7 @@ pub enum Exit { CommonChildSetupError = 0xE0000098, CreateImeThread = 0xE0000099, OpenDiskImage = 0xE000009A, + VirtioSoundDeviceNew = 0xE000009B, } impl From for ExitCode { diff --git a/devices/Cargo.toml b/devices/Cargo.toml index 329bccb3f5..55cf5a5feb 100644 --- a/devices/Cargo.toml +++ b/devices/Cargo.toml @@ -33,6 +33,7 @@ argh = "0.1.7" async-task = "4" acpi_tables = {path = "../acpi_tables" } anyhow = "*" +async-trait = "0.1.36" audio_streams = "*" balloon_control = { path = "../common/balloon_control" } base = { path = "../base" } diff --git a/devices/src/virtio/descriptor_utils.rs b/devices/src/virtio/descriptor_utils.rs index 44b80d5217..c1a28ab6ae 100644 --- a/devices/src/virtio/descriptor_utils.rs +++ b/devices/src/virtio/descriptor_utils.rs @@ -339,6 +339,19 @@ impl Reader { read } + /// Reads data from the descriptor chain buffer and passes the `VolatileSlice`s to the callback + /// `cb`. + pub fn read_to_cb usize>( + &mut self, + cb: C, + count: usize, + ) -> usize { + let iovs = self.regions.get_remaining_with_count(&self.mem, count); + let written = cb(&iovs[..]); + self.regions.consume(written); + written + } + /// Reads data from the descriptor chain buffer into a writable object. /// Returns the number of bytes read from the descriptor chain buffer. /// The number of bytes read can be less than `count` if there isn't diff --git a/devices/src/virtio/snd/common_backend/async_funcs.rs b/devices/src/virtio/snd/common_backend/async_funcs.rs index bb8623f9ea..5a3a4eb8c9 100644 --- a/devices/src/virtio/snd/common_backend/async_funcs.rs +++ b/devices/src/virtio/snd/common_backend/async_funcs.rs @@ -8,6 +8,7 @@ use std::io::Read; use std::io::Write; use std::rc::Rc; +use async_trait::async_trait; use audio_streams::capture::AsyncCaptureBuffer; use audio_streams::AsyncPlaybackBuffer; use base::debug; @@ -27,14 +28,16 @@ use futures::SinkExt; use futures::StreamExt; use thiserror::Error as ThisError; use vm_memory::GuestMemory; +#[cfg(windows)] +use win_audio::AudioSharedFormat; -use super::DirectionalStream; use super::Error; use super::SndData; use super::WorkerStatus; use crate::virtio::snd::common::*; use crate::virtio::snd::common_backend::stream_info::SetParams; use crate::virtio::snd::common_backend::stream_info::StreamInfo; +use crate::virtio::snd::common_backend::DirectionalStream; use crate::virtio::snd::common_backend::PcmResponse; use crate::virtio::snd::constants::*; use crate::virtio::snd::layout::*; @@ -44,6 +47,46 @@ use crate::virtio::Reader; use crate::virtio::SignalableInterrupt; use crate::virtio::Writer; +// TODO(b/246601226): Remove once a generic audio_stream solution that can accpet +// arbitrarily size buffers. +/// Trait to wrap system specific helpers for writing to endpoint playback buffers. +#[async_trait(?Send)] +pub trait PlaybackBufferWriter { + fn new( + guest_period_bytes: usize, + #[cfg(windows)] frame_size: usize, + #[cfg(windows)] frame_rate: usize, + #[cfg(windows)] guest_num_channels: usize, + #[cfg(windows)] audio_shared_format: AudioSharedFormat, + ) -> Self + where + Self: Sized; + + /// Returns the period of the endpoint device. + fn endpoint_period_bytes(&self) -> usize; + + /// Read audio samples from the tx virtqueue. + fn copy_to_buffer( + &mut self, + dst_buf: &mut AsyncPlaybackBuffer<'_>, + reader: &mut Reader, + ) -> Result { + dst_buf.copy_from(reader).map_err(Error::Io) + } + /// Check to see if an additional read from the tx virtqueue is needed during a playback + /// loop. If so, read from the virtqueue. + /// + /// Prefill will happen, for example, if the endpoint buffer requires a 513 frame period, but + /// each tx virtqueue read only produces 480 frames. + #[cfg(windows)] + async fn check_and_prefill( + &mut self, + mem: &GuestMemory, + desc_receiver: &mut mpsc::UnboundedReceiver, + sender: &mut mpsc::UnboundedSender, + ) -> Result<(), Error>; +} + #[derive(Debug)] enum VirtioSndPcmCmd { SetParams { set_params: SetParams }, @@ -184,17 +227,20 @@ async fn process_pcm_ctrl( async fn write_data( mut dst_buf: AsyncPlaybackBuffer<'_>, reader: Option, - period_bytes: usize, + buffer_writer: &mut Box, ) -> Result<(), Error> { let transferred = match reader { - Some(mut reader) => dst_buf.copy_from(&mut reader), - None => dst_buf.copy_from(&mut io::repeat(0).take(period_bytes as u64)), - } - .map_err(Error::Io)?; - if transferred as usize != period_bytes { + Some(mut reader) => buffer_writer.copy_to_buffer(&mut dst_buf, &mut reader)?, + None => dst_buf + .copy_from(&mut io::repeat(0).take(buffer_writer.endpoint_period_bytes() as u64)) + .map_err(Error::Io)?, + }; + + if transferred as usize != buffer_writer.endpoint_period_bytes() { error!( "Bytes written {} != period_bytes {}", - transferred, period_bytes + transferred, + buffer_writer.endpoint_period_bytes() ); Err(Error::InvalidBufferSize) } else { @@ -288,6 +334,19 @@ async fn drain_desc_receiver( Ok(()) } +pub(crate) fn get_index_with_reader_and_writer( + mem: &GuestMemory, + desc_chain: DescriptorChain, +) -> Result<(u16, Reader, Writer), Error> { + let desc_index = desc_chain.index; + let mut reader = + Reader::new(mem.clone(), desc_chain.clone()).map_err(Error::DescriptorChain)?; + // stream_id was already read in handle_pcm_queue + reader.consume(std::mem::size_of::()); + let writer = Writer::new(mem.clone(), desc_chain).map_err(Error::DescriptorChain)?; + Ok((desc_index, reader, writer)) +} + /// Start a pcm worker that receives descriptors containing PCM frames (audio data) from the tx/rx /// queue, and forward them to CRAS. One pcm worker per stream. /// @@ -300,7 +359,6 @@ pub async fn start_pcm_worker( status_mutex: Rc>, mem: GuestMemory, mut sender: mpsc::UnboundedSender, - period_bytes: usize, ) -> Result<(), Error> { let res = pcm_worker_loop( ex, @@ -309,7 +367,6 @@ pub async fn start_pcm_worker( &status_mutex, &mem, &mut sender, - period_bytes, ) .await; *status_mutex.lock().await = WorkerStatus::Quit; @@ -332,50 +389,52 @@ async fn pcm_worker_loop( status_mutex: &Rc>, mem: &GuestMemory, sender: &mut mpsc::UnboundedSender, - period_bytes: usize, ) -> Result<(), Error> { match dstream { - DirectionalStream::Output(mut stream) => { - loop { - let dst_buf = stream - .next_playback_buffer(&ex) - .await - .map_err(Error::FetchBuffer)?; - let worker_status = status_mutex.lock().await; - match *worker_status { - WorkerStatus::Quit => { - drain_desc_receiver(desc_receiver, mem, sender).await?; - if let Err(e) = write_data(dst_buf, None, period_bytes).await { - error!("Error on write_data after worker quit: {}", e) - } - break Ok(()); + #[allow(unused_mut)] + DirectionalStream::Output(mut stream, mut buffer_writer) => loop { + let dst_buf = stream + .next_playback_buffer(&ex) + .await + .map_err(Error::FetchBuffer)?; + let worker_status = status_mutex.lock().await; + match *worker_status { + WorkerStatus::Quit => { + drain_desc_receiver(desc_receiver, mem, sender).await?; + if let Err(e) = write_data(dst_buf, None, &mut buffer_writer).await { + error!("Error on write_data after worker quit: {}", e) } - WorkerStatus::Pause => { - write_data(dst_buf, None, period_bytes).await?; - } - WorkerStatus::Running => match desc_receiver.try_next() { + break Ok(()); + } + WorkerStatus::Pause => { + write_data(dst_buf, None, &mut buffer_writer).await?; + } + WorkerStatus::Running => { + // TODO(b/246601226): Remove once a generic audio_stream solution that can + // accpet arbitrarily size buffers + #[cfg(windows)] + buffer_writer + .check_and_prefill(mem, desc_receiver, sender) + .await?; + + match desc_receiver.try_next() { Err(e) => { error!("Underrun. No new DescriptorChain while running: {}", e); - write_data(dst_buf, None, period_bytes).await?; + write_data(dst_buf, None, &mut buffer_writer).await?; } Ok(None) => { error!("Unreachable. status should be Quit when the channel is closed"); - write_data(dst_buf, None, period_bytes).await?; + write_data(dst_buf, None, &mut buffer_writer).await?; return Err(Error::InvalidPCMWorkerState); } Ok(Some(desc_chain)) => { - let desc_index = desc_chain.index; - let mut reader = Reader::new(mem.clone(), desc_chain.clone()) - .map_err(Error::DescriptorChain)?; - // stream_id was already read in handle_pcm_queue - reader.consume(std::mem::size_of::()); - let writer = Writer::new(mem.clone(), desc_chain) - .map_err(Error::DescriptorChain)?; + let (desc_index, reader, writer) = + get_index_with_reader_and_writer(mem, desc_chain)?; sender .send(PcmResponse { desc_index, - status: write_data(dst_buf, Some(reader), period_bytes) + status: write_data(dst_buf, Some(reader), &mut buffer_writer) .await .into(), writer, @@ -384,64 +443,57 @@ async fn pcm_worker_loop( .await .map_err(Error::MpscSend)?; } - }, + } } } - } - DirectionalStream::Input(mut stream) => { - loop { - let src_buf = stream - .next_capture_buffer(&ex) - .await - .map_err(Error::FetchBuffer)?; + }, + DirectionalStream::Input(mut stream, period_bytes) => loop { + let src_buf = stream + .next_capture_buffer(&ex) + .await + .map_err(Error::FetchBuffer)?; - let worker_status = status_mutex.lock().await; - match *worker_status { - WorkerStatus::Quit => { - drain_desc_receiver(desc_receiver, mem, sender).await?; - if let Err(e) = read_data(src_buf, None, period_bytes).await { - error!("Error on read_data after worker quit: {}", e) - } - break Ok(()); + let worker_status = status_mutex.lock().await; + match *worker_status { + WorkerStatus::Quit => { + drain_desc_receiver(desc_receiver, mem, sender).await?; + if let Err(e) = read_data(src_buf, None, period_bytes).await { + error!("Error on read_data after worker quit: {}", e) } - WorkerStatus::Pause => { + break Ok(()); + } + WorkerStatus::Pause => { + read_data(src_buf, None, period_bytes).await?; + } + WorkerStatus::Running => match desc_receiver.try_next() { + Err(e) => { + error!("Overrun. No new DescriptorChain while running: {}", e); read_data(src_buf, None, period_bytes).await?; } - WorkerStatus::Running => match desc_receiver.try_next() { - Err(e) => { - error!("Overrun. No new DescriptorChain while running: {}", e); - read_data(src_buf, None, period_bytes).await?; - } - Ok(None) => { - error!("Unreachable. status should be Quit when the channel is closed"); - read_data(src_buf, None, period_bytes).await?; - return Err(Error::InvalidPCMWorkerState); - } - Ok(Some(desc_chain)) => { - let desc_index = desc_chain.index; - let mut reader = Reader::new(mem.clone(), desc_chain.clone()) - .map_err(Error::DescriptorChain)?; - // stream_id was already read in handle_pcm_queue - reader.consume(std::mem::size_of::()); - let mut writer = Writer::new(mem.clone(), desc_chain) - .map_err(Error::DescriptorChain)?; + Ok(None) => { + error!("Unreachable. status should be Quit when the channel is closed"); + read_data(src_buf, None, period_bytes).await?; + return Err(Error::InvalidPCMWorkerState); + } + Ok(Some(desc_chain)) => { + let (desc_index, _reader, mut writer) = + get_index_with_reader_and_writer(mem, desc_chain)?; - sender - .send(PcmResponse { - desc_index, - status: read_data(src_buf, Some(&mut writer), period_bytes) - .await - .into(), - writer, - done: None, - }) - .await - .map_err(Error::MpscSend)?; - } - }, - } + sender + .send(PcmResponse { + desc_index, + status: read_data(src_buf, Some(&mut writer), period_bytes) + .await + .into(), + writer, + done: None, + }) + .await + .map_err(Error::MpscSend)?; + } + }, } - } + }, } } @@ -745,7 +797,7 @@ pub async fn handle_ctrl_queue( the number of chmaps ({})", start_id, count, - snd_data.pcm_info.len() + snd_data.chmap_info.len() ); return writer .write_obj(VIRTIO_SND_S_BAD_MSG) diff --git a/devices/src/virtio/snd/common_backend/mod.rs b/devices/src/virtio/snd/common_backend/mod.rs index 29d7feb9b5..f23ee74030 100644 --- a/devices/src/virtio/snd/common_backend/mod.rs +++ b/devices/src/virtio/snd/common_backend/mod.rs @@ -10,7 +10,6 @@ use std::thread; use anyhow::Context; use audio_streams::BoxError; -use audio_streams::StreamSourceGenerator; use base::debug; use base::error; use base::warn; @@ -48,6 +47,9 @@ use crate::virtio::snd::parameters::Parameters; use crate::virtio::snd::parameters::StreamSourceBackend; use crate::virtio::snd::sys::create_stream_source_generators as sys_create_stream_source_generators; use crate::virtio::snd::sys::set_audio_thread_priority; +use crate::virtio::snd::sys::SysAsyncStreamObjects; +use crate::virtio::snd::sys::SysAudioStreamSourceGenerator; +use crate::virtio::snd::sys::SysBufferWriter; use crate::virtio::DescriptorError; use crate::virtio::DeviceType; use crate::virtio::Interrupt; @@ -71,6 +73,9 @@ pub enum Error { /// Creating stream failed. #[error("Failed to create stream: {0}")] CreateStream(BoxError), + /// Creating stream failed. + #[error("No stream source found.")] + EmptyStreamSource, /// Creating kill event failed. #[error("Failed to create kill event: {0}")] CreateKillEvent(SysError), @@ -131,8 +136,14 @@ pub enum Error { } pub enum DirectionalStream { - Input(Box), - Output(Box), + Input( + Box, + usize, // `period_size` in `usize` + ), + Output( + Box, + Box, + ), } #[derive(Copy, Clone, std::cmp::PartialEq, Eq)] @@ -170,10 +181,10 @@ const SUPPORTED_FRAME_RATES: u64 = 1 << VIRTIO_SND_PCM_RATE_8000 // Response from pcm_worker to pcm_queue pub struct PcmResponse { - desc_index: u16, - status: virtio_snd_pcm_status, // response to the pcm message - writer: Writer, - done: Option>, // when pcm response is written to the queue + pub(crate) desc_index: u16, + pub(crate) status: virtio_snd_pcm_status, // response to the pcm message + pub(crate) writer: Writer, + pub(crate) done: Option>, // when pcm response is written to the queue } pub struct VirtioSnd { @@ -209,7 +220,7 @@ impl VirtioSnd { pub(crate) fn create_stream_source_generators( params: &Parameters, snd_data: &SndData, -) -> Vec> { +) -> Vec { match params.backend { StreamSourceBackend::NULL => create_null_stream_source_generators(snd_data), StreamSourceBackend::Sys(backend) => { @@ -441,7 +452,7 @@ fn run_worker( snd_data: SndData, queue_evts: Vec, kill_evt: Event, - stream_source_generators: Vec>, + stream_source_generators: Vec, ) -> Result<(), String> { let ex = Executor::new().expect("Failed to create an executor"); diff --git a/devices/src/virtio/snd/common_backend/stream_info.rs b/devices/src/virtio/snd/common_backend/stream_info.rs index 860aacafb3..4277e0bcf0 100644 --- a/devices/src/virtio/snd/common_backend/stream_info.rs +++ b/devices/src/virtio/snd/common_backend/stream_info.rs @@ -6,8 +6,6 @@ use std::fmt; use std::rc::Rc; use audio_streams::SampleFormat; -use audio_streams::StreamSource; -use audio_streams::StreamSourceGenerator; use base::error; use cros_async::sync::Mutex as AsyncMutex; use cros_async::Executor; @@ -22,7 +20,11 @@ use super::WorkerStatus; use crate::virtio::snd::common::*; use crate::virtio::snd::common_backend::async_funcs::*; use crate::virtio::snd::common_backend::DirectionalStream; +use crate::virtio::snd::common_backend::SysAsyncStreamObjects; +use crate::virtio::snd::common_backend::SysBufferWriter; use crate::virtio::snd::constants::*; +use crate::virtio::snd::sys::SysAudioStreamSource; +use crate::virtio::snd::sys::SysAudioStreamSourceGenerator; use crate::virtio::DescriptorChain; /// Parameters for setting parameters in StreamInfo @@ -38,13 +40,13 @@ pub struct SetParams { /// StreamInfo represents a virtio snd stream. pub struct StreamInfo { - stream_source: Option>, - stream_source_generator: Box, - channels: u8, - format: SampleFormat, - frame_rate: u32, + pub(crate) stream_source: Option, + stream_source_generator: SysAudioStreamSourceGenerator, + pub(crate) channels: u8, + pub(crate) format: SampleFormat, + pub(crate) frame_rate: u32, buffer_bytes: usize, - period_bytes: usize, + pub(crate) period_bytes: usize, direction: u8, // VIRTIO_SND_D_* pub state: u32, // VIRTIO_SND_R_PCM_SET_PARAMS -> VIRTIO_SND_R_PCM_STOP, or 0 (uninitialized) @@ -98,7 +100,7 @@ impl StreamInfo { /// Creates a new [`StreamInfo`]. /// /// * `stream_source_generator`: Generator which generates stream source in [`StreamInfo::prepare()`]. - pub fn new(stream_source_generator: Box) -> Self { + pub fn new(stream_source_generator: SysAudioStreamSourceGenerator) -> Self { StreamInfo { stream_source: None, stream_source_generator, @@ -191,53 +193,48 @@ impl StreamInfo { .map_err(Error::GenerateStreamSource)?, ); } - // (*) - // `buffer_size` in `audio_streams` API indicates the buffer size in bytes that the stream - // consumes (or transmits) each time (next_playback/capture_buffer). - // `period_bytes` in virtio-snd device (or ALSA) indicates the device transmits (or - // consumes) for each PCM message. - // Therefore, `buffer_size` in `audio_streams` == `period_bytes` in virtio-snd. - let (stream, pcm_sender) = match self.direction { - VIRTIO_SND_D_OUTPUT => ( - DirectionalStream::Output( - self.stream_source - .as_mut() - .unwrap() - .async_new_async_playback_stream( - self.channels as usize, - self.format, - self.frame_rate, - // See (*) - self.period_bytes / frame_size, - ex, - ) - .await - .map_err(Error::CreateStream)? - .1, - ), - tx_send.clone(), - ), - VIRTIO_SND_D_INPUT => { - ( - DirectionalStream::Input( - self.stream_source - .as_mut() - .unwrap() - .async_new_async_capture_stream( - self.channels as usize, - self.format, - self.frame_rate, - // See (*) - self.period_bytes / frame_size, - &[], - ex, - ) - .await - .map_err(Error::CreateStream)? - .1, + let SysAsyncStreamObjects { stream, pcm_sender } = match self.direction { + VIRTIO_SND_D_OUTPUT => { + let sys_async_stream = self.set_up_async_playback_stream(frame_size, ex).await?; + + let buffer_writer = SysBufferWriter::new( + self.period_bytes, + #[cfg(windows)] + frame_size, + #[cfg(windows)] + usize::try_from(self.frame_rate).expect("Failed to cast from u32 to usize"), + #[cfg(windows)] + usize::try_from(self.channels).expect("Failed to cast from u32 to usize"), + #[cfg(windows)] + sys_async_stream.audio_shared_format, + ); + SysAsyncStreamObjects { + stream: DirectionalStream::Output( + sys_async_stream.async_playback_buffer_stream, + Box::new(buffer_writer), ), - rx_send.clone(), - ) + pcm_sender: tx_send.clone(), + } + } + VIRTIO_SND_D_INPUT => { + let async_stream = self + .stream_source + .as_mut() + .ok_or(Error::EmptyStreamSource)? + .new_async_capture_stream( + self.channels as usize, + self.format, + self.frame_rate, + self.period_bytes / frame_size, + &[], + ex, + ) + .map_err(Error::CreateStream)? + .1; + SysAsyncStreamObjects { + stream: DirectionalStream::Input(async_stream, self.period_bytes), + pcm_sender: rx_send.clone(), + } } _ => unreachable!(), }; @@ -254,7 +251,6 @@ impl StreamInfo { self.status_mutex.clone(), mem, pcm_sender, - self.period_bytes, ); self.worker_future = Some(Box::new(ex.spawn_local(f).into_future())); self.ex = Some(ex.clone()); @@ -336,10 +332,11 @@ impl StreamInfo { } // TODO(b/246997900): Get these new tests to run on Windows. -#[cfg(unix)] #[cfg(test)] mod tests { use audio_streams::NoopStreamSourceGenerator; + #[cfg(windows)] + use vm_memory::GuestAddress; use super::*; @@ -372,6 +369,9 @@ mod tests { expected_ok: bool, expected_state: u32, ) -> StreamInfo { + #[cfg(windows)] + let mem = GuestMemory::new(&[(GuestAddress(0), 0x10000)]).unwrap(); + #[cfg(unix)] let mem = GuestMemory::new(&[]).unwrap(); let (tx_send, _) = mpsc::unbounded(); let (rx_send, _) = mpsc::unbounded(); diff --git a/devices/src/virtio/snd/null_backend.rs b/devices/src/virtio/snd/null_backend.rs index 35c8801ea0..41f52e9c9b 100644 --- a/devices/src/virtio/snd/null_backend.rs +++ b/devices/src/virtio/snd/null_backend.rs @@ -3,14 +3,14 @@ // found in the LICENSE file. use audio_streams::NoopStreamSourceGenerator; -use audio_streams::StreamSourceGenerator; use crate::virtio::snd::common_backend::SndData; +use crate::virtio::snd::sys::SysAudioStreamSourceGenerator; pub(crate) fn create_null_stream_source_generators( snd_data: &SndData, -) -> Vec> { - let mut generators: Vec> = Vec::new(); +) -> Vec { + let mut generators: Vec = Vec::new(); generators.resize_with(snd_data.pcm_info_len(), || { Box::new(NoopStreamSourceGenerator::new()) }); diff --git a/devices/src/virtio/snd/sys/mod.rs b/devices/src/virtio/snd/sys/mod.rs index 1c73fd7c96..b4448d0400 100644 --- a/devices/src/virtio/snd/sys/mod.rs +++ b/devices/src/virtio/snd/sys/mod.rs @@ -7,7 +7,7 @@ cfg_if::cfg_if! { mod unix; use unix as platform; } else if #[cfg(windows)] { - mod windows; + pub(crate) mod windows; use windows as platform; } } @@ -15,3 +15,7 @@ cfg_if::cfg_if! { pub(crate) use platform::create_stream_source_generators; pub(crate) use platform::set_audio_thread_priority; pub use platform::StreamSourceBackend; +pub(crate) use platform::SysAsyncStreamObjects; +pub(crate) use platform::SysAudioStreamSource; +pub(crate) use platform::SysAudioStreamSourceGenerator; +pub(crate) use platform::SysBufferWriter; diff --git a/devices/src/virtio/snd/sys/unix.rs b/devices/src/virtio/snd/sys/unix.rs index 163e463566..ad4c77e2a0 100644 --- a/devices/src/virtio/snd/sys/unix.rs +++ b/devices/src/virtio/snd/sys/unix.rs @@ -2,19 +2,42 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +use async_trait::async_trait; +use audio_streams::AsyncPlaybackBufferStream; +use audio_streams::StreamSource; use audio_streams::StreamSourceGenerator; use base::set_rt_prio_limit; use base::set_rt_round_robin; use base::warn; +use cros_async::Executor; +use futures::channel::mpsc::UnboundedSender; #[cfg(feature = "audio_cras")] use libcras::CrasStreamSourceGenerator; +use crate::virtio::common_backend::PcmResponse; +use crate::virtio::snd::common_backend::async_funcs::PlaybackBufferWriter; +use crate::virtio::snd::common_backend::stream_info::StreamInfo; +use crate::virtio::snd::common_backend::DirectionalStream; +use crate::virtio::snd::common_backend::Error; use crate::virtio::snd::common_backend::SndData; -use crate::virtio::snd::parameters::Error; +use crate::virtio::snd::parameters::Error as ParametersError; use crate::virtio::snd::parameters::Parameters; const AUDIO_THREAD_RTPRIO: u16 = 10; // Matches other cros audio clients. +pub(crate) type SysAudioStreamSourceGenerator = Box; +pub(crate) type SysAudioStreamSource = Box; +pub(crate) type SysBufferWriter = UnixBufferWriter; + +pub(crate) struct SysAsyncStream { + pub(crate) async_playback_buffer_stream: Box, +} + +pub(crate) struct SysAsyncStreamObjects { + pub(crate) stream: DirectionalStream, + pub(crate) pcm_sender: UnboundedSender, +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum StreamSourceBackend { #[cfg(feature = "audio_cras")] @@ -22,13 +45,13 @@ pub enum StreamSourceBackend { } impl TryFrom<&str> for StreamSourceBackend { - type Error = Error; + type Error = ParametersError; fn try_from(s: &str) -> Result { match s { #[cfg(feature = "audio_cras")] "cras" => Ok(StreamSourceBackend::CRAS), - _ => Err(Error::InvalidBackend), + _ => Err(ParametersError::InvalidBackend), } } } @@ -68,3 +91,49 @@ pub(crate) fn set_audio_thread_priority() { warn!("Failed to set audio thread to real time: {}", e); } } + +impl StreamInfo { + /// (*) + /// `buffer_size` in `audio_streams` API indicates the buffer size in bytes that the stream + /// consumes (or transmits) each time (next_playback/capture_buffer). + /// `period_bytes` in virtio-snd device (or ALSA) indicates the device transmits (or + /// consumes) for each PCM message. + /// Therefore, `buffer_size` in `audio_streams` == `period_bytes` in virtio-snd. + pub(crate) async fn set_up_async_playback_stream( + &mut self, + frame_size: usize, + ex: &Executor, + ) -> Result { + Ok(SysAsyncStream { + async_playback_buffer_stream: self + .stream_source + .as_mut() + .ok_or(Error::EmptyStreamSource)? + .async_new_async_playback_stream( + self.channels as usize, + self.format, + self.frame_rate, + // See (*) + self.period_bytes / frame_size, + ex, + ) + .await + .map_err(Error::CreateStream)? + .1, + }) + } +} + +pub(crate) struct UnixBufferWriter { + guest_period_bytes: usize, +} + +#[async_trait(?Send)] +impl PlaybackBufferWriter for UnixBufferWriter { + fn new(guest_period_bytes: usize) -> Self { + UnixBufferWriter { guest_period_bytes } + } + fn endpoint_period_bytes(&self) -> usize { + self.guest_period_bytes + } +} diff --git a/devices/src/virtio/snd/sys/windows.rs b/devices/src/virtio/snd/sys/windows.rs index 4b47fb2471..b0e1ae6844 100644 --- a/devices/src/virtio/snd/sys/windows.rs +++ b/devices/src/virtio/snd/sys/windows.rs @@ -2,20 +2,68 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. -use audio_streams::StreamSourceGenerator; +use std::io; +use std::io::Read; +use std::slice; +use std::sync::Arc; +use crate::virtio::DescriptorChain; +use crate::virtio::Reader; +use async_trait::async_trait; +use audio_streams::AsyncPlaybackBuffer; +use audio_streams::AsyncPlaybackBufferStream; +use base::error; +use base::set_audio_thread_priorities; +use base::warn; +use cros_async::Executor; +use futures::channel::mpsc::UnboundedReceiver; +use futures::channel::mpsc::UnboundedSender; +use futures::SinkExt; +use sync::Mutex; +use vm_memory::GuestMemory; +use win_audio::async_stream::WinAudioStreamSourceGenerator; +use win_audio::intermediate_resampler_buffer::IntermediateResamplerBuffer; +use win_audio::AudioSharedFormat; +use win_audio::WinAudioServer; +use win_audio::WinStreamSourceGenerator; + +use crate::virtio::snd::common_backend::async_funcs::get_index_with_reader_and_writer; +use crate::virtio::snd::common_backend::async_funcs::PlaybackBufferWriter; +use crate::virtio::snd::common_backend::stream_info::StreamInfo; +use crate::virtio::snd::common_backend::DirectionalStream; +use crate::virtio::snd::common_backend::Error; +use crate::virtio::snd::common_backend::PcmResponse; use crate::virtio::snd::common_backend::SndData; -use crate::virtio::snd::parameters::Error; +use crate::virtio::snd::parameters::Error as ParametersError; use crate::virtio::snd::parameters::Parameters; +pub(crate) type SysAudioStreamSourceGenerator = Box; +pub(crate) type SysAudioStreamSource = Box; +pub(crate) type SysBufferWriter = WinBufferWriter; + +pub(crate) struct SysAsyncStream { + pub(crate) async_playback_buffer_stream: Box, + pub(crate) audio_shared_format: AudioSharedFormat, +} + +pub(crate) struct SysAsyncStreamObjects { + pub(crate) stream: DirectionalStream, + pub(crate) pcm_sender: UnboundedSender, +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum StreamSourceBackend {} +pub enum StreamSourceBackend { + WINAUDIO, +} impl TryFrom<&str> for StreamSourceBackend { - type Error = Error; + type Error = ParametersError; fn try_from(s: &str) -> Result { - todo!(); + match s { + "winaudio" => Ok(StreamSourceBackend::WINAUDIO), + _ => Err(ParametersError::InvalidBackend), + } } } @@ -23,8 +71,180 @@ pub(crate) fn create_stream_source_generators( _backend: StreamSourceBackend, _params: &Parameters, _snd_data: &SndData, -) -> Vec> { - todo!(); +) -> Vec { + vec![Box::new(WinAudioStreamSourceGenerator {})] } -pub(crate) fn set_audio_thread_priority() {} +pub(crate) fn set_audio_thread_priority() { + if let Err(e) = set_audio_thread_priorities() { + error!("Failed to set audio thread priority: {}", e); + } +} + +impl StreamInfo { + pub(crate) async fn set_up_async_playback_stream( + &mut self, + frame_size: usize, + ex: &Executor, + ) -> Result { + let (async_playback_buffer_stream, audio_shared_format) = self + .stream_source + .as_mut() + .ok_or(Error::EmptyStreamSource)? + .new_async_playback_stream_and_get_shared_format( + self.channels as usize, + self.format, + self.frame_rate as usize, + // `buffer_size` in `audio_streams` API indicates the buffer size in bytes that the stream + // consumes (or transmits) each time (next_playback/capture_buffer). + // `period_bytes` in virtio-snd device (or ALSA) indicates the device transmits (or + // consumes) for each PCM message. + // Therefore, `buffer_size` in `audio_streams` == `period_bytes` in virtio-snd. + self.period_bytes / frame_size, + ex, + ) + .map_err(Error::CreateStream)?; + Ok(SysAsyncStream { + async_playback_buffer_stream, + audio_shared_format, + }) + } +} + +pub(crate) struct WinBufferWriter { + guest_period_bytes: usize, + shared_audio_engine_period_bytes: usize, + guest_num_channels: usize, + intermediate_resampler_buffer: IntermediateResamplerBuffer, +} + +impl WinBufferWriter { + fn needs_prefill(&self) -> bool { + self.intermediate_resampler_buffer.ring_buf.len() + + (self + .intermediate_resampler_buffer + .guest_period_in_target_sample_rate_frames + * self.guest_num_channels) + <= self + .intermediate_resampler_buffer + .shared_audio_engine_period_in_frames + * self.guest_num_channels + } + + fn write_to_resampler_buffer(&mut self, reader: &mut Reader) -> Result { + let written = reader.read_to_cb( + |iovs| { + let mut written = 0; + for iov in iovs { + let buffer_slice = unsafe { slice::from_raw_parts(iov.as_ptr(), iov.size()) }; + self.intermediate_resampler_buffer + .convert_and_add(buffer_slice); + written += iov.size(); + } + written + }, + self.guest_period_bytes, + ); + + if written != self.guest_period_bytes { + error!( + "{} written bytes != guest period bytes of {}", + written, self.guest_period_bytes + ); + Err(Error::InvalidBufferSize) + } else { + Ok(written) + } + } +} + +#[async_trait(?Send)] +impl PlaybackBufferWriter for WinBufferWriter { + fn new( + guest_period_bytes: usize, + frame_size: usize, + frame_rate: usize, + guest_num_channels: usize, + audio_shared_format: AudioSharedFormat, + ) -> Self { + WinBufferWriter { + guest_period_bytes, + shared_audio_engine_period_bytes: audio_shared_format + .shared_audio_engine_period_in_frames + * audio_shared_format.bit_depth + / 8 + * audio_shared_format.channels, + guest_num_channels, + intermediate_resampler_buffer: IntermediateResamplerBuffer::new( + /* from */ frame_rate, + /* to */ audio_shared_format.frame_rate, + guest_period_bytes / frame_size, + audio_shared_format.shared_audio_engine_period_in_frames, + audio_shared_format.channels, + audio_shared_format.channel_mask, + ) + .expect("Failed to create intermediate resampler buffer"), + } + } + fn endpoint_period_bytes(&self) -> usize { + self.shared_audio_engine_period_bytes + } + fn copy_to_buffer( + &mut self, + dst_buf: &mut AsyncPlaybackBuffer<'_>, + reader: &mut Reader, + ) -> Result { + self.write_to_resampler_buffer(reader)?; + + if let Some(next_period) = self.intermediate_resampler_buffer.get_next_period() { + dst_buf + .copy_cb(next_period.len(), |out| out.copy_from_slice(next_period)) + .map_err(Error::Io) + } else { + warn!("Getting the next period failed. Most likely the resampler is being primed."); + dst_buf + .copy_from(&mut io::repeat(0).take(self.shared_audio_engine_period_bytes as u64)) + .map_err(Error::Io) + } + } + + async fn check_and_prefill( + &mut self, + mem: &GuestMemory, + desc_receiver: &mut UnboundedReceiver, + sender: &mut UnboundedSender, + ) -> Result<(), Error> { + if !self.needs_prefill() { + return Ok(()); + } + + match desc_receiver.try_next() { + Err(e) => { + error!( + " Prefill Underrun. No new DescriptorChain while running: {}", + e + ); + } + Ok(None) => { + error!(" Prefill Unreachable. status should be Quit when the channel is closed"); + return Err(Error::InvalidPCMWorkerState); + } + Ok(Some(desc_chain)) => { + let (desc_index, mut reader, writer) = + get_index_with_reader_and_writer(mem, desc_chain)?; + self.write_to_resampler_buffer(&mut reader)?; + + sender + .send(PcmResponse { + desc_index, + status: Ok(()).into(), + writer, + done: None, + }) + .await + .map_err(Error::MpscSend)?; + } + }; + Ok(()) + } +} diff --git a/src/sys/windows.rs b/src/sys/windows.rs index 2cc7c5f4d0..e88a0727bf 100644 --- a/src/sys/windows.rs +++ b/src/sys/windows.rs @@ -84,6 +84,8 @@ use devices::tsc::standard_deviation; use devices::tsc::TscSyncMitigations; use devices::virtio; use devices::virtio::block::block::DiskOption; +use devices::virtio::snd::common_backend::VirtioSnd; +use devices::virtio::snd::parameters::Parameters as SndParameters; #[cfg(feature = "gpu")] use devices::virtio::vhost::user::device::gpu::sys::windows::GpuVmmConfig; #[cfg(feature = "balloon")] @@ -506,6 +508,20 @@ fn create_virtio_devices( devs.push(dev); } + let features = virtio::base_features(cfg.protection_type); + let snd_params = SndParameters { + backend: ExitContext::exit_context( + "winaudio".try_into(), + Exit::VirtioSoundDeviceNew, + "failed to set up virtio sound device", + )?, + ..Default::default() + }; + devs.push(VirtioDeviceStub { + dev: Box::new(VirtioSnd::new(features, snd_params)?), + jail: None, + }); + if let Some(tube) = pvclock_device_tube { #[cfg(feature = "pvclock")] devs.push(VirtioDeviceStub { diff --git a/win_audio/src/lib.rs b/win_audio/src/lib.rs index b100cf049d..35d0220708 100644 --- a/win_audio/src/lib.rs +++ b/win_audio/src/lib.rs @@ -254,6 +254,29 @@ impl WinAudioServer for NoopStreamSource { )) } + fn new_async_playback_stream_and_get_shared_format( + &mut self, + num_channels: usize, + format: SampleFormat, + frame_rate: usize, + buffer_size: usize, + ex: &dyn audio_streams::AudioStreamsExecutor, + ) -> Result<(Box, AudioSharedFormat), BoxError> { + let (_, playback_stream) = self + .new_async_playback_stream(num_channels, format, frame_rate as u32, buffer_size, ex) + .unwrap(); + + // Set shared format to be the same as the incoming audio format. + let format = AudioSharedFormat { + bit_depth: format.sample_bytes() * 8, + frame_rate, + channels: num_channels, + shared_audio_engine_period_in_frames: buffer_size * format.sample_bytes(), + channel_mask: None, + }; + Ok((playback_stream, format)) + } + fn is_noop_stream(&self) -> bool { true }