-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jax-metal: segmentation fault inside jax.lax.while_loop
#21552
Comments
Pretty sure the issue is specific to the dynamic slice inside while, I already run into this in several places, and removing the dynamic slice from the code makes it no longer segfault. |
The dynamic slice prevents the backend to encode the whileOp. We are looking for the fix. |
Running into the same issue in jax-metal 0.1.0 |
import jax
import jax.numpy as jnp
def f(x):
def scan_fn(h, w):
h_bne = w * h
return h_bne, None
return jax.lax.scan(scan_fn, x, jnp.array([[0.0]]))
x = jnp.ones(1)
print(jax.jit(f).lower(x).as_text())
print(jax.jit(f)(x))
my system info:
|
I have the same issue :( |
I'm having the same issue as well |
The issue seems to have been fixed in jax-metal 0.1.1. I tested the provided repro on M1 Pro Mac and it is executed without any error or segmentation fault. >>> import jax
>>> import jax.numpy as jnp
>>>
>>>
>>> def f(x):
... def cond(carry):
... i, x, acc = carry
... return i < x.shape[0]
... def body(carry):
... i, x, acc = carry
... return (i + 1, x, acc + x[i])
... i = jnp.array(0)
... acc = jnp.array(0)
... return jax.lax.while_loop(cond, body, (i, x, acc))
...
>>>
>>> x = jnp.array([1, 2, 3])
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1733393163.091288 2444566 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Pro
systemMemory: 16.00 GB
maxCacheSize: 5.33 GB
I0000 00:00:1733393163.106153 2444566 service.cc:145] XLA service 0x600002c95f00 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1733393163.106166 2444566 service.cc:153] StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1733393163.108507 2444566 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1733393163.108524 2444566 mps_client.cc:384] XLA backend will use up to 11452858368 bytes on device 0 for SimpleAllocator.
>>>
>>> # Print lowered HLO
>>> print(jax.jit(f).lower(x).as_text())
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<3xi32>) -> (tensor<i32> {jax.result_info = "[0]"}, tensor<3xi32> {jax.result_info = "[1]"}, tensor<i32> {jax.result_info = "[2]"}) {
%c = stablehlo.constant dense<0> : tensor<i32>
%c_0 = stablehlo.constant dense<0> : tensor<i32>
%0:3 = stablehlo.while(%iterArg = %c, %iterArg_1 = %arg0, %iterArg_2 = %c_0) : tensor<i32>, tensor<3xi32>, tensor<i32>
cond {
%c_3 = stablehlo.constant dense<3> : tensor<i32>
%1 = stablehlo.compare LT, %iterArg, %c_3, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
stablehlo.return %1 : tensor<i1>
} do {
%c_3 = stablehlo.constant dense<1> : tensor<i32>
%1 = stablehlo.add %iterArg, %c_3 : tensor<i32>
%c_4 = stablehlo.constant dense<0> : tensor<i32>
%2 = stablehlo.compare LT, %iterArg, %c_4, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%3 = stablehlo.convert %iterArg : tensor<i32>
%c_5 = stablehlo.constant dense<3> : tensor<i32>
%4 = stablehlo.add %3, %c_5 : tensor<i32>
%5 = stablehlo.select %2, %4, %iterArg : tensor<i1>, tensor<i32>
%6 = stablehlo.dynamic_slice %iterArg_1, %5, sizes = [1] : (tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
%7 = stablehlo.reshape %6 : (tensor<1xi32>) -> tensor<i32>
%8 = stablehlo.convert %iterArg_2 : tensor<i32>
%9 = stablehlo.add %8, %7 : tensor<i32>
stablehlo.return %1, %iterArg_1, %9 : tensor<i32>, tensor<3xi32>, tensor<i32>
}
return %0#0, %0#1, %0#2 : tensor<i32>, tensor<3xi32>, tensor<i32>
}
}
>>> print(jax.jit(f)(x))
(Array(3, dtype=int32, weak_type=True), Array([1, 2, 3], dtype=int32), Array(6, dtype=int32))
Could you please verify with jax-metal 0.1.1, if the issue still persists? Thank you. |
Yes, this appears to be solved in 0.1.1. The example from @abrasumente233 also works. Thanks :) |
Description
HLO
This above loop computes the sum of tensor elements. Running the code results in a segmentation fault.
System info (python version, jaxlib version, accelerator, etc.)
jax-metal 0.0.7
The text was updated successfully, but these errors were encountered: