You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I updated to Numpy 2.0 and found that the MultivariateNormalDiag and MultivariateNormalFullCovariance constructors crashed because np.issctype has been removed. Is Numpy 2.0 supported, or will it be soon?
Here is a simple repro:
import numpy as np
import jax
import jax.numpy as jnp
from tensorflow_probability.substrates import jax as tfp
tfd = tfp.distributions
print(f"numpy version: {np.__version__}")
print(f"jax version: {jax.__version__}")
print(f"tfp verison: {tfp.__version__}")
# works fine
nml = tfd.Normal(jnp.zeros(3), jnp.ones(3))
# fails with numpy 2.0
mvn = tfd.MultivariateNormalDiag(jnp.zeros(3), jnp.ones(3))
# also fails with same error numpy 2.0
# mvn = tfd.MultivariateNormalFullCovariance(jnp.zeros(3), jnp.eye(3))
On my machine with Python 3.10, it produces the following output:
numpy version: 2.0.0
jax version: 0.4.30
tfp verison: 0.24.0
Traceback (most recent call last):
File "/Users/scott/Projects/dynamax/tfp_debug_20240618.py", line 15, in <module>
mvn = tfd.MultivariateNormalDiag(jnp.zeros(3), jnp.ones(3))
File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/decorator.py", line 232, in fun
return caller(func, *(extras + args), **kw)
File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py", line 342, in wrapped_init
default_init(self_, *args, **kwargs)
File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_diag.py", line 209, in __init__
super(MultivariateNormalDiag, self).__init__(
File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/decorator.py", line 232, in fun
return caller(func, *(extras + args), **kw)
File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py", line 342, in wrapped_init
default_init(self_, *args, **kwargs)
File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/mvn_linear_operator.py", line 205, in __init__
super(MultivariateNormalLinearOperator, self).__init__(
File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/decorator.py", line 232, in fun
return caller(func, *(extras + args), **kw)
File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/distribution.py", line 342, in wrapped_init
default_init(self_, *args, **kwargs)
File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/distributions/transformed_distribution.py", line 244, in __init__
dtype = self.bijector.forward_dtype(self.distribution.dtype)
File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py", line 1705, in forward_dtype
input_dtype = nest.map_structure_up_to(
File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py", line 324, in map_structure_up_to
return map_structure_with_tuple_paths_up_to(
File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py", line 353, in map_structure_with_tuple_paths_up_to
return dm_tree.map_structure_with_path_up_to(
File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tree/__init__.py", line 778, in map_structure_with_path_up_to
results.append(func(*path_and_values))
File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/python/internal/backend/jax/nest.py", line 326, in <lambda>
lambda _, *args: func(*args), # Discards path.
File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/bijectors/bijector.py", line 1707, in <lambda>
lambda x: dtype_util.convert_to_dtype(x, dtype=self.dtype),
File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/tensorflow_probability/substrates/jax/internal/dtype_util.py", line 247, in convert_to_dtype
elif np.issctype(tensor_or_dtype):
File "/Users/scott/anaconda3/envs/dynamax/lib/python3.10/site-packages/numpy/__init__.py", line 397, in __getattr__
raise AttributeError(
AttributeError: `np.issctype` was removed in the NumPy 2.0 release. Use `issubclass(rep, np.generic)` instead.. Did you mean: 'isdtype'?
The text was updated successfully, but these errors were encountered:
I updated to Numpy 2.0 and found that the
MultivariateNormalDiag
andMultivariateNormalFullCovariance
constructors crashed becausenp.issctype
has been removed. Is Numpy 2.0 supported, or will it be soon?Here is a simple repro:
On my machine with Python 3.10, it produces the following output:
The text was updated successfully, but these errors were encountered: