Skip to content

Commit

Permalink
Merge branch 'branch-24.02' into filename-plot
Browse files Browse the repository at this point in the history
  • Loading branch information
divyegala authored Dec 15, 2023
2 parents b6534c3 + 1beb556 commit b609acc
Showing 1 changed file with 39 additions and 10 deletions.
49 changes: 39 additions & 10 deletions python/pylibraft/pylibraft/common/device_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ def __init__(self, np_ndarray):
Parameters
----------
ndarray : A numpy.ndarray which will be copied and moved to the device
ndarray : Can be numpy.ndarray, array like or even directly
an __array_interface__. Only case it is a numpy.ndarray its
contents will be copied to the device.
Examples
--------
Expand All @@ -58,11 +60,38 @@ def __init__(self, np_ndarray):
raft_array = device_ndarray.empty((100, 50))
torch_tensor = torch.as_tensor(raft_array, device='cuda')
"""
self.ndarray_ = np_ndarray

if type(np_ndarray) is np.ndarray:
# np_ndarray IS an actual numpy.ndarray
self.__array_interface__ = np_ndarray.__array_interface__.copy()
self.ndarray_ = np_ndarray
copy = True
elif hasattr(np_ndarray, "__array_interface__"):
# np_ndarray HAS an __array_interface__
self.__array_interface__ = np_ndarray.__array_interface__.copy()
self.ndarray_ = np_ndarray
copy = False
elif all(
name in np_ndarray for name in {"typestr", "shape", "version"}
):
# np_ndarray IS an __array_interface__
self.__array_interface__ = np_ndarray.copy()
self.ndarray_ = None
copy = False
else:
raise ValueError(
"np_ndarray should be or contain __array_interface__"
)

order = "C" if self.c_contiguous else "F"
self.device_buffer_ = rmm.DeviceBuffer.to_device(
self.ndarray_.tobytes(order=order)
)
if copy:
self.device_buffer_ = rmm.DeviceBuffer.to_device(
self.ndarray_.tobytes(order=order)
)
else:
self.device_buffer_ = rmm.DeviceBuffer(
size=np.prod(self.shape) * self.dtype.itemsize
)

@classmethod
def empty(cls, shape, dtype=np.float32, order="C"):
Expand All @@ -82,7 +111,7 @@ def empty(cls, shape, dtype=np.float32, order="C"):
or column-major (Fortran-style) order in memory
"""
arr = np.empty(shape, dtype=dtype, order=order)
return cls(arr)
return cls(arr.__array_interface__.copy())

@property
def c_contiguous(self):
Expand All @@ -104,23 +133,23 @@ def dtype(self):
"""
Datatype of the current device_ndarray instance
"""
array_interface = self.ndarray_.__array_interface__
array_interface = self.__array_interface__
return np.dtype(array_interface["typestr"])

@property
def shape(self):
"""
Shape of the current device_ndarray instance
"""
array_interface = self.ndarray_.__array_interface__
array_interface = self.__array_interface__
return array_interface["shape"]

@property
def strides(self):
"""
Strides of the current device_ndarray instance
"""
array_interface = self.ndarray_.__array_interface__
array_interface = self.__array_interface__
return array_interface.get("strides")

@property
Expand All @@ -131,7 +160,7 @@ def __cuda_array_interface__(self):
zero-copy semantics.
"""
device_cai = self.device_buffer_.__cuda_array_interface__
host_cai = self.ndarray_.__array_interface__.copy()
host_cai = self.__array_interface__.copy()
host_cai["data"] = (device_cai["data"][0], device_cai["data"][1])

return host_cai
Expand Down

0 comments on commit b609acc

Please sign in to comment.