Skip to content

Commit

Permalink
Remove redundant erf(c) rewrites
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo authored and rlouf committed Oct 18, 2022
1 parent 9f4650f commit 1846741
Showing 1 changed file with 4 additions and 32 deletions.
36 changes: 4 additions & 32 deletions aesara/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -2575,8 +2575,6 @@ def local_greedy_distributor(fgraph, node):
register_stabilize(local_one_plus_erf)
register_specialize(local_one_plus_erf)

# Only one of the two rewrites below is needed if a canonicalization is added
# for sub(x, y) -> add(x, -y) or a specialization for add(x, -y) -> sub(x, y)
# 1-erf(x)=>erfc(x)
local_one_minus_erf = PatternNodeRewriter(
(sub, 1, (erf, "x")),
Expand All @@ -2590,21 +2588,9 @@ def local_greedy_distributor(fgraph, node):
register_stabilize(local_one_minus_erf)
register_specialize(local_one_minus_erf)

local_one_minus_erf2 = PatternNodeRewriter(
(add, 1, (neg, (erf, "x"))),
(erfc, "x"),
allow_multiple_clients=True,
name="local_one_minus_erf2",
tracks=[erf],
get_nodes=get_clients_at_depth2,
)
register_canonicalize(local_one_minus_erf2)
register_stabilize(local_one_minus_erf2)
register_specialize(local_one_minus_erf2)

# (-1)+erf(x) => -erfc(x)
# There is no need for erf(x)+(-1) nor erf(x) - 1, as the canonicalize will
# convert those to the matched pattern
# There is no need for erf(x)+(-1) nor erf(x) - 1, as the `local_add_mul`
# canonicalize will convert those to the matched pattern
local_erf_minus_one = PatternNodeRewriter(
(add, -1, (erf, "x")),
(neg, (erfc, "x")),
Expand All @@ -2617,8 +2603,6 @@ def local_greedy_distributor(fgraph, node):
register_stabilize(local_erf_minus_one)
register_specialize(local_erf_minus_one)

# Only one of the two rewrites below is needed if a canonicalization is added
# for sub(x, y) -> add(x, -y) or a specialization for add(x, -y) -> sub(x, y)
# 1-erfc(x) => erf(x)
local_one_minus_erfc = PatternNodeRewriter(
(sub, 1, (erfc, "x")),
Expand All @@ -2632,21 +2616,9 @@ def local_greedy_distributor(fgraph, node):
register_stabilize(local_one_minus_erfc)
register_specialize(local_one_minus_erfc)

local_one_minus_erfc2 = PatternNodeRewriter(
(add, 1, (neg, (erfc, "x"))),
(erf, "x"),
allow_multiple_clients=True,
name="local_one_minus_erfc2",
tracks=[erfc],
get_nodes=get_clients_at_depth2,
)
register_canonicalize(local_one_minus_erfc2)
register_stabilize(local_one_minus_erfc2)
register_specialize(local_one_minus_erfc2)

# (-1)+erfc(-x)=>erf(x)
# erfc(-x)-1=>erf(x)
local_erf_neg_minus_one = PatternNodeRewriter(
(add, -1, (erfc, (neg, "x"))),
(sub, (erfc, (neg, "x")), 1),
(erf, "x"),
allow_multiple_clients=True,
name="local_erf_neg_minus_one",
Expand Down

0 comments on commit 1846741

Please sign in to comment.