diff --git a/devices/src/usb/xhci/device_slot.rs b/devices/src/usb/xhci/device_slot.rs index 81e6399e0c..5e10491d08 100644 --- a/devices/src/usb/xhci/device_slot.rs +++ b/devices/src/usb/xhci/device_slot.rs @@ -50,6 +50,8 @@ use crate::utils::FailHandle; pub enum Error { #[error("bad device context: {0}")] BadDeviceContextAddr(GuestAddress), + #[error("device slot get a bad endpoint id: {0}")] + BadEndpointId(u8), #[error("bad input context address: {0}")] BadInputContextAddr(GuestAddress), #[error("device slot get a bad port id: {0}")] @@ -412,7 +414,10 @@ impl DeviceSlot { } // Assigns the device address and initializes slot and endpoint 0 context. - pub fn set_address(&self, trb: &AddressDeviceCommandTrb) -> Result { + pub fn set_address( + self: &Arc, + trb: &AddressDeviceCommandTrb, + ) -> Result { if !self.enabled.load(Ordering::SeqCst) { error!( "trying to set address to a disabled device slot {}", @@ -450,9 +455,6 @@ impl DeviceSlot { "port id {} is assigned to slot id {}", port_id, self.slot_id ); - let endpoint_context_addr = self - .get_device_context_addr()? - .unchecked_add(size_of::() as u64); // Initialize the control endpoint. Endpoint id = 1. self.set_trc( @@ -465,7 +467,7 @@ impl DeviceSlot { self.interrupter.clone(), self.slot_id, 1, - endpoint_context_addr, + Arc::downgrade(self), ) .map_err(Error::CreateTransferController)?, ), @@ -516,7 +518,7 @@ impl DeviceSlot { // Adds or drops multiple endpoints in the device slot. pub fn configure_endpoint( - &self, + self: &Arc, trb: &ConfigureEndpointCommandTrb, ) -> Result { let input_control_context = if trb.get_deconfigure() { @@ -780,17 +782,13 @@ impl DeviceSlot { self.port_id.reset(); } - fn add_one_endpoint(&self, device_context_index: u8) -> Result<()> { + fn add_one_endpoint(self: &Arc, device_context_index: u8) -> Result<()> { xhci_trace!( "adding one endpoint, device context index {}", device_context_index ); let mut device_context = self.get_device_context()?; let transfer_ring_index = (device_context_index - 1) as usize; - let endpoint_context_addr = self - .get_device_context_addr()? - .unchecked_add(size_of::() as u64) - .unchecked_add(size_of::() as u64 * transfer_ring_index as u64); let trc = TransferRingController::new( self.mem.clone(), self.hub @@ -800,7 +798,7 @@ impl DeviceSlot { self.interrupter.clone(), self.slot_id, device_context_index, - endpoint_context_addr, + Arc::downgrade(self), ) .map_err(Error::CreateTransferController)?; trc.set_dequeue_pointer( @@ -817,7 +815,7 @@ impl DeviceSlot { self.set_device_context(device_context) } - fn drop_one_endpoint(&self, device_context_index: u8) -> Result<()> { + fn drop_one_endpoint(self: &Arc, device_context_index: u8) -> Result<()> { let endpoint_index = (device_context_index - 1) as usize; self.set_trc(endpoint_index, None); let mut ctx = self.get_device_context()?; @@ -883,4 +881,26 @@ impl DeviceSlot { ctx.slot_context.set_slot_state(state); self.set_device_context(ctx) } + + pub fn halt_endpoint(&self, endpoint_id: u8) -> Result<()> { + if !valid_endpoint_id(endpoint_id) { + return Err(Error::BadEndpointId(endpoint_id)); + } + let index = endpoint_id - 1; + let mut device_context = self.get_device_context()?; + let endpoint_context = &mut device_context.endpoint_context[index as usize]; + match self.get_trc(index as usize) { + Some(trc) => { + endpoint_context.set_tr_dequeue_pointer(DequeuePtr::new(trc.get_dequeue_pointer())); + endpoint_context.set_dequeue_cycle_state(trc.get_consumer_cycle_state()); + } + None => { + error!("trc for endpoint {} not found", endpoint_id); + return Err(Error::BadEndpointId(endpoint_id)); + } + } + endpoint_context.set_endpoint_state(EndpointState::Halted); + self.set_device_context(device_context)?; + Ok(()) + } } diff --git a/devices/src/usb/xhci/transfer_ring_controller.rs b/devices/src/usb/xhci/transfer_ring_controller.rs index da02624d0a..a17a07abb5 100644 --- a/devices/src/usb/xhci/transfer_ring_controller.rs +++ b/devices/src/usb/xhci/transfer_ring_controller.rs @@ -3,13 +3,14 @@ // found in the LICENSE file. use std::sync::Arc; +use std::sync::Weak; use anyhow::Context; use base::Event; use sync::Mutex; -use vm_memory::GuestAddress; use vm_memory::GuestMemory; +use super::device_slot::DeviceSlot; use super::interrupter::Interrupter; use super::usb_hub::UsbPort; use super::xhci_abi::TransferDescriptor; @@ -32,7 +33,6 @@ pub struct TransferRingTrbHandler { slot_id: u8, endpoint_id: u8, transfer_manager: XhciTransferManager, - endpoint_context_addr: GuestAddress, } impl TransferDescriptorHandler for TransferRingTrbHandler { @@ -49,7 +49,6 @@ impl TransferDescriptorHandler for TransferRingTrbHandler { self.endpoint_id, descriptor, completion_event, - self.endpoint_context_addr, ); xhci_transfer .send_to_backend_if_valid() @@ -75,7 +74,7 @@ impl TransferRingController { interrupter: Arc>, slot_id: u8, endpoint_id: u8, - endpoint_context_addr: GuestAddress, + device_slot: Weak, ) -> Result, TransferRingControllerError> { RingBufferController::new_with_handler( format!("transfer ring slot_{} ep_{}", slot_id, endpoint_id), @@ -87,8 +86,7 @@ impl TransferRingController { interrupter, slot_id, endpoint_id, - transfer_manager: XhciTransferManager::new(), - endpoint_context_addr, + transfer_manager: XhciTransferManager::new(device_slot), }, ) } diff --git a/devices/src/usb/xhci/xhci_transfer.rs b/devices/src/usb/xhci/xhci_transfer.rs index f21cd3a84d..b5f7276ed8 100644 --- a/devices/src/usb/xhci/xhci_transfer.rs +++ b/devices/src/usb/xhci/xhci_transfer.rs @@ -21,10 +21,10 @@ use sync::Mutex; use thiserror::Error; use usb_util::TransferStatus; use usb_util::UsbRequestSetup; -use vm_memory::GuestAddress; use vm_memory::GuestMemory; use vm_memory::GuestMemoryError; +use super::device_slot::DeviceSlot; use super::interrupter::Error as InterrupterError; use super::interrupter::Interrupter; use super::scatter_gather_buffer::Error as BufferError; @@ -32,9 +32,6 @@ use super::scatter_gather_buffer::ScatterGatherBuffer; use super::usb_hub::Error as HubError; use super::usb_hub::UsbPort; use super::xhci_abi::AddressedTrb; -use super::xhci_abi::DequeuePtr; -use super::xhci_abi::EndpointContext; -use super::xhci_abi::EndpointState; use super::xhci_abi::Error as TrbError; use super::xhci_abi::EventDataTrb; use super::xhci_abi::SetupStageTrb; @@ -55,6 +52,8 @@ pub enum Error { CreateBuffer(BufferError), #[error("cannot detach from port: {0}")] DetachPort(HubError), + #[error("failed to halt the endpoint: {0}")] + HaltEndpoint(u8), #[error("failed to read guest memory: {0}")] ReadGuestMemory(GuestMemoryError), #[error("cannot send interrupt: {0}")] @@ -183,13 +182,15 @@ impl XhciTransferType { #[derive(Clone)] pub struct XhciTransferManager { transfers: Arc>>>>, + device_slot: Weak, } impl XhciTransferManager { /// Create a new manager. - pub fn new() -> XhciTransferManager { + pub fn new(device_slot: Weak) -> XhciTransferManager { XhciTransferManager { transfers: Arc::new(Mutex::new(Vec::new())), + device_slot, } } @@ -203,7 +204,6 @@ impl XhciTransferManager { endpoint_id: u8, transfer_trbs: TransferDescriptor, completion_event: Event, - endpoint_context_addr: GuestAddress, ) -> XhciTransfer { assert!(!transfer_trbs.is_empty()); let transfer_dir = { @@ -226,7 +226,7 @@ impl XhciTransferManager { endpoint_id, transfer_dir, transfer_trbs, - endpoint_context_addr, + device_slot: self.device_slot.clone(), }; self.transfers.lock().push(Arc::downgrade(&t.state)); t @@ -262,7 +262,7 @@ impl XhciTransferManager { impl Default for XhciTransferManager { fn default() -> Self { - Self::new() + Self::new(Weak::new()) } } @@ -280,7 +280,7 @@ pub struct XhciTransfer { transfer_dir: TransferDirection, transfer_trbs: TransferDescriptor, transfer_completion_event: Event, - endpoint_context_addr: GuestAddress, + device_slot: Weak, } impl Drop for XhciTransfer { @@ -354,18 +354,12 @@ impl XhciTransfer { .map_err(Error::WriteCompletionEvent)?; } TransferStatus::Stalled => { - let mut context = self.get_endpoint_context()?; - let dequeue_pointer = match self.transfer_trbs.last() { - Some(atrb) => atrb.gpa, - None => context.get_tr_dequeue_pointer().get_gpa().0, - }; - warn!( - "xhci: endpoint is stalled. set state to Halted and dequeue pointer to {:#x}", - dequeue_pointer - ); - context.set_endpoint_state(EndpointState::Halted); - context.set_tr_dequeue_pointer(DequeuePtr::new(GuestAddress(dequeue_pointer))); - self.set_endpoint_context(context)?; + warn!("xhci: endpoint is stalled. set state to Halted"); + if let Some(device_slot) = self.device_slot.upgrade() { + device_slot + .halt_endpoint(self.endpoint_id) + .map_err(|_| Error::HaltEndpoint(self.endpoint_id))?; + } self.transfer_completion_event .signal() .map_err(Error::WriteCompletionEvent)?; @@ -507,20 +501,6 @@ impl XhciTransfer { } Ok(valid) } - - fn get_endpoint_context(&self) -> Result { - let ctx = self - .mem - .read_obj_from_addr(self.endpoint_context_addr) - .map_err(Error::ReadGuestMemory)?; - Ok(ctx) - } - - fn set_endpoint_context(&self, endpoint_context: EndpointContext) -> Result<()> { - self.mem - .write_obj_at_addr(endpoint_context, self.endpoint_context_addr) - .map_err(Error::WriteGuestMemory) - } } fn trb_is_valid(atrb: &AddressedTrb) -> bool {