feat: dag iter

This commit is contained in:
Zixuan Chen 2022-08-04 02:42:00 +08:00
parent 8a15d2e863
commit b8287837dc
5 changed files with 225 additions and 9 deletions

View file

@ -16,6 +16,7 @@ moveit = "0.5.0"
pin-project = "1.0.10"
serde = {version = "1.0.140", features = ["derive"]}
thiserror = "1.0.31"
im = "15.1.0"
[dev-dependencies]
proptest = "1.0.0"

View file

@ -4,6 +4,7 @@ use std::{
};
use fxhash::{FxHashMap, FxHashSet};
mod iter;
#[cfg(test)]
mod test;
@ -14,6 +15,8 @@ use crate::{
version::VersionVector,
};
use self::iter::{iter_dag, DagIterator};
pub trait DagNode {
fn dag_id_start(&self) -> ID;
fn lamport_start(&self) -> Lamport;
@ -64,13 +67,21 @@ fn reverse_path(path: &mut Vec<IdSpan>) {
}
}
/// We have following invariance in DAG
/// - All deps' lamports are smaller than current node's lamport
pub(crate) trait Dag {
type Node: DagNode;
fn get(&self, id: ID) -> Option<&Self::Node>;
fn contains(&self, id: ID) -> bool;
#[inline]
fn contains(&self, id: ID) -> bool {
self.vv().includes(id)
}
fn frontier(&self) -> &[ID];
fn roots(&self) -> Vec<&Self::Node>;
fn vv(&self) -> VersionVector;
//
// TODO: Maybe use Result return type
@ -162,6 +173,13 @@ pub(crate) trait Dag {
ans
}
fn iter(&self) -> DagIterator<'_, Self::Node>
where
Self: Sized,
{
iter_dag(self)
}
}
fn get_version_vector<'a, Get>(get: &'a Get, id: ID) -> VersionVector

View file

@ -0,0 +1,112 @@
use super::*;
#[derive(Debug, Clone, PartialEq, Eq)]
struct IdHeapItem {
id: ID,
lamport: Lamport,
}
impl PartialOrd for IdHeapItem {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.lamport.cmp(&other.lamport).reverse())
}
}
impl Ord for IdHeapItem {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.lamport.cmp(&other.lamport).reverse()
}
}
pub(crate) fn iter_dag<T>(dag: &dyn Dag<Node = T>) -> DagIterator<'_, T> {
DagIterator {
dag,
vv_map: Default::default(),
visited: VersionVector::new(),
heap: BinaryHeap::new(),
}
}
pub(crate) struct DagIterator<'a, T> {
dag: &'a dyn Dag<Node = T>,
/// we should keep every nodes starting id inside this map
vv_map: FxHashMap<ID, VersionVector>,
/// Because all deps' lamports are smaller than current node's lamport.
/// We can use the lamport to sort the nodes so that each node's deps are processed before itself.
///
/// The ids in this heap are start ids of nodes. It won't be a id pointing to the middle of a node.
heap: BinaryHeap<IdHeapItem>,
visited: VersionVector,
}
// TODO: Need benchmark on memory
impl<'a, T: DagNode> Iterator for DagIterator<'a, T> {
type Item = (&'a T, VersionVector);
fn next(&mut self) -> Option<Self::Item> {
if self.vv_map.is_empty() {
if self.dag.vv().len() == 0 {
return None;
}
for (&client_id, _) in self.dag.vv().iter() {
let vv = VersionVector::new();
if let Some(node) = self.dag.get(ID::new(client_id, 0)) {
if node.lamport_start() == 0 {
self.vv_map.insert(ID::new(client_id, 0), vv.clone());
}
self.heap.push(IdHeapItem {
id: ID::new(client_id, 0),
lamport: node.lamport_start(),
});
}
self.visited.insert(client_id, 0);
}
}
if !self.heap.is_empty() {
let item = self.heap.pop().unwrap();
let id = item.id;
let node = self.dag.get(id).unwrap();
debug_assert_eq!(id, node.dag_id_start());
let mut vv = {
// calculate vv
let mut vv = None;
for &dep_id in node.deps() {
let dep = self.dag.get(dep_id).unwrap();
let dep_vv = self.vv_map.get(&dep.dag_id_start()).unwrap();
if vv.is_none() {
vv = Some(dep_vv.clone());
} else {
vv.as_mut().unwrap().merge(dep_vv);
}
if dep.dag_id_start() != dep_id {
vv.as_mut().unwrap().set_end(dep_id);
}
}
vv.unwrap_or_else(VersionVector::new)
};
vv.try_update_end(id);
self.vv_map.insert(node.dag_id_start(), vv.clone());
// push next node from the same client to the heap
let next_id = id.inc(node.len() as i32);
if self.dag.contains(next_id) {
let next_node = self.dag.get(next_id).unwrap();
self.heap.push(IdHeapItem {
id: next_id,
lamport: next_node.lamport_start(),
});
}
return Some((node, vv));
}
None
}
}

View file

@ -49,7 +49,7 @@ impl DagNode for TestNode {
struct TestDag {
nodes: FxHashMap<ClientID, Vec<TestNode>>,
frontier: Vec<ID>,
version_vec: FxHashMap<ClientID, Counter>,
version_vec: VersionVector,
next_lamport: Lamport,
client_id: ClientID,
}
@ -77,6 +77,10 @@ impl Dag for TestDag {
.and_then(|x| if *x > id.counter { Some(()) } else { None })
.is_some()
}
fn vv(&self) -> VersionVector {
self.version_vec.clone()
}
}
impl TestDag {
@ -84,7 +88,7 @@ impl TestDag {
Self {
nodes: FxHashMap::default(),
frontier: Vec::new(),
version_vec: FxHashMap::default(),
version_vec: VersionVector::new(),
next_lamport: 0,
client_id,
}
@ -216,6 +220,41 @@ struct Interaction {
len: usize,
}
mod iter {
use super::*;
#[test]
fn test() {
let mut a = TestDag::new(0);
let mut b = TestDag::new(1);
// 0-0
a.push(1);
// 1-0
b.push(1);
a.merge(&b);
// 0-1
a.push(1);
b.merge(&a);
// 1-1
b.push(1);
a.merge(&b);
// 0-2
a.push(1);
let mut count = 0;
for (node, vv) in a.iter() {
count += 1;
if node.id == ID::new(0, 0) {
assert_eq!(vv, vec![ID::new(0, 0)].into());
} else if node.id == ID::new(0, 2) {
assert_eq!(vv, vec![ID::new(0, 2), ID::new(1, 1)].into());
}
}
assert_eq!(count, 5);
}
}
mod get_version_vector {
use super::*;

View file

@ -4,7 +4,8 @@ use std::{
ops::{Deref, DerefMut},
};
use fxhash::FxHashMap;
use fxhash::{FxBuildHasher, FxHashMap};
use im::hashmap::HashMap as ImHashMap;
use crate::{
change::Lamport,
@ -13,12 +14,18 @@ use crate::{
ClientID,
};
/// It's a immutable hash map with O(1) clone. Because
/// - we want a cheap clone op on vv;
/// - neighbor op's VersionVectors are very similar, most of the memory can be shared in
/// immutable hashmap
///
/// see also [im].
#[repr(transparent)]
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct VersionVector(FxHashMap<ClientID, Counter>);
pub struct VersionVector(ImHashMap<ClientID, Counter>);
impl Deref for VersionVector {
type Target = FxHashMap<ClientID, Counter>;
type Target = ImHashMap<ClientID, Counter>;
fn deref(&self) -> &Self::Target {
&self.0
@ -76,7 +83,7 @@ impl DerefMut for VersionVector {
impl VersionVector {
#[inline]
pub fn new() -> Self {
Self(FxHashMap::default())
Self(ImHashMap::new())
}
#[inline]
@ -89,7 +96,7 @@ impl VersionVector {
#[inline]
pub fn try_update_end(&mut self, id: ID) -> bool {
if let Some(end) = self.0.get_mut(&id.client_id) {
if *end < id.counter {
if *end < id.counter + 1 {
*end = id.counter + 1;
true
} else {
@ -115,6 +122,27 @@ impl VersionVector {
ans
}
pub fn merge(&mut self, other: &Self) {
for (&client_id, &other_end) in other.iter() {
if let Some(my_end) = self.get_mut(&client_id) {
if *my_end < other_end {
*my_end = other_end;
}
} else {
self.0.insert(client_id, other_end);
}
}
}
pub fn includes(&mut self, id: ID) -> bool {
if let Some(end) = self.get_mut(&id.client_id) {
if *end > id.counter {
return true;
}
}
false
}
}
impl Default for VersionVector {
@ -125,7 +153,11 @@ impl Default for VersionVector {
impl From<FxHashMap<ClientID, Counter>> for VersionVector {
fn from(map: FxHashMap<ClientID, Counter>) -> Self {
Self(map)
let mut im_map = ImHashMap::new();
for (client_id, counter) in map {
im_map.insert(client_id, counter);
}
Self(im_map)
}
}
@ -170,4 +202,18 @@ mod tests {
assert_eq!(a.partial_cmp(&b), Some(Ordering::Less));
}
}
#[test]
fn im() {
let mut a = VersionVector::new();
a.set_end(ID::new(1, 1));
a.set_end(ID::new(2, 1));
let mut b = a.clone();
b.merge(&vec![ID::new(1, 2), ID::new(2, 2)].into());
assert!(a != b);
assert_eq!(a.get(&1), Some(&2));
assert_eq!(a.get(&2), Some(&2));
assert_eq!(b.get(&1), Some(&3));
assert_eq!(b.get(&2), Some(&3));
}
}