Skip to content

Commit

Permalink
Implement DimShuffle lifting optimization for RandomVariables
Browse files Browse the repository at this point in the history
This optimization does *not* preserve equality between the numeric
results of the untransformed and transformed graphs when the RNGs and seeds are
equal.  The reason is that the underlying sampler methods themselves are not
implemented in Theano, so we cannot apply the requisite DimShuffle-like
operations to the intermediate samples used to generate multiple replications
and/or independent variates.

For example, sampling a normal of size (3, 2) requires a draw of size (3, 2)
from a standard normal and we can't transpose that (3, 2) array.  If we could,
then we would be able to maintain numerical equality between graphs.
  • Loading branch information
brandonwillard committed Nov 10, 2020
1 parent 1ead811 commit 7c36c55
Show file tree
Hide file tree
Showing 2 changed files with 268 additions and 7 deletions.
141 changes: 138 additions & 3 deletions tests/tensor/random/test_opt.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import numpy as np
import pytest

import theano.tensor as tt
from theano import change_flags, config, shared
from theano.compile.function import function
from theano.compile.mode import Mode
from theano.gof.optdb import Query
from theano.tensor.random.basic import normal
from theano.tensor.elemwise import DimShuffle
from theano.tensor.random.basic import dirichlet, multivariate_normal, normal
from theano.tensor.random.opt import lift_rv_shapes


opts = Query(include=["random_make_inplace"], exclude=[])
inplace_mode = Mode("py", opts)
inplace_mode = Mode("py", Query(include=["random_make_inplace"], exclude=[]))
canonicalize_mode = Mode("py", Query(include=["canonicalize"], exclude=[]))
no_mode = Mode("py", Query(include=[], exclude=[]))


def test_inplace_optimization():
Expand All @@ -30,3 +36,132 @@ def test_inplace_optimization():
np.array_equal(a.data, b.data)
for a, b in zip(new_out.owner.inputs[1:], out.owner.inputs[1:])
)


def check_shape_lifted_rv(rv, params, size, rng):
tt_params = []
for p in params:
p_tt = tt.as_tensor(p)
p_tt = p_tt.type()
p_tt.tag.test_value = p
tt_params.append(p_tt)

tt_size = []
for s in size:
s_tt = tt.as_tensor(s)
s_tt = s_tt.type()
s_tt.tag.test_value = s
tt_size.append(s_tt)

rv = rv(*tt_params, size=tt_size, rng=rng)
rv_lifted = lift_rv_shapes(rv.owner)

# Make sure the size input is empty
assert np.array_equal(rv_lifted.inputs[1].data, [])

f_ref = function(
tt_params + tt_size,
rv,
mode=no_mode,
)
f_lifted = function(
tt_params + tt_size,
rv_lifted.outputs[1],
mode=no_mode,
)
f_ref_val = f_ref(*(params + size))
f_lifted_val = f_lifted(*(params + size))
assert np.array_equal(f_ref_val, f_lifted_val)


@change_flags(compute_test_value="raise")
def test_lift_rv_shapes():

rng = shared(np.random.RandomState(1233532), borrow=False)

test_params = [
np.array(1.0, dtype=config.floatX),
np.array(5.0, dtype=config.floatX),
]
test_size = []
check_shape_lifted_rv(normal, test_params, test_size, rng)

test_params = [
np.array([0.0, 1.0], dtype=config.floatX),
np.array(5.0, dtype=config.floatX),
]
test_size = [3, 2]
check_shape_lifted_rv(normal, test_params, test_size, rng)

test_params = [
np.array([[0], [10], [100]], dtype=config.floatX),
np.diag(np.array([1e-6], dtype=config.floatX)),
]
test_size = [2, 3]
check_shape_lifted_rv(multivariate_normal, test_params, test_size, rng)

test_params = [
np.array([[100, 1, 1], [1, 100, 1], [1, 1, 100]], dtype=config.floatX)
]
test_size = [2, 3]
check_shape_lifted_rv(dirichlet, test_params, test_size, rng)


@pytest.mark.parametrize(
"new_dims,lifted",
[
((0, 2, 1), True),
(("x", 0, 2, 1, "x"), True),
(("x", 0, "x", 2, "x", 1, "x"), True),
(("x", 0, 2, 1, "x"), True),
(("x", 1, 0, 2, "x"), False),
],
)
@change_flags(compute_test_value="off")
def test_DimShuffle_lift(new_dims, lifted):

rng = shared(np.random.RandomState(1233532), borrow=False)

mean = tt.matrix("mean")
mean.tag.test_value = np.array([[-1, 20], [300, -4000]], dtype=config.floatX)
std = tt.matrix("std")
std.tag.test_value = np.array([[1e-6, 2e-6]], dtype=config.floatX)
s1 = tt.iscalar("size_d1")
s1.tag.test_value = 3
s2 = tt.iscalar("size_d2")
s2.tag.test_value = 2
s3 = tt.iscalar("size_d3")
s3.tag.test_value = 2
size = [s1, s2, s3]

norm_dm = normal(mean, std, size=size, rng=rng).dimshuffle(new_dims)

f_opt = function(
[mean, std] + size,
norm_dm,
mode=canonicalize_mode,
)

(new_out,) = f_opt.maker.fgraph.outputs

if lifted:
assert new_out.owner.op == normal
assert all(isinstance(i.owner.op, DimShuffle) for i in new_out.owner.inputs[3:])
else:
assert isinstance(new_out.owner.op, DimShuffle)
return

f_base = function(
[mean, std] + size,
norm_dm,
mode=no_mode,
)

res_base = f_base(
mean.tag.test_value, std.tag.test_value, *[s.tag.test_value for s in size]
)
res_opt = f_opt(
mean.tag.test_value, std.tag.test_value, *[s.tag.test_value for s in size]
)

assert np.allclose(res_base, res_opt, rtol=1e-3)
134 changes: 130 additions & 4 deletions theano/tensor/random/opt.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
import theano.tensor as tt
from theano.compile import optdb
from theano.gof.opt import local_optimizer
from theano.tensor.opt import in2out
from theano.tensor.elemwise import DimShuffle
from theano.tensor.extra_ops import broadcast_to
from theano.tensor.opt import in2out, register_canonicalize
from theano.tensor.random.op import RandomVariable
from theano.tensor.random.utils import broadcast_params


@local_optimizer([RandomVariable])
def random_make_inplace(node):
op = node.op

if isinstance(op, RandomVariable) and not op.inplace:

name, ndim_supp, ndims_params, dtype, _ = op._props()
new_op = type(op)(name, ndim_supp, ndims_params, dtype, True)
# rng, size, dtype, *dist_params = node.inputs

return new_op.make_node(*node.inputs).outputs

return False
Expand All @@ -26,3 +27,128 @@ def random_make_inplace(node):
"fast_run",
"inplace",
)


def lift_rv_shapes(node):
"""Lift `RandomVariable`'s shape-related parameters.
In other words, this will broadcast the distribution parameters and
extra dimensions added by the `size` parameter.
For example, ``normal([0.0, 1.0], 5.0, size=(3, 2))`` becomes
``normal([[0., 1.], [0., 1.], [0., 1.]], [[5., 5.], [5., 5.], [5., 5.]])``.
"""

if not isinstance(node.op, RandomVariable):
return False

rng, size, dtype, *dist_params = node.inputs

dist_params = broadcast_params(dist_params, node.op.ndims_params)

dist_params = [
broadcast_to(
p, (tuple(size) + tuple(p.shape)) if node.op.ndim_supp > 0 else size
)
for p in dist_params
]

return node.op.make_node(rng, None, dtype, *dist_params)


@register_canonicalize
@local_optimizer([DimShuffle])
def local_dimshuffle_rv_lift(node):
"""Lift `DimShuffle`s through `RandomVariable` `Op`s.
For example, ``normal(mu, std).T == normal(mu.T, std.T)``.
"""

ds_op = node.op

if not isinstance(ds_op, DimShuffle):
return False

rv_node = node.inputs[0].owner
if not (
rv_node and isinstance(rv_node.op, RandomVariable) and rv_node.op.ndim_supp == 0
):
return False

rv_op = rv_node.op
rng, size, dtype, *dist_params = rv_node.inputs

lift_rv_shapes

# We need to know the dimensions that were *not* added by the `size`
# parameter (i.e. the dimensions corresponding to independent variates with
# different parameter values)
num_ind_dims = None
if len(dist_params) == 1:
num_ind_dims = dist_params[0].ndim
else:
# When there is more than one distribution parameter, assume that all
# of them will broadcast to the maximum number of dimensions
num_ind_dims = max(d.ndim for d in dist_params)

# If the indices in `ds_new_order` are entirely within the replication
# indices group or the independent variates indices group, then move
# forward.
ds_new_order = ds_op.new_order
dim_orders = [(n, d) for n, d in enumerate(ds_new_order) if isinstance(d, int)]

reps_ind_split_idx = len(dim_orders) - (num_ind_dims + rv_op.ndim_supp)

ds_reps_new_dims = dim_orders[:reps_ind_split_idx]
ds_ind_new_dims = dim_orders[reps_ind_split_idx:]
ds_only_in_ind = ds_ind_new_dims and all(
d >= reps_ind_split_idx for n, d in ds_ind_new_dims
)

if ds_only_in_ind:

# Update the `size` array to reflect the `DimShuffle`d dimensions,
# since the trailing dimensions in `size` represent the independent
# variates dimensions (for univariate distributions, at least)
new_size = [
tt.constant(1, dtype="int64") if o == "x" else size[o] for o in ds_new_order
]
# Compute the new axes parameter(s) for the `DimShuffle` that will be
# applied to the `RandomVariable` parameters (they need to be offset)
rv_params_new_order = [
d - reps_ind_split_idx if isinstance(d, int) else d
for d in ds_new_order[ds_ind_new_dims[0][0] :]
]
# Lift the `DimShuffle`s into the parameters
# NOTE: The parameters might not be broadcasted against each other, so
# we can only apply the parts of the `DimShuffle` that are relevant.
new_dist_params = []
for d in dist_params:
if d.ndim < len(ds_ind_new_dims):
_rv_params_new_order = [
o
for o in rv_params_new_order
if (isinstance(o, int) and o < d.ndim) or o == "x"
]
else:
_rv_params_new_order = rv_params_new_order

new_dist_params.append(
type(ds_op)(d.type.broadcastable, _rv_params_new_order)(d)
)
return [rv_op.make_node(rng, new_size, dtype, *new_dist_params).outputs[1]]

ds_only_in_reps = ds_reps_new_dims and all(
d < reps_ind_split_idx for n, d in ds_reps_new_dims
)

if ds_only_in_reps:
# Update the `size` array to reflect the `DimShuffle`d dimensions.
# There should be no need to `DimShuffle` now.
new_size = [
tt.constant(1, dtype="int64") if o == "x" else size[o] for o in ds_new_order
]
return [rv_op.make_node(rng, new_size, dtype, *dist_params).outputs[1]]

return False

0 comments on commit 7c36c55

Please sign in to comment.