From cca81c4970233b8f8dc14f8ab7bc952cb703cad2 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 13 Feb 2023 08:38:49 -0800 Subject: [PATCH] Add innerproduct to the pairwise distance api (#1226) Authors: - Ben Frederickson (https://github.com/benfred) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1226 --- cpp/CMakeLists.txt | 2 + cpp/bench/distance/distance_common.cuh | 4 +- cpp/include/raft/distance/detail/distance.cuh | 202 ++++++++++++------ .../distance/detail/kernels/gram_matrix.cuh | 6 +- .../detail/kernels/kernel_matrices.cuh | 6 +- cpp/include/raft/distance/distance.cuh | 79 +++---- .../specializations/detail/canberra.cuh | 5 +- .../specializations/detail/chebyshev.cuh | 4 +- .../specializations/detail/correlation.cuh | 5 +- .../specializations/detail/cosine.cuh | 5 +- .../detail/hamming_unexpanded.cuh | 5 +- .../detail/hellinger_expanded.cuh | 4 +- .../specializations/detail/inner_product.cuh | 52 +++++ .../specializations/detail/jensen_shannon.cuh | 25 +-- .../specializations/detail/kl_divergence.cuh | 4 +- .../distance/specializations/detail/l1.cuh | 4 +- .../specializations/detail/l2_expanded.cuh | 25 +-- .../detail/l2_sqrt_expanded.cuh | 4 +- .../detail/l2_sqrt_unexpanded.cuh | 4 +- .../specializations/detail/l2_unexpanded.cuh | 4 +- .../specializations/detail/lp_unexpanded.cuh | 4 +- .../specializations/detail/russel_rao.cuh | 4 +- .../distance/specializations/distance.cuh | 3 +- cpp/include/raft/linalg/detail/gemm.hpp | 17 +- cpp/include/raft/linalg/gemm.cuh | 8 +- .../canberra_double_double_double_int.cu | 16 +- .../detail/canberra_float_float_float_int.cu | 2 +- .../chebyshev_double_double_double_int.cu | 2 +- .../detail/chebyshev_float_float_float_int.cu | 2 +- .../correlation_double_double_double_int.cu | 2 +- .../correlation_float_float_float_int.cu | 2 +- .../detail/cosine_double_double_double_int.cu | 2 +- .../detail/cosine_float_float_float_int.cu | 3 +- ...ing_unexpanded_double_double_double_int.cu | 2 +- ...amming_unexpanded_float_float_float_int.cu | 2 +- ...inger_expanded_double_double_double_int.cu | 2 +- ...ellinger_expanded_float_float_float_int.cu | 3 +- .../inner_product_double_double_double_int.cu | 37 ++++ .../inner_product_float_float_float_int.cu | 36 ++++ ...jensen_shannon_double_double_double_int.cu | 4 +- .../jensen_shannon_float_float_float_int.cu | 17 +- .../kl_divergence_double_double_double_int.cu | 4 +- .../kl_divergence_float_float_float_int.cu | 4 +- .../detail/l1_double_double_double_int.cu | 4 +- .../detail/l1_float_float_float_int.cu | 4 +- .../l2_expanded_double_double_double_int.cu | 4 +- .../l2_expanded_float_float_float_int.cu | 4 +- ..._sqrt_expanded_double_double_double_int.cu | 4 +- .../l2_sqrt_expanded_float_float_float_int.cu | 4 +- ...qrt_unexpanded_double_double_double_int.cu | 4 +- ...2_sqrt_unexpanded_float_float_float_int.cu | 4 +- .../l2_unexpanded_double_double_double_int.cu | 4 +- .../l2_unexpanded_float_float_float_int.cu | 4 +- .../lp_unexpanded_double_double_double_int.cu | 4 +- .../lp_unexpanded_float_float_float_int.cu | 4 +- .../russel_rao_double_double_double_int.cu | 4 +- .../russel_rao_float_float_float_int.cu | 4 +- cpp/test/CMakeLists.txt | 1 + cpp/test/distance/dist_adj.cu | 4 +- cpp/test/distance/dist_inner_product.cu | 76 +++++++ cpp/test/distance/distance_base.cuh | 55 +++-- .../pylibraft/distance/pairwise_distance.pyx | 2 +- .../pylibraft/pylibraft/test/test_distance.py | 10 +- 63 files changed, 555 insertions(+), 276 deletions(-) create mode 100644 cpp/include/raft/distance/specializations/detail/inner_product.cuh create mode 100644 cpp/src/distance/distance/specializations/detail/inner_product_double_double_double_int.cu create mode 100644 cpp/src/distance/distance/specializations/detail/inner_product_float_float_float_int.cu create mode 100644 cpp/test/distance/dist_inner_product.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 6ec8dada43..350306f523 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -321,6 +321,8 @@ if(RAFT_COMPILE_DIST_LIBRARY) src/distance/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu src/distance/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu src/distance/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu + src/distance/distance/specializations/detail/inner_product_float_float_float_int.cu + src/distance/distance/specializations/detail/inner_product_double_double_double_int.cu src/distance/distance/specializations/detail/jensen_shannon_float_float_float_int.cu src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu src/distance/distance/specializations/detail/kernels/gram_matrix_base_double.cu diff --git a/cpp/bench/distance/distance_common.cuh b/cpp/bench/distance/distance_common.cuh index 7d51bfa608..906271bf5a 100644 --- a/cpp/bench/distance/distance_common.cuh +++ b/cpp/bench/distance/distance_common.cuh @@ -49,7 +49,8 @@ struct distance : public fixture { void run_benchmark(::benchmark::State& state) override { loop_on_state(state, [this]() { - raft::distance::distance(x.data(), + raft::distance::distance(handle, + x.data(), y.data(), out.data(), params.m, @@ -57,7 +58,6 @@ struct distance : public fixture { params.k, (void*)workspace.data(), worksize, - stream, params.isRowMajor); }); } diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index b459c73bee..f9a85ea627 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include #include @@ -30,6 +31,7 @@ #include #include #include +#include #include #include @@ -92,7 +94,8 @@ template struct DistanceImpl { - void run(const InType* x, + void run(raft::resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -101,7 +104,6 @@ struct DistanceImpl { void* workspace, size_t worksize, FinalLambda fin_op, - cudaStream_t stream, bool isRowMajor, InType metric_arg = 2.0f) { @@ -119,7 +121,8 @@ struct DistanceImpl { - void run(const InType* x, + void run(raft::resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -128,12 +131,22 @@ struct DistanceImpl( - m, n, k, x, y, dist, false, (AccType*)workspace, worksize, fin_op, stream, isRowMajor); + m, + n, + k, + x, + y, + dist, + false, + (AccType*)workspace, + worksize, + fin_op, + raft::resource::get_cuda_stream(handle), + isRowMajor); } }; @@ -148,7 +161,8 @@ struct DistanceImpl { - void run(const InType* x, + void run(raft::resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -157,12 +171,22 @@ struct DistanceImpl( - m, n, k, x, y, dist, true, (AccType*)workspace, worksize, fin_op, stream, isRowMajor); + m, + n, + k, + x, + y, + dist, + true, + (AccType*)workspace, + worksize, + fin_op, + raft::resource::get_cuda_stream(handle), + isRowMajor); } }; @@ -177,7 +201,8 @@ struct DistanceImpl { - void run(const InType* x, + void run(raft::resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -186,12 +211,60 @@ struct DistanceImpl( - m, n, k, x, y, dist, (AccType*)workspace, worksize, fin_op, stream, isRowMajor); + m, + n, + k, + x, + y, + dist, + (AccType*)workspace, + worksize, + fin_op, + raft::resource::get_cuda_stream(handle), + isRowMajor); + } +}; + +template +struct DistanceImpl { + void run(raft::resources const& handle, + const InType* x, + const InType* y, + OutType* dist, + Index_ m, + Index_ n, + Index_ k, + void* workspace, + size_t worksize, + FinalLambda fin_op, + bool isRowMajor, + InType) + { + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + raft::linalg::gemm(handle, + dist, + const_cast(x), + const_cast(y), + m, + n, + k, + !isRowMajor, + !isRowMajor, + isRowMajor, + stream); } }; @@ -206,7 +279,8 @@ struct DistanceImpl { - void run(const InType* x, + void run(raft::resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -215,12 +289,11 @@ struct DistanceImpl( - m, n, k, x, y, dist, false, fin_op, stream, isRowMajor); + m, n, k, x, y, dist, false, fin_op, raft::resource::get_cuda_stream(handle), isRowMajor); } }; @@ -235,7 +308,8 @@ struct DistanceImpl { - void run(const InType* x, + void run(raft::resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -244,12 +318,11 @@ struct DistanceImpl( - m, n, k, x, y, dist, true, fin_op, stream, isRowMajor); + m, n, k, x, y, dist, true, fin_op, raft::resource::get_cuda_stream(handle), isRowMajor); } }; @@ -264,7 +337,8 @@ struct DistanceImpl { - void run(const InType* x, + void run(raft::resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -273,12 +347,11 @@ struct DistanceImpl( - m, n, k, x, y, dist, fin_op, stream, isRowMajor); + m, n, k, x, y, dist, fin_op, raft::resource::get_cuda_stream(handle), isRowMajor); } }; @@ -293,7 +366,8 @@ struct DistanceImpl { - void run(const InType* x, + void run(raft::resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -302,12 +376,11 @@ struct DistanceImpl( - m, n, k, x, y, dist, fin_op, stream, isRowMajor); + m, n, k, x, y, dist, fin_op, raft::resource::get_cuda_stream(handle), isRowMajor); } }; @@ -322,7 +395,8 @@ struct DistanceImpl { - void run(const InType* x, + void run(raft::resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -331,12 +405,11 @@ struct DistanceImpl( - m, n, k, x, y, dist, fin_op, stream, isRowMajor); + m, n, k, x, y, dist, fin_op, raft::resource::get_cuda_stream(handle), isRowMajor); } }; @@ -351,7 +424,8 @@ struct DistanceImpl { - void run(const InType* x, + void run(raft::resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -360,12 +434,11 @@ struct DistanceImpl( - m, n, k, x, y, dist, fin_op, stream, isRowMajor, metric_arg); + m, n, k, x, y, dist, fin_op, raft::resource::get_cuda_stream(handle), isRowMajor, metric_arg); } }; @@ -380,7 +453,8 @@ struct DistanceImpl { - void run(const InType* x, + void run(raft::resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -389,12 +463,11 @@ struct DistanceImpl( - m, n, k, x, y, dist, fin_op, stream, isRowMajor); + m, n, k, x, y, dist, fin_op, raft::resource::get_cuda_stream(handle), isRowMajor); } }; @@ -409,7 +482,8 @@ struct DistanceImpl { - void run(const InType* x, + void run(raft::resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -418,12 +492,11 @@ struct DistanceImpl( - m, n, k, x, y, dist, fin_op, stream, isRowMajor); + m, n, k, x, y, dist, fin_op, raft::resource::get_cuda_stream(handle), isRowMajor); } }; @@ -438,7 +511,8 @@ struct DistanceImpl { - void run(const InType* x, + void run(raft::resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -447,12 +521,11 @@ struct DistanceImpl( - m, n, k, x, y, dist, fin_op, stream, isRowMajor); + m, n, k, x, y, dist, fin_op, raft::resource::get_cuda_stream(handle), isRowMajor); } }; @@ -467,7 +540,8 @@ struct DistanceImpl { - void run(const InType* x, + void run(raft::resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -476,12 +550,11 @@ struct DistanceImpl( - m, n, k, x, y, dist, fin_op, stream, isRowMajor); + m, n, k, x, y, dist, fin_op, raft::resource::get_cuda_stream(handle), isRowMajor); } }; @@ -496,7 +569,8 @@ struct DistanceImpl { - void run(const InType* x, + void run(raft::resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -505,12 +579,11 @@ struct DistanceImpl( - m, n, k, x, y, dist, fin_op, stream, isRowMajor); + m, n, k, x, y, dist, fin_op, raft::resource::get_cuda_stream(handle), isRowMajor); } }; @@ -525,7 +598,8 @@ struct DistanceImpl { - void run(const InType* x, + void run(raft::resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -534,12 +608,21 @@ struct DistanceImpl( - m, n, k, x, y, dist, (AccType*)workspace, worksize, fin_op, stream, isRowMajor); + m, + n, + k, + x, + y, + dist, + (AccType*)workspace, + worksize, + fin_op, + raft::resource::get_cuda_stream(handle), + isRowMajor); } }; @@ -562,7 +645,6 @@ struct DistanceImpl -void distance(const InType* x, +void distance(raft::resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -585,12 +668,11 @@ void distance(const InType* x, void* workspace, size_t worksize, FinalLambda fin_op, - cudaStream_t stream, bool isRowMajor = true, InType metric_arg = 2.0f) { DistanceImpl distImpl; - distImpl.run(x, y, dist, m, n, k, workspace, worksize, fin_op, stream, isRowMajor, metric_arg); + distImpl.run(handle, x, y, dist, m, n, k, workspace, worksize, fin_op, isRowMajor, metric_arg); RAFT_CUDA_TRY(cudaPeekAtLastError()); } @@ -609,7 +691,6 @@ void distance(const InType* x, * @param k dimensionality * @param workspace temporary workspace needed for computations * @param worksize number of bytes of the workspace - * @param stream cuda stream * @param isRowMajor whether the matrices are row-major or col-major * * @note if workspace is passed as nullptr, this will return in @@ -633,7 +714,8 @@ template -void distance(const InType* x, +void distance(raft::resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -641,7 +723,6 @@ void distance(const InType* x, Index_ k, void* workspace, size_t worksize, - cudaStream_t stream, bool isRowMajor = true, InType metric_arg = 2.0f) { @@ -653,7 +734,7 @@ void distance(const InType* x, "OutType can be uint8_t, float, double," "if sizeof(OutType) > 1 then sizeof(AccType) == sizeof(OutType)."); distance( - x, y, dist, m, n, k, workspace, worksize, fin_op, stream, isRowMajor, metric_arg); + handle, x, y, dist, m, n, k, workspace, worksize, fin_op, isRowMajor, metric_arg); RAFT_CUDA_TRY(cudaPeekAtLastError()); } @@ -710,25 +791,24 @@ size_t getWorkspaceSize(const InType* x, const InType* y, Index_ m, Index_ n, In * @param workspace temporary workspace buffer which can get resized as per the * needed workspace size * @param metric distance metric - * @param stream cuda stream * @param isRowMajor whether the matrices are row-major or col-major */ template -void pairwise_distance_impl(const Type* x, +void pairwise_distance_impl(raft::resources const& handle, + const Type* x, const Type* y, Type* dist, Index_ m, Index_ n, Index_ k, rmm::device_uvector& workspace, - cudaStream_t stream, bool isRowMajor, Type metric_arg = 2.0f) { auto worksize = getWorkspaceSize(x, y, m, n, k); - workspace.resize(worksize, stream); + workspace.resize(worksize, raft::resource::get_cuda_stream(handle)); distance( - x, y, dist, m, n, k, workspace.data(), worksize, stream, isRowMajor, metric_arg); + handle, x, y, dist, m, n, k, workspace.data(), worksize, isRowMajor, metric_arg); } /** @} */ }; // namespace detail diff --git a/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh b/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh index 344dda693e..aaf3052892 100644 --- a/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh +++ b/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -212,7 +212,7 @@ class GramMatrixBase { int ld_out) { raft::distance::distance( - x1, x2, out, n1, n2, n_cols, stream, is_row_major); + raft::device_resources(stream), x1, x2, out, n1, n2, n_cols, is_row_major); } }; -}; // end namespace raft::distance::kernels::detail \ No newline at end of file +}; // end namespace raft::distance::kernels::detail diff --git a/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh b/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh index b74de84d80..d1465efdb0 100644 --- a/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh +++ b/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -359,7 +359,8 @@ class RBFKernel : public GramMatrixBase { math_t, math_t, decltype(fin_op), - index_t>(const_cast(x1), + index_t>(device_resources(stream), + const_cast(x1), const_cast(x2), out, n1, @@ -368,7 +369,6 @@ class RBFKernel : public GramMatrixBase { NULL, 0, fin_op, - stream, is_row_major); } }; diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index 93a5ce7f1a..5fd5ebe98d 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -18,7 +18,8 @@ #pragma once -#include +#include +#include #include #include #include @@ -50,7 +51,6 @@ namespace distance { * @param workspace temporary workspace needed for computations * @param worksize number of bytes of the workspace * @param fin_op the final gemm epilogue lambda - * @param stream cuda stream * @param isRowMajor whether the matrices are row-major or col-major * @param metric_arg metric argument (used for Minkowski distance) * @@ -65,7 +65,8 @@ template -void distance(const InType* x, +void distance(raft::resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -74,12 +75,11 @@ void distance(const InType* x, void* workspace, size_t worksize, FinalLambda fin_op, - cudaStream_t stream, bool isRowMajor = true, InType metric_arg = 2.0f) { detail::distance( - x, y, dist, m, n, k, workspace, worksize, fin_op, stream, isRowMajor, metric_arg); + handle, x, y, dist, m, n, k, workspace, worksize, fin_op, isRowMajor, metric_arg); } /** @@ -97,7 +97,6 @@ void distance(const InType* x, * @param k dimensionality * @param workspace temporary workspace needed for computations * @param worksize number of bytes of the workspace - * @param stream cuda stream * @param isRowMajor whether the matrices are row-major or col-major * @param metric_arg metric argument (used for Minkowski distance) * @@ -109,7 +108,8 @@ template -void distance(const InType* x, +void distance(raft::resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -117,12 +117,11 @@ void distance(const InType* x, Index_ k, void* workspace, size_t worksize, - cudaStream_t stream, bool isRowMajor = true, InType metric_arg = 2.0f) { detail::distance( - x, y, dist, m, n, k, workspace, worksize, stream, isRowMajor, metric_arg); + handle, x, y, dist, m, n, k, workspace, worksize, isRowMajor, metric_arg); } /** @@ -193,7 +192,6 @@ size_t getWorkspaceSize(const raft::device_matrix_view x, * @param m number of points in x * @param n number of points in y * @param k dimensionality - * @param stream cuda stream * @param isRowMajor whether the matrices are row-major or col-major * @param metric_arg metric argument (used for Minkowski distance) */ @@ -202,21 +200,22 @@ template -void distance(const InType* x, +void distance(raft::resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, Index_ n, Index_ k, - cudaStream_t stream, bool isRowMajor = true, InType metric_arg = 2.0f) { + auto stream = raft::resource::get_cuda_stream(handle); rmm::device_uvector workspace(0, stream); auto worksize = getWorkspaceSize(x, y, m, n, k); workspace.resize(worksize, stream); detail::distance( - x, y, dist, m, n, k, workspace.data(), worksize, stream, isRowMajor, metric_arg); + handle, x, y, dist, m, n, k, workspace.data(), worksize, isRowMajor, metric_arg); } /** @@ -238,7 +237,7 @@ void distance(const InType* x, * @param metric_arg metric argument (used for Minkowski distance) */ template -void pairwise_distance(raft::device_resources const& handle, +void pairwise_distance(raft::resources const& handle, const Type* x, const Type* y, Type* dist, @@ -253,64 +252,68 @@ void pairwise_distance(raft::device_resources const& handle, switch (metric) { case raft::distance::DistanceType::L2Expanded: detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, workspace, isRowMajor); break; case raft::distance::DistanceType::L2SqrtExpanded: detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, workspace, isRowMajor); break; case raft::distance::DistanceType::CosineExpanded: detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, workspace, isRowMajor); break; case raft::distance::DistanceType::L1: detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, workspace, isRowMajor); break; case raft::distance::DistanceType::L2Unexpanded: detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, workspace, isRowMajor); break; case raft::distance::DistanceType::L2SqrtUnexpanded: detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, workspace, isRowMajor); break; case raft::distance::DistanceType::Linf: detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, workspace, isRowMajor); break; case raft::distance::DistanceType::HellingerExpanded: detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, workspace, isRowMajor); break; case raft::distance::DistanceType::LpUnexpanded: detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor, metric_arg); + handle, x, y, dist, m, n, k, workspace, isRowMajor, metric_arg); break; case raft::distance::DistanceType::Canberra: detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, workspace, isRowMajor); break; case raft::distance::DistanceType::HammingUnexpanded: detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, workspace, isRowMajor); break; case raft::distance::DistanceType::JensenShannon: detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, workspace, isRowMajor); break; case raft::distance::DistanceType::RusselRaoExpanded: detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, workspace, isRowMajor); break; case raft::distance::DistanceType::KLDivergence: detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, workspace, isRowMajor); break; case raft::distance::DistanceType::CorrelationExpanded: detail:: pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, workspace, isRowMajor); + break; + case raft::distance::DistanceType::InnerProduct: + detail::pairwise_distance_impl( + handle, x, y, dist, m, n, k, workspace, isRowMajor); break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); }; @@ -333,7 +336,7 @@ void pairwise_distance(raft::device_resources const& handle, * @param metric_arg metric argument (used for Minkowski distance) */ template -void pairwise_distance(raft::device_resources const& handle, +void pairwise_distance(raft::resources const& handle, const Type* x, const Type* y, Type* dist, @@ -344,7 +347,8 @@ void pairwise_distance(raft::device_resources const& handle, bool isRowMajor = true, Type metric_arg = 2.0f) { - rmm::device_uvector workspace(0, handle.get_stream()); + auto stream = raft::resource::get_cuda_stream(handle); + rmm::device_uvector workspace(0, stream); pairwise_distance( handle, x, y, dist, m, n, k, workspace, metric, isRowMajor, metric_arg); } @@ -398,7 +402,7 @@ template -void distance(raft::device_resources const& handle, +void distance(raft::resources const& handle, raft::device_matrix_view const x, raft::device_matrix_view const y, raft::device_matrix_view dist, @@ -417,13 +421,13 @@ void distance(raft::device_resources const& handle, constexpr auto is_rowmajor = std::is_same_v; - distance(x.data_handle(), + distance(handle, + x.data_handle(), y.data_handle(), dist.data_handle(), x.extent(0), y.extent(0), x.extent(1), - handle.get_stream(), is_rowmajor, metric_arg); } @@ -441,7 +445,7 @@ void distance(raft::device_resources const& handle, * @param metric_arg metric argument (used for Minkowski distance) */ template -void pairwise_distance(raft::device_resources const& handle, +void pairwise_distance(raft::resources const& handle, device_matrix_view const x, device_matrix_view const y, device_matrix_view dist, @@ -462,7 +466,8 @@ void pairwise_distance(raft::device_resources const& handle, constexpr auto rowmajor = std::is_same_v; - rmm::device_uvector workspace(0, handle.get_stream()); + auto stream = raft::resource::get_cuda_stream(handle); + rmm::device_uvector workspace(0, stream); pairwise_distance(handle, x.data_handle(), @@ -481,4 +486,4 @@ void pairwise_distance(raft::device_resources const& handle, }; // namespace distance }; // namespace raft -#endif \ No newline at end of file +#endif diff --git a/cpp/include/raft/distance/specializations/detail/canberra.cuh b/cpp/include/raft/distance/specializations/detail/canberra.cuh index 6b6364fb58..badce715a5 100644 --- a/cpp/include/raft/distance/specializations/detail/canberra.cuh +++ b/cpp/include/raft/distance/specializations/detail/canberra.cuh @@ -23,6 +23,7 @@ namespace raft { namespace distance { namespace detail { extern template void distance( + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -31,11 +32,11 @@ extern template void distance( + raft::resources const& handle, const double* x, const double* y, double* dist, @@ -44,10 +45,8 @@ extern template void distance( + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -30,11 +31,11 @@ extern template void distance( + raft::resources const& handle, const double* x, const double* y, double* dist, @@ -43,7 +44,6 @@ extern template void distance( + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -31,12 +32,12 @@ distance( + raft::resources const& handle, const double* x, const double* y, double* dist, @@ -45,10 +46,8 @@ distance( + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -31,12 +32,12 @@ distance int k, void* workspace, size_t worksize, - cudaStream_t stream, bool isRowMajor, float metric_arg); extern template void distance( + raft::resources const& handle, const double* x, const double* y, double* dist, @@ -45,10 +46,8 @@ distance( + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -31,12 +32,12 @@ distance( + raft::resources const& handle, const double* x, const double* y, double* dist, @@ -45,10 +46,8 @@ distance( + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -31,12 +32,12 @@ distance( + raft::resources const& handle, const double* x, const double* y, double* dist, @@ -45,7 +46,6 @@ distance + +namespace raft { +namespace distance { +namespace detail { +extern template void distance( + raft::resources const& handle, + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + void* workspace, + size_t worksize, + bool isRowMajor, + float metric_arg); + +extern template void +distance( + raft::resources const& handle, + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + void* workspace, + size_t worksize, + bool isRowMajor, + double metric_arg); +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh b/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh index 4e86417840..145834fb70 100644 --- a/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh +++ b/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh @@ -22,20 +22,22 @@ namespace raft { namespace distance { namespace detail { extern template void -distance(const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - cudaStream_t stream, - bool isRowMajor, - float metric_arg); +distance( + raft::resources const& handle, + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + void* workspace, + size_t worksize, + bool isRowMajor, + float metric_arg); extern template void distance( + raft::resources const& handle, const double* x, const double* y, double* dist, @@ -44,7 +46,6 @@ distance( + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -30,12 +31,12 @@ extern template void distance( + raft::resources const& handle, const double* x, const double* y, double* dist, @@ -44,7 +45,6 @@ distance( + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -30,11 +31,11 @@ extern template void distance( + raft::resources const& handle, const double* x, const double* y, double* dist, @@ -43,7 +44,6 @@ extern template void distance( + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -30,22 +31,22 @@ extern template void distance(const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - cudaStream_t stream, - bool isRowMajor, - double metric_arg); +distance( + raft::resources const& handle, + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + void* workspace, + size_t worksize, + bool isRowMajor, + double metric_arg); } // namespace detail } // namespace distance diff --git a/cpp/include/raft/distance/specializations/detail/l2_sqrt_expanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_sqrt_expanded.cuh index 65e48e8401..9f5f6a3706 100644 --- a/cpp/include/raft/distance/specializations/detail/l2_sqrt_expanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/l2_sqrt_expanded.cuh @@ -23,6 +23,7 @@ namespace distance { namespace detail { extern template void distance( + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -31,12 +32,12 @@ distance int k, void* workspace, size_t worksize, - cudaStream_t stream, bool isRowMajor, float metric_arg); extern template void distance( + raft::resources const& handle, const double* x, const double* y, double* dist, @@ -45,7 +46,6 @@ distance( + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -31,12 +32,12 @@ distance( + raft::resources const& handle, const double* x, const double* y, double* dist, @@ -45,7 +46,6 @@ distance( + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -30,12 +31,12 @@ extern template void distance( + raft::resources const& handle, const double* x, const double* y, double* dist, @@ -44,7 +45,6 @@ distance( + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -30,12 +31,12 @@ extern template void distance( + raft::resources const& handle, const double* x, const double* y, double* dist, @@ -44,7 +45,6 @@ distance( + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -31,12 +32,12 @@ distance( + raft::resources const& handle, const double* x, const double* y, double* dist, @@ -45,7 +46,6 @@ distance #include #include +#include #include #include #include diff --git a/cpp/include/raft/linalg/detail/gemm.hpp b/cpp/include/raft/linalg/detail/gemm.hpp index ba9496c3b9..d82c821148 100644 --- a/cpp/include/raft/linalg/detail/gemm.hpp +++ b/cpp/include/raft/linalg/detail/gemm.hpp @@ -20,7 +20,8 @@ #include "cublas_wrappers.hpp" -#include +#include +#include namespace raft { namespace linalg { @@ -49,7 +50,7 @@ namespace detail { * @param [in] stream */ template -void gemm(raft::device_resources const& handle, +void gemm(raft::resources const& handle, const bool trans_a, const bool trans_b, const int m, @@ -65,7 +66,7 @@ void gemm(raft::device_resources const& handle, const int ldc, cudaStream_t stream) { - cublasHandle_t cublas_h = handle.get_cublas_handle(); + auto cublas_h = raft::resource::get_cublas_handle(handle); cublas_device_pointer_mode pmode(cublas_h); RAFT_CUBLAS_TRY(cublasgemm(cublas_h, trans_a ? CUBLAS_OP_T : CUBLAS_OP_N, @@ -103,7 +104,7 @@ void gemm(raft::device_resources const& handle, * @param stream cuda stream */ template -void gemm(raft::device_resources const& handle, +void gemm(raft::resources const& handle, const math_t* a, int n_rows_a, int n_cols_a, @@ -117,7 +118,7 @@ void gemm(raft::device_resources const& handle, math_t beta, cudaStream_t stream) { - cublasHandle_t cublas_h = handle.get_cublas_handle(); + auto cublas_h = raft::resource::get_cublas_handle(handle); int m = n_rows_c; int n = n_cols_c; @@ -130,7 +131,7 @@ void gemm(raft::device_resources const& handle, } template -void gemm(raft::device_resources const& handle, +void gemm(raft::resources const& handle, const math_t* a, int n_rows_a, int n_cols_a, @@ -149,7 +150,7 @@ void gemm(raft::device_resources const& handle, } template -void gemm(raft::device_resources const& handle, +void gemm(raft::resources const& handle, T* z, T* x, T* y, @@ -163,7 +164,7 @@ void gemm(raft::device_resources const& handle, T* alpha, T* beta) { - cublasHandle_t cublas_h = handle.get_cublas_handle(); + auto cublas_h = raft::resource::get_cublas_handle(handle); cublas_device_pointer_mode pmode(cublas_h); cublasOperation_t trans_a, trans_b; diff --git a/cpp/include/raft/linalg/gemm.cuh b/cpp/include/raft/linalg/gemm.cuh index d5dc5ffab5..a336f844bf 100644 --- a/cpp/include/raft/linalg/gemm.cuh +++ b/cpp/include/raft/linalg/gemm.cuh @@ -52,7 +52,7 @@ namespace linalg { * @param [in] stream */ template -void gemm(raft::device_resources const& handle, +void gemm(raft::resources const& handle, const bool trans_a, const bool trans_b, const int m, @@ -91,7 +91,7 @@ void gemm(raft::device_resources const& handle, * @param stream cuda stream */ template -void gemm(raft::device_resources const& handle, +void gemm(raft::resources const& handle, const math_t* a, int n_rows_a, int n_cols_a, @@ -126,7 +126,7 @@ void gemm(raft::device_resources const& handle, * @param stream cuda stream */ template -void gemm(raft::device_resources const& handle, +void gemm(raft::resources const& handle, const math_t* a, int n_rows_a, int n_cols_a, @@ -161,7 +161,7 @@ void gemm(raft::device_resources const& handle, * @param beta scalar */ template -void gemm(raft::device_resources const& handle, +void gemm(raft::resources const& handle, T* z, T* x, T* y, diff --git a/cpp/src/distance/distance/specializations/detail/canberra_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/canberra_double_double_double_int.cu index a00861f421..d3a629f5a0 100644 --- a/cpp/src/distance/distance/specializations/detail/canberra_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/canberra_double_double_double_int.cu @@ -19,20 +19,8 @@ namespace raft { namespace distance { namespace detail { -template void distance( - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - cudaStream_t stream, - bool isRowMajor, - float metric_arg); - template void distance( + raft::resources const& handle, const double* x, const double* y, double* dist, @@ -41,10 +29,8 @@ template void distance( + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -28,7 +29,6 @@ template void distance( + raft::resources const& handle, const double* x, const double* y, double* dist, @@ -29,7 +30,6 @@ template void distance( + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -28,7 +29,6 @@ template void distance( + raft::resources const& handle, const double* x, const double* y, double* dist, @@ -30,7 +31,6 @@ distance( + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -28,7 +29,6 @@ template void distance( + raft::resources const& handle, const double* x, const double* y, double* dist, @@ -29,7 +30,6 @@ template void distance( + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -28,7 +30,6 @@ template void distance( + raft::resources const& handle, const double* x, const double* y, double* dist, @@ -30,7 +31,6 @@ distance( + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -28,7 +29,6 @@ template void distance( + raft::resources const& handle, const double* x, const double* y, double* dist, @@ -30,7 +31,6 @@ distance( + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -28,7 +30,6 @@ template void distance + +namespace raft { +namespace distance { +namespace detail { +template void distance( + raft::resources const& handle, + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + void* workspace, + std::size_t worksize, + bool isRowMajor, + double metric_arg); + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/src/distance/distance/specializations/detail/inner_product_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/inner_product_float_float_float_int.cu new file mode 100644 index 0000000000..f147c51b52 --- /dev/null +++ b/cpp/src/distance/distance/specializations/detail/inner_product_float_float_float_int.cu @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft { +namespace distance { +namespace detail { +template void distance( + raft::resources const& handle, + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + void* workspace, + std::size_t worksize, + bool isRowMajor, + float metric_arg); +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu index 30fbf70322..c490033dc3 100644 --- a/cpp/src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ namespace raft { namespace distance { namespace detail { template void distance( + raft::resources const& handle, const double* x, const double* y, double* dist, @@ -28,7 +29,6 @@ template void distance( + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -28,23 +29,9 @@ template void distance( - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - cudaStream_t stream, - bool isRowMajor, - double metric_arg); - } // namespace detail } // namespace distance } // namespace raft diff --git a/cpp/src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu index 0ac2b23b29..ed2a29bccf 100644 --- a/cpp/src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ namespace raft { namespace distance { namespace detail { template void distance( + raft::resources const& handle, const double* x, const double* y, double* dist, @@ -28,7 +29,6 @@ template void distance( + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -28,7 +29,6 @@ template void distance( + raft::resources const& handle, const double* x, const double* y, double* dist, @@ -28,7 +29,6 @@ template void distance( + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -28,7 +29,6 @@ template void distance( + raft::resources const& handle, const double* x, const double* y, double* dist, @@ -28,7 +29,6 @@ template void distance( + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -28,7 +29,6 @@ template void distance( + raft::resources const& handle, const double* x, const double* y, double* dist, @@ -28,7 +29,6 @@ template void distance( + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -28,7 +29,6 @@ template void distance( + raft::resources const& handle, const double* x, const double* y, double* dist, @@ -29,7 +30,6 @@ template void distance( + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -28,7 +29,6 @@ template void distance( + raft::resources const& handle, const double* x, const double* y, double* dist, @@ -28,7 +29,6 @@ template void distance( + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -28,7 +29,6 @@ template void distance( + raft::resources const& handle, const double* x, const double* y, double* dist, @@ -29,7 +30,6 @@ template void distance( + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -28,7 +29,6 @@ template void distance( + raft::resources const& handle, const double* x, const double* y, double* dist, @@ -29,7 +30,6 @@ distance( + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -28,7 +29,6 @@ template void distance(x.data(), + threshold_final_op_>(handle, + x.data(), y.data(), dist.data(), m, @@ -141,7 +142,6 @@ class DistanceAdjTest : public ::testing::TestWithParam +class DistanceInnerProduct + : public DistanceTest { +}; + +const std::vector> inputsf = { + {0.001f, 10, 5, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceInnerProduct DistanceInnerProductF; +TEST_P(DistanceInnerProductF, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceInnerProductF, ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.001, 1024, 1024, 32, true, 1234ULL}, + {0.001, 1024, 32, 1024, true, 1234ULL}, + {0.001, 32, 1024, 1024, true, 1234ULL}, + {0.003, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceInnerProduct DistanceInnerProductD; +TEST_P(DistanceInnerProductD, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceInnerProductD, ::testing::ValuesIn(inputsd)); + +class BigMatrixInnerProduct + : public BigMatrixDistanceTest { +}; +TEST_F(BigMatrixInnerProduct, Result) {} + +} // end namespace distance +} // end namespace raft diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index be7b2b1de8..278fde0c45 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -122,6 +122,28 @@ __global__ void naiveCosineDistanceKernel( dist[outidx] = (DataType)1.0 - acc_ab / (raft::sqrt(acc_a) * raft::sqrt(acc_b)); } +template +__global__ void naiveInnerProductKernel( + DataType* dist, const DataType* x, const DataType* y, int m, int n, int k, bool isRowMajor) +{ + int midx = threadIdx.x + blockIdx.x * blockDim.x; + int nidx = threadIdx.y + blockIdx.y * blockDim.y; + if (midx >= m || nidx >= n) { return; } + + DataType acc_ab = DataType(0); + + for (int i = 0; i < k; ++i) { + int xidx = isRowMajor ? i + midx * k : i * m + midx; + int yidx = isRowMajor ? i + nidx * k : i * n + nidx; + auto a = x[xidx]; + auto b = y[yidx]; + acc_ab += a * b; + } + + int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; + dist[outidx] = acc_ab; +} + template __global__ void naiveHellingerDistanceKernel( DataType* dist, const DataType* x, const DataType* y, int m, int n, int k, bool isRowMajor) @@ -348,6 +370,9 @@ void naiveDistance(DataType* dist, naiveHammingDistanceKernel <<>>(dist, x, y, m, n, k, isRowMajor); break; + case raft::distance::DistanceType::InnerProduct: + naiveInnerProductKernel<<>>(dist, x, y, m, n, k, isRowMajor); + break; case raft::distance::DistanceType::JensenShannon: naiveJensenShannonDistanceKernel <<>>(dist, x, y, m, n, k, isRowMajor); @@ -385,7 +410,8 @@ template } template -void distanceLauncher(DataType* x, +void distanceLauncher(raft::device_resources const& handle, + DataType* x, DataType* y, DataType* dist, DataType* dist2, @@ -394,11 +420,8 @@ void distanceLauncher(DataType* x, int k, DistanceInputs& params, DataType threshold, - cudaStream_t stream, DataType metric_arg = 2.0f) { - raft::device_resources handle(stream); - auto x_v = make_device_matrix_view(x, m, k); auto y_v = make_device_matrix_view(y, n, k); auto dist_v = make_device_matrix_view(dist, m, n); @@ -454,7 +477,8 @@ class DistanceTest : public ::testing::TestWithParam> { DataType threshold = -10000.f; if (isRowMajor) { - distanceLauncher(x.data(), + distanceLauncher(handle, + x.data(), y.data(), dist.data(), dist2.data(), @@ -463,11 +487,11 @@ class DistanceTest : public ::testing::TestWithParam> { k, params, threshold, - stream, metric_arg); } else { - distanceLauncher(x.data(), + distanceLauncher(handle, + x.data(), y.data(), dist.data(), dist2.data(), @@ -476,7 +500,6 @@ class DistanceTest : public ::testing::TestWithParam> { k, params, threshold, - stream, metric_arg); } handle.sync_stream(stream); @@ -500,20 +523,8 @@ class BigMatrixDistanceTest : public ::testing::Test { auto testInfo = testing::UnitTest::GetInstance()->current_test_info(); common::nvtx::range fun_scope("test::%s/%s", testInfo->test_suite_name(), testInfo->name()); - size_t worksize = raft::distance::getWorkspaceSize( - x.data(), x.data(), m, n, k); - rmm::device_uvector workspace(worksize, handle.get_stream()); - raft::distance::distance(x.data(), - x.data(), - dist.data(), - m, - n, - k, - workspace.data(), - worksize, - handle.get_stream(), - true, - 0.0f); + raft::distance::distance( + handle, x.data(), x.data(), dist.data(), m, n, k, true, 0.0f); RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); } diff --git a/python/pylibraft/pylibraft/distance/pairwise_distance.pyx b/python/pylibraft/pylibraft/distance/pairwise_distance.pyx index 9649531b61..3037b9a725 100644 --- a/python/pylibraft/pylibraft/distance/pairwise_distance.pyx +++ b/python/pylibraft/pylibraft/distance/pairwise_distance.pyx @@ -79,7 +79,7 @@ DISTANCE_TYPES = { "kl_divergence": DistanceType.KLDivergence, "minkowski": DistanceType.LpUnexpanded, "russellrao": DistanceType.RusselRaoExpanded, - "dice": DistanceType.DiceExpanded + "dice": DistanceType.DiceExpanded, } SUPPORTED_DISTANCES = ["euclidean", "l1", "cityblock", "l2", "inner_product", diff --git a/python/pylibraft/pylibraft/test/test_distance.py b/python/pylibraft/pylibraft/test/test_distance.py index dd6050a098..2c0a842fe5 100644 --- a/python/pylibraft/pylibraft/test/test_distance.py +++ b/python/pylibraft/pylibraft/test/test_distance.py @@ -21,8 +21,8 @@ from pylibraft.distance import pairwise_distance -@pytest.mark.parametrize("n_rows", [100]) -@pytest.mark.parametrize("n_cols", [100]) +@pytest.mark.parametrize("n_rows", [32, 100]) +@pytest.mark.parametrize("n_cols", [40, 100]) @pytest.mark.parametrize( "metric", [ @@ -36,6 +36,7 @@ "russellrao", "cosine", "sqeuclidean", + "inner_product", ], ) @pytest.mark.parametrize("inplace", [True, False]) @@ -57,7 +58,10 @@ def test_distance(n_rows, n_cols, inplace, metric, order, dtype): output = np.zeros((n_rows, n_rows), dtype=dtype) - expected = cdist(input1, input1, metric) + if metric == "inner_product": + expected = np.matmul(input1, input1.T) + else: + expected = cdist(input1, input1, metric) expected[expected <= 1e-5] = 0.0