diff --git a/aesara/tensor/math_opt.py b/aesara/tensor/math_opt.py index fcf66e8f96..ffe3c3964f 100644 --- a/aesara/tensor/math_opt.py +++ b/aesara/tensor/math_opt.py @@ -1888,6 +1888,36 @@ def local_sub_neg_to_add(fgraph, node): return [new_out] +@register_canonicalize +@local_optimizer([add]) +def local_add_neg_to_sub(fgraph, node): + """ + -x + y -> y - x + x + (-y) -> x - y + """ + + # Rewrite is only applicable when there are two inputs to add + if node.op == add and len(node.inputs) == 2: + node.outputs[0].dtype + + # Look for pattern with either input order + for first, second in (node.inputs, reversed(node.inputs)): + if second.owner: + if second.owner.op == neg: + pre_neg = second.owner.inputs[0] + new_out = sub(first, pre_neg) + return [new_out] + + # Check if negation is hidden behind a DimShuffle + elif isinstance(second.owner.op, DimShuffle): + dimshuffle_op = second.owner.op + dimshuffle_input = second.owner.inputs[0] + if dimshuffle_input.owner and dimshuffle_input.owner.op == neg: + pre_neg = dimshuffle_input.owner.inputs[0] + new_out = sub(first, dimshuffle_op(pre_neg)) + return [new_out] + + @local_optimizer([mul]) def local_mul_zero(fgraph, node): """ diff --git a/tests/tensor/test_math_opt.py b/tests/tensor/test_math_opt.py index 627186a81f..3e0885621a 100644 --- a/tests/tensor/test_math_opt.py +++ b/tests/tensor/test_math_opt.py @@ -4566,3 +4566,36 @@ def test_local_sub_neg_to_add(x, y, x_test, y_test): ] assert nodes == [aet.add] assert np.allclose(f(x_test, y_test), x_test - (-y_test)) + + +@pytest.mark.parametrize( + "x, x_test", + [ + (scalar(), np.full((), 1.0, dtype=config.floatX)), + (vector(), np.full(1, 2.0, dtype=config.floatX)), + (matrix(), np.full((2, 2), 3.0, dtype=config.floatX)), + ], +) +@pytest.mark.parametrize( + "y, y_test", + [ + (scalar(), np.full((), 1.0, dtype=config.floatX)), + (vector(), np.full(1, 2.0, dtype=config.floatX)), + (matrix(), np.full((2, 2), 3.0, dtype=config.floatX)), + ], +) +@pytest.mark.parametrize("first_negative", (True, False)) +def test_local_add_neg_to_sub(x, y, x_test, y_test, first_negative): + mode = Mode("py").including("canonicalize") + + out = -x + y if first_negative else x + (-y) + f = function([x, y], out, mode=mode) + + nodes = [ + node.op + for node in f.maker.fgraph.toposort() + if not isinstance(node.op, DimShuffle) + ] + assert nodes == [aet.sub] + exp = -x_test + y_test if first_negative else x_test + (-y_test) + assert np.allclose(f(x_test, y_test), exp)