Skip to content

Commit

Permalink
Only use input shapes to compute output shape in Elemwise.infer_shape
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Jun 8, 2022
1 parent 22416ba commit 064e72f
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 33 deletions.
43 changes: 11 additions & 32 deletions aesara/tensor/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from aesara.misc.safe_asarray import _asarray
from aesara.printing import FunctionPrinter, Printer, pprint
from aesara.scalar import get_scalar_type
from aesara.scalar.basic import ScalarType
from aesara.scalar.basic import bool as scalar_bool
from aesara.scalar.basic import identity as scalar_identity
from aesara.scalar.basic import transfer_type, upcast
Expand Down Expand Up @@ -804,37 +803,17 @@ def perform(self, node, inputs, output_storage):
storage[0] = variable

def infer_shape(self, fgraph, node, i_shapes):
rval = []
for o in node.outputs:
oshp = []
for dim, b in enumerate(o.type.broadcastable):
b_dim = None
if b:
# this is broadcastable
b_dim = 1
else:
# there must be some input that is not broadcastable in
# dimension 'dim'
for ishp, i in zip(i_shapes, node.inputs):
if isinstance(i.type, ScalarType):
continue # we skip scalar
if not i.type.broadcastable[dim]:
# input i is not broadcastable in position dim
# therefore if its shape is known, we can use it
# as the output shape
if ishp[dim]:
b_dim = ishp[dim]
break

# b_dim might still be None, if every input's shape was unknown
# in dimension 'dim'
oshp.append(b_dim)
# TODO: it would be interesting to return the constraining
# information that if one of the inputs shape[dim] is known
# and another input's shape[dim] is not, that we can now assume
# that the other input's shape[dim] is the same as the first.
rval.append(tuple(oshp))
return rval

if len(node.outputs) > 1:
from aesara.tensor.basic_opt import ShapeError

raise ShapeError(
"Multiple outputs are not supported by the default `Elemwise.infer_shape`"
)

out_shape = aesara.tensor.broadcast_shape(*i_shapes, arrays_are_shapes=True)

return [out_shape]

def _c_all(self, node, nodename, inames, onames, sub):
# Some `Op`s directly call `Elemwise._c_all` or `Elemwise.c_code`
Expand Down
43 changes: 42 additions & 1 deletion tests/tensor/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
import tests.unittest_tools as utt
from aesara.compile.mode import Mode
from aesara.configdefaults import config
from aesara.graph.basic import Variable
from aesara.graph.basic import Apply, Variable
from aesara.graph.fg import FunctionGraph
from aesara.link.basic import PerformLinker
from aesara.link.c.basic import CLinker, OpWiseCLinker
from aesara.tensor import as_tensor_variable
from aesara.tensor.basic import second
from aesara.tensor.basic_opt import ShapeError
from aesara.tensor.elemwise import CAReduce, CAReduceDtype, DimShuffle, Elemwise
from aesara.tensor.math import all as at_all
from aesara.tensor.math import any as at_any
Expand Down Expand Up @@ -800,6 +801,46 @@ def test_str(self):
op = Elemwise(aes.add, inplace_pattern=None, name="my_op")
assert str(op) == "my_op"

def test_partial_static_shape_info(self):
"""Make sure that `Elemwise.infer_shape` can handle changes in the static shape information during rewriting."""

x = TensorType("floatX", shape=(None, None))()
z = Elemwise(aes.add)(x, x)

x_inferred_shape = (aes.constant(1), aes.constant(1))

res_shape = z.owner.op.infer_shape(
None, z.owner, [x_inferred_shape, x_inferred_shape]
)

assert len(res_shape) == 1
assert len(res_shape[0]) == 2
assert res_shape[0][0].data == 1
assert res_shape[0][1].data == 1

def test_multi_output(self):
class CustomElemwise(Elemwise):
def make_node(self, *args):
res = super().make_node(*args)
return Apply(
self,
res.inputs,
# Return two outputs
[
TensorType(dtype="float64", shape=(None, None))()
for i in range(2)
],
)

z_1, z_2 = CustomElemwise(aes.add)(
as_tensor_variable(np.eye(1)), as_tensor_variable(np.eye(1))
)

in_1_shape = (aes.constant(1), aes.constant(1))

with pytest.raises(ShapeError):
z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_1_shape])


def test_not_implemented_elemwise_grad():
# Regression test for unimplemented gradient in an Elemwise Op.
Expand Down

0 comments on commit 064e72f

Please sign in to comment.