We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dirichlet distribution ignores floatX config and creates float64 variables in the graph
import pymc as pm import pytensor.tensor as pt import pytensor def test_dirichlet(): with pm.Model() as model: c = pm.floatX([1, 1, 1]) print(c, c.dtype) d = pm.Dirichlet("a", c) print(model.point_logps()) with pytensor.config.change_flags(warn_float64="raise", floatX="float32"): test_dirichlet()
--------------------------------------------------------------------------- Exception Traceback (most recent call last) Cell In[10], line 2 1 with pytensor.config.change_flags(warn_float64="raise", floatX="float32"): ----> 2 test_dirichlet() Cell In[8], line 5, in test_dirichlet() 3 c = pm.floatX([1, 1, 1]) 4 print(c, c.dtype) ----> 5 d = pm.Dirichlet("a", c) 6 print(model.point_logps()) File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pymc/distributions/distribution.py:314, in Distribution.__new__(cls, name, rng, dims, initval, observed, total_size, transform, *args, **kwargs) 310 kwargs["shape"] = tuple(observed.shape) 312 rv_out = cls.dist(*args, **kwargs) --> 314 rv_out = model.register_rv( 315 rv_out, 316 name, 317 observed, 318 total_size, 319 dims=dims, 320 transform=transform, 321 initval=initval, 322 ) 324 # add in pretty-printing support 325 rv_out.str_repr = types.MethodType(str_for_dist, rv_out) File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pymc/model.py:1333, in Model.register_rv(self, rv_var, name, observed, total_size, dims, transform, initval) 1331 raise ValueError("total_size can only be passed to observed RVs") 1332 self.free_RVs.append(rv_var) -> 1333 self.create_value_var(rv_var, transform) 1334 self.add_named_variable(rv_var, dims) 1335 self.set_initval(rv_var, initval) File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pymc/model.py:1526, in Model.create_value_var(self, rv_var, transform, value_var) 1523 value_var.tag.test_value = rv_var.tag.test_value 1524 else: 1525 # Create value variable with the same type as the transformed RV -> 1526 value_var = transform.forward(rv_var, *rv_var.owner.inputs).type() 1527 value_var.name = f"{rv_var.name}_{transform.name}__" 1528 value_var.tag.transform = transform File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pymc/logprob/transforms.py:985, in SimplexTransform.forward(self, value, *inputs) 983 def forward(self, value, *inputs): 984 log_value = pt.log(value) --> 985 shift = pt.sum(log_value, -1, keepdims=True) / value.shape[-1] 986 return log_value[..., :-1] - shift File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pytensor/tensor/var.py:173, in _tensor_py_operators.__truediv__(self, other) 172 def __truediv__(self, other): --> 173 return at.math.true_div(self, other) File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pytensor/graph/op.py:295, in Op.__call__(self, *inputs, **kwargs) 253 r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs. 254 255 This method is just a wrapper around :meth:`Op.make_node`. (...) 292 293 """ 294 return_list = kwargs.pop("return_list", False) --> 295 node = self.make_node(*inputs, **kwargs) 297 if config.compute_test_value != "off": 298 compute_test_value(node) File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pytensor/tensor/elemwise.py:486, in Elemwise.make_node(self, *inputs) 484 inputs = [as_tensor_variable(i) for i in inputs] 485 out_dtypes, out_shapes, inputs = self.get_output_info(DimShuffle, *inputs) --> 486 outputs = [ 487 TensorType(dtype=dtype, shape=shape)() 488 for dtype, shape in zip(out_dtypes, out_shapes) 489 ] 490 return Apply(self, inputs, outputs) File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pytensor/tensor/elemwise.py:487, in <listcomp>(.0) 484 inputs = [as_tensor_variable(i) for i in inputs] 485 out_dtypes, out_shapes, inputs = self.get_output_info(DimShuffle, *inputs) 486 outputs = [ --> 487 TensorType(dtype=dtype, shape=shape)() 488 for dtype, shape in zip(out_dtypes, out_shapes) 489 ] 490 return Apply(self, inputs, outputs) File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pytensor/graph/type.py:228, in Type.__call__(self, name) 219 def __call__(self, name: Optional[str] = None) -> variable_type: 220 """Return a new `Variable` instance of Type `self`. 221 222 Parameters (...) 226 227 """ --> 228 return utils.add_tag_trace(self.make_variable(name)) File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pytensor/graph/type.py:200, in Type.make_variable(self, name) 191 def make_variable(self, name: Optional[str] = None) -> variable_type: 192 """Return a new `Variable` instance of this `Type`. 193 194 Parameters (...) 198 199 """ --> 200 return self.variable_type(self, None, name=name) File ~/micromamba/envs/pymc-blog/lib/python3.9/site-packages/pytensor/tensor/var.py:863, in TensorVariable.__init__(self, type, owner, index, name) 861 warnings.warn(msg, stacklevel=1 + nb_rm) 862 elif config.warn_float64 == "raise": --> 863 raise Exception(msg) 864 elif config.warn_float64 == "pdb": 865 import pdb Exception: You are creating a TensorVariable with float64 dtype. You requested an action via the PyTensor flag warn_float64={ignore,warn,raise,pdb}.
master
related to pymc-devs/pymc-extras#182
The text was updated successfully, but these errors were encountered:
local_sum_make_vector
Successfully merging a pull request may close this issue.
Describe the issue:
Dirichlet distribution ignores floatX config and creates float64 variables in the graph
Reproduceable code example:
Error message:
PyMC version information:
master
Context for the issue:
related to pymc-devs/pymc-extras#182
The text was updated successfully, but these errors were encountered: