diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 12bebfa2a5..d8525b057d 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -244,6 +244,7 @@ set_target_properties(raft_distance PROPERTIES EXPORT_NAME distance) if(RAFT_COMPILE_DIST_LIBRARY) add_library(raft_distance_lib src/distance/pairwise_distance.cu + src/distance/fused_l2_min_arg.cu src/distance/specializations/detail/canberra.cu src/distance/specializations/detail/chebyshev.cu src/distance/specializations/detail/correlation.cu diff --git a/cpp/include/raft_distance/fused_l2_min_arg.hpp b/cpp/include/raft_distance/fused_l2_min_arg.hpp new file mode 100644 index 0000000000..f7d3748666 --- /dev/null +++ b/cpp/include/raft_distance/fused_l2_min_arg.hpp @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft::distance::runtime { + +/** + * @brief Wrapper around fusedL2NN with minimum reduction operators. + * + * fusedL2NN cannot be compiled in the distance library due to the lambda + * operators, so this wrapper covers the most common case (minimum). + * + * @param[in] handle raft handle + * @param[out] min will contain the reduced output (Length = `m`) + * (on device) + * @param[in] x first matrix. Row major. Dim = `m x k`. + * (on device). + * @param[in] y second matrix. Row major. Dim = `n x k`. + * (on device). + * @param[in] m gemm m + * @param[in] n gemm n + * @param[in] k gemm k + * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt + */ +void fused_l2_nn_min_arg(raft::handle_t const& handle, + int* min, + const float* x, + const float* y, + int m, + int n, + int k, + bool sqrt); + +void fused_l2_nn_min_arg(raft::handle_t const& handle, + int* min, + const double* x, + const double* y, + int m, + int n, + int k, + bool sqrt); + +} // end namespace raft::distance::runtime \ No newline at end of file diff --git a/cpp/src/distance/fused_l2_min_arg.cu b/cpp/src/distance/fused_l2_min_arg.cu new file mode 100644 index 0000000000..c722b5a566 --- /dev/null +++ b/cpp/src/distance/fused_l2_min_arg.cu @@ -0,0 +1,98 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::distance::runtime { + +template +struct KeyValueIndexOp { + __host__ __device__ __forceinline__ IndexT + operator()(const raft::KeyValuePair& a) const + { + return a.key; + } +}; + +template +void compute_fused_l2_nn_min_arg(raft::handle_t const& handle, + idx_t* min, + const value_t* x, + const value_t* y, + idx_t m, + idx_t n, + idx_t k, + bool sqrt) +{ + rmm::device_uvector workspace(m, handle.get_stream()); + auto kvp = raft::make_device_vector>(handle, m); + + rmm::device_uvector x_norms(m, handle.get_stream()); + rmm::device_uvector y_norms(n, handle.get_stream()); + raft::linalg::rowNorm(x_norms.data(), x, k, m, raft::linalg::L2Norm, true, handle.get_stream()); + raft::linalg::rowNorm(y_norms.data(), y, k, n, raft::linalg::L2Norm, true, handle.get_stream()); + + fusedL2NNMinReduce(kvp.data_handle(), + x, + y, + x_norms.data(), + y_norms.data(), + m, + n, + k, + (void*)workspace.data(), + sqrt, + true, + handle.get_stream()); + + KeyValueIndexOp conversion_op; + thrust::transform( + handle.get_thrust_policy(), kvp.data_handle(), kvp.data_handle() + m, min, conversion_op); + handle.sync_stream(); +} + +void fused_l2_nn_min_arg(raft::handle_t const& handle, + int* min, + const float* x, + const float* y, + int m, + int n, + int k, + bool sqrt) +{ + compute_fused_l2_nn_min_arg(handle, min, x, y, m, n, k, sqrt); +} + +void fused_l2_nn_min_arg(raft::handle_t const& handle, + int* min, + const double* x, + const double* y, + int m, + int n, + int k, + bool sqrt) +{ + compute_fused_l2_nn_min_arg(handle, min, x, y, m, n, k, sqrt); +} + +} // end namespace raft::distance::runtime \ No newline at end of file diff --git a/python/pylibraft/pylibraft/distance/CMakeLists.txt b/python/pylibraft/pylibraft/distance/CMakeLists.txt index 707ea737b3..d074171e58 100644 --- a/python/pylibraft/pylibraft/distance/CMakeLists.txt +++ b/python/pylibraft/pylibraft/distance/CMakeLists.txt @@ -13,7 +13,8 @@ # ============================================================================= # Set the list of Cython files to build -set(cython_sources pairwise_distance.pyx) +set(cython_sources pairwise_distance.pyx + fused_l2_nn.pyx) set(linked_libraries raft::raft raft::distance) # Build all of the Cython targets diff --git a/python/pylibraft/pylibraft/distance/__init__.py b/python/pylibraft/pylibraft/distance/__init__.py index ca3e6c5a2e..a3c4e2229b 100644 --- a/python/pylibraft/pylibraft/distance/__init__.py +++ b/python/pylibraft/pylibraft/distance/__init__.py @@ -13,4 +13,5 @@ # limitations under the License. # +from .fused_l2_nn import fused_l2_nn_argmin from .pairwise_distance import distance as pairwise_distance \ No newline at end of file diff --git a/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx b/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx new file mode 100644 index 0000000000..5fb837c114 --- /dev/null +++ b/python/pylibraft/pylibraft/distance/fused_l2_nn.pyx @@ -0,0 +1,150 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +import numpy as np + +from libc.stdint cimport uintptr_t +from cython.operator cimport dereference as deref + +from libcpp cimport bool +from .distance_type cimport DistanceType +from pylibraft.common.handle cimport handle_t + + +def is_c_cont(cai, dt): + return "strides" not in cai or \ + cai["strides"] is None or \ + cai["strides"][1] == dt.itemsize + + +cdef extern from "raft_distance/fused_l2_min_arg.hpp" \ + namespace "raft::distance::runtime": + + void fused_l2_nn_min_arg( + const handle_t &handle, + int* min, + const float* x, + const float* y, + int m, + int n, + int k, + bool sqrt) + + void fused_l2_nn_min_arg( + const handle_t &handle, + int* min, + const double* x, + const double* y, + int m, + int n, + int k, + bool sqrt) + + +def fused_l2_nn_argmin(X, Y, output, sqrt=True): + """ + Compute the 1-nearest neighbors between X and Y using the L2 distance + + Parameters + ---------- + + X : CUDA array interface compliant matrix shape (m, k) + Y : CUDA array interface compliant matrix shape (n, k) + output : Writable CUDA array interface matrix shape (m, 1) + + Examples + -------- + + .. code-block:: python + + import cupy as cp + + from pylibraft.distance import fused_l2_nn + + n_samples = 5000 + n_clusters = 5 + n_features = 50 + + in1 = cp.random.random_sample((n_samples, n_features), + dtype=cp.float32) + in2 = cp.random.random_sample((n_clusters, n_features), + dtype=cp.float32) + output = cp.empty((n_samples, 1), dtype=cp.int32) + + fused_l2_nn_argmin(in1, in2, output) + """ + + x_cai = X.__cuda_array_interface__ + y_cai = Y.__cuda_array_interface__ + output_cai = output.__cuda_array_interface__ + + m = x_cai["shape"][0] + n = y_cai["shape"][0] + + x_k = x_cai["shape"][1] + y_k = y_cai["shape"][1] + + if x_k != y_k: + raise ValueError("Inputs must have same number of columns. " + "a=%s, b=%s" % (x_k, y_k)) + + x_ptr = x_cai["data"][0] + y_ptr = y_cai["data"][0] + + d_ptr = output_cai["data"][0] + + cdef handle_t *h = new handle_t() + + x_dt = np.dtype(x_cai["typestr"]) + y_dt = np.dtype(y_cai["typestr"]) + d_dt = np.dtype(output_cai["typestr"]) + + x_c_contiguous = is_c_cont(x_cai, x_dt) + y_c_contiguous = is_c_cont(y_cai, y_dt) + + if x_c_contiguous != y_c_contiguous: + raise ValueError("Inputs must have matching strides") + + print(x_dt) + if x_dt != y_dt: + raise ValueError("Inputs must have the same dtypes") + if d_dt != np.int32: + raise ValueError("Output array must be int32") + + if x_dt == np.float32: + fused_l2_nn_min_arg(deref(h), + d_ptr, + x_ptr, + y_ptr, + m, + n, + x_k, + sqrt) + elif x_dt == np.float64: + fused_l2_nn_min_arg(deref(h), + d_ptr, + x_ptr, + y_ptr, + m, + n, + x_k, + sqrt) + else: + raise ValueError("dtype %s not supported" % x_dt) diff --git a/python/pylibraft/pylibraft/test/test_fused_l2_argmin.py b/python/pylibraft/pylibraft/test/test_fused_l2_argmin.py new file mode 100644 index 0000000000..b12cc30472 --- /dev/null +++ b/python/pylibraft/pylibraft/test/test_fused_l2_argmin.py @@ -0,0 +1,47 @@ +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from scipy.spatial.distance import cdist +import pytest +import numpy as np + +from pylibraft.distance import fused_l2_nn_argmin +from pylibraft.testing.utils import TestDeviceBuffer + + +@pytest.mark.parametrize("n_rows", [10, 100]) +@pytest.mark.parametrize("n_clusters", [5, 10]) +@pytest.mark.parametrize("n_cols", [3, 5]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_fused_l2_nn_minarg(n_rows, n_cols, n_clusters, dtype): + input1 = np.random.random_sample((n_rows, n_cols)) + input1 = np.asarray(input1, order="C").astype(dtype) + + input2 = np.random.random_sample((n_clusters, n_cols)) + input2 = np.asarray(input2, order="C").astype(dtype) + + output = np.zeros((n_rows), dtype="int32") + expected = cdist(input1, input2, metric="euclidean") + + expected = expected.argmin(axis=1) + + input1_device = TestDeviceBuffer(input1, "C") + input2_device = TestDeviceBuffer(input2, "C") + output_device = TestDeviceBuffer(output, "C") + + fused_l2_nn_argmin(input1_device, input2_device, output_device, True) + actual = output_device.copy_to_host() + + assert np.allclose(expected, actual, rtol=1e-4) diff --git a/python/pylibraft/pylibraft/testing/utils.py b/python/pylibraft/pylibraft/testing/utils.py index 53115e991c..979fbb5672 100644 --- a/python/pylibraft/pylibraft/testing/utils.py +++ b/python/pylibraft/pylibraft/testing/utils.py @@ -21,6 +21,7 @@ class TestDeviceBuffer: def __init__(self, ndarray, order): + self.ndarray_ = ndarray self.device_buffer_ = \ rmm.DeviceBuffer.to_device(ndarray.ravel(order=order).tobytes())