Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add specialization for subtraction and addition with negative terms #549

Merged
merged 3 commits into from
Oct 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 51 additions & 32 deletions aesara/tensor/rewriting/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1788,6 +1788,53 @@ def local_neg_div_neg(fgraph, node):
return [true_div(new_num, denom)]


@register_canonicalize
@register_specialize
@node_rewriter([sub])
def local_sub_neg_to_add(fgraph, node):
"""
x - (-y) -> x + y

"""
if node.op == sub:
minuend, subtrahend = node.inputs

if subtrahend.owner:
if subtrahend.owner.op == neg:
pre_neg = subtrahend.owner.inputs[0]
new_out = add(minuend, pre_neg)
return [new_out]


@register_specialize
@node_rewriter([add])
def local_add_neg_to_sub(fgraph, node):
"""
-x + y -> y - x
x + (-y) -> x - y

"""
# This rewrite is only registered during specialization, because the
# `local_neg_to_mul` rewrite modifies the relevant pattern during canonicalization

# Rewrite is only applicable when there are two inputs to add
if node.op == add and len(node.inputs) == 2:

# Look for pattern with either input order
for first, second in (node.inputs, reversed(node.inputs)):
if second.owner:
if second.owner.op == neg:
pre_neg = second.owner.inputs[0]
new_out = sub(first, pre_neg)
return [new_out]

# Check if it is a negative constant
const = get_constant(second)
if const is not None and const < 0:
new_out = sub(first, np.abs(const))
return [new_out]


@register_canonicalize
@node_rewriter([mul])
def local_mul_zero(fgraph, node):
Expand Down Expand Up @@ -2528,8 +2575,6 @@ def local_greedy_distributor(fgraph, node):
register_stabilize(local_one_plus_erf)
register_specialize(local_one_plus_erf)

# Only one of the two rewrites below is needed if a canonicalization is added
# for sub(x, y) -> add(x, -y) or a specialization for add(x, -y) -> sub(x, y)
# 1-erf(x)=>erfc(x)
local_one_minus_erf = PatternNodeRewriter(
(sub, 1, (erf, "x")),
Expand All @@ -2543,21 +2588,9 @@ def local_greedy_distributor(fgraph, node):
register_stabilize(local_one_minus_erf)
register_specialize(local_one_minus_erf)

local_one_minus_erf2 = PatternNodeRewriter(
(add, 1, (neg, (erf, "x"))),
(erfc, "x"),
allow_multiple_clients=True,
name="local_one_minus_erf2",
tracks=[erf],
get_nodes=get_clients_at_depth2,
)
register_canonicalize(local_one_minus_erf2)
register_stabilize(local_one_minus_erf2)
register_specialize(local_one_minus_erf2)

# (-1)+erf(x) => -erfc(x)
# There is no need for erf(x)+(-1) nor erf(x) - 1, as the canonicalize will
# convert those to the matched pattern
# There is no need for erf(x)+(-1) nor erf(x) - 1, as the `local_add_mul`
# canonicalize will convert those to the matched pattern
local_erf_minus_one = PatternNodeRewriter(
(add, -1, (erf, "x")),
(neg, (erfc, "x")),
Expand All @@ -2570,8 +2603,6 @@ def local_greedy_distributor(fgraph, node):
register_stabilize(local_erf_minus_one)
register_specialize(local_erf_minus_one)

# Only one of the two rewrites below is needed if a canonicalization is added
# for sub(x, y) -> add(x, -y) or a specialization for add(x, -y) -> sub(x, y)
# 1-erfc(x) => erf(x)
local_one_minus_erfc = PatternNodeRewriter(
(sub, 1, (erfc, "x")),
Expand All @@ -2585,21 +2616,9 @@ def local_greedy_distributor(fgraph, node):
register_stabilize(local_one_minus_erfc)
register_specialize(local_one_minus_erfc)

local_one_minus_erfc2 = PatternNodeRewriter(
(add, 1, (neg, (erfc, "x"))),
(erf, "x"),
allow_multiple_clients=True,
name="local_one_minus_erfc2",
tracks=[erfc],
get_nodes=get_clients_at_depth2,
)
register_canonicalize(local_one_minus_erfc2)
register_stabilize(local_one_minus_erfc2)
register_specialize(local_one_minus_erfc2)

# (-1)+erfc(-x)=>erf(x)
# erfc(-x)-1=>erf(x)
local_erf_neg_minus_one = PatternNodeRewriter(
(add, -1, (erfc, (neg, "x"))),
(sub, (erfc, (neg, "x")), 1),
(erf, "x"),
allow_multiple_clients=True,
name="local_erf_neg_minus_one",
Expand Down
85 changes: 82 additions & 3 deletions tests/tensor/rewriting/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -4042,9 +4042,14 @@ def test_local_expm1():
for n in h.maker.fgraph.toposort()
)

assert not any(
isinstance(n.op, Elemwise) and isinstance(n.op.scalar_op, aes.basic.Expm1)
for n in r.maker.fgraph.toposort()
# This rewrite works when `local_add_neg_to_sub` specialization rewrite is invoked
expect_rewrite = config.mode != "FAST_COMPILE"
assert (
any(
isinstance(n.op, Elemwise) and isinstance(n.op.scalar_op, aes.basic.Expm1)
for n in r.maker.fgraph.toposort()
)
== expect_rewrite
)


Expand Down Expand Up @@ -4618,3 +4623,77 @@ def test_deprecations():
"""Make sure we can import from deprecated modules."""
with pytest.deprecated_call():
from aesara.tensor.math_opt import AlgebraicCanonizer # noqa: F401 F811


def test_local_sub_neg_to_add():
x = scalar("x")
y = vector("y")

f = function([x, y], x - (-y), mode=Mode("py"))

nodes = [
node.op
for node in f.maker.fgraph.toposort()
if not isinstance(node.op, DimShuffle)
]
assert nodes == [at.add]

x_test = np.full((), 1.0, dtype=config.floatX)
y_test = np.full(5, 2.0, dtype=config.floatX)
assert np.allclose(f(x_test, y_test), x_test - (-y_test))


def test_local_sub_neg_to_add_const():
# This rewrite is achieved by the local_add_canonizer
x = vector("x")
const = 5.0

f = function([x], x - (-const), mode=Mode("py"))

nodes = [
node.op
for node in f.maker.fgraph.toposort()
if not isinstance(node.op, DimShuffle)
]
assert nodes == [at.add]

x_test = np.array([3, 4], dtype=config.floatX)
assert np.allclose(f(x_test), x_test - (-const))


@pytest.mark.parametrize("first_negative", (True, False))
def test_local_add_neg_to_sub(first_negative):
x = scalar("x")
y = vector("y")
out = -x + y if first_negative else x + (-y)

f = function([x, y], out, mode=Mode("py"))

nodes = [
node.op
for node in f.maker.fgraph.toposort()
if not isinstance(node.op, DimShuffle)
]
assert nodes == [at.sub]

x_test = np.full((), 1.0, dtype=config.floatX)
y_test = np.full(5, 2.0, dtype=config.floatX)
exp = -x_test + y_test if first_negative else x_test + (-y_test)
assert np.allclose(f(x_test, y_test), exp)


def test_local_add_neg_to_sub_const():
x = vector("x")
const = 5.0

f = function([x], x + (-const), mode=Mode("py"))

nodes = [
node.op
for node in f.maker.fgraph.toposort()
if not isinstance(node.op, DimShuffle)
]
assert nodes == [at.sub]

x_test = np.array([3, 4], dtype=config.floatX)
assert np.allclose(f(x_test), x_test + (-const))