Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: ZeroSumNormal unintentionally upcasts to float64 #6886

Closed
lancechua opened this issue Aug 31, 2023 · 1 comment · Fixed by #6889
Closed

BUG: ZeroSumNormal unintentionally upcasts to float64 #6886

lancechua opened this issue Aug 31, 2023 · 1 comment · Fixed by #6889
Labels

Comments

@lancechua
Copy link
Contributor

Describe the issue:

ZeroSumNormal does not follow floatX when it is set to float32.

The issue has been mentioned here, but the PR (merged in 5.6.0) doesn't seem to resolve it.

Also related: #6871

Reproduceable code example:

import pymc as pm
import pytensor


with pytensor.config.change_flags(floatX="float32", warn_float64="raise"):
    with pm.Model():
        # no issues here
        a = pm.Normal("a", 0, 1)

    with pm.Model():
        # errors here    
        b = pm.ZeroSumNormal("b", 1, shape=(5,))

Error message:

---------------------------------------------------------------------------
Exception                                 Traceback (most recent call last)
Input In [10], in <cell line: 5>()
      8     a = pm.Normal("a", 0, 1)
     10 with pm.Model():
     11     # errors here    
---> 12     b = pm.ZeroSumNormal("b", 1, shape=(5,))

File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pymc/distributions/multivariate.py:2512, in ZeroSumNormal.__new__(cls, zerosum_axes, n_zerosum_axes, support_shape, dims, *args, **kwargs)
   2502     n_zerosum_axes = cls.check_zerosum_axes(n_zerosum_axes)
   2504     support_shape = get_support_shape(
   2505         support_shape=support_shape,
   2506         shape=None,  # Shape will be checked in `cls.dist`
   (...)
   2509         ndim_supp=n_zerosum_axes,
   2510     )
-> 2512 return super().__new__(
   2513     cls,
   2514     *args,
   2515     n_zerosum_axes=n_zerosum_axes,
   2516     support_shape=support_shape,
   2517     dims=dims,
   2518     **kwargs,
   2519 )

File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pymc/distributions/distribution.py:316, in Distribution.__new__(cls, name, rng, dims, initval, observed, total_size, transform, *args, **kwargs)
    312         kwargs["shape"] = tuple(observed.shape)
    314 rv_out = cls.dist(*args, **kwargs)
--> 316 rv_out = model.register_rv(
    317     rv_out,
    318     name,
    319     observed,
    320     total_size,
    321     dims=dims,
    322     transform=transform,
    323     initval=initval,
    324 )
    326 # add in pretty-printing support
    327 rv_out.str_repr = types.MethodType(str_for_dist, rv_out)

File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pymc/model.py:1289, in Model.register_rv(self, rv_var, name, observed, total_size, dims, transform, initval)
   1287     raise ValueError("total_size can only be passed to observed RVs")
   1288 self.free_RVs.append(rv_var)
-> 1289 self.create_value_var(rv_var, transform)
   1290 self.add_named_variable(rv_var, dims)
   1291 self.set_initval(rv_var, initval)

File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pymc/model.py:1442, in Model.create_value_var(self, rv_var, transform, value_var)
   1439         value_var.tag.test_value = rv_var.tag.test_value
   1440 else:
   1441     # Create value variable with the same type as the transformed RV
-> 1442     value_var = transform.forward(rv_var, *rv_var.owner.inputs).type()
   1443     value_var.name = f"{rv_var.name}_{transform.name}__"
   1444     value_var.tag.transform = transform

File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pymc/distributions/transforms.py:320, in ZeroSumTransform.forward(self, value, *rv_inputs)
    318 def forward(self, value, *rv_inputs):
    319     for axis in self.zerosum_axes:
--> 320         value = extend_axis_rev(value, axis=axis)
    321     return value

File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pymc/distributions/transforms.py:348, in extend_axis_rev(array, axis)
    345 n = array.shape[normalized_axis]
    346 last = pt.take(array, [-1], axis=normalized_axis)
--> 348 sum_vals = -last * pt.sqrt(n)
    349 norm = sum_vals / (pt.sqrt(n) + n)
    350 slice_before = (slice(None, None),) * normalized_axis

File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pytensor/graph/op.py:304, in Op.__call__(self, *inputs, **kwargs)
    262 r"""Construct an `Apply` node using :meth:`Op.make_node` and return its outputs.
    263 
    264 This method is just a wrapper around :meth:`Op.make_node`.
   (...)
    301 
    302 """
    303 return_list = kwargs.pop("return_list", False)
--> 304 node = self.make_node(*inputs, **kwargs)
    306 if config.compute_test_value != "off":
    307     compute_test_value(node)

File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pytensor/tensor/elemwise.py:497, in Elemwise.make_node(self, *inputs)
    495 inputs = [as_tensor_variable(i) for i in inputs]
    496 out_dtypes, out_shapes, inputs = self.get_output_info(DimShuffle, *inputs)
--> 497 outputs = [
    498     TensorType(dtype=dtype, shape=shape)()
    499     for dtype, shape in zip(out_dtypes, out_shapes)
    500 ]
    501 return Apply(self, inputs, outputs)

File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pytensor/tensor/elemwise.py:498, in <listcomp>(.0)
    495 inputs = [as_tensor_variable(i) for i in inputs]
    496 out_dtypes, out_shapes, inputs = self.get_output_info(DimShuffle, *inputs)
    497 outputs = [
--> 498     TensorType(dtype=dtype, shape=shape)()
    499     for dtype, shape in zip(out_dtypes, out_shapes)
    500 ]
    501 return Apply(self, inputs, outputs)

File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pytensor/graph/type.py:228, in Type.__call__(self, name)
    219 def __call__(self, name: Optional[str] = None) -> variable_type:
    220     """Return a new `Variable` instance of Type `self`.
    221 
    222     Parameters
   (...)
    226 
    227     """
--> 228     return utils.add_tag_trace(self.make_variable(name))

File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pytensor/graph/type.py:200, in Type.make_variable(self, name)
    191 def make_variable(self, name: Optional[str] = None) -> variable_type:
    192     """Return a new `Variable` instance of this `Type`.
    193 
    194     Parameters
   (...)
    198 
    199     """
--> 200     return self.variable_type(self, None, name=name)

File ~/miniconda3/envs/rova_dev/lib/python3.10/site-packages/pytensor/tensor/var.py:860, in TensorVariable.__init__(self, type, owner, index, name)
    858     warnings.warn(msg, stacklevel=1 + nb_rm)
    859 elif config.warn_float64 == "raise":
--> 860     raise Exception(msg)
    861 elif config.warn_float64 == "pdb":
    862     import pdb

Exception: You are creating a TensorVariable with float64 dtype. You requested an action via the PyTensor flag warn_float64={ignore,warn,raise,pdb}.

PyMC version information:

PyMC/PyMC3 Version: 5.7.2
PyTensor/Aesara Version: 2.14.2
Python Version: 3.10.12
Operating system: Darwin arm64
How did you install PyMC/PyMC3: pip

Context for the issue:

ZeroSumNormal is needed for hierarchical models, and using float32 would be preferable since it gives sufficient precision with less memory.

@lancechua lancechua added the bug label Aug 31, 2023
@welcome
Copy link

welcome bot commented Aug 31, 2023

Welcome Banner
🎉 Welcome to PyMC! 🎉 We're really excited to have your input into the project! 💖

If you haven't done so already, please make sure you check out our Contributing Guidelines and Code of Conduct.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant