Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError when attempting grad of vmap #101

Closed
rhacking opened this issue Aug 7, 2024 · 2 comments
Closed

TypeError when attempting grad of vmap #101

rhacking opened this issue Aug 7, 2024 · 2 comments

Comments

@rhacking
Copy link

rhacking commented Aug 7, 2024

In the following code

import jax
import jax.numpy as jnp
import numpy as np
import lineax as lx

N = 8
M = 16
K = 128

np.random.seed(42)
A = np.random.randn(M, N)
B = np.random.randn(M, K)

@jax.jit
@jax.grad
def test1(A):
  op = lx.MatrixLinearOperator(A)
  return jax.vmap(lambda b: lx.linear_solve(op, b, lx.AutoLinearSolver(well_posed=False)).value)(B.T).mean()

@jax.jit
@jax.grad
def test2(A):
  op = lx.MatrixLinearOperator(A)
  result = jnp.empty((B.shape[1], A.shape[1]), dtype=jnp.float32)
  for i, b in enumerate(B.T):
    result = result.at[i, :].set(lx.linear_solve(op, b, lx.AutoLinearSolver(well_posed=False)).value)
  return result.mean()

@jax.jit
@jax.grad
def test3(A):
  op = lx.MatrixLinearOperator(A)
  result = jax.lax.scan(lambda _, b: (None, lx.linear_solve(op, b, lx.AutoLinearSolver(well_posed=False)).value), None, B.T)[1]
  return result.mean()

calling test2 and test3 works fine, but calling test1 results in

---------------------------------------------------------------------------
JaxStackTraceBeforeTransformation         Traceback (most recent call last)
File <frozen runpy>:198, in _run_module_as_main()

File <frozen runpy>:88, in _run_code()

File [~/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/ipykernel_launcher.py:18](http://localhost:8888/home/roel/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/ipykernel_launcher.py#line=17)
     16 from ipykernel import kernelapp as app
---> 18 app.launch_new_instance()

File [~/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/traitlets/config/application.py:1075](http://localhost:8888/home/roel/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/traitlets/config/application.py#line=1074), in launch_instance()
   1074 app.initialize(argv)
-> 1075 app.start()

File [~/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/ipykernel/kernelapp.py:739](http://localhost:8888/home/roel/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/ipykernel/kernelapp.py#line=738), in start()
    738 try:
--> 739     self.io_loop.start()
    740 except KeyboardInterrupt:

File [~/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/tornado/platform/asyncio.py:205](http://localhost:8888/home/roel/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/tornado/platform/asyncio.py#line=204), in start()
    204 def start(self) -> None:
--> 205     self.asyncio_loop.run_forever()

File [/usr/lib/python3.12/asyncio/base_events.py:641](http://localhost:8888/usr/lib/python3.12/asyncio/base_events.py#line=640), in run_forever()
    640 while True:
--> 641     self._run_once()
    642     if self._stopping:

File [/usr/lib/python3.12/asyncio/base_events.py:1987](http://localhost:8888/usr/lib/python3.12/asyncio/base_events.py#line=1986), in _run_once()
   1986     else:
-> 1987         handle._run()
   1988 handle = None

File [/usr/lib/python3.12/asyncio/events.py:88](http://localhost:8888/usr/lib/python3.12/asyncio/events.py#line=87), in _run()
     87 try:
---> 88     self._context.run(self._callback, *self._args)
     89 except (SystemExit, KeyboardInterrupt):

File [~/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/ipykernel/kernelbase.py:545](http://localhost:8888/home/roel/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/ipykernel/kernelbase.py#line=544), in dispatch_queue()
    544 try:
--> 545     await self.process_one()
    546 except Exception:

File [~/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/ipykernel/kernelbase.py:534](http://localhost:8888/home/roel/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/ipykernel/kernelbase.py#line=533), in process_one()
    533         return
--> 534 await dispatch(*args)

File [~/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/ipykernel/kernelbase.py:437](http://localhost:8888/home/roel/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/ipykernel/kernelbase.py#line=436), in dispatch_shell()
    436     if inspect.isawaitable(result):
--> 437         await result
    438 except Exception:

File [~/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/ipykernel/ipkernel.py:362](http://localhost:8888/home/roel/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/ipykernel/ipkernel.py#line=361), in execute_request()
    361 self._associate_new_top_level_threads_with(parent_header)
--> 362 await super().execute_request(stream, ident, parent)

File [~/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/ipykernel/kernelbase.py:778](http://localhost:8888/home/roel/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/ipykernel/kernelbase.py#line=777), in execute_request()
    777 if inspect.isawaitable(reply_content):
--> 778     reply_content = await reply_content
    780 # Flush output before sending the reply.

File [~/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/ipykernel/ipkernel.py:449](http://localhost:8888/home/roel/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/ipykernel/ipkernel.py#line=448), in do_execute()
    448 if accepts_params["cell_id"]:
--> 449     res = shell.run_cell(
    450         code,
    451         store_history=store_history,
    452         silent=silent,
    453         cell_id=cell_id,
    454     )
    455 else:

File [~/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/ipykernel/zmqshell.py:549](http://localhost:8888/home/roel/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/ipykernel/zmqshell.py#line=548), in run_cell()
    548 self._last_traceback = None
--> 549 return super().run_cell(*args, **kwargs)

File [~/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3075](http://localhost:8888/home/roel/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/IPython/core/interactiveshell.py#line=3074), in run_cell()
   3074 try:
-> 3075     result = self._run_cell(
   3076         raw_cell, store_history, silent, shell_futures, cell_id
   3077     )
   3078 finally:

File [~/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3130](http://localhost:8888/home/roel/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/IPython/core/interactiveshell.py#line=3129), in _run_cell()
   3129 try:
-> 3130     result = runner(coro)
   3131 except BaseException as e:

File [~/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/IPython/core/async_helpers.py:129](http://localhost:8888/home/roel/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/IPython/core/async_helpers.py#line=128), in _pseudo_sync_runner()
    128 try:
--> 129     coro.send(None)
    130 except StopIteration as exc:

File [~/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3334](http://localhost:8888/home/roel/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/IPython/core/interactiveshell.py#line=3333), in run_cell_async()
   3331 interactivity = "none" if silent else self.ast_node_interactivity
-> 3334 has_raised = await self.run_ast_nodes(code_ast.body, cell_name,
   3335        interactivity=interactivity, compiler=compiler, result=result)
   3337 self.last_execution_succeeded = not has_raised

File [~/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3517](http://localhost:8888/home/roel/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/IPython/core/interactiveshell.py#line=3516), in run_ast_nodes()
   3516     asy = compare(code)
-> 3517 if await self.run_code(code, result, async_=asy):
   3518     return True

File [~/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/IPython/core/interactiveshell.py:3577](http://localhost:8888/home/roel/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/IPython/core/interactiveshell.py#line=3576), in run_code()
   3576     else:
-> 3577         exec(code_obj, self.user_global_ns, self.user_ns)
   3578 finally:
   3579     # Reset our crash handler in place

Cell In[72], line 1
----> 1 test1(A)

Cell In[38], line 13, in test1()
     12 op = lx.MatrixLinearOperator(A)
---> 13 return jax.vmap(lambda b: lx.linear_solve(op, b, lx.AutoLinearSolver(well_posed=False)))(B.T).value.mean()

Cell In[38], line 13, in test1.<locals>.<lambda>()
     12 op = lx.MatrixLinearOperator(A)
---> 13 return jax.vmap(lambda b: lx.linear_solve(op, b, lx.AutoLinearSolver(well_posed=False)))(B.T).value.mean()

File [~/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/lineax/_solve.py:809](http://localhost:8888/home/roel/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/lineax/_solve.py#line=808), in linear_solve()
    806 solver = eqxi.nondifferentiable(
    807     solver, name="`lineax.linear_solve(..., solver=...)`"
    808 )
--> 809 solution, result, stats = eqxi.filter_primitive_bind(
    810     linear_solve_p, operator, state, vector, options, solver, throw
    811 )
    812 # TODO: prevent forward-mode autodiff through stats

File [~/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/equinox/internal/_primitive.py:264](http://localhost:8888/home/roel/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/equinox/internal/_primitive.py#line=263), in filter_primitive_bind()
    263 flatten = Flatten()
--> 264 flat_out = prim.bind(*dynamic, treedef=treedef, static=static, flatten=flatten)
    265 treedef_out, static_out = flatten.get()

JaxStackTraceBeforeTransformation: TypeError: dot_general requires contracting dimensions to have the same shape, got (8,) and (128,).

The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.

--------------------

The above exception was the direct cause of the following exception:

TypeError                                 Traceback (most recent call last)
Cell In[72], line 1
----> 1 test1(A)

    [... skipping hidden 38 frame]

File [~/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/jax/_src/lax/lax.py:2705](http://localhost:8888/home/roel/.local/share/virtualenvs/jax_ol-Vu_XrYFB/lib/python3.12/site-packages/jax/_src/lax/lax.py#line=2704), in _dot_general_shape_rule(lhs, rhs, dimension_numbers, precision, preferred_element_type)
   2702 if not core.definitely_equal_shape(lhs_contracting_shape, rhs_contracting_shape):
   2703   msg = ("dot_general requires contracting dimensions to have the same "
   2704          "shape, got {} and {}.")
-> 2705   raise TypeError(msg.format(lhs_contracting_shape, rhs_contracting_shape))
   2707 return _dot_general_shape_computation(lhs.shape, rhs.shape, dimension_numbers)

TypeError: dot_general requires contracting dimensions to have the same shape, got (8,) and (128,).

whereas it runs fine if I remove jax.grad from test1. This seems to happen both on the latest release on PyPI, as well as on the latest version on the main branch. Am I doing something wrong here, or should this work?

patrick-kidger added a commit to patrick-kidger/equinox that referenced this issue Aug 7, 2024
In particular this resulted in a trace error here: patrick-kidger/lineax#101
@patrick-kidger
Copy link
Owner

Thanks for the detail reproduction! That's really useful.

This should now be fixed in patrick-kidger/equinox#795 ! If you install Equinox from that branch then you find that things are resolved. I'll merge that upstream and do a new release of Equinox shortly.

@rhacking
Copy link
Author

rhacking commented Aug 7, 2024

Thanks! That fixed it for me.

@rhacking rhacking closed this as completed Aug 7, 2024
patrick-kidger added a commit to patrick-kidger/equinox that referenced this issue Aug 12, 2024
In particular this resulted in a trace error here: patrick-kidger/lineax#101
Artur-Galstyan pushed a commit to Artur-Galstyan/equinox that referenced this issue Aug 12, 2024
In particular this resulted in a trace error here: patrick-kidger/lineax#101
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants