diff --git a/src/runtime.rs b/src/runtime.rs index 3d27a712..b3eb74cf 100644 --- a/src/runtime.rs +++ b/src/runtime.rs @@ -8,6 +8,7 @@ use rustc_hash::{FxHashMap, FxHasher}; use smallvec::SmallVec; use std::fmt::Write; use std::hash::BuildHasherDefault; +use std::num::NonZeroU64; use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; use std::sync::Arc; @@ -154,17 +155,13 @@ where /// Read current value of the revision counter. #[inline] pub(crate) fn current_revision(&self) -> Revision { - Revision { - generation: self.shared_state.revision.load(Ordering::SeqCst), - } + Revision::from(self.shared_state.revision.load(Ordering::SeqCst)) } /// Read current value of the revision counter. #[inline] fn pending_revision(&self) -> Revision { - Revision { - generation: self.shared_state.pending_revision.load(Ordering::SeqCst), - } + Revision::from(self.shared_state.pending_revision.load(Ordering::SeqCst)) } /// Check if the current revision is canceled. If this method ever @@ -286,9 +283,7 @@ where let old_revision = self.shared_state.revision.fetch_add(1, Ordering::SeqCst); assert_eq!(current_revision, old_revision); - let new_revision = Revision { - generation: current_revision + 1, - }; + let new_revision = Revision::from(current_revision + 1); debug!("increment_revision: incremented to {:?}", new_revision); @@ -517,7 +512,7 @@ impl ActiveQuery { database_key, changed_at: ChangedAt { is_constant: true, - revision: Revision::START, + revision: Revision::start(), }, dependencies: Some(FxIndexSet::default()), } @@ -563,16 +558,22 @@ pub struct RuntimeId { /// directly as a user of salsa. #[derive(Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct Revision { - generation: u64, + generation: NonZeroU64, } impl Revision { - pub(crate) const START: Self = Revision { generation: 1 }; + fn start() -> Self { + Self::from(1) + } + + fn from(g: u64) -> Self { + Self { + generation: NonZeroU64::new(g).unwrap(), + } + } fn next(self) -> Revision { - Revision { - generation: self.generation + 1, - } + Self::from(self.generation.get() + 1) } }