We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ZeroSumNormal does not follow floatX when it is set to float32.
The issue has been mentioned here, but the PR (merged in 5.6.0) doesn't seem to resolve it.
Also related: #6871
import pymc as pm import pytensor with pytensor.config.change_flags(floatX="float32", warn_float64="raise"): with pm.Model(): # no issues here a = pm.Normal("a", 0, 1) with pm.Model(): # errors here b = pm.ZeroSumNormal("b", 1, shape=(5,))
--------------------------------------------------------------------------- Exception Traceback (most recent call last) Input In [10], in <cell line: 5>() 8 a = pm.Normal("a", 0, 1) 10 with pm.Model(): 11 # errors here ---> 12 b = pm.ZeroSumNormal("b", 1, shape=(5,)) File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pymc/distributions/multivariate.py:2512, in ZeroSumNormal.__new__(cls, zerosum_axes, n_zerosum_axes, support_shape, dims, *args, **kwargs) 2502 n_zerosum_axes = cls.check_zerosum_axes(n_zerosum_axes) 2504 support_shape = get_support_shape( 2505 support_shape=support_shape, 2506 shape=None, # Shape will be checked in `cls.dist` (...) 2509 ndim_supp=n_zerosum_axes, 2510 ) -> 2512 return super().__new__( 2513 cls, 2514 *args, 2515 n_zerosum_axes=n_zerosum_axes, 2516 support_shape=support_shape, 2517 dims=dims, 2518 **kwargs, 2519 ) File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pymc/distributions/distribution.py:316, in Distribution.__new__(cls, name, rng, dims, initval, observed, total_size, transform, *args, **kwargs) 312 kwargs["shape"] = tuple(observed.shape) 314 rv_out = cls.dist(*args, **kwargs) --> 316 rv_out = model.register_rv( 317 rv_out, 318 name, 319 observed, 320 total_size, 321 dims=dims, 322 transform=transform, 323 initval=initval, 324 ) 326 # add in pretty-printing support 327 rv_out.str_repr = types.MethodType(str_for_dist, rv_out) File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pymc/model.py:1289, in Model.register_rv(self, rv_var, name, observed, total_size, dims, transform, initval) 1287 raise ValueError("total_size can only be passed to observed RVs") 1288 self.free_RVs.append(rv_var) -> 1289 self.create_value_var(rv_var, transform) 1290 self.add_named_variable(rv_var, dims) 1291 self.set_initval(rv_var, initval) File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pymc/model.py:1442, in Model.create_value_var(self, rv_var, transform, value_var) 1439 value_var.tag.test_value = rv_var.tag.test_value 1440 else: 1441 # Create value variable with the same type as the transformed RV -> 1442 value_var = transform.forward(rv_var, *rv_var.owner.inputs).type() 1443 value_var.name = f"{rv_var.name}_{transform.name}__" 1444 value_var.tag.transform = transform File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pymc/distributions/transforms.py:320, in ZeroSumTransform.forward(self, value, *rv_inputs) 318 def forward(self, value, *rv_inputs): 319 for axis in self.zerosum_axes: --> 320 value = extend_axis_rev(value, axis=axis) 321 return value File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pymc/distributions/transforms.py:348, in extend_axis_rev(array, axis) 345 n = array.shape[normalized_axis] 346 last = pt.take(array, [-1], axis=normalized_axis) --> 348 sum_vals = -last * pt.sqrt(n) 349 norm = sum_vals / (pt.sqrt(n) + n) 350 slice_before = (slice(None, None),) * normalized_axis File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pytensor/graph/op.py:304, in Op.__call__(self, *inputs, **kwargs) 262 r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs. 263 264 This method is just a wrapper around :meth:`Op.make_node`. (...) 301 302 """ 303 return_list = kwargs.pop("return_list", False) --> 304 node = self.make_node(*inputs, **kwargs) 306 if config.compute_test_value != "off": 307 compute_test_value(node) File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pytensor/tensor/elemwise.py:497, in Elemwise.make_node(self, *inputs) 495 inputs = [as_tensor_variable(i) for i in inputs] 496 out_dtypes, out_shapes, inputs = self.get_output_info(DimShuffle, *inputs) --> 497 outputs = [ 498 TensorType(dtype=dtype, shape=shape)() 499 for dtype, shape in zip(out_dtypes, out_shapes) 500 ] 501 return Apply(self, inputs, outputs) File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pytensor/tensor/elemwise.py:498, in <listcomp>(.0) 495 inputs = [as_tensor_variable(i) for i in inputs] 496 out_dtypes, out_shapes, inputs = self.get_output_info(DimShuffle, *inputs) 497 outputs = [ --> 498 TensorType(dtype=dtype, shape=shape)() 499 for dtype, shape in zip(out_dtypes, out_shapes) 500 ] 501 return Apply(self, inputs, outputs) File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pytensor/graph/type.py:228, in Type.__call__(self, name) 219 def __call__(self, name: Optional[str] = None) -> variable_type: 220 """Return a new `Variable` instance of Type `self`. 221 222 Parameters (...) 226 227 """ --> 228 return utils.add_tag_trace(self.make_variable(name)) File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pytensor/graph/type.py:200, in Type.make_variable(self, name) 191 def make_variable(self, name: Optional[str] = None) -> variable_type: 192 """Return a new `Variable` instance of this `Type`. 193 194 Parameters (...) 198 199 """ --> 200 return self.variable_type(self, None, name=name) File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pytensor/tensor/var.py:860, in TensorVariable.__init__(self, type, owner, index, name) 858 warnings.warn(msg, stacklevel=1 + nb_rm) 859 elif config.warn_float64 == "raise": --> 860 raise Exception(msg) 861 elif config.warn_float64 == "pdb": 862 import pdb Exception: You are creating a TensorVariable with float64 dtype. You requested an action via the PyTensor flag warn_float64={ignore,warn,raise,pdb}.
PyMC/PyMC3 Version: 5.7.2 PyTensor/Aesara Version: 2.14.2 Python Version: 3.10.12 Operating system: Darwin arm64 How did you install PyMC/PyMC3: pip
ZeroSumNormal is needed for hierarchical models, and using float32 would be preferable since it gives sufficient precision with less memory.
ZeroSumNormal
float32
The text was updated successfully, but these errors were encountered:
🎉 Welcome to PyMC! 🎉 We're really excited to have your input into the project! 💖 If you haven't done so already, please make sure you check out our Contributing Guidelines and Code of Conduct.
Sorry, something went wrong.
config.floatX
Successfully merging a pull request may close this issue.
Describe the issue:
ZeroSumNormal does not follow floatX when it is set to float32.
The issue has been mentioned here, but the PR (merged in 5.6.0) doesn't seem to resolve it.
Also related: #6871
Reproduceable code example:
Error message:
PyMC version information:
PyMC/PyMC3 Version: 5.7.2
PyTensor/Aesara Version: 2.14.2
Python Version: 3.10.12
Operating system: Darwin arm64
How did you install PyMC/PyMC3: pip
Context for the issue:
ZeroSumNormal
is needed for hierarchical models, and usingfloat32
would be preferable since it gives sufficient precision with less memory.The text was updated successfully, but these errors were encountered: