diff --git a/src/code.jl b/src/code.jl index 774c2693..e25a661a 100644 --- a/src/code.jl +++ b/src/code.jl @@ -1,6 +1,7 @@ module Code -using StaticArrays, SparseArrays, LinearAlgebra, NaNMath, SpecialFunctions +using StaticArrays, SparseArrays, LinearAlgebra, NaNMath, SpecialFunctions, + DocStringExtensions export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr, SetArray, MakeArray, MakeSparseArray, MakeTuple, AtIndex, @@ -696,6 +697,52 @@ end @inline newsym(::Type{T}) where T = Sym{T}(gensym("cse")) +""" +$(SIGNATURES) + +Perform a topological sort on a symbolic expression represented as a Directed Acyclic +Graph (DAG). + +This function takes a symbolic expression `graph` (potentially containing shared common +sub-expressions) and returns an array of `Assignment` objects. Each `Assignment` +represents a node in the sorted order, assigning a fresh symbol to its corresponding +expression. The order ensures that all dependencies of a node appear before the node itself +in the array. + +Hash consing is assumed, meaning that structurally identical expressions are represented by +the same object in memory. This allows for efficient equality checks using `IdDict`. +""" +function topological_sort(graph) + sorted_nodes = Assignment[] + visited = IdDict() + + function dfs(node) + if haskey(visited, node) + return visited[node] + end + if iscall(node) + args = map(dfs, arguments(node)) + new_node = maketerm(typeof(node), operation(node), args, metadata(node)) + sym = newsym(symtype(new_node)) + push!(sorted_nodes, sym ← new_node) + visited[node] = sym + return sym + elseif _is_array_of_symbolics(node) + new_node = map(dfs, node) + sym = newsym(typeof(new_node)) + push!(sorted_nodes, sym ← new_node) + visited[node] = sym + return sym + else + visited[node] = node + return node + end + end + + dfs(graph) + return sorted_nodes +end + function _cse!(mem, expr) iscall(expr) || return expr op = _cse!(mem, operation(expr)) @@ -714,12 +761,16 @@ function _cse!(mem, expr) end function cse(expr) - state = Dict{Any, Int}() - cse_state!(state, expr) - cse_block(state, expr) + sorted_nodes = topological_sort(expr) + if isempty(sorted_nodes) + return Let(Assignment[], expr) + else + last_assignment = pop!(sorted_nodes) + body = rhs(last_assignment) + return Let(sorted_nodes, body) + end end - function _cse(exprs::AbstractArray) letblock = cse(Term{Any}(tuple, vec(exprs))) letblock.pairs, reshape(arguments(letblock.body), size(exprs)) @@ -746,41 +797,4 @@ function cse(x::MakeSparseArray) end end - -function cse_state!(state, t) - !iscall(t) && return t - state[t] = Base.get(state, t, 0) + 1 - foreach(x->cse_state!(state, x), arguments(t)) -end - -function cse_block!(assignments, counter, names, name, state, x) - if get(state, x, 0) > 1 - if haskey(names, x) - return names[x] - else - sym = Sym{symtype(x)}(Symbol(name, counter[])) - names[x] = sym - push!(assignments, sym ← x) - counter[] += 1 - return sym - end - elseif iscall(x) - args = map(a->cse_block!(assignments, counter, names, name, state,a), arguments(x)) - if isterm(x) - return term(operation(x), args...) - else - return maketerm(typeof(x), operation(x), args, metadata(x)) - end - else - return x - end -end - -function cse_block(state, t, name=Symbol("var-", hash(t))) - assignments = Assignment[] - counter = Ref{Int}(1) - names = Dict{Any, BasicSymbolic}() - Let(assignments, cse_block!(assignments, counter, names, name, state, t)) -end - end diff --git a/test/cse.jl b/test/cse.jl index fcc36d47..7e2ef68e 100644 --- a/test/cse.jl +++ b/test/cse.jl @@ -1,10 +1,56 @@ using SymbolicUtils, SymbolicUtils.Code, Test +using SymbolicUtils.Code: topological_sort + @testset "CSE" begin @syms x t = cse(hypot(hypot(cos(x), sin(x)), atan(cos(x), sin(x)))) @test t isa Let - @test length(t.pairs) == 2 - @test occursin(t.pairs[1].lhs, t.body) - @test occursin(t.pairs[2].lhs, t.body) + @test length(t.pairs) == 4 + @test occursin(t.pairs[3].lhs, t.body) + @test occursin(t.pairs[4].lhs, t.body) +end + +@testset "DAG CSE" begin + @syms a b + expr = sin(a + b) * (a + b) + sorted_nodes = topological_sort(expr) + @test length(sorted_nodes) == 3 + @test isequal(sorted_nodes[1].rhs, a + b) + @test isequal(sin(sorted_nodes[1].lhs), sorted_nodes[2].rhs) + + expr = (a + b)^(a + b) + sorted_nodes = topological_sort(expr) + @test length(sorted_nodes) == 2 + @test isequal(sorted_nodes[1].rhs, a + b) + ab_node = sorted_nodes[1].lhs + @test isequal(ab_node^ab_node, sorted_nodes[2].rhs) + let_expr = cse(expr) + @test length(let_expr.pairs) == 1 + @test isequal(let_expr.pairs[1].rhs, a + b) + corresponding_sym = let_expr.pairs[1].lhs + @test isequal(let_expr.body, corresponding_sym^corresponding_sym) + + expr = a + b + sorted_nodes = topological_sort(expr) + @test length(sorted_nodes) == 1 + @test isequal(sorted_nodes[1].rhs, a + b) + let_expr = cse(expr) + @test isempty(let_expr.pairs) + @test isequal(let_expr.body, a + b) + + expr = a + sorted_nodes = topological_sort(expr) + @test isempty(sorted_nodes) + let_expr = cse(expr) + @test isempty(let_expr.pairs) + @test isequal(let_expr.body, a) + + # array symbolics + # https://github.com/JuliaSymbolics/SymbolicUtils.jl/pull/688#pullrequestreview-2554931739 + @syms c + function foo end + ex = term(foo, [a^2 + b^2, b^2 + c], c; type = Real) + sorted_nodes = topological_sort(ex) + @test length(sorted_nodes) == 6 end