diff --git a/aesara/tensor/elemwise.py b/aesara/tensor/elemwise.py index 9db605596d..dd3c25783f 100644 --- a/aesara/tensor/elemwise.py +++ b/aesara/tensor/elemwise.py @@ -16,7 +16,6 @@ from aesara.misc.safe_asarray import _asarray from aesara.printing import FunctionPrinter, Printer, pprint from aesara.scalar import get_scalar_type -from aesara.scalar.basic import ScalarType from aesara.scalar.basic import bool as scalar_bool from aesara.scalar.basic import identity as scalar_identity from aesara.scalar.basic import transfer_type, upcast @@ -804,37 +803,17 @@ def perform(self, node, inputs, output_storage): storage[0] = variable def infer_shape(self, fgraph, node, i_shapes): - rval = [] - for o in node.outputs: - oshp = [] - for dim, b in enumerate(o.type.broadcastable): - b_dim = None - if b: - # this is broadcastable - b_dim = 1 - else: - # there must be some input that is not broadcastable in - # dimension 'dim' - for ishp, i in zip(i_shapes, node.inputs): - if isinstance(i.type, ScalarType): - continue # we skip scalar - if not i.type.broadcastable[dim]: - # input i is not broadcastable in position dim - # therefore if its shape is known, we can use it - # as the output shape - if ishp[dim]: - b_dim = ishp[dim] - break - - # b_dim might still be None, if every input's shape was unknown - # in dimension 'dim' - oshp.append(b_dim) - # TODO: it would be interesting to return the constraining - # information that if one of the inputs shape[dim] is known - # and another input's shape[dim] is not, that we can now assume - # that the other input's shape[dim] is the same as the first. - rval.append(tuple(oshp)) - return rval + + if len(node.outputs) > 1: + from aesara.tensor.basic_opt import ShapeError + + raise ShapeError( + "Multiple outputs are not supported by the default `Elemwise.infer_shape`" + ) + + out_shape = aesara.tensor.broadcast_shape(*i_shapes, arrays_are_shapes=True) + + return [out_shape] def _c_all(self, node, nodename, inames, onames, sub): # Some `Op`s directly call `Elemwise._c_all` or `Elemwise.c_code` diff --git a/aesara/tensor/extra_ops.py b/aesara/tensor/extra_ops.py index a72d129f9e..c6a920e600 100644 --- a/aesara/tensor/extra_ops.py +++ b/aesara/tensor/extra_ops.py @@ -1,4 +1,5 @@ from collections.abc import Collection +from functools import reduce from typing import Iterable, Tuple, Union import numpy as np @@ -6,6 +7,7 @@ from numpy.core.multiarray import normalize_axis_index import aesara +import aesara.scalar.basic as aes from aesara.gradient import ( DisconnectedType, _float_zeros_like, @@ -26,9 +28,7 @@ from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.math import abs as at_abs from aesara.tensor.math import all as at_all -from aesara.tensor.math import eq, ge, lt -from aesara.tensor.math import max as at_max -from aesara.tensor.math import maximum, minimum, or_, prod +from aesara.tensor.math import ge, lt, maximum, minimum, prod from aesara.tensor.math import sum as at_sum from aesara.tensor.subtensor import advanced_inc_subtensor1, set_subtensor from aesara.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector @@ -1534,13 +1534,19 @@ def broadcast_shape_iter( result_dims.append(maybe_non_bcast_shapes[0]) continue - non_bcast_vec = at.as_tensor(maybe_non_bcast_shapes) - non_bcast_vec = at.switch(eq(non_bcast_vec, 1), -one_at, non_bcast_vec) - dim_max = at_abs(at_max(non_bcast_vec)) + non_bcast_vec = [ + aes.switch(aes.eq(nbv, 1), -one_at, nbv) + for nbv in maybe_non_bcast_shapes + ] + dim_max = aes.abs(reduce(aes.scalar_maximum, non_bcast_vec)) assert_dim = Assert("Could not broadcast dimensions") - assert_cond = at_all( - or_(eq(non_bcast_vec, -one_at), eq(non_bcast_vec, dim_max)) + assert_cond = reduce( + aes.and_, + ( + aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dim_max)) + for nbv in non_bcast_vec + ), ) bcast_dim = assert_dim(dim_max, assert_cond) diff --git a/tests/tensor/test_basic_opt.py b/tests/tensor/test_basic_opt.py index 3eebb5925d..58a097a311 100644 --- a/tests/tensor/test_basic_opt.py +++ b/tests/tensor/test_basic_opt.py @@ -2890,10 +2890,10 @@ def test_infer_shape(self): self._compile_and_check([admat], [Shape_i(1)(admat)], [admat_val], Shape_i) -class TestShapeFeature: +class TestSameShape: def test_scalar(self): x = scalar() - cst = at.constant(1).clone() + cst = at.constant(1) o = x + cst fgraph = FunctionGraph([x], [o], clone=False) shape_feature = ShapeFeature() @@ -2902,34 +2902,42 @@ def test_scalar(self): def test_vector(self): x = vector() - cst = at.constant(1).clone() + cst = at.constant(1) o = x + cst fgraph = FunctionGraph([x], [o], clone=False) shape_feature = ShapeFeature() fgraph.attach_feature(shape_feature) assert shape_feature.same_shape(x, o) - def test_vector2(self): + def test_no_static_shapes(self): x = vector() y = vector() o = x + y fgraph = FunctionGraph([x, y], [o], clone=False) shape_feature = ShapeFeature() fgraph.attach_feature(shape_feature) - assert shape_feature.same_shape(x, o) + # We no longer assume that `x` has the same shape as `y` simply because + # neither has static shape information. Instead, when there is no + # static shape information is available, we assume that `x` and/or `y` + # could have shapes `(1,)` and/or `(n,)`, where `n != 1`, or any + # combination of the two. + assert not shape_feature.same_shape(x, o) # The following case isn't implemented assert not shape_feature.same_shape(y, o) - def test_vector_dim(self): - x = vector() - y = vector() + @pytest.mark.parametrize( + "y_dim_0", + [2, pytest.param(None, marks=pytest.mark.xfail(reason="Not implemented"))], + ) + def test_vector_dim(self, y_dim_0): + x = at.tensor(dtype="floatX", shape=(2, None)) + y = at.tensor(dtype="floatX", shape=(y_dim_0, None)) o = x + y fgraph = FunctionGraph([x, y], [o], clone=False) shape_feature = ShapeFeature() fgraph.attach_feature(shape_feature) assert shape_feature.same_shape(x, o, 0, 0) - # The following case isn't implemented - assert not shape_feature.same_shape(y, o, 0, 0) + assert not shape_feature.same_shape(x, o, 1, 1) def test_vector_dim_err(self): x = vector() diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 1b298a92a8..0cc6e8658c 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -11,12 +11,13 @@ import tests.unittest_tools as utt from aesara.compile.mode import Mode from aesara.configdefaults import config -from aesara.graph.basic import Variable +from aesara.graph.basic import Apply, Variable from aesara.graph.fg import FunctionGraph from aesara.link.basic import PerformLinker from aesara.link.c.basic import CLinker, OpWiseCLinker from aesara.tensor import as_tensor_variable from aesara.tensor.basic import second +from aesara.tensor.basic_opt import ShapeError from aesara.tensor.elemwise import CAReduce, CAReduceDtype, DimShuffle, Elemwise from aesara.tensor.math import all as at_all from aesara.tensor.math import any as at_any @@ -800,6 +801,46 @@ def test_str(self): op = Elemwise(aes.add, inplace_pattern=None, name="my_op") assert str(op) == "my_op" + def test_partial_static_shape_info(self): + """Make sure that `Elemwise.infer_shape` can handle changes in the static shape information during rewriting.""" + + x = TensorType("floatX", shape=(None, None))() + z = Elemwise(aes.add)(x, x) + + x_inferred_shape = (aes.constant(1), aes.constant(1)) + + res_shape = z.owner.op.infer_shape( + None, z.owner, [x_inferred_shape, x_inferred_shape] + ) + + assert len(res_shape) == 1 + assert len(res_shape[0]) == 2 + assert res_shape[0][0].data == 1 + assert res_shape[0][1].data == 1 + + def test_multi_output(self): + class CustomElemwise(Elemwise): + def make_node(self, *args): + res = super().make_node(*args) + return Apply( + self, + res.inputs, + # Return two outputs + [ + TensorType(dtype="float64", shape=(None, None))() + for i in range(2) + ], + ) + + z_1, z_2 = CustomElemwise(aes.add)( + as_tensor_variable(np.eye(1)), as_tensor_variable(np.eye(1)) + ) + + in_1_shape = (aes.constant(1), aes.constant(1)) + + with pytest.raises(ShapeError): + z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_1_shape]) + def test_not_implemented_elemwise_grad(): # Regression test for unimplemented gradient in an Elemwise Op. diff --git a/tests/tensor/test_math_opt.py b/tests/tensor/test_math_opt.py index d32ae1bb13..615e5f9c80 100644 --- a/tests/tensor/test_math_opt.py +++ b/tests/tensor/test_math_opt.py @@ -513,7 +513,6 @@ def test_elemwise_multiple_inputs_optimisation2(self): assert len(f.maker.fgraph.toposort()) == nb_elemwise assert out_dtype == out.dtype - @pytest.mark.slow def test_multiple_case(self): # test those case take from the comment in AlgebraicCanonizer # x / x -> 1 @@ -594,10 +593,7 @@ def test_multiple_case(self): assert out_dtype == out.dtype utt.assert_allclose(out, val_inputs[1]) topo = f.maker.fgraph.toposort() - if topo and not (len(topo) == 1 and topo[0].op == deep_copy_op): - for node in topo[:-1]: - assert isinstance(node.op, Shape_i) - assert isinstance(topo[-1].op, Alloc) + assert not any(node.op == at.true_div for node in topo) # test x / y / x -> 1 / y for id, (g, sym_inputs, val_inputs, nb_elemwise, out_dtype) in enumerate(