diff --git a/distrax/_src/utils/math.py b/distrax/_src/utils/math.py index 9a4d373..86d88fb 100644 --- a/distrax/_src/utils/math.py +++ b/distrax/_src/utils/math.py @@ -14,10 +14,13 @@ # ============================================================================== """Utility math functions.""" +import functools from typing import Optional, Tuple import chex import jax +from jax import core as jax_core +from jax.custom_derivatives import SymbolicZero import jax.numpy as jnp Array = chex.Array @@ -44,7 +47,24 @@ def multiply_no_nan(x: Array, y: Array) -> Array: return jnp.where(y == 0, jnp.zeros((), dtype=dtype), x * y) -@multiply_no_nan.defjvp +# TODO(dougalm): move helpers like these into JAX AD utils +def add_maybe_symbolic(x, y): + if isinstance(x, SymbolicZero): + return y + elif isinstance(y, SymbolicZero): + return x + else: + return x + y + + +def scale_maybe_symbolic(result_aval, tangent, scale): + if isinstance(tangent, SymbolicZero): + return SymbolicZero(result_aval) + else: + return tangent * scale + + +@functools.partial(multiply_no_nan.defjvp, symbolic_zeros=True) def multiply_no_nan_jvp( primals: Tuple[Array, Array], tangents: Tuple[Array, Array]) -> Tuple[Array, Array]: @@ -52,8 +72,10 @@ def multiply_no_nan_jvp( x, y = primals x_dot, y_dot = tangents primal_out = multiply_no_nan(x, y) - tangent_out = y * x_dot + x * y_dot - return primal_out, tangent_out + result_aval = jax_core.get_type(primal_out).to_tangent_aval() + tangent_out_1 = scale_maybe_symbolic(result_aval, x_dot, y) + tangent_out_2 = scale_maybe_symbolic(result_aval, y_dot, x) + return primal_out, add_maybe_symbolic(tangent_out_1, tangent_out_2) @jax.custom_jvp