Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multithreading Support #231

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# 3.0
- Updated TermInterface to 1.0.1
- Use a custom per-thread stack for e-matching

# 2.0
- No longer dispatch against types, but instead dispatch against objects.
Expand Down Expand Up @@ -30,4 +31,4 @@ Metatheory.jl + SymbolicUtils.jl = ❤️
- Removed `@metatheory_init`
- Rules now support type and function predicates as in SymbolicUtils.jl
- Redesigned the library
- Introduced `@timerewrite` to time the execution of classical rewriting systems.
- Introduced `@timerewrite` to time the execution of classical rewriting systems.
5 changes: 3 additions & 2 deletions src/EGraphs/saturation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ function eqsat_search!(


@debug "SEARCHING"
stack = get_local_stack()
for (rule_idx, rule) in enumerate(theory)
prev_matches = n_matches
@timeit report.to string(rule_idx) begin
Expand All @@ -87,14 +88,14 @@ function eqsat_search!(

for i in ids_left
cansearch(scheduler, rule_idx, i) || continue
n_matches += rule.ematcher_left!(g, rule_idx, i, rule.stack, ematch_buffer)
n_matches += rule.ematcher_left!(g, rule_idx, i, stack, ematch_buffer)
inform!(scheduler, rule_idx, i, n_matches)
end

if is_bidirectional(rule)
for i in ids_right
cansearch(scheduler, rule_idx, i) || continue
n_matches += rule.ematcher_right!(g, rule_idx, i, rule.stack, ematch_buffer)
n_matches += rule.ematcher_right!(g, rule_idx, i, stack, ematch_buffer)
inform!(scheduler, rule_idx, i, n_matches)
end
end
Expand Down
49 changes: 37 additions & 12 deletions src/Rules.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module Rules

using Base.Threads
using TermInterface
using AutoHashEquals
using Metatheory.Patterns
Expand All @@ -16,7 +17,8 @@ export RewriteRule,
Theory,
direct,
direct_left_to_right,
direct_right_to_left
direct_right_to_left,
get_local_stack

const STACK_SIZE = 512

Expand All @@ -27,14 +29,14 @@ Rules defined as with the --> are
called *directed rewrite* rules. Application of a *directed rewrite* rule
is a replacement of the `left` pattern with
the `right` substitution, with the correct instantiation
of pattern variables.
of pattern variables.

```julia
@rule ~a * ~b --> ~b * ~a
```

An *equational rule* is a symbolic substitution rule with operator `==` that
can be rewritten bidirectionally. Therefore, it can only be used
An *equational rule* is a symbolic substitution rule with operator `==` that
can be rewritten bidirectionally. Therefore, it can only be used
with the EGraphs backend.

```julia
Expand All @@ -43,7 +45,7 @@ with the EGraphs backend.

Rules defined with the `!=` act as *anti*-rules for checking contradictions in e-graph
rewriting. If two terms, corresponding to the left and right hand side of an
*anti-rule* are found in an `EGraph`, saturation is halted immediately.
*anti-rule* are found in an `EGraph`, saturation is halted immediately.

```julia
!a != a
Expand Down Expand Up @@ -71,11 +73,34 @@ Base.@kwdef struct RewriteRule{Op<:Function}
ematcher_right!::Union{Nothing,Function} = nothing
matcher_left::Function
matcher_right::Union{Nothing,Function} = nothing
stack::OptBuffer{UInt16} = OptBuffer{UInt16}(STACK_SIZE)
lhs_original = nothing
rhs_original = nothing
end

const THREAD_STACKS = OptBuffer{UInt16}[]
"""
Retrieve the per-thread stack thread used for program counters in matching.

We need a stack for each thread so that multithreading works correctly.

Modeled off [Julia's global RNG](https://github.com/JuliaLang/julia/blob/bc4b2e848400764e389c825b57d1481ed76f4d85/stdlib/Random/src/RNGs.jl)
"""
@inline get_local_stack() = get_local_stack(Threads.threadid())
0x0f0f0f marked this conversation as resolved.
Show resolved Hide resolved
@noinline function get_local_stack(tid::Int)
@assert 0 < tid <= length(THREAD_STACKS)
if @inbounds isassigned(THREAD_STACKS, tid)
@inbounds stack = THREAD_STACKS[tid]
else
stack = OptBuffer{UInt16}(STACK_SIZE)
@inbounds THREAD_STACKS[tid] = stack
end
return stack
end

function __init__()
resize!(empty!(THREAD_STACKS), Threads.nthreads())
end

function --> end
const DirectedRule = RewriteRule{typeof(-->)}
const EqualityRule = RewriteRule{typeof(==)}
Expand All @@ -99,8 +124,8 @@ function Base.show(io::IO, r::RewriteRule)
end


(r::DirectedRule)(term) = r.matcher_left(term, (bindings...) -> instantiate(term, r.right, bindings), r.stack)
(r::DynamicRule)(term) = r.matcher_left(term, (bindings...) -> r.right(term, nothing, bindings...), r.stack)
(r::DirectedRule)(term) = r.matcher_left(term, (bindings...) -> instantiate(term, r.right, bindings), get_local_stack())
(r::DynamicRule)(term) = r.matcher_left(term, (bindings...) -> r.right(term, nothing, bindings...), get_local_stack())

# ---------------------
# Theories
Expand Down Expand Up @@ -165,21 +190,21 @@ function Base.inv(r::RewriteRule)
end

"""
Turns an EqualityRule into a DirectedRule. For example,
Turns an EqualityRule into a DirectedRule. For example,

```julia
direct(@rule f(~x) == g(~x)) == f(~x) --> g(~x)
direct(@rule f(~x) == g(~x)) == f(~x) --> g(~x)
```
"""
function direct(r::EqualityRule)
RewriteRule(r.name, -->, (getfield(r, k) for k in fieldnames(DirectedRule)[3:end])...)
end

"""
Turns an EqualityRule into a DirectedRule, but right to left. For example,
Turns an EqualityRule into a DirectedRule, but right to left. For example,

```julia
direct(@rule f(~x) == g(~x)) == g(~x) --> f(~x)
direct(@rule f(~x) == g(~x)) == g(~x) --> f(~x)
```
"""
direct_right_to_left(r::EqualityRule) = inv(direct(r))
Expand Down
52 changes: 26 additions & 26 deletions test/egraphs/ematch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,56 +13,56 @@ b = OptBuffer{UInt128}(10)
r = @rule 2 --> true
g = EGraph(2)

@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1
@test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1
0x0f0f0f marked this conversation as resolved.
Show resolved Hide resolved
end

@testset "Composite Ground Terms" begin
r = @rule f(2, 3) --> true
g = EGraph(:(f(2, 3)))

@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1
@test r.ematcher_left!(g, 0, Id(1), r.stack, b) == 0
@test r.ematcher_left!(g, 0, Id(2), r.stack, b) == 0
@test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1
@test r.ematcher_left!(g, 0, Id(1), get_local_stack(), b) == 0
@test r.ematcher_left!(g, 0, Id(2), get_local_stack(), b) == 0

g = EGraph(:(f(2, 4)))

@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0
@test r.ematcher_left!(g, 0, Id(1), r.stack, b) == 0
@test r.ematcher_left!(g, 0, Id(2), r.stack, b) == 0
@test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 0
@test r.ematcher_left!(g, 0, Id(1), get_local_stack(), b) == 0
@test r.ematcher_left!(g, 0, Id(2), get_local_stack(), b) == 0


r = @rule f(2, h(3, 4)) --> true
g = EGraph(:(f(2, h(3, 4))))

@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1
@test r.ematcher_left!(g, 0, Id(1), r.stack, b) == 0
@test r.ematcher_left!(g, 0, Id(2), r.stack, b) == 0
@test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1
@test r.ematcher_left!(g, 0, Id(1), get_local_stack(), b) == 0
@test r.ematcher_left!(g, 0, Id(2), get_local_stack(), b) == 0
end

@testset "Pattern Variables" begin
g = EGraph(:(f(2, 1)))
r = @rule ~a --> true

@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1
@test r.ematcher_left!(g, 0, Id(1), r.stack, b) == 1
@test r.ematcher_left!(g, 0, Id(2), r.stack, b) == 1
@test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1
@test r.ematcher_left!(g, 0, Id(1), get_local_stack(), b) == 1
@test r.ematcher_left!(g, 0, Id(2), get_local_stack(), b) == 1
end

@testset "Type Assertions" begin
r = @rule ~a::Int --> true
g = EGraph(:(f(2, 1)))
@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0
@test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 0

g = EGraph(:3)
@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1
@test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1

new_id = addexpr!(g, :f)
union!(g, g.root, new_id)

new_id = addexpr!(g, 4)
union!(g, g.root, new_id)

@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 2
@test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 2
end

@testset "Predicate Assertions" begin
Expand All @@ -78,40 +78,40 @@ end
end

g = EGraph(:(f(2, 1)))
@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0
@test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 0

g = EGraph(:2)
@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1
@test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1

g = EGraph(:3)
@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0
@test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 0

new_id = addexpr!(g, :f)
union!(g, g.root, new_id)

new_id = addexpr!(g, 4)
union!(g, g.root, new_id)

@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1
@test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1
end


@testset "Non-Ground Terms" begin
g = EGraph(:(f(2, 1)))
r = @rule f(2, ~a) --> true

@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1
@test r.ematcher_left!(g, 0, Id(1), r.stack, b) == 0
@test r.ematcher_left!(g, 0, Id(2), r.stack, b) == 0
@test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1
@test r.ematcher_left!(g, 0, Id(1), get_local_stack(), b) == 0
@test r.ematcher_left!(g, 0, Id(2), get_local_stack(), b) == 0

r = @rule f(~a, ~a) --> true
@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0
@test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 0

g = EGraph(:(f(2, 2)))
@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1
@test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1

g = EGraph(:(f(h(3, 4), h(3, 4))))
@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1
@test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1
end


Expand Down
Loading