You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
jax-metal (and its backend) don't yet support reducer with custom computing functions, neither multi-operands. Argmax/argmin are mapped to the corresponding backend ops as special cases.
I see, and so I assume for reduce with single operand there are also special cases for +, * and similar, that makes sense. I saw x, y -> x + y + 1 ignore the + 1 and that's what I thought :D
For context, we try to integrate the metal plugin in Nx (Elixir project that uses XLA similarly to Jax). We implement argmax/argmin on top of reduce, but the IR does not match Jax exactly. I may try to align the IR in the meantime.
Description
HLO
This fails with:
Interestingly
jnp.argmax
works and it is lowered to similar reduce on operand and index (just more elaborate).System info (python version, jaxlib version, accelerator, etc.)
jax-metal 0.0.7
(I also tried with jax/jaxlib 0.4.28 and
ENABLE_PJRT_COMPATIBILITY=1
, but same result)The text was updated successfully, but these errors were encountered: