Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax-metal: segmentation fault inside jax.lax.while_loop #21552

Closed
jonatanklosko opened this issue May 31, 2024 · 8 comments
Closed

jax-metal: segmentation fault inside jax.lax.while_loop #21552

jonatanklosko opened this issue May 31, 2024 · 8 comments
Assignees
Labels
Apple GPU (Metal) plugin bug Something isn't working

Comments

@jonatanklosko
Copy link

jonatanklosko commented May 31, 2024

Description

import jax
import jax.numpy as jnp

def f(x):
  def cond(carry):
    i, x, acc = carry
    return i < x.shape[0]

  def body(carry):
    i, x, acc = carry
    return (i + 1, x, acc + x[i])

  i = jnp.array(0)
  acc = jnp.array(0)

  return jax.lax.while_loop(cond, body, (i, x, acc))

x = jnp.array([1, 2, 3])

# Print lowered HLO
print(jax.jit(f).lower(x).as_text())
print(jax.jit(f)(x))
HLO
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3xi32> {mhlo.layout_mode = "default"}) -> (tensor<i32> {jax.result_info = "[0]", mhlo.layout_mode = "default"}, tensor<3xi32> {jax.result_info = "[1]", mhlo.layout_mode = "default"}, tensor<i32> {jax.result_info = "[2]", mhlo.layout_mode = "default"}) {
    %0 = stablehlo.constant dense<0> : tensor<i32>
    %1 = stablehlo.constant dense<1> : tensor<i32>
    %2:3 = stablehlo.while(%iterArg = %0, %iterArg_0 = %arg0, %iterArg_1 = %1) : tensor<i32>, tensor<3xi32>, tensor<i32>
     cond {
      %3 = stablehlo.constant dense<3> : tensor<i32>
      %4 = stablehlo.compare  LT, %iterArg, %3,  SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
      stablehlo.return %4 : tensor<i1>
    } do {
      %3 = stablehlo.constant dense<1> : tensor<i32>
      %4 = stablehlo.add %iterArg, %3 : tensor<i32>
      %5 = stablehlo.constant dense<0> : tensor<i32>
      %6 = stablehlo.compare  LT, %iterArg, %5,  SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
      %7 = stablehlo.convert %iterArg : tensor<i32>
      %8 = stablehlo.constant dense<3> : tensor<i32>
      %9 = stablehlo.add %7, %8 : tensor<i32>
      %10 = stablehlo.select %6, %9, %iterArg : tensor<i1>, tensor<i32>
      %11 = stablehlo.dynamic_slice %iterArg_0, %10, sizes = [1] : (tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
      %12 = stablehlo.reshape %11 : (tensor<1xi32>) -> tensor<i32>
      %13 = stablehlo.convert %iterArg_1 : tensor<i32>
      %14 = stablehlo.add %13, %12 : tensor<i32>
      stablehlo.return %4, %iterArg_0, %14 : tensor<i32>, tensor<3xi32>, tensor<i32>
    }
    return %2#0, %2#1, %2#2 : tensor<i32>, tensor<3xi32>, tensor<i32>
  }
}

This above loop computes the sum of tensor elements. Running the code results in a segmentation fault.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.10.8 (main, Nov 16 2022, 12:45:33) [Clang 14.0.0 (clang-1400.0.29.202)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='chonker', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May  1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')

jax-metal 0.0.7

@jonatanklosko jonatanklosko added the bug Something isn't working label May 31, 2024
@jonatanklosko
Copy link
Author

Pretty sure the issue is specific to the dynamic slice inside while, I already run into this in several places, and removing the dynamic slice from the code makes it no longer segfault.

@shuhand0
Copy link
Collaborator

shuhand0 commented Jun 3, 2024

The dynamic slice prevents the backend to encode the whileOp. We are looking for the fix.

@acranej
Copy link

acranej commented Jun 11, 2024

Running into the same issue in jax-metal 0.1.0

@abrasumente233
Copy link

import jax
import jax.numpy as jnp

def f(x):
    def scan_fn(h, w):
        h_bne = w * h
        return h_bne, None

    return jax.lax.scan(scan_fn, x, jnp.array([[0.0]]))

x = jnp.ones(1)
print(jax.jit(f).lower(x).as_text())
print(jax.jit(f)(x))

jax.lax.scan hits segfault as well, and also has a dynamic_slice in lowered HLO.

my system info:

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.11.9 | packaged by conda-forge | (main, Apr 19 2024, 18:34:54) [Clang 16.0.6 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='[redacted]', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May  1 20:16:51 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T8103', machine='arm64')

@aniquetahir
Copy link

I have the same issue :(

@vyeevani
Copy link

I'm having the same issue as well

@rajasekharporeddy
Copy link
Contributor

rajasekharporeddy commented Dec 5, 2024

Hi @jonatanklosko

The issue seems to have been fixed in jax-metal 0.1.1. I tested the provided repro on M1 Pro Mac and it is executed without any error or segmentation fault.

>>> import jax
>>> import jax.numpy as jnp
>>> 
>>> 
>>> def f(x):
...   def cond(carry):
...     i, x, acc = carry
...     return i < x.shape[0]
...   def body(carry):
...     i, x, acc = carry
...     return (i + 1, x, acc + x[i])
...   i = jnp.array(0)
...   acc = jnp.array(0)
...   return jax.lax.while_loop(cond, body, (i, x, acc))
... 
>>> 
>>> x = jnp.array([1, 2, 3])
Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1733393163.091288 2444566 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M1 Pro

systemMemory: 16.00 GB
maxCacheSize: 5.33 GB

I0000 00:00:1733393163.106153 2444566 service.cc:145] XLA service 0x600002c95f00 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1733393163.106166 2444566 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1733393163.108507 2444566 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1733393163.108524 2444566 mps_client.cc:384] XLA backend will use up to 11452858368 bytes on device 0 for SimpleAllocator.
>>> 
>>> # Print lowered HLO
>>> print(jax.jit(f).lower(x).as_text())
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3xi32>) -> (tensor<i32> {jax.result_info = "[0]"}, tensor<3xi32> {jax.result_info = "[1]"}, tensor<i32> {jax.result_info = "[2]"}) {
    %c = stablehlo.constant dense<0> : tensor<i32>
    %c_0 = stablehlo.constant dense<0> : tensor<i32>
    %0:3 = stablehlo.while(%iterArg = %c, %iterArg_1 = %arg0, %iterArg_2 = %c_0) : tensor<i32>, tensor<3xi32>, tensor<i32>
     cond {
      %c_3 = stablehlo.constant dense<3> : tensor<i32>
      %1 = stablehlo.compare  LT, %iterArg, %c_3,  SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
      stablehlo.return %1 : tensor<i1>
    } do {
      %c_3 = stablehlo.constant dense<1> : tensor<i32>
      %1 = stablehlo.add %iterArg, %c_3 : tensor<i32>
      %c_4 = stablehlo.constant dense<0> : tensor<i32>
      %2 = stablehlo.compare  LT, %iterArg, %c_4,  SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
      %3 = stablehlo.convert %iterArg : tensor<i32>
      %c_5 = stablehlo.constant dense<3> : tensor<i32>
      %4 = stablehlo.add %3, %c_5 : tensor<i32>
      %5 = stablehlo.select %2, %4, %iterArg : tensor<i1>, tensor<i32>
      %6 = stablehlo.dynamic_slice %iterArg_1, %5, sizes = [1] : (tensor<3xi32>, tensor<i32>) -> tensor<1xi32>
      %7 = stablehlo.reshape %6 : (tensor<1xi32>) -> tensor<i32>
      %8 = stablehlo.convert %iterArg_2 : tensor<i32>
      %9 = stablehlo.add %8, %7 : tensor<i32>
      stablehlo.return %1, %iterArg_1, %9 : tensor<i32>, tensor<3xi32>, tensor<i32>
    }
    return %0#0, %0#1, %0#2 : tensor<i32>, tensor<3xi32>, tensor<i32>
  }
}

>>> print(jax.jit(f)(x))
(Array(3, dtype=int32, weak_type=True), Array([1, 2, 3], dtype=int32), Array(6, dtype=int32))
>>> jax.print_environment_info()
jax:    0.4.35
jaxlib: 0.4.35
numpy:  2.1.3
python: 3.11.6 (v3.11.6:8b6ee5ba3b, Oct  2 2023, 11:18:21) [Clang 13.0.0 (clang-1300.0.29.30)]
device info: Metal-1, 1 local devices"
process_count: 1
platform: uname_result(system='Darwin', node='rajasekharp-macbookpro.roam.internal', release='24.1.0', version='Darwin Kernel Version 24.1.0: Thu Oct 10 21:03:15 PDT 2024; root:xnu-11215.41.3~2/RELEASE_ARM64_T6000', machine='arm64')

Could you please verify with jax-metal 0.1.1, if the issue still persists?

Thank you.

@jonatanklosko
Copy link
Author

jonatanklosko commented Dec 5, 2024

Yes, this appears to be solved in 0.1.1. The example from @abrasumente233 also works. Thanks :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple GPU (Metal) plugin bug Something isn't working
Projects
None yet
Development

No branches or pull requests

8 participants