Skip to content

Commit

Permalink
Future-proof reference to deprecated pytype_aval_mappings
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 730610446
  • Loading branch information
vanderplas authored and tensorflower-gardener committed Feb 24, 2025
1 parent 3ea485c commit 135080b
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions tensorflow_probability/python/internal/backend/numpy/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,8 +676,10 @@ def assign_sub(self, value, **_):
if JAX_MODE:
jax.interpreters.xla.canonicalize_dtype_handlers[NumpyVariable] = (
jax.interpreters.xla.canonicalize_dtype_handlers[onp.ndarray])
jax.interpreters.xla.pytype_aval_mappings[NumpyVariable] = (
jax.interpreters.xla.pytype_aval_mappings[onp.ndarray])
if hasattr(jax.interpreters.xla, 'pytype_aval_mappings'):
# Deprecated in JAX v0.5.0
jax.interpreters.xla.pytype_aval_mappings[NumpyVariable] = (
jax.interpreters.xla.pytype_aval_mappings[onp.ndarray])
jax.core.pytype_aval_mappings[NumpyVariable] = (
jax.core.pytype_aval_mappings[onp.ndarray])

Expand Down

0 comments on commit 135080b

Please sign in to comment.