Skip to content

Commit

Permalink
Add normal form, leader rewriting and fix equiv proof not have congru…
Browse files Browse the repository at this point in the history
…ence
  • Loading branch information
adrianleh committed Jul 2, 2024
1 parent 539c2c3 commit beb29b3
Show file tree
Hide file tree
Showing 4 changed files with 128 additions and 6 deletions.
3 changes: 3 additions & 0 deletions src/EGraphs/EGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,7 @@ using .Schedulers
include("saturation.jl")
export SaturationParams, saturate!

include("exprproof.jl")
export PositionedProof, find_node_proof

end
61 changes: 61 additions & 0 deletions src/EGraphs/exprproof.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
export PositionedProof, find_node_proof


mutable struct PositionedProof
"""
Positioned proof is a structure that keeps track of where we apply proofs to in larger expressions.
"""
proof::Vector{ProofNode}
children::Vector{PositionedProof}
# TODO: Track what is matched
end

function find_node_proof(g::EGraph, node1::Id, node2::Id)::PositionedProof
# Proof search that can deal with expressions, too.

# Idea:

# Walk expr trees

# For each node:
# If has flat proof, proof to leader
# Else, recursively unfold

# If no proof found for subexpr, return nothing

# Issues: how to relate expressions?
# Especially if different Size
# e.g. a*(b+c) = ab+bc (which is different size AST)
# bigger problem comes when a=z then z*(b+c) = ab+bc

# So I guess the way we should go about it is go to base terms, rewrite to leader

flat_proof = find_flat_proof(g.proof, node1, node2)
# If there is a basic proof, no need to construct something more complicated
# TODO: Profile if this kills performance
if length(flat_proof) != 0
return flat_proof
end

# Idea: rewrite both sides to "normal forms" and concat
# TODO: This is definetely suboptimal and should be optimized



end

#
function rewrite_to_normal_form(g::EGraph, node::Id)::PositionedProof
# Start off by rewriting node to leader
lp = rewrite_to_leader(g.proof, node1)
leader = lp.leader
leader_proof = lp.proof

expr = g.nodes[leader]
proof = PositionedProof(leader_proof, [])

for (idx, child) in enumerate(v_children(expr))
proof.children[idx] = rewrite_to_normal_form(g, child)
end
return PositionedProof
end
36 changes: 33 additions & 3 deletions src/EGraphs/proof.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export ProofConnection, ProofNode, EGraphProof, find_flat_proof
export ProofConnection, ProofNode, EGraphProof, find_flat_proof, rewrite_to_leader

mutable struct ProofConnection
"""
Expand All @@ -24,6 +24,7 @@ end


mutable struct ProofNode
# TODO: Explain
existence_node::Id
# TODO is this the parent in the unionfind?
parent_connection::ProofConnection
Expand Down Expand Up @@ -79,6 +80,7 @@ function make_leader(proof::EGraphProof, node::Id)::Bool
true
end


function Base.union!(proof::EGraphProof, node1::Id, node2::Id, rule_idx::Int)
# TODO maybe should have extra argument called `rhs_new` in egg that is true when called from
# application of rules where the instantiation of the rhs creates new e-classes
Expand Down Expand Up @@ -106,7 +108,11 @@ end
@inline isroot(pn::ProofNode) = isroot(pn.parent_connection)
@inline isroot(pc::ProofConnection) = pc.current === pc.next

function find_flat_proof(proof::EGraphProof, node1::Id, node2::Id)




function find_flat_proof(proof::EGraphProof, node1::Id, node2::Id)::Vector{ProofNode}
# We're doing a lowest common ancestor search.
# We cache the IDs we have seen
seen_set = Set{Id}()
Expand All @@ -117,6 +123,9 @@ function find_flat_proof(proof::EGraphProof, node1::Id, node2::Id)
# No existence_node would ever have id 0
lca = UInt(0)
curr = proof.explain_find[node1]
if (node1 == node2)
return [curr]
end

# Walk up to the root
while true
Expand Down Expand Up @@ -155,4 +164,25 @@ function find_flat_proof(proof::EGraphProof, node1::Id, node2::Id)
# TODO maybe reverse
append!(ret, walk_from2)
ret
end
end

struct LeaderProof
leader::Id
proof::Vector{ProofNode}
end

function rewrite_to_leader(proof::EGraphProof, node::Id)::LeaderProof
# Returns the leader of e-class and a proof to transform node into said leader
curr_proof = proof.explain_find[node]
proofs = []
final_id = node
if curr_proof.parent_connection.current == curr_proof.parent_connection.next
return LeaderProof(node, [curr_proof]) # Special case to report congruence
end
while curr_proof.parent_connection.current != curr_proof.parent_connection.next
push!(proofs, curr_proof)
final_id = curr_proof.parent_connection.next
curr_proof = proof.explain_find[curr_proof.parent_connection.next]
end
return LeaderProof(final_id, proofs)
end
34 changes: 31 additions & 3 deletions test/egraphs/proof.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
using Metatheory, Test
using Metatheory.Library


g = EGraph(; proof = true)

id_a = addexpr!(g, :a)
println(find_flat_proof(g.proof, id_a, id_a))
@test length(find_flat_proof(g.proof, id_a, id_a)) == 1

# print_proof(g)

Expand Down Expand Up @@ -31,10 +35,34 @@ id_d = addexpr!(g, :d)

union!(g, id_a, id_d, 3)
print_proof(g)

println(find_flat_proof(g.proof, id_c, id_d))
# Takes 4 steps
@test length(find_flat_proof(g.proof, id_a, id_d)) == 4
@test length(find_flat_proof(g.proof, id_c, id_d)) == 3

# TODO: Why doesn't d have a its leader
for id in [id_a, id_b, id_c, id_d]
leader = rewrite_to_leader(g.proof, id)
@test leader.leader == id_d
@test length(leader.proof) == length(find_flat_proof(g.proof, id, id_a))
end



id_e = addexpr!(g, :e)
@test isempty(find_flat_proof(g.proof, id_a, id_e))
@test isempty(find_flat_proof(g.proof, id_a, id_e))

comm_monoid = @commutative_monoid (*) 1

fold_mul = @theory begin
~a::Number * ~b::Number => ~a * ~b
end

ex = :(a * 4)
id_ex = addexpr!(g, ex)
ex_to = :(e * 4)
id_ex_to = addexpr!(g, ex_to)
print_proof(g)

println(find_node_proof(g, id_ex, id_ex_to)) # Current challenge


0 comments on commit beb29b3

Please sign in to comment.