Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StridedMemoryView fails with Jax arrays #285

Closed
leofang opened this issue Dec 11, 2024 · 15 comments · Fixed by #292
Closed

StridedMemoryView fails with Jax arrays #285

leofang opened this issue Dec 11, 2024 · 15 comments · Fixed by #292
Assignees
Labels
bug Something isn't working cuda.core Everything related to the cuda.core module P0 High priority - Must do!

Comments

@leofang
Copy link
Member

leofang commented Dec 11, 2024

@yangcal was trying the StridedMemoryView but it doesn’t seem to work for jax array, at least with cuda12.2. Below is the script and the error log:

import cupy as cp
import numpy as np
import jax.numpy as jnp

from cuda.core.experimental.utils import args_viewable_as_strided_memory

@args_viewable_as_strided_memory((0,))
def parse_tensor(arr):
    view = arr.view(-1)
    print(type(arr), type(view))
    print(f"shape={view.shape}")
    print(f"strides={view.strides}")

for module in (np, cp, jnp):
    arr = module.eye(2)
    print(f"module={module.__name__}")
    parse_tensor(arr)

Error Log:

E1211 13:52:21.674104 2256128 ptx_compiler_helpers.cc:71] *** WARNING *** Invoking ptxas with version 12.2.140, which corresponds to a CUDA version <=12.6.2. CUDA versions 12.x.y up to and including 12.6.2 miscompile certain edge cases around clamping.
Please upgrade to CUDA 12.6.3 or newer.
module=jax.numpy
Traceback (most recent call last):
  File "/home/scratch.yangg_sw/software/cuda-python/cuda_core/examples/tmp.py", line 18, in <module>
    parse_tensor(arr)
  File "cuda/core/experimental/_memoryview.pyx", line 372, in cuda.core.experimental._memoryview.args_viewable_as_strided_memory.wrapped_func_with_indices.wrapped_func
  File "/home/scratch.yangg_sw/software/cuda-python/cuda_core/examples/tmp.py", line 10, in parse_tensor
    view = arr.view(-1)
           ^^^^^^^^^^^^
  File "cuda/core/experimental/_memoryview.pyx", line 146, in cuda.core.experimental._memoryview._StridedMemoryViewProxy.view
  File "cuda/core/experimental/_memoryview.pyx", line 148, in cuda.core.experimental._memoryview._StridedMemoryViewProxy.view
  File "cuda/core/experimental/_memoryview.pyx", line 180, in cuda.core.experimental._memoryview.view_as_dlpack
  File "/home/Self/marie/miniconda3/envs/jax/lib/python3.12/site-packages/jax/_src/array.py", line 446, in __dlpack__
    return to_dlpack(self, stream=stream,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/Self/marie/miniconda3/envs/jax/lib/python3.12/site-packages/jax/_src/dlpack.py", line 134, in to_dlpack
    return _to_dlpack(
           ^^^^^^^^^^^
  File "/home/Self/marie/miniconda3/envs/jax/lib/python3.12/site-packages/jax/_src/dlpack.py", line 65, in _to_dlpack
    return xla_client._xla.buffer_to_dlpack_managed_tensor(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: CUDA error: : CUDA_ERROR_INVALID_HANDLE: invalid resource handle

Is it a semantics error on my side or cuda version issue?

@leofang leofang added this to the cuda.core beta 2 milestone Dec 11, 2024
@leofang leofang added triage Needs the team's attention P0 High priority - Must do! cuda.core Everything related to the cuda.core module labels Dec 11, 2024
@yangcal
Copy link

yangcal commented Dec 12, 2024

On a side note, while args_viewable_as_strided_memory is a very helpful decorator, I think it would be nice if there are APIs to allow users to directly construct StridedMemoryView instances from ndarray objects of various packages. This offers more freedom and user can actually write their own decorators.

@leofang
Copy link
Member Author

leofang commented Dec 12, 2024

I think it would be nice if there are APIs to allow users to directly construct StridedMemoryView instances from ndarray objects of various packages.

Check out StridedMemoryView docstring, it is supported! But we want to encourage the decorator use case, because it allows a scoped access (as defined by the decorated function) and not having the view dangling forever.

@ksimpson-work
Copy link
Contributor

ksimpson-work commented Dec 12, 2024

This appears to be an issue specific to Jax's implementation of ArrayImpl(basearray.Array).__dlpack__. I am going to keep digging to get a more specific diagnosis, but I am wondering at which point I should call it a Jax bug and move on @leofang

@leofang
Copy link
Member Author

leofang commented Dec 12, 2024

Could you share what you found here? Then we can determine if a bug to Jax is needed or not, or we just need to work around it, or it's simply a user error.

@ksimpson-work
Copy link
Contributor

ksimpson-work commented Dec 12, 2024

For context I have repro'd this locally using cuda 12.6 and jax.numpy for cuda12.

Using the same logic with a cupy array works.
Jax is failing here https://github.com/jax-ml/jax/blob/99d675ac25f583c0c2e61631355e0d11ba3abf18/jax/_src/dlpack.py#L65
It would appear that the CUDA_ERROR_INVALID_HANDLE is raised because -1 is being passed all the way through. If I create a stream and pass that as the handle the call hangs and then encurs a segmentation fault. I ran this test quickly just to see the behaviour of a stream value != 1. I may be responsible for the segfault.. I need to look into that as well.

Next I am planning on finding the implementation of xla_client._xla.buffer_to_dlpack_managed_tensor to see what is going on, and also possibly build cupy locally so I can investigate how cupy is handling dlpack() and the stream argument

@leofang
Copy link
Member Author

leofang commented Dec 12, 2024

@yangcal I think passing -1 in view = arr.view(-1) could be problematic for Jax arrays. In a normal (stream-ordered) CUDA program you should know which stream to use. Could you modify your toy code and see if it works?

Hi @jakevdp we hit a possibly non-compliant DLPack implementation in either Jax or XLA. stream=-1 is a valid input as described in the __dlpack__ docs. Could you guide us which repo is the best to raise this issue? Maybe https://github.com/openxla/xla/?

@leofang
Copy link
Member Author

leofang commented Dec 12, 2024

and also possibly build cupy locally so I can investigate how cupy is handling dlpack() and the stream argument

FWIW, -1 is handled around here (the logic is a bit convoluted)

@leofang
Copy link
Member Author

leofang commented Dec 12, 2024

There are 3 2 bugs in Jax/XLA that cause DLPack exchange not working:

  1. Jax/XLA does not support DLPack 1.0, but this path exists and returns a (legacy, pre-1.0) capsule.
  2. In XLA, stream is passed as-is, without checking if it's -1

So in short even if Yang tested this

I think passing -1 in view = arr.view(-1) could be problematic for Jax arrays. In a normal (stream-ordered) CUDA program you should know which stream to use. Could you modify your toy code and see if it works?

it would not work because we'd then reach an AssertionError (I tried).

@ksimpson-work
Copy link
Contributor

+1. All of that is consistent with what I have found.

@leofang
Copy link
Member Author

leofang commented Dec 12, 2024

@yangcal could you check if this patch works on your side? Let's work around the Jax bugs assuming they can be fixed in the next version (cc: @jakevdp for vis)

diff --git a/cuda_core/cuda/core/experimental/_memoryview.pyx b/cuda_core/cuda/core/experimental/_memoryview.pyx
index d8eba46..1d4d977 100644
--- a/cuda_core/cuda/core/experimental/_memoryview.pyx
+++ b/cuda_core/cuda/core/experimental/_memoryview.pyx
@@ -7,6 +7,7 @@ cimport cython
 from ._dlpack cimport *
 
 import functools
+import importlib.metadata
 from typing import Any, Optional
 
 from cuda import cuda
@@ -181,6 +182,13 @@ cdef StridedMemoryView view_as_dlpack(obj, stream_ptr, view=None):
             stream=stream_ptr,
             max_version=(DLPACK_MAJOR_VERSION, DLPACK_MINOR_VERSION))
         versioned = True
+        try:
+            if "jax.numpy" in str(obj.__array_namespace__()):
+                ver = tuple(int(i) for i in importlib.metadata.version("jax").split("."))
+                if ver <= (0, 4, 38):
+                    versioned = False
+        except AttributeError:
+            pass
     except TypeError:
         capsule = obj.__dlpack__(
             stream=stream_ptr)

You'd also need to avoid passing -1 as the stream, like this

import cupy as cp
import numpy as np
import jax.numpy as jnp

from cuda.core.experimental import Device
from cuda.core.experimental.utils import args_viewable_as_strided_memory

@args_viewable_as_strided_memory((0,))
def parse_tensor(arr, s, mod):
    view = arr.view(s.handle if mod is not np else -1)
    print(type(arr), type(view))
    print(f"shape={view.shape}")
    print(f"strides={view.strides}")


dev = Device(0)
dev.set_current()
s = dev.create_stream()
for module in (np, cp, jnp):
    arr = module.eye(2)
    parse_tensor(arr, s, module)

@leofang
Copy link
Member Author

leofang commented Dec 12, 2024

It could also be that the bug is on our side... Ex: on the max_version argument the doc says

This means the consumer must verify the version even when max_version is passed.

in which case we're the consumer (to create a view).

@leofang
Copy link
Member Author

leofang commented Dec 12, 2024

It could also be that the bug is on our side...

See #292.

@leofang leofang added bug Something isn't working and removed triage Needs the team's attention labels Dec 12, 2024
@ksimpson-work
Copy link
Contributor

I have locally verified that the #292 solves the bug on our side. Jax still does not handle -1 correctly. Consider passing a stream handle explicitly until they (hopefully) resolve on their end. Thanks for bringing this to our attention @yangcal!

@yangcal
Copy link

yangcal commented Dec 17, 2024

sorry for the slow response. I can confirm that the fix works and a follow-up question: it seems like the script above requires knowledge on which stream is used to populate the operand on device or any random stream on that device would work? I tried a random integer number as the handle input and encounters segfault, but I also tried generate a new stream object with Device.create_stream() and the script works. So the script just needs a dummy stream/handle?

@leofang
Copy link
Member Author

leofang commented Dec 17, 2024

Hi Yang, no it's not any random stream, it's the stream that you will be using to access the content in the decorated function (assuming you're using the decorator). We will order it properly after the stream on which the data is being generated/processed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working cuda.core Everything related to the cuda.core module P0 High priority - Must do!
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants