Skip to content

Commit

Permalink
Implement RandomVariable Subtensor lift optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed Nov 24, 2020
1 parent 6e67e9e commit f28efcb
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 0 deletions.
64 changes: 64 additions & 0 deletions tests/tensor/random/test_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from theano.gof.optdb import Query
from theano.tensor.elemwise import DimShuffle
from theano.tensor.random.basic import dirichlet, multivariate_normal, normal
from theano.tensor.random.op import RandomVariable
from theano.tensor.random.opt import lift_rv_shapes
from theano.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subtensor


inplace_mode = Mode("py", Query(include=["random_make_inplace"], exclude=[]))
Expand Down Expand Up @@ -165,3 +167,65 @@ def test_DimShuffle_lift(new_dims, lifted):
)

np.testing.assert_allclose(res_base, res_opt, rtol=1e-3)


def test_Subtensor_lift_univariate():
mean = tt.vector("mean")
mean.tag.test_value = np.array([1.0, 10.0, 100.0], dtype=config.floatX)

std = tt.vector("std")
std.tag.test_value = np.array([1e-5, 2e-5, 3e-5], dtype=config.floatX)

size = tt.iscalars(2)
size[0].tag.test_value = np.array(4, dtype=size[0].dtype)
size[1].tag.test_value = np.array(3, dtype=size[0].dtype)

idx1 = tt.type_other.make_slice(1, mean.shape[0])
idx2 = tt.ivector("idx2")
idx2.tag.test_value = np.array([0, 2], dtype=idx2.dtype)

seed = 1233532
rng_np = np.random.RandomState(seed)
rng = shared(rng_np, borrow=False)

test_rv = normal(mean, std, size=size, rng=rng)

indices = (idx1, idx2)

# Non-lifted simple advanced `Subtensor`
norm_out = test_rv[indices]

inputs = [mean, std, idx2] + size
fn_base = function(inputs, norm_out, mode=no_mode)
fn_opt = function(inputs, norm_out, mode=canonicalize_mode)

res_base = fn_base(*[i.get_test_value() for i in inputs])
res_opt = fn_opt(*[i.get_test_value() for i in inputs])

new_out = fn_opt.maker.fgraph.outputs[0]
assert isinstance(new_out.owner.op, RandomVariable)
assert all(
isinstance(i.owner.op, (AdvancedSubtensor, AdvancedSubtensor1, Subtensor))
for i in new_out.owner.inputs[3:]
)

np.testing.assert_allclose(res_base, res_opt, rtol=1e-3)

# Now, let's try some advanced indexing that broadcasts/expands
# dimensions
# idx2 = tt.imatrix("idx2")
# idx2.tag.test_value = np.array([[0, 2], [2, 1]], dtype=idx2.dtype)
#
# indices = (idx1, idx2)
# tuple(idx.get_test_value() for idx in indices)
#
# norm_out = test_rv[indices]
# norm_out.get_test_value()
#
# output_shape = (slice_len(idx1, test_rv.shape[1]),) + tuple(idx2.shape)
# assert tuple(s.get_test_value() for s in output_shape) == norm_out.get_test_value().shape
#
# norm_out_lift = normal(mean[idx2], std[idx2], size=output_shape)
# norm_out_lift.get_test_value()
#
# np.testing.assert_allclose(norm_out.get_test_value(), norm_out_lift.get_test_value(), rtol=1e-3)
97 changes: 97 additions & 0 deletions theano/tensor/random/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
from theano.tensor.opt import in2out, register_canonicalize
from theano.tensor.random.op import RandomVariable
from theano.tensor.random.utils import broadcast_params
from theano.tensor.subtensor import (
AdvancedSubtensor,
AdvancedSubtensor1,
Subtensor,
indexed_result_shape,
)


@local_optimizer([RandomVariable])
Expand Down Expand Up @@ -168,3 +174,94 @@ def local_dimshuffle_rv_lift(fgraph, node):
return [rv_op.make_node(rng, new_size, dtype, *dist_params).outputs[1]]

return False


@register_canonicalize
@local_optimizer([Subtensor, AdvancedSubtensor1, AdvancedSubtensor])
def local_subtensor_rv_lift(fgraph, node):
"""Lift ``*Subtensor`` `Op`s up to a `RandomVariable`'s parameters.
In a fashion similar to `local_dimshuffle_rv_lift`, the indexed dimensions
need to be separated into distinct replication-space and (independent)
parameter-space ``*Subtensor``s.
The replication-space ``*Subtensor`` can be used to determine a
sub/super-set of the replication-space and, thus, a "smaller"/"larger"
``size`` tuple. The parameter-space ``*Subtensor`` is simply lifted and
applied to the `RandomVariable`'s distribution parameters.
Consider the following example graph:
``normal(mu, std, size=(d1, d2, d3))[idx1, idx2, idx3]``. The
``*Subtensor`` `Op` requests indices ``idx1``, ``idx2``, and ``idx3``,
which correspond to all three ``size`` dimensions. Now, depending on the
broadcasted dimensions of ``mu`` and ``std``, this ``*Subtensor`` `Op`
could be reducing the ``size`` parameter and/or subsetting the independent
``mu`` and ``std`` parameters. Only once the dimensions are properly
separated into the two replication/parameter subspaces can we determine how
the ``*Subtensor`` indices are distributed.
For instance, ``normal(mu, std, size=(d1, d2, d3))[idx1, idx2, idx3]``
could become ``normal(mu[idx1], std[idx2], size=np.shape(idx1) + np.shape(idx2) + np.shape(idx3))``
if ``mu.shape == std.shape == ()``
``normal`` is a rather simple case, because it's univariate. Multivariate
cases require a mapping between the parameter space and the image of the
random variable. This may not always be possible, but for many common
distributions it is. For example, the dimensions of the multivariate
normal's image can be mapped directly to each dimension of its parameters.
We use these mappings to change a graph like ``multivariate_normal(mu, Sigma)[idx1]``
into ``multivariate_normal(mu[idx1], Sigma[idx1, idx1])``. Notice how
Also, there's the important matter of "advanced" indexing, which may not
only subset an array, but also broadcast it to a larger size.
"""

ds_op = node.op

if not isinstance(ds_op, (AdvancedSubtensor, AdvancedSubtensor1, Subtensor)):
return False

rv_node = node.inputs[0].owner
if not (rv_node and isinstance(rv_node.op, RandomVariable)):
return False

rv_op = rv_node.op
rng, size, dtype, *dist_params = rv_node.inputs

base_rv = node.inputs[0]
rv_op = base_rv.owner.op
rng, size, dtype, *dist_params = base_rv.owner.inputs
st_indices = node.inputs[1:]
output_shape = indexed_result_shape(base_rv.shape, st_indices)

# We need to separate dimensions into replications and independents
num_ind_dims = None
if len(dist_params) == 1:
num_ind_dims = dist_params[0].ndim
else:
# When there is more than one distribution parameter, assume that all
# of them will broadcast to the maximum number of dimensions
num_ind_dims = max(d.ndim for d in dist_params)

reps_ind_split_idx = len(output_shape) - (num_ind_dims + rv_op.ndim_supp)

# These are the indices that need to be applied to the parameters
ind_indices = tuple(st_indices[reps_ind_split_idx:])

size_lifted = (
output_shape if rv_op.ndim_supp == 0 else output_shape[: -rv_op.ndim_supp]
)

# TODO: For multidimensional distributions, we need a map that tells us
# which dimensions of the parameters need to be indexed.
#
# For example, `multivariate_normal` would have the following:
# `RandomVariable.param_to_image_dims = ((0,), (0, 1))`
#
# I.e. the first parameter's (i.e. mean's) first dimension maps directly to
# the dimension of the RV's image, and its second parameter's
# (i.e. covariance's) first and second dimensions map directly to the
# dimension of the RV's image.
args_lifted = tuple(p[ind_indices] for p in dist_params)

return [rv_op.make_node(rng, size_lifted, dtype, *args_lifted).outputs[1]]

0 comments on commit f28efcb

Please sign in to comment.