Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize CSE: Transition to DAG Representation with Hash Consing for Faster Equality Checks #688

Merged
merged 11 commits into from
Jan 23, 2025
98 changes: 56 additions & 42 deletions src/code.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module Code

using StaticArrays, SparseArrays, LinearAlgebra, NaNMath, SpecialFunctions
using StaticArrays, SparseArrays, LinearAlgebra, NaNMath, SpecialFunctions,
DocStringExtensions

export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr,
SetArray, MakeArray, MakeSparseArray, MakeTuple, AtIndex,
Expand Down Expand Up @@ -696,6 +697,52 @@ end

@inline newsym(::Type{T}) where T = Sym{T}(gensym("cse"))

"""
$(SIGNATURES)

Perform a topological sort on a symbolic expression represented as a Directed Acyclic
Graph (DAG).

This function takes a symbolic expression `graph` (potentially containing shared common
sub-expressions) and returns an array of `Assignment` objects. Each `Assignment`
represents a node in the sorted order, assigning a fresh symbol to its corresponding
expression. The order ensures that all dependencies of a node appear before the node itself
in the array.

Hash consing is assumed, meaning that structurally identical expressions are represented by
the same object in memory. This allows for efficient equality checks using `IdDict`.
"""
function topological_sort(graph)
sorted_nodes = Assignment[]
visited = IdDict()

function dfs(node)
if haskey(visited, node)
return visited[node]
end
if iscall(node)
args = map(dfs, arguments(node))
new_node = maketerm(typeof(node), operation(node), args, metadata(node))
sym = newsym(symtype(new_node))
push!(sorted_nodes, sym ← new_node)
visited[node] = sym
return sym
elseif _is_array_of_symbolics(node)
new_node = map(dfs, node)
sym = newsym(typeof(new_node))
push!(sorted_nodes, sym ← new_node)
visited[node] = sym
return sym
else
visited[node] = node
bowenszhu marked this conversation as resolved.
Show resolved Hide resolved
return node
end
end

dfs(graph)
return sorted_nodes
end

function _cse!(mem, expr)
iscall(expr) || return expr
op = _cse!(mem, operation(expr))
Expand All @@ -714,12 +761,16 @@ function _cse!(mem, expr)
end

function cse(expr)
state = Dict{Any, Int}()
cse_state!(state, expr)
cse_block(state, expr)
sorted_nodes = topological_sort(expr)
if isempty(sorted_nodes)
return Let(Assignment[], expr)
else
last_assignment = pop!(sorted_nodes)
body = rhs(last_assignment)
return Let(sorted_nodes, body)
end
end


function _cse(exprs::AbstractArray)
letblock = cse(Term{Any}(tuple, vec(exprs)))
letblock.pairs, reshape(arguments(letblock.body), size(exprs))
Expand All @@ -746,41 +797,4 @@ function cse(x::MakeSparseArray)
end
end


function cse_state!(state, t)
!iscall(t) && return t
state[t] = Base.get(state, t, 0) + 1
foreach(x->cse_state!(state, x), arguments(t))
end

function cse_block!(assignments, counter, names, name, state, x)
if get(state, x, 0) > 1
if haskey(names, x)
return names[x]
else
sym = Sym{symtype(x)}(Symbol(name, counter[]))
names[x] = sym
push!(assignments, sym ← x)
counter[] += 1
return sym
end
elseif iscall(x)
args = map(a->cse_block!(assignments, counter, names, name, state,a), arguments(x))
if isterm(x)
return term(operation(x), args...)
else
return maketerm(typeof(x), operation(x), args, metadata(x))
end
else
return x
end
end

function cse_block(state, t, name=Symbol("var-", hash(t)))
assignments = Assignment[]
counter = Ref{Int}(1)
names = Dict{Any, BasicSymbolic}()
Let(assignments, cse_block!(assignments, counter, names, name, state, t))
end

end
52 changes: 49 additions & 3 deletions test/cse.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,56 @@
using SymbolicUtils, SymbolicUtils.Code, Test
using SymbolicUtils.Code: topological_sort

@testset "CSE" begin
@syms x
t = cse(hypot(hypot(cos(x), sin(x)), atan(cos(x), sin(x))))

@test t isa Let
@test length(t.pairs) == 2
@test occursin(t.pairs[1].lhs, t.body)
@test occursin(t.pairs[2].lhs, t.body)
@test length(t.pairs) == 4
@test occursin(t.pairs[3].lhs, t.body)
@test occursin(t.pairs[4].lhs, t.body)
end

@testset "DAG CSE" begin
@syms a b
expr = sin(a + b) * (a + b)
sorted_nodes = topological_sort(expr)
@test length(sorted_nodes) == 3
@test isequal(sorted_nodes[1].rhs, a + b)
@test isequal(sin(sorted_nodes[1].lhs), sorted_nodes[2].rhs)

expr = (a + b)^(a + b)
sorted_nodes = topological_sort(expr)
@test length(sorted_nodes) == 2
@test isequal(sorted_nodes[1].rhs, a + b)
ab_node = sorted_nodes[1].lhs
@test isequal(ab_node^ab_node, sorted_nodes[2].rhs)
let_expr = cse(expr)
@test length(let_expr.pairs) == 1
@test isequal(let_expr.pairs[1].rhs, a + b)
corresponding_sym = let_expr.pairs[1].lhs
@test isequal(let_expr.body, corresponding_sym^corresponding_sym)

expr = a + b
sorted_nodes = topological_sort(expr)
@test length(sorted_nodes) == 1
@test isequal(sorted_nodes[1].rhs, a + b)
let_expr = cse(expr)
@test isempty(let_expr.pairs)
@test isequal(let_expr.body, a + b)

expr = a
sorted_nodes = topological_sort(expr)
@test isempty(sorted_nodes)
let_expr = cse(expr)
@test isempty(let_expr.pairs)
@test isequal(let_expr.body, a)

# array symbolics
# https://github.com/JuliaSymbolics/SymbolicUtils.jl/pull/688#pullrequestreview-2554931739
@syms c
function foo end
ex = term(foo, [a^2 + b^2, b^2 + c], c; type = Real)
sorted_nodes = topological_sort(ex)
@test length(sorted_nodes) == 6
end
Loading