Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use input shapes to compute output shape in Elemwise.infer_shape #981

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 11 additions & 32 deletions aesara/tensor/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from aesara.misc.safe_asarray import _asarray
from aesara.printing import FunctionPrinter, Printer, pprint
from aesara.scalar import get_scalar_type
from aesara.scalar.basic import ScalarType
from aesara.scalar.basic import bool as scalar_bool
from aesara.scalar.basic import identity as scalar_identity
from aesara.scalar.basic import transfer_type, upcast
Expand Down Expand Up @@ -804,37 +803,17 @@ def perform(self, node, inputs, output_storage):
storage[0] = variable

def infer_shape(self, fgraph, node, i_shapes):
rval = []
for o in node.outputs:
oshp = []
for dim, b in enumerate(o.type.broadcastable):
b_dim = None
if b:
# this is broadcastable
b_dim = 1
else:
# there must be some input that is not broadcastable in
# dimension 'dim'
for ishp, i in zip(i_shapes, node.inputs):
if isinstance(i.type, ScalarType):
continue # we skip scalar
if not i.type.broadcastable[dim]:
# input i is not broadcastable in position dim
# therefore if its shape is known, we can use it
# as the output shape
if ishp[dim]:
b_dim = ishp[dim]
break

# b_dim might still be None, if every input's shape was unknown
# in dimension 'dim'
oshp.append(b_dim)
# TODO: it would be interesting to return the constraining
# information that if one of the inputs shape[dim] is known
# and another input's shape[dim] is not, that we can now assume
# that the other input's shape[dim] is the same as the first.
rval.append(tuple(oshp))
return rval

if len(node.outputs) > 1:
from aesara.tensor.basic_opt import ShapeError

raise ShapeError(
"Multiple outputs are not supported by the default `Elemwise.infer_shape`"
)

out_shape = aesara.tensor.broadcast_shape(*i_shapes, arrays_are_shapes=True)

return [out_shape]

def _c_all(self, node, nodename, inames, onames, sub):
# Some `Op`s directly call `Elemwise._c_all` or `Elemwise.c_code`
Expand Down
22 changes: 14 additions & 8 deletions aesara/tensor/extra_ops.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from collections.abc import Collection
from functools import reduce
from typing import Iterable, Tuple, Union

import numpy as np
import numpy.core.numeric
from numpy.core.multiarray import normalize_axis_index

import aesara
import aesara.scalar.basic as aes
from aesara.gradient import (
DisconnectedType,
_float_zeros_like,
Expand All @@ -26,9 +28,7 @@
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import abs as at_abs
from aesara.tensor.math import all as at_all
from aesara.tensor.math import eq, ge, lt
from aesara.tensor.math import max as at_max
from aesara.tensor.math import maximum, minimum, or_, prod
from aesara.tensor.math import ge, lt, maximum, minimum, prod
from aesara.tensor.math import sum as at_sum
from aesara.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
from aesara.tensor.type import TensorType, dvector, int_dtypes, integer_dtypes, vector
Expand Down Expand Up @@ -1534,13 +1534,19 @@ def broadcast_shape_iter(
result_dims.append(maybe_non_bcast_shapes[0])
continue

non_bcast_vec = at.as_tensor(maybe_non_bcast_shapes)
non_bcast_vec = at.switch(eq(non_bcast_vec, 1), -one_at, non_bcast_vec)
dim_max = at_abs(at_max(non_bcast_vec))
non_bcast_vec = [
brandonwillard marked this conversation as resolved.
Show resolved Hide resolved
aes.switch(aes.eq(nbv, 1), -one_at, nbv)
for nbv in maybe_non_bcast_shapes
]
dim_max = aes.abs(reduce(aes.scalar_maximum, non_bcast_vec))

assert_dim = Assert("Could not broadcast dimensions")
assert_cond = at_all(
or_(eq(non_bcast_vec, -one_at), eq(non_bcast_vec, dim_max))
assert_cond = reduce(
aes.and_,
(
aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dim_max))
for nbv in non_bcast_vec
),
)
bcast_dim = assert_dim(dim_max, assert_cond)

Expand Down
28 changes: 18 additions & 10 deletions tests/tensor/test_basic_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2890,10 +2890,10 @@ def test_infer_shape(self):
self._compile_and_check([admat], [Shape_i(1)(admat)], [admat_val], Shape_i)


class TestShapeFeature:
class TestSameShape:
def test_scalar(self):
x = scalar()
cst = at.constant(1).clone()
cst = at.constant(1)
o = x + cst
fgraph = FunctionGraph([x], [o], clone=False)
shape_feature = ShapeFeature()
Expand All @@ -2902,34 +2902,42 @@ def test_scalar(self):

def test_vector(self):
x = vector()
cst = at.constant(1).clone()
cst = at.constant(1)
o = x + cst
fgraph = FunctionGraph([x], [o], clone=False)
shape_feature = ShapeFeature()
fgraph.attach_feature(shape_feature)
assert shape_feature.same_shape(x, o)

def test_vector2(self):
def test_no_static_shapes(self):
x = vector()
y = vector()
o = x + y
fgraph = FunctionGraph([x, y], [o], clone=False)
shape_feature = ShapeFeature()
fgraph.attach_feature(shape_feature)
assert shape_feature.same_shape(x, o)
# We no longer assume that `x` has the same shape as `y` simply because
# neither has static shape information. Instead, when there is no
# static shape information is available, we assume that `x` and/or `y`
# could have shapes `(1,)` and/or `(n,)`, where `n != 1`, or any
# combination of the two.
assert not shape_feature.same_shape(x, o)
# The following case isn't implemented
assert not shape_feature.same_shape(y, o)

def test_vector_dim(self):
x = vector()
y = vector()
@pytest.mark.parametrize(
"y_dim_0",
[2, pytest.param(None, marks=pytest.mark.xfail(reason="Not implemented"))],
)
def test_vector_dim(self, y_dim_0):
x = at.tensor(dtype="floatX", shape=(2, None))
y = at.tensor(dtype="floatX", shape=(y_dim_0, None))
o = x + y
fgraph = FunctionGraph([x, y], [o], clone=False)
shape_feature = ShapeFeature()
fgraph.attach_feature(shape_feature)
assert shape_feature.same_shape(x, o, 0, 0)
# The following case isn't implemented
assert not shape_feature.same_shape(y, o, 0, 0)
assert not shape_feature.same_shape(x, o, 1, 1)

def test_vector_dim_err(self):
x = vector()
Expand Down
43 changes: 42 additions & 1 deletion tests/tensor/test_elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
import tests.unittest_tools as utt
from aesara.compile.mode import Mode
from aesara.configdefaults import config
from aesara.graph.basic import Variable
from aesara.graph.basic import Apply, Variable
from aesara.graph.fg import FunctionGraph
from aesara.link.basic import PerformLinker
from aesara.link.c.basic import CLinker, OpWiseCLinker
from aesara.tensor import as_tensor_variable
from aesara.tensor.basic import second
from aesara.tensor.basic_opt import ShapeError
from aesara.tensor.elemwise import CAReduce, CAReduceDtype, DimShuffle, Elemwise
from aesara.tensor.math import all as at_all
from aesara.tensor.math import any as at_any
Expand Down Expand Up @@ -800,6 +801,46 @@ def test_str(self):
op = Elemwise(aes.add, inplace_pattern=None, name="my_op")
assert str(op) == "my_op"

def test_partial_static_shape_info(self):
"""Make sure that `Elemwise.infer_shape` can handle changes in the static shape information during rewriting."""

x = TensorType("floatX", shape=(None, None))()
z = Elemwise(aes.add)(x, x)

x_inferred_shape = (aes.constant(1), aes.constant(1))

res_shape = z.owner.op.infer_shape(
None, z.owner, [x_inferred_shape, x_inferred_shape]
)

assert len(res_shape) == 1
assert len(res_shape[0]) == 2
assert res_shape[0][0].data == 1
assert res_shape[0][1].data == 1

def test_multi_output(self):
class CustomElemwise(Elemwise):
def make_node(self, *args):
res = super().make_node(*args)
return Apply(
self,
res.inputs,
# Return two outputs
[
TensorType(dtype="float64", shape=(None, None))()
for i in range(2)
],
)

z_1, z_2 = CustomElemwise(aes.add)(
as_tensor_variable(np.eye(1)), as_tensor_variable(np.eye(1))
)

in_1_shape = (aes.constant(1), aes.constant(1))

with pytest.raises(ShapeError):
z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_1_shape])


def test_not_implemented_elemwise_grad():
# Regression test for unimplemented gradient in an Elemwise Op.
Expand Down
6 changes: 1 addition & 5 deletions tests/tensor/test_math_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,6 @@ def test_elemwise_multiple_inputs_optimisation2(self):
assert len(f.maker.fgraph.toposort()) == nb_elemwise
assert out_dtype == out.dtype

@pytest.mark.slow
def test_multiple_case(self):
# test those case take from the comment in AlgebraicCanonizer
# x / x -> 1
Expand Down Expand Up @@ -594,10 +593,7 @@ def test_multiple_case(self):
assert out_dtype == out.dtype
utt.assert_allclose(out, val_inputs[1])
topo = f.maker.fgraph.toposort()
if topo and not (len(topo) == 1 and topo[0].op == deep_copy_op):
for node in topo[:-1]:
assert isinstance(node.op, Shape_i)
assert isinstance(topo[-1].op, Alloc)
assert not any(node.op == at.true_div for node in topo)

# test x / y / x -> 1 / y
for id, (g, sym_inputs, val_inputs, nb_elemwise, out_dtype) in enumerate(
Expand Down