-
Notifications
You must be signed in to change notification settings - Fork 197
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding
fused_l2_nn_argmin
wrapper to Pylibraft (#924)
This coincides pretty well w/ the `pairwise_distance_armin` building block that's being exposed in Scikit-learn, except it's faster and saves a lot of gpu memory by fusing the argmin w/ the pairwise distances so we don't ever have to store the n^2 distances. cc @betatim Authors: - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Dante Gama Dessavre (https://github.com/dantegd) URL: #924
- Loading branch information
Showing
8 changed files
with
358 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
/* | ||
* Copyright (c) 2022, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include <raft/core/handle.hpp> | ||
#include <raft/distance/distance_types.hpp> | ||
|
||
namespace raft::distance::runtime { | ||
|
||
/** | ||
* @brief Wrapper around fusedL2NN with minimum reduction operators. | ||
* | ||
* fusedL2NN cannot be compiled in the distance library due to the lambda | ||
* operators, so this wrapper covers the most common case (minimum). | ||
* | ||
* @param[in] handle raft handle | ||
* @param[out] min will contain the reduced output (Length = `m`) | ||
* (on device) | ||
* @param[in] x first matrix. Row major. Dim = `m x k`. | ||
* (on device). | ||
* @param[in] y second matrix. Row major. Dim = `n x k`. | ||
* (on device). | ||
* @param[in] m gemm m | ||
* @param[in] n gemm n | ||
* @param[in] k gemm k | ||
* @param[in] sqrt Whether the output `minDist` should contain L2-sqrt | ||
*/ | ||
void fused_l2_nn_min_arg(raft::handle_t const& handle, | ||
int* min, | ||
const float* x, | ||
const float* y, | ||
int m, | ||
int n, | ||
int k, | ||
bool sqrt); | ||
|
||
void fused_l2_nn_min_arg(raft::handle_t const& handle, | ||
int* min, | ||
const double* x, | ||
const double* y, | ||
int m, | ||
int n, | ||
int k, | ||
bool sqrt); | ||
|
||
} // end namespace raft::distance::runtime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
/* | ||
* Copyright (c) 2022, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include <raft/core/device_mdarray.hpp> | ||
#include <raft/core/handle.hpp> | ||
#include <raft/core/kvp.hpp> | ||
#include <raft/distance/distance_types.hpp> | ||
#include <raft/distance/fused_l2_nn.cuh> | ||
#include <raft/distance/specializations.cuh> | ||
#include <thrust/for_each.h> | ||
#include <thrust/tuple.h> | ||
|
||
namespace raft::distance::runtime { | ||
|
||
template <typename IndexT, typename DataT> | ||
struct KeyValueIndexOp { | ||
__host__ __device__ __forceinline__ IndexT | ||
operator()(const raft::KeyValuePair<IndexT, DataT>& a) const | ||
{ | ||
return a.key; | ||
} | ||
}; | ||
|
||
template <typename value_t, typename idx_t> | ||
void compute_fused_l2_nn_min_arg(raft::handle_t const& handle, | ||
idx_t* min, | ||
const value_t* x, | ||
const value_t* y, | ||
idx_t m, | ||
idx_t n, | ||
idx_t k, | ||
bool sqrt) | ||
{ | ||
rmm::device_uvector<int> workspace(m, handle.get_stream()); | ||
auto kvp = raft::make_device_vector<raft::KeyValuePair<idx_t, value_t>>(handle, m); | ||
|
||
rmm::device_uvector<value_t> x_norms(m, handle.get_stream()); | ||
rmm::device_uvector<value_t> y_norms(n, handle.get_stream()); | ||
raft::linalg::rowNorm(x_norms.data(), x, k, m, raft::linalg::L2Norm, true, handle.get_stream()); | ||
raft::linalg::rowNorm(y_norms.data(), y, k, n, raft::linalg::L2Norm, true, handle.get_stream()); | ||
|
||
fusedL2NNMinReduce(kvp.data_handle(), | ||
x, | ||
y, | ||
x_norms.data(), | ||
y_norms.data(), | ||
m, | ||
n, | ||
k, | ||
(void*)workspace.data(), | ||
sqrt, | ||
true, | ||
handle.get_stream()); | ||
|
||
KeyValueIndexOp<idx_t, value_t> conversion_op; | ||
thrust::transform( | ||
handle.get_thrust_policy(), kvp.data_handle(), kvp.data_handle() + m, min, conversion_op); | ||
handle.sync_stream(); | ||
} | ||
|
||
void fused_l2_nn_min_arg(raft::handle_t const& handle, | ||
int* min, | ||
const float* x, | ||
const float* y, | ||
int m, | ||
int n, | ||
int k, | ||
bool sqrt) | ||
{ | ||
compute_fused_l2_nn_min_arg<float, int>(handle, min, x, y, m, n, k, sqrt); | ||
} | ||
|
||
void fused_l2_nn_min_arg(raft::handle_t const& handle, | ||
int* min, | ||
const double* x, | ||
const double* y, | ||
int m, | ||
int n, | ||
int k, | ||
bool sqrt) | ||
{ | ||
compute_fused_l2_nn_min_arg<double, int>(handle, min, x, y, m, n, k, sqrt); | ||
} | ||
|
||
} // end namespace raft::distance::runtime |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
# | ||
# Copyright (c) 2022, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
# cython: profile=False | ||
# distutils: language = c++ | ||
# cython: embedsignature = True | ||
# cython: language_level = 3 | ||
|
||
import numpy as np | ||
|
||
from libc.stdint cimport uintptr_t | ||
from cython.operator cimport dereference as deref | ||
|
||
from libcpp cimport bool | ||
from .distance_type cimport DistanceType | ||
from pylibraft.common.handle cimport handle_t | ||
|
||
|
||
def is_c_cont(cai, dt): | ||
return "strides" not in cai or \ | ||
cai["strides"] is None or \ | ||
cai["strides"][1] == dt.itemsize | ||
|
||
|
||
cdef extern from "raft_distance/fused_l2_min_arg.hpp" \ | ||
namespace "raft::distance::runtime": | ||
|
||
void fused_l2_nn_min_arg( | ||
const handle_t &handle, | ||
int* min, | ||
const float* x, | ||
const float* y, | ||
int m, | ||
int n, | ||
int k, | ||
bool sqrt) | ||
|
||
void fused_l2_nn_min_arg( | ||
const handle_t &handle, | ||
int* min, | ||
const double* x, | ||
const double* y, | ||
int m, | ||
int n, | ||
int k, | ||
bool sqrt) | ||
|
||
|
||
def fused_l2_nn_argmin(X, Y, output, sqrt=True): | ||
""" | ||
Compute the 1-nearest neighbors between X and Y using the L2 distance | ||
Parameters | ||
---------- | ||
X : CUDA array interface compliant matrix shape (m, k) | ||
Y : CUDA array interface compliant matrix shape (n, k) | ||
output : Writable CUDA array interface matrix shape (m, 1) | ||
Examples | ||
-------- | ||
.. code-block:: python | ||
import cupy as cp | ||
from pylibraft.distance import fused_l2_nn | ||
n_samples = 5000 | ||
n_clusters = 5 | ||
n_features = 50 | ||
in1 = cp.random.random_sample((n_samples, n_features), | ||
dtype=cp.float32) | ||
in2 = cp.random.random_sample((n_clusters, n_features), | ||
dtype=cp.float32) | ||
output = cp.empty((n_samples, 1), dtype=cp.int32) | ||
fused_l2_nn_argmin(in1, in2, output) | ||
""" | ||
|
||
x_cai = X.__cuda_array_interface__ | ||
y_cai = Y.__cuda_array_interface__ | ||
output_cai = output.__cuda_array_interface__ | ||
|
||
m = x_cai["shape"][0] | ||
n = y_cai["shape"][0] | ||
|
||
x_k = x_cai["shape"][1] | ||
y_k = y_cai["shape"][1] | ||
|
||
if x_k != y_k: | ||
raise ValueError("Inputs must have same number of columns. " | ||
"a=%s, b=%s" % (x_k, y_k)) | ||
|
||
x_ptr = <uintptr_t>x_cai["data"][0] | ||
y_ptr = <uintptr_t>y_cai["data"][0] | ||
|
||
d_ptr = <uintptr_t>output_cai["data"][0] | ||
|
||
cdef handle_t *h = new handle_t() | ||
|
||
x_dt = np.dtype(x_cai["typestr"]) | ||
y_dt = np.dtype(y_cai["typestr"]) | ||
d_dt = np.dtype(output_cai["typestr"]) | ||
|
||
x_c_contiguous = is_c_cont(x_cai, x_dt) | ||
y_c_contiguous = is_c_cont(y_cai, y_dt) | ||
|
||
if x_c_contiguous != y_c_contiguous: | ||
raise ValueError("Inputs must have matching strides") | ||
|
||
print(x_dt) | ||
if x_dt != y_dt: | ||
raise ValueError("Inputs must have the same dtypes") | ||
if d_dt != np.int32: | ||
raise ValueError("Output array must be int32") | ||
|
||
if x_dt == np.float32: | ||
fused_l2_nn_min_arg(deref(h), | ||
<int*> d_ptr, | ||
<float*> x_ptr, | ||
<float*> y_ptr, | ||
<int>m, | ||
<int>n, | ||
<int>x_k, | ||
<bool>sqrt) | ||
elif x_dt == np.float64: | ||
fused_l2_nn_min_arg(deref(h), | ||
<int*> d_ptr, | ||
<double*> x_ptr, | ||
<double*> y_ptr, | ||
<int>m, | ||
<int>n, | ||
<int>x_k, | ||
<bool>sqrt) | ||
else: | ||
raise ValueError("dtype %s not supported" % x_dt) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# Copyright (c) 2022, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
from scipy.spatial.distance import cdist | ||
import pytest | ||
import numpy as np | ||
|
||
from pylibraft.distance import fused_l2_nn_argmin | ||
from pylibraft.testing.utils import TestDeviceBuffer | ||
|
||
|
||
@pytest.mark.parametrize("n_rows", [10, 100]) | ||
@pytest.mark.parametrize("n_clusters", [5, 10]) | ||
@pytest.mark.parametrize("n_cols", [3, 5]) | ||
@pytest.mark.parametrize("dtype", [np.float32, np.float64]) | ||
def test_fused_l2_nn_minarg(n_rows, n_cols, n_clusters, dtype): | ||
input1 = np.random.random_sample((n_rows, n_cols)) | ||
input1 = np.asarray(input1, order="C").astype(dtype) | ||
|
||
input2 = np.random.random_sample((n_clusters, n_cols)) | ||
input2 = np.asarray(input2, order="C").astype(dtype) | ||
|
||
output = np.zeros((n_rows), dtype="int32") | ||
expected = cdist(input1, input2, metric="euclidean") | ||
|
||
expected = expected.argmin(axis=1) | ||
|
||
input1_device = TestDeviceBuffer(input1, "C") | ||
input2_device = TestDeviceBuffer(input2, "C") | ||
output_device = TestDeviceBuffer(output, "C") | ||
|
||
fused_l2_nn_argmin(input1_device, input2_device, output_device, True) | ||
actual = output_device.copy_to_host() | ||
|
||
assert np.allclose(expected, actual, rtol=1e-4) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters