From c18564f929f4f04ed63f5bddd7841633a8057591 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 6 Feb 2023 14:21:17 -0800 Subject: [PATCH 1/7] Add InnerProduct to pairwise distances api This adds InnerProduct distance to the pairwise distances api, using cublass gemm to compute the distance. Since this requires a cublas handle, this also changes the distance api to take a raft::device_resources instead of just a cuda stream --- cpp/CMakeLists.txt | 1 + cpp/include/raft/distance/detail/distance.cuh | 180 ++++++++++++------ .../distance/detail/kernels/gram_matrix.cuh | 6 +- .../detail/kernels/kernel_matrices.cuh | 5 +- cpp/include/raft/distance/distance.cuh | 62 +++--- .../specializations/detail/canberra.cuh | 5 +- .../specializations/detail/chebyshev.cuh | 4 +- .../specializations/detail/correlation.cuh | 5 +- .../specializations/detail/cosine.cuh | 5 +- .../detail/hamming_unexpanded.cuh | 5 +- .../detail/hellinger_expanded.cuh | 4 +- .../specializations/detail/inner_product.cuh | 52 +++++ .../specializations/detail/jensen_shannon.cuh | 25 +-- .../specializations/detail/kl_divergence.cuh | 4 +- .../distance/specializations/detail/l1.cuh | 4 +- .../specializations/detail/l2_expanded.cuh | 25 +-- .../detail/l2_sqrt_expanded.cuh | 4 +- .../detail/l2_sqrt_unexpanded.cuh | 4 +- .../specializations/detail/l2_unexpanded.cuh | 4 +- .../specializations/detail/lp_unexpanded.cuh | 4 +- .../specializations/detail/russel_rao.cuh | 4 +- .../distance/specializations/distance.cuh | 3 +- .../canberra_double_double_double_int.cu | 16 +- .../detail/canberra_float_float_float_int.cu | 2 +- .../chebyshev_double_double_double_int.cu | 2 +- .../detail/chebyshev_float_float_float_int.cu | 2 +- .../correlation_double_double_double_int.cu | 2 +- .../correlation_float_float_float_int.cu | 2 +- .../detail/cosine_double_double_double_int.cu | 2 +- .../detail/cosine_float_float_float_int.cu | 3 +- ...ing_unexpanded_double_double_double_int.cu | 2 +- ...amming_unexpanded_float_float_float_int.cu | 2 +- ...inger_expanded_double_double_double_int.cu | 2 +- ...ellinger_expanded_float_float_float_int.cu | 3 +- .../specializations/detail/inner_product.cu | 50 +++++ ...jensen_shannon_double_double_double_int.cu | 4 +- .../jensen_shannon_float_float_float_int.cu | 17 +- .../kl_divergence_double_double_double_int.cu | 4 +- .../kl_divergence_float_float_float_int.cu | 4 +- .../detail/l1_double_double_double_int.cu | 4 +- .../detail/l1_float_float_float_int.cu | 4 +- .../l2_expanded_double_double_double_int.cu | 4 +- .../l2_expanded_float_float_float_int.cu | 4 +- ..._sqrt_expanded_double_double_double_int.cu | 4 +- .../l2_sqrt_expanded_float_float_float_int.cu | 4 +- ...qrt_unexpanded_double_double_double_int.cu | 4 +- ...2_sqrt_unexpanded_float_float_float_int.cu | 4 +- .../l2_unexpanded_double_double_double_int.cu | 4 +- .../l2_unexpanded_float_float_float_int.cu | 4 +- .../lp_unexpanded_double_double_double_int.cu | 4 +- .../lp_unexpanded_float_float_float_int.cu | 4 +- .../russel_rao_double_double_double_int.cu | 4 +- .../russel_rao_float_float_float_int.cu | 4 +- cpp/test/distance/dist_adj.cu | 4 +- cpp/test/distance/distance_base.cuh | 30 +-- python/pylibraft/CMakeLists.txt | 2 +- .../pylibraft/distance/pairwise_distance.pyx | 2 +- .../pylibraft/pylibraft/test/test_distance.py | 6 +- 58 files changed, 381 insertions(+), 253 deletions(-) create mode 100644 cpp/include/raft/distance/specializations/detail/inner_product.cuh create mode 100644 cpp/src/distance/distance/specializations/detail/inner_product.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index cf938a5d33..0f32b1bec2 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -307,6 +307,7 @@ if(RAFT_COMPILE_DIST_LIBRARY) src/distance/neighbors/specializations/refine_h_uint64_t_uint8_t.cu src/distance/cluster/kmeans_fit_float.cu src/distance/cluster/kmeans_fit_double.cu + src/distance/distance/specializations/detail/inner_product.cu src/distance/distance/specializations/detail/canberra_double_double_double_int.cu src/distance/distance/specializations/detail/canberra_float_float_float_int.cu src/distance/distance/specializations/detail/chebyshev_double_double_double_int.cu diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index b459c73bee..59e71d2b2f 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,6 +30,7 @@ #include #include #include +#include #include #include @@ -92,7 +93,8 @@ template struct DistanceImpl { - void run(const InType* x, + void run(raft::device_resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -101,7 +103,6 @@ struct DistanceImpl { void* workspace, size_t worksize, FinalLambda fin_op, - cudaStream_t stream, bool isRowMajor, InType metric_arg = 2.0f) { @@ -119,7 +120,8 @@ struct DistanceImpl { - void run(const InType* x, + void run(raft::device_resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -128,12 +130,22 @@ struct DistanceImpl( - m, n, k, x, y, dist, false, (AccType*)workspace, worksize, fin_op, stream, isRowMajor); + m, + n, + k, + x, + y, + dist, + false, + (AccType*)workspace, + worksize, + fin_op, + handle.get_stream(), + isRowMajor); } }; @@ -148,7 +160,8 @@ struct DistanceImpl { - void run(const InType* x, + void run(raft::device_resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -157,12 +170,22 @@ struct DistanceImpl( - m, n, k, x, y, dist, true, (AccType*)workspace, worksize, fin_op, stream, isRowMajor); + m, + n, + k, + x, + y, + dist, + true, + (AccType*)workspace, + worksize, + fin_op, + handle.get_stream(), + isRowMajor); } }; @@ -177,7 +200,8 @@ struct DistanceImpl { - void run(const InType* x, + void run(raft::device_resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -186,12 +210,49 @@ struct DistanceImpl( - m, n, k, x, y, dist, (AccType*)workspace, worksize, fin_op, stream, isRowMajor); + m, n, k, x, y, dist, (AccType*)workspace, worksize, fin_op, handle.get_stream(), isRowMajor); + } +}; + +template +struct DistanceImpl { + void run(raft::device_resources const& handle, + const InType* x, + const InType* y, + OutType* dist, + Index_ m, + Index_ n, + Index_ k, + void* workspace, + size_t worksize, + FinalLambda fin_op, + bool isRowMajor, + InType) + { + raft::linalg::gemm(handle, + dist, + const_cast(x), + const_cast(y), + m, + n, + k, + isRowMajor, + !isRowMajor, // transpose + isRowMajor, + handle.get_stream()); } }; @@ -206,7 +267,8 @@ struct DistanceImpl { - void run(const InType* x, + void run(raft::device_resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -215,12 +277,11 @@ struct DistanceImpl( - m, n, k, x, y, dist, false, fin_op, stream, isRowMajor); + m, n, k, x, y, dist, false, fin_op, handle.get_stream(), isRowMajor); } }; @@ -235,7 +296,8 @@ struct DistanceImpl { - void run(const InType* x, + void run(raft::device_resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -244,12 +306,11 @@ struct DistanceImpl( - m, n, k, x, y, dist, true, fin_op, stream, isRowMajor); + m, n, k, x, y, dist, true, fin_op, handle.get_stream(), isRowMajor); } }; @@ -264,7 +325,8 @@ struct DistanceImpl { - void run(const InType* x, + void run(raft::device_resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -273,12 +335,11 @@ struct DistanceImpl( - m, n, k, x, y, dist, fin_op, stream, isRowMajor); + m, n, k, x, y, dist, fin_op, handle.get_stream(), isRowMajor); } }; @@ -293,7 +354,8 @@ struct DistanceImpl { - void run(const InType* x, + void run(raft::device_resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -302,12 +364,11 @@ struct DistanceImpl( - m, n, k, x, y, dist, fin_op, stream, isRowMajor); + m, n, k, x, y, dist, fin_op, handle.get_stream(), isRowMajor); } }; @@ -322,7 +383,8 @@ struct DistanceImpl { - void run(const InType* x, + void run(raft::device_resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -331,12 +393,11 @@ struct DistanceImpl( - m, n, k, x, y, dist, fin_op, stream, isRowMajor); + m, n, k, x, y, dist, fin_op, handle.get_stream(), isRowMajor); } }; @@ -351,7 +412,8 @@ struct DistanceImpl { - void run(const InType* x, + void run(raft::device_resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -360,12 +422,11 @@ struct DistanceImpl( - m, n, k, x, y, dist, fin_op, stream, isRowMajor, metric_arg); + m, n, k, x, y, dist, fin_op, handle.get_stream(), isRowMajor, metric_arg); } }; @@ -380,7 +441,8 @@ struct DistanceImpl { - void run(const InType* x, + void run(raft::device_resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -389,12 +451,11 @@ struct DistanceImpl( - m, n, k, x, y, dist, fin_op, stream, isRowMajor); + m, n, k, x, y, dist, fin_op, handle.get_stream(), isRowMajor); } }; @@ -409,7 +470,8 @@ struct DistanceImpl { - void run(const InType* x, + void run(raft::device_resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -418,12 +480,11 @@ struct DistanceImpl( - m, n, k, x, y, dist, fin_op, stream, isRowMajor); + m, n, k, x, y, dist, fin_op, handle.get_stream(), isRowMajor); } }; @@ -438,7 +499,8 @@ struct DistanceImpl { - void run(const InType* x, + void run(raft::device_resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -447,12 +509,11 @@ struct DistanceImpl( - m, n, k, x, y, dist, fin_op, stream, isRowMajor); + m, n, k, x, y, dist, fin_op, handle.get_stream(), isRowMajor); } }; @@ -467,7 +528,8 @@ struct DistanceImpl { - void run(const InType* x, + void run(raft::device_resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -476,12 +538,11 @@ struct DistanceImpl( - m, n, k, x, y, dist, fin_op, stream, isRowMajor); + m, n, k, x, y, dist, fin_op, handle.get_stream(), isRowMajor); } }; @@ -496,7 +557,8 @@ struct DistanceImpl { - void run(const InType* x, + void run(raft::device_resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -505,12 +567,11 @@ struct DistanceImpl( - m, n, k, x, y, dist, fin_op, stream, isRowMajor); + m, n, k, x, y, dist, fin_op, handle.get_stream(), isRowMajor); } }; @@ -525,7 +586,8 @@ struct DistanceImpl { - void run(const InType* x, + void run(raft::device_resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -534,12 +596,11 @@ struct DistanceImpl( - m, n, k, x, y, dist, (AccType*)workspace, worksize, fin_op, stream, isRowMajor); + m, n, k, x, y, dist, (AccType*)workspace, worksize, fin_op, handle.get_stream(), isRowMajor); } }; @@ -562,7 +623,6 @@ struct DistanceImpl -void distance(const InType* x, +void distance(raft::device_resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -585,12 +646,11 @@ void distance(const InType* x, void* workspace, size_t worksize, FinalLambda fin_op, - cudaStream_t stream, bool isRowMajor = true, InType metric_arg = 2.0f) { DistanceImpl distImpl; - distImpl.run(x, y, dist, m, n, k, workspace, worksize, fin_op, stream, isRowMajor, metric_arg); + distImpl.run(handle, x, y, dist, m, n, k, workspace, worksize, fin_op, isRowMajor, metric_arg); RAFT_CUDA_TRY(cudaPeekAtLastError()); } @@ -609,7 +669,6 @@ void distance(const InType* x, * @param k dimensionality * @param workspace temporary workspace needed for computations * @param worksize number of bytes of the workspace - * @param stream cuda stream * @param isRowMajor whether the matrices are row-major or col-major * * @note if workspace is passed as nullptr, this will return in @@ -633,7 +692,8 @@ template -void distance(const InType* x, +void distance(raft::device_resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -641,7 +701,6 @@ void distance(const InType* x, Index_ k, void* workspace, size_t worksize, - cudaStream_t stream, bool isRowMajor = true, InType metric_arg = 2.0f) { @@ -653,7 +712,7 @@ void distance(const InType* x, "OutType can be uint8_t, float, double," "if sizeof(OutType) > 1 then sizeof(AccType) == sizeof(OutType)."); distance( - x, y, dist, m, n, k, workspace, worksize, fin_op, stream, isRowMajor, metric_arg); + handle, x, y, dist, m, n, k, workspace, worksize, fin_op, isRowMajor, metric_arg); RAFT_CUDA_TRY(cudaPeekAtLastError()); } @@ -710,25 +769,24 @@ size_t getWorkspaceSize(const InType* x, const InType* y, Index_ m, Index_ n, In * @param workspace temporary workspace buffer which can get resized as per the * needed workspace size * @param metric distance metric - * @param stream cuda stream * @param isRowMajor whether the matrices are row-major or col-major */ template -void pairwise_distance_impl(const Type* x, +void pairwise_distance_impl(raft::device_resources const& handle, + const Type* x, const Type* y, Type* dist, Index_ m, Index_ n, Index_ k, rmm::device_uvector& workspace, - cudaStream_t stream, bool isRowMajor, Type metric_arg = 2.0f) { auto worksize = getWorkspaceSize(x, y, m, n, k); - workspace.resize(worksize, stream); + workspace.resize(worksize, handle.get_stream()); distance( - x, y, dist, m, n, k, workspace.data(), worksize, stream, isRowMajor, metric_arg); + handle, x, y, dist, m, n, k, workspace.data(), worksize, isRowMajor, metric_arg); } /** @} */ }; // namespace detail diff --git a/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh b/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh index 344dda693e..aaf3052892 100644 --- a/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh +++ b/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -212,7 +212,7 @@ class GramMatrixBase { int ld_out) { raft::distance::distance( - x1, x2, out, n1, n2, n_cols, stream, is_row_major); + raft::device_resources(stream), x1, x2, out, n1, n2, n_cols, is_row_major); } }; -}; // end namespace raft::distance::kernels::detail \ No newline at end of file +}; // end namespace raft::distance::kernels::detail diff --git a/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh b/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh index b74de84d80..3d37b37293 100644 --- a/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh +++ b/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -359,7 +359,8 @@ class RBFKernel : public GramMatrixBase { math_t, math_t, decltype(fin_op), - index_t>(const_cast(x1), + index_t>(device_resources(stream), + const_cast(x1), const_cast(x2), out, n1, diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index 93a5ce7f1a..623f788d79 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -50,7 +50,6 @@ namespace distance { * @param workspace temporary workspace needed for computations * @param worksize number of bytes of the workspace * @param fin_op the final gemm epilogue lambda - * @param stream cuda stream * @param isRowMajor whether the matrices are row-major or col-major * @param metric_arg metric argument (used for Minkowski distance) * @@ -65,7 +64,8 @@ template -void distance(const InType* x, +void distance(raft::device_resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -74,12 +74,11 @@ void distance(const InType* x, void* workspace, size_t worksize, FinalLambda fin_op, - cudaStream_t stream, bool isRowMajor = true, InType metric_arg = 2.0f) { detail::distance( - x, y, dist, m, n, k, workspace, worksize, fin_op, stream, isRowMajor, metric_arg); + handle, x, y, dist, m, n, k, workspace, worksize, fin_op, isRowMajor, metric_arg); } /** @@ -97,7 +96,6 @@ void distance(const InType* x, * @param k dimensionality * @param workspace temporary workspace needed for computations * @param worksize number of bytes of the workspace - * @param stream cuda stream * @param isRowMajor whether the matrices are row-major or col-major * @param metric_arg metric argument (used for Minkowski distance) * @@ -109,7 +107,8 @@ template -void distance(const InType* x, +void distance(raft::device_resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, @@ -117,12 +116,11 @@ void distance(const InType* x, Index_ k, void* workspace, size_t worksize, - cudaStream_t stream, bool isRowMajor = true, InType metric_arg = 2.0f) { detail::distance( - x, y, dist, m, n, k, workspace, worksize, stream, isRowMajor, metric_arg); + handle, x, y, dist, m, n, k, workspace, worksize, isRowMajor, metric_arg); } /** @@ -193,7 +191,6 @@ size_t getWorkspaceSize(const raft::device_matrix_view x, * @param m number of points in x * @param n number of points in y * @param k dimensionality - * @param stream cuda stream * @param isRowMajor whether the matrices are row-major or col-major * @param metric_arg metric argument (used for Minkowski distance) */ @@ -202,21 +199,22 @@ template -void distance(const InType* x, +void distance(raft::device_resources const& handle, + const InType* x, const InType* y, OutType* dist, Index_ m, Index_ n, Index_ k, - cudaStream_t stream, bool isRowMajor = true, InType metric_arg = 2.0f) { + auto stream = handle.get_stream(); rmm::device_uvector workspace(0, stream); auto worksize = getWorkspaceSize(x, y, m, n, k); workspace.resize(worksize, stream); detail::distance( - x, y, dist, m, n, k, workspace.data(), worksize, stream, isRowMajor, metric_arg); + handle, x, y, dist, m, n, k, workspace.data(), worksize, isRowMajor, metric_arg); } /** @@ -253,64 +251,68 @@ void pairwise_distance(raft::device_resources const& handle, switch (metric) { case raft::distance::DistanceType::L2Expanded: detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, workspace, isRowMajor); break; case raft::distance::DistanceType::L2SqrtExpanded: detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, workspace, isRowMajor); break; case raft::distance::DistanceType::CosineExpanded: detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, workspace, isRowMajor); break; case raft::distance::DistanceType::L1: detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, workspace, isRowMajor); break; case raft::distance::DistanceType::L2Unexpanded: detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, workspace, isRowMajor); break; case raft::distance::DistanceType::L2SqrtUnexpanded: detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, workspace, isRowMajor); break; case raft::distance::DistanceType::Linf: detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, workspace, isRowMajor); break; case raft::distance::DistanceType::HellingerExpanded: detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, workspace, isRowMajor); break; case raft::distance::DistanceType::LpUnexpanded: detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor, metric_arg); + handle, x, y, dist, m, n, k, workspace, isRowMajor, metric_arg); break; case raft::distance::DistanceType::Canberra: detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, workspace, isRowMajor); break; case raft::distance::DistanceType::HammingUnexpanded: detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, workspace, isRowMajor); break; case raft::distance::DistanceType::JensenShannon: detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, workspace, isRowMajor); break; case raft::distance::DistanceType::RusselRaoExpanded: detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, workspace, isRowMajor); break; case raft::distance::DistanceType::KLDivergence: detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, workspace, isRowMajor); break; case raft::distance::DistanceType::CorrelationExpanded: detail:: pairwise_distance_impl( - x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + handle, x, y, dist, m, n, k, workspace, isRowMajor); + break; + case raft::distance::DistanceType::InnerProduct: + detail::pairwise_distance_impl( + handle, x, y, dist, m, n, k, workspace, isRowMajor); break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); }; @@ -417,13 +419,13 @@ void distance(raft::device_resources const& handle, constexpr auto is_rowmajor = std::is_same_v; - distance(x.data_handle(), + distance(handle, + x.data_handle(), y.data_handle(), dist.data_handle(), x.extent(0), y.extent(0), x.extent(1), - handle.get_stream(), is_rowmajor, metric_arg); } @@ -481,4 +483,4 @@ void pairwise_distance(raft::device_resources const& handle, }; // namespace distance }; // namespace raft -#endif \ No newline at end of file +#endif diff --git a/cpp/include/raft/distance/specializations/detail/canberra.cuh b/cpp/include/raft/distance/specializations/detail/canberra.cuh index 6b6364fb58..178c9047a6 100644 --- a/cpp/include/raft/distance/specializations/detail/canberra.cuh +++ b/cpp/include/raft/distance/specializations/detail/canberra.cuh @@ -23,6 +23,7 @@ namespace raft { namespace distance { namespace detail { extern template void distance( + raft::device_resources const& handle, const float* x, const float* y, float* dist, @@ -31,11 +32,11 @@ extern template void distance( + raft::device_resources const& handle, const double* x, const double* y, double* dist, @@ -44,10 +45,8 @@ extern template void distance( + raft::device_resources const& handle, const float* x, const float* y, float* dist, @@ -30,11 +31,11 @@ extern template void distance( + raft::device_resources const& handle, const double* x, const double* y, double* dist, @@ -43,7 +44,6 @@ extern template void distance( + raft::device_resources const& handle, const float* x, const float* y, float* dist, @@ -31,12 +32,12 @@ distance( + raft::device_resources const& handle, const double* x, const double* y, double* dist, @@ -45,10 +46,8 @@ distance( + raft::device_resources const& handle, const float* x, const float* y, float* dist, @@ -31,12 +32,12 @@ distance int k, void* workspace, size_t worksize, - cudaStream_t stream, bool isRowMajor, float metric_arg); extern template void distance( + raft::device_resources const& handle, const double* x, const double* y, double* dist, @@ -45,10 +46,8 @@ distance( + raft::device_resources const& handle, const float* x, const float* y, float* dist, @@ -31,12 +32,12 @@ distance( + raft::device_resources const& handle, const double* x, const double* y, double* dist, @@ -45,10 +46,8 @@ distance( + raft::device_resources const& handle, const float* x, const float* y, float* dist, @@ -31,12 +32,12 @@ distance( + raft::device_resources const& handle, const double* x, const double* y, double* dist, @@ -45,7 +46,6 @@ distance + +namespace raft { +namespace distance { +namespace detail { +extern template void distance( + raft::device_resources const& handle, + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + void* workspace, + size_t worksize, + bool isRowMajor, + float metric_arg); + +extern template void +distance( + raft::device_resources const& handle, + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + void* workspace, + size_t worksize, + bool isRowMajor, + double metric_arg); +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh b/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh index 4e86417840..7723cce7d4 100644 --- a/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh +++ b/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh @@ -22,20 +22,22 @@ namespace raft { namespace distance { namespace detail { extern template void -distance(const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - cudaStream_t stream, - bool isRowMajor, - float metric_arg); +distance( + raft::device_resources const& handle, + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + void* workspace, + size_t worksize, + bool isRowMajor, + float metric_arg); extern template void distance( + raft::device_resources const& handle, const double* x, const double* y, double* dist, @@ -44,7 +46,6 @@ distance( + raft::device_resources const& handle, const float* x, const float* y, float* dist, @@ -30,12 +31,12 @@ extern template void distance( + raft::device_resources const& handle, const double* x, const double* y, double* dist, @@ -44,7 +45,6 @@ distance( + raft::device_resources const& handle, const float* x, const float* y, float* dist, @@ -30,11 +31,11 @@ extern template void distance( + raft::device_resources const& handle, const double* x, const double* y, double* dist, @@ -43,7 +44,6 @@ extern template void distance( + raft::device_resources const& handle, const float* x, const float* y, float* dist, @@ -30,22 +31,22 @@ extern template void distance(const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - cudaStream_t stream, - bool isRowMajor, - double metric_arg); +distance( + raft::device_resources const& handle, + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + void* workspace, + size_t worksize, + bool isRowMajor, + double metric_arg); } // namespace detail } // namespace distance diff --git a/cpp/include/raft/distance/specializations/detail/l2_sqrt_expanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_sqrt_expanded.cuh index 65e48e8401..3bd65bb769 100644 --- a/cpp/include/raft/distance/specializations/detail/l2_sqrt_expanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/l2_sqrt_expanded.cuh @@ -23,6 +23,7 @@ namespace distance { namespace detail { extern template void distance( + raft::device_resources const& handle, const float* x, const float* y, float* dist, @@ -31,12 +32,12 @@ distance int k, void* workspace, size_t worksize, - cudaStream_t stream, bool isRowMajor, float metric_arg); extern template void distance( + raft::device_resources const& handle, const double* x, const double* y, double* dist, @@ -45,7 +46,6 @@ distance( + raft::device_resources const& handle, const float* x, const float* y, float* dist, @@ -31,12 +32,12 @@ distance( + raft::device_resources const& handle, const double* x, const double* y, double* dist, @@ -45,7 +46,6 @@ distance( + raft::device_resources const& handle, const float* x, const float* y, float* dist, @@ -30,12 +31,12 @@ extern template void distance( + raft::device_resources const& handle, const double* x, const double* y, double* dist, @@ -44,7 +45,6 @@ distance( + raft::device_resources const& handle, const float* x, const float* y, float* dist, @@ -30,12 +31,12 @@ extern template void distance( + raft::device_resources const& handle, const double* x, const double* y, double* dist, @@ -44,7 +45,6 @@ distance( + raft::device_resources const& handle, const float* x, const float* y, float* dist, @@ -31,12 +32,12 @@ distance( + raft::device_resources const& handle, const double* x, const double* y, double* dist, @@ -45,7 +46,6 @@ distance #include #include +#include #include #include #include diff --git a/cpp/src/distance/distance/specializations/detail/canberra_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/canberra_double_double_double_int.cu index a00861f421..2674758992 100644 --- a/cpp/src/distance/distance/specializations/detail/canberra_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/canberra_double_double_double_int.cu @@ -19,20 +19,8 @@ namespace raft { namespace distance { namespace detail { -template void distance( - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - cudaStream_t stream, - bool isRowMajor, - float metric_arg); - template void distance( + raft::device_resources const& handle, const double* x, const double* y, double* dist, @@ -41,10 +29,8 @@ template void distance( + raft::device_resources const& handle, const float* x, const float* y, float* dist, @@ -28,7 +29,6 @@ template void distance( + raft::device_resources const& handle, const double* x, const double* y, double* dist, @@ -29,7 +30,6 @@ template void distance( + raft::device_resources const& handle, const float* x, const float* y, float* dist, @@ -28,7 +29,6 @@ template void distance( + raft::device_resources const& handle, const double* x, const double* y, double* dist, @@ -30,7 +31,6 @@ distance( + raft::device_resources const& handle, const float* x, const float* y, float* dist, @@ -28,7 +29,6 @@ template void distance( + raft::device_resources const& handle, const double* x, const double* y, double* dist, @@ -29,7 +30,6 @@ template void distance( + raft::device_resources const& handle, const float* x, const float* y, float* dist, @@ -28,7 +30,6 @@ template void distance( + raft::device_resources const& handle, const double* x, const double* y, double* dist, @@ -30,7 +31,6 @@ distance( + raft::device_resources const& handle, const float* x, const float* y, float* dist, @@ -28,7 +29,6 @@ template void distance( + raft::device_resources const& handle, const double* x, const double* y, double* dist, @@ -30,7 +31,6 @@ distance( + raft::device_resources const& handle, const float* x, const float* y, float* dist, @@ -28,7 +30,6 @@ template void distance + +namespace raft { +namespace distance { +namespace detail { +template void distance( + raft::device_resources const& handle, + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + void* workspace, + std::size_t worksize, + bool isRowMajor, + float metric_arg); + +template void distance( + raft::device_resources const& handle, + const double* x, + const double* y, + double* dist, + int m, + int n, + int k, + void* workspace, + std::size_t worksize, + bool isRowMajor, + double metric_arg); + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu index 30fbf70322..74aa93c4af 100644 --- a/cpp/src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ namespace raft { namespace distance { namespace detail { template void distance( + raft::device_resources const& handle, const double* x, const double* y, double* dist, @@ -28,7 +29,6 @@ template void distance( + raft::device_resources const& handle, const float* x, const float* y, float* dist, @@ -28,23 +29,9 @@ template void distance( - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - cudaStream_t stream, - bool isRowMajor, - double metric_arg); - } // namespace detail } // namespace distance } // namespace raft diff --git a/cpp/src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu index 0ac2b23b29..7e0f802949 100644 --- a/cpp/src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ namespace raft { namespace distance { namespace detail { template void distance( + raft::device_resources const& handle, const double* x, const double* y, double* dist, @@ -28,7 +29,6 @@ template void distance( + raft::device_resources const& handle, const float* x, const float* y, float* dist, @@ -28,7 +29,6 @@ template void distance( + raft::device_resources const& handle, const double* x, const double* y, double* dist, @@ -28,7 +29,6 @@ template void distance( + raft::device_resources const& handle, const float* x, const float* y, float* dist, @@ -28,7 +29,6 @@ template void distance( + raft::device_resources const& handle, const double* x, const double* y, double* dist, @@ -28,7 +29,6 @@ template void distance( + raft::device_resources const& handle, const float* x, const float* y, float* dist, @@ -28,7 +29,6 @@ template void distance( + raft::device_resources const& handle, const double* x, const double* y, double* dist, @@ -28,7 +29,6 @@ template void distance( + raft::device_resources const& handle, const float* x, const float* y, float* dist, @@ -28,7 +29,6 @@ template void distance( + raft::device_resources const& handle, const double* x, const double* y, double* dist, @@ -29,7 +30,6 @@ template void distance( + raft::device_resources const& handle, const float* x, const float* y, float* dist, @@ -28,7 +29,6 @@ template void distance( + raft::device_resources const& handle, const double* x, const double* y, double* dist, @@ -28,7 +29,6 @@ template void distance( + raft::device_resources const& handle, const float* x, const float* y, float* dist, @@ -28,7 +29,6 @@ template void distance( + raft::device_resources const& handle, const double* x, const double* y, double* dist, @@ -29,7 +30,6 @@ template void distance( + raft::device_resources const& handle, const float* x, const float* y, float* dist, @@ -28,7 +29,6 @@ template void distance( + raft::device_resources const& handle, const double* x, const double* y, double* dist, @@ -29,7 +30,6 @@ distance( + raft::device_resources const& handle, const float* x, const float* y, float* dist, @@ -28,7 +29,6 @@ template void distance(x.data(), + threshold_final_op_>(handle, + x.data(), y.data(), dist.data(), m, @@ -141,7 +142,6 @@ class DistanceAdjTest : public ::testing::TestWithParam } template -void distanceLauncher(DataType* x, +void distanceLauncher(raft::device_resources const& handle, + DataType* x, DataType* y, DataType* dist, DataType* dist2, @@ -394,11 +395,8 @@ void distanceLauncher(DataType* x, int k, DistanceInputs& params, DataType threshold, - cudaStream_t stream, DataType metric_arg = 2.0f) { - raft::device_resources handle(stream); - auto x_v = make_device_matrix_view(x, m, k); auto y_v = make_device_matrix_view(y, n, k); auto dist_v = make_device_matrix_view(dist, m, n); @@ -454,7 +452,8 @@ class DistanceTest : public ::testing::TestWithParam> { DataType threshold = -10000.f; if (isRowMajor) { - distanceLauncher(x.data(), + distanceLauncher(handle, + x.data(), y.data(), dist.data(), dist2.data(), @@ -463,11 +462,11 @@ class DistanceTest : public ::testing::TestWithParam> { k, params, threshold, - stream, metric_arg); } else { - distanceLauncher(x.data(), + distanceLauncher(handle, + x.data(), y.data(), dist.data(), dist2.data(), @@ -476,7 +475,6 @@ class DistanceTest : public ::testing::TestWithParam> { k, params, threshold, - stream, metric_arg); } handle.sync_stream(stream); @@ -500,20 +498,8 @@ class BigMatrixDistanceTest : public ::testing::Test { auto testInfo = testing::UnitTest::GetInstance()->current_test_info(); common::nvtx::range fun_scope("test::%s/%s", testInfo->test_suite_name(), testInfo->name()); - size_t worksize = raft::distance::getWorkspaceSize( - x.data(), x.data(), m, n, k); - rmm::device_uvector workspace(worksize, handle.get_stream()); - raft::distance::distance(x.data(), - x.data(), - dist.data(), - m, - n, - k, - workspace.data(), - worksize, - handle.get_stream(), - true, - 0.0f); + raft::distance::distance( + handle, x.data(), x.data(), dist.data(), m, n, k, true, 0.0f); RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); } diff --git a/python/pylibraft/CMakeLists.txt b/python/pylibraft/CMakeLists.txt index 98d723e27b..b12d0a63ea 100644 --- a/python/pylibraft/CMakeLists.txt +++ b/python/pylibraft/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at diff --git a/python/pylibraft/pylibraft/distance/pairwise_distance.pyx b/python/pylibraft/pylibraft/distance/pairwise_distance.pyx index 9649531b61..3037b9a725 100644 --- a/python/pylibraft/pylibraft/distance/pairwise_distance.pyx +++ b/python/pylibraft/pylibraft/distance/pairwise_distance.pyx @@ -79,7 +79,7 @@ DISTANCE_TYPES = { "kl_divergence": DistanceType.KLDivergence, "minkowski": DistanceType.LpUnexpanded, "russellrao": DistanceType.RusselRaoExpanded, - "dice": DistanceType.DiceExpanded + "dice": DistanceType.DiceExpanded, } SUPPORTED_DISTANCES = ["euclidean", "l1", "cityblock", "l2", "inner_product", diff --git a/python/pylibraft/pylibraft/test/test_distance.py b/python/pylibraft/pylibraft/test/test_distance.py index dd6050a098..971f40498f 100644 --- a/python/pylibraft/pylibraft/test/test_distance.py +++ b/python/pylibraft/pylibraft/test/test_distance.py @@ -36,6 +36,7 @@ "russellrao", "cosine", "sqeuclidean", + "inner_product", ], ) @pytest.mark.parametrize("inplace", [True, False]) @@ -57,7 +58,10 @@ def test_distance(n_rows, n_cols, inplace, metric, order, dtype): output = np.zeros((n_rows, n_rows), dtype=dtype) - expected = cdist(input1, input1, metric) + if metric == "inner_product": + expected = np.matmul(input1, input1.T) + else: + expected = cdist(input1, input1, metric) expected[expected <= 1e-5] = 0.0 From f40047710c12eef4b0d4b38db31ca1f35f254c1d Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 7 Feb 2023 11:06:49 -0800 Subject: [PATCH 2/7] benchmark build fix --- cpp/bench/distance/distance_common.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/bench/distance/distance_common.cuh b/cpp/bench/distance/distance_common.cuh index 7ddecd7579..f6f6040d3c 100644 --- a/cpp/bench/distance/distance_common.cuh +++ b/cpp/bench/distance/distance_common.cuh @@ -49,7 +49,8 @@ struct distance : public fixture { void run_benchmark(::benchmark::State& state) override { loop_on_state(state, [this]() { - raft::distance::distance(x.data(), + raft::distance::distance(handle, + x.data(), y.data(), out.data(), params.m, @@ -57,7 +58,6 @@ struct distance : public fixture { params.k, (void*)workspace.data(), worksize, - stream, params.isRowMajor); }); } From 713b9d38e4108bd07413fff1401516638439fd2b Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 7 Feb 2023 12:32:29 -0800 Subject: [PATCH 3/7] Use raft::resources instead of raft::device_resources --- cpp/include/raft/distance/detail/distance.cuh | 96 ++++++++++++------- cpp/include/raft/distance/distance.cuh | 25 ++--- .../specializations/detail/canberra.cuh | 4 +- .../specializations/detail/chebyshev.cuh | 4 +- .../specializations/detail/correlation.cuh | 4 +- .../specializations/detail/cosine.cuh | 4 +- .../detail/hamming_unexpanded.cuh | 4 +- .../detail/hellinger_expanded.cuh | 4 +- .../specializations/detail/inner_product.cuh | 4 +- .../specializations/detail/jensen_shannon.cuh | 4 +- .../specializations/detail/kl_divergence.cuh | 4 +- .../distance/specializations/detail/l1.cuh | 4 +- .../specializations/detail/l2_expanded.cuh | 4 +- .../detail/l2_sqrt_expanded.cuh | 4 +- .../detail/l2_sqrt_unexpanded.cuh | 4 +- .../specializations/detail/l2_unexpanded.cuh | 4 +- .../specializations/detail/lp_unexpanded.cuh | 4 +- .../specializations/detail/russel_rao.cuh | 4 +- cpp/include/raft/linalg/detail/gemm.hpp | 17 ++-- cpp/include/raft/linalg/gemm.cuh | 8 +- .../canberra_double_double_double_int.cu | 2 +- .../detail/canberra_float_float_float_int.cu | 2 +- .../chebyshev_double_double_double_int.cu | 2 +- .../detail/chebyshev_float_float_float_int.cu | 2 +- .../correlation_double_double_double_int.cu | 2 +- .../correlation_float_float_float_int.cu | 2 +- .../detail/cosine_double_double_double_int.cu | 2 +- .../detail/cosine_float_float_float_int.cu | 2 +- ...ing_unexpanded_double_double_double_int.cu | 2 +- ...amming_unexpanded_float_float_float_int.cu | 2 +- ...inger_expanded_double_double_double_int.cu | 2 +- ...ellinger_expanded_float_float_float_int.cu | 2 +- .../specializations/detail/inner_product.cu | 4 +- ...jensen_shannon_double_double_double_int.cu | 2 +- .../jensen_shannon_float_float_float_int.cu | 2 +- .../kl_divergence_double_double_double_int.cu | 2 +- .../kl_divergence_float_float_float_int.cu | 2 +- .../detail/l1_double_double_double_int.cu | 2 +- .../detail/l1_float_float_float_int.cu | 2 +- .../l2_expanded_double_double_double_int.cu | 2 +- .../l2_expanded_float_float_float_int.cu | 2 +- ..._sqrt_expanded_double_double_double_int.cu | 2 +- .../l2_sqrt_expanded_float_float_float_int.cu | 2 +- ...qrt_unexpanded_double_double_double_int.cu | 2 +- ...2_sqrt_unexpanded_float_float_float_int.cu | 2 +- .../l2_unexpanded_double_double_double_int.cu | 2 +- .../l2_unexpanded_float_float_float_int.cu | 2 +- .../lp_unexpanded_double_double_double_int.cu | 2 +- .../lp_unexpanded_float_float_float_int.cu | 2 +- .../russel_rao_double_double_double_int.cu | 2 +- .../russel_rao_float_float_float_int.cu | 2 +- 51 files changed, 150 insertions(+), 124 deletions(-) diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index 59e71d2b2f..6fead0df49 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include #include @@ -93,7 +94,7 @@ template struct DistanceImpl { - void run(raft::device_resources const& handle, + void run(raft::resources const& handle, const InType* x, const InType* y, OutType* dist, @@ -120,7 +121,7 @@ struct DistanceImpl { - void run(raft::device_resources const& handle, + void run(raft::resources const& handle, const InType* x, const InType* y, OutType* dist, @@ -144,7 +145,7 @@ struct DistanceImpl { - void run(raft::device_resources const& handle, + void run(raft::resources const& handle, const InType* x, const InType* y, OutType* dist, @@ -184,7 +185,7 @@ struct DistanceImpl { - void run(raft::device_resources const& handle, + void run(raft::resources const& handle, const InType* x, const InType* y, OutType* dist, @@ -214,7 +215,17 @@ struct DistanceImpl( - m, n, k, x, y, dist, (AccType*)workspace, worksize, fin_op, handle.get_stream(), isRowMajor); + m, + n, + k, + x, + y, + dist, + (AccType*)workspace, + worksize, + fin_op, + raft::resource::get_cuda_stream(handle), + isRowMajor); } }; @@ -229,7 +240,7 @@ struct DistanceImpl { - void run(raft::device_resources const& handle, + void run(raft::resources const& handle, const InType* x, const InType* y, OutType* dist, @@ -242,6 +253,7 @@ struct DistanceImpl(x), @@ -252,7 +264,7 @@ struct DistanceImpl { - void run(raft::device_resources const& handle, + void run(raft::resources const& handle, const InType* x, const InType* y, OutType* dist, @@ -281,7 +293,7 @@ struct DistanceImpl( - m, n, k, x, y, dist, false, fin_op, handle.get_stream(), isRowMajor); + m, n, k, x, y, dist, false, fin_op, raft::resource::get_cuda_stream(handle), isRowMajor); } }; @@ -296,7 +308,7 @@ struct DistanceImpl { - void run(raft::device_resources const& handle, + void run(raft::resources const& handle, const InType* x, const InType* y, OutType* dist, @@ -310,7 +322,7 @@ struct DistanceImpl( - m, n, k, x, y, dist, true, fin_op, handle.get_stream(), isRowMajor); + m, n, k, x, y, dist, true, fin_op, raft::resource::get_cuda_stream(handle), isRowMajor); } }; @@ -325,7 +337,7 @@ struct DistanceImpl { - void run(raft::device_resources const& handle, + void run(raft::resources const& handle, const InType* x, const InType* y, OutType* dist, @@ -339,7 +351,7 @@ struct DistanceImpl( - m, n, k, x, y, dist, fin_op, handle.get_stream(), isRowMajor); + m, n, k, x, y, dist, fin_op, raft::resource::get_cuda_stream(handle), isRowMajor); } }; @@ -354,7 +366,7 @@ struct DistanceImpl { - void run(raft::device_resources const& handle, + void run(raft::resources const& handle, const InType* x, const InType* y, OutType* dist, @@ -368,7 +380,7 @@ struct DistanceImpl( - m, n, k, x, y, dist, fin_op, handle.get_stream(), isRowMajor); + m, n, k, x, y, dist, fin_op, raft::resource::get_cuda_stream(handle), isRowMajor); } }; @@ -383,7 +395,7 @@ struct DistanceImpl { - void run(raft::device_resources const& handle, + void run(raft::resources const& handle, const InType* x, const InType* y, OutType* dist, @@ -397,7 +409,7 @@ struct DistanceImpl( - m, n, k, x, y, dist, fin_op, handle.get_stream(), isRowMajor); + m, n, k, x, y, dist, fin_op, raft::resource::get_cuda_stream(handle), isRowMajor); } }; @@ -412,7 +424,7 @@ struct DistanceImpl { - void run(raft::device_resources const& handle, + void run(raft::resources const& handle, const InType* x, const InType* y, OutType* dist, @@ -426,7 +438,7 @@ struct DistanceImpl( - m, n, k, x, y, dist, fin_op, handle.get_stream(), isRowMajor, metric_arg); + m, n, k, x, y, dist, fin_op, raft::resource::get_cuda_stream(handle), isRowMajor, metric_arg); } }; @@ -441,7 +453,7 @@ struct DistanceImpl { - void run(raft::device_resources const& handle, + void run(raft::resources const& handle, const InType* x, const InType* y, OutType* dist, @@ -455,7 +467,7 @@ struct DistanceImpl( - m, n, k, x, y, dist, fin_op, handle.get_stream(), isRowMajor); + m, n, k, x, y, dist, fin_op, raft::resource::get_cuda_stream(handle), isRowMajor); } }; @@ -470,7 +482,7 @@ struct DistanceImpl { - void run(raft::device_resources const& handle, + void run(raft::resources const& handle, const InType* x, const InType* y, OutType* dist, @@ -484,7 +496,7 @@ struct DistanceImpl( - m, n, k, x, y, dist, fin_op, handle.get_stream(), isRowMajor); + m, n, k, x, y, dist, fin_op, raft::resource::get_cuda_stream(handle), isRowMajor); } }; @@ -499,7 +511,7 @@ struct DistanceImpl { - void run(raft::device_resources const& handle, + void run(raft::resources const& handle, const InType* x, const InType* y, OutType* dist, @@ -513,7 +525,7 @@ struct DistanceImpl( - m, n, k, x, y, dist, fin_op, handle.get_stream(), isRowMajor); + m, n, k, x, y, dist, fin_op, raft::resource::get_cuda_stream(handle), isRowMajor); } }; @@ -528,7 +540,7 @@ struct DistanceImpl { - void run(raft::device_resources const& handle, + void run(raft::resources const& handle, const InType* x, const InType* y, OutType* dist, @@ -542,7 +554,7 @@ struct DistanceImpl( - m, n, k, x, y, dist, fin_op, handle.get_stream(), isRowMajor); + m, n, k, x, y, dist, fin_op, raft::resource::get_cuda_stream(handle), isRowMajor); } }; @@ -557,7 +569,7 @@ struct DistanceImpl { - void run(raft::device_resources const& handle, + void run(raft::resources const& handle, const InType* x, const InType* y, OutType* dist, @@ -571,7 +583,7 @@ struct DistanceImpl( - m, n, k, x, y, dist, fin_op, handle.get_stream(), isRowMajor); + m, n, k, x, y, dist, fin_op, raft::resource::get_cuda_stream(handle), isRowMajor); } }; @@ -586,7 +598,7 @@ struct DistanceImpl { - void run(raft::device_resources const& handle, + void run(raft::resources const& handle, const InType* x, const InType* y, OutType* dist, @@ -600,7 +612,17 @@ struct DistanceImpl( - m, n, k, x, y, dist, (AccType*)workspace, worksize, fin_op, handle.get_stream(), isRowMajor); + m, + n, + k, + x, + y, + dist, + (AccType*)workspace, + worksize, + fin_op, + raft::resource::get_cuda_stream(handle), + isRowMajor); } }; @@ -636,7 +658,7 @@ template -void distance(raft::device_resources const& handle, +void distance(raft::resources const& handle, const InType* x, const InType* y, OutType* dist, @@ -692,7 +714,7 @@ template -void distance(raft::device_resources const& handle, +void distance(raft::resources const& handle, const InType* x, const InType* y, OutType* dist, @@ -772,7 +794,7 @@ size_t getWorkspaceSize(const InType* x, const InType* y, Index_ m, Index_ n, In * @param isRowMajor whether the matrices are row-major or col-major */ template -void pairwise_distance_impl(raft::device_resources const& handle, +void pairwise_distance_impl(raft::resources const& handle, const Type* x, const Type* y, Type* dist, @@ -784,7 +806,7 @@ void pairwise_distance_impl(raft::device_resources const& handle, Type metric_arg = 2.0f) { auto worksize = getWorkspaceSize(x, y, m, n, k); - workspace.resize(worksize, handle.get_stream()); + workspace.resize(worksize, raft::resource::get_cuda_stream(handle)); distance( handle, x, y, dist, m, n, k, workspace.data(), worksize, isRowMajor, metric_arg); } diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index 623f788d79..5084280fd7 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -18,7 +18,8 @@ #pragma once -#include +#include +#include #include #include #include @@ -64,7 +65,7 @@ template -void distance(raft::device_resources const& handle, +void distance(raft::resources const& handle, const InType* x, const InType* y, OutType* dist, @@ -107,7 +108,7 @@ template -void distance(raft::device_resources const& handle, +void distance(raft::resources const& handle, const InType* x, const InType* y, OutType* dist, @@ -199,7 +200,7 @@ template -void distance(raft::device_resources const& handle, +void distance(raft::resources const& handle, const InType* x, const InType* y, OutType* dist, @@ -209,7 +210,7 @@ void distance(raft::device_resources const& handle, bool isRowMajor = true, InType metric_arg = 2.0f) { - auto stream = handle.get_stream(); + auto stream = raft::resource::get_cuda_stream(handle); rmm::device_uvector workspace(0, stream); auto worksize = getWorkspaceSize(x, y, m, n, k); workspace.resize(worksize, stream); @@ -236,7 +237,7 @@ void distance(raft::device_resources const& handle, * @param metric_arg metric argument (used for Minkowski distance) */ template -void pairwise_distance(raft::device_resources const& handle, +void pairwise_distance(raft::resources const& handle, const Type* x, const Type* y, Type* dist, @@ -335,7 +336,7 @@ void pairwise_distance(raft::device_resources const& handle, * @param metric_arg metric argument (used for Minkowski distance) */ template -void pairwise_distance(raft::device_resources const& handle, +void pairwise_distance(raft::resources const& handle, const Type* x, const Type* y, Type* dist, @@ -346,7 +347,8 @@ void pairwise_distance(raft::device_resources const& handle, bool isRowMajor = true, Type metric_arg = 2.0f) { - rmm::device_uvector workspace(0, handle.get_stream()); + auto stream = raft::resource::get_cuda_stream(handle); + rmm::device_uvector workspace(0, stream); pairwise_distance( handle, x, y, dist, m, n, k, workspace, metric, isRowMajor, metric_arg); } @@ -400,7 +402,7 @@ template -void distance(raft::device_resources const& handle, +void distance(raft::resources const& handle, raft::device_matrix_view const x, raft::device_matrix_view const y, raft::device_matrix_view dist, @@ -443,7 +445,7 @@ void distance(raft::device_resources const& handle, * @param metric_arg metric argument (used for Minkowski distance) */ template -void pairwise_distance(raft::device_resources const& handle, +void pairwise_distance(raft::resources const& handle, device_matrix_view const x, device_matrix_view const y, device_matrix_view dist, @@ -464,7 +466,8 @@ void pairwise_distance(raft::device_resources const& handle, constexpr auto rowmajor = std::is_same_v; - rmm::device_uvector workspace(0, handle.get_stream()); + auto stream = raft::resource::get_cuda_stream(handle); + rmm::device_uvector workspace(0, stream); pairwise_distance(handle, x.data_handle(), diff --git a/cpp/include/raft/distance/specializations/detail/canberra.cuh b/cpp/include/raft/distance/specializations/detail/canberra.cuh index 178c9047a6..badce715a5 100644 --- a/cpp/include/raft/distance/specializations/detail/canberra.cuh +++ b/cpp/include/raft/distance/specializations/detail/canberra.cuh @@ -23,7 +23,7 @@ namespace raft { namespace distance { namespace detail { extern template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -36,7 +36,7 @@ extern template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const double* x, const double* y, double* dist, diff --git a/cpp/include/raft/distance/specializations/detail/chebyshev.cuh b/cpp/include/raft/distance/specializations/detail/chebyshev.cuh index 06cb20e873..9a46d7b488 100644 --- a/cpp/include/raft/distance/specializations/detail/chebyshev.cuh +++ b/cpp/include/raft/distance/specializations/detail/chebyshev.cuh @@ -22,7 +22,7 @@ namespace raft { namespace distance { namespace detail { extern template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -35,7 +35,7 @@ extern template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const double* x, const double* y, double* dist, diff --git a/cpp/include/raft/distance/specializations/detail/correlation.cuh b/cpp/include/raft/distance/specializations/detail/correlation.cuh index 55b09d92aa..013a0d43a3 100644 --- a/cpp/include/raft/distance/specializations/detail/correlation.cuh +++ b/cpp/include/raft/distance/specializations/detail/correlation.cuh @@ -23,7 +23,7 @@ namespace distance { namespace detail { extern template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -37,7 +37,7 @@ distance( - raft::device_resources const& handle, + raft::resources const& handle, const double* x, const double* y, double* dist, diff --git a/cpp/include/raft/distance/specializations/detail/cosine.cuh b/cpp/include/raft/distance/specializations/detail/cosine.cuh index 0663d81f4f..c88bd1b0f6 100644 --- a/cpp/include/raft/distance/specializations/detail/cosine.cuh +++ b/cpp/include/raft/distance/specializations/detail/cosine.cuh @@ -23,7 +23,7 @@ namespace distance { namespace detail { extern template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -37,7 +37,7 @@ distance extern template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const double* x, const double* y, double* dist, diff --git a/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh index f48f6b396e..3c5cad3315 100644 --- a/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh @@ -23,7 +23,7 @@ namespace distance { namespace detail { extern template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -37,7 +37,7 @@ distance( - raft::device_resources const& handle, + raft::resources const& handle, const double* x, const double* y, double* dist, diff --git a/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh b/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh index 1ced55de57..bf214c046f 100644 --- a/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh @@ -23,7 +23,7 @@ namespace distance { namespace detail { extern template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -37,7 +37,7 @@ distance( - raft::device_resources const& handle, + raft::resources const& handle, const double* x, const double* y, double* dist, diff --git a/cpp/include/raft/distance/specializations/detail/inner_product.cuh b/cpp/include/raft/distance/specializations/detail/inner_product.cuh index 3ea29d05aa..d97d678928 100644 --- a/cpp/include/raft/distance/specializations/detail/inner_product.cuh +++ b/cpp/include/raft/distance/specializations/detail/inner_product.cuh @@ -22,7 +22,7 @@ namespace raft { namespace distance { namespace detail { extern template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -36,7 +36,7 @@ extern template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const double* x, const double* y, double* dist, diff --git a/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh b/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh index 7723cce7d4..145834fb70 100644 --- a/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh +++ b/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh @@ -23,7 +23,7 @@ namespace distance { namespace detail { extern template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -37,7 +37,7 @@ distance( extern template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const double* x, const double* y, double* dist, diff --git a/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh b/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh index 12b3c900d9..f0928916cd 100644 --- a/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh +++ b/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh @@ -22,7 +22,7 @@ namespace raft { namespace distance { namespace detail { extern template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -36,7 +36,7 @@ extern template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const double* x, const double* y, double* dist, diff --git a/cpp/include/raft/distance/specializations/detail/l1.cuh b/cpp/include/raft/distance/specializations/detail/l1.cuh index f6b7826b36..23261a2571 100644 --- a/cpp/include/raft/distance/specializations/detail/l1.cuh +++ b/cpp/include/raft/distance/specializations/detail/l1.cuh @@ -22,7 +22,7 @@ namespace raft { namespace distance { namespace detail { extern template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -35,7 +35,7 @@ extern template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const double* x, const double* y, double* dist, diff --git a/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh index 182c02bb94..f953018b7d 100644 --- a/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh @@ -22,7 +22,7 @@ namespace raft { namespace distance { namespace detail { extern template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -36,7 +36,7 @@ extern template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const double* x, const double* y, double* dist, diff --git a/cpp/include/raft/distance/specializations/detail/l2_sqrt_expanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_sqrt_expanded.cuh index 3bd65bb769..9f5f6a3706 100644 --- a/cpp/include/raft/distance/specializations/detail/l2_sqrt_expanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/l2_sqrt_expanded.cuh @@ -23,7 +23,7 @@ namespace distance { namespace detail { extern template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -37,7 +37,7 @@ distance extern template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const double* x, const double* y, double* dist, diff --git a/cpp/include/raft/distance/specializations/detail/l2_sqrt_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_sqrt_unexpanded.cuh index 28f9348592..94531ddc33 100644 --- a/cpp/include/raft/distance/specializations/detail/l2_sqrt_unexpanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/l2_sqrt_unexpanded.cuh @@ -23,7 +23,7 @@ namespace distance { namespace detail { extern template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -37,7 +37,7 @@ distance( - raft::device_resources const& handle, + raft::resources const& handle, const double* x, const double* y, double* dist, diff --git a/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh index f384fb6c3f..224b21fce8 100644 --- a/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh @@ -22,7 +22,7 @@ namespace raft { namespace distance { namespace detail { extern template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -36,7 +36,7 @@ extern template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const double* x, const double* y, double* dist, diff --git a/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh index 897dbc57b0..e05ef02c42 100644 --- a/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh @@ -22,7 +22,7 @@ namespace raft { namespace distance { namespace detail { extern template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -36,7 +36,7 @@ extern template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const double* x, const double* y, double* dist, diff --git a/cpp/include/raft/distance/specializations/detail/russel_rao.cuh b/cpp/include/raft/distance/specializations/detail/russel_rao.cuh index c8cd60e31b..afc87997c0 100644 --- a/cpp/include/raft/distance/specializations/detail/russel_rao.cuh +++ b/cpp/include/raft/distance/specializations/detail/russel_rao.cuh @@ -23,7 +23,7 @@ namespace distance { namespace detail { extern template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -37,7 +37,7 @@ distance( - raft::device_resources const& handle, + raft::resources const& handle, const double* x, const double* y, double* dist, diff --git a/cpp/include/raft/linalg/detail/gemm.hpp b/cpp/include/raft/linalg/detail/gemm.hpp index ba9496c3b9..c8593bd624 100644 --- a/cpp/include/raft/linalg/detail/gemm.hpp +++ b/cpp/include/raft/linalg/detail/gemm.hpp @@ -20,7 +20,8 @@ #include "cublas_wrappers.hpp" -#include +#include +#include namespace raft { namespace linalg { @@ -49,7 +50,7 @@ namespace detail { * @param [in] stream */ template -void gemm(raft::device_resources const& handle, +void gemm(raft::resources const& handle, const bool trans_a, const bool trans_b, const int m, @@ -65,7 +66,7 @@ void gemm(raft::device_resources const& handle, const int ldc, cudaStream_t stream) { - cublasHandle_t cublas_h = handle.get_cublas_handle(); + auto cublas_h = raft::resource::get_cublas_handle(handle); cublas_device_pointer_mode pmode(cublas_h); RAFT_CUBLAS_TRY(cublasgemm(cublas_h, trans_a ? CUBLAS_OP_T : CUBLAS_OP_N, @@ -103,7 +104,7 @@ void gemm(raft::device_resources const& handle, * @param stream cuda stream */ template -void gemm(raft::device_resources const& handle, +void gemm(raft::resources const& handle, const math_t* a, int n_rows_a, int n_cols_a, @@ -117,7 +118,7 @@ void gemm(raft::device_resources const& handle, math_t beta, cudaStream_t stream) { - cublasHandle_t cublas_h = handle.get_cublas_handle(); + auto cublas_h = raft::resource::get_cublas_handle(handle); int m = n_rows_c; int n = n_cols_c; @@ -130,7 +131,7 @@ void gemm(raft::device_resources const& handle, } template -void gemm(raft::device_resources const& handle, +void gemm(raft::resources const& handle, const math_t* a, int n_rows_a, int n_cols_a, @@ -149,7 +150,7 @@ void gemm(raft::device_resources const& handle, } template -void gemm(raft::device_resources const& handle, +void gemm(raft::resources const& handle, T* z, T* x, T* y, @@ -163,7 +164,7 @@ void gemm(raft::device_resources const& handle, T* alpha, T* beta) { - cublasHandle_t cublas_h = handle.get_cublas_handle(); + auto cublas_h = raft::resource::get_cublas_handle(handle); cublas_device_pointer_mode pmode(cublas_h); cublasOperation_t trans_a, trans_b; diff --git a/cpp/include/raft/linalg/gemm.cuh b/cpp/include/raft/linalg/gemm.cuh index d5dc5ffab5..a336f844bf 100644 --- a/cpp/include/raft/linalg/gemm.cuh +++ b/cpp/include/raft/linalg/gemm.cuh @@ -52,7 +52,7 @@ namespace linalg { * @param [in] stream */ template -void gemm(raft::device_resources const& handle, +void gemm(raft::resources const& handle, const bool trans_a, const bool trans_b, const int m, @@ -91,7 +91,7 @@ void gemm(raft::device_resources const& handle, * @param stream cuda stream */ template -void gemm(raft::device_resources const& handle, +void gemm(raft::resources const& handle, const math_t* a, int n_rows_a, int n_cols_a, @@ -126,7 +126,7 @@ void gemm(raft::device_resources const& handle, * @param stream cuda stream */ template -void gemm(raft::device_resources const& handle, +void gemm(raft::resources const& handle, const math_t* a, int n_rows_a, int n_cols_a, @@ -161,7 +161,7 @@ void gemm(raft::device_resources const& handle, * @param beta scalar */ template -void gemm(raft::device_resources const& handle, +void gemm(raft::resources const& handle, T* z, T* x, T* y, diff --git a/cpp/src/distance/distance/specializations/detail/canberra_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/canberra_double_double_double_int.cu index 2674758992..d3a629f5a0 100644 --- a/cpp/src/distance/distance/specializations/detail/canberra_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/canberra_double_double_double_int.cu @@ -20,7 +20,7 @@ namespace raft { namespace distance { namespace detail { template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const double* x, const double* y, double* dist, diff --git a/cpp/src/distance/distance/specializations/detail/canberra_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/canberra_float_float_float_int.cu index 7ee48ec656..71899c00a7 100644 --- a/cpp/src/distance/distance/specializations/detail/canberra_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/canberra_float_float_float_int.cu @@ -20,7 +20,7 @@ namespace raft { namespace distance { namespace detail { template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const float* x, const float* y, float* dist, diff --git a/cpp/src/distance/distance/specializations/detail/chebyshev_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/chebyshev_double_double_double_int.cu index 652fb68fa9..c071ddc902 100644 --- a/cpp/src/distance/distance/specializations/detail/chebyshev_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/chebyshev_double_double_double_int.cu @@ -21,7 +21,7 @@ namespace distance { namespace detail { template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const double* x, const double* y, double* dist, diff --git a/cpp/src/distance/distance/specializations/detail/chebyshev_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/chebyshev_float_float_float_int.cu index 427b2c26e0..d42d769306 100644 --- a/cpp/src/distance/distance/specializations/detail/chebyshev_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/chebyshev_float_float_float_int.cu @@ -20,7 +20,7 @@ namespace raft { namespace distance { namespace detail { template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const float* x, const float* y, float* dist, diff --git a/cpp/src/distance/distance/specializations/detail/correlation_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/correlation_double_double_double_int.cu index 3cd15146ce..4c626a0c73 100644 --- a/cpp/src/distance/distance/specializations/detail/correlation_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/correlation_double_double_double_int.cu @@ -22,7 +22,7 @@ namespace detail { template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const double* x, const double* y, double* dist, diff --git a/cpp/src/distance/distance/specializations/detail/correlation_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/correlation_float_float_float_int.cu index fd4eee09c9..5f23a5042e 100644 --- a/cpp/src/distance/distance/specializations/detail/correlation_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/correlation_float_float_float_int.cu @@ -20,7 +20,7 @@ namespace raft { namespace distance { namespace detail { template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const float* x, const float* y, float* dist, diff --git a/cpp/src/distance/distance/specializations/detail/cosine_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/cosine_double_double_double_int.cu index 97ad944edf..b7f2eb569d 100644 --- a/cpp/src/distance/distance/specializations/detail/cosine_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/cosine_double_double_double_int.cu @@ -21,7 +21,7 @@ namespace distance { namespace detail { template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const double* x, const double* y, double* dist, diff --git a/cpp/src/distance/distance/specializations/detail/cosine_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/cosine_float_float_float_int.cu index 5f7a72a213..650c889a0d 100644 --- a/cpp/src/distance/distance/specializations/detail/cosine_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/cosine_float_float_float_int.cu @@ -21,7 +21,7 @@ namespace distance { namespace detail { template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const float* x, const float* y, float* dist, diff --git a/cpp/src/distance/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu index c410b6cd8b..282449f48b 100644 --- a/cpp/src/distance/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu @@ -22,7 +22,7 @@ namespace detail { template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const double* x, const double* y, double* dist, diff --git a/cpp/src/distance/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu index 08ade61b61..2bed6c2401 100644 --- a/cpp/src/distance/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu @@ -20,7 +20,7 @@ namespace raft { namespace distance { namespace detail { template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const float* x, const float* y, float* dist, diff --git a/cpp/src/distance/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu index 10c1f57487..b14425eec3 100644 --- a/cpp/src/distance/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu @@ -22,7 +22,7 @@ namespace detail { template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const double* x, const double* y, double* dist, diff --git a/cpp/src/distance/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu index 61a5f38a7c..b738b3983b 100644 --- a/cpp/src/distance/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu @@ -21,7 +21,7 @@ namespace distance { namespace detail { template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const float* x, const float* y, float* dist, diff --git a/cpp/src/distance/distance/specializations/detail/inner_product.cu b/cpp/src/distance/distance/specializations/detail/inner_product.cu index c8a6424dac..5a7b1d6191 100644 --- a/cpp/src/distance/distance/specializations/detail/inner_product.cu +++ b/cpp/src/distance/distance/specializations/detail/inner_product.cu @@ -20,7 +20,7 @@ namespace raft { namespace distance { namespace detail { template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const float* x, const float* y, float* dist, @@ -33,7 +33,7 @@ template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const double* x, const double* y, double* dist, diff --git a/cpp/src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu index 74aa93c4af..c490033dc3 100644 --- a/cpp/src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu @@ -20,7 +20,7 @@ namespace raft { namespace distance { namespace detail { template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const double* x, const double* y, double* dist, diff --git a/cpp/src/distance/distance/specializations/detail/jensen_shannon_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/jensen_shannon_float_float_float_int.cu index 476593b1b9..ae1d860b3e 100644 --- a/cpp/src/distance/distance/specializations/detail/jensen_shannon_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/jensen_shannon_float_float_float_int.cu @@ -20,7 +20,7 @@ namespace raft { namespace distance { namespace detail { template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const float* x, const float* y, float* dist, diff --git a/cpp/src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu index 7e0f802949..ed2a29bccf 100644 --- a/cpp/src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu @@ -20,7 +20,7 @@ namespace raft { namespace distance { namespace detail { template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const double* x, const double* y, double* dist, diff --git a/cpp/src/distance/distance/specializations/detail/kl_divergence_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/kl_divergence_float_float_float_int.cu index d84762a2d1..511e0fb1a6 100644 --- a/cpp/src/distance/distance/specializations/detail/kl_divergence_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/kl_divergence_float_float_float_int.cu @@ -20,7 +20,7 @@ namespace raft { namespace distance { namespace detail { template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const float* x, const float* y, float* dist, diff --git a/cpp/src/distance/distance/specializations/detail/l1_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/l1_double_double_double_int.cu index d6767a0eaf..7a1c884223 100644 --- a/cpp/src/distance/distance/specializations/detail/l1_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l1_double_double_double_int.cu @@ -20,7 +20,7 @@ namespace raft { namespace distance { namespace detail { template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const double* x, const double* y, double* dist, diff --git a/cpp/src/distance/distance/specializations/detail/l1_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/l1_float_float_float_int.cu index 1b4e2f7d23..f307ade06f 100644 --- a/cpp/src/distance/distance/specializations/detail/l1_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l1_float_float_float_int.cu @@ -20,7 +20,7 @@ namespace raft { namespace distance { namespace detail { template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const float* x, const float* y, float* dist, diff --git a/cpp/src/distance/distance/specializations/detail/l2_expanded_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/l2_expanded_double_double_double_int.cu index a60cd5587e..18ddf10b14 100644 --- a/cpp/src/distance/distance/specializations/detail/l2_expanded_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l2_expanded_double_double_double_int.cu @@ -20,7 +20,7 @@ namespace raft { namespace distance { namespace detail { template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const double* x, const double* y, double* dist, diff --git a/cpp/src/distance/distance/specializations/detail/l2_expanded_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/l2_expanded_float_float_float_int.cu index 3c888be766..28e7b73243 100644 --- a/cpp/src/distance/distance/specializations/detail/l2_expanded_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l2_expanded_float_float_float_int.cu @@ -20,7 +20,7 @@ namespace raft { namespace distance { namespace detail { template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const float* x, const float* y, float* dist, diff --git a/cpp/src/distance/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu index afb6bf01ed..404dc5f8b4 100644 --- a/cpp/src/distance/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu @@ -20,7 +20,7 @@ namespace raft { namespace distance { namespace detail { template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const double* x, const double* y, double* dist, diff --git a/cpp/src/distance/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu index 52a6b153ed..b4f8757d4a 100644 --- a/cpp/src/distance/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu @@ -20,7 +20,7 @@ namespace raft { namespace distance { namespace detail { template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const float* x, const float* y, float* dist, diff --git a/cpp/src/distance/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu index 4544609aa6..10cafac4a7 100644 --- a/cpp/src/distance/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu @@ -21,7 +21,7 @@ namespace distance { namespace detail { template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const double* x, const double* y, double* dist, diff --git a/cpp/src/distance/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu index a8d865ef0b..6a5f43bc46 100644 --- a/cpp/src/distance/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu @@ -20,7 +20,7 @@ namespace raft { namespace distance { namespace detail { template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const float* x, const float* y, float* dist, diff --git a/cpp/src/distance/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu index 48be358a22..174a893d36 100644 --- a/cpp/src/distance/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu @@ -20,7 +20,7 @@ namespace raft { namespace distance { namespace detail { template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const double* x, const double* y, double* dist, diff --git a/cpp/src/distance/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu index ec84034167..0cc43b099f 100644 --- a/cpp/src/distance/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu @@ -20,7 +20,7 @@ namespace raft { namespace distance { namespace detail { template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const float* x, const float* y, float* dist, diff --git a/cpp/src/distance/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu index 7c014781fa..66e765b1c0 100644 --- a/cpp/src/distance/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu @@ -21,7 +21,7 @@ namespace distance { namespace detail { template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const double* x, const double* y, double* dist, diff --git a/cpp/src/distance/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu index 446d0e5ba1..6fbd52ac89 100644 --- a/cpp/src/distance/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu @@ -20,7 +20,7 @@ namespace raft { namespace distance { namespace detail { template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const float* x, const float* y, float* dist, diff --git a/cpp/src/distance/distance/specializations/detail/russel_rao_double_double_double_int.cu b/cpp/src/distance/distance/specializations/detail/russel_rao_double_double_double_int.cu index dbf1b15a44..77e1fc70d0 100644 --- a/cpp/src/distance/distance/specializations/detail/russel_rao_double_double_double_int.cu +++ b/cpp/src/distance/distance/specializations/detail/russel_rao_double_double_double_int.cu @@ -21,7 +21,7 @@ namespace distance { namespace detail { template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const double* x, const double* y, double* dist, diff --git a/cpp/src/distance/distance/specializations/detail/russel_rao_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/russel_rao_float_float_float_int.cu index 117e6a655d..6f898310c7 100644 --- a/cpp/src/distance/distance/specializations/detail/russel_rao_float_float_float_int.cu +++ b/cpp/src/distance/distance/specializations/detail/russel_rao_float_float_float_int.cu @@ -20,7 +20,7 @@ namespace raft { namespace distance { namespace detail { template void distance( - raft::device_resources const& handle, + raft::resources const& handle, const float* x, const float* y, float* dist, From 35b7398cd857c2ee1b08c920a48ebd1eb34339ac Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Tue, 7 Feb 2023 12:39:41 -0800 Subject: [PATCH 4/7] style --- cpp/include/raft/distance/distance.cuh | 2 +- cpp/include/raft/linalg/detail/gemm.hpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index 5084280fd7..5fd5ebe98d 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -18,8 +18,8 @@ #pragma once -#include #include +#include #include #include #include diff --git a/cpp/include/raft/linalg/detail/gemm.hpp b/cpp/include/raft/linalg/detail/gemm.hpp index c8593bd624..d82c821148 100644 --- a/cpp/include/raft/linalg/detail/gemm.hpp +++ b/cpp/include/raft/linalg/detail/gemm.hpp @@ -20,8 +20,8 @@ #include "cublas_wrappers.hpp" -#include #include +#include namespace raft { namespace linalg { From 47f5fdbe0247e8667375036edd83a50156cc2b1c Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Fri, 10 Feb 2023 14:18:44 -0800 Subject: [PATCH 5/7] Add C++ test for IP distance --- cpp/include/raft/distance/detail/distance.cuh | 4 +- cpp/test/CMakeLists.txt | 1 + cpp/test/distance/dist_inner_product.cu | 76 +++++++++++++++++++ cpp/test/distance/distance_base.cuh | 25 ++++++ .../pylibraft/pylibraft/test/test_distance.py | 4 +- 5 files changed, 106 insertions(+), 4 deletions(-) create mode 100644 cpp/test/distance/dist_inner_product.cu diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index 6fead0df49..f9a85ea627 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -261,8 +261,8 @@ struct DistanceImpl +class DistanceInnerProduct + : public DistanceTest { +}; + +const std::vector> inputsf = { + {0.001f, 10, 5, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceInnerProduct DistanceInnerProductF; +TEST_P(DistanceInnerProductF, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceInnerProductF, ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.001, 1024, 1024, 32, true, 1234ULL}, + {0.001, 1024, 32, 1024, true, 1234ULL}, + {0.001, 32, 1024, 1024, true, 1234ULL}, + {0.003, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceInnerProduct DistanceInnerProductD; +TEST_P(DistanceInnerProductD, Result) +{ + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + + ASSERT_TRUE(devArrMatch( + dist_ref.data(), dist.data(), m, n, raft::CompareApprox(params.tolerance), stream)); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceInnerProductD, ::testing::ValuesIn(inputsd)); + +class BigMatrixInnerProduct + : public BigMatrixDistanceTest { +}; +TEST_F(BigMatrixInnerProduct, Result) {} + +} // end namespace distance +} // end namespace raft diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index a9dca21359..278fde0c45 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -122,6 +122,28 @@ __global__ void naiveCosineDistanceKernel( dist[outidx] = (DataType)1.0 - acc_ab / (raft::sqrt(acc_a) * raft::sqrt(acc_b)); } +template +__global__ void naiveInnerProductKernel( + DataType* dist, const DataType* x, const DataType* y, int m, int n, int k, bool isRowMajor) +{ + int midx = threadIdx.x + blockIdx.x * blockDim.x; + int nidx = threadIdx.y + blockIdx.y * blockDim.y; + if (midx >= m || nidx >= n) { return; } + + DataType acc_ab = DataType(0); + + for (int i = 0; i < k; ++i) { + int xidx = isRowMajor ? i + midx * k : i * m + midx; + int yidx = isRowMajor ? i + nidx * k : i * n + nidx; + auto a = x[xidx]; + auto b = y[yidx]; + acc_ab += a * b; + } + + int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; + dist[outidx] = acc_ab; +} + template __global__ void naiveHellingerDistanceKernel( DataType* dist, const DataType* x, const DataType* y, int m, int n, int k, bool isRowMajor) @@ -348,6 +370,9 @@ void naiveDistance(DataType* dist, naiveHammingDistanceKernel <<>>(dist, x, y, m, n, k, isRowMajor); break; + case raft::distance::DistanceType::InnerProduct: + naiveInnerProductKernel<<>>(dist, x, y, m, n, k, isRowMajor); + break; case raft::distance::DistanceType::JensenShannon: naiveJensenShannonDistanceKernel <<>>(dist, x, y, m, n, k, isRowMajor); diff --git a/python/pylibraft/pylibraft/test/test_distance.py b/python/pylibraft/pylibraft/test/test_distance.py index 971f40498f..2c0a842fe5 100644 --- a/python/pylibraft/pylibraft/test/test_distance.py +++ b/python/pylibraft/pylibraft/test/test_distance.py @@ -21,8 +21,8 @@ from pylibraft.distance import pairwise_distance -@pytest.mark.parametrize("n_rows", [100]) -@pytest.mark.parametrize("n_cols", [100]) +@pytest.mark.parametrize("n_rows", [32, 100]) +@pytest.mark.parametrize("n_cols", [40, 100]) @pytest.mark.parametrize( "metric", [ From 1ff9fbb95c3714ac5c06a1a7e14c5a4620584d51 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Fri, 10 Feb 2023 14:24:03 -0800 Subject: [PATCH 6/7] Split IP specializations into separate files --- cpp/CMakeLists.txt | 3 +- ...inner_product_double_double_double_int.cu} | 13 ------- .../inner_product_float_float_float_int.cu | 36 +++++++++++++++++++ 3 files changed, 38 insertions(+), 14 deletions(-) rename cpp/src/distance/distance/specializations/detail/{inner_product.cu => inner_product_double_double_double_int.cu} (78%) create mode 100644 cpp/src/distance/distance/specializations/detail/inner_product_float_float_float_int.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 67a13fad5f..731a30ae03 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -321,7 +321,8 @@ if(RAFT_COMPILE_DIST_LIBRARY) src/distance/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu src/distance/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu src/distance/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu - src/distance/distance/specializations/detail/inner_product.cu + src/distance/distance/specializations/detail/inner_product_float_float_float_int.cu + src/distance/distance/specializations/detail/inner_product_double_double_double_int.cu src/distance/distance/specializations/detail/jensen_shannon_float_float_float_int.cu src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu src/distance/distance/specializations/detail/kernels/gram_matrix_base_double.cu diff --git a/cpp/src/distance/distance/specializations/detail/inner_product.cu b/cpp/src/distance/distance/specializations/detail/inner_product_double_double_double_int.cu similarity index 78% rename from cpp/src/distance/distance/specializations/detail/inner_product.cu rename to cpp/src/distance/distance/specializations/detail/inner_product_double_double_double_int.cu index 5a7b1d6191..41bd6c2c14 100644 --- a/cpp/src/distance/distance/specializations/detail/inner_product.cu +++ b/cpp/src/distance/distance/specializations/detail/inner_product_double_double_double_int.cu @@ -19,19 +19,6 @@ namespace raft { namespace distance { namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); - template void distance( raft::resources const& handle, const double* x, diff --git a/cpp/src/distance/distance/specializations/detail/inner_product_float_float_float_int.cu b/cpp/src/distance/distance/specializations/detail/inner_product_float_float_float_int.cu new file mode 100644 index 0000000000..f147c51b52 --- /dev/null +++ b/cpp/src/distance/distance/specializations/detail/inner_product_float_float_float_int.cu @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft { +namespace distance { +namespace detail { +template void distance( + raft::resources const& handle, + const float* x, + const float* y, + float* dist, + int m, + int n, + int k, + void* workspace, + std::size_t worksize, + bool isRowMajor, + float metric_arg); +} // namespace detail +} // namespace distance +} // namespace raft From b589d5d9b68e1aa43e52d51f8e3a1a8ed298013d Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Fri, 10 Feb 2023 22:42:59 -0800 Subject: [PATCH 7/7] fix --- cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh b/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh index 3d37b37293..d1465efdb0 100644 --- a/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh +++ b/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh @@ -369,7 +369,6 @@ class RBFKernel : public GramMatrixBase { NULL, 0, fin_op, - stream, is_row_major); } };