From 05462f2feb15381bff0427b72a9aa3fc11532cfc Mon Sep 17 00:00:00 2001 From: hottwaj Date: Wed, 26 Feb 2020 10:25:20 +0000 Subject: [PATCH] Initial changes to allow pymc3.Data() to support both int and float input data (previously all input data was coerced to float) WIP for #3813 --- pymc3/data.py | 14 ++++++++++++-- pymc3/model.py | 8 +++++--- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/pymc3/data.py b/pymc3/data.py index c638478b081..e39809123ef 100644 --- a/pymc3/data.py +++ b/pymc3/data.py @@ -478,10 +478,20 @@ class Data: For more information, take a look at this example notebook https://docs.pymc.io/notebooks/data_container.html """ - def __new__(self, name, value): + def __new__(self, name, value, dtype = None): + if dtype is None: + if hasattr(value, 'dtype'): + # if no dtype given, but available as attr of value, use that as dtype + dtype = value.dtype + elif isinstance(value, int): + dtype = int + else: + # otherwise, assume float + dtype = float + # `pm.model.pandas_to_array` takes care of parameter `value` and # transforms it to something digestible for pymc3 - shared_object = theano.shared(pm.model.pandas_to_array(value), name) + shared_object = theano.shared(pm.model.pandas_to_array(value, dtype = dtype), name) # To draw the node for this variable in the graphviz Digraph we need # its shape. diff --git a/pymc3/model.py b/pymc3/model.py index 3de6e4f380e..a4616fd9f3c 100644 --- a/pymc3/model.py +++ b/pymc3/model.py @@ -1473,7 +1473,7 @@ def init_value(self): return self.tag.test_value -def pandas_to_array(data): +def pandas_to_array(data, dtype = float): if hasattr(data, 'values'): # pandas if data.isnull().any().any(): # missing values ret = np.ma.MaskedArray(data.values, data.isnull().values) @@ -1492,8 +1492,10 @@ def pandas_to_array(data): ret = generator(data) else: ret = np.asarray(data) - return pm.floatX(ret) - + if dtype in [float, np.float32, np.float64]: + return pm.floatX(ret) + elif dtype in [int, np.int32, np.int64]: + return pm.intX(ret) def as_tensor(data, name, model, distribution): dtype = distribution.dtype