Skip to content

Commit

Permalink
Remove redundant erf(c) rewrites
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo committed Mar 13, 2022
1 parent eacaa55 commit aeaf64d
Showing 1 changed file with 4 additions and 32 deletions.
36 changes: 4 additions & 32 deletions aesara/tensor/math_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2556,8 +2556,6 @@ def local_greedy_distributor(fgraph, node):
register_stabilize(local_one_plus_erf)
register_specialize(local_one_plus_erf)

# Only one of the two rewrites below is needed if a canonicalization is added
# for sub(x, y) -> add(x, -y) or a specialization for add(x, -y) -> sub(x, y)
# 1-erf(x)=>erfc(x)
local_one_minus_erf = PatternSub(
(sub, 1, (erf, "x")),
Expand All @@ -2571,21 +2569,9 @@ def local_greedy_distributor(fgraph, node):
register_stabilize(local_one_minus_erf)
register_specialize(local_one_minus_erf)

local_one_minus_erf2 = PatternSub(
(add, 1, (neg, (erf, "x"))),
(erfc, "x"),
allow_multiple_clients=True,
name="local_one_minus_erf2",
tracks=[erf],
get_nodes=get_clients_at_depth2,
)
register_canonicalize(local_one_minus_erf2)
register_stabilize(local_one_minus_erf2)
register_specialize(local_one_minus_erf2)

# (-1)+erf(x) => -erfc(x)
# There is no need for erf(x)+(-1) nor erf(x) - 1, as the canonicalize will
# convert those to the matched pattern
# There is no need for erf(x)+(-1) nor erf(x) - 1, as the local_add_mul canonicalize
# will convert those to the matched pattern
local_erf_minus_one = PatternSub(
(add, -1, (erf, "x")),
(neg, (erfc, "x")),
Expand All @@ -2598,8 +2584,6 @@ def local_greedy_distributor(fgraph, node):
register_stabilize(local_erf_minus_one)
register_specialize(local_erf_minus_one)

# Only one of the two rewrites below is needed if a canonicalization is added
# for sub(x, y) -> add(x, -y) or a specialization for add(x, -y) -> sub(x, y)
# 1-erfc(x) => erf(x)
local_one_minus_erfc = PatternSub(
(sub, 1, (erfc, "x")),
Expand All @@ -2613,21 +2597,9 @@ def local_greedy_distributor(fgraph, node):
register_stabilize(local_one_minus_erfc)
register_specialize(local_one_minus_erfc)

local_one_minus_erfc2 = PatternSub(
(add, 1, (neg, (erfc, "x"))),
(erf, "x"),
allow_multiple_clients=True,
name="local_one_minus_erfc2",
tracks=[erfc],
get_nodes=get_clients_at_depth2,
)
register_canonicalize(local_one_minus_erfc2)
register_stabilize(local_one_minus_erfc2)
register_specialize(local_one_minus_erfc2)

# (-1)+erfc(-x)=>erf(x)
# erfc(-x)-1=>erf(x)
local_erf_neg_minus_one = PatternSub(
(add, -1, (erfc, (neg, "x"))),
(sub, (erfc, (neg, "x")), 1),
(erf, "x"),
allow_multiple_clients=True,
name="local_erf_neg_minus_one",
Expand Down

0 comments on commit aeaf64d

Please sign in to comment.