Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Dirichlet is not tolerant to floatX=float32 #6779

Closed
ferrine opened this issue Jun 15, 2023 · 0 comments · Fixed by #6780
Closed

BUG: Dirichlet is not tolerant to floatX=float32 #6779

ferrine opened this issue Jun 15, 2023 · 0 comments · Fixed by #6780

Comments

@ferrine
Copy link
Member

ferrine commented Jun 15, 2023

Describe the issue:

Dirichlet distribution ignores floatX config and creates float64 variables in the graph

Reproduceable code example:

import pymc as pm
import pytensor.tensor as pt
import pytensor

def test_dirichlet():
    with pm.Model() as model:
        c = pm.floatX([1, 1, 1])
        print(c, c.dtype)
        d = pm.Dirichlet("a", c)
    print(model.point_logps())
    
with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
    test_dirichlet()

Error message:

---------------------------------------------------------------------------
Exception                                 Traceback (most recent call last)
Cell In[10], line 2
      1 with pytensor.config.change_flags(warn_float64="raise", floatX="float32"):
----> 2     test_dirichlet()

Cell In[8], line 5, in test_dirichlet()
      3     c = pm.floatX([1, 1, 1])
      4     print(c, c.dtype)
----> 5     d = pm.Dirichlet("a", c)
      6 print(model.point_logps())

File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pymc/distributions/distribution.py:314, in Distribution.__new__(cls, name, rng, dims, initval, observed, total_size, transform, *args, **kwargs)
    310         kwargs["shape"] = tuple(observed.shape)
    312 rv_out = cls.dist(*args, **kwargs)
--> 314 rv_out = model.register_rv(
    315     rv_out,
    316     name,
    317     observed,
    318     total_size,
    319     dims=dims,
    320     transform=transform,
    321     initval=initval,
    322 )
    324 # add in pretty-printing support
    325 rv_out.str_repr = types.MethodType(str_for_dist, rv_out)

File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pymc/model.py:1333, in Model.register_rv(self, rv_var, name, observed, total_size, dims, transform, initval)
   1331     raise ValueError("total_size can only be passed to observed RVs")
   1332 self.free_RVs.append(rv_var)
-> 1333 self.create_value_var(rv_var, transform)
   1334 self.add_named_variable(rv_var, dims)
   1335 self.set_initval(rv_var, initval)

File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pymc/model.py:1526, in Model.create_value_var(self, rv_var, transform, value_var)
   1523         value_var.tag.test_value = rv_var.tag.test_value
   1524 else:
   1525     # Create value variable with the same type as the transformed RV
-> 1526     value_var = transform.forward(rv_var, *rv_var.owner.inputs).type()
   1527     value_var.name = f"{rv_var.name}_{transform.name}__"
   1528     value_var.tag.transform = transform

File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pymc/logprob/transforms.py:985, in SimplexTransform.forward(self, value, *inputs)
    983 def forward(self, value, *inputs):
    984     log_value = pt.log(value)
--> 985     shift = pt.sum(log_value, -1, keepdims=True) / value.shape[-1]
    986     return log_value[..., :-1] - shift

File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pytensor/tensor/var.py:173, in _tensor_py_operators.__truediv__(self, other)
    172 def __truediv__(self, other):
--> 173     return at.math.true_div(self, other)

File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pytensor/graph/op.py:295, in Op.__call__(self, *inputs, **kwargs)
    253 r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs.
    254 
    255 This method is just a wrapper around :meth:`Op.make_node`.
   (...)
    292 
    293 """
    294 return_list = kwargs.pop("return_list", False)
--> 295 node = self.make_node(*inputs, **kwargs)
    297 if config.compute_test_value != "off":
    298     compute_test_value(node)

File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pytensor/tensor/elemwise.py:486, in Elemwise.make_node(self, *inputs)
    484 inputs = [as_tensor_variable(i) for i in inputs]
    485 out_dtypes, out_shapes, inputs = self.get_output_info(DimShuffle, *inputs)
--> 486 outputs = [
    487     TensorType(dtype=dtype, shape=shape)()
    488     for dtype, shape in zip(out_dtypes, out_shapes)
    489 ]
    490 return Apply(self, inputs, outputs)

File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pytensor/tensor/elemwise.py:487, in <listcomp>(.0)
    484 inputs = [as_tensor_variable(i) for i in inputs]
    485 out_dtypes, out_shapes, inputs = self.get_output_info(DimShuffle, *inputs)
    486 outputs = [
--> 487     TensorType(dtype=dtype, shape=shape)()
    488     for dtype, shape in zip(out_dtypes, out_shapes)
    489 ]
    490 return Apply(self, inputs, outputs)

File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pytensor/graph/type.py:228, in Type.__call__(self, name)
    219 def __call__(self, name: Optional[str] = None) -> variable_type:
    220     """Return a new `Variable` instance of Type `self`.
    221 
    222     Parameters
   (...)
    226 
    227     """
--> 228     return utils.add_tag_trace(self.make_variable(name))

File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pytensor/graph/type.py:200, in Type.make_variable(self, name)
    191 def make_variable(self, name: Optional[str] = None) -> variable_type:
    192     """Return a new `Variable` instance of this `Type`.
    193 
    194     Parameters
   (...)
    198 
    199     """
--> 200     return self.variable_type(self, None, name=name)

File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pytensor/tensor/var.py:863, in TensorVariable.__init__(self, type, owner, index, name)
    861     warnings.warn(msg, stacklevel=1 + nb_rm)
    862 elif config.warn_float64 == "raise":
--> 863     raise Exception(msg)
    864 elif config.warn_float64 == "pdb":
    865     import pdb

Exception: You are creating a TensorVariable with float64 dtype. You requested an action via the PyTensor flag warn_float64={ignore,warn,raise,pdb}.
​

PyMC version information:

master

Context for the issue:

related to pymc-devs/pymc-extras#182

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant