From 9e76ebe3165a14facd4e2f2eaa9a485fa876e859 Mon Sep 17 00:00:00 2001 From: Owen Lynch Date: Fri, 16 Aug 2024 10:46:31 -0700 Subject: [PATCH 01/10] added threading support --- src/EGraphs/saturation.jl | 4 +-- src/Rules.jl | 27 +++++++++++++++++--- test/egraphs/ematch.jl | 52 +++++++++++++++++++-------------------- 3 files changed, 51 insertions(+), 32 deletions(-) diff --git a/src/EGraphs/saturation.jl b/src/EGraphs/saturation.jl index eeb6bb39..e90c46f5 100644 --- a/src/EGraphs/saturation.jl +++ b/src/EGraphs/saturation.jl @@ -87,14 +87,14 @@ function eqsat_search!( for i in ids_left cansearch(scheduler, rule_idx, i) || continue - n_matches += rule.ematcher_left!(g, rule_idx, i, rule.stack, ematch_buffer) + n_matches += rule.ematcher_left!(g, rule_idx, i, get_local_stack(), ematch_buffer) inform!(scheduler, rule_idx, i, n_matches) end if is_bidirectional(rule) for i in ids_right cansearch(scheduler, rule_idx, i) || continue - n_matches += rule.ematcher_right!(g, rule_idx, i, rule.stack, ematch_buffer) + n_matches += rule.ematcher_right!(g, rule_idx, i, get_local_stack(), ematch_buffer) inform!(scheduler, rule_idx, i, n_matches) end end diff --git a/src/Rules.jl b/src/Rules.jl index d32e064e..5ce97a49 100644 --- a/src/Rules.jl +++ b/src/Rules.jl @@ -1,5 +1,6 @@ module Rules +using Base.Threads using TermInterface using AutoHashEquals using Metatheory.Patterns @@ -16,7 +17,8 @@ export RewriteRule, Theory, direct, direct_left_to_right, - direct_right_to_left + direct_right_to_left, + get_local_stack const STACK_SIZE = 512 @@ -71,11 +73,28 @@ Base.@kwdef struct RewriteRule{Op<:Function} ematcher_right!::Union{Nothing,Function} = nothing matcher_left::Function matcher_right::Union{Nothing,Function} = nothing - stack::OptBuffer{UInt16} = OptBuffer{UInt16}(STACK_SIZE) lhs_original = nothing rhs_original = nothing end +# Modeled off https://github.com/JuliaLang/julia/blob/bc4b2e848400764e389c825b57d1481ed76f4d85/stdlib/Random/src/RNGs.jl +const THREAD_STACKS = OptBuffer{UInt16}[] +@inline get_local_stack() = get_local_stack(Threads.threadid()) +@noinline function get_local_stack(tid::Int) + @assert 0 < tid <= length(THREAD_STACKS) + if @inbounds isassigned(THREAD_STACKS, tid) + @inbounds stack = THREAD_STACKS[tid] + else + stack = OptBuffer{UInt16}(STACK_SIZE) + @inbounds THREAD_STACKS[tid] = stack + end + return stack +end + +function __init__() + resize!(empty!(THREAD_STACKS), Threads.nthreads()) +end + function --> end const DirectedRule = RewriteRule{typeof(-->)} const EqualityRule = RewriteRule{typeof(==)} @@ -99,8 +118,8 @@ function Base.show(io::IO, r::RewriteRule) end -(r::DirectedRule)(term) = r.matcher_left(term, (bindings...) -> instantiate(term, r.right, bindings), r.stack) -(r::DynamicRule)(term) = r.matcher_left(term, (bindings...) -> r.right(term, nothing, bindings...), r.stack) +(r::DirectedRule)(term) = r.matcher_left(term, (bindings...) -> instantiate(term, r.right, bindings), get_local_stack()) +(r::DynamicRule)(term) = r.matcher_left(term, (bindings...) -> r.right(term, nothing, bindings...), get_local_stack()) # --------------------- # Theories diff --git a/test/egraphs/ematch.jl b/test/egraphs/ematch.jl index d2c5fb6d..995ebba0 100644 --- a/test/egraphs/ematch.jl +++ b/test/egraphs/ematch.jl @@ -13,48 +13,48 @@ b = OptBuffer{UInt128}(10) r = @rule 2 --> true g = EGraph(2) - @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 + @test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1 end @testset "Composite Ground Terms" begin r = @rule f(2, 3) --> true g = EGraph(:(f(2, 3))) - @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 - @test r.ematcher_left!(g, 0, Id(1), r.stack, b) == 0 - @test r.ematcher_left!(g, 0, Id(2), r.stack, b) == 0 + @test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1 + @test r.ematcher_left!(g, 0, Id(1), get_local_stack(), b) == 0 + @test r.ematcher_left!(g, 0, Id(2), get_local_stack(), b) == 0 g = EGraph(:(f(2, 4))) - @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0 - @test r.ematcher_left!(g, 0, Id(1), r.stack, b) == 0 - @test r.ematcher_left!(g, 0, Id(2), r.stack, b) == 0 + @test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 0 + @test r.ematcher_left!(g, 0, Id(1), get_local_stack(), b) == 0 + @test r.ematcher_left!(g, 0, Id(2), get_local_stack(), b) == 0 r = @rule f(2, h(3, 4)) --> true g = EGraph(:(f(2, h(3, 4)))) - @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 - @test r.ematcher_left!(g, 0, Id(1), r.stack, b) == 0 - @test r.ematcher_left!(g, 0, Id(2), r.stack, b) == 0 + @test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1 + @test r.ematcher_left!(g, 0, Id(1), get_local_stack(), b) == 0 + @test r.ematcher_left!(g, 0, Id(2), get_local_stack(), b) == 0 end @testset "Pattern Variables" begin g = EGraph(:(f(2, 1))) r = @rule ~a --> true - @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 - @test r.ematcher_left!(g, 0, Id(1), r.stack, b) == 1 - @test r.ematcher_left!(g, 0, Id(2), r.stack, b) == 1 + @test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1 + @test r.ematcher_left!(g, 0, Id(1), get_local_stack(), b) == 1 + @test r.ematcher_left!(g, 0, Id(2), get_local_stack(), b) == 1 end @testset "Type Assertions" begin r = @rule ~a::Int --> true g = EGraph(:(f(2, 1))) - @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0 + @test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 0 g = EGraph(:3) - @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 + @test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1 new_id = addexpr!(g, :f) union!(g, g.root, new_id) @@ -62,7 +62,7 @@ end new_id = addexpr!(g, 4) union!(g, g.root, new_id) - @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 2 + @test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 2 end @testset "Predicate Assertions" begin @@ -78,13 +78,13 @@ end end g = EGraph(:(f(2, 1))) - @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0 + @test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 0 g = EGraph(:2) - @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 + @test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1 g = EGraph(:3) - @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0 + @test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 0 new_id = addexpr!(g, :f) union!(g, g.root, new_id) @@ -92,7 +92,7 @@ end new_id = addexpr!(g, 4) union!(g, g.root, new_id) - @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 + @test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1 end @@ -100,18 +100,18 @@ end g = EGraph(:(f(2, 1))) r = @rule f(2, ~a) --> true - @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 - @test r.ematcher_left!(g, 0, Id(1), r.stack, b) == 0 - @test r.ematcher_left!(g, 0, Id(2), r.stack, b) == 0 + @test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1 + @test r.ematcher_left!(g, 0, Id(1), get_local_stack(), b) == 0 + @test r.ematcher_left!(g, 0, Id(2), get_local_stack(), b) == 0 r = @rule f(~a, ~a) --> true - @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0 + @test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 0 g = EGraph(:(f(2, 2))) - @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 + @test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1 g = EGraph(:(f(h(3, 4), h(3, 4)))) - @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 + @test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1 end From dd76de604daec18ec5da3fd8d7159e7e67e51c1e Mon Sep 17 00:00:00 2001 From: Owen Lynch Date: Wed, 21 Aug 2024 11:37:41 -0700 Subject: [PATCH 02/10] addressed review comments --- src/EGraphs/saturation.jl | 5 +++-- src/Rules.jl | 24 +++++++++++++++--------- 2 files changed, 18 insertions(+), 11 deletions(-) diff --git a/src/EGraphs/saturation.jl b/src/EGraphs/saturation.jl index e90c46f5..f5a4241e 100644 --- a/src/EGraphs/saturation.jl +++ b/src/EGraphs/saturation.jl @@ -73,6 +73,7 @@ function eqsat_search!( @debug "SEARCHING" + stack = get_local_stack() for (rule_idx, rule) in enumerate(theory) prev_matches = n_matches @timeit report.to string(rule_idx) begin @@ -87,14 +88,14 @@ function eqsat_search!( for i in ids_left cansearch(scheduler, rule_idx, i) || continue - n_matches += rule.ematcher_left!(g, rule_idx, i, get_local_stack(), ematch_buffer) + n_matches += rule.ematcher_left!(g, rule_idx, i, stack, ematch_buffer) inform!(scheduler, rule_idx, i, n_matches) end if is_bidirectional(rule) for i in ids_right cansearch(scheduler, rule_idx, i) || continue - n_matches += rule.ematcher_right!(g, rule_idx, i, get_local_stack(), ematch_buffer) + n_matches += rule.ematcher_right!(g, rule_idx, i, stack, ematch_buffer) inform!(scheduler, rule_idx, i, n_matches) end end diff --git a/src/Rules.jl b/src/Rules.jl index 5ce97a49..7d3c0850 100644 --- a/src/Rules.jl +++ b/src/Rules.jl @@ -29,14 +29,14 @@ Rules defined as with the --> are called *directed rewrite* rules. Application of a *directed rewrite* rule is a replacement of the `left` pattern with the `right` substitution, with the correct instantiation -of pattern variables. +of pattern variables. ```julia @rule ~a * ~b --> ~b * ~a ``` -An *equational rule* is a symbolic substitution rule with operator `==` that -can be rewritten bidirectionally. Therefore, it can only be used +An *equational rule* is a symbolic substitution rule with operator `==` that +can be rewritten bidirectionally. Therefore, it can only be used with the EGraphs backend. ```julia @@ -45,7 +45,7 @@ with the EGraphs backend. Rules defined with the `!=` act as *anti*-rules for checking contradictions in e-graph rewriting. If two terms, corresponding to the left and right hand side of an -*anti-rule* are found in an `EGraph`, saturation is halted immediately. +*anti-rule* are found in an `EGraph`, saturation is halted immediately. ```julia !a != a @@ -77,8 +77,14 @@ Base.@kwdef struct RewriteRule{Op<:Function} rhs_original = nothing end -# Modeled off https://github.com/JuliaLang/julia/blob/bc4b2e848400764e389c825b57d1481ed76f4d85/stdlib/Random/src/RNGs.jl const THREAD_STACKS = OptBuffer{UInt16}[] +""" +Retrieve the per-thread stack thread used for program counters in matching. + +We need a stack for each thread so that multithreading works correctly. + +Modeled off [Julia's global RNG](https://github.com/JuliaLang/julia/blob/bc4b2e848400764e389c825b57d1481ed76f4d85/stdlib/Random/src/RNGs.jl) +""" @inline get_local_stack() = get_local_stack(Threads.threadid()) @noinline function get_local_stack(tid::Int) @assert 0 < tid <= length(THREAD_STACKS) @@ -184,10 +190,10 @@ function Base.inv(r::RewriteRule) end """ -Turns an EqualityRule into a DirectedRule. For example, +Turns an EqualityRule into a DirectedRule. For example, ```julia -direct(@rule f(~x) == g(~x)) == f(~x) --> g(~x) +direct(@rule f(~x) == g(~x)) == f(~x) --> g(~x) ``` """ function direct(r::EqualityRule) @@ -195,10 +201,10 @@ function direct(r::EqualityRule) end """ -Turns an EqualityRule into a DirectedRule, but right to left. For example, +Turns an EqualityRule into a DirectedRule, but right to left. For example, ```julia -direct(@rule f(~x) == g(~x)) == g(~x) --> f(~x) +direct(@rule f(~x) == g(~x)) == g(~x) --> f(~x) ``` """ direct_right_to_left(r::EqualityRule) = inv(direct(r)) From f4fc6308228770664b61ee82def90baae02f2365 Mon Sep 17 00:00:00 2001 From: Owen Lynch Date: Wed, 28 Aug 2024 15:52:42 -0700 Subject: [PATCH 03/10] ci test commit --- NEWS.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index 623e2ab3..0c49b75a 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,6 @@ # 3.0 - Updated TermInterface to 1.0.1 +- Use a custom per-thread stack for e-matching # 2.0 - No longer dispatch against types, but instead dispatch against objects. @@ -30,4 +31,4 @@ Metatheory.jl + SymbolicUtils.jl = ❤️ - Removed `@metatheory_init` - Rules now support type and function predicates as in SymbolicUtils.jl - Redesigned the library -- Introduced `@timerewrite` to time the execution of classical rewriting systems. \ No newline at end of file +- Introduced `@timerewrite` to time the execution of classical rewriting systems. From ca867735988b909fa401de2556fd1ad747f805cf Mon Sep 17 00:00:00 2001 From: a Date: Thu, 12 Sep 2024 10:44:14 +0200 Subject: [PATCH 04/10] change target --- .github/workflows/benchmark_pr.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/benchmark_pr.yml b/.github/workflows/benchmark_pr.yml index 587fb8d4..223d456c 100644 --- a/.github/workflows/benchmark_pr.yml +++ b/.github/workflows/benchmark_pr.yml @@ -47,7 +47,7 @@ jobs: ls -l ~/.julia/bin mkdir results benchpkg Metatheory \ - --rev="${{github.event.repository.default_branch}},${{github.event.pull_request.head.sha}}" \ + --rev="${{github.event.pull_request.base.sha}},${{github.event.pull_request.head.sha}}" \ --url=${{ github.event.repository.clone_url }} \ --bench-on="${{github.event.pull_request.head.sha}}" \ --output-dir=results/ --tune @@ -68,7 +68,7 @@ jobs: - name: Create markdown table from benchmarks run: | julia --project=egg-benchmark/scripts egg-benchmark/scripts/load_results.jl \ - -b ${{github.event.pull_request.head.sha}} -b "${{github.event.repository.default_branch}}" \ + -b ${{github.event.pull_request.head.sha}} -b "${{github.event.pull_request.base.sha}}" \ --mt-results=results/ \ --egg-results=egg-benchmark/target/criterion \ -o table.md From 111905fe912e003a2a974e42f3c92f76bcb8bfd8 Mon Sep 17 00:00:00 2001 From: Gabriel Kronberger Date: Sat, 21 Sep 2024 10:53:43 +0200 Subject: [PATCH 05/10] Add an unit test for race conditions in saturation. --- test/egraphs/concurrency.jl | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 test/egraphs/concurrency.jl diff --git a/test/egraphs/concurrency.jl b/test/egraphs/concurrency.jl new file mode 100644 index 00000000..8dbe07bc --- /dev/null +++ b/test/egraphs/concurrency.jl @@ -0,0 +1,28 @@ +# MT currently does not support thread-parallel saturation with a shared state. +# But it should be possible to saturate independent egraphs withing separate threads. + +using Test, Metatheory + +function run_eq() + theory = @theory a b c begin + a + b == b + a + a + (b + c) == (a + b) + c + end + + g = EGraph{Expr}(:(1 + (2 + (3 + (4 + (5 + 6)))))); + saturate!(g, theory, SaturationParams(timeout=100)) + end + + function test_threads() + @assert Threads.threadpoolsize() > 1 # this test is only useful in multi-threaded scenarios. + + # run equality saturation in parallel threads (no shared state) + Threads.@threads for _ in 1:1000 + run_eq() + end + true + end + +@testset "Concurrency" begin + @test test_threads() broken=true +end \ No newline at end of file From 57ad9770e39585cc6751cbb6adba97d0537e20c4 Mon Sep 17 00:00:00 2001 From: Gabriel Kronberger Date: Sat, 21 Sep 2024 11:03:12 +0200 Subject: [PATCH 06/10] Allow test to fail. --- test/egraphs/concurrency.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/egraphs/concurrency.jl b/test/egraphs/concurrency.jl index 8dbe07bc..77294baa 100644 --- a/test/egraphs/concurrency.jl +++ b/test/egraphs/concurrency.jl @@ -24,5 +24,5 @@ function run_eq() end @testset "Concurrency" begin - @test test_threads() broken=true + @test test_threads() end \ No newline at end of file From d8d63aa497cc6e8eae8b5d9dc0146f160ebb8eda Mon Sep 17 00:00:00 2001 From: Gabriel Kronberger Date: Sat, 21 Sep 2024 11:06:28 +0200 Subject: [PATCH 07/10] Set number of threads for runtests --- .github/workflows/ci.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 125f32f3..b35c9bd0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -38,6 +38,8 @@ jobs: ${{ runner.os }}- - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 + env: + JULIA_NUM_THREADS: 2 - uses: julia-actions/julia-processcoverage@v1 - uses: codecov/codecov-action@v4 with: From 5398503b9f77da44a6ac104f0b9d14510598e774 Mon Sep 17 00:00:00 2001 From: Gabriel Kronberger Date: Sat, 21 Sep 2024 11:13:00 +0200 Subject: [PATCH 08/10] Update the copy of the pattern enode in instantiate_enode!() --- src/EGraphs/saturation.jl | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/EGraphs/saturation.jl b/src/EGraphs/saturation.jl index db8f5643..8afb219a 100644 --- a/src/EGraphs/saturation.jl +++ b/src/EGraphs/saturation.jl @@ -124,21 +124,21 @@ end instantiate_enode!(bindings, @nospecialize(g::EGraph), p::PatVar)::Id = v_pair_first(bindings[p.idx]) function instantiate_enode!(bindings, g::EGraph{ExpressionType}, p::PatExpr)::Id where {ExpressionType} add_constant_hashed!(g, p.head, p.head_hash) - + n = copy(p.n) for i in v_children_range(p.n) - @inbounds p.n[i] = instantiate_enode!(bindings, g, p.children[i - VECEXPR_META_LENGTH]) + @inbounds n[i] = instantiate_enode!(bindings, g, p.children[i - VECEXPR_META_LENGTH]) end - add!(g, p.n, true) + add!(g, n, false) end function instantiate_enode!(bindings, g::EGraph{Expr}, p::PatExpr)::Id add_constant_hashed!(g, p.quoted_head, p.quoted_head_hash) v_set_head!(p.n, p.quoted_head_hash) - + n = copy(p.n) for i in v_children_range(p.n) - @inbounds p.n[i] = instantiate_enode!(bindings, g, p.children[i - VECEXPR_META_LENGTH]) + @inbounds n[i] = instantiate_enode!(bindings, g, p.children[i - VECEXPR_META_LENGTH]) end - add!(g, p.n, true) + add!(g, n, false) end """ From 8289f6188bdb38deff075f97fd3eaac6b7dd44d6 Mon Sep 17 00:00:00 2001 From: Gabriel Kronberger Date: Sat, 21 Sep 2024 11:21:31 +0200 Subject: [PATCH 09/10] Use nthreads() instead of threadpoolsize() because the later was only introduced in v1.9 --- test/egraphs/concurrency.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/egraphs/concurrency.jl b/test/egraphs/concurrency.jl index 77294baa..307538be 100644 --- a/test/egraphs/concurrency.jl +++ b/test/egraphs/concurrency.jl @@ -14,7 +14,7 @@ function run_eq() end function test_threads() - @assert Threads.threadpoolsize() > 1 # this test is only useful in multi-threaded scenarios. + @assert Threads.nthreads() > 1 # this test is only useful in multi-threaded scenarios. # run equality saturation in parallel threads (no shared state) Threads.@threads for _ in 1:1000 From 52b48461e7365a70050d4998ed9888e403003641 Mon Sep 17 00:00:00 2001 From: Gabriel Kronberger Date: Sat, 21 Sep 2024 11:31:06 +0200 Subject: [PATCH 10/10] Try to increase number of threads to increase the probability of fail. --- .github/workflows/ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b35c9bd0..5c100471 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -39,7 +39,7 @@ jobs: - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 env: - JULIA_NUM_THREADS: 2 + JULIA_NUM_THREADS: 20 - uses: julia-actions/julia-processcoverage@v1 - uses: codecov/codecov-action@v4 with: