diff --git a/tests/tensor/random/test_opt.py b/tests/tensor/random/test_opt.py index 0cc0c1eb05..d2d0f4a28c 100644 --- a/tests/tensor/random/test_opt.py +++ b/tests/tensor/random/test_opt.py @@ -1,13 +1,19 @@ import numpy as np +import pytest +import theano.tensor as tt +from theano import change_flags, config, shared from theano.compile.function import function from theano.compile.mode import Mode from theano.gof.optdb import Query -from theano.tensor.random.basic import normal +from theano.tensor.elemwise import DimShuffle +from theano.tensor.random.basic import dirichlet, multivariate_normal, normal +from theano.tensor.random.opt import lift_rv_shapes -opts = Query(include=["random_make_inplace"], exclude=[]) -inplace_mode = Mode("py", opts) +inplace_mode = Mode("py", Query(include=["random_make_inplace"], exclude=[])) +canonicalize_mode = Mode("py", Query(include=["canonicalize"], exclude=[])) +no_mode = Mode("py", Query(include=[], exclude=[])) def test_inplace_optimization(): @@ -30,3 +36,132 @@ def test_inplace_optimization(): np.array_equal(a.data, b.data) for a, b in zip(new_out.owner.inputs[1:], out.owner.inputs[1:]) ) + + +def check_shape_lifted_rv(rv, params, size, rng): + tt_params = [] + for p in params: + p_tt = tt.as_tensor(p) + p_tt = p_tt.type() + p_tt.tag.test_value = p + tt_params.append(p_tt) + + tt_size = [] + for s in size: + s_tt = tt.as_tensor(s) + s_tt = s_tt.type() + s_tt.tag.test_value = s + tt_size.append(s_tt) + + rv = rv(*tt_params, size=tt_size, rng=rng) + rv_lifted = lift_rv_shapes(rv.owner) + + # Make sure the size input is empty + assert np.array_equal(rv_lifted.inputs[1].data, []) + + f_ref = function( + tt_params + tt_size, + rv, + mode=no_mode, + ) + f_lifted = function( + tt_params + tt_size, + rv_lifted.outputs[1], + mode=no_mode, + ) + f_ref_val = f_ref(*(params + size)) + f_lifted_val = f_lifted(*(params + size)) + assert np.array_equal(f_ref_val, f_lifted_val) + + +@change_flags(compute_test_value="raise") +def test_lift_rv_shapes(): + + rng = shared(np.random.RandomState(1233532), borrow=False) + + test_params = [ + np.array(1.0, dtype=config.floatX), + np.array(5.0, dtype=config.floatX), + ] + test_size = [] + check_shape_lifted_rv(normal, test_params, test_size, rng) + + test_params = [ + np.array([0.0, 1.0], dtype=config.floatX), + np.array(5.0, dtype=config.floatX), + ] + test_size = [3, 2] + check_shape_lifted_rv(normal, test_params, test_size, rng) + + test_params = [ + np.array([[0], [10], [100]], dtype=config.floatX), + np.diag(np.array([1e-6], dtype=config.floatX)), + ] + test_size = [2, 3] + check_shape_lifted_rv(multivariate_normal, test_params, test_size, rng) + + test_params = [ + np.array([[100, 1, 1], [1, 100, 1], [1, 1, 100]], dtype=config.floatX) + ] + test_size = [2, 3] + check_shape_lifted_rv(dirichlet, test_params, test_size, rng) + + +@pytest.mark.parametrize( + "new_dims,lifted", + [ + ((0, 2, 1), True), + (("x", 0, 2, 1, "x"), True), + (("x", 0, "x", 2, "x", 1, "x"), True), + (("x", 0, 2, 1, "x"), True), + (("x", 1, 0, 2, "x"), False), + ], +) +@change_flags(compute_test_value="off") +def test_DimShuffle_lift(new_dims, lifted): + + rng = shared(np.random.RandomState(1233532), borrow=False) + + mean = tt.matrix("mean") + mean.tag.test_value = np.array([[-1, 20], [300, -4000]], dtype=config.floatX) + std = tt.matrix("std") + std.tag.test_value = np.array([[1e-6, 2e-6]], dtype=config.floatX) + s1 = tt.iscalar("size_d1") + s1.tag.test_value = 3 + s2 = tt.iscalar("size_d2") + s2.tag.test_value = 2 + s3 = tt.iscalar("size_d3") + s3.tag.test_value = 2 + size = [s1, s2, s3] + + norm_dm = normal(mean, std, size=size, rng=rng).dimshuffle(new_dims) + + f_opt = function( + [mean, std] + size, + norm_dm, + mode=canonicalize_mode, + ) + + (new_out,) = f_opt.maker.fgraph.outputs + + if lifted: + assert new_out.owner.op == normal + assert all(isinstance(i.owner.op, DimShuffle) for i in new_out.owner.inputs[3:]) + else: + assert isinstance(new_out.owner.op, DimShuffle) + return + + f_base = function( + [mean, std] + size, + norm_dm, + mode=no_mode, + ) + + res_base = f_base( + mean.tag.test_value, std.tag.test_value, *[s.tag.test_value for s in size] + ) + res_opt = f_opt( + mean.tag.test_value, std.tag.test_value, *[s.tag.test_value for s in size] + ) + + assert np.allclose(res_base, res_opt, rtol=1e-3) diff --git a/theano/tensor/random/opt.py b/theano/tensor/random/opt.py index 9c9c5fe7c2..fc5fcc4154 100644 --- a/theano/tensor/random/opt.py +++ b/theano/tensor/random/opt.py @@ -1,7 +1,11 @@ +import theano.tensor as tt from theano.compile import optdb from theano.gof.opt import local_optimizer -from theano.tensor.opt import in2out +from theano.tensor.elemwise import DimShuffle +from theano.tensor.extra_ops import broadcast_to +from theano.tensor.opt import in2out, register_canonicalize from theano.tensor.random.op import RandomVariable +from theano.tensor.random.utils import broadcast_params @local_optimizer([RandomVariable]) @@ -9,11 +13,8 @@ def random_make_inplace(node): op = node.op if isinstance(op, RandomVariable) and not op.inplace: - name, ndim_supp, ndims_params, dtype, _ = op._props() new_op = type(op)(name, ndim_supp, ndims_params, dtype, True) - # rng, size, dtype, *dist_params = node.inputs - return new_op.make_node(*node.inputs).outputs return False @@ -26,3 +27,128 @@ def random_make_inplace(node): "fast_run", "inplace", ) + + +def lift_rv_shapes(node): + """Lift `RandomVariable`'s shape-related parameters. + + In other words, this will broadcast the distribution parameters and + extra dimensions added by the `size` parameter. + + For example, ``normal([0.0, 1.0], 5.0, size=(3, 2))`` becomes + ``normal([[0., 1.], [0., 1.], [0., 1.]], [[5., 5.], [5., 5.], [5., 5.]])``. + + """ + + if not isinstance(node.op, RandomVariable): + return False + + rng, size, dtype, *dist_params = node.inputs + + dist_params = broadcast_params(dist_params, node.op.ndims_params) + + dist_params = [ + broadcast_to( + p, (tuple(size) + tuple(p.shape)) if node.op.ndim_supp > 0 else size + ) + for p in dist_params + ] + + return node.op.make_node(rng, None, dtype, *dist_params) + + +@register_canonicalize +@local_optimizer([DimShuffle]) +def local_dimshuffle_rv_lift(node): + """Lift `DimShuffle`s through `RandomVariable` `Op`s. + + For example, ``normal(mu, std).T == normal(mu.T, std.T)``. + """ + + ds_op = node.op + + if not isinstance(ds_op, DimShuffle): + return False + + rv_node = node.inputs[0].owner + if not ( + rv_node and isinstance(rv_node.op, RandomVariable) and rv_node.op.ndim_supp == 0 + ): + return False + + rv_op = rv_node.op + rng, size, dtype, *dist_params = rv_node.inputs + + lift_rv_shapes + + # We need to know the dimensions that were *not* added by the `size` + # parameter (i.e. the dimensions corresponding to independent variates with + # different parameter values) + num_ind_dims = None + if len(dist_params) == 1: + num_ind_dims = dist_params[0].ndim + else: + # When there is more than one distribution parameter, assume that all + # of them will broadcast to the maximum number of dimensions + num_ind_dims = max(d.ndim for d in dist_params) + + # If the indices in `ds_new_order` are entirely within the replication + # indices group or the independent variates indices group, then move + # forward. + ds_new_order = ds_op.new_order + dim_orders = [(n, d) for n, d in enumerate(ds_new_order) if isinstance(d, int)] + + reps_ind_split_idx = len(dim_orders) - (num_ind_dims + rv_op.ndim_supp) + + ds_reps_new_dims = dim_orders[:reps_ind_split_idx] + ds_ind_new_dims = dim_orders[reps_ind_split_idx:] + ds_only_in_ind = ds_ind_new_dims and all( + d >= reps_ind_split_idx for n, d in ds_ind_new_dims + ) + + if ds_only_in_ind: + + # Update the `size` array to reflect the `DimShuffle`d dimensions, + # since the trailing dimensions in `size` represent the independent + # variates dimensions (for univariate distributions, at least) + new_size = [ + tt.constant(1, dtype="int64") if o == "x" else size[o] for o in ds_new_order + ] + # Compute the new axes parameter(s) for the `DimShuffle` that will be + # applied to the `RandomVariable` parameters (they need to be offset) + rv_params_new_order = [ + d - reps_ind_split_idx if isinstance(d, int) else d + for d in ds_new_order[ds_ind_new_dims[0][0] :] + ] + # Lift the `DimShuffle`s into the parameters + # NOTE: The parameters might not be broadcasted against each other, so + # we can only apply the parts of the `DimShuffle` that are relevant. + new_dist_params = [] + for d in dist_params: + if d.ndim < len(ds_ind_new_dims): + _rv_params_new_order = [ + o + for o in rv_params_new_order + if (isinstance(o, int) and o < d.ndim) or o == "x" + ] + else: + _rv_params_new_order = rv_params_new_order + + new_dist_params.append( + type(ds_op)(d.type.broadcastable, _rv_params_new_order)(d) + ) + return [rv_op.make_node(rng, new_size, dtype, *new_dist_params).outputs[1]] + + ds_only_in_reps = ds_reps_new_dims and all( + d < reps_ind_split_idx for n, d in ds_reps_new_dims + ) + + if ds_only_in_reps: + # Update the `size` array to reflect the `DimShuffle`d dimensions. + # There should be no need to `DimShuffle` now. + new_size = [ + tt.constant(1, dtype="int64") if o == "x" else size[o] for o in ds_new_order + ] + return [rv_op.make_node(rng, new_size, dtype, *dist_params).outputs[1]] + + return False