diff --git a/lib/src/repo.rs b/lib/src/repo.rs index 9441aa778..c64acc1b0 100644 --- a/lib/src/repo.rs +++ b/lib/src/repo.rs @@ -12,11 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::cell::{Cell, RefCell}; use std::collections::{HashMap, HashSet}; use std::fmt::{Debug, Formatter}; use std::io::ErrorKind; -use std::ops::Deref; use std::path::{Path, PathBuf}; use std::sync::Arc; use std::{fs, io}; @@ -25,6 +23,7 @@ use itertools::Itertools; use once_cell::sync::OnceCell; use thiserror::Error; +use self::dirty_cell::DirtyCell; use crate::backend::{Backend, BackendError, ChangeId, CommitId}; use crate::commit::Commit; use crate::commit_builder::CommitBuilder; @@ -444,8 +443,7 @@ impl RepoLoader { pub struct MutableRepo { base_repo: Arc, index: MutableIndex, - view: RefCell, - view_dirty: Cell, + view: DirtyCell, rewritten_commits: HashMap>, abandoned_commits: HashSet, } @@ -461,8 +459,7 @@ impl MutableRepo { MutableRepo { base_repo, index: mut_index, - view: RefCell::new(mut_view), - view_dirty: Cell::new(false), + view: DirtyCell::with_clean(mut_view), rewritten_commits: Default::default(), abandoned_commits: Default::default(), } @@ -489,14 +486,8 @@ impl MutableRepo { } pub fn view(&self) -> &View { - // SAFETY: We don't really rely on runtime Ref/RefMut tracking. Since view_dirty - // is set only by &mut self functions, and view() is the solo function which - // publicly provides view-related reference, there should be no view reference - // alive if view_dirty == true. - self.enforce_view_invariants(); - let view_borrow = self.view.borrow(); - let view = view_borrow.deref(); - unsafe { std::mem::transmute(view) } + self.view + .get_or_ensure_clean(|v| self.enforce_view_invariants(v)) } fn view_mut(&mut self) -> &mut View { @@ -504,12 +495,11 @@ impl MutableRepo { } pub fn has_changes(&self) -> bool { - self.enforce_view_invariants(); - self.view.borrow().deref() != &self.base_repo.view + self.view() != &self.base_repo.view } pub fn consume(self) -> (MutableIndex, View) { - self.enforce_view_invariants(); + self.view.ensure_clean(|v| self.enforce_view_invariants(v)); (self.index, self.view.into_inner()) } @@ -624,7 +614,9 @@ impl MutableRepo { } fn leave_commit(&mut self, workspace_id: &WorkspaceId) { - let maybe_wc_commit_id = self.view.borrow().get_wc_commit_id(workspace_id).cloned(); + let maybe_wc_commit_id = self + .view + .with_ref(|v| v.get_wc_commit_id(workspace_id).cloned()); if let Some(wc_commit_id) = maybe_wc_commit_id { let wc_commit = self.store().get_commit(&wc_commit_id).unwrap(); if wc_commit.is_empty() @@ -637,12 +629,8 @@ impl MutableRepo { } } - fn enforce_view_invariants(&self) { - if !self.view_dirty.get() { - return; - } - let mut view_borrow_mut = self.view.borrow_mut(); - let view = view_borrow_mut.store_view_mut(); + fn enforce_view_invariants(&self, view: &mut View) { + let view = view.store_view_mut(); view.public_head_ids = self .index .heads(view.public_head_ids.iter()) @@ -656,7 +644,6 @@ impl MutableRepo { .iter() .cloned() .collect(); - self.view_dirty.set(false); } pub fn add_head(&mut self, head: &Commit) { @@ -690,27 +677,27 @@ impl MutableRepo { self.index.add_commit(missing_commit); } self.view.get_mut().add_head(head.id()); - *self.view_dirty.get_mut() = true; + self.view.mark_dirty(); } } pub fn remove_head(&mut self, head: &CommitId) { self.view_mut().remove_head(head); - *self.view_dirty.get_mut() = true; + self.view.mark_dirty(); } pub fn add_public_head(&mut self, head: &Commit) { self.view_mut().add_public_head(head.id()); - *self.view_dirty.get_mut() = true; + self.view.mark_dirty(); } pub fn remove_public_head(&mut self, head: &CommitId) { self.view_mut().remove_public_head(head); - *self.view_dirty.get_mut() = true; + self.view.mark_dirty(); } pub fn get_branch(&self, name: &str) -> Option { - self.view.borrow().get_branch(name).cloned() + self.view.with_ref(|v| v.get_branch(name).cloned()) } pub fn set_branch(&mut self, name: String, target: BranchTarget) { @@ -722,7 +709,7 @@ impl MutableRepo { } pub fn get_local_branch(&self, name: &str) -> Option { - self.view.borrow().get_local_branch(name) + self.view.with_ref(|v| v.get_local_branch(name)) } pub fn set_local_branch(&mut self, name: String, target: RefTarget) { @@ -734,7 +721,8 @@ impl MutableRepo { } pub fn get_remote_branch(&self, name: &str, remote_name: &str) -> Option { - self.view.borrow().get_remote_branch(name, remote_name) + self.view + .with_ref(|v| v.get_remote_branch(name, remote_name)) } pub fn set_remote_branch(&mut self, name: String, remote_name: String, target: RefTarget) { @@ -750,7 +738,7 @@ impl MutableRepo { } pub fn get_tag(&self, name: &str) -> Option { - self.view.borrow().get_tag(name) + self.view.with_ref(|v| v.get_tag(name)) } pub fn set_tag(&mut self, name: String, target: RefTarget) { @@ -779,7 +767,7 @@ impl MutableRepo { pub fn set_view(&mut self, data: op_store::View) { self.view_mut().set_view(data); - *self.view_dirty.get_mut() = true; + self.view.mark_dirty(); } pub fn merge(&mut self, base_repo: &ReadonlyRepo, other_repo: &ReadonlyRepo) { @@ -790,9 +778,9 @@ impl MutableRepo { self.index.merge_in(base_repo.index()); self.index.merge_in(other_repo.index()); - self.enforce_view_invariants(); + self.view.ensure_clean(|v| self.enforce_view_invariants(v)); self.merge_view(&base_repo.view, &other_repo.view); - *self.view_dirty.get_mut() = true; + self.view.mark_dirty(); } fn merge_view(&mut self, base: &View, other: &View) { @@ -982,3 +970,58 @@ impl IoResultExt for io::Result { }) } } + +mod dirty_cell { + use std::cell::{Cell, RefCell}; + use std::ops::Deref; + + /// Cell that lazily updates the value after `mark_dirty()`. + #[derive(Clone, Debug)] + pub struct DirtyCell { + value: RefCell, + dirty: Cell, + } + + impl DirtyCell { + pub fn with_clean(value: T) -> Self { + DirtyCell { + value: RefCell::new(value), + dirty: Cell::new(false), + } + } + + pub fn get_or_ensure_clean(&self, f: impl FnOnce(&mut T)) -> &T { + // SAFETY: get_mut/mark_dirty(&mut self) should invalidate any previously-clean + // references leaked by this method. Clean value never changes until then. + self.ensure_clean(f); + let borrow = self.value.borrow(); + let temp_ref = borrow.deref(); + unsafe { std::mem::transmute(temp_ref) } + } + + pub fn ensure_clean(&self, f: impl FnOnce(&mut T)) { + if self.dirty.get() { + // This borrow_mut() ensures that there is no dirty temporary reference. + // Panics if ensure_clean() is invoked from with_ref() callback for example. + f(&mut self.value.borrow_mut()); + self.dirty.set(false); + } + } + + pub fn into_inner(self) -> T { + self.value.into_inner() + } + + pub fn with_ref(&self, f: impl FnOnce(&T) -> R) -> R { + f(&self.value.borrow()) + } + + pub fn get_mut(&mut self) -> &mut T { + self.value.get_mut() + } + + pub fn mark_dirty(&mut self) { + *self.dirty.get_mut() = true; + } + } +}