diff --git a/src/decompression.jl b/src/decompression.jl index d8bddc0..cbf5f8b 100644 --- a/src/decompression.jl +++ b/src/decompression.jl @@ -299,30 +299,37 @@ end function decompress_aux!( A::AbstractMatrix{R}, B::AbstractMatrix{R}, result::TreeSetColoringResult ) where {R<:Real} + n = checksquare(A) A .= zero(R) S = get_matrix(result) color = column_colors(result) - forest = result.tree_set.forest + # forest is a structure DisjointSets from DataStructures.jl + # - forest.intmap: a dictionary that maps an edge (i, j) to an integer k + # - forest.revmap: a dictionary that does the reverse of intmap, mapping an integer k to an edge (i, j) + # - forest.internal.ngroups: the number of trees in the forest + forest = result.tree_set.forest ntrees = forest.internal.ngroups - for edge in forest.revmap - # ensure that all paths are compressed - find_root!(forest, edge) - end - roots = forest.internal.parents - unique_roots = unique(roots) + # vector of trees where each tree contains the indices of its edges trees = [Int[] for i in 1:ntrees] + + # dictionary that maps a tree's root to the index of the tree + roots = Dict{Int, Int}() + k = 0 - for root in unique_roots - k += 1 - for (pos, val) in enumerate(roots) - if root == val - push!(trees[k], pos) - end + for edge in forest.revmap + root_edge = find_root!(forest, edge) + root = forest.intmap[root_edge] + if !haskey(roots, root) + k += 1 + roots[root] = k end + index_tree = roots[root] + push!(trees[index_tree], forest.intmap[edge]) end + # vector of dictionaries where each dictionary stores the degree of each vertex in a tree degrees = [Dict{Int,Int}() for k in 1:ntrees] for k in 1:ntrees tree = trees[k] @@ -336,7 +343,8 @@ function decompress_aux!( end end - n = checksquare(A) + # stored_values holds the sum of edge values for subtrees in a tree. + # For each vertex i, stored_values[i] is the sum of edge values in the subtree rooted at i. stored_values = Vector{R}(undef, n) # Recover the diagonal coefficients of A