You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm observing a consistent segmentation fault when attempting to perform a scan on the code attached below on a jax metal device (M1 Ultra Mac Studio, Sequoia 15.0). I've attempted show the minimum criteria needed to trigger the bug.
importjaximportjax.numpyasjnpjax.config.update('jax_platform_name', 'METAL') # fine when setting to 'cpu'n_layer=5state_width=2# fine when setting equal to to_addto_add=1x_size=5out_state=jax.random.uniform(jax.random.key(0), (n_layer, state_width))
print(out_state)
defdo_loop(x, i):
s1=out_state[i, :to_add] # fine when replacing i with constantto_update=jnp.concat((s1, x[:-to_add]))
# to_update = jnp.concat((s1, jnp.zeros(x_size - to_add))) # still has same bug# to_update = s1 @ jnp.zeros((to_add, x_size)) # still has same bugreturnto_update, ix=jnp.zeros(x_size)
foriinrange(n_layer): # fine with python for-loopx, _=do_loop(x, i)
print("correct output", x)
x=jnp.zeros(x_size)
print(" scan output", jax.lax.scan(do_loop, x, jnp.arange(n_layer))[0]) # segmentation fault
Some notes:
this bug only happens on metal; CPU is perfectly fine
this bug is only triggered when indexing using a scan input AND taking a partial slice (the bug does not occur when only one of these conditions are met)
this bug is not triggered on python for-loops (even when using jit)
System info (python version, jaxlib version, accelerator, etc.)
Description
I'm observing a consistent segmentation fault when attempting to perform a scan on the code attached below on a jax metal device (M1 Ultra Mac Studio, Sequoia 15.0). I've attempted show the minimum criteria needed to trigger the bug.
Some notes:
System info (python version, jaxlib version, accelerator, etc.)
Additionally, jax-metal is 0.1.0
The text was updated successfully, but these errors were encountered: