diff --git a/node.go b/node.go index 9524d3f..1be9639 100644 --- a/node.go +++ b/node.go @@ -221,6 +221,12 @@ func (n *Node[T]) ReverseIterator() *ReverseIterator[T] { return NewReverseIterator(n) } +// Iterator is used to return an iterator at +// the given node to walk the tree +func (n *Node[T]) PathIterator(path []byte) *PathIterator[T] { + return &PathIterator[T]{node: n, path: path} +} + // rawIterator is used to return a raw iterator at the given node to walk the // tree. func (n *Node[T]) rawIterator() *rawIterator[T] { @@ -274,30 +280,12 @@ func (n *Node[T]) WalkPrefix(prefix []byte, fn WalkFn[T]) { // all the entries *under* the given prefix, this walks the // entries *above* the given prefix. func (n *Node[T]) WalkPath(path []byte, fn WalkFn[T]) { - search := path - for { - // Visit the leaf values if any - if n.leaf != nil && fn(n.leaf.key, n.leaf.val) { - return - } - - // Check for key exhaustion - if len(search) == 0 { - return - } + i := n.PathIterator(path) - // Look for an edge - _, n = n.getEdge(search[0]) - if n == nil { + for path, val, ok := i.Next(); ok; path, val, ok = i.Next() { + if fn(path, val) { return } - - // Consume the search prefix - if bytes.HasPrefix(search, n.prefix) { - search = search[len(n.prefix):] - } else { - break - } } } diff --git a/path_iter.go b/path_iter.go new file mode 100644 index 0000000..d0fd345 --- /dev/null +++ b/path_iter.go @@ -0,0 +1,59 @@ +package iradix + +import "bytes" + +// PathIterator is used to iterate over a set of nodes from the root +// down to a specified path. This will iterate overthe same values that +// the Node.WalkPath method will. +type PathIterator[T any] struct { + node *Node[T] + path []byte + done bool +} + +// Next returns the next node in order +func (i *PathIterator[T]) Next() ([]byte, T, bool) { + // This is mostly just an asyncrhonous implementation of the WalkPath + // method on the node. + var zero T + var leaf *leafNode[T] + + for leaf == nil && i.node != nil { + // visit the leaf values if any + if i.node.leaf != nil { + leaf = i.node.leaf + } + + i.iterate() + } + + if leaf != nil { + return leaf.key, leaf.val, true + } + + return nil, zero, false +} + +func (i *PathIterator[T]) iterate() { + // Check for key exhaustion + if len(i.path) == 0 { + i.node = nil + return + } + + // Look for an edge + _, i.node = i.node.getEdge(i.path[0]) + if i.node == nil { + return + } + + // Consume the search prefix + if bytes.HasPrefix(i.path, i.node.prefix) { + i.path = i.path[len(i.node.prefix):] + } else { + // there are no more nodes to iterate through so + // nil out the node to prevent returning results + // for subsequent calls to Next() + i.node = nil + } +}