diff --git a/src/iterators/bfs.jl b/src/iterators/bfs.jl index d45f7793..4e240bfa 100644 --- a/src/iterators/bfs.jl +++ b/src/iterators/bfs.jl @@ -1,9 +1,11 @@ """ - BFSIterator + BFSIterator(graph, source; depth_limit=nothing, neighbors_type=outneighbors) -`BFSIterator` is used to iterate through graph vertices using a breadth-first search. -A source node(s) is optionally supplied as an `Int` or an array-like type that can be -indexed if supplying multiple sources. +`BFSIterator` is used to iterate through graph vertices using a breadth-first search. +A source node(s) must be supplied as an `Int` or an array-like type that can be +indexed if supplying multiple sources. It is also possible to specify a `depth_limit` +which will stop the search once all nodes at that depth are visited and a `neighbors_type` +which specifies what kind of neighbors of a node should be considered when exploring the graph. # Examples ```julia-repl @@ -20,14 +22,18 @@ julia> for node in BFSIterator(g,3) 2 ``` """ -struct BFSIterator{S,G<:AbstractGraph} +struct BFSIterator{S,G<:AbstractGraph,F} graph::G source::S - function BFSIterator(graph::G, source::S) where {S,G} + depth_limit::Int + neighbors_type::F + function BFSIterator( + graph::G, source::S; depth_limit=typemax(Int64), neighbors_type::F=outneighbors + ) where {S,G,F} if any(node -> !has_vertex(graph, node), source) error("Some source nodes for the iterator are not in the graph") end - return new{S,G}(graph, source) + return new{S,G,F}(graph, source, depth_limit, neighbors_type) end end @@ -46,6 +52,7 @@ mutable struct BFSVertexIteratorState next_level::Vector{Int} node_idx::Int n_visited::Int + n_level::Int end Base.IteratorSize(::BFSIterator) = Base.SizeUnknown() @@ -59,7 +66,7 @@ First iteration to visit vertices in a graph using breadth-first search. function Base.iterate(t::BFSIterator{<:Integer}) visited = falses(nv(t.graph)) visited[t.source] = true - state = BFSVertexIteratorState(visited, [t.source], Int[], 0, 0) + state = BFSVertexIteratorState(visited, [t.source], Int[], 0, 0, 0) return Base.iterate(t, state) end @@ -68,7 +75,7 @@ function Base.iterate(t::BFSIterator{<:AbstractArray}) curr_level = unique(s for s in t.source) sort!(curr_level) visited[curr_level] .= true - state = BFSVertexIteratorState(visited, curr_level, Int[], 0, 0) + state = BFSVertexIteratorState(visited, curr_level, Int[], 0, 0, 0) return Base.iterate(t, state) end @@ -80,10 +87,13 @@ Iterator to visit vertices in a graph using breadth-first search. function Base.iterate(t::BFSIterator, state::BFSVertexIteratorState) # we fill nodes in this level if state.node_idx == length(state.curr_level) + state.n_level == t.depth_limit && return nothing + state.n_level += 1 state.n_visited += length(state.curr_level) state.n_visited == nv(t.graph) && return nothing + neighbors_type = t.neighbors_type @inbounds for node in state.curr_level - for adj_node in outneighbors(t.graph, node) + for adj_node in neighbors_type(t.graph, node) if !state.visited[adj_node] push!(state.next_level, adj_node) state.visited[adj_node] = true diff --git a/test/iterators/bfs.jl b/test/iterators/bfs.jl index f459eafa..00e9fd02 100644 --- a/test/iterators/bfs.jl +++ b/test/iterators/bfs.jl @@ -45,4 +45,12 @@ @test sort(nodes_visited[1:3]) == sort(levels[1]) @test sort(nodes_visited[4:8]) == sort(levels[2]) @test sort(nodes_visited[9:end]) == sort(levels[3]) + + nodes_visited = collect(BFSIterator(g2, [8, 1, 6]; depth_limit=1)) + @test sort(nodes_visited[1:3]) == sort(levels[1]) + @test sort(nodes_visited[4:end]) == sort(levels[2]) + + g = path_digraph(7) + nodes_visited = collect(BFSIterator(g, 7; neighbors_type=inneighbors)) + @test nodes_visited == collect(7:-1:1) end