From beb29b3e095fd6c56e7a7afc2e2f155b33af1b9e Mon Sep 17 00:00:00 2001 From: Adrian Lehmann Date: Tue, 2 Jul 2024 17:47:17 -0500 Subject: [PATCH] Add normal form, leader rewriting and fix equiv proof not have congruence --- src/EGraphs/EGraphs.jl | 3 ++ src/EGraphs/exprproof.jl | 61 ++++++++++++++++++++++++++++++++++++++++ src/EGraphs/proof.jl | 36 ++++++++++++++++++++++-- test/egraphs/proof.jl | 34 ++++++++++++++++++++-- 4 files changed, 128 insertions(+), 6 deletions(-) create mode 100644 src/EGraphs/exprproof.jl diff --git a/src/EGraphs/EGraphs.jl b/src/EGraphs/EGraphs.jl index b9e2a1bb..3a68100c 100644 --- a/src/EGraphs/EGraphs.jl +++ b/src/EGraphs/EGraphs.jl @@ -32,4 +32,7 @@ using .Schedulers include("saturation.jl") export SaturationParams, saturate! +include("exprproof.jl") +export PositionedProof, find_node_proof + end diff --git a/src/EGraphs/exprproof.jl b/src/EGraphs/exprproof.jl new file mode 100644 index 00000000..183f0f72 --- /dev/null +++ b/src/EGraphs/exprproof.jl @@ -0,0 +1,61 @@ +export PositionedProof, find_node_proof + + +mutable struct PositionedProof + """ + Positioned proof is a structure that keeps track of where we apply proofs to in larger expressions. + """ + proof::Vector{ProofNode} + children::Vector{PositionedProof} + # TODO: Track what is matched +end + +function find_node_proof(g::EGraph, node1::Id, node2::Id)::PositionedProof + # Proof search that can deal with expressions, too. + + # Idea: + + # Walk expr trees + + # For each node: + # If has flat proof, proof to leader + # Else, recursively unfold + + # If no proof found for subexpr, return nothing + + # Issues: how to relate expressions? + # Especially if different Size + # e.g. a*(b+c) = ab+bc (which is different size AST) + # bigger problem comes when a=z then z*(b+c) = ab+bc + + # So I guess the way we should go about it is go to base terms, rewrite to leader + + flat_proof = find_flat_proof(g.proof, node1, node2) + # If there is a basic proof, no need to construct something more complicated + # TODO: Profile if this kills performance + if length(flat_proof) != 0 + return flat_proof + end + + # Idea: rewrite both sides to "normal forms" and concat + # TODO: This is definetely suboptimal and should be optimized + + + +end + +# +function rewrite_to_normal_form(g::EGraph, node::Id)::PositionedProof + # Start off by rewriting node to leader + lp = rewrite_to_leader(g.proof, node1) + leader = lp.leader + leader_proof = lp.proof + + expr = g.nodes[leader] + proof = PositionedProof(leader_proof, []) + + for (idx, child) in enumerate(v_children(expr)) + proof.children[idx] = rewrite_to_normal_form(g, child) + end + return PositionedProof +end \ No newline at end of file diff --git a/src/EGraphs/proof.jl b/src/EGraphs/proof.jl index 956d46fb..3e6d493a 100644 --- a/src/EGraphs/proof.jl +++ b/src/EGraphs/proof.jl @@ -1,4 +1,4 @@ -export ProofConnection, ProofNode, EGraphProof, find_flat_proof +export ProofConnection, ProofNode, EGraphProof, find_flat_proof, rewrite_to_leader mutable struct ProofConnection """ @@ -24,6 +24,7 @@ end mutable struct ProofNode + # TODO: Explain existence_node::Id # TODO is this the parent in the unionfind? parent_connection::ProofConnection @@ -79,6 +80,7 @@ function make_leader(proof::EGraphProof, node::Id)::Bool true end + function Base.union!(proof::EGraphProof, node1::Id, node2::Id, rule_idx::Int) # TODO maybe should have extra argument called `rhs_new` in egg that is true when called from # application of rules where the instantiation of the rhs creates new e-classes @@ -106,7 +108,11 @@ end @inline isroot(pn::ProofNode) = isroot(pn.parent_connection) @inline isroot(pc::ProofConnection) = pc.current === pc.next -function find_flat_proof(proof::EGraphProof, node1::Id, node2::Id) + + + + +function find_flat_proof(proof::EGraphProof, node1::Id, node2::Id)::Vector{ProofNode} # We're doing a lowest common ancestor search. # We cache the IDs we have seen seen_set = Set{Id}() @@ -117,6 +123,9 @@ function find_flat_proof(proof::EGraphProof, node1::Id, node2::Id) # No existence_node would ever have id 0 lca = UInt(0) curr = proof.explain_find[node1] + if (node1 == node2) + return [curr] + end # Walk up to the root while true @@ -155,4 +164,25 @@ function find_flat_proof(proof::EGraphProof, node1::Id, node2::Id) # TODO maybe reverse append!(ret, walk_from2) ret -end \ No newline at end of file +end + +struct LeaderProof + leader::Id + proof::Vector{ProofNode} +end + +function rewrite_to_leader(proof::EGraphProof, node::Id)::LeaderProof + # Returns the leader of e-class and a proof to transform node into said leader + curr_proof = proof.explain_find[node] + proofs = [] + final_id = node + if curr_proof.parent_connection.current == curr_proof.parent_connection.next + return LeaderProof(node, [curr_proof]) # Special case to report congruence + end + while curr_proof.parent_connection.current != curr_proof.parent_connection.next + push!(proofs, curr_proof) + final_id = curr_proof.parent_connection.next + curr_proof = proof.explain_find[curr_proof.parent_connection.next] + end + return LeaderProof(final_id, proofs) +end diff --git a/test/egraphs/proof.jl b/test/egraphs/proof.jl index 69c2144a..36dfc362 100644 --- a/test/egraphs/proof.jl +++ b/test/egraphs/proof.jl @@ -1,8 +1,12 @@ using Metatheory, Test +using Metatheory.Library + g = EGraph(; proof = true) id_a = addexpr!(g, :a) +println(find_flat_proof(g.proof, id_a, id_a)) +@test length(find_flat_proof(g.proof, id_a, id_a)) == 1 # print_proof(g) @@ -31,10 +35,34 @@ id_d = addexpr!(g, :d) union!(g, id_a, id_d, 3) print_proof(g) - +println(find_flat_proof(g.proof, id_c, id_d)) # Takes 4 steps -@test length(find_flat_proof(g.proof, id_a, id_d)) == 4 +@test length(find_flat_proof(g.proof, id_c, id_d)) == 3 + +# TODO: Why doesn't d have a its leader +for id in [id_a, id_b, id_c, id_d] + leader = rewrite_to_leader(g.proof, id) + @test leader.leader == id_d + @test length(leader.proof) == length(find_flat_proof(g.proof, id, id_a)) +end + id_e = addexpr!(g, :e) -@test isempty(find_flat_proof(g.proof, id_a, id_e)) \ No newline at end of file +@test isempty(find_flat_proof(g.proof, id_a, id_e)) + +comm_monoid = @commutative_monoid (*) 1 + +fold_mul = @theory begin + ~a::Number * ~b::Number => ~a * ~b +end + +ex = :(a * 4) +id_ex = addexpr!(g, ex) +ex_to = :(e * 4) +id_ex_to = addexpr!(g, ex_to) +print_proof(g) + +println(find_node_proof(g, id_ex, id_ex_to)) # Current challenge + +