diff --git a/io_uring/src/uring.rs b/io_uring/src/uring.rs index 8d569aa878..8a69872d41 100644 --- a/io_uring/src/uring.rs +++ b/io_uring/src/uring.rs @@ -2,9 +2,9 @@ // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. +use std::collections::BTreeMap; use std::fmt; use std::fs::File; -use std::io::IoSlice; use std::os::unix::io::{AsRawFd, FromRawFd, RawFd}; use std::ptr::null_mut; use std::sync::atomic::{AtomicU32, Ordering}; @@ -78,11 +78,11 @@ pub struct URingStats { /// let f = File::open(Path::new("/dev/zero")).unwrap(); /// let mut uring = URingContext::new(16).unwrap(); /// uring -/// .add_poll_fd(f.as_raw_fd(), WatchingEvents::empty().set_read(), 454) +/// .add_poll_fd(f.as_raw_fd(), &WatchingEvents::empty().set_read(), 454) /// .unwrap(); /// let (user_data, res) = uring.wait().unwrap().next().unwrap(); -/// assert_eq!(user_data, 454 as UserData); -/// assert_eq!(res.unwrap(), 1 as i32); +/// assert_eq!(user_data, 454 as io_uring::UserData); +/// assert_eq!(res.unwrap(), 1 as u32); /// /// ``` pub struct URingContext { @@ -260,6 +260,21 @@ impl URingContext { self.add_rw_op(ptr, len, fd, offset, user_data, IORING_OP_READV as u8) } + /// See 'writev' but accepts an iterator instead of a vector if there isn't already a vector in + /// existence. + pub unsafe fn add_writev_iter( + &mut self, + iovecs: I, + fd: RawFd, + offset: u64, + user_data: UserData, + ) -> Result<()> + where + I: Iterator, + { + self.add_writev(iovecs.collect(), fd, offset, user_data) + } + /// Asynchronously writes to `fd` from the addresses given in `iovecs`. /// # Safety /// `add_writev` will write to the address given by `iovecs`. This is only safe if the caller @@ -267,9 +282,10 @@ impl URingContext { /// transaction is complete and that completion has been returned from the `wait` function. In /// addition there must not be any mutable references to the data pointed to by `iovecs` until /// the operation completes. Ensure that the fd remains open until the op completes as well. + /// The iovecs reference must be kept alive until the op returns. pub unsafe fn add_writev( &mut self, - iovecs: &[IoSlice], + iovecs: Vec, fd: RawFd, offset: u64, user_data: UserData, @@ -284,7 +300,24 @@ impl URingContext { sqe.user_data = user_data; sqe.flags = 0; sqe.fd = fd; - }) + })?; + self.complete_ring.add_op_data(user_data, iovecs); + Ok(()) + } + + /// See 'readv' but accepts an iterator instead of a vector if there isn't already a vector in + /// existence. + pub unsafe fn add_readv_iter( + &mut self, + iovecs: I, + fd: RawFd, + offset: u64, + user_data: UserData, + ) -> Result<()> + where + I: Iterator, + { + self.add_readv(iovecs.collect(), fd, offset, user_data) } /// Asynchronously reads from `fd` to the addresses given in `iovecs`. @@ -294,9 +327,10 @@ impl URingContext { /// transaction is complete and that completion has been returned from the `wait` function. In /// addition there must not be any references to the data pointed to by `iovecs` until the /// operation completes. Ensure that the fd remains open until the op completes as well. + /// The iovecs reference must be kept alive until the op returns. pub unsafe fn add_readv( &mut self, - iovecs: &[IoSlice], + iovecs: Vec, fd: RawFd, offset: u64, user_data: UserData, @@ -311,7 +345,9 @@ impl URingContext { sqe.user_data = user_data; sqe.flags = 0; sqe.fd = fd; - }) + })?; + self.complete_ring.add_op_data(user_data, iovecs); + Ok(()) } /// Syncs all completed operations, the ordering with in-flight async ops is not @@ -367,7 +403,7 @@ impl URingContext { pub fn add_poll_fd( &mut self, fd: RawFd, - events: WatchingEvents, + events: &WatchingEvents, user_data: UserData, ) -> Result<()> { self.prep_next_sqe(|sqe, _iovec| { @@ -389,7 +425,7 @@ impl URingContext { pub fn remove_poll_fd( &mut self, fd: RawFd, - events: WatchingEvents, + events: &WatchingEvents, user_data: UserData, ) -> Result<()> { self.prep_next_sqe(|sqe, _iovec| { @@ -524,6 +560,9 @@ struct CompleteQueueState { ring_mask: u32, cqes_offset: u32, completed: usize, + //For ops that pass in arrays of iovecs, they need to be valid for the duration of the + //operation because the kernel might read them at any time. + pending_op_addrs: BTreeMap>, } impl CompleteQueueState { @@ -541,9 +580,14 @@ impl CompleteQueueState { ring_mask, cqes_offset: params.cq_off.cqes, completed: 0, + pending_op_addrs: BTreeMap::new(), } } + fn add_op_data(&mut self, user_data: UserData, addrs: Vec) { + self.pending_op_addrs.insert(user_data, addrs); + } + fn get_cqe(&self, head: u32) -> &io_uring_cqe { unsafe { // Safe because we trust that the kernel has returned enough memory in io_uring_setup @@ -582,6 +626,9 @@ impl Iterator for CompleteQueueState { let user_data = cqe.user_data; let res = cqe.res; + // free the addrs saved for this op. + let _ = self.pending_op_addrs.remove(&user_data); + // Store the new head and ensure the reads above complete before the kernel sees the // update to head, `set_head` uses `Release` ordering let new_head = head.wrapping_add(1); @@ -637,6 +684,7 @@ impl QueuePointers { #[cfg(test)] mod tests { use std::fs::OpenOptions; + use std::io::{IoSlice, IoSliceMut}; use std::io::{Read, Seek, SeekFrom, Write}; use std::path::{Path, PathBuf}; use std::time::Duration; @@ -677,10 +725,18 @@ mod tests { offset: u64, user_data: UserData, ) { - let iovecs = [IoSlice::new(buf)]; + let io_vecs = unsafe { + //safe to transmut from IoSlice to iovec. + vec![IoSliceMut::new(buf)] + .into_iter() + .map(|slice| std::mem::transmute::(slice)) + .collect::>() + }; let (user_data_ret, res) = unsafe { // Safe because the `wait` call waits until the kernel is done with `buf`. - uring.add_readv(&iovecs, fd, offset, user_data).unwrap(); + uring + .add_readv_iter(io_vecs.into_iter(), fd, offset, user_data) + .unwrap(); uring.wait().unwrap().next().unwrap() }; assert_eq!(user_data_ret, user_data); @@ -771,15 +827,27 @@ mod tests { const BUF_SIZE: usize = 0x2000; let mut uring = URingContext::new(queue_size).unwrap(); - let buf = [0u8; BUF_SIZE]; - let buf2 = [0u8; BUF_SIZE]; - let buf3 = [0u8; BUF_SIZE]; - let io_slices = vec![IoSlice::new(&buf), IoSlice::new(&buf2), IoSlice::new(&buf3)]; - let total_len = io_slices.iter().fold(0, |a, iovec| a + iovec.len()); + let mut buf = [0u8; BUF_SIZE]; + let mut buf2 = [0u8; BUF_SIZE]; + let mut buf3 = [0u8; BUF_SIZE]; + let io_vecs = unsafe { + //safe to transmut from IoSlice to iovec. + vec![ + IoSliceMut::new(&mut buf), + IoSliceMut::new(&mut buf2), + IoSliceMut::new(&mut buf3), + ] + .into_iter() + .map(|slice| std::mem::transmute::(slice)) + .collect::>() + }; + let total_len = io_vecs.iter().fold(0, |a, iovec| a + iovec.iov_len); let f = create_test_file(&temp_dir, total_len as u64 * 2); let (user_data_ret, res) = unsafe { // Safe because the `wait` call waits until the kernel is done with `buf`. - uring.add_readv(&io_slices, f.as_raw_fd(), 0, 55).unwrap(); + uring + .add_readv_iter(io_vecs.into_iter(), f.as_raw_fd(), 0, 55) + .unwrap(); uring.wait().unwrap().next().unwrap() }; assert_eq!(user_data_ret, 55); @@ -865,13 +933,19 @@ mod tests { let buf = [0xaau8; BUF_SIZE]; let buf2 = [0xffu8; BUF_SIZE]; let buf3 = [0x55u8; BUF_SIZE]; - let io_slices = vec![IoSlice::new(&buf), IoSlice::new(&buf2), IoSlice::new(&buf3)]; - let total_len = io_slices.iter().fold(0, |a, iovec| a + iovec.len()); + let io_vecs = unsafe { + //safe to transmut from IoSlice to iovec. + vec![IoSlice::new(&buf), IoSlice::new(&buf2), IoSlice::new(&buf3)] + .into_iter() + .map(|slice| std::mem::transmute::(slice)) + .collect::>() + }; + let total_len = io_vecs.iter().fold(0, |a, iovec| a + iovec.iov_len); let mut f = create_test_file(&temp_dir, total_len as u64 * 2); let (user_data_ret, res) = unsafe { // Safe because the `wait` call waits until the kernel is done with `buf`. uring - .add_writev(&io_slices, f.as_raw_fd(), OFFSET, 55) + .add_writev_iter(io_vecs.into_iter(), f.as_raw_fd(), OFFSET, 55) .unwrap(); uring.wait().unwrap().next().unwrap() }; @@ -951,7 +1025,7 @@ mod tests { let f = File::open(Path::new("/dev/zero")).unwrap(); let mut uring = URingContext::new(16).unwrap(); uring - .add_poll_fd(f.as_raw_fd(), WatchingEvents::empty().set_read(), 454) + .add_poll_fd(f.as_raw_fd(), &WatchingEvents::empty().set_read(), 454) .unwrap(); let (user_data, res) = uring.wait().unwrap().next().unwrap(); assert_eq!(user_data, 454 as UserData);