Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: amt: implement external iteration for ranged iteration
Browse files Browse the repository at this point in the history
part of #1861
Stebalien committed Jan 22, 2024
1 parent 40fd112 commit 1737133
Showing 4 changed files with 124 additions and 127 deletions.
50 changes: 19 additions & 31 deletions ipld/amt/src/amt.rs
Original file line number Diff line number Diff line change
@@ -508,6 +508,7 @@ where
/// assert_eq!(num_traversed, 3);
/// assert_eq!(next_idx, Some(10));
/// ```
#[deprecated = "use `.iter_from()` and `.take(limit)` instead"]
pub fn for_each_ranged<F>(
&self,
start_at: Option<u64>,
@@ -517,25 +518,16 @@ where
where
F: FnMut(u64, &V) -> anyhow::Result<()>,
{
if let Some(start_at) = start_at {
if start_at >= nodes_for_height(self.bit_width(), self.height() + 1) {
return Ok((0, None));
let mut num_traversed = 0;
for kv in self.iter_from(start_at.unwrap_or(0))? {
let (k, v) = kv?;
if limit.map(|l| num_traversed >= l).unwrap_or(false) {
return Ok((num_traversed, Some(k)));
}
num_traversed += 1;
f(k, v)?;
}

let (_, num_traversed, next_index) = self.root.node.for_each_while_ranged(
&self.block_store,
start_at,
limit,
self.height(),
self.bit_width(),
0,
&mut |i, v| {
f(i, v)?;
Ok(true)
},
)?;
Ok((num_traversed, next_index))
Ok((num_traversed, None))
}

/// Iterates over values in the Amt and runs a function on the values, for as long as that
@@ -547,6 +539,7 @@ where
/// `limit` elements have been traversed. Returns a tuple describing the number of elements
/// iterated over and optionally the index of the next element in the AMT if more elements
/// remain.
#[deprecated = "use `.iter_from()` and `.take(limit)` instead"]
pub fn for_each_while_ranged<F>(
&self,
start_at: Option<u64>,
@@ -556,22 +549,17 @@ where
where
F: FnMut(u64, &V) -> anyhow::Result<bool>,
{
if let Some(start_at) = start_at {
if start_at >= nodes_for_height(self.bit_width(), self.height() + 1) {
return Ok((0, None));
let mut num_traversed = 0;
let mut keep_going = true;
for kv in self.iter_from(start_at.unwrap_or(0))? {
let (k, v) = kv?;
if !keep_going || limit.map(|l| num_traversed >= l).unwrap_or(false) {
return Ok((num_traversed, Some(k)));
}
num_traversed += 1;
keep_going = f(k, v)?;
}

let (_, num_traversed, next_index) = self.root.node.for_each_while_ranged(
&self.block_store,
start_at,
limit,
self.height(),
self.bit_width(),
0,
&mut f,
)?;
Ok((num_traversed, next_index))
Ok((num_traversed, None))
}

/// Iterates over each value in the Amt and runs a function on the values that allows modifying
98 changes: 98 additions & 0 deletions ipld/amt/src/iter.rs
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@ use crate::node::CollapsedNode;
use crate::node::{Link, Node};
use crate::MAX_INDEX;
use crate::{nodes_for_height, Error};
use anyhow::anyhow;
use fvm_ipld_blockstore::Blockstore;
use fvm_ipld_encoding::ser::Serialize;
use fvm_ipld_encoding::CborStore;
@@ -53,6 +54,103 @@ where
key: 0,
}
}

/// Iterate over the AMT from the given starting point.
///
/// ```rust
/// use fvm_ipld_amt::Amt;
/// use fvm_ipld_blockstore::MemoryBlockstore;
///
/// let store = MemoryBlockstore::default();
///
/// let mut amt = Amt::new(store);
/// let kvs: Vec<u64> = (0..=5).collect();
/// kvs
/// .iter()
/// .map(|k| amt.set(u64::try_from(*k).unwrap(), k.to_string()))
/// .collect::<Vec<_>>();
///
/// for kv in amt.iter_from(3) {
/// let (k, v) = kv?;
/// println!("{k:?}: {v:?}");
/// }
///
/// # anyhow::Ok(())
/// ```
pub fn iter_from(&self, start: u64) -> Result<Iter<'_, V, &BS, Ver>, Error> {
// Short-circuit when we're starting at 0.
if start == 0 {
return Ok(self.iter());
}

let height = self.height();
let bit_width = self.bit_width();

if start >= nodes_for_height(bit_width, height + 1) {
return Ok(Iter {
height,
bit_width,
stack: Vec::new(),
blockstore: &self.block_store,
ver: PhantomData,
key: start,
});
}

let mut stack = vec![];
let mut node = &self.root.node;
let mut offset = 0;
loop {
let start_idx = start.saturating_sub(offset);
match node {
Node::Leaf { vals } => {
if start_idx >= vals.len() as u64 {
// Not deep enough.
return Err(anyhow!("incorrect height for tree depth: expected values at depth {}, found them at {}", height, stack.len()).into());
}
stack.push(IterStack {
node,
idx: start_idx as usize,
});
break;
}
Node::Link { links } => {
let nfh =
nodes_for_height(self.bit_width(), self.height() - stack.len() as u32);
let idx: usize = (start_idx / nfh).try_into().expect("index overflow");
assert!(idx < links.len(), "miscalculated nodes for height");
let Some(l) = &links[idx] else {
// If there's nothing here, mark this as the starting point. We'll start
// scanning here when we iterate.
stack.push(IterStack { node, idx });
break;
};
let sub = match l {
Link::Dirty(sub) => sub,
Link::Cid { cid, cache } => cache.get_or_try_init(|| {
self.block_store
.get_cbor::<CollapsedNode<V>>(cid)?
.ok_or_else(|| Error::CidNotFound(cid.to_string()))?
.expand(self.bit_width())
.map(Box::new)
})?,
};
// Push idx+1 because we've already processed this node.
stack.push(IterStack { node, idx: idx + 1 });
node = sub;
offset += idx as u64 * nfh;
}
}
}
Ok(Iter {
stack,
height,
bit_width,
blockstore: &self.block_store,
ver: PhantomData,
key: start,
})
}
}

impl<'a, V, BS, Ver> IntoIterator for &'a crate::AmtImpl<V, BS, Ver>
96 changes: 0 additions & 96 deletions ipld/amt/src/node.rs
Original file line number Diff line number Diff line change
@@ -552,102 +552,6 @@ where

Ok((true, did_mutate))
}

/// Iterates through the current node in the tree and all subtrees. `start_at` refers to the
/// global AMT index, before which no values should be traversed and `limit` is the maximum
/// number of leaf nodes that should be traversed in this subtree. `offset` refers the offset
/// in the global AMT address space that this subtree is rooted at.
#[allow(clippy::too_many_arguments)]
pub(super) fn for_each_while_ranged<S, F>(
&self,
bs: &S,
start_at: Option<u64>,
limit: Option<u64>,
height: u32,
bit_width: u32,
offset: u64,
f: &mut F,
) -> Result<(bool, u64, Option<u64>), Error>
where
F: FnMut(u64, &V) -> anyhow::Result<bool>,
S: Blockstore,
{
let mut traversed_count = 0_u64;
match self {
Node::Leaf { vals } => {
let start_idx = start_at.map_or(0, |s| s.saturating_sub(offset));
if start_idx as usize >= vals.len() {
return Ok((false, 0, None));
}
let mut keep_going = true;
for (i, v) in (start_idx..).zip(vals[start_idx as usize..].iter()) {
let idx = offset + i;
if let Some(v) = v {
if limit.map_or(false, |l| traversed_count >= l) {
return Ok((keep_going, traversed_count, Some(idx)));
} else if !keep_going {
return Ok((false, traversed_count, Some(idx)));
}
keep_going = f(idx, v)?;
traversed_count += 1;
}
}
}
Node::Link { links } => {
let nfh = nodes_for_height(bit_width, height);
let idx: usize = ((start_at.map_or(0, |s| s.saturating_sub(offset))) / nfh)
.try_into()
.expect("index overflow");
if idx >= links.len() {
return Ok((false, 0, None));
}
for (i, link) in (idx..).zip(links[idx..].iter()) {
if let Some(l) = link {
let offs = offset + (i as u64 * nfh);
let (keep_going, count, next) = match l {
Link::Dirty(sub) => sub.for_each_while_ranged(
bs,
start_at,
limit.map(|l| l.checked_sub(traversed_count).unwrap()),
height - 1,
bit_width,
offs,
f,
)?,
Link::Cid { cid, cache } => {
let cached_node = cache.get_or_try_init(|| {
bs.get_cbor::<CollapsedNode<V>>(cid)?
.ok_or_else(|| Error::CidNotFound(cid.to_string()))?
.expand(bit_width)
.map(Box::new)
})?;

cached_node.for_each_while_ranged(
bs,
start_at,
limit.map(|l| l.checked_sub(traversed_count).unwrap()),
height - 1,
bit_width,
offs,
f,
)?
}
};

traversed_count += count;

if limit.map_or(false, |l| traversed_count >= l) && next.is_some() {
return Ok((keep_going, traversed_count, next));
} else if !keep_going {
return Ok((false, traversed_count, next));
}
}
}
}
};

Ok((true, traversed_count, None))
}
}

#[cfg(test)]
7 changes: 7 additions & 0 deletions ipld/amt/tests/amt_tests.rs
Original file line number Diff line number Diff line change
@@ -420,6 +420,7 @@ fn for_each_ranged() {
// Iterate over amt with dirty cache from different starting values
for start_val in 0..RANGE {
let mut retrieved_values = Vec::new();
#[allow(deprecated)]
let (count, next_key) = a
.for_each_while_ranged(Some(start_val), None, |index, _: &BytesDe| {
retrieved_values.push(index);
@@ -435,6 +436,7 @@ fn for_each_ranged() {

// Iterate out of bounds
for i in [RANGE, RANGE + 1, 2 * RANGE, 8 * RANGE] {
#[allow(deprecated)]
let (count, next_key) = a
.for_each_while_ranged(Some(i), None, |_, _: &BytesDe| {
panic!("didn't expect to iterate")
@@ -447,6 +449,7 @@ fn for_each_ranged() {
// Iterate over amt with dirty cache with different page sizes
for page_size in 1..=RANGE {
let mut retrieved_values = Vec::new();
#[allow(deprecated)]
let (count, next_key) = a
.for_each_while_ranged(None, Some(page_size), |index, _: &BytesDe| {
retrieved_values.push(index);
@@ -468,6 +471,7 @@ fn for_each_ranged() {
let mut retrieved_values = Vec::new();
let mut start_cursor = None;
loop {
#[allow(deprecated)]
let (num_traversed, next_cursor) = a
.for_each_while_ranged(start_cursor, Some(page_size), |idx, _val| {
retrieved_values.push(idx);
@@ -493,6 +497,7 @@ fn for_each_ranged() {
let mut retrieved_values = Vec::new();
let mut start_cursor = None;
loop {
#[allow(deprecated)]
let (num_traversed, next_cursor) = a
.for_each_ranged(start_cursor, Some(page_size), |idx, _val: &BytesDe| {
retrieved_values.push(idx);
@@ -518,6 +523,7 @@ fn for_each_ranged() {

// Iterate over the amt with dirty cache ignoring gaps in the address space including at the
// beginning of the amt, we should only see the values that were not deleted
#[allow(deprecated)]
let (num_traversed, next_cursor) = a
.for_each_while_ranged(Some(0), Some(501), |i, _v| {
assert_eq!((i / 10) % 2, 1); // only "odd" batches of ten 10 - 19, 30 - 39, etc. should be present
@@ -530,6 +536,7 @@ fn for_each_ranged() {
// flush the amt to the blockstore, reload and repeat the test with a clean cache
let cid = a.flush().unwrap();
let a = Amt::load(&cid, &db).unwrap();
#[allow(deprecated)]
let (num_traversed, next_cursor) = a
.for_each_while_ranged(Some(0), Some(501), |i, _v: &BytesDe| {
assert_eq!((i / 10) % 2, 1); // only "odd" batches of ten 10 - 19, 30 - 39, etc. should be present

0 comments on commit 1737133

Please sign in to comment.