diff --git a/ipld/amt/src/amt.rs b/ipld/amt/src/amt.rs index 0f430b3aa..2ddbce751 100644 --- a/ipld/amt/src/amt.rs +++ b/ipld/amt/src/amt.rs @@ -508,6 +508,7 @@ where /// assert_eq!(num_traversed, 3); /// assert_eq!(next_idx, Some(10)); /// ``` + #[deprecated = "use `.iter_from()` and `.take(limit)` instead"] pub fn for_each_ranged( &self, start_at: Option, @@ -517,25 +518,16 @@ where where F: FnMut(u64, &V) -> anyhow::Result<()>, { - if let Some(start_at) = start_at { - if start_at >= nodes_for_height(self.bit_width(), self.height() + 1) { - return Ok((0, None)); + let mut num_traversed = 0; + for kv in self.iter_from(start_at.unwrap_or(0))? { + let (k, v) = kv?; + if limit.map(|l| num_traversed >= l).unwrap_or(false) { + return Ok((num_traversed, Some(k))); } + num_traversed += 1; + f(k, v)?; } - - let (_, num_traversed, next_index) = self.root.node.for_each_while_ranged( - &self.block_store, - start_at, - limit, - self.height(), - self.bit_width(), - 0, - &mut |i, v| { - f(i, v)?; - Ok(true) - }, - )?; - Ok((num_traversed, next_index)) + Ok((num_traversed, None)) } /// Iterates over values in the Amt and runs a function on the values, for as long as that @@ -547,6 +539,7 @@ where /// `limit` elements have been traversed. Returns a tuple describing the number of elements /// iterated over and optionally the index of the next element in the AMT if more elements /// remain. + #[deprecated = "use `.iter_from()` and `.take(limit)` instead"] pub fn for_each_while_ranged( &self, start_at: Option, @@ -556,22 +549,17 @@ where where F: FnMut(u64, &V) -> anyhow::Result, { - if let Some(start_at) = start_at { - if start_at >= nodes_for_height(self.bit_width(), self.height() + 1) { - return Ok((0, None)); + let mut num_traversed = 0; + let mut keep_going = true; + for kv in self.iter_from(start_at.unwrap_or(0))? { + let (k, v) = kv?; + if !keep_going || limit.map(|l| num_traversed >= l).unwrap_or(false) { + return Ok((num_traversed, Some(k))); } + num_traversed += 1; + keep_going = f(k, v)?; } - - let (_, num_traversed, next_index) = self.root.node.for_each_while_ranged( - &self.block_store, - start_at, - limit, - self.height(), - self.bit_width(), - 0, - &mut f, - )?; - Ok((num_traversed, next_index)) + Ok((num_traversed, None)) } /// Iterates over each value in the Amt and runs a function on the values that allows modifying diff --git a/ipld/amt/src/diff.rs b/ipld/amt/src/diff.rs index 63e60cae0..291e3cc9e 100644 --- a/ipld/amt/src/diff.rs +++ b/ipld/amt/src/diff.rs @@ -9,7 +9,9 @@ use fvm_ipld_blockstore::Blockstore; use fvm_ipld_encoding::CborStore; use serde::{de::DeserializeOwned, Serialize}; +use crate::iter::Iter; use crate::node::{CollapsedNode, Link}; +use crate::root::version; use super::*; @@ -108,14 +110,14 @@ where Node::Leaf { vals } => vals.len(), Node::Link { links } => links.len(), }); - node.for_each_while(ctx.store, ctx.height, ctx.bit_width, offset, &mut |i, x| { + for kv in Iter::<_, _, version::V3>::new(node, ctx.store, ctx.height, ctx.bit_width, offset) { + let (k, v) = kv?; changes.push(Change { - key: i, + key: k, before: None, - after: Some(x.clone()), + after: Some(v.clone()), }); - Ok(true) - })?; + } Ok(changes) } @@ -133,14 +135,14 @@ where Node::Leaf { vals } => vals.len(), Node::Link { links } => links.len(), }); - node.for_each_while(ctx.store, ctx.height, ctx.bit_width, offset, &mut |i, x| { + for kv in Iter::<_, _, version::V3>::new(node, ctx.store, ctx.height, ctx.bit_width, offset) { + let (k, v) = kv?; changes.push(Change { - key: i, - before: Some(x.clone()), + key: k, + before: Some(v.clone()), after: None, }); - Ok(true) - })?; + } Ok(changes) } diff --git a/ipld/amt/src/iter.rs b/ipld/amt/src/iter.rs index 3a56b3e37..03c8cf17f 100644 --- a/ipld/amt/src/iter.rs +++ b/ipld/amt/src/iter.rs @@ -5,6 +5,7 @@ use crate::node::CollapsedNode; use crate::node::{Link, Node}; use crate::MAX_INDEX; use crate::{nodes_for_height, Error}; +use anyhow::anyhow; use fvm_ipld_blockstore::Blockstore; use fvm_ipld_encoding::ser::Serialize; use fvm_ipld_encoding::CborStore; @@ -30,7 +31,7 @@ where /// let kvs: Vec = (0..=5).collect(); /// kvs /// .iter() - /// .map(|k| amt.set(u64::try_from(*k).unwrap(), k.to_string())) + /// .map(|k| amt.set(*k, k.to_string())) /// .collect::>(); /// /// for kv in &amt { @@ -41,17 +42,111 @@ where /// # anyhow::Ok(()) /// ``` pub fn iter(&self) -> Iter<'_, V, &BS, Ver> { - Iter { - stack: vec![IterStack { - node: &self.root.node, - idx: 0, - }], - height: self.root.height, + Iter::new( + &self.root.node, + &self.block_store, + self.height(), + self.bit_width(), + 0, + ) + } + + /// Iterate over the AMT from the given starting point. + /// + /// ```rust + /// use fvm_ipld_amt::Amt; + /// use fvm_ipld_blockstore::MemoryBlockstore; + /// + /// let store = MemoryBlockstore::default(); + /// + /// let mut amt = Amt::new(store); + /// let kvs: Vec = (0..=5).collect(); + /// kvs + /// .iter() + /// .map(|k| amt.set(*k, k.to_string())) + /// .collect::>(); + /// + /// for kv in amt.iter_from(3)? { + /// let (k, v) = kv?; + /// println!("{k:?}: {v:?}"); + /// } + /// + /// # anyhow::Ok(()) + /// ``` + pub fn iter_from(&self, start: u64) -> Result, Error> { + // Short-circuit when we're starting at 0. + if start == 0 { + return Ok(self.iter()); + } + + let height = self.height(); + let bit_width = self.bit_width(); + + // Fast-path for case where start is beyond what we know this amt could currently contain. + if start >= nodes_for_height(bit_width, height + 1) { + return Ok(Iter { + height, + bit_width, + stack: Vec::new(), + blockstore: &self.block_store, + ver: PhantomData, + key: start, + }); + } + + let mut stack = Vec::with_capacity(height as usize); + let mut node = &self.root.node; + let mut offset = 0; + loop { + let start_idx = start.saturating_sub(offset); + match node { + Node::Leaf { vals } => { + if start_idx >= vals.len() as u64 { + // Not deep enough. + return Err(anyhow!("incorrect height for tree depth: expected values at depth {}, found them at {}", height, stack.len()).into()); + } + stack.push(IterStack { + node, + idx: start_idx as usize, + }); + break; + } + Node::Link { links } => { + let nfh = + nodes_for_height(self.bit_width(), self.height() - stack.len() as u32); + let idx: usize = (start_idx / nfh).try_into().expect("index overflow"); + assert!(idx < links.len(), "miscalculated nodes for height"); + let Some(l) = &links[idx] else { + // If there's nothing here, mark this as the starting point. We'll start + // scanning here when we iterate. + stack.push(IterStack { node, idx }); + break; + }; + let sub = match l { + Link::Dirty(sub) => sub, + Link::Cid { cid, cache } => cache.get_or_try_init(|| { + self.block_store + .get_cbor::>(cid)? + .ok_or_else(|| Error::CidNotFound(cid.to_string()))? + .expand(self.bit_width()) + .map(Box::new) + })?, + }; + // Push idx+1 because we've already processed this node. + stack.push(IterStack { node, idx: idx + 1 }); + node = sub; + offset += idx as u64 * nfh; + } + } + } + Ok(Iter { + stack, + height, + bit_width, blockstore: &self.block_store, - bit_width: self.bit_width(), ver: PhantomData, - key: 0, - } + key: start, + }) } } @@ -77,6 +172,27 @@ pub struct Iter<'a, V, BS, Ver> { key: u64, } +impl<'a, V, BS, Ver> Iter<'a, V, &'a BS, Ver> { + pub(crate) fn new( + node: &'a Node, + blockstore: &'a BS, + height: u32, + bit_width: u32, + offset: u64, + ) -> Self { + let mut stack = Vec::with_capacity(height as usize); + stack.push(IterStack { node, idx: 0 }); + Iter { + stack, + height, + blockstore, + bit_width, + ver: PhantomData, + key: offset, + } + } +} + pub struct IterStack<'a, V> { pub(crate) node: &'a Node, pub(crate) idx: usize, diff --git a/ipld/amt/src/node.rs b/ipld/amt/src/node.rs index 2d5becf09..5a594e9bc 100644 --- a/ipld/amt/src/node.rs +++ b/ipld/amt/src/node.rs @@ -420,61 +420,6 @@ where } } - pub(super) fn for_each_while( - &self, - bs: &S, - height: u32, - bit_width: u32, - offset: u64, - f: &mut F, - ) -> Result - where - F: FnMut(u64, &V) -> anyhow::Result, - S: Blockstore, - { - match self { - Node::Leaf { vals } => { - for (i, v) in (0..).zip(vals.iter()) { - if let Some(v) = v { - let keep_going = f(offset + i, v)?; - - if !keep_going { - return Ok(false); - } - } - } - } - Node::Link { links } => { - for (i, l) in (0..).zip(links.iter()) { - if let Some(l) = l { - let offs = offset + (i * nodes_for_height(bit_width, height)); - let keep_going = match l { - Link::Dirty(sub) => { - sub.for_each_while(bs, height - 1, bit_width, offs, f)? - } - Link::Cid { cid, cache } => { - let cached_node = cache.get_or_try_init(|| { - bs.get_cbor::>(cid)? - .ok_or_else(|| Error::CidNotFound(cid.to_string()))? - .expand(bit_width) - .map(Box::new) - })?; - - cached_node.for_each_while(bs, height - 1, bit_width, offs, f)? - } - }; - - if !keep_going { - return Ok(false); - } - } - } - } - } - - Ok(true) - } - /// Returns a `(keep_going, did_mutate)` pair. `keep_going` will be `false` iff /// a closure call returned `Ok(false)`, indicating that a `break` has happened. /// `did_mutate` will be `true` iff any of the values in the node was actually @@ -552,102 +497,6 @@ where Ok((true, did_mutate)) } - - /// Iterates through the current node in the tree and all subtrees. `start_at` refers to the - /// global AMT index, before which no values should be traversed and `limit` is the maximum - /// number of leaf nodes that should be traversed in this subtree. `offset` refers the offset - /// in the global AMT address space that this subtree is rooted at. - #[allow(clippy::too_many_arguments)] - pub(super) fn for_each_while_ranged( - &self, - bs: &S, - start_at: Option, - limit: Option, - height: u32, - bit_width: u32, - offset: u64, - f: &mut F, - ) -> Result<(bool, u64, Option), Error> - where - F: FnMut(u64, &V) -> anyhow::Result, - S: Blockstore, - { - let mut traversed_count = 0_u64; - match self { - Node::Leaf { vals } => { - let start_idx = start_at.map_or(0, |s| s.saturating_sub(offset)); - if start_idx as usize >= vals.len() { - return Ok((false, 0, None)); - } - let mut keep_going = true; - for (i, v) in (start_idx..).zip(vals[start_idx as usize..].iter()) { - let idx = offset + i; - if let Some(v) = v { - if limit.map_or(false, |l| traversed_count >= l) { - return Ok((keep_going, traversed_count, Some(idx))); - } else if !keep_going { - return Ok((false, traversed_count, Some(idx))); - } - keep_going = f(idx, v)?; - traversed_count += 1; - } - } - } - Node::Link { links } => { - let nfh = nodes_for_height(bit_width, height); - let idx: usize = ((start_at.map_or(0, |s| s.saturating_sub(offset))) / nfh) - .try_into() - .expect("index overflow"); - if idx >= links.len() { - return Ok((false, 0, None)); - } - for (i, link) in (idx..).zip(links[idx..].iter()) { - if let Some(l) = link { - let offs = offset + (i as u64 * nfh); - let (keep_going, count, next) = match l { - Link::Dirty(sub) => sub.for_each_while_ranged( - bs, - start_at, - limit.map(|l| l.checked_sub(traversed_count).unwrap()), - height - 1, - bit_width, - offs, - f, - )?, - Link::Cid { cid, cache } => { - let cached_node = cache.get_or_try_init(|| { - bs.get_cbor::>(cid)? - .ok_or_else(|| Error::CidNotFound(cid.to_string()))? - .expand(bit_width) - .map(Box::new) - })?; - - cached_node.for_each_while_ranged( - bs, - start_at, - limit.map(|l| l.checked_sub(traversed_count).unwrap()), - height - 1, - bit_width, - offs, - f, - )? - } - }; - - traversed_count += count; - - if limit.map_or(false, |l| traversed_count >= l) && next.is_some() { - return Ok((keep_going, traversed_count, next)); - } else if !keep_going { - return Ok((false, traversed_count, next)); - } - } - } - } - }; - - Ok((true, traversed_count, None)) - } } #[cfg(test)] diff --git a/ipld/amt/tests/amt_tests.rs b/ipld/amt/tests/amt_tests.rs index 47d3e8822..886be203b 100644 --- a/ipld/amt/tests/amt_tests.rs +++ b/ipld/amt/tests/amt_tests.rs @@ -420,6 +420,7 @@ fn for_each_ranged() { // Iterate over amt with dirty cache from different starting values for start_val in 0..RANGE { let mut retrieved_values = Vec::new(); + #[allow(deprecated)] let (count, next_key) = a .for_each_while_ranged(Some(start_val), None, |index, _: &BytesDe| { retrieved_values.push(index); @@ -435,6 +436,7 @@ fn for_each_ranged() { // Iterate out of bounds for i in [RANGE, RANGE + 1, 2 * RANGE, 8 * RANGE] { + #[allow(deprecated)] let (count, next_key) = a .for_each_while_ranged(Some(i), None, |_, _: &BytesDe| { panic!("didn't expect to iterate") @@ -447,6 +449,7 @@ fn for_each_ranged() { // Iterate over amt with dirty cache with different page sizes for page_size in 1..=RANGE { let mut retrieved_values = Vec::new(); + #[allow(deprecated)] let (count, next_key) = a .for_each_while_ranged(None, Some(page_size), |index, _: &BytesDe| { retrieved_values.push(index); @@ -468,6 +471,7 @@ fn for_each_ranged() { let mut retrieved_values = Vec::new(); let mut start_cursor = None; loop { + #[allow(deprecated)] let (num_traversed, next_cursor) = a .for_each_while_ranged(start_cursor, Some(page_size), |idx, _val| { retrieved_values.push(idx); @@ -493,6 +497,7 @@ fn for_each_ranged() { let mut retrieved_values = Vec::new(); let mut start_cursor = None; loop { + #[allow(deprecated)] let (num_traversed, next_cursor) = a .for_each_ranged(start_cursor, Some(page_size), |idx, _val: &BytesDe| { retrieved_values.push(idx); @@ -518,6 +523,7 @@ fn for_each_ranged() { // Iterate over the amt with dirty cache ignoring gaps in the address space including at the // beginning of the amt, we should only see the values that were not deleted + #[allow(deprecated)] let (num_traversed, next_cursor) = a .for_each_while_ranged(Some(0), Some(501), |i, _v| { assert_eq!((i / 10) % 2, 1); // only "odd" batches of ten 10 - 19, 30 - 39, etc. should be present @@ -530,6 +536,7 @@ fn for_each_ranged() { // flush the amt to the blockstore, reload and repeat the test with a clean cache let cid = a.flush().unwrap(); let a = Amt::load(&cid, &db).unwrap(); + #[allow(deprecated)] let (num_traversed, next_cursor) = a .for_each_while_ranged(Some(0), Some(501), |i, _v: &BytesDe| { assert_eq!((i / 10) % 2, 1); // only "odd" batches of ten 10 - 19, 30 - 39, etc. should be present