Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

unexpected vmap error due to commit c36e1f7 #25289

Open
marcocuturi opened this issue Dec 5, 2024 · 10 comments
Open

unexpected vmap error due to commit c36e1f7 #25289

marcocuturi opened this issue Dec 5, 2024 · 10 comments
Assignees
Labels
bug Something isn't working

Comments

@marcocuturi
Copy link

Description

Hi,
@michalk8 and I noticed a bug in our tests here that occurs in the latest JAX version. After doing git-bisect, we found the bad commit to be: c36e1f7.
Here's the traceback with JAX_TRACEBACK_FILTERING=off:

self = <sinkhorn_diff_test.TestSinkhornHessian object at 0x10ff434c0>, rng = Array((), dtype=key<fry>) overlaying:
[0 0], lse_mode = False, tau_a = 0.93, tau_b = 0.91, arg = 1, lineax_ridge = 1e-05

    @pytest.mark.fast.with_args(
        "lse_mode,tau_a,tau_b,arg,lineax_ridge", (
            (True, 1.0, 1.0, 0, 0.0),
            (False, 1.0, 1.0, 0, 1e-8),
            (True, 1.0, 1.0, 1, 0.0),
            (True, 1.0, 0.91, 0, 1e-7),
            (True, 0.93, 0.91, 1, 0.0),
            (False, 0.93, 0.91, 1, 1e-5),
        ),
        only_fast=-1
    )
    def test_hessian_sinkhorn(
        self, rng: jax.Array, lse_mode: bool, tau_a: float, tau_b: float,
        arg: int, lineax_ridge: float
    ):
      """Test hessian w.r.t. weights and locations."""
      try:
        from ott.solvers.linear import lineax_implicit  # noqa: F401
        test_back = True
        ridge = lineax_ridge
      except ImportError:
        test_back = False
        ridge = 1e-5
    
      n, m = (12, 15)
      dim = 3
      rngs = jax.random.split(rng, 6)
      x = jax.random.uniform(rngs[0], (n, dim))
      y = jax.random.uniform(rngs[1], (m, dim))
      a = jax.random.uniform(rngs[2], (n,)) + 0.1
      b = jax.random.uniform(rngs[3], (m,)) + 0.1
      a = a / jnp.sum(a)
      b = b / jnp.sum(b)
      epsilon = 0.1
    
      # Add a ridge when using JAX solvers, smaller ridge for lineax solvers
      solver_kwargs = {
          "ridge_identity": ridge,
          "ridge_kernel": ridge if tau_a == tau_b == 1.0 else 0.0
      }
    
      imp_dif = implicit_lib.ImplicitDiff(solver_kwargs=solver_kwargs)
    
      def loss(a: jnp.ndarray, x: jnp.ndarray, implicit: bool = True):
        geom = pointcloud.PointCloud(x, y, epsilon=epsilon)
        prob = linear_problem.LinearProblem(geom, a, b, tau_a, tau_b)
        implicit_diff = imp_dif if implicit else None
        solver = sinkhorn.Sinkhorn(
            lse_mode=lse_mode,
            threshold=1e-4,
            use_danskin=False,
            implicit_diff=implicit_diff,
        )
        return solver(prob).reg_ot_cost
    
      delta_a = jax.random.uniform(rngs[4], (n,))
      delta_a = delta_a - jnp.mean(delta_a)
      delta_x = jax.random.uniform(rngs[5], (n, dim))
    
      hess_loss_imp = jax.jit(
          jax.hessian(lambda a, x: loss(a, x, True), argnums=arg)
      )
>     hess_imp = hess_loss_imp(a, x)

tests/solvers/linear/sinkhorn_diff_test.py:794: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../jax-latest/jax/_src/pjit.py:337: in cache_miss
    outs, out_flat, out_tree, args_flat, jaxpr, attrs_tracked = _python_pjit_helper(
../jax-latest/jax/_src/pjit.py:177: in _python_pjit_helper
    p, args_flat = _infer_params(fun, jit_info, args, kwargs)
../jax-latest/jax/_src/pjit.py:743: in _infer_params
    p, args_flat = _infer_params_impl(
../jax-latest/jax/_src/pjit.py:640: in _infer_params_impl
    jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
../jax-latest/jax/_src/linear_util.py:345: in memoized_fun
    ans = call(fun, *args)
../jax-latest/jax/_src/pjit.py:1287: in _create_pjit_jaxpr
    jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(
../jax-latest/jax/_src/profiler.py:333: in wrapper
    return func(*args, **kwargs)
../jax-latest/jax/_src/interpreters/partial_eval.py:2160: in trace_to_jaxpr_dynamic
    ans = fun.call_wrapped(*in_tracers)
../jax-latest/jax/_src/linear_util.py:191: in call_wrapped
    ans = self.f(*args, **dict(self.params, **kwargs))
../jax-latest/jax/_src/api.py:581: in jacfun
    y, jac = vmap(pushfwd, out_axes=(None, -1))(_std_basis(dyn_args))
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

args = ((Traced<ShapedArray(float32[36,12,3])>with<DynamicJaxprTrace>,),), kwargs = {}, args_flat = [Traced<ShapedArray(float32[36,12,3])>with<DynamicJaxprTrace>], in_tree = PyTreeDef((((*,),), {}))
f = Wrapped function:

Core: functools.partial(<function _jvp at 0x108c63490>, Wrapped function:
0   : _argnums_partial   ...il.Unhashable object at 0x12806f520>,))
Core: <lambda>
, (Traced<ShapedArray(float32[12,3])>with<DynamicJaxprTrace>,))

flat_fun = Wrapped function:
0   : flatten_fun_for_vmap   (PyTreeDef((((*,),), {})),)
Core: functools.partial(<function _jvp at 0...il.Unhashable object at 0x12806f520>,))
Core: <lambda>
, (Traced<ShapedArray(float32[12,3])>with<DynamicJaxprTrace>,))

in_axes_flat = [0], axis_size_ = 36, axis_data = AxisData(name=<object object at 0x105014c10>, size=36, spmd_name=None), out_axes_flat = [None, -1]

    @wraps(fun, docstr=docstr)
    @api_boundary
    def vmap_f(*args, **kwargs):
      if isinstance(in_axes, tuple) and len(in_axes) != len(args):
        raise ValueError("vmap in_axes must be an int, None, or a tuple of entries corresponding "
                         "to the positional arguments passed to the function, "
                         f"but got {len(in_axes)=}, {len(args)=}")
      args_flat, in_tree  = tree_flatten((args, kwargs), is_leaf=batching.is_vmappable)
      f = lu.wrap_init(fun)
      flat_fun, out_tree = batching.flatten_fun_for_vmap(f, in_tree)
      in_axes_flat = flatten_axes("vmap in_axes", in_tree, (in_axes, 0), kws=True)
      axis_size_ = (axis_size if axis_size is not None else
                    _mapped_axis_size(fun, in_tree, args_flat, in_axes_flat, "vmap"))
      try:
        axis_data = batching.AxisData(axis_name, axis_size_, spmd_axis_name)
        out_flat = batching.batch(
            flat_fun, axis_data, in_axes_flat,
            lambda: flatten_axes("vmap out_axes", out_tree(), out_axes)
        ).call_wrapped(*args_flat)
      except batching.SpecMatchError as e:
        out_axes_flat = flatten_axes("vmap out_axes", out_tree(), out_axes)
        out_axes_full = tree_unflatten(out_tree(), out_axes_flat)
        pairs, _ = tree_flatten_with_path(out_axes_full, is_leaf=lambda x: x is None)
    
        path, _ = pairs[e.leaf_idx]
>       raise ValueError(f'at vmap out_axes{keystr(path)}, got axis spec {e.dst} '
                         f'but output was batched on axis {e.src}') from None
E       ValueError: at vmap out_axes[0], got axis spec None but output was batched on axis 0

../jax-latest/jax/_src/api.py:1003: ValueError

It might be that the bug is coming from transformations created by lineax, as the test doesn't fail when using the CG solver from JAX, (the test still fails, but only because of the precision, not the above ValueError).

Code to reproduce:

git clone https://github.com/ott-jax/ott/ && cd ott && pip install -e'.[test]'
pip install git+https://github.com/jax-ml/jax.git@c36e1f7c1ad4782060cbc8e8c596d85dfb83986f
pytest -k 'test_hessian_sinkhorn[False-0.93-0.91-1-1e-05]'

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.36.dev20241029+c36e1f7c1
jaxlib: 0.4.35
numpy:  2.0.2
python: 3.11.10 | packaged by conda-forge | (main, Oct 16 2024, 01:26:25) [Clang 17.0.6 ]
device info: cpu-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='Michals-MacBook-Pro-3.local', release='24.1.0', version='Darwin Kernel Version 24.1.0: Thu Oct 10 21:03:15 PDT 2024; root:xnu-11215.41.3~2/RELEASE_ARM64_T6000', machine='arm64')
@marcocuturi marcocuturi added the bug Something isn't working label Dec 5, 2024
@mattjj
Copy link
Collaborator

mattjj commented Dec 5, 2024

Thanks for the clear report!

@mattjj
Copy link
Collaborator

mattjj commented Dec 5, 2024

@marcocuturi could you minimize this? There's a lot going on in this code that we aren't familiar with, and it's much harder for us to minimize unfamiliar code than for you to minimize your own code. Think of the time for us to debug this as exponential in the length of the repro you give us.

It might be that the bug is coming from transformations created by lineax,

Are any of lineax's tests failing?

@michalk8
Copy link

michalk8 commented Dec 5, 2024

Are any of lineax's tests failing?

No, all of them are passing with the latest version of JAX/lineax.

@marcocuturi
Copy link
Author

marcocuturi commented Dec 6, 2024

thanks a lot @mattjj for taking a look!

Here's a simpler example crafted by @michalk8 and @Algue-Rythme demonstrating the problem, which arises indeed from lineax (tagging @patrick-kidger ), leading to the same vmap error:

import jax
import jax.numpy as jnp
import lineax as lx


@jax.custom_vjp
def f(x):
  return x.sum()

def f_fwd(x):
  return x.sum(), (x,)

def f_bwd(res, g):
  x, = res
  A, b = jnp.eye(x.shape[0]), jnp.ones_like(x)
  # op = lx.FunctionLinearOperator(lambda x: x, input_structure=x)  # fails     
  op = lx.MatrixLinearOperator(A)  # fails
  r = lx.linear_solve(op, b).value
  return (r * x * g,)

f.defvjp(f_fwd, f_bwd)


rng = jax.random.key(0)
x = jax.random.normal(rng, (10,))
jax.hessian(f)(x)

@patrick-kidger
Copy link
Collaborator

Yeah, I've also been seeing widespread failures in Diffrax's test suite, due to what looks like a totally different vmap failure. I've spent most of today digging through this and haven't identified a root cause yet. It might take a while to update the JAX ecosystem to be compatible with this version of JAX.

patrick-kidger added a commit to patrick-kidger/equinox that referenced this issue Dec 7, 2024
See:

- jax-ml/jax#25289
- patrick-kidger/diffrax#532

The problem was that batching has now become a dynamic trace, and our batching rules were not set up to handle the case that every batch axis is `not_mapped`.
@patrick-kidger
Copy link
Collaborator

Okay, I think I've identified the root cause: with the latest changes, the batch interpreter has become a dynamic trace, i.e. it calls batch rules when it previously wouldn't. This meant that a lot of arrays were having their nonbatch dimensions now being turned into batch dimensions!

With that problem identified it's been a relatively simple matter to update a couple of batching rules in Equinox to handle this new calling case appropriately.

@marcocuturi @michalk8 can you try patrick-kidger/equinox#907 on your full example / on your tests? If it passes then I'll do a new release of Equinox that is compatible with latest JAX.

patrick-kidger added a commit to patrick-kidger/equinox that referenced this issue Dec 7, 2024
See:

- jax-ml/jax#25289
- patrick-kidger/diffrax#532

The problem was that batching has now become a dynamic trace, and our batching rules were not set up to handle the case that every batch axis is `not_mapped`.
patrick-kidger added a commit to patrick-kidger/equinox that referenced this issue Dec 7, 2024
See:

- jax-ml/jax#25289
- patrick-kidger/diffrax#532

The problem was that batching has now become a dynamic trace, and our batching rules were not set up to handle the case that every batch axis is `not_mapped`.
@mattjj
Copy link
Collaborator

mattjj commented Dec 7, 2024

Nice find @patrick-kidger ! That’s right, actually everything is a dynamic tracer now. No more automatic rules-only-called-based-on-data-dependence, though rules themselves can choose to behave based on dependence. I believe it gives rules strictly more power/expressiveness.

@soraros
Copy link

soraros commented Dec 7, 2024

Hey @mattjj, can we consolidate this kind of knowledge into an updated version of Autodidax?

patrick-kidger added a commit to patrick-kidger/equinox that referenced this issue Dec 8, 2024
See:

- jax-ml/jax#25289
- patrick-kidger/diffrax#532

The problem was that batching has now become a dynamic trace, and our batching rules were not set up to handle the case that every batch axis is `not_mapped`.
@michalk8
Copy link

michalk8 commented Dec 8, 2024

@marcocuturi @michalk8 can you try patrick-kidger/equinox#907 on your full example / on your tests? If it passes then I'll do a new release of Equinox that is compatible with latest JAX.

Works great, thanks!

@mattjj
Copy link
Collaborator

mattjj commented Dec 8, 2024

@soraros yes, good idea, that’s our plan! Cc @dougalm

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

6 participants