-
Notifications
You must be signed in to change notification settings - Fork 83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
StridedMemoryView
fails with Jax arrays
#285
Comments
On a side note, while |
Check out |
This appears to be an issue specific to Jax's implementation of ArrayImpl(basearray.Array).__dlpack__. I am going to keep digging to get a more specific diagnosis, but I am wondering at which point I should call it a Jax bug and move on @leofang |
Could you share what you found here? Then we can determine if a bug to Jax is needed or not, or we just need to work around it, or it's simply a user error. |
For context I have repro'd this locally using cuda 12.6 and jax.numpy for cuda12. Using the same logic with a cupy array works. Next I am planning on finding the implementation of xla_client._xla.buffer_to_dlpack_managed_tensor to see what is going on, and also possibly build cupy locally so I can investigate how cupy is handling dlpack() and the stream argument |
@yangcal I think passing Hi @jakevdp we hit a possibly non-compliant DLPack implementation in either Jax or XLA. |
FWIW, -1 is handled around here (the logic is a bit convoluted) |
There are
So in short even if Yang tested this
it would not work because we'd then reach an |
+1. All of that is consistent with what I have found. |
@yangcal could you check if this patch works on your side? Let's work around the Jax bugs assuming they can be fixed in the next version (cc: @jakevdp for vis) diff --git a/cuda_core/cuda/core/experimental/_memoryview.pyx b/cuda_core/cuda/core/experimental/_memoryview.pyx
index d8eba46..1d4d977 100644
--- a/cuda_core/cuda/core/experimental/_memoryview.pyx
+++ b/cuda_core/cuda/core/experimental/_memoryview.pyx
@@ -7,6 +7,7 @@ cimport cython
from ._dlpack cimport *
import functools
+import importlib.metadata
from typing import Any, Optional
from cuda import cuda
@@ -181,6 +182,13 @@ cdef StridedMemoryView view_as_dlpack(obj, stream_ptr, view=None):
stream=stream_ptr,
max_version=(DLPACK_MAJOR_VERSION, DLPACK_MINOR_VERSION))
versioned = True
+ try:
+ if "jax.numpy" in str(obj.__array_namespace__()):
+ ver = tuple(int(i) for i in importlib.metadata.version("jax").split("."))
+ if ver <= (0, 4, 38):
+ versioned = False
+ except AttributeError:
+ pass
except TypeError:
capsule = obj.__dlpack__(
stream=stream_ptr) You'd also need to avoid passing import cupy as cp
import numpy as np
import jax.numpy as jnp
from cuda.core.experimental import Device
from cuda.core.experimental.utils import args_viewable_as_strided_memory
@args_viewable_as_strided_memory((0,))
def parse_tensor(arr, s, mod):
view = arr.view(s.handle if mod is not np else -1)
print(type(arr), type(view))
print(f"shape={view.shape}")
print(f"strides={view.strides}")
dev = Device(0)
dev.set_current()
s = dev.create_stream()
for module in (np, cp, jnp):
arr = module.eye(2)
parse_tensor(arr, s, module) |
It could also be that the bug is on our side... Ex: on the
in which case we're the consumer (to create a view). |
See #292. |
sorry for the slow response. I can confirm that the fix works and a follow-up question: it seems like the script above requires knowledge on which stream is used to populate the operand on device or any random stream on that device would work? I tried a random integer number as the handle input and encounters segfault, but I also tried generate a new stream object with |
Hi Yang, no it's not any random stream, it's the stream that you will be using to access the content in the decorated function (assuming you're using the decorator). We will order it properly after the stream on which the data is being generated/processed. |
@yangcal was trying the
StridedMemoryView
but it doesn’t seem to work for jax array, at least with cuda12.2. Below is the script and the error log:Error Log:
Is it a semantics error on my side or cuda version issue?
The text was updated successfully, but these errors were encountered: