diff --git a/base/src/sys/unix/file.rs b/base/src/sys/unix/file.rs index fac94372de..39b898d3d0 100644 --- a/base/src/sys/unix/file.rs +++ b/base/src/sys/unix/file.rs @@ -29,6 +29,39 @@ fn lseek(fd: &dyn AsRawDescriptor, offset: u64, option: LseekOption) -> Result Result>> { + let end = offset + len; + let offset_data = match lseek(fd, offset, LseekOption::Data) { + Ok(offset) => { + if offset >= end { + return Ok(None); + } else { + offset + } + } + Err(e) => { + return match e.errno() { + libc::ENXIO => Ok(None), + _ => Err(e), + } + } + }; + let offset_hole = lseek(fd, offset_data, LseekOption::Hole)?; + + Ok(Some(offset_data..offset_hole.min(end))) +} + /// Iterator returning the offset range of data in the file. /// /// This uses `lseek(2)` internally, and thus it changes the file offset. @@ -53,34 +86,13 @@ impl<'a> FileDataIterator<'a> { end: offset + len, } } - - fn find_next_data(&self) -> Result>> { - let offset_data = match lseek(self.fd, self.offset, LseekOption::Data) { - Ok(offset) => { - if offset >= self.end { - return Ok(None); - } else { - offset - } - } - Err(e) => { - return match e.errno() { - libc::ENXIO => Ok(None), - _ => Err(e), - } - } - }; - let offset_hole = lseek(self.fd, offset_data, LseekOption::Hole)?; - - Ok(Some(offset_data..offset_hole.min(self.end))) - } } impl<'a> Iterator for FileDataIterator<'a> { type Item = Range; fn next(&mut self) -> Option { - match self.find_next_data() { + match find_next_data(self.fd, self.offset, self.end - self.offset) { Ok(data_range) => { if let Some(ref data_range) = data_range { self.offset = data_range.end; diff --git a/base/src/sys/unix/mod.rs b/base/src/sys/unix/mod.rs index 716e5e8a5d..2e6bbd422c 100644 --- a/base/src/sys/unix/mod.rs +++ b/base/src/sys/unix/mod.rs @@ -71,6 +71,7 @@ pub use capabilities::drop_capabilities; pub use descriptor::*; pub use event::EventExt; pub(crate) use event::PlatformEvent; +pub use file::find_next_data; pub use file::FileDataIterator; pub use file_flags::*; pub use get_filesystem_type::*; diff --git a/common/data_model/src/volatile_memory.rs b/common/data_model/src/volatile_memory.rs index a3516ed663..3638632328 100644 --- a/common/data_model/src/volatile_memory.rs +++ b/common/data_model/src/volatile_memory.rs @@ -312,6 +312,56 @@ impl<'a> VolatileSlice<'a> { } } } + + /// Returns whether all bytes in this slice are zero or not. + /// + /// This is optimized for [VolatileSlice] aligned with 16 bytes. + /// + /// TODO(b/274840085): Use SIMD for better performance. + pub fn is_all_zero(&self) -> bool { + const MASK_4BIT: usize = 0x0f; + let head_addr = self.as_ptr() as usize; + // Round up by 16 + let aligned_head_addr = (head_addr + MASK_4BIT) & !MASK_4BIT; + let tail_addr = head_addr + self.size(); + // Round down by 16 + let aligned_tail_addr = tail_addr & !MASK_4BIT; + + // Check 16 bytes at once. The addresses should be 16 bytes aligned for better performance. + // SAFETY: Each aligned_addr is within VolatileSlice + if (aligned_head_addr..aligned_tail_addr) + .step_by(16) + .any(|aligned_addr| unsafe { *(aligned_addr as *const u128) } != 0) + { + return false; + } + + if head_addr == aligned_head_addr && tail_addr == aligned_tail_addr { + // If head_addr and tail_addr are aligned, we can skip the unaligned part which contains + // at least 2 conditional branches. + true + } else { + // Check unaligned part. + // SAFETY: The range [head_addr, aligned_head_addr) and [aligned_tail_addr, tail_addr) + // are within VolatileSlice. + unsafe { + is_all_zero_naive(head_addr, aligned_head_addr) + && is_all_zero_naive(aligned_tail_addr, tail_addr) + } + } + } +} + +/// Check whether every byte is zero. +/// +/// This checks byte by byte. +/// +/// ## Safety +/// +/// * `head_addr` <= `tail_addr` +/// * Bytes between `head_addr` and `tail_addr` is valid to access. +unsafe fn is_all_zero_naive(head_addr: usize, tail_addr: usize) -> bool { + (head_addr..tail_addr).all(|addr| *(addr as *const u8) == 0) } impl<'a> VolatileMemory for VolatileSlice<'a> { @@ -320,6 +370,23 @@ impl<'a> VolatileMemory for VolatileSlice<'a> { } } +impl PartialEq> for VolatileSlice<'_> { + fn eq(&self, other: &VolatileSlice) -> bool { + let size = self.size(); + if size != other.size() { + return false; + } + + // SAFETY: We pass pointers into valid VolatileSlice regions, and size is checked above. + let cmp = unsafe { libc::memcmp(self.as_ptr() as _, other.as_ptr() as _, size) }; + + cmp == 0 + } +} + +/// The `PartialEq` implementation for `VolatileSlice` is reflexive, symmetric, and transitive. +impl Eq for VolatileSlice<'_> {} + /// A memory location that supports volatile access of a `T`. /// /// # Examples @@ -491,4 +558,61 @@ mod tests { let res = a.get_slice(55, 50).unwrap_err(); assert_eq!(res, Error::OutOfBounds { addr: 105 }); } + + #[test] + fn is_all_zero_16bytes_aligned() { + let a = VecMem::new(1024); + let slice = a.get_slice(0, 1024).unwrap(); + + assert!(slice.is_all_zero()); + a.get_slice(129, 1).unwrap().write_bytes(1); + assert!(!slice.is_all_zero()); + } + + #[test] + fn is_all_zero_head_not_aligned() { + let a = VecMem::new(1024); + let slice = a.get_slice(1, 1023).unwrap(); + + assert!(slice.is_all_zero()); + a.get_slice(0, 1).unwrap().write_bytes(1); + assert!(slice.is_all_zero()); + a.get_slice(1, 1).unwrap().write_bytes(1); + assert!(!slice.is_all_zero()); + a.get_slice(1, 1).unwrap().write_bytes(0); + a.get_slice(129, 1).unwrap().write_bytes(1); + assert!(!slice.is_all_zero()); + } + + #[test] + fn is_all_zero_tail_not_aligned() { + let a = VecMem::new(1024); + let slice = a.get_slice(0, 1023).unwrap(); + + assert!(slice.is_all_zero()); + a.get_slice(1023, 1).unwrap().write_bytes(1); + assert!(slice.is_all_zero()); + a.get_slice(1022, 1).unwrap().write_bytes(1); + assert!(!slice.is_all_zero()); + a.get_slice(1022, 1).unwrap().write_bytes(0); + a.get_slice(0, 1).unwrap().write_bytes(1); + assert!(!slice.is_all_zero()); + } + + #[test] + fn is_all_zero_no_aligned_16bytes() { + let a = VecMem::new(1024); + let slice = a.get_slice(1, 16).unwrap(); + + assert!(slice.is_all_zero()); + a.get_slice(0, 1).unwrap().write_bytes(1); + assert!(slice.is_all_zero()); + for i in 1..17 { + a.get_slice(i, 1).unwrap().write_bytes(1); + assert!(!slice.is_all_zero()); + a.get_slice(i, 1).unwrap().write_bytes(0); + } + a.get_slice(17, 1).unwrap().write_bytes(1); + assert!(slice.is_all_zero()); + } } diff --git a/src/crosvm/cmdline.rs b/src/crosvm/cmdline.rs index 48805fd385..7b96445dd7 100644 --- a/src/crosvm/cmdline.rs +++ b/src/crosvm/cmdline.rs @@ -309,9 +309,18 @@ pub struct SwapEnableCommand { pub socket_path: String, } +#[derive(FromArgs)] +#[argh(subcommand, name = "trim")] +/// Trim pages in the staging memory +pub struct SwapTrimCommand { + #[argh(positional, arg_name = "VM_SOCKET")] + /// VM Socket path + pub socket_path: String, +} + #[derive(FromArgs)] #[argh(subcommand, name = "out")] -/// Enable swap of a VM +/// Swap out staging memory to swap file pub struct SwapOutCommand { #[argh(positional, arg_name = "VM_SOCKET")] /// VM Socket path @@ -348,6 +357,7 @@ pub struct SwapCommand { /// Swap related operations pub enum SwapSubcommands { Enable(SwapEnableCommand), + Trim(SwapTrimCommand), SwapOut(SwapOutCommand), Disable(SwapDisableCommand), Status(SwapStatusCommand), diff --git a/src/main.rs b/src/main.rs index ecb57ab080..967944dc56 100644 --- a/src/main.rs +++ b/src/main.rs @@ -181,6 +181,7 @@ fn swap_vms(cmd: cmdline::SwapCommand) -> std::result::Result<(), ()> { use cmdline::SwapSubcommands::*; let (req, path) = match &cmd.nested { Enable(params) => (VmRequest::Swap(SwapCommand::Enable), ¶ms.socket_path), + Trim(params) => (VmRequest::Swap(SwapCommand::Trim), ¶ms.socket_path), SwapOut(params) => (VmRequest::Swap(SwapCommand::SwapOut), ¶ms.socket_path), Disable(params) => (VmRequest::Swap(SwapCommand::Disable), ¶ms.socket_path), Status(params) => (VmRequest::Swap(SwapCommand::Status), ¶ms.socket_path), diff --git a/swap/src/file.rs b/swap/src/file.rs index ffb4b0836e..11b6950160 100644 --- a/swap/src/file.rs +++ b/swap/src/file.rs @@ -85,6 +85,10 @@ impl<'a> SwapFile<'a> { }) } + pub(crate) fn base_offset(&self) -> u64 { + self.offset + } + /// Returns the total count of managed pages. pub fn num_pages(&self) -> usize { self.present_list.len() @@ -156,7 +160,10 @@ impl<'a> SwapFile<'a> { } } - /// Clears the pages in the file corresponding to the index. + /// Mark the pages in the file corresponding to the index as cleared. + /// + /// The contents on the swap file are preserved and will be reused by + /// `SwapFile::mark_as_present()` and reduce disk I/O. /// /// If the pages are mlock(2)ed, unlock them before MADV_DONTNEED. This returns the number of /// pages munlock(2)ed. @@ -249,6 +256,17 @@ impl<'a> SwapFile<'a> { Ok(()) } + /// Mark the page as present on the file. + /// + /// The content on the swap file on previous `SwapFile::write_to_file()` is reused. + /// + /// # Arguments + /// + /// * `idx` - the index of the page from the head of the pages. + pub fn mark_as_present(&mut self, idx: usize) { + self.present_list.mark_as_present(idx..idx + 1); + } + /// Writes the contents to the swap file. /// /// # Arguments @@ -291,7 +309,8 @@ impl<'a> SwapFile<'a> { self.present_list.first_data_range(max_pages) } - /// Returns the [VolatileSlice] corresponding to the indices. + /// Returns the [VolatileSlice] corresponding to the indices regardless of whether the pages are + /// present or not. /// /// If the range is out of the region, this returns [Error::OutOfRange]. /// diff --git a/swap/src/lib.rs b/swap/src/lib.rs index 3adb502b47..9885582c49 100644 --- a/swap/src/lib.rs +++ b/swap/src/lib.rs @@ -79,6 +79,8 @@ use crate::worker::Worker; /// The max size of chunks to swap out/in at once. const MAX_SWAP_CHUNK_SIZE: usize = 2 * 1024 * 1024; // = 2MB +/// The max pages to trim at once. +const MAX_TRIM_PAGES: usize = 1024; /// Current state of vmm-swap. /// @@ -90,6 +92,8 @@ pub enum State { Ready, /// Pages in guest memory are moved to the staging memory. Pending, + /// Trimming staging memory. + TrimInProgress, /// swap-out is in progress. SwapOutInProgress, /// swap out succeeded. @@ -104,6 +108,7 @@ impl From<&SwapState<'_>> for State { fn from(state: &SwapState<'_>) -> Self { match state { SwapState::SwapOutPending => State::Pending, + SwapState::Trim(_) => State::TrimInProgress, SwapState::SwapOutInProgress { .. } => State::SwapOutInProgress, SwapState::SwapOutCompleted => State::Active, SwapState::SwapInInProgress(_) => State::SwapInInProgress, @@ -205,6 +210,7 @@ impl Status { #[derive(Serialize, Deserialize, Debug)] enum Command { Enable, + Trim, SwapOut, Disable, Exit, @@ -390,6 +396,17 @@ impl SwapController { Ok(()) } + /// Trim pages in the staging memory which are needless to be written back to the swap file. + /// + /// * zero pages + /// * pages which are the same as the pages in the swap file. + pub fn trim(&self) -> anyhow::Result<()> { + self.command_tube + .send(&Command::Trim) + .context("send swap trim request")?; + Ok(()) + } + /// Swap out all the pages in the staging memory to the swap files. /// /// This returns as soon as it succeeds to send request to the monitor process. @@ -679,6 +696,9 @@ fn monitor_process( // events are obsolete. Run `WaitContext::wait()` again break; } + Command::Trim => { + warn!("swap trim while disabled"); + } Command::SwapOut => { warn!("swap out while disabled"); } @@ -705,6 +725,7 @@ fn monitor_process( enum SwapState<'scope> { SwapOutPending, + Trim(ScopedJoinHandle<'scope, anyhow::Result<()>>), SwapOutInProgress { started_time: Instant }, SwapOutCompleted, SwapInInProgress(ScopedJoinHandle<'scope, anyhow::Result<()>>), @@ -779,6 +800,19 @@ fn move_guest_to_staging( } } +fn abort_background_job( + join_handle: ScopedJoinHandle<'_, T>, + bg_job_control: &BackgroundJobControl, +) -> anyhow::Result { + bg_job_control.abort(); + // Wait until the background job is aborted and the thread finishes. + let result = join_handle + .join() + .expect("panic on the background job thread"); + bg_job_control.reset().context("reset swap in event")?; + Ok(result) +} + fn handle_vmm_swap<'scope, 'env>( scope: &'scope Scope<'scope, 'env>, wait_ctx: &WaitContext, @@ -887,15 +921,21 @@ fn handle_vmm_swap<'scope, 'env>( bail!("child process is forked while swap is enabled"); } Command::Enable => { - if let SwapState::SwapInInProgress(join_handle) = state { - info!("abort swap-in"); - bg_job_control.abort(); - // Wait until the background job is aborted and the thread finishes. - if let Err(e) = join_handle.join() { - bail!("failed to join swap in thread: {:?}", e); + match state { + SwapState::SwapInInProgress(join_handle) => { + info!("abort swap-in"); + abort_background_job(join_handle, bg_job_control) + .context("abort swap in")? + .context("swap_in failure")?; } - bg_job_control.reset().context("reset swap in event")?; - }; + SwapState::Trim(join_handle) => { + info!("abort trim"); + abort_background_job(join_handle, bg_job_control) + .context("abort trim")? + .context("trim failure")?; + } + _ => {} + } info!("start moving memory to staging"); match move_guest_to_staging(page_handler, guest_memory, vm_tube, worker) { @@ -914,6 +954,46 @@ fn handle_vmm_swap<'scope, 'env>( } } } + Command::Trim => match &state { + SwapState::SwapOutPending => { + *state_transition.lock() = StateTransition::default(); + let join_handle = scope.spawn(|| { + let mut ctx = page_handler.start_trim(); + let job = bg_job_control.new_job(); + let start_time = std::time::Instant::now(); + + while !job.is_aborted() { + if let Some(trimmed_pages) = + ctx.trim_pages(MAX_TRIM_PAGES).context("trim pages")? + { + let mut state_transition = state_transition.lock(); + state_transition.pages += trimmed_pages; + state_transition.time_ms = start_time.elapsed().as_millis(); + } else { + // Traversed all pages. + break; + } + } + + if job.is_aborted() { + info!("trim is aborted"); + } else { + info!( + "trimmed {} clean pages and {} zero pages", + ctx.trimmed_clean_pages(), + ctx.trimmed_zero_pages() + ); + } + Ok(()) + }); + + state = SwapState::Trim(join_handle); + info!("start trimming staging memory"); + } + state => { + warn!("swap trim is not ready. state: {:?}", State::from(state)); + } + }, Command::SwapOut => match &state { SwapState::SwapOutPending => { state = SwapState::SwapOutInProgress { @@ -927,7 +1007,13 @@ fn handle_vmm_swap<'scope, 'env>( } }, Command::Disable => { - match &state { + match state { + SwapState::Trim(join_handle) => { + info!("abort trim"); + abort_background_job(join_handle, bg_job_control) + .context("abort trim")? + .context("trim failure")?; + } SwapState::SwapOutInProgress { .. } => { info!("swap out is aborted"); } @@ -975,16 +1061,19 @@ fn handle_vmm_swap<'scope, 'env>( if let Err(e) = join_handle.join() { bail!("failed to join swap in thread: {:?}", e); } + return Ok(true); } - _ => { - let mut ctx = page_handler.start_swap_in(); - let uffd = uffd_list.main_uffd(); - // Swap-in all before exit. - while ctx.swap_in(uffd, MAX_SWAP_CHUNK_SIZE).context("swap in")? > 0 - { - } + SwapState::Trim(join_handle) => { + abort_background_job(join_handle, bg_job_control) + .context("abort trim")? + .context("trim failure")?; } + _ => {} } + let mut ctx = page_handler.start_swap_in(); + let uffd = uffd_list.main_uffd(); + // Swap-in all before exit. + while ctx.swap_in(uffd, MAX_SWAP_CHUNK_SIZE).context("swap in")? > 0 {} return Ok(true); } Command::Status => { @@ -1004,28 +1093,37 @@ fn handle_vmm_swap<'scope, 'env>( // the obsolete token here. continue; } - if let SwapState::SwapInInProgress(join_handle) = state { - match join_handle.join() { - Ok(Ok(_)) => { - let state_transition = state_transition.lock(); - info!( - "swap in all {} pages in {} ms.", - state_transition.pages, state_transition.time_ms - ); - return Ok(false); - } - Ok(Err(e)) => { - bail!("swap in failed: {:?}", e) - } - Err(e) => { - bail!("failed to wait for the swap in thread: {:?}", e); - } + match state { + SwapState::SwapInInProgress(join_handle) => { + join_handle + .join() + .expect("panic on the background job thread") + .context("swap in finish")?; + let state_transition = state_transition.lock(); + info!( + "swap in all {} pages in {} ms.", + state_transition.pages, state_transition.time_ms + ); + return Ok(false); + } + SwapState::Trim(join_handle) => { + join_handle + .join() + .expect("panic on the background job thread") + .context("trim finish")?; + let state_transition = state_transition.lock(); + info!( + "trimmed {} pages in {} ms.", + state_transition.pages, state_transition.time_ms + ); + state = SwapState::SwapOutPending; + } + state => { + bail!( + "background job completed but the actual state is {:?}", + State::from(&state) + ); } - } else { - bail!( - "swap in completed but the actual state is {:?}", - State::from(&state) - ); } } }; diff --git a/swap/src/page_handler.rs b/swap/src/page_handler.rs index 7dde812b23..8966593e3c 100644 --- a/swap/src/page_handler.rs +++ b/swap/src/page_handler.rs @@ -13,6 +13,7 @@ use std::sync::Arc; use anyhow::Context; use base::error; +use base::sys::find_next_data; use base::unix::FileDataIterator; use base::AsRawDescriptor; use base::SharedMemory; @@ -165,6 +166,7 @@ struct PageHandleContext<'a> { pub struct PageHandler<'a> { ctx: Mutex>, channel: Arc>, + swap_raw_file: &'a File, } impl<'a> PageHandler<'a> { @@ -180,14 +182,14 @@ impl<'a> PageHandler<'a> { /// * `address_ranges` - The list of address range of the regions. the start address must align /// with page. the size must be multiple of pagesize. pub fn create( - swap_file: &'a File, + swap_raw_file: &'a File, staging_shmem: &'a SharedMemory, address_ranges: &[Range], stating_move_context: Arc>, ) -> Result { // Truncate the file into the size to hold all regions, otherwise access beyond the end of // file may cause SIGBUS. - swap_file + swap_raw_file .set_len( address_ranges .iter() @@ -231,7 +233,7 @@ impl<'a> PageHandler<'a> { assert!(is_page_aligned(base_addr)); assert!(is_page_aligned(region_size)); - let file = SwapFile::new(swap_file, offset_pages, num_of_pages)?; + let file = SwapFile::new(swap_raw_file, offset_pages, num_of_pages)?; let staging_memory = StagingMemory::new( staging_shmem, pages_to_bytes(offset_pages) as u64, @@ -259,6 +261,7 @@ impl<'a> PageHandler<'a> { mlock_budget_pages: bytes_to_pages(MLOCK_BUDGET), }), channel: stating_move_context, + swap_raw_file, }) } @@ -571,6 +574,19 @@ impl<'a> PageHandler<'a> { } } + /// Create a new [TrimContext]. + pub fn start_trim(&'a self) -> TrimContext<'a> { + TrimContext { + ctx: &self.ctx, + swap_raw_file: self.swap_raw_file, + cur_page: 0, + cur_region: 0, + next_data_in_file: 0..0, + clean_pages: 0, + zero_pages: 0, + } + } + /// Returns count of pages active on the memory. pub fn compute_resident_pages(&self) -> usize { self.ctx @@ -738,3 +754,136 @@ impl Drop for SwapInContext<'_> { ctx.mlock_budget_pages = bytes_to_pages(MLOCK_BUDGET); } } + +/// Context for trim operation. +/// +/// This drops 2 types of pages in the staging memory to reduce disk write. +/// +/// * Clean pages +/// * The pages which have been swapped out to the disk and have not been changed. +/// * Drop the pages in the staging memory and mark it as present on the swap file. +/// * Zero pages +/// * Drop the pages in the staging memory. The pages will be UFFD_ZEROed on page fault. +pub struct TrimContext<'a> { + ctx: &'a Mutex>, + swap_raw_file: &'a File, + cur_region: usize, + cur_page: usize, + /// The page idx range of pages which have been stored in the swap file. + next_data_in_file: Range, + clean_pages: usize, + zero_pages: usize, +} + +impl TrimContext<'_> { + /// Trim pages in the staging memory. + /// + /// This returns the pages trimmed. This returns `None` if it traversed all pages in the staging + /// memory. + /// + /// # Arguments + /// + /// `max_size` - The maximum pages to be compared. + pub fn trim_pages(&mut self, max_pages: usize) -> anyhow::Result> { + let mut ctx = self.ctx.lock(); + if self.cur_region >= ctx.regions.len() { + return Ok(None); + } + let region = &mut ctx.regions[self.cur_region]; + let region_size_bytes = pages_to_bytes(region.file.num_pages()) as u64; + let mut n_trimmed = 0; + + for _ in 0..max_pages { + if let Some(slice_in_staging) = region + .staging_memory + .page_content(self.cur_page) + .context("get page of staging memory")? + { + let idx_range = self.cur_page..self.cur_page + 1; + + if self.cur_page >= self.next_data_in_file.end { + let offset_in_region = pages_to_bytes(self.cur_page) as u64; + let offset = region.file.base_offset() + offset_in_region; + if let Some(offset_range) = find_next_data( + self.swap_raw_file, + offset, + region_size_bytes - offset_in_region, + ) + .context("find next data in swap file")? + { + let start = bytes_to_pages( + (offset_range.start - region.file.base_offset()) as usize, + ); + let end = + bytes_to_pages((offset_range.end - region.file.base_offset()) as usize); + self.next_data_in_file = start..end; + } else { + self.next_data_in_file = region.file.num_pages()..region.file.num_pages(); + } + } + + // Check zero page on the staging memory first. If the page is non-zero and have not + // been changed, zero checking is useless, but less cost than file I/O for the pages + // which were in the swap file and now is zero. + // Check 2 types of page in the same loop to utilize CPU cache for staging memory. + if slice_in_staging.is_all_zero() { + region + .staging_memory + .clear_range(idx_range.clone()) + .context("clear a page in staging memory")?; + if self.cur_page >= self.next_data_in_file.start { + // The page is on the swap file as well. + let munlocked_pages = region + .file + .erase_from_disk(idx_range) + .context("clear a page in swap file")?; + if munlocked_pages != 0 { + // Only either of swap-in or trimming runs at the same time. This is not + // expected path. Just logging an error because leaking + // mlock_budget_pages is not fatal. + error!("pages are mlock(2)ed while trimming"); + } + } + n_trimmed += 1; + self.zero_pages += 1; + } else if self.cur_page >= self.next_data_in_file.start { + // The previous content of the page is on the disk. + let slice_in_file = region + .file + .get_slice(idx_range.clone()) + .context("get slice in swap file")?; + + if slice_in_staging == slice_in_file { + region + .staging_memory + .clear_range(idx_range.clone()) + .context("clear a page in staging memory")?; + region.file.mark_as_present(self.cur_page); + n_trimmed += 1; + self.clean_pages += 1; + } + } + } + + self.cur_page += 1; + if self.cur_page >= region.file.num_pages() { + self.cur_region += 1; + self.cur_page = 0; + self.next_data_in_file = 0..0; + break; + } + } + + Ok(Some(n_trimmed)) + } + + /// Total trimmed clean pages. + pub fn trimmed_clean_pages(&self) -> usize { + self.clean_pages + } + + /// Total trimmed zero pages. + pub fn trimmed_zero_pages(&self) -> usize { + self.zero_pages + } +} diff --git a/swap/tests/page_handler.rs b/swap/tests/page_handler.rs index a3389a0a55..09e9a4d9a9 100644 --- a/swap/tests/page_handler.rs +++ b/swap/tests/page_handler.rs @@ -544,7 +544,7 @@ fn move_to_staging_invalid_base_addr() { worker.close(); } -fn swap_out_all(page_handler: &mut PageHandler) { +fn swap_out_all(page_handler: &PageHandler) { while page_handler.swap_out(1024 * 1024).unwrap() != 0 {} } @@ -570,7 +570,7 @@ fn swap_out_success() { base_addr1..(base_addr1 + 3 * pagesize()), base_addr2..(base_addr2 + 3 * pagesize()), ]; - let mut page_handler = + let page_handler = PageHandler::create(&file, &staging_shmem, ®ions, worker.channel.clone()).unwrap(); // write data before registering to userfaultfd unsafe { @@ -590,7 +590,7 @@ fn swap_out_success() { .unwrap(); } worker.channel.wait_complete(); - swap_out_all(&mut page_handler); + swap_out_all(&page_handler); // page faults on all pages. page 0 and page 2 will be swapped in from the file. page 1 will // be filled with zero. for i in 0..3 { @@ -647,7 +647,7 @@ fn swap_out_handled_page() { let base_addr1 = mmap1.as_ptr() as usize; let regions = [base_addr1..(base_addr1 + 3 * pagesize())]; - let mut page_handler = + let page_handler = PageHandler::create(&file, &staging_shmem, ®ions, worker.channel.clone()).unwrap(); // write data before registering to userfaultfd unsafe { @@ -665,7 +665,7 @@ fn swap_out_handled_page() { page_handler .handle_page_fault(&uffd, base_addr1 + pagesize()) .unwrap(); - swap_out_all(&mut page_handler); + swap_out_all(&page_handler); // read values on another thread to avoid blocking forever let join_handle = thread::spawn(move || { @@ -708,7 +708,7 @@ fn swap_out_twice() { base_addr1..(base_addr1 + 3 * pagesize()), base_addr2..(base_addr2 + 3 * pagesize()), ]; - let mut page_handler = + let page_handler = PageHandler::create(&file, &staging_shmem, ®ions, worker.channel.clone()).unwrap(); unsafe { for i in 0..pagesize() { @@ -727,7 +727,7 @@ fn swap_out_twice() { .unwrap(); } worker.channel.wait_complete(); - swap_out_all(&mut page_handler); + swap_out_all(&page_handler); // page faults on all pages in mmap1. for i in 0..3 { page_handler @@ -757,7 +757,7 @@ fn swap_out_twice() { .unwrap(); } worker.channel.wait_complete(); - swap_out_all(&mut page_handler); + swap_out_all(&page_handler); // page faults on all pages. for i in 0..3 { @@ -821,7 +821,7 @@ fn swap_in_success() { base_addr1..(base_addr1 + 3 * pagesize()), base_addr2..(base_addr2 + 3 * pagesize()), ]; - let mut page_handler = + let page_handler = PageHandler::create(&file, &staging_shmem, ®ions, worker.channel.clone()).unwrap(); unsafe { for i in base_addr1 + pagesize()..base_addr1 + 2 * pagesize() { @@ -843,7 +843,7 @@ fn swap_in_success() { .unwrap(); } worker.channel.wait_complete(); - swap_out_all(&mut page_handler); + swap_out_all(&page_handler); page_handler .handle_page_fault(&uffd, base_addr1 + pagesize()) .unwrap(); @@ -896,3 +896,134 @@ fn swap_in_success() { } worker.close(); } + +#[test] +fn trim_success() { + let worker = Worker::new(2, 2); + let uffd = create_uffd_for_test(); + let file = tempfile::tempfile().unwrap(); + let staging_shmem = SharedMemory::new("test staging memory", 6 * pagesize() as u64).unwrap(); + let shm = SharedMemory::new("shm", 6 * pagesize() as u64).unwrap(); + let mmap1 = MemoryMappingBuilder::new(3 * pagesize()) + .from_shared_memory(&shm) + .build() + .unwrap(); + let mmap2 = MemoryMappingBuilder::new(3 * pagesize()) + .from_shared_memory(&shm) + .offset(3 * pagesize() as u64) + .build() + .unwrap(); + let base_addr1 = mmap1.as_ptr() as usize; + let base_addr2 = mmap2.as_ptr() as usize; + let regions = [ + base_addr1..(base_addr1 + 3 * pagesize()), + base_addr2..(base_addr2 + 3 * pagesize()), + ]; + let page_handler = + PageHandler::create(&file, &staging_shmem, ®ions, worker.channel.clone()).unwrap(); + unsafe { + for i in base_addr1..base_addr1 + pagesize() { + *(i as *mut u8) = 0; + } + for i in base_addr1 + pagesize()..base_addr1 + 2 * pagesize() { + *(i as *mut u8) = 1; + } + for i in base_addr2..base_addr2 + pagesize() { + *(i as *mut u8) = 0; + } + for i in base_addr2 + pagesize()..base_addr2 + 2 * pagesize() { + *(i as *mut u8) = 2; + } + for i in base_addr2 + 2 * pagesize()..base_addr2 + 3 * pagesize() { + *(i as *mut u8) = 3; + } + } + unsafe { register_regions(®ions, array::from_ref(&uffd)) }.unwrap(); + + unsafe { + page_handler.move_to_staging(base_addr1, &shm, 0).unwrap(); + page_handler + .move_to_staging(base_addr2, &shm, 3 * pagesize() as u64) + .unwrap(); + } + worker.channel.wait_complete(); + + let mut trim_ctx = page_handler.start_trim(); + + assert_eq!(trim_ctx.trim_pages(6 * pagesize()).unwrap().unwrap(), 1); + assert_eq!(trim_ctx.trimmed_clean_pages(), 0); + assert_eq!(trim_ctx.trimmed_zero_pages(), 1); + // 1 zero page + assert_eq!(trim_ctx.trim_pages(6 * pagesize()).unwrap().unwrap(), 1); + assert_eq!(trim_ctx.trimmed_clean_pages(), 0); + assert_eq!(trim_ctx.trimmed_zero_pages(), 2); + + swap_out_all(&page_handler); + for i in 0..3 { + page_handler + .handle_page_fault(&uffd, base_addr1 + i * pagesize()) + .unwrap(); + page_handler + .handle_page_fault(&uffd, base_addr2 + i * pagesize()) + .unwrap(); + } + unsafe { + for i in base_addr2 + pagesize()..base_addr2 + 2 * pagesize() { + *(i as *mut u8) = 4; + } + } + + // move to staging memory. + unsafe { + page_handler.move_to_staging(base_addr1, &shm, 0).unwrap(); + page_handler + .move_to_staging(base_addr2, &shm, 3 * pagesize() as u64) + .unwrap(); + } + worker.channel.wait_complete(); + + let mut trim_ctx = page_handler.start_trim(); + // 2 zero pages and 1 clean page + assert_eq!(trim_ctx.trim_pages(6 * pagesize()).unwrap().unwrap(), 3); + assert_eq!(trim_ctx.trimmed_clean_pages(), 1); + assert_eq!(trim_ctx.trimmed_zero_pages(), 2); + // 1 zero page and 1 clean pages + assert_eq!(trim_ctx.trim_pages(6 * pagesize()).unwrap().unwrap(), 2); + assert_eq!(trim_ctx.trimmed_clean_pages(), 2); + assert_eq!(trim_ctx.trimmed_zero_pages(), 3); + assert!(trim_ctx.trim_pages(pagesize()).unwrap().is_none()); + + let mut swap_in_ctx = page_handler.start_swap_in(); + while swap_in_ctx.swap_in(&uffd, 1024 * 1024).unwrap() != 0 {} + unregister_regions(®ions, array::from_ref(&uffd)).unwrap(); + + // read values on another thread to avoid blocking forever + let join_handle = thread::spawn(move || { + let mut result = Vec::new(); + for i in 0..3 { + for j in 0..pagesize() { + let ptr = (base_addr1 + i * pagesize() + j) as *mut u8; + unsafe { + result.push(*ptr); + } + } + } + for i in 0..3 { + for j in 0..pagesize() { + let ptr = (base_addr2 + i * pagesize() + j) as *mut u8; + unsafe { + result.push(*ptr); + } + } + } + result + }); + let result = wait_thread_with_timeout(join_handle, 100); + let values: Vec = vec![0, 1, 0, 0, 4, 3]; + for (i, v) in values.iter().enumerate() { + for j in 0..pagesize() { + assert_eq!(&result[i * pagesize() + j], v); + } + } + worker.close(); +} diff --git a/vm_control/src/lib.rs b/vm_control/src/lib.rs index afe3d4ee87..1925e179c9 100644 --- a/vm_control/src/lib.rs +++ b/vm_control/src/lib.rs @@ -1066,6 +1066,7 @@ pub enum PvClockCommandResponse { #[derive(Serialize, Deserialize, Debug)] pub enum SwapCommand { Enable, + Trim, SwapOut, Disable, Status, @@ -1388,6 +1389,19 @@ impl VmRequest { } VmResponse::Err(SysError::new(ENOTSUP)) } + VmRequest::Swap(SwapCommand::Trim) => { + #[cfg(feature = "swap")] + if let Some(swap_controller) = swap_controller { + return match swap_controller.trim() { + Ok(()) => VmResponse::Ok, + Err(e) => { + error!("swap trim failed: {}", e); + VmResponse::Err(SysError::new(EINVAL)) + } + }; + } + VmResponse::Err(SysError::new(ENOTSUP)) + } VmRequest::Swap(SwapCommand::SwapOut) => { #[cfg(feature = "swap")] if let Some(swap_controller) = swap_controller {