devices: vhost-user: remove Arc from SlaveListener and SlaveReqHandler

The SlaveReqHandler should be the sole owner of a VhostUserBackend since
it calls its mutable methods, so we shouldn't need an Arc here. If
sharing is necessary users can use their own locking mechanism.

Single ownership will help to retrieve the backend after a client
disconnects, to reuse it for another connection if needed.

BUG=b:229554679
BUG=b:216407443
TEST=cargo test -p vmm_vhost
TEST=vhost-user console device works.
TEST=cargo test -p devices

Change-Id: I2358c807ac3ddb1ee4b29d97df0ade5a3e30a85a
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/platform/crosvm/+/3591108
Commit-Queue: Alexandre Courbot <acourbot@chromium.org>
Reviewed-by: Richard Zhang <rizhang@google.com>
Reviewed-by: Keiichi Watanabe <keiichiw@chromium.org>
Tested-by: kokoro <noreply+kokoro@google.com>
This commit is contained in:
Alexandre Courbot 2022-04-14 14:58:10 +09:00 committed by Chromeos LUCI
parent c9b84cf570
commit 64a3fc2d08
9 changed files with 48 additions and 62 deletions

View file

@ -825,9 +825,7 @@ mod tests {
});
// Device side
let handler = Arc::new(std::sync::Mutex::new(DeviceRequestHandler::new(
FakeBackend::new(),
)));
let handler = std::sync::Mutex::new(DeviceRequestHandler::new(FakeBackend::new()));
let mut listener = SlaveListener::<SocketEndpoint<_>, _>::new(listener, handler).unwrap();
// Notify listener is ready.

View file

@ -280,8 +280,7 @@ where
.context("failed to accept an incoming connection")
})
.await?;
let req_handler =
SlaveReqHandler::from_stream(socket, Arc::new(std::sync::Mutex::new(self)));
let req_handler = SlaveReqHandler::from_stream(socket, std::sync::Mutex::new(self));
run_handler(req_handler, ex).await
}
@ -299,11 +298,8 @@ where
let mut listener = VfioListener::new(driver)
.map_err(|e| anyhow!("failed to create a VFIO listener: {}", e))
.and_then(|l| {
SlaveListener::<VfioEndpoint<_, _>, _>::new(
l,
Arc::new(std::sync::Mutex::new(self)),
)
.map_err(|e| anyhow!("failed to create SlaveListener: {}", e))
SlaveListener::<VfioEndpoint<_, _>, _>::new(l, std::sync::Mutex::new(self))
.map_err(|e| anyhow!("failed to create SlaveListener: {}", e))
})?;
let req_handler = listener
@ -376,9 +372,7 @@ mod tests {
});
// Device side
let handler = Arc::new(std::sync::Mutex::new(DeviceRequestHandler::new(
FakeBackend::new(),
)));
let handler = std::sync::Mutex::new(DeviceRequestHandler::new(FakeBackend::new()));
let mut listener = SlaveListener::<SocketEndpoint<_>, _>::new(listener, handler).unwrap();
// Notify listener is ready.

View file

@ -149,7 +149,7 @@ where
EventAsync::new(exit_event, ex).context("failed to create an async event")?;
let mut req_handler =
SlaveReqHandler::from_stream(vhost_user_tube, Arc::new(std::sync::Mutex::new(self)));
SlaveReqHandler::from_stream(vhost_user_tube, std::sync::Mutex::new(self));
let read_event_fut = read_event.next_val().fuse();
let close_event_fut = close_event.next_val().fuse();
@ -220,9 +220,7 @@ mod tests {
});
// Device side
let backend = Arc::new(std::sync::Mutex::new(DeviceRequestHandler::new(
FakeBackend::new(),
)));
let backend = std::sync::Mutex::new(DeviceRequestHandler::new(FakeBackend::new()));
let mut req_handler = SlaveReqHandler::from_stream(dev_tube, backend);

View file

@ -521,7 +521,7 @@ impl VhostUserSlaveReqHandlerMut for VsockBackend {
async fn run_device<P: AsRef<Path>>(
ex: &Executor,
socket: P,
backend: Arc<StdMutex<VsockBackend>>,
backend: StdMutex<VsockBackend>,
) -> anyhow::Result<()> {
let listener = UnixListener::bind(socket)
.map(UnlinkUnixListener)
@ -581,7 +581,6 @@ fn run_vvu_device<P: AsRef<Path>>(
}),
)
.map(StdMutex::new)
.map(Arc::new)
.context("failed to create `VsockBackend`")?;
let driver = VvuDevice::new(device);
@ -623,8 +622,7 @@ pub fn run_vsock_device(program_name: &str, args: &[&str]) -> anyhow::Result<()>
(Some(socket), None) => {
let backend =
VsockBackend::new(&ex, opts.cid, opts.vhost_socket, HandlerType::VhostUser)
.map(StdMutex::new)
.map(Arc::new)?;
.map(StdMutex::new)?;
// TODO: Replace the `and_then` with `Result::flatten` once it is stabilized.
ex.run_until(run_device(&ex, socket, backend))

View file

@ -10,15 +10,12 @@ pub(crate) mod tests {
};
use crate::master::Master;
use crate::message::MasterReq;
use tempfile::{Builder, TempDir};
#[cfg(feature = "device")]
use {
crate::{
slave::SlaveListener,
slave_req_handler::{SlaveReqHandler, VhostUserSlaveReqHandler},
},
std::sync::Arc,
use crate::{
slave::SlaveListener,
slave_req_handler::{SlaveReqHandler, VhostUserSlaveReqHandler},
};
use tempfile::{Builder, TempDir};
pub(crate) type TestMaster = Master<SocketEndpoint<MasterReq>>;
pub(crate) type TestEndpoint = SocketEndpoint<MasterReq>;
@ -39,7 +36,7 @@ pub(crate) mod tests {
#[cfg(feature = "device")]
pub(crate) fn create_master_slave_pair<S>(
backend: Arc<S>,
backend: S,
) -> (TestMaster, SlaveReqHandler<S, TestEndpoint>)
where
S: VhostUserSlaveReqHandler,

View file

@ -10,10 +10,7 @@ pub(crate) mod tests {
use crate::master::Master;
use crate::message::MasterReq;
#[cfg(feature = "device")]
use {
crate::slave_req_handler::{SlaveReqHandler, VhostUserSlaveReqHandler},
std::sync::Arc,
};
use crate::slave_req_handler::{SlaveReqHandler, VhostUserSlaveReqHandler};
use crate::SystemStream;
pub(crate) type TestEndpoint = TubeEndpoint<MasterReq>;
@ -27,7 +24,7 @@ pub(crate) mod tests {
#[cfg(feature = "device")]
pub(crate) fn create_master_slave_pair<S>(
backend: Arc<S>,
backend: S,
) -> (TestMaster, SlaveReqHandler<S, TestEndpoint>)
where
S: VhostUserSlaveReqHandler,

View file

@ -225,7 +225,7 @@ mod tests {
#[test]
fn create_dummy_slave() {
let slave = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
let slave = Mutex::new(DummySlaveReqHandler::new());
slave.set_owner().unwrap();
assert!(slave.set_owner().is_err());
@ -233,40 +233,40 @@ mod tests {
#[test]
fn test_set_owner() {
let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
let (master, mut slave) = create_master_slave_pair(slave_be.clone());
let slave_be = Mutex::new(DummySlaveReqHandler::new());
let (master, mut slave) = create_master_slave_pair(slave_be);
assert!(!slave_be.lock().unwrap().owned);
assert!(!slave.as_ref().lock().unwrap().owned);
master.set_owner().unwrap();
slave.handle_request().unwrap();
assert!(slave_be.lock().unwrap().owned);
assert!(slave.as_ref().lock().unwrap().owned);
master.set_owner().unwrap();
assert!(slave.handle_request().is_err());
assert!(slave_be.lock().unwrap().owned);
assert!(slave.as_ref().lock().unwrap().owned);
}
#[test]
fn test_set_features() {
let mbar = Arc::new(Barrier::new(2));
let sbar = mbar.clone();
let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
let (mut master, mut slave) = create_master_slave_pair(slave_be.clone());
let slave_be = Mutex::new(DummySlaveReqHandler::new());
let (mut master, mut slave) = create_master_slave_pair(slave_be);
thread::spawn(move || {
slave.handle_request().unwrap();
assert!(slave_be.lock().unwrap().owned);
assert!(slave.as_ref().lock().unwrap().owned);
slave.handle_request().unwrap();
slave.handle_request().unwrap();
assert_eq!(
slave_be.lock().unwrap().acked_features,
slave.as_ref().lock().unwrap().acked_features,
VIRTIO_FEATURES & !0x1
);
slave.handle_request().unwrap();
slave.handle_request().unwrap();
assert_eq!(
slave_be.lock().unwrap().acked_protocol_features,
slave.as_ref().lock().unwrap().acked_protocol_features,
VhostUserProtocolFeatures::all().bits()
);
@ -292,26 +292,26 @@ mod tests {
fn test_master_slave_process() {
let mbar = Arc::new(Barrier::new(2));
let sbar = mbar.clone();
let slave_be = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
let (mut master, mut slave) = create_master_slave_pair(slave_be.clone());
let slave_be = Mutex::new(DummySlaveReqHandler::new());
let (mut master, mut slave) = create_master_slave_pair(slave_be);
thread::spawn(move || {
// set_own()
slave.handle_request().unwrap();
assert!(slave_be.lock().unwrap().owned);
assert!(slave.as_ref().lock().unwrap().owned);
// get/set_features()
slave.handle_request().unwrap();
slave.handle_request().unwrap();
assert_eq!(
slave_be.lock().unwrap().acked_features,
slave.as_ref().lock().unwrap().acked_features,
VIRTIO_FEATURES & !0x1
);
slave.handle_request().unwrap();
slave.handle_request().unwrap();
assert_eq!(
slave_be.lock().unwrap().acked_protocol_features,
slave.as_ref().lock().unwrap().acked_protocol_features,
VhostUserProtocolFeatures::all().bits()
);

View file

@ -5,8 +5,6 @@
//!
//! These are used on platforms where the slave has to listen for connections (e.g. POSIX only).
use std::sync::Arc;
use super::connection::{Endpoint, Listener};
use super::message::*;
use super::{Result, SlaveReqHandler, VhostUserSlaveReqHandler};
@ -14,14 +12,14 @@ use super::{Result, SlaveReqHandler, VhostUserSlaveReqHandler};
/// Vhost-user slave side connection listener.
pub struct SlaveListener<E: Endpoint<MasterReq>, S: VhostUserSlaveReqHandler> {
listener: E::Listener,
backend: Option<Arc<S>>,
backend: Option<S>,
}
/// Sets up a listener for incoming master connections, and handles construction
/// of a Slave on success.
impl<E: Endpoint<MasterReq>, S: VhostUserSlaveReqHandler> SlaveListener<E, S> {
/// Create a unix domain socket for incoming master connections.
pub fn new(listener: E::Listener, backend: Arc<S>) -> Result<Self> {
pub fn new(listener: E::Listener, backend: S) -> Result<Self> {
Ok(SlaveListener {
listener,
backend: Some(backend),
@ -57,7 +55,7 @@ mod tests {
#[test]
fn test_slave_listener_set_nonblocking() {
let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
let backend = Mutex::new(DummySlaveReqHandler::new());
let listener =
Listener::new("/tmp/vhost_user_lib_unit_test_slave_nonblocking", true).unwrap();
let slave_listener = SlaveListener::<Endpoint<_>, _>::new(listener, backend).unwrap();
@ -76,7 +74,7 @@ mod tests {
use crate::Master;
let path = "/tmp/vhost_user_lib_unit_test_slave_accept";
let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
let backend = Mutex::new(DummySlaveReqHandler::new());
let listener = Listener::new(path, true).unwrap();
let mut slave_listener = SlaveListener::<Endpoint<_>, _>::new(listener, backend).unwrap();

View file

@ -5,7 +5,7 @@ use base::{AsRawDescriptor, RawDescriptor};
use std::fs::File;
use std::mem;
use std::slice;
use std::sync::{Arc, Mutex};
use std::sync::Mutex;
use data_model::DataInit;
@ -395,7 +395,7 @@ impl<E: Endpoint<MasterReq> + AsRawDescriptor> AsRawDescriptor for SlaveReqHelpe
pub struct SlaveReqHandler<S: VhostUserSlaveReqHandler, E: Endpoint<MasterReq>> {
slave_req_helper: SlaveReqHelper<E>,
// the vhost-user backend device object
backend: Arc<S>,
backend: S,
virtio_features: u64,
acked_virtio_features: u64,
@ -408,14 +408,20 @@ pub struct SlaveReqHandler<S: VhostUserSlaveReqHandler, E: Endpoint<MasterReq>>
impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S, MasterReqEndpoint> {
/// Create a vhost-user slave endpoint from a connected socket.
pub fn from_stream(socket: SystemStream, backend: Arc<S>) -> Self {
pub fn from_stream(socket: SystemStream, backend: S) -> Self {
Self::new(MasterReqEndpoint::from(socket), backend)
}
}
impl<S: VhostUserSlaveReqHandler, E: Endpoint<MasterReq>> AsRef<S> for SlaveReqHandler<S, E> {
fn as_ref(&self) -> &S {
&self.backend
}
}
impl<S: VhostUserSlaveReqHandler, E: Endpoint<MasterReq>> SlaveReqHandler<S, E> {
/// Create a vhost-user slave endpoint.
pub(super) fn new(endpoint: E, backend: Arc<S>) -> Self {
pub(super) fn new(endpoint: E, backend: S) -> Self {
SlaveReqHandler {
slave_req_helper: SlaveReqHelper::new(endpoint, backend.protocol()),
backend,
@ -432,7 +438,7 @@ impl<S: VhostUserSlaveReqHandler, E: Endpoint<MasterReq>> SlaveReqHandler<S, E>
/// # Arguments
/// * - `path` - path of Unix domain socket listener to connect to
/// * - `backend` - handler for requests from the master to the slave
pub fn connect(path: &str, backend: Arc<S>) -> Result<Self> {
pub fn connect(path: &str, backend: S) -> Result<Self> {
Ok(Self::new(Endpoint::<MasterReq>::connect(path)?, backend))
}
@ -946,7 +952,7 @@ mod tests {
fn test_slave_req_handler_new() {
let (p1, _p2) = SystemStream::pair().unwrap();
let endpoint = MasterReqEndpoint::from(p1);
let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
let backend = Mutex::new(DummySlaveReqHandler::new());
let mut handler = SlaveReqHandler::new(endpoint, backend);
handler.check_state().unwrap();