Skip to content

Commit

Permalink
fixed (some) handling of symbolic scalar shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonwillard committed May 30, 2016
1 parent 9c72a85 commit cc71de3
Showing 1 changed file with 51 additions and 18 deletions.
69 changes: 51 additions & 18 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,32 @@
'MultivariateDiscrete', 'UnivariateDiscrete', 'NoDistribution']


def _as_tensor_shape_variable(var):
""" Just a collection of useful shape stuff from
`_infer_ndim_bcast` """

if var is None:
return T.constant([], dtype='int64')

res = var
if isinstance(res, (tuple, list)):
if len(res) == 0:
return T.constant([], dtype='int64')
res = T.as_tensor_variable(res, ndim=1)

else:
if res.ndim != 1:
raise TypeError("shape must be a vector or list of scalar, got\
'%s'" % res)

if (not (res.dtype.startswith('int') or
res.dtype.startswith('uint'))):

raise TypeError('shape must be an integer vector or list',
res.dtype)
return res


class Distribution(object):
"""Statistical distribution"""
def __new__(cls, name, *args, **kwargs):
Expand Down Expand Up @@ -40,7 +66,8 @@ def dist(cls, *args, **kwargs):
dist.__init__(*args, **kwargs)
return dist

def __init__(self, shape_supp, shape_ind, shape_reps, bcast, dtype, testval=None, defaults=[], transform=None):
def __init__(self, shape_supp, shape_ind, shape_reps, bcast, dtype,
testval=None, defaults=None, transform=None):
r"""
Distributions are specified in terms of the shape of their support, the shape
of the space of independent instances and the shape of the space of replications.
Expand Down Expand Up @@ -114,23 +141,29 @@ def __init__(self, shape_supp, shape_ind, shape_reps, bcast, dtype, testval=None
A transform function
"""

self.shape_supp = T.cast(T.as_tensor_variable(shape_supp, ndim=1), 'int64')
self.shape_ind = T.cast(T.as_tensor_variable(shape_ind, ndim=1), 'int64')
self.shape_reps = T.cast(T.as_tensor_variable(shape_reps, ndim=1), 'int64')
self.shape = T.concatenate((self.shape_supp, self.shape_ind, self.shape_reps))

#
# Following Theano Op's handling of shape parameters (well, at least
# theano.tensor.raw_random.RandomFunction.make_node).
#
#if self.shape.type.ndim != 1 or \
# not (self.shape.type.dtype == 'int64') and \
# not (self.shape.type.dtype == 'int32'):
# raise TypeError("Expected int elements in shape")

#self.ndim = self.shape.ndim
#self.dtype = dtype
#self.bcast = bcast
self.shape_supp = _as_tensor_shape_variable(shape_supp)
self.ndim_supp = T.get_vector_length(self.shape_supp)
self.shape_ind = _as_tensor_shape_variable(shape_ind)
self.ndim_ind = T.get_vector_length(self.shape_ind)
self.shape_reps = _as_tensor_shape_variable(shape_reps)
self.ndim_reps = T.get_vector_length(self.shape_reps)

ndim_sum = self.ndim_supp + self.ndim_ind + self.ndim_reps
if ndim_sum == 0:
self.shape = T.constant([], dtype='int64')
else:
self.shape = tuple(self.shape_reps) +\
tuple(self.shape_ind) +\
tuple(self.shape_supp)

if testval is None:
if ndim_sum == 0:
testval = T.constant(0, dtype=dtype)
else:
testval = T.zeros(self.shape)

self.ndim = T.get_vector_length(self.shape)

self.testval = testval
self.defaults = defaults
self.transform = transform
Expand Down

0 comments on commit cc71de3

Please sign in to comment.