Skip to content

Commit

Permalink
Add canonicalization for subtraction with negation x - (-y) -> x + y
Browse files Browse the repository at this point in the history
  • Loading branch information
Ricardo committed Aug 9, 2021
1 parent 36e2e1a commit 8d03db2
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 0 deletions.
26 changes: 26 additions & 0 deletions aesara/tensor/math_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1862,6 +1862,32 @@ def local_neg_div_neg(fgraph, node):
return [true_div(new_num, denom)]


@register_canonicalize
@local_optimizer([sub])
def local_sub_neg_to_add(fgraph, node):
"""
x - (-y) -> x + y
"""

if node.op == sub:
minuend, subtrahend = node.inputs

if subtrahend.owner:
if subtrahend.owner.op == neg:
pre_neg = subtrahend.owner.inputs[0]
new_out = add(minuend, pre_neg)
return [new_out]

# Check if negation is hidden behind a DimShuffle
elif isinstance(subtrahend.owner.op, DimShuffle):
dimshuffle_op = subtrahend.owner.op
dimshuffle_input = subtrahend.owner.inputs[0]
if dimshuffle_input.owner and dimshuffle_input.owner.op == neg:
pre_neg = dimshuffle_input.owner.inputs[0]
new_out = add(minuend, dimshuffle_op(pre_neg))
return [new_out]


@local_optimizer([mul])
def local_mul_zero(fgraph, node):
"""
Expand Down
30 changes: 30 additions & 0 deletions tests/tensor/test_math_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4536,3 +4536,33 @@ def test_log1mexp_stabilization():
f(np.array([-0.8, -0.6], dtype=config.floatX)),
np.log(1 - np.exp([-0.8, -0.6])),
)


@pytest.mark.parametrize(
"x, x_test",
[
(scalar(), np.full((), 1.0, dtype=config.floatX)),
(vector(), np.full(1, 2.0, dtype=config.floatX)),
(matrix(), np.full((2, 2), 3.0, dtype=config.floatX)),
],
)
@pytest.mark.parametrize(
"y, y_test",
[
(scalar(), np.full((), 1.0, dtype=config.floatX)),
(vector(), np.full(1, 2.0, dtype=config.floatX)),
(matrix(), np.full((2, 2), 3.0, dtype=config.floatX)),
],
)
def test_local_sub_neg_to_add(x, y, x_test, y_test):
mode = Mode("py").including("canonicalize")

f = function([x, y], x - (-y), mode=mode)

nodes = [
node.op
for node in f.maker.fgraph.toposort()
if not isinstance(node.op, DimShuffle)
]
assert nodes == [aet.add]
assert np.allclose(f(x_test, y_test), x_test - (-y_test))

0 comments on commit 8d03db2

Please sign in to comment.