Skip to content

Commit

Permalink
Add canonicalization for addition with negation x + (-y) -> x - y
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo committed Aug 9, 2021
1 parent 8d03db2 commit 9e6b7e6
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
30 changes: 30 additions & 0 deletions aesara/tensor/math_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1888,6 +1888,36 @@ def local_sub_neg_to_add(fgraph, node):
return [new_out]


@register_canonicalize
@local_optimizer([add])
def local_add_neg_to_sub(fgraph, node):
"""
-x + y -> y - x
x + (-y) -> x - y
"""

# Rewrite is only applicable when there are two inputs to add
if node.op == add and len(node.inputs) == 2:
node.outputs[0].dtype

# Look for pattern with either input order
for first, second in (node.inputs, reversed(node.inputs)):
if second.owner:
if second.owner.op == neg:
pre_neg = second.owner.inputs[0]
new_out = sub(first, pre_neg)
return [new_out]

# Check if negation is hidden behind a DimShuffle
elif isinstance(second.owner.op, DimShuffle):
dimshuffle_op = second.owner.op
dimshuffle_input = second.owner.inputs[0]
if dimshuffle_input.owner and dimshuffle_input.owner.op == neg:
pre_neg = dimshuffle_input.owner.inputs[0]
new_out = sub(first, dimshuffle_op(pre_neg))
return [new_out]


@local_optimizer([mul])
def local_mul_zero(fgraph, node):
"""
Expand Down
33 changes: 33 additions & 0 deletions tests/tensor/test_math_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4566,3 +4566,36 @@ def test_local_sub_neg_to_add(x, y, x_test, y_test):
]
assert nodes == [aet.add]
assert np.allclose(f(x_test, y_test), x_test - (-y_test))


@pytest.mark.parametrize(
"x, x_test",
[
(scalar(), np.full((), 1.0, dtype=config.floatX)),
(vector(), np.full(1, 2.0, dtype=config.floatX)),
(matrix(), np.full((2, 2), 3.0, dtype=config.floatX)),
],
)
@pytest.mark.parametrize(
"y, y_test",
[
(scalar(), np.full((), 1.0, dtype=config.floatX)),
(vector(), np.full(1, 2.0, dtype=config.floatX)),
(matrix(), np.full((2, 2), 3.0, dtype=config.floatX)),
],
)
@pytest.mark.parametrize("first_negative", (True, False))
def test_local_add_neg_to_sub(x, y, x_test, y_test, first_negative):
mode = Mode("py").including("canonicalize")

out = -x + y if first_negative else x + (-y)
f = function([x, y], out, mode=mode)

nodes = [
node.op
for node in f.maker.fgraph.toposort()
if not isinstance(node.op, DimShuffle)
]
assert nodes == [aet.sub]
exp = -x_test + y_test if first_negative else x_test + (-y_test)
assert np.allclose(f(x_test, y_test), exp)

0 comments on commit 9e6b7e6

Please sign in to comment.