Skip to content

Commit

Permalink
Clean up and fix primal type to tangent type mapping
Browse files Browse the repository at this point in the history
This is part of the ["stackless"](#23299) change. I'm splitting it out into a separate PR because we need it for some work on sharding types.

Changes:
  1. Rename `at_least_vspace` to `to_tangent_type` since that's what we always meant by it. `at_least_vspace` was always a bad name (sorry!) but it makes even less sense when you can have a special tangent type for a primal types that's already a vector space itself.
  2. Replace `Zero.from_value` with `Zero.from_primal_value`, which does the required primal-type-to-tangent-type conversion.
  3. Add `to_tangent_type` calls in various other places they're missing.
  4. Remove non-support for float0 in custom deriviatives?
  5. [Optional, WIP] Reinstate some checks that had been skipped over, presumably because of these bugs. (We'll see how far I get with it. Might end up being a separate PR.)
PiperOrigin-RevId: 675606346
  • Loading branch information
dougalm authored and DistraxDev committed Sep 18, 2024
1 parent 0e44982 commit ea61ce2
Showing 1 changed file with 25 additions and 3 deletions.
28 changes: 25 additions & 3 deletions distrax/_src/utils/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@
# ==============================================================================
"""Utility math functions."""

import functools
from typing import Optional, Tuple

import chex
import jax
from jax import core as jax_core
from jax.custom_derivatives import SymbolicZero
import jax.numpy as jnp

Array = chex.Array
Expand All @@ -44,16 +47,35 @@ def multiply_no_nan(x: Array, y: Array) -> Array:
return jnp.where(y == 0, jnp.zeros((), dtype=dtype), x * y)


@multiply_no_nan.defjvp
# TODO(dougalm): move helpers like these into JAX AD utils
def add_maybe_symbolic(x, y):
if isinstance(x, SymbolicZero):
return y
elif isinstance(y, SymbolicZero):
return x
else:
return x + y


def scale_maybe_symbolic(result_aval, tangent, scale):
if isinstance(tangent, SymbolicZero):
return SymbolicZero(result_aval)
else:
return tangent * scale


@functools.partial(multiply_no_nan.defjvp, symbolic_zeros=True)
def multiply_no_nan_jvp(
primals: Tuple[Array, Array],
tangents: Tuple[Array, Array]) -> Tuple[Array, Array]:
"""Custom gradient computation for `multiply_no_nan`."""
x, y = primals
x_dot, y_dot = tangents
primal_out = multiply_no_nan(x, y)
tangent_out = y * x_dot + x * y_dot
return primal_out, tangent_out
result_aval = jax_core.get_type(primal_out).to_tangent_aval()
tangent_out_1 = scale_maybe_symbolic(result_aval, x_dot, y)
tangent_out_2 = scale_maybe_symbolic(result_aval, y_dot, x)
return primal_out, add_maybe_symbolic(tangent_out_1, tangent_out_2)


@jax.custom_jvp
Expand Down

0 comments on commit ea61ce2

Please sign in to comment.