diff --git a/cpp/include/raft/distance/canberra.cuh b/cpp/include/raft/distance/detail/canberra.cuh similarity index 98% rename from cpp/include/raft/distance/canberra.cuh rename to cpp/include/raft/distance/detail/canberra.cuh index b87c295eb0..c4c384c45f 100644 --- a/cpp/include/raft/distance/canberra.cuh +++ b/cpp/include/raft/distance/detail/canberra.cuh @@ -15,10 +15,11 @@ */ #pragma once -#include +#include namespace raft { namespace distance { +namespace detail { /** * @brief the canberra distance matrix calculation implementer @@ -157,5 +158,7 @@ void canberraImpl(int m, int n, int k, const InType *pA, const InType *pB, n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); } } + +} // namespace detail } // namespace distance } // namespace raft diff --git a/cpp/include/raft/distance/chebyshev.cuh b/cpp/include/raft/distance/detail/chebyshev.cuh similarity index 98% rename from cpp/include/raft/distance/chebyshev.cuh rename to cpp/include/raft/distance/detail/chebyshev.cuh index 8d53408cf8..77fba28310 100644 --- a/cpp/include/raft/distance/chebyshev.cuh +++ b/cpp/include/raft/distance/detail/chebyshev.cuh @@ -15,11 +15,11 @@ */ #pragma once -#include +#include namespace raft { namespace distance { - +namespace detail { /** * @brief the Chebyshev distance matrix calculation implementer * It computes the following equation: cij = max(cij, op(ai-bj)) @@ -154,5 +154,6 @@ void chebyshevImpl(int m, int n, int k, const InType *pA, const InType *pB, n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); } } +} // namespace detail } // namespace distance } // namespace raft diff --git a/cpp/include/raft/distance/correlation.cuh b/cpp/include/raft/distance/detail/correlation.cuh similarity index 99% rename from cpp/include/raft/distance/correlation.cuh rename to cpp/include/raft/distance/detail/correlation.cuh index ed3b7a5464..cee986997a 100644 --- a/cpp/include/raft/distance/correlation.cuh +++ b/cpp/include/raft/distance/detail/correlation.cuh @@ -16,11 +16,12 @@ #pragma once #include -#include +#include #include namespace raft { namespace distance { +namespace detail { /** * @brief the Correlation distance matrix: @@ -243,5 +244,7 @@ void correlationImpl(int m, int n, int k, const InType *pA, const InType *pB, fin_op, stream); } } + +} // namespace detail } // namespace distance } // namespace raft diff --git a/cpp/include/raft/distance/cosine.cuh b/cpp/include/raft/distance/detail/cosine.cuh similarity index 98% rename from cpp/include/raft/distance/cosine.cuh rename to cpp/include/raft/distance/detail/cosine.cuh index ed9bd28b7f..900e045edc 100644 --- a/cpp/include/raft/distance/cosine.cuh +++ b/cpp/include/raft/distance/detail/cosine.cuh @@ -16,11 +16,12 @@ #pragma once -#include +#include #include namespace raft { namespace distance { +namespace detail { /** * @brief the cosine distance matrix calculation implementer @@ -201,5 +202,6 @@ void cosineAlgo1(Index_ m, Index_ n, Index_ k, const InType *pA, } } +}; // end namespace detail }; // end namespace distance }; // end namespace raft diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh similarity index 63% rename from cpp/include/raft/distance/distance.cuh rename to cpp/include/raft/distance/detail/distance.cuh index 65b4f3b830..199dc73fb6 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -19,22 +19,70 @@ #include #include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include namespace raft { namespace distance { +namespace detail { + +/** enum to tell how to compute distance */ +enum DistanceType : unsigned short { + + /** evaluate as dist_ij = sum(x_ik^2) + sum(y_ij)^2 - 2*sum(x_ik * y_jk) */ + L2Expanded = 0, + /** same as above, but inside the epilogue, perform square root operation */ + L2SqrtExpanded = 1, + /** cosine distance */ + CosineExpanded = 2, + /** L1 distance */ + L1 = 3, + /** evaluate as dist_ij += (x_ik - y-jk)^2 */ + L2Unexpanded = 4, + /** same as above, but inside the epilogue, perform square root operation */ + L2SqrtUnexpanded = 5, + /** basic inner product **/ + InnerProduct = 6, + /** Chebyshev (Linf) distance **/ + Linf = 7, + /** Canberra distance **/ + Canberra = 8, + /** Generalized Minkowski distance **/ + LpUnexpanded = 9, + /** Correlation distance **/ + CorrelationExpanded = 10, + /** Jaccard distance **/ + JaccardExpanded = 11, + /** Hellinger distance **/ + HellingerExpanded = 12, + /** Haversine distance **/ + Haversine = 13, + /** Bray-Curtis distance **/ + BrayCurtis = 14, + /** Jensen-Shannon distance**/ + JensenShannon = 15, + /** Hamming distance **/ + HammingUnexpanded = 16, + /** KLDivergence **/ + KLDivergence = 17, + /** RusselRao **/ + RusselRaoExpanded = 18, + /** Dice-Sorensen distance **/ + DiceExpanded = 19, + /** Precomputed (special value) **/ + Precomputed = 100 +}; namespace { template (m, n, k, x, y, dist, false, - (AccType *)workspace, worksize, - fin_op, stream, isRowMajor); + raft::distance::detail::euclideanAlgo1( + m, n, k, x, y, dist, false, (AccType *)workspace, worksize, fin_op, + stream, isRowMajor); } }; @@ -67,10 +115,10 @@ struct DistanceImpl(m, n, k, x, y, dist, true, - (AccType *)workspace, worksize, - fin_op, stream, isRowMajor); + raft::distance::detail::euclideanAlgo1( + m, n, k, x, y, dist, true, (AccType *)workspace, worksize, fin_op, stream, + isRowMajor); } }; @@ -81,9 +129,10 @@ struct DistanceImpl( - m, n, k, x, y, dist, (AccType *)workspace, worksize, fin_op, stream, - isRowMajor); + raft::distance::detail::cosineAlgo1(m, n, k, x, y, dist, + (AccType *)workspace, worksize, + fin_op, stream, isRowMajor); } }; @@ -94,9 +143,9 @@ struct DistanceImpl(m, n, k, x, y, dist, false, fin_op, - stream, isRowMajor); + raft::distance::detail::euclideanAlgo2( + m, n, k, x, y, dist, false, fin_op, stream, isRowMajor); } }; @@ -107,9 +156,9 @@ struct DistanceImpl(m, n, k, x, y, dist, true, fin_op, - stream, isRowMajor); + raft::distance::detail::euclideanAlgo2( + m, n, k, x, y, dist, true, fin_op, stream, isRowMajor); } }; @@ -120,8 +169,9 @@ struct DistanceImpl( - m, n, k, x, y, dist, fin_op, stream, isRowMajor); + raft::distance::detail::l1Impl(m, n, k, x, y, dist, fin_op, stream, + isRowMajor); } }; @@ -132,9 +182,9 @@ struct DistanceImpl(m, n, k, x, y, dist, fin_op, stream, - isRowMajor); + raft::distance::detail::chebyshevImpl(m, n, k, x, y, dist, fin_op, + stream, isRowMajor); } }; @@ -145,9 +195,9 @@ struct DistanceImpl(m, n, k, x, y, dist, fin_op, stream, - isRowMajor); + raft::distance::detail::hellingerImpl(m, n, k, x, y, dist, fin_op, + stream, isRowMajor); } }; @@ -158,9 +208,9 @@ struct DistanceImpl(m, n, k, x, y, dist, fin_op, stream, - isRowMajor, metric_arg); + raft::distance::detail::minkowskiImpl( + m, n, k, x, y, dist, fin_op, stream, isRowMajor, metric_arg); } }; @@ -171,8 +221,9 @@ struct DistanceImpl( - m, n, k, x, y, dist, fin_op, stream, isRowMajor); + raft::distance::detail::canberraImpl(m, n, k, x, y, dist, fin_op, + stream, isRowMajor); } }; @@ -183,9 +234,9 @@ struct DistanceImpl(m, n, k, x, y, dist, fin_op, - stream, isRowMajor); + raft::distance::detail::hammingUnexpandedImpl( + m, n, k, x, y, dist, fin_op, stream, isRowMajor); } }; @@ -196,9 +247,9 @@ struct DistanceImpl(m, n, k, x, y, dist, fin_op, - stream, isRowMajor); + raft::distance::detail::jensenShannonImpl( + m, n, k, x, y, dist, fin_op, stream, isRowMajor); } }; @@ -209,9 +260,9 @@ struct DistanceImpl(m, n, k, x, y, dist, fin_op, stream, - isRowMajor); + raft::distance::detail::russellRaoImpl( + m, n, k, x, y, dist, fin_op, stream, isRowMajor); } }; @@ -222,9 +273,9 @@ struct DistanceImpl(m, n, k, x, y, dist, fin_op, - stream, isRowMajor); + raft::distance::detail::klDivergenceImpl( + m, n, k, x, y, dist, fin_op, stream, isRowMajor); } }; @@ -235,50 +286,15 @@ struct DistanceImpl(m, n, k, x, y, dist, - (AccType *)workspace, worksize, - fin_op, stream, isRowMajor); + raft::distance::detail::correlationImpl( + m, n, k, x, y, dist, (AccType *)workspace, worksize, fin_op, stream, + isRowMajor); } }; } // anonymous namespace -/** - * @brief Return the exact workspace size to compute the distance - * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam Index_ Index type - * @param x first set of points - * @param y second set of points - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * - * @note If the specifed distanceType doesn't need the workspace at all, it - * returns 0. - */ -template -size_t getWorkspaceSize(const InType *x, const InType *y, Index_ m, Index_ n, - Index_ k) { - size_t worksize = 0; - constexpr bool is_allocated = - (distanceType <= raft::distance::DistanceType::CosineExpanded) || - (distanceType == raft::distance::DistanceType::CorrelationExpanded); - constexpr int numOfBuffers = - (distanceType == raft::distance::DistanceType::CorrelationExpanded) ? 2 : 1; - - if (is_allocated) { - worksize += numOfBuffers * m * sizeof(AccType); - if (x != y) worksize += numOfBuffers * n * sizeof(AccType); - } - - return worksize; -} - /** * @brief Evaluate pairwise distances with the user epilogue lamba allowed * @tparam DistanceType which distance to evaluate @@ -319,7 +335,43 @@ void distance(const InType *x, const InType *y, OutType *dist, Index_ m, } /** - * @brief Evaluate pairwise distances for the simple use case + * @brief Evaluate pairwise distances for the simple use case + * @tparam DistanceType which distance to evaluate + * @tparam InType input argument type + * @tparam AccType accumulation type + * @tparam OutType output type + * @tparam Index_ Index type + * @param x first set of points + * @param y second set of points + * @param dist output distance matrix + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * @param workspace temporary workspace needed for computations + * @param worksize number of bytes of the workspace + * @param stream cuda stream + * @param isRowMajor whether the matrices are row-major or col-major + * + * @note if workspace is passed as nullptr, this will return in + * worksize, the number of bytes of workspace required + */ +template +void distance(const InType *x, const InType *y, OutType *dist, Index_ m, + Index_ n, Index_ k, void *workspace, size_t worksize, + cudaStream_t stream, bool isRowMajor = true, + InType metric_arg = 2.0f) { + auto default_fin_op = [] __device__(AccType d_val, Index_ g_d_idx) { + return d_val; + }; + distance(x, y, dist, m, n, k, workspace, worksize, default_fin_op, + stream, isRowMajor, metric_arg); + CUDA_CHECK(cudaPeekAtLastError()); +} + +/** + * @brief Return the exact workspace size to compute the distance * @tparam DistanceType which distance to evaluate * @tparam InType input argument type * @tparam AccType accumulation type @@ -327,31 +379,30 @@ void distance(const InType *x, const InType *y, OutType *dist, Index_ m, * @tparam Index_ Index type * @param x first set of points * @param y second set of points - * @param dist output distance matrix * @param m number of points in x * @param n number of points in y * @param k dimensionality - * @param workspace temporary workspace needed for computations - * @param worksize number of bytes of the workspace - * @param stream cuda stream - * @param isRowMajor whether the matrices are row-major or col-major * - * @note if workspace is passed as nullptr, this will return in - * worksize, the number of bytes of workspace required + * @note If the specifed distanceType doesn't need the workspace at all, it + * returns 0. */ template -void distance(const InType *x, const InType *y, OutType *dist, Index_ m, - Index_ n, Index_ k, void *workspace, size_t worksize, - cudaStream_t stream, bool isRowMajor = true, - InType metric_arg = 2.0f) { - auto default_fin_op = [] __device__(AccType d_val, Index_ g_d_idx) { - return d_val; - }; - distance(x, y, dist, m, n, k, workspace, worksize, default_fin_op, - stream, isRowMajor, metric_arg); - CUDA_CHECK(cudaPeekAtLastError()); +size_t getWorkspaceSize(const InType *x, const InType *y, Index_ m, Index_ n, + Index_ k) { + size_t worksize = 0; + constexpr bool is_allocated = + (distanceType <= raft::distance::DistanceType::CosineExpanded) || + (distanceType == raft::distance::DistanceType::CorrelationExpanded); + constexpr int numOfBuffers = + (distanceType == raft::distance::DistanceType::CorrelationExpanded) ? 2 : 1; + + if (is_allocated) { + worksize += numOfBuffers * m * sizeof(AccType); + if (x != y) worksize += numOfBuffers * n * sizeof(AccType); + } + + return worksize; } /** @@ -386,91 +437,7 @@ void pairwise_distance_impl(const Type *x, const Type *y, Type *dist, Index_ m, workspace.data(), worksize, stream, isRowMajor, metric_arg); } - -template -void pairwise_distance(const Type *x, const Type *y, Type *dist, Index_ m, - Index_ n, Index_ k, rmm::device_uvector &workspace, - raft::distance::DistanceType metric, cudaStream_t stream, - bool isRowMajor = true, Type metric_arg = 2.0f) { - switch (metric) { - case raft::distance::DistanceType::L2Expanded: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, stream, isRowMajor); - break; - case raft::distance::DistanceType::L2SqrtExpanded: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, stream, isRowMajor); - break; - case raft::distance::DistanceType::CosineExpanded: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, stream, isRowMajor); - break; - case raft::distance::DistanceType::L1: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, stream, isRowMajor); - break; - case raft::distance::DistanceType::L2Unexpanded: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, stream, isRowMajor); - break; - case raft::distance::DistanceType::L2SqrtUnexpanded: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, stream, isRowMajor); - break; - case raft::distance::DistanceType::Linf: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, stream, isRowMajor); - break; - case raft::distance::DistanceType::HellingerExpanded: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, stream, isRowMajor); - break; - case raft::distance::DistanceType::LpUnexpanded: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, stream, isRowMajor, metric_arg); - break; - case raft::distance::DistanceType::Canberra: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, stream, isRowMajor); - break; - case raft::distance::DistanceType::HammingUnexpanded: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, stream, isRowMajor); - break; - case raft::distance::DistanceType::JensenShannon: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, stream, isRowMajor); - break; - case raft::distance::DistanceType::RusselRaoExpanded: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, stream, isRowMajor); - break; - case raft::distance::DistanceType::KLDivergence: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, stream, isRowMajor); - break; - case raft::distance::DistanceType::CorrelationExpanded: - pairwise_distance_impl( - x, y, dist, m, n, k, workspace, stream, isRowMajor); - break; - default: - THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - }; -} /** @} */ - +}; // namespace detail }; // namespace distance }; // namespace raft diff --git a/cpp/include/raft/distance/euclidean.cuh b/cpp/include/raft/distance/detail/euclidean.cuh similarity index 99% rename from cpp/include/raft/distance/euclidean.cuh rename to cpp/include/raft/distance/detail/euclidean.cuh index 484da0e5bf..8b8882c244 100644 --- a/cpp/include/raft/distance/euclidean.cuh +++ b/cpp/include/raft/distance/detail/euclidean.cuh @@ -15,11 +15,12 @@ */ #pragma once -#include +#include #include namespace raft { namespace distance { +namespace detail { /** * @brief the expanded euclidean distance matrix calculation implementer @@ -352,5 +353,6 @@ void euclideanAlgo2(Index_ m, Index_ n, Index_ k, const InType *pA, } } +}; // end namespace detail }; // end namespace distance }; // end namespace raft diff --git a/cpp/include/raft/distance/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh similarity index 75% rename from cpp/include/raft/distance/fused_l2_nn.cuh rename to cpp/include/raft/distance/detail/fused_l2_nn.cuh index b96a536e38..ca8f729a68 100644 --- a/cpp/include/raft/distance/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -20,11 +20,12 @@ #include #include #include -#include +#include #include namespace raft { namespace distance { +namespace detail { #if (ENABLE_MEMCPY_ASYNC == 1) #include @@ -32,7 +33,7 @@ using namespace nvcuda::experimental; #endif template -struct KVPMinReduce { +struct KVPMinReduceImpl { typedef cub::KeyValuePair KVP; DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { @@ -42,7 +43,7 @@ struct KVPMinReduce { }; // KVPMinReduce template -struct MinAndDistanceReduceOp { +struct MinAndDistanceReduceOpImpl { typedef typename cub::KeyValuePair KVP; DI void operator()(LabelT rid, KVP* out, const KVP& other) { if (other.value < out->value) { @@ -58,7 +59,7 @@ struct MinAndDistanceReduceOp { }; template -struct MinReduceOp { +struct MinReduceOpImpl { typedef typename cub::KeyValuePair KVP; DI void operator()(LabelT rid, DataT* out, const KVP& other) { if (other.value < *out) { @@ -77,6 +78,14 @@ __global__ void initKernel(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) { } } +template +void initialize(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp, + cudaStream_t stream) { + auto blks = raft::ceildiv(m, 256); + initKernel + <<>>(min, m, maxVal, redOp); +} + // TODO: specialize this function for MinAndDistanceReduceOp // with atomicCAS of 64 bit which will eliminate mutex and shfls template -void fusedL2NN(OutT* min, const DataT* x, const DataT* y, const DataT* xn, - const DataT* yn, IdxT m, IdxT n, IdxT k, void* workspace, - ReduceOpT redOp, KVPReduceOpT pairRedOp, bool sqrt, - bool initOutBuffer, cudaStream_t stream) { - size_t bytes = sizeof(DataT) * k; - if (16 % sizeof(DataT) == 0 && bytes % 16 == 0) { - fusedL2NNImpl( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, - initOutBuffer, stream); - } else if (8 % sizeof(DataT) == 0 && bytes % 8 == 0) { - fusedL2NNImpl( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, - initOutBuffer, stream); - } else { - fusedL2NNImpl( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, - initOutBuffer, stream); - } -} - +} // namespace detail } // namespace distance } // namespace raft diff --git a/cpp/include/raft/distance/hamming.cuh b/cpp/include/raft/distance/detail/hamming.cuh similarity index 98% rename from cpp/include/raft/distance/hamming.cuh rename to cpp/include/raft/distance/detail/hamming.cuh index 08f1020b85..0169ba33a2 100644 --- a/cpp/include/raft/distance/hamming.cuh +++ b/cpp/include/raft/distance/detail/hamming.cuh @@ -15,10 +15,11 @@ */ #pragma once -#include +#include namespace raft { namespace distance { +namespace detail { /** * @brief the Hamming distance matrix using the unexpanded form: @@ -171,5 +172,7 @@ void hammingUnexpandedImpl(int m, int n, int k, const InType *pA, pDcast, fin_op, stream); } } + +} // namespace detail } // namespace distance } // namespace raft diff --git a/cpp/include/raft/distance/hellinger.cuh b/cpp/include/raft/distance/detail/hellinger.cuh similarity index 98% rename from cpp/include/raft/distance/hellinger.cuh rename to cpp/include/raft/distance/detail/hellinger.cuh index f7ad3ed1ba..933d850dbf 100644 --- a/cpp/include/raft/distance/hellinger.cuh +++ b/cpp/include/raft/distance/detail/hellinger.cuh @@ -15,11 +15,12 @@ */ #pragma once -#include +#include #include namespace raft { namespace distance { +namespace detail { /** * @brief the Hellinger distance matrix using the expanded form: @@ -200,5 +201,6 @@ void hellingerImpl(int m, int n, int k, const InType *pA, const InType *pB, n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); } } +} // namespace detail } // namespace distance } // namespace raft diff --git a/cpp/include/raft/distance/jensen_shannon.cuh b/cpp/include/raft/distance/detail/jensen_shannon.cuh similarity index 98% rename from cpp/include/raft/distance/jensen_shannon.cuh rename to cpp/include/raft/distance/detail/jensen_shannon.cuh index 2a94205853..1e39f39682 100644 --- a/cpp/include/raft/distance/jensen_shannon.cuh +++ b/cpp/include/raft/distance/detail/jensen_shannon.cuh @@ -15,10 +15,11 @@ */ #pragma once -#include +#include namespace raft { namespace distance { +namespace detail { /** * @brief the Jensen Shannon distance matrix: @@ -177,5 +178,6 @@ void jensenShannonImpl(int m, int n, int k, const InType *pA, const InType *pB, stream); } } +} // namespace detail } // namespace distance } // namespace raft diff --git a/cpp/include/raft/distance/kl_divergence.cuh b/cpp/include/raft/distance/detail/kl_divergence.cuh similarity index 99% rename from cpp/include/raft/distance/kl_divergence.cuh rename to cpp/include/raft/distance/detail/kl_divergence.cuh index 3197b73d10..5a18ba1670 100644 --- a/cpp/include/raft/distance/kl_divergence.cuh +++ b/cpp/include/raft/distance/detail/kl_divergence.cuh @@ -15,10 +15,11 @@ */ #pragma once -#include +#include namespace raft { namespace distance { +namespace detail { /** * @brief the KL Divergence distance matrix: @@ -238,5 +239,6 @@ void klDivergenceImpl(int m, int n, int k, const InType *pA, const InType *pB, false>(n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); } } +} // namespace detail } // namespace distance } // namespace raft diff --git a/cpp/include/raft/distance/l1.cuh b/cpp/include/raft/distance/detail/l1.cuh similarity index 98% rename from cpp/include/raft/distance/l1.cuh rename to cpp/include/raft/distance/detail/l1.cuh index 6ab084f041..33e9bae206 100644 --- a/cpp/include/raft/distance/l1.cuh +++ b/cpp/include/raft/distance/detail/l1.cuh @@ -15,10 +15,11 @@ */ #pragma once -#include +#include namespace raft { namespace distance { +namespace detail { /** * @brief the L1 distance matrix calculation implementer @@ -151,5 +152,6 @@ void l1Impl(int m, int n, int k, const InType *pA, const InType *pB, n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); } } +} // namespace detail } // namespace distance } // namespace raft diff --git a/cpp/include/raft/distance/minkowski.cuh b/cpp/include/raft/distance/detail/minkowski.cuh similarity index 98% rename from cpp/include/raft/distance/minkowski.cuh rename to cpp/include/raft/distance/detail/minkowski.cuh index 803f5fc78a..8bd3deb08f 100644 --- a/cpp/include/raft/distance/minkowski.cuh +++ b/cpp/include/raft/distance/detail/minkowski.cuh @@ -15,10 +15,11 @@ */ #pragma once -#include +#include namespace raft { namespace distance { +namespace detail { /** * @brief the unexpanded Minkowski distance matrix calculation @@ -167,6 +168,6 @@ void minkowskiImpl(Index_ m, Index_ n, Index_ k, const InType *pA, n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream, metric_arg); } } - +}; // end namespace detail }; // end namespace distance }; // end namespace raft diff --git a/cpp/include/raft/distance/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh similarity index 92% rename from cpp/include/raft/distance/pairwise_distance_base.cuh rename to cpp/include/raft/distance/detail/pairwise_distance_base.cuh index e3ff9a7081..a98bda1541 100644 --- a/cpp/include/raft/distance/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -24,6 +24,7 @@ namespace raft { namespace distance { +namespace detail { /** * @brief Device class for L1, L2 and cosine distance metrics. @@ -69,11 +70,11 @@ template -__global__ __launch_bounds__( - Policy::Nthreads, - 2) void pairwiseDistanceMatKernel(const DataT* x, const DataT* y, - const DataT* _xn, const DataT* _yn, IdxT m, - IdxT n, IdxT k, IdxT lda, IdxT ldb, - IdxT ldd, OutT* dOutput, CoreLambda core_op, - EpilogueLambda epilog_op, - FinalLambda fin_op) { +__global__ __launch_bounds__(Policy::Nthreads, 2) + + void pairwiseDistanceMatKernel(const DataT *x, const DataT *y, + const DataT *_xn, const DataT *_yn, IdxT m, + IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, + OutT *dOutput, CoreLambda core_op, + EpilogueLambda epilog_op, FinalLambda fin_op) { extern __shared__ char smem[]; auto rowEpilog = [] __device__(IdxT starty) { return; }; @@ -337,5 +337,6 @@ dim3 launchConfigGenerator(IdxT m, IdxT n, std::size_t sMemSize, T func) { return grid; } +}; // namespace detail }; // namespace distance }; // namespace raft diff --git a/cpp/include/raft/distance/russell_rao.cuh b/cpp/include/raft/distance/detail/russell_rao.cuh similarity index 98% rename from cpp/include/raft/distance/russell_rao.cuh rename to cpp/include/raft/distance/detail/russell_rao.cuh index 417fb73b94..8e4c4824c3 100644 --- a/cpp/include/raft/distance/russell_rao.cuh +++ b/cpp/include/raft/distance/detail/russell_rao.cuh @@ -15,10 +15,11 @@ */ #pragma once -#include +#include namespace raft { namespace distance { +namespace detail { /** * @brief the Russell Rao distance matrix: @@ -167,5 +168,6 @@ void russellRaoImpl(int m, int n, int k, const InType *pA, const InType *pB, n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); } } +} // namespace detail } // namespace distance } // namespace raft diff --git a/cpp/include/raft/distance/distance.hpp b/cpp/include/raft/distance/distance.hpp new file mode 100644 index 0000000000..84e8af261a --- /dev/null +++ b/cpp/include/raft/distance/distance.hpp @@ -0,0 +1,291 @@ +/* + * Copyright (c) 2018-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace raft { +namespace distance { + +/** +* @brief Evaluate pairwise distances with the user epilogue lamba allowed +* @tparam DistanceType which distance to evaluate +* @tparam InType input argument type +* @tparam AccType accumulation type +* @tparam OutType output type +* @tparam FinalLambda user-defined epilogue lamba +* @tparam Index_ Index type +* @param x first set of points +* @param y second set of points +* @param dist output distance matrix +* @param m number of points in x +* @param n number of points in y +* @param k dimensionality +* @param workspace temporary workspace needed for computations +* @param worksize number of bytes of the workspace +* @param fin_op the final gemm epilogue lambda +* @param stream cuda stream +* @param isRowMajor whether the matrices are row-major or col-major +* +* @note fin_op: This is a device lambda which is supposed to operate upon the +* input which is AccType and returns the output in OutType. It's signature is +* as follows:
OutType fin_op(AccType in, int g_idx);
. If one needs +* any other parameters, feel free to pass them via closure. +*/ +template +void distance(const InType *x, const InType *y, OutType *dist, Index_ m, + Index_ n, Index_ k, void *workspace, size_t worksize, + FinalLambda fin_op, cudaStream_t stream, bool isRowMajor = true, + InType metric_arg = 2.0f) { + detail::distance( + x, y, dist, m, n, k, workspace, worksize, fin_op, stream, isRowMajor, + metric_arg); +} + +/** +* @brief Evaluate pairwise distances for the simple use case +* @tparam DistanceType which distance to evaluate +* @tparam InType input argument type +* @tparam AccType accumulation type +* @tparam OutType output type +* @tparam Index_ Index type +* @param x first set of points +* @param y second set of points +* @param dist output distance matrix +* @param m number of points in x +* @param n number of points in y +* @param k dimensionality +* @param workspace temporary workspace needed for computations +* @param worksize number of bytes of the workspace +* @param stream cuda stream +* @param isRowMajor whether the matrices are row-major or col-major +* +* @note if workspace is passed as nullptr, this will return in +* worksize, the number of bytes of workspace required +*/ +template +void distance(const InType *x, const InType *y, OutType *dist, Index_ m, + Index_ n, Index_ k, void *workspace, size_t worksize, + cudaStream_t stream, bool isRowMajor = true, + InType metric_arg = 2.0f) { + detail::distance( + x, y, dist, m, n, k, workspace, worksize, stream, isRowMajor, metric_arg); +} + +/** +* @brief Return the exact workspace size to compute the distance +* @tparam DistanceType which distance to evaluate +* @tparam InType input argument type +* @tparam AccType accumulation type +* @tparam OutType output type +* @tparam Index_ Index type +* @param x first set of points +* @param y second set of points +* @param m number of points in x +* @param n number of points in y +* @param k dimensionality +* +* @note If the specified distanceType doesn't need the workspace at all, it +* returns 0. +*/ +template +size_t getWorkspaceSize(const InType *x, const InType *y, Index_ m, Index_ n, + Index_ k) { + return detail::getWorkspaceSize(x, y, m, n, k); +} + +/** +* @brief Evaluate pairwise distances for the simple use case +* @tparam DistanceType which distance to evaluate +* @tparam InType input argument type +* @tparam AccType accumulation type +* @tparam OutType output type +* @tparam Index_ Index type +* @param x first set of points +* @param y second set of points +* @param dist output distance matrix +* @param m number of points in x +* @param n number of points in y +* @param k dimensionality +* @param workspace temporary workspace needed for computations +* @param worksize number of bytes of the workspace +* @param stream cuda stream +* @param isRowMajor whether the matrices are row-major or col-major +* +* @note if workspace is passed as nullptr, this will return in +* worksize, the number of bytes of workspace required +*/ +template +void distance(const InType *x, const InType *y, OutType *dist, Index_ m, + Index_ n, Index_ k, cudaStream_t stream, bool isRowMajor = true, + InType metric_arg = 2.0f) { + rmm::device_uvector workspace(0, stream); + auto worksize = + getWorkspaceSize(x, y, m, n, + k); + workspace.resize(worksize, stream); + detail::distance( + x, y, dist, m, n, k, workspace.data(), worksize, stream, isRowMajor, + metric_arg); +} + +/** + * @defgroup pairwise_distance pairwise distance prims + * @{ + * @brief Convenience wrapper around 'distance' prim to convert runtime metric + * into compile time for the purpose of dispatch + * @tparam Type input/accumulation/output data-type + * @tparam Index_ indexing type + * @param x first set of points + * @param y second set of points + * @param dist output distance matrix + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * @param workspace temporary workspace buffer which can get resized as per the + * needed workspace size + * @param metric distance metric + * @param stream cuda stream + * @param isRowMajor whether the matrices are row-major or col-major + */ +template +void pairwise_distance(const raft::handle_t &handle, const Type *x, + const Type *y, Type *dist, Index_ m, Index_ n, Index_ k, + rmm::device_uvector &workspace, + raft::distance::DistanceType metric, + bool isRowMajor = true, Type metric_arg = 2.0f) { + switch (metric) { + case raft::distance::DistanceType::L2Expanded: + detail::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + case raft::distance::DistanceType::L2SqrtExpanded: + detail::pairwise_distance_impl< + Type, Index_, raft::distance::DistanceType::L2SqrtExpanded>( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + case raft::distance::DistanceType::CosineExpanded: + detail::pairwise_distance_impl< + Type, Index_, raft::distance::DistanceType::CosineExpanded>( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + case raft::distance::DistanceType::L1: + detail::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + case raft::distance::DistanceType::L2Unexpanded: + detail::pairwise_distance_impl< + Type, Index_, raft::distance::DistanceType::L2Unexpanded>( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + case raft::distance::DistanceType::L2SqrtUnexpanded: + detail::pairwise_distance_impl< + Type, Index_, raft::distance::DistanceType::L2SqrtUnexpanded>( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + case raft::distance::DistanceType::Linf: + detail::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + case raft::distance::DistanceType::HellingerExpanded: + detail::pairwise_distance_impl< + Type, Index_, raft::distance::DistanceType::HellingerExpanded>( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + case raft::distance::DistanceType::LpUnexpanded: + detail::pairwise_distance_impl< + Type, Index_, raft::distance::DistanceType::LpUnexpanded>( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor, + metric_arg); + break; + case raft::distance::DistanceType::Canberra: + detail::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + case raft::distance::DistanceType::HammingUnexpanded: + detail::pairwise_distance_impl< + Type, Index_, raft::distance::DistanceType::HammingUnexpanded>( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + case raft::distance::DistanceType::JensenShannon: + detail::pairwise_distance_impl< + Type, Index_, raft::distance::DistanceType::JensenShannon>( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + case raft::distance::DistanceType::RusselRaoExpanded: + detail::pairwise_distance_impl< + Type, Index_, raft::distance::DistanceType::RusselRaoExpanded>( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + case raft::distance::DistanceType::KLDivergence: + detail::pairwise_distance_impl< + Type, Index_, raft::distance::DistanceType::KLDivergence>( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + case raft::distance::DistanceType::CorrelationExpanded: + detail::pairwise_distance_impl< + Type, Index_, raft::distance::DistanceType::CorrelationExpanded>( + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); + break; + default: + THROW("Unknown or unsupported distance metric '%d'!", (int)metric); + }; +} +/** @} */ + +/** + * @defgroup pairwise_distance pairwise distance prims + * @{ + * @brief Convenience wrapper around 'distance' prim to convert runtime metric + * into compile time for the purpose of dispatch + * @tparam Type input/accumulation/output data-type + * @tparam Index_ indexing type + * @param x first set of points + * @param y second set of points + * @param dist output distance matrix + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * @param metric distance metric + * @param stream cuda stream + * @param isRowMajor whether the matrices are row-major or col-major + */ +template +void pairwise_distance(const raft::handle_t &handle, const Type *x, + const Type *y, Type *dist, Index_ m, Index_ n, Index_ k, + raft::distance::DistanceType metric, + bool isRowMajor = true, Type metric_arg = 2.0f) { + rmm::device_uvector workspace(0, handle.get_stream()); + pairwise_distance(handle, x, y, dist, m, n, k, workspace, + metric, isRowMajor, metric_arg); +} + +}; // namespace distance +}; // namespace raft diff --git a/cpp/include/raft/distance/fused_l2_nn.hpp b/cpp/include/raft/distance/fused_l2_nn.hpp new file mode 100644 index 0000000000..df9974f602 --- /dev/null +++ b/cpp/include/raft/distance/fused_l2_nn.hpp @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace distance { + +template +using KVPMinReduce = detail::KVPMinReduceImpl; + +template +using MinAndDistanceReduceOp = + detail::MinAndDistanceReduceOpImpl; + +template +using MinReduceOp = detail::MinReduceOpImpl; + +/** + * Initialize array using init value from reduction op + */ +template +void initialize(const raft::handle_t& handle, OutT* min, IdxT m, DataT maxVal, + ReduceOpT redOp) { + detail::initialize(min, m, maxVal, redOp, + handle.get_stream()); +} + +/** + * @brief Fused L2 distance and 1-nearest-neighbor computation in a single call. + * + * The benefits of such a call are 2-fold: 1) eliminate the need for an + * intermediate buffer to store the output of gemm 2) reduce the memory read + * traffic on this intermediate buffer, otherwise needed during the reduction + * phase for 1-NN. + * + * @tparam DataT data type + * @tparam OutT output type to either store 1-NN indices and their minimum + * distances or store only the min distances. Accordingly, one + * has to pass an appropriate `ReduceOpT` + * @tparam IdxT indexing arithmetic type + * @tparam ReduceOpT A struct to perform the final needed reduction operation + * and also to initialize the output array elements with the + * appropriate initial value needed for reduction. + * + * @param[out] min will contain the reduced output (Length = `m`) + * (on device) + * @param[in] x first matrix. Row major. Dim = `m x k`. + * (on device). + * @param[in] y second matrix. Row major. Dim = `n x k`. + * (on device). + * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). + * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) + * @param[in] m gemm m + * @param[in] n gemm n + * @param[in] k gemm k + * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) + * @param[in] redOp reduction operator in the epilogue + * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt + * @param[in] initOutBuffer whether to initialize the output buffer before the + * main kernel launch + * @param[in] stream cuda stream + */ +template +void fusedL2NN(OutT* min, const DataT* x, const DataT* y, const DataT* xn, + const DataT* yn, IdxT m, IdxT n, IdxT k, void* workspace, + ReduceOpT redOp, KVPReduceOpT pairRedOp, bool sqrt, + bool initOutBuffer, cudaStream_t stream) { + size_t bytes = sizeof(DataT) * k; + if (16 % sizeof(DataT) == 0 && bytes % 16 == 0) { + detail::fusedL2NNImpl( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, + initOutBuffer, stream); + } else if (8 % sizeof(DataT) == 0 && bytes % 8 == 0) { + detail::fusedL2NNImpl( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, + initOutBuffer, stream); + } else { + detail::fusedL2NNImpl( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, + initOutBuffer, stream); + } +} + +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/sparse/distance/bin_distance.cuh b/cpp/include/raft/sparse/distance/detail/bin_distance.cuh similarity index 98% rename from cpp/include/raft/sparse/distance/bin_distance.cuh rename to cpp/include/raft/sparse/distance/detail/bin_distance.cuh index 6885c250c0..e6dd1331ae 100644 --- a/cpp/include/raft/sparse/distance/bin_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/bin_distance.cuh @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include #include @@ -32,7 +32,7 @@ namespace raft { namespace sparse { namespace distance { - +namespace detail { // @TODO: Move this into sparse prims (coo_norm) template __global__ void compute_binary_row_norm_kernel( @@ -193,6 +193,7 @@ class dice_expanded_distances_t : public distances_t { ip_distances_t ip_dists; }; +} // END namespace detail }; // END namespace distance }; // END namespace sparse }; // END namespace raft diff --git a/cpp/include/raft/sparse/distance/coo_spmv.cuh b/cpp/include/raft/sparse/distance/detail/coo_spmv.cuh similarity index 98% rename from cpp/include/raft/sparse/distance/coo_spmv.cuh rename to cpp/include/raft/sparse/distance/detail/coo_spmv.cuh index 24be171900..83844b8c54 100644 --- a/cpp/include/raft/sparse/distance/coo_spmv.cuh +++ b/cpp/include/raft/sparse/distance/detail/coo_spmv.cuh @@ -24,9 +24,9 @@ #include #include -#include "../csr.cuh" -#include "../utils.h" -#include "common.h" +#include "../../csr.cuh" +#include "../../utils.h" +#include "../common.h" #include @@ -37,6 +37,7 @@ namespace raft { namespace sparse { namespace distance { +namespace detail { template class coo_spmv_strategy { @@ -84,6 +85,7 @@ class coo_spmv_strategy { const distances_config_t &config; }; +} // namespace detail } // namespace distance } // namespace sparse } // namespace raft diff --git a/cpp/include/raft/sparse/distance/coo_spmv_strategies/coo_mask_row_iterators.cuh b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/coo_mask_row_iterators.cuh similarity index 99% rename from cpp/include/raft/sparse/distance/coo_spmv_strategies/coo_mask_row_iterators.cuh rename to cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/coo_mask_row_iterators.cuh index 74eb37bc2b..0ab7b65ac2 100644 --- a/cpp/include/raft/sparse/distance/coo_spmv_strategies/coo_mask_row_iterators.cuh +++ b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/coo_mask_row_iterators.cuh @@ -16,7 +16,7 @@ #pragma once -#include "../common.h" +#include "../../common.h" #include "../utils.cuh" #include @@ -24,6 +24,7 @@ namespace raft { namespace sparse { namespace distance { +namespace detail { template class mask_row_it { @@ -186,6 +187,7 @@ class chunked_mask_row_it : public mask_row_it { } }; +} // namespace detail } // namespace distance } // namespace sparse } // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/sparse/distance/coo_spmv_strategies/dense_smem_strategy.cuh b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/dense_smem_strategy.cuh similarity index 98% rename from cpp/include/raft/sparse/distance/coo_spmv_strategies/dense_smem_strategy.cuh rename to cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/dense_smem_strategy.cuh index c463654a3b..79a5f154d0 100644 --- a/cpp/include/raft/sparse/distance/coo_spmv_strategies/dense_smem_strategy.cuh +++ b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/dense_smem_strategy.cuh @@ -21,6 +21,7 @@ namespace raft { namespace sparse { namespace distance { +namespace detail { template class dense_smem_strategy : public coo_spmv_strategy { @@ -92,6 +93,7 @@ class dense_smem_strategy : public coo_spmv_strategy { } }; +} // namespace detail } // namespace distance } // namespace sparse } // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/sparse/distance/coo_spmv_strategies/hash_strategy.cuh b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/hash_strategy.cuh similarity index 99% rename from cpp/include/raft/sparse/distance/coo_spmv_strategies/hash_strategy.cuh rename to cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/hash_strategy.cuh index a95c6ff85b..5ba2d5c102 100644 --- a/cpp/include/raft/sparse/distance/coo_spmv_strategies/hash_strategy.cuh +++ b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/hash_strategy.cuh @@ -33,6 +33,7 @@ CUCO_DECLARE_BITWISE_COMPARABLE(double); namespace raft { namespace sparse { namespace distance { +namespace detail { template class hash_strategy : public coo_spmv_strategy { @@ -217,6 +218,7 @@ class hash_strategy : public coo_spmv_strategy { int map_size; }; +} // namespace detail } // namespace distance } // namespace sparse } // namespace raft diff --git a/cpp/include/raft/sparse/distance/ip_distance.cuh b/cpp/include/raft/sparse/distance/detail/ip_distance.cuh similarity index 78% rename from cpp/include/raft/sparse/distance/ip_distance.cuh rename to cpp/include/raft/sparse/distance/detail/ip_distance.cuh index b1e2756671..2cd7b670d8 100644 --- a/cpp/include/raft/sparse/distance/ip_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/ip_distance.cuh @@ -27,8 +27,8 @@ #include #include #include -#include -#include +#include +#include #include #include @@ -36,14 +36,15 @@ namespace raft { namespace sparse { namespace distance { +namespace detail { template class ip_distances_t : public distances_t { public: /** - * Computes simple sparse inner product distances as sum(x_y * y_k) - * @param[in] config specifies inputs, outputs, and sizes - */ + * Computes simple sparse inner product distances as sum(x_y * y_k) + * @param[in] config specifies inputs, outputs, and sizes + */ ip_distances_t(const distances_config_t &config) : config_(&config), coo_rows_b(config.b_nnz, config.handle.get_stream()) { raft::sparse::convert::csr_to_coo(config_->b_indptr, config_->b_nrows, @@ -52,13 +53,13 @@ class ip_distances_t : public distances_t { } /** - * Performs pairwise distance computation and computes output distances - * @param out_distances dense output matrix (size a_nrows * b_nrows) - */ + * Performs pairwise distance computation and computes output distances + * @param out_distances dense output matrix (size a_nrows * b_nrows) + */ void compute(value_t *out_distances) { /** - * Compute pairwise distances and return dense matrix in row-major format - */ + * Compute pairwise distances and return dense matrix in row-major format + */ balanced_coo_pairwise_generalized_spmv( out_distances, *config_, coo_rows_b.data(), Product(), Sum(), AtomicAdd()); @@ -72,6 +73,8 @@ class ip_distances_t : public distances_t { const distances_config_t *config_; rmm::device_uvector coo_rows_b; }; + +}; // END namespace detail }; // END namespace distance }; // END namespace sparse }; // END namespace raft diff --git a/cpp/include/raft/sparse/distance/l2_distance.cuh b/cpp/include/raft/sparse/distance/detail/l2_distance.cuh similarity index 99% rename from cpp/include/raft/sparse/distance/l2_distance.cuh rename to cpp/include/raft/sparse/distance/detail/l2_distance.cuh index 6ccfd4adcb..e7ac78b80a 100644 --- a/cpp/include/raft/sparse/distance/l2_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/l2_distance.cuh @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include #include @@ -34,6 +34,7 @@ namespace raft { namespace sparse { namespace distance { +namespace detail { // @TODO: Move this into sparse prims (coo_norm) template @@ -417,6 +418,7 @@ class russelrao_expanded_distances_t : public distances_t { ip_distances_t ip_dists; }; +}; // END namespace detail }; // END namespace distance }; // END namespace sparse }; // END namespace raft diff --git a/cpp/include/raft/sparse/distance/lp_distance.cuh b/cpp/include/raft/sparse/distance/detail/lp_distance.cuh similarity index 98% rename from cpp/include/raft/sparse/distance/lp_distance.cuh rename to cpp/include/raft/sparse/distance/detail/lp_distance.cuh index 885d55ee50..c11369375b 100644 --- a/cpp/include/raft/sparse/distance/lp_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/lp_distance.cuh @@ -30,13 +30,14 @@ #include #include -#include +#include #include namespace raft { namespace sparse { namespace distance { +namespace detail { template @@ -272,6 +273,7 @@ class kl_divergence_unexpanded_distances_t : public distances_t { const distances_config_t *config_; }; +}; // END namespace detail }; // END namespace distance }; // END namespace sparse }; // END namespace raft diff --git a/cpp/include/raft/sparse/distance/operators.cuh b/cpp/include/raft/sparse/distance/detail/operators.cuh similarity index 98% rename from cpp/include/raft/sparse/distance/operators.cuh rename to cpp/include/raft/sparse/distance/detail/operators.cuh index 89acda8b1a..9f206095bf 100644 --- a/cpp/include/raft/sparse/distance/operators.cuh +++ b/cpp/include/raft/sparse/distance/detail/operators.cuh @@ -21,6 +21,7 @@ namespace raft { namespace sparse { namespace distance { +namespace detail { struct Sum { template @@ -90,6 +91,7 @@ struct AbsDiff { return fabs(a - b); } }; +} // namespace detail } // namespace distance } // namespace sparse }; // namespace raft diff --git a/cpp/include/raft/sparse/distance/utils.cuh b/cpp/include/raft/sparse/distance/detail/utils.cuh similarity index 96% rename from cpp/include/raft/sparse/distance/utils.cuh rename to cpp/include/raft/sparse/distance/detail/utils.cuh index 3bee1bc87d..abfb7d24ea 100644 --- a/cpp/include/raft/sparse/distance/utils.cuh +++ b/cpp/include/raft/sparse/distance/detail/utils.cuh @@ -24,6 +24,7 @@ namespace raft { namespace sparse { namespace distance { +namespace detail { /** * Computes the maximum number of columns that can be stored @@ -39,6 +40,7 @@ inline int max_cols_per_block() { sizeof(value_t); } +} // namespace detail } // namespace distance } // namespace sparse } // namespace raft diff --git a/cpp/include/raft/sparse/distance/distance.cuh b/cpp/include/raft/sparse/distance/distance.hpp similarity index 70% rename from cpp/include/raft/sparse/distance/distance.cuh rename to cpp/include/raft/sparse/distance/distance.hpp index 03df396b2e..24b10420f3 100644 --- a/cpp/include/raft/sparse/distance/distance.cuh +++ b/cpp/include/raft/sparse/distance/distance.hpp @@ -31,10 +31,10 @@ #include #include -#include -#include -#include -#include +#include +#include +#include +#include #include @@ -78,70 +78,77 @@ void pairwiseDistance(value_t *out, raft::distance::DistanceType metric, float metric_arg) { switch (metric) { case raft::distance::DistanceType::L2Expanded: - l2_expanded_distances_t(input_config).compute(out); + detail::l2_expanded_distances_t(input_config) + .compute(out); break; case raft::distance::DistanceType::L2SqrtExpanded: - l2_sqrt_expanded_distances_t(input_config) + detail::l2_sqrt_expanded_distances_t(input_config) .compute(out); break; case raft::distance::DistanceType::InnerProduct: - ip_distances_t(input_config).compute(out); + detail::ip_distances_t(input_config).compute(out); break; case raft::distance::DistanceType::L2Unexpanded: - l2_unexpanded_distances_t(input_config).compute(out); + detail::l2_unexpanded_distances_t(input_config) + .compute(out); break; case raft::distance::DistanceType::L2SqrtUnexpanded: - l2_sqrt_unexpanded_distances_t(input_config) + detail::l2_sqrt_unexpanded_distances_t(input_config) .compute(out); break; case raft::distance::DistanceType::L1: - l1_unexpanded_distances_t(input_config).compute(out); + detail::l1_unexpanded_distances_t(input_config) + .compute(out); break; case raft::distance::DistanceType::LpUnexpanded: - lp_unexpanded_distances_t(input_config, metric_arg) + detail::lp_unexpanded_distances_t(input_config, + metric_arg) .compute(out); break; case raft::distance::DistanceType::Linf: - linf_unexpanded_distances_t(input_config) + detail::linf_unexpanded_distances_t(input_config) .compute(out); break; case raft::distance::DistanceType::Canberra: - canberra_unexpanded_distances_t(input_config) + detail::canberra_unexpanded_distances_t(input_config) .compute(out); break; case raft::distance::DistanceType::JaccardExpanded: - jaccard_expanded_distances_t(input_config) + detail::jaccard_expanded_distances_t(input_config) .compute(out); break; case raft::distance::DistanceType::CosineExpanded: - cosine_expanded_distances_t(input_config) + detail::cosine_expanded_distances_t(input_config) .compute(out); break; case raft::distance::DistanceType::HellingerExpanded: - hellinger_expanded_distances_t(input_config) + detail::hellinger_expanded_distances_t(input_config) .compute(out); break; case raft::distance::DistanceType::DiceExpanded: - dice_expanded_distances_t(input_config).compute(out); + detail::dice_expanded_distances_t(input_config) + .compute(out); break; case raft::distance::DistanceType::CorrelationExpanded: - correlation_expanded_distances_t(input_config) + detail::correlation_expanded_distances_t(input_config) .compute(out); break; case raft::distance::DistanceType::RusselRaoExpanded: - russelrao_expanded_distances_t(input_config) + detail::russelrao_expanded_distances_t(input_config) .compute(out); break; case raft::distance::DistanceType::HammingUnexpanded: - hamming_unexpanded_distances_t(input_config) + detail::hamming_unexpanded_distances_t(input_config) .compute(out); break; case raft::distance::DistanceType::JensenShannon: - jensen_shannon_unexpanded_distances_t(input_config) + detail::jensen_shannon_unexpanded_distances_t( + input_config) .compute(out); break; case raft::distance::DistanceType::KLDivergence: - kl_divergence_unexpanded_distances_t(input_config) + detail::kl_divergence_unexpanded_distances_t( + input_config) .compute(out); break; diff --git a/cpp/include/raft/sparse/selection/connect_components.cuh b/cpp/include/raft/sparse/selection/connect_components.cuh index 46369ca964..5313b81192 100644 --- a/cpp/include/raft/sparse/selection/connect_components.cuh +++ b/cpp/include/raft/sparse/selection/connect_components.cuh @@ -16,7 +16,7 @@ #include -#include +#include #include #include #include diff --git a/cpp/include/raft/sparse/selection/knn.cuh b/cpp/include/raft/sparse/selection/knn.cuh index 49573a679d..fc1a7c0d8d 100644 --- a/cpp/include/raft/sparse/selection/knn.cuh +++ b/cpp/include/raft/sparse/selection/knn.cuh @@ -20,7 +20,6 @@ #include #include -#include #include #include #include @@ -30,16 +29,9 @@ #include #include #include -#include +#include #include -#include -#include -#include -#include - -#include - namespace raft { namespace sparse { namespace selection { @@ -412,7 +404,6 @@ class sparse_knn_t { * @param[out] output_indices dense matrix for output indices (size n_query_rows * k) * @param[out] output_dists dense matrix for output distances (size n_query_rows * k) * @param[in] k the number of neighbors to query - * @param[in] cusparseHandle the initialized cusparseHandle instance to use * @param[in] handle.get_stream() CUDA handle.get_stream() to order operations with respect to * @param[in] batch_size_index maximum number of rows to use from index matrix per batch * @param[in] batch_size_query maximum number of rows to use from query matrix per batch diff --git a/cpp/include/raft/spatial/knn/detail/ann_quantized_faiss.cuh b/cpp/include/raft/spatial/knn/detail/ann_quantized_faiss.cuh index 0e91b5225d..980001f166 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_quantized_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_quantized_faiss.cuh @@ -27,7 +27,7 @@ #include "processing.hpp" #include