Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moving TestDeviceBuffer to pylibraft.common.device_ndarray #1008

Merged
merged 14 commits into from
Nov 14, 2022
Merged
67 changes: 65 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,73 @@ auto metric = raft::distance::DistanceType::L2SqrtExpanded;
raft::distance::pairwise_distance(handle, input.view(), input.view(), output.view(), metric);
```

It's also possible to create `raft::device_mdspan` views to invoke the same API with raw pointers and shape information:

```c++
#include <raft/core/handle.hpp>
#include <raft/core/device_mdspan.hpp>
#include <raft/random/make_blobs.cuh>
#include <raft/distance/distance.cuh>

raft::handle_t handle;

int n_samples = 5000;
int n_features = 50;

float *input;
int *labels;
float *output;

...
// Allocate input, labels, and output pointers
...

auto input_view = raft::make_device_matrix_view(input, n_samples, n_features);
auto labels_view = raft::make_device_vector_view(labels, n_samples);
auto output_view = raft::make_device_matrix_view(output, n_samples, n_samples);

raft::random::make_blobs(handle, input_view, labels_view);

auto metric = raft::distance::DistanceType::L2SqrtExpanded;
raft::distance::pairwise_distance(handle, input_view, input_view, output_view, metric);
```


### Python Example

The `pylibraft` package contains a Python API for RAFT algorithms and primitives. `pylibraft` integrates nicely into other libraries by being very lightweight with minimal dependencies and accepting any object that supports the `__cuda_array_interface__`, such as [CuPy's ndarray](https://docs.cupy.dev/en/stable/user_guide/interoperability.html#rmm). The number of RAFT algorithms exposed in this package is continuing to grow from release to release.

The example below demonstrates computing the pairwise Euclidean distances between CuPy arrays. `pylibraft` is a low-level API that prioritizes efficiency and simplicity over being pythonic, which is shown here by pre-allocating the output memory before invoking the `pairwise_distance` function. Note that CuPy is not a required dependency for `pylibraft`.
The example below demonstrates computing the pairwise Euclidean distances between CuPy arrays. Note that CuPy is not a required dependency for `pylibraft`.

```python
import cupy as cp

from pylibraft.distance import pairwise_distance

n_samples = 5000
n_features = 50

in1 = cp.random.random_sample((n_samples, n_features), dtype=cp.float32)
in2 = cp.random.random_sample((n_samples, n_features), dtype=cp.float32)

output = pairwise_distance(in1, in2, metric="euclidean")
```

The `output` array supports [__cuda_array_interface__](https://numba.pydata.org/numba-doc/dev/cuda/cuda_array_interface.html#cuda-array-interface-version-2) so it is interoperable with other libraries like CuPy, Numba, and PyTorch that also support it.

Below is an example of converting the output `pylibraft.device_ndarray` to a CuPy array:
```python
cupy_array = cp.asarray(output)
```

And converting to a PyTorch tensor:
```python
import torch

torch_tensor = torch.as_tensor(output, device='cuda')
```

`pylibraft` also supports writing to a pre-allocated output array so any `__cuda_array_interface__` supported array can be written to in-place:

```python
import cupy as cp
Expand All @@ -95,9 +157,10 @@ in1 = cp.random.random_sample((n_samples, n_features), dtype=cp.float32)
in2 = cp.random.random_sample((n_samples, n_features), dtype=cp.float32)
output = cp.empty((n_samples, n_samples), dtype=cp.float32)

pairwise_distance(in1, in2, output, metric="euclidean")
pairwise_distance(in1, in2, out=output, metric="euclidean")
```


## Installing

RAFT itself can be installed through conda, [Cmake Package Manager (CPM)](https://github.com/cpm-cmake/CPM.cmake), or by building the repository from source. Please refer to the [build instructions](docs/source/build.md) for more a comprehensive guide on building RAFT and using it in downstream projects.
Expand Down
2 changes: 2 additions & 0 deletions python/pylibraft/pylibraft/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,7 @@
# limitations under the License.
#


from .cuda import Stream
from .device_ndarray import device_ndarray
from .handle import Handle
158 changes: 158 additions & 0 deletions python/pylibraft/pylibraft/common/device_ndarray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import numpy as np

import rmm


class device_ndarray:
"""
pylibraft.common.device_ndarray is meant to be a very lightweight
__cuda_array_interface__ wrapper around a numpy.ndarray.
"""

def __init__(self, np_ndarray):
"""
Construct a pylibraft.common.device_ndarray wrapper around a
numpy.ndarray

Parameters
----------
ndarray : A numpy.ndarray which will be copied and moved to the device

Examples
--------
The device_ndarray is __cuda_array_interface__ compliant so it is
interoperable with other libraries that also support it, such as
CuPy and PyTorch.

The following usage example demonstrates
converting a pylibraft.common.device_ndarray to a cupy.ndarray:
.. code-block:: python

import cupy as cp
from pylibraft.common import device_ndarray

raft_array = device_ndarray.empty((100, 50))
cupy_array = cp.asarray(raft_array)

And the converting pylibraft.common.device_ndarray to a PyTorch tensor:
.. code-block:: python

import torch
from pylibraft.common import device_ndarray

raft_array = device_ndarray.empty((100, 50))
torch_tensor = torch.as_tensor(raft_array, device='cuda')
"""
self.ndarray_ = np_ndarray
order = "C" if self.c_contiguous else "F"
self.device_buffer_ = rmm.DeviceBuffer.to_device(
self.ndarray_.tobytes(order=order)
)

@classmethod
def empty(cls, shape, dtype=np.float32, order="C"):
"""
Return a new device_ndarray of given shape and type, without
initializing entries.

Parameters
----------
shape : int or tuple of int
Shape of the empty array, e.g., (2, 3) or 2.
dtype : data-type, optional
Desired output data-type for the array, e.g, numpy.int8.
Default is numpy.float32.
order : {'C', 'F'}, optional (default: 'C')
Whether to store multi-dimensional dat ain row-major (C-style)
or column-major (Fortran-style) order in memory
"""
arr = np.empty(shape, dtype=dtype, order=order)
return cls(arr)

@property
def c_contiguous(self):
"""
Is the current device_ndarray laid out in row-major format?
"""
array_interface = self.ndarray_.__array_interface__
strides = self.strides
return (
strides is None
or array_interface["strides"][1] == self.dtype.itemsize
)

@property
def f_contiguous(self):
"""
Is the current device_ndarray laid out in column-major format?
"""
return not self.c_contiguous

@property
def dtype(self):
"""
Datatype of the current device_ndarray instance
"""
array_interface = self.ndarray_.__array_interface__
return np.dtype(array_interface["typestr"])

@property
def shape(self):
"""
Shape of the current device_ndarray instance
"""
array_interface = self.ndarray_.__array_interface__
return array_interface["shape"]

@property
def strides(self):
"""
Strides of the current device_ndarray instance
"""
array_interface = self.ndarray_.__array_interface__
return (
None
if "strides" not in array_interface
else array_interface["strides"]
)

@property
def __cuda_array_interface__(self):
"""
Returns the __cuda_array_interface__ compliant dict for
integrating with other device-enabled libraries using
zero-copy semantics.
"""
device_cai = self.device_buffer_.__cuda_array_interface__
host_cai = self.ndarray_.__array_interface__.copy()
host_cai["data"] = (device_cai["data"][0], device_cai["data"][1])

return host_cai

def copy_to_host(self):
"""
Returns a new numpy.ndarray object on host with the current contents of
this device_ndarray
"""
ret = np.frombuffer(
self.device_buffer_.tobytes(),
dtype=self.dtype,
like=self.ndarray_,
).astype(self.dtype)
ret = np.lib.stride_tricks.as_strided(ret, self.shape, self.strides)
return ret
51 changes: 45 additions & 6 deletions python/pylibraft/pylibraft/distance/fused_l2_nn.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ from libcpp cimport bool

from .distance_type cimport DistanceType

from pylibraft.common import Handle
from pylibraft.common import Handle, device_ndarray
from pylibraft.common.handle import auto_sync_handle
from pylibraft.common.handle cimport handle_t

Expand Down Expand Up @@ -62,7 +62,7 @@ cdef extern from "raft_distance/fused_l2_min_arg.hpp" \


@auto_sync_handle
def fused_l2_nn_argmin(X, Y, output, sqrt=True, handle=None):
def fused_l2_nn_argmin(X, Y, out=None, sqrt=True, handle=None):
"""
Compute the 1-nearest neighbors between X and Y using the L2 distance

Expand All @@ -77,6 +77,35 @@ def fused_l2_nn_argmin(X, Y, output, sqrt=True, handle=None):
Examples
--------

To compute the 1-nearest neighbors argmin:
.. code-block:: python

import cupy as cp

from pylibraft.common import Handle
from pylibraft.distance import fused_l2_nn_argmin

n_samples = 5000
n_clusters = 5
n_features = 50

in1 = cp.random.random_sample((n_samples, n_features),
dtype=cp.float32)
in2 = cp.random.random_sample((n_clusters, n_features),
dtype=cp.float32)

# A single RAFT handle can optionally be reused across
# pylibraft functions.
handle = Handle()
...
output = fused_l2_nn_argmin(in1, in2, output, handle=handle)
...
# pylibraft functions are often asynchronous so the
# handle needs to be explicitly synchronized
handle.sync()

The output can also be computed in-place on a preallocated
array:
.. code-block:: python

import cupy as cp
Expand All @@ -98,20 +127,30 @@ def fused_l2_nn_argmin(X, Y, output, sqrt=True, handle=None):
# pylibraft functions.
handle = Handle()
...
fused_l2_nn_argmin(in1, in2, output, handle=handle)
fused_l2_nn_argmin(in1, in2, out=output, handle=handle)
...
# pylibraft functions are often asynchronous so the
# handle needs to be explicitly synchronized
handle.sync()

"""

x_cai = X.__cuda_array_interface__
y_cai = Y.__cuda_array_interface__
output_cai = output.__cuda_array_interface__

x_dt = np.dtype(x_cai["typestr"])
y_dt = np.dtype(y_cai["typestr"])

m = x_cai["shape"][0]
n = y_cai["shape"][0]

if out is None:
output = device_ndarray.empty((m,), dtype="int32")
else:
output = out

output_cai = output.__cuda_array_interface__

x_k = x_cai["shape"][1]
y_k = y_cai["shape"][1]

Expand All @@ -127,8 +166,6 @@ def fused_l2_nn_argmin(X, Y, output, sqrt=True, handle=None):
handle = handle if handle is not None else Handle()
cdef handle_t *h = <handle_t*><size_t>handle.getHandle()

x_dt = np.dtype(x_cai["typestr"])
y_dt = np.dtype(y_cai["typestr"])
d_dt = np.dtype(output_cai["typestr"])

x_c_contiguous = is_c_cont(x_cai, x_dt)
Expand Down Expand Up @@ -162,3 +199,5 @@ def fused_l2_nn_argmin(X, Y, output, sqrt=True, handle=None):
<bool>sqrt)
else:
raise ValueError("dtype %s not supported" % x_dt)

return output
Loading