Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] fix empty initialization of device_ndarray in pylibraft #2061

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 39 additions & 10 deletions python/pylibraft/pylibraft/common/device_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def __init__(self, np_ndarray):

Parameters
----------
ndarray : A numpy.ndarray which will be copied and moved to the device
ndarray : Can be numpy.ndarray, array like or even directly
an __array_interface__. Only case it is a numpy.ndarray its
contents will be copied to the device.

Examples
--------
Expand All @@ -58,11 +60,38 @@ def __init__(self, np_ndarray):
raft_array = device_ndarray.empty((100, 50))
torch_tensor = torch.as_tensor(raft_array, device='cuda')
"""
self.ndarray_ = np_ndarray

if type(np_ndarray) is np.ndarray:
# np_ndarray IS an actual numpy.ndarray
self.__array_interface__ = np_ndarray.__array_interface__.copy()
self.ndarray_ = np_ndarray
copy = True
elif hasattr(np_ndarray, "__array_interface__"):
# np_ndarray HAS an __array_interface__
self.__array_interface__ = np_ndarray.__array_interface__.copy()
self.ndarray_ = np_ndarray
copy = False
elif all(
name in np_ndarray for name in {"typestr", "shape", "version"}
):
# np_ndarray IS an __array_interface__
self.__array_interface__ = np_ndarray.copy()
self.ndarray_ = None
copy = False
else:
raise ValueError(
"np_ndarray should be or contain __array_interface__"
)

order = "C" if self.c_contiguous else "F"
self.device_buffer_ = rmm.DeviceBuffer.to_device(
self.ndarray_.tobytes(order=order)
)
if copy:
self.device_buffer_ = rmm.DeviceBuffer.to_device(
self.ndarray_.tobytes(order=order)
)
else:
self.device_buffer_ = rmm.DeviceBuffer(
size=np.prod(self.shape) * self.dtype.itemsize
)

@classmethod
def empty(cls, shape, dtype=np.float32, order="C"):
Expand All @@ -82,7 +111,7 @@ def empty(cls, shape, dtype=np.float32, order="C"):
or column-major (Fortran-style) order in memory
"""
arr = np.empty(shape, dtype=dtype, order=order)
return cls(arr)
return cls(arr.__array_interface__.copy())

@property
def c_contiguous(self):
Expand All @@ -104,23 +133,23 @@ def dtype(self):
"""
Datatype of the current device_ndarray instance
"""
array_interface = self.ndarray_.__array_interface__
array_interface = self.__array_interface__
return np.dtype(array_interface["typestr"])

@property
def shape(self):
"""
Shape of the current device_ndarray instance
"""
array_interface = self.ndarray_.__array_interface__
array_interface = self.__array_interface__
return array_interface["shape"]

@property
def strides(self):
"""
Strides of the current device_ndarray instance
"""
array_interface = self.ndarray_.__array_interface__
array_interface = self.__array_interface__
return array_interface.get("strides")

@property
Expand All @@ -131,7 +160,7 @@ def __cuda_array_interface__(self):
zero-copy semantics.
"""
device_cai = self.device_buffer_.__cuda_array_interface__
host_cai = self.ndarray_.__array_interface__.copy()
host_cai = self.__array_interface__.copy()
host_cai["data"] = (device_cai["data"][0], device_cai["data"][1])

return host_cai
Expand Down