Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor output Type inference in RandomVariable and Scan #1253

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions aesara/compile/debugmode.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,7 @@ def _get_preallocated_maps(
for r in considered_outputs:
if isinstance(r.type, TensorType):
# Build a C-contiguous buffer
new_buf = r.type.value_zeros(r_vals[r].shape)
new_buf = np.empty(r_vals[r].shape, dtype=r.type.dtype)
assert new_buf.flags["C_CONTIGUOUS"]
new_buf[...] = np.asarray(def_val).astype(r.type.dtype)

Expand Down Expand Up @@ -875,7 +875,8 @@ def _get_preallocated_maps(
buf_shape.append(s)
else:
buf_shape.append(s * 2)
new_buf = r.type.value_zeros(buf_shape)

new_buf = np.empty(buf_shape, dtype=r.type.dtype)
new_buf[...] = np.asarray(def_val).astype(r.type.dtype)
init_strided[r] = new_buf

Expand Down Expand Up @@ -950,7 +951,7 @@ def _get_preallocated_maps(
max((s + sd), 0)
for s, sd in zip(r_vals[r].shape, r_shape_diff)
]
new_buf = r.type.value_zeros(out_shape)
new_buf = np.empty(out_shape, dtype=r.type.dtype)
new_buf[...] = np.asarray(def_val).astype(r.type.dtype)
wrong_size[r] = new_buf

Expand Down
38 changes: 23 additions & 15 deletions aesara/scan/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,9 @@ def check_broadcast(v1, v2):
which may wrongly be interpreted as broadcastable.

"""
if not hasattr(v1, "broadcastable") and not hasattr(v2, "broadcastable"):
if not isinstance(v1.type, TensorType) and not isinstance(v2.type, TensorType):
return

msg = (
"The broadcast pattern of the output of scan (%s) is "
"inconsistent with the one provided in `output_info` "
Expand All @@ -173,13 +174,13 @@ def check_broadcast(v1, v2):
"them consistent, e.g. using aesara.tensor."
"{unbroadcast, specify_broadcastable}."
)
size = min(len(v1.broadcastable), len(v2.broadcastable))
size = min(v1.type.ndim, v2.type.ndim)
for n, (b1, b2) in enumerate(
zip(v1.broadcastable[-size:], v2.broadcastable[-size:])
zip(v1.type.broadcastable[-size:], v2.type.broadcastable[-size:])
):
if b1 != b2:
a1 = n + size - len(v1.broadcastable) + 1
a2 = n + size - len(v2.broadcastable) + 1
a1 = n + size - v1.type.ndim + 1
a2 = n + size - v2.type.ndim + 1
raise TypeError(msg % (v1.type, v2.type, a1, b1, b2, a2))


Expand All @@ -200,7 +201,7 @@ def copy_var_format(var, as_var):
rval = as_var.type.filter_variable(rval)
else:
tmp = as_var.type.clone(
shape=(tuple(var.broadcastable[:1]) + tuple(as_var.broadcastable))
shape=(tuple(var.type.shape[:1]) + tuple(as_var.type.shape))
)
rval = tmp.filter_variable(rval)
return rval
Expand Down Expand Up @@ -628,6 +629,7 @@ def validate_inner_graph(self):
type_input = self.inner_inputs[inner_iidx].type
type_output = self.inner_outputs[inner_oidx].type
if (
# TODO: Use the `Type` interface for this
type_input.dtype != type_output.dtype
or type_input.broadcastable != type_output.broadcastable
):
Expand Down Expand Up @@ -805,7 +807,9 @@ def tensorConstructor(shape, dtype):
# output sequence
o = outputs[idx]
self.output_types.append(
typeConstructor((False,) + o.type.broadcastable, o.type.dtype)
# TODO: What can we actually say about the shape of this
# added dimension?
typeConstructor((None,) + o.type.shape, o.type.dtype)
)

idx += len(info.mit_mot_out_slices[jdx])
Expand All @@ -816,7 +820,9 @@ def tensorConstructor(shape, dtype):

for o in outputs[idx:end]:
self.output_types.append(
typeConstructor((False,) + o.type.broadcastable, o.type.dtype)
# TODO: What can we actually say about the shape of this
# added dimension?
typeConstructor((None,) + o.type.shape, o.type.dtype)
)

# shared outputs + possibly the ending condition
Expand Down Expand Up @@ -1380,11 +1386,13 @@ def prepare_fgraph(self, fgraph):
# the output value, possibly inplace, at the end of the
# function execution. Also, since an update is defined,
# a default value must also be (this is verified by
# DebugMode). Use an array of size 0 with the correct
# ndim and dtype (use a shape of 1 on broadcastable
# dimensions, and 0 on the others).
default_shape = [1 if _b else 0 for _b in inp.broadcastable]
default_val = inp.type.value_zeros(default_shape)
# DebugMode).
# TODO FIXME: Why do we need a "default value" here?
# This sounds like a serious design issue.
default_shape = tuple(
s if s is not None else 0 for s in inp.type.shape
)
default_val = np.empty(default_shape, dtype=inp.type.dtype)
wrapped_inp = In(
variable=inp,
value=default_val,
Expand Down Expand Up @@ -2318,8 +2326,8 @@ def infer_shape(self, fgraph, node, input_shapes):
# equivalent (if False). Here, we only need the variable.
v_shp_i = validator.check(shp_i)
if v_shp_i is None:
if hasattr(r, "broadcastable") and r.broadcastable[i]:
shp.append(1)
if r.type.shape[i] is not None:
shp.append(r.type.shape[i])
else:
shp.append(Shape_i(i)(r))
else:
Expand Down
8 changes: 0 additions & 8 deletions aesara/sparse/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,14 +224,6 @@ def get_size(self, shape_info):
+ (shape_info[2] + shape_info[3]) * np.dtype("int32").itemsize
)

def value_zeros(self, shape):
matrix_constructor = self.format_cls.get(self.format)

if matrix_constructor is None:
raise ValueError(f"Sparse matrix type {self.format} not found in SciPy")

return matrix_constructor(shape, dtype=self.dtype)

def __eq__(self, other):
res = super().__eq__(other)

Expand Down
43 changes: 31 additions & 12 deletions aesara/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from collections.abc import Sequence
from functools import partial
from numbers import Number
from typing import Optional
from typing import TYPE_CHECKING, Optional
from typing import Sequence as TypeSequence
from typing import Tuple, Union
from typing import cast as type_cast
Expand Down Expand Up @@ -68,6 +68,10 @@
from aesara.tensor.var import TensorConstant, TensorVariable, get_unique_value


if TYPE_CHECKING:
from aesara.tensor import TensorLike


def __oplist_tag(thing, tag):
tags = getattr(thing, "__oplist_tags", [])
tags.append(tag)
Expand Down Expand Up @@ -1334,11 +1338,25 @@ def identity_like(x, dtype: Optional[Union[str, np.generic, np.dtype]] = None):
return eye(_x.shape[0], _x.shape[1], k=0, dtype=dtype)


def infer_broadcastable(shape):
"""Infer the broadcastable dimensions for `shape`.
def infer_static_shape(
shape: Union[Variable, TypeSequence[Union[Variable, int]]]
) -> Tuple[TypeSequence["TensorLike"], TypeSequence[Optional[int]]]:
"""Infer the static shapes implied by the potentially symbolic elements in `shape`.

`shape` will be validated and constant folded. As a result, this function
can be expensive and shouldn't be used unless absolutely necessary.

It mostly exists as a hold-over from pre-static shape times, when it was
required in order to produce correct broadcastable arrays and prevent
some graphs from being unusable. Now, it is no longer strictly required,
so don't use it unless you want the same shape graphs to be rewritten
multiple times during graph construction.

Returns
-------
A validated sequence of symbolic shape values, and a sequence of
``None``/``int`` values that can be used as `TensorType.shape` values.

`shape` will be validated and constant folded in order to determine
which dimensions are broadcastable (i.e. equal to ``1``).
"""
from aesara.tensor.rewriting.basic import topo_constant_folding
from aesara.tensor.rewriting.shape import ShapeFeature
Expand All @@ -1362,9 +1380,10 @@ def check_type(s):
clone=True,
)
folded_shape = rewrite_graph(shape_fg, custom_rewrite=topo_constant_folding).outputs

bcast = tuple(getattr(s, "data", s) == 1 for s in folded_shape)
return sh, bcast
static_shape = tuple(
s.data.item() if isinstance(s, Constant) else None for s in folded_shape
)
return sh, static_shape


class Alloc(COp):
Expand Down Expand Up @@ -1394,15 +1413,15 @@ class Alloc(COp):

def make_node(self, value, *shape):
v = as_tensor_variable(value)
sh, bcast = infer_broadcastable(shape)
sh, static_shape = infer_static_shape(shape)
if v.ndim > len(sh):
raise TypeError(
"The Alloc value to use has more dimensions"
" than the specified dimensions",
v.ndim,
len(sh),
)
otype = TensorType(dtype=v.dtype, shape=bcast)
otype = TensorType(dtype=v.dtype, shape=static_shape)
return Apply(self, [v] + sh, [otype()])

def perform(self, node, inputs, out_):
Expand Down Expand Up @@ -3823,8 +3842,8 @@ def typecode(self):
return np.dtype(self.dtype).num

def make_node(self, *_shape):
_shape, bcast = infer_broadcastable(_shape)
otype = TensorType(dtype=self.dtype, shape=bcast)
_shape, static_shape = infer_static_shape(_shape)
otype = TensorType(dtype=self.dtype, shape=static_shape)
output = otype()

output.tag.values_eq_approx = values_eq_approx_always_true
Expand Down
56 changes: 46 additions & 10 deletions aesara/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from aesara.raise_op import Assert
from aesara.scalar import int32 as int_t
from aesara.scalar import upcast
from aesara.scalar.basic import Composite
from aesara.tensor import basic as at
from aesara.tensor import get_vector_length
from aesara.tensor.exceptions import NotScalarConstantError
Expand Down Expand Up @@ -1552,16 +1553,32 @@ def broadcast_shape_iter(
# be broadcastable or equal to the one non-broadcastable
# constant `const_nt_shape_var`.
assert_dim = Assert("Could not broadcast dimensions")

scalar_nonconst_nb_shapes = [
at.scalar_from_tensor(s)
if isinstance(s.type, TensorType)
else s
for s in nonconst_nb_shapes
]

dummy_nonconst_nb_shapes = [
aes.get_scalar_type(dtype=v.dtype)()
for v in scalar_nonconst_nb_shapes
]
assert_cond = reduce(
aes.and_,
(
aes.or_(
aes.eq(nbv, one_at), aes.eq(nbv, const_nt_shape_var)
)
for nbv in nonconst_nb_shapes
for nbv in dummy_nonconst_nb_shapes
),
)
bcast_dim = assert_dim(const_nt_shape_var, assert_cond)
assert_cond_op = Composite(dummy_nonconst_nb_shapes, [assert_cond])

bcast_dim = assert_dim(
const_nt_shape_var, assert_cond_op(*scalar_nonconst_nb_shapes)
)
else:
bcast_dim = const_nt_shape_var
else:
Expand All @@ -1579,21 +1596,37 @@ def broadcast_shape_iter(
result_dims.append(maybe_non_bcast_shapes[0])
continue

scalar_maybe_non_bcast_shapes = [
at.scalar_from_tensor(s) if isinstance(s.type, TensorType) else s
for s in maybe_non_bcast_shapes
]
dummy_maybe_non_bcast_shapes = [
aes.get_scalar_type(dtype=v.dtype)()
for v in scalar_maybe_non_bcast_shapes
]
non_bcast_vec = [
aes.switch(aes.eq(nbv, 1), -one_at, nbv)
for nbv in maybe_non_bcast_shapes
for nbv in dummy_maybe_non_bcast_shapes
]
dim_max = aes.abs(reduce(aes.scalar_maximum, non_bcast_vec))
dim_max_op = Composite(dummy_maybe_non_bcast_shapes, [dim_max])

dummy_dim_max = dim_max_op(*dummy_maybe_non_bcast_shapes)

assert_dim = Assert("Could not broadcast dimensions")
assert_cond = reduce(
aes.and_,
(
aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dim_max))
aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dummy_dim_max))
for nbv in non_bcast_vec
),
)
bcast_dim = assert_dim(dim_max, assert_cond)
assert_cond_op = Composite(dummy_maybe_non_bcast_shapes, [assert_cond])

bcast_dim = assert_dim(
dim_max_op(*scalar_maybe_non_bcast_shapes),
assert_cond_op(*scalar_maybe_non_bcast_shapes),
)

result_dims.append(bcast_dim)

Expand All @@ -1613,9 +1646,9 @@ def __call__(self, a, shape, **kwargs):
def make_node(self, a, *shape):
a = at.as_tensor_variable(a)

shape, bcast = at.infer_broadcastable(shape)
shape, static_shape = at.infer_static_shape(shape)

out = TensorType(dtype=a.type.dtype, shape=bcast)()
out = TensorType(dtype=a.type.dtype, shape=static_shape)()

# Attempt to prevent in-place operations on this view-based output
out.tag.indestructible = True
Expand All @@ -1637,11 +1670,14 @@ def grad(self, inputs, outputs_gradients):
d_wrt_a = broadcast_to(dout, shape).sum(axis=new_dims)

# Determine the dimensions that were broadcast
_, shape_bcast = at.infer_broadcastable(shape)
_, static_shape = at.infer_static_shape(shape)

# TODO: This needs to be performed at run-time when static shape
# information isn't available.
bcast_sums = [
i
for i, (a_b, s_b) in enumerate(zip(a.broadcastable, shape_bcast[-a.ndim :]))
if a_b and not s_b
for i, (a_s, s_s) in enumerate(zip(a.type.shape, static_shape[-a.ndim :]))
if a_s == 1 and s_s != 1
]

if bcast_sums:
Expand Down
6 changes: 3 additions & 3 deletions aesara/tensor/random/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
constant,
get_scalar_constant_value,
get_vector_length,
infer_broadcastable,
infer_static_shape,
)
from aesara.tensor.random.type import RandomGeneratorType, RandomStateType, RandomType
from aesara.tensor.random.utils import normalize_size_param, params_broadcast_shapes
Expand Down Expand Up @@ -322,7 +322,7 @@ def make_node(self, rng, size, dtype, *dist_params):
)

shape = self._infer_shape(size, dist_params)
_, bcast = infer_broadcastable(shape)
_, static_shape = infer_static_shape(shape)
dtype = self.dtype or dtype

if dtype == "floatX":
Expand All @@ -336,7 +336,7 @@ def make_node(self, rng, size, dtype, *dist_params):
dtype_idx = constant(dtype, dtype="int64")
dtype = all_dtypes[dtype_idx.data]

outtype = TensorType(dtype=dtype, shape=bcast)
outtype = TensorType(dtype=dtype, shape=static_shape)
out_var = outtype()
inputs = (rng, size, dtype_idx) + dist_params
outputs = (rng.type(), out_var)
Expand Down
7 changes: 0 additions & 7 deletions aesara/tensor/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,13 +331,6 @@ def convert_variable(self, var):
# `specify_shape` will combine the more precise shapes of the two types
return aesara.tensor.specify_shape(var, self.shape)

def value_zeros(self, shape):
"""Create an numpy ndarray full of 0 values.

TODO: Remove this trivial method.
"""
return np.zeros(shape, dtype=self.dtype)

@staticmethod
def values_eq(a, b, force_same_dtype=True):
# TODO: check to see if the shapes must match; for now, we err on safe
Expand Down
Loading