From 5c2c9bcac145ed794d7fc78c374b6dc874b763b4 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 22 Jul 2021 17:54:52 +0530 Subject: [PATCH 01/18] add hamming dist metric support and test case for it --- cpp/include/raft/distance/distance.cuh | 18 +++ cpp/include/raft/distance/hamming.cuh | 171 +++++++++++++++++++++++++ cpp/test/CMakeLists.txt | 1 + cpp/test/distance/dist_hamming.cu | 69 ++++++++++ cpp/test/distance/distance_base.cuh | 24 ++++ 5 files changed, 283 insertions(+) create mode 100644 cpp/include/raft/distance/hamming.cuh create mode 100644 cpp/test/distance/dist_hamming.cu diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index 1b39a6ec18..f46fa1bafc 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -171,6 +172,18 @@ struct DistanceImpl +struct DistanceImpl { + void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, + Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, + cudaStream_t stream, bool isRowMajor, InType metric_arg) { + raft::distance::hammingUnexpandedImpl(m, n, k, x, y, dist, fin_op, stream, isRowMajor); + } +}; + } // anonymous namespace /** @@ -366,6 +379,11 @@ void pairwise_distance(const Type *x, const Type *y, Type *dist, Index_ m, raft::distance::DistanceType::Canberra>( x, y, dist, m, n, k, workspace, stream, isRowMajor); break; + case raft::distance::DistanceType::HammingUnexpanded: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, stream, isRowMajor); + break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); }; diff --git a/cpp/include/raft/distance/hamming.cuh b/cpp/include/raft/distance/hamming.cuh new file mode 100644 index 0000000000..4b29708db8 --- /dev/null +++ b/cpp/include/raft/distance/hamming.cuh @@ -0,0 +1,171 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace raft { +namespace distance { + +/** + * @brief the Hamming distance matrix using the unexpanded form: + * It computes the following equation: + Cij = sum(x_i != y_i) / k + * + * @tparam DataT input data-type (for A and B matrices) + * @tparam AccT accumulation data-type + * @tparam OutT output data-type (for C and D matrices) + * @tparam IdxT index data-type + * @tparam Veclen number of k-elements loaded by each thread + for every LDG call. details in contractions.cuh + * @tparam FinalLambda final lambda called on final distance value + * @tparam isRowMajor true if input/output is row major, + false for column major + * @param[in] x input matrix + * @param[in] y input matrix + * @param[in] m number of rows of A and C/D + * @param[in] n number of rows of B and C/D + * @param[in] k number of cols of A and B + * @param[in] lda leading dimension of A + * @param[in] ldb leading dimension of B + * @param[in] ldd leading dimension of C/D + * @param[output] dOutput output matrix + * @param[in] fin_op the final gemm epilogue lambda + * @param[in] stream cuda stream to launch work + */ +template +static void hammingUnexpandedImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, + IdxT k, IdxT lda, IdxT ldb, IdxT ldd, OutT *dOutput, + FinalLambda fin_op, cudaStream_t stream) { + typedef typename raft::linalg::Policy4x4::Policy RowPolicy; + typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; + + typedef + typename std::conditional::type KPolicy; + + dim3 blk(KPolicy::Nthreads); + + // Accumulation operation lambda + auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { + acc += (x != y); + }; + + // epilogue operation lambda for final value calculation + auto epilog_lambda = [k] __device__( + AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], + DataT * regxn, DataT * regyn, IdxT gridStrideX, + IdxT gridStrideY) { + const DataT one_over_k = DataT(1.0) / k; +#pragma unroll + for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { + acc[i][j] *= one_over_k; + } + } + }; + + if (isRowMajor) { + auto hammingUnexpandedRowMajor = + pairwiseDistanceMatKernel; + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + hammingUnexpandedRowMajor); + + hammingUnexpandedRowMajor<<>>( + x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); + } else { + auto hammingUnexpandedColMajor = + pairwiseDistanceMatKernel; + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + hammingUnexpandedColMajor); + hammingUnexpandedColMajor<<>>( + x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); + } + + CUDA_CHECK(cudaGetLastError()); +} + +template +void hammingUnexpanded(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, + const DataT *x, const DataT *y, OutT *dOutput, + FinalLambda fin_op, cudaStream_t stream) { + size_t bytesA = sizeof(DataT) * lda; + size_t bytesB = sizeof(DataT) * ldb; + if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { + hammingUnexpandedImpl(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, + stream); + } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { + hammingUnexpandedImpl(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, + stream); + } else { + hammingUnexpandedImpl( + x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); + } +} + +/** + * @brief the Hamming Unexpanded distance matrix calculation + * It computes the following equation: + Cij = sum(x_i != y_i) / k + * + * @tparam InType input data-type (for A and B matrices) + * @tparam AccType accumulation data-type + * @tparam OutType output data-type (for C and D matrices) + * @tparam FinalLambda user-defined epilogue lamba + * @tparam Index_ Index type + * @param m number of rows of A and C/D + * @param n number of columns of B and C/D + * @param k number of cols of A and rows of B + * @param pA input matrix + * @param pB input matrix + * @param pD output matrix + * @param fin_op the final element-wise epilogue lambda + * @param stream cuda stream where to launch work + * @param isRowMajor whether the input and output matrices are row major + */ +template +void hammingUnexpandedImpl(int m, int n, int k, const InType *pA, const InType *pB, + OutType *pD, FinalLambda fin_op, cudaStream_t stream, + bool isRowMajor) { + typedef std::is_same is_bool; + typedef typename std::conditional::type + hammingUnexpandedOutType; + Index_ lda, ldb, ldd; + hammingUnexpandedOutType *pDcast = reinterpret_cast(pD); + if (isRowMajor) { + lda = k, ldb = k, ldd = n; + hammingUnexpanded( + m, n, k, lda, ldb, ldd, pA, pB, pDcast, fin_op, stream); + + } else { + lda = n, ldb = m, ldd = m; + hammingUnexpanded( + n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); + } +} +} // namespace distance +} // namespace raft diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index f94a8d9525..69c23c1cd4 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -24,6 +24,7 @@ add_executable(test_raft test/distance/dist_cos.cu test/distance/dist_euc_exp.cu test/distance/dist_euc_unexp.cu + test/distance/dist_hamming.cu test/distance/dist_hellinger.cu test/distance/dist_l1.cu test/distance/dist_minkowski.cu diff --git a/cpp/test/distance/dist_hamming.cu b/cpp/test/distance/dist_hamming.cu new file mode 100644 index 0000000000..47febd825b --- /dev/null +++ b/cpp/test/distance/dist_hamming.cu @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2018-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.h" +#include "distance_base.cuh" + +namespace raft { +namespace distance { + +template +class DistanceHamming + : public DistanceTest {}; + +const std::vector> inputsf = { + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceHamming DistanceHammingF; +TEST_P(DistanceHammingF, Result) { + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n, + raft::CompareApprox(params.tolerance))); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceHammingF, + ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.001, 1024, 1024, 32, true, 1234ULL}, + {0.001, 1024, 32, 1024, true, 1234ULL}, + {0.001, 32, 1024, 1024, true, 1234ULL}, + {0.003, 1024, 1024, 1024, true, 1234ULL}, + {0.001, 1024, 1024, 32, false, 1234ULL}, + {0.001, 1024, 32, 1024, false, 1234ULL}, + {0.001, 32, 1024, 1024, false, 1234ULL}, + {0.003, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceHamming DistanceHammingD; +TEST_P(DistanceHammingD, Result) { + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n, + raft::CompareApprox(params.tolerance))); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceHammingD, + ::testing::ValuesIn(inputsd)); + +} // end namespace distance +} // end namespace raft diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index fc7b064205..d653ba0e53 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -160,6 +160,26 @@ __global__ void naiveLpUnexpDistanceKernel(DataType *dist, const DataType *x, dist[outidx] = acc; } +template +__global__ void naiveHammingDistanceKernel(DataType *dist, const DataType *x, + const DataType *y, int m, int n, + int k, bool isRowMajor) { + int midx = threadIdx.x + blockIdx.x * blockDim.x; + int nidx = threadIdx.y + blockIdx.y * blockDim.y; + if (midx >= m || nidx >= n) return; + DataType acc = DataType(0); + for (int i = 0; i < k; ++i) { + int xidx = isRowMajor ? i + midx * k : i * m + midx; + int yidx = isRowMajor ? i + nidx * k : i * n + nidx; + auto a = x[xidx]; + auto b = y[yidx]; + acc += (a != b); + } + acc = acc / k; + int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; + dist[outidx] = acc; +} + template void naiveDistance(DataType *dist, const DataType *x, const DataType *y, int m, int n, int k, raft::distance::DistanceType type, @@ -193,6 +213,10 @@ void naiveDistance(DataType *dist, const DataType *x, const DataType *y, int m, naiveLpUnexpDistanceKernel <<>>(dist, x, y, m, n, k, isRowMajor, metric_arg); break; + case raft::distance::DistanceType::HammingUnexpanded: + naiveHammingDistanceKernel + <<>>(dist, x, y, m, n, k, isRowMajor); + break; default: FAIL() << "should be here\n"; } From f2f26bf046b2bcf79c32863cc1d9c41d3d307f6f Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 28 Jul 2021 16:51:28 +0530 Subject: [PATCH 02/18] add jensen shannon distance metric support --- cpp/include/raft/distance/distance.cuh | 18 ++ cpp/include/raft/distance/jensen_shannon.cuh | 191 +++++++++++++++++++ cpp/test/CMakeLists.txt | 1 + cpp/test/distance/dist_jensen_shannon.cu | 69 +++++++ cpp/test/distance/distance_base.cuh | 39 +++- 5 files changed, 317 insertions(+), 1 deletion(-) create mode 100644 cpp/include/raft/distance/jensen_shannon.cuh create mode 100644 cpp/test/distance/dist_jensen_shannon.cu diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index f46fa1bafc..6fe756f4eb 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -25,6 +25,7 @@ #include #include #include +#include #include #include #include @@ -184,6 +185,18 @@ struct DistanceImpl +struct DistanceImpl { + void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, + Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, + cudaStream_t stream, bool isRowMajor, InType metric_arg) { + raft::distance::jensenShannonImpl(m, n, k, x, y, dist, fin_op, stream, isRowMajor); + } +}; + } // anonymous namespace /** @@ -384,6 +397,11 @@ void pairwise_distance(const Type *x, const Type *y, Type *dist, Index_ m, raft::distance::DistanceType::HammingUnexpanded>( x, y, dist, m, n, k, workspace, stream, isRowMajor); break; + case raft::distance::DistanceType::JensenShannon: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, stream, isRowMajor); + break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); }; diff --git a/cpp/include/raft/distance/jensen_shannon.cuh b/cpp/include/raft/distance/jensen_shannon.cuh new file mode 100644 index 0000000000..8217ca7a39 --- /dev/null +++ b/cpp/include/raft/distance/jensen_shannon.cuh @@ -0,0 +1,191 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace raft { +namespace distance { + +template +DI T fastLog(T x); +template <> +DI float fastLog(float x) { + return __logf(x); +} +template <> +DI double fastLog(double x) { + return log(x); +} + +/** + * @brief the Jensen Shannon distance matrix: + * It computes the following equation: + Cij = sqrt(0.5 * sum( -x_i * (log(0.5 * (x_i + y_i)) - log(x_i)) + + (-y_i * (log(0.5 * (x_i + y_i)) - log(y_i))))) + * + * @tparam DataT input data-type (for A and B matrices) + * @tparam AccT accumulation data-type + * @tparam OutT output data-type (for C and D matrices) + * @tparam IdxT index data-type + * @tparam Veclen number of k-elements loaded by each thread + for every LDG call. details in contractions.cuh + * @tparam FinalLambda final lambda called on final distance value + * @tparam isRowMajor true if input/output is row major, + false for column major + * @param[in] x input matrix + * @param[in] y input matrix + * @param[in] m number of rows of A and C/D + * @param[in] n number of rows of B and C/D + * @param[in] k number of cols of A and B + * @param[in] lda leading dimension of A + * @param[in] ldb leading dimension of B + * @param[in] ldd leading dimension of C/D + * @param[output] dOutput output matrix + * @param[in] fin_op the final gemm epilogue lambda + * @param[in] stream cuda stream to launch work + */ +template +static void jensenShannonImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, + IdxT k, IdxT lda, IdxT ldb, IdxT ldd, OutT *dOutput, + FinalLambda fin_op, cudaStream_t stream) { + typedef typename raft::linalg::Policy4x4::Policy RowPolicy; + typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; + + typedef + typename std::conditional::type KPolicy; + + dim3 blk(KPolicy::Nthreads); + + // Accumulation operation lambda + auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { + const DataT m = 0.5f * (x + y); + + const bool m_zero = (m == 0); + const auto logM = (!m_zero) * fastLog(m + m_zero); + const bool x_zero = x == 0; + const bool y_zero = y == 0; + + acc += (-x * (logM - (!x_zero) * fastLog(x_zero + x))) + + (-y * (logM - (!y_zero) * fastLog(y_zero + y))); + }; + + // epilogue operation lambda for final value calculation + auto epilog_lambda = [] __device__( + AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], + DataT * regxn, DataT * regyn, IdxT gridStrideX, + IdxT gridStrideY) { +#pragma unroll + for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { + acc[i][j] = raft::mySqrt(0.5 * acc[i][j]); + } + } + }; + + if (isRowMajor) { + auto jensenShannonRowMajor = + pairwiseDistanceMatKernel; + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + jensenShannonRowMajor); + + jensenShannonRowMajor<<>>( + x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); + } else { + auto jensenShannonColMajor = + pairwiseDistanceMatKernel; + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + jensenShannonColMajor); + jensenShannonColMajor<<>>( + x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); + } + + CUDA_CHECK(cudaGetLastError()); +} + +template +void jensenShannon(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, + const DataT *x, const DataT *y, OutT *dOutput, + FinalLambda fin_op, cudaStream_t stream) { + size_t bytesA = sizeof(DataT) * lda; + size_t bytesB = sizeof(DataT) * ldb; + if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { + jensenShannonImpl(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, + stream); + } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { + jensenShannonImpl(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, + stream); + } else { + jensenShannonImpl( + x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); + } +} + +/** + * @brief the Jensen Shannon distance matrix calculation + * It computes the following equation: + Cij = sqrt(0.5 * sum( -x_i * (log(0.5 * (x_i + y_i)) - log(x_i)) + + (-y_i * (log(0.5 * (x_i + y_i)) - log(y_i))))) + * + * @tparam InType input data-type (for A and B matrices) + * @tparam AccType accumulation data-type + * @tparam OutType output data-type (for C and D matrices) + * @tparam FinalLambda user-defined epilogue lamba + * @tparam Index_ Index type + * @param m number of rows of A and C/D + * @param n number of columns of B and C/D + * @param k number of cols of A and rows of B + * @param pA input matrix + * @param pB input matrix + * @param pD output matrix + * @param fin_op the final element-wise epilogue lambda + * @param stream cuda stream where to launch work + * @param isRowMajor whether the input and output matrices are row major + */ +template +void jensenShannonImpl(int m, int n, int k, const InType *pA, const InType *pB, + OutType *pD, FinalLambda fin_op, cudaStream_t stream, + bool isRowMajor) { + typedef std::is_same is_bool; + typedef typename std::conditional::type + jensenShannonOutType; + Index_ lda, ldb, ldd; + jensenShannonOutType *pDcast = reinterpret_cast(pD); + if (isRowMajor) { + lda = k, ldb = k, ldd = n; + jensenShannon( + m, n, k, lda, ldb, ldd, pA, pB, pDcast, fin_op, stream); + + } else { + lda = n, ldb = m, ldd = m; + jensenShannon( + n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); + } +} +} // namespace distance +} // namespace raft diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 69c23c1cd4..67ab8c2bbb 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -26,6 +26,7 @@ add_executable(test_raft test/distance/dist_euc_unexp.cu test/distance/dist_hamming.cu test/distance/dist_hellinger.cu + test/distance/dist_jensen_shannon.cu test/distance/dist_l1.cu test/distance/dist_minkowski.cu test/distance/fused_l2_nn.cu diff --git a/cpp/test/distance/dist_jensen_shannon.cu b/cpp/test/distance/dist_jensen_shannon.cu new file mode 100644 index 0000000000..a6fa954042 --- /dev/null +++ b/cpp/test/distance/dist_jensen_shannon.cu @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.h" +#include "distance_base.cuh" + +namespace raft { +namespace distance { + +template +class DistanceJensenShannon + : public DistanceTest {}; + +const std::vector> inputsf = { + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceJensenShannon DistanceJensenShannonF; +TEST_P(DistanceJensenShannonF, Result) { + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n, + raft::CompareApprox(params.tolerance))); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceJensenShannonF, + ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.001, 1024, 1024, 32, true, 1234ULL}, + {0.001, 1024, 32, 1024, true, 1234ULL}, + {0.001, 32, 1024, 1024, true, 1234ULL}, + {0.003, 1024, 1024, 1024, true, 1234ULL}, + {0.001, 1024, 1024, 32, false, 1234ULL}, + {0.001, 1024, 32, 1024, false, 1234ULL}, + {0.001, 32, 1024, 1024, false, 1234ULL}, + {0.003, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceJensenShannon DistanceJensenShannonD; +TEST_P(DistanceJensenShannonD, Result) { + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n, + raft::CompareApprox(params.tolerance))); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceJensenShannonD, + ::testing::ValuesIn(inputsd)); + +} // end namespace distance +} // end namespace raft diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index d653ba0e53..c731d353d8 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -180,6 +180,38 @@ __global__ void naiveHammingDistanceKernel(DataType *dist, const DataType *x, dist[outidx] = acc; } +template +__global__ void naiveJensenShannonDistanceKernel(DataType *dist, + const DataType *x, const DataType *y, + int m, int n, int k, bool isRowMajor) { + int midx = threadIdx.x + blockIdx.x * blockDim.x; + int nidx = threadIdx.y + blockIdx.y * blockDim.y; + if (midx >= m || nidx >= n) return; + DataType acc = DataType(0); + for (int i = 0; i < k; ++i) { + int xidx = isRowMajor ? i + midx * k : i * m + midx; + int yidx = isRowMajor ? i + nidx * k : i * n + nidx; + auto a = x[xidx]; + auto b = y[yidx]; + + DataType m = 0.5f * (a + b); + bool a_zero = a == 0; + bool b_zero = b == 0; + + DataType p = (!a_zero * m) / (a_zero + a); + DataType q = (!b_zero * m) / (b_zero + b); + + bool p_zero = p == 0; + bool q_zero = q == 0; + + acc += (-a * (!p_zero * log(p + p_zero))) + + (-b * (!q_zero * log(q + q_zero))); + } + acc = raft::mySqrt(0.5f * acc); + int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; + dist[outidx] = acc; +} + template void naiveDistance(DataType *dist, const DataType *x, const DataType *y, int m, int n, int k, raft::distance::DistanceType type, @@ -217,6 +249,10 @@ void naiveDistance(DataType *dist, const DataType *x, const DataType *y, int m, naiveHammingDistanceKernel <<>>(dist, x, y, m, n, k, isRowMajor); break; + case raft::distance::DistanceType::JensenShannon: + naiveJensenShannonDistanceKernel + <<>>(dist, x, y, m, n, k, isRowMajor); + break; default: FAIL() << "should be here\n"; } @@ -271,7 +307,8 @@ class DistanceTest : public ::testing::TestWithParam> { raft::allocate(dist_ref, m * n); raft::allocate(dist, m * n); raft::allocate(dist2, m * n); - if (distanceType == raft::distance::DistanceType::HellingerExpanded) { + if (distanceType == raft::distance::DistanceType::HellingerExpanded || + distanceType == raft::distance::DistanceType::JensenShannon) { // Hellinger works only on positive numbers r.uniform(x, m * k, DataType(0.0), DataType(1.0), stream); r.uniform(y, n * k, DataType(0.0), DataType(1.0), stream); From a4bcb45ddffbeead3be6abccb200b40ed04484c9 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 28 Jul 2021 20:14:50 +0530 Subject: [PATCH 03/18] remove x_zero & y_zero multiplying factor to log() as x & y are already a multiplying factor so inferring 0s to avoid log is already in place --- cpp/include/raft/distance/jensen_shannon.cuh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/distance/jensen_shannon.cuh b/cpp/include/raft/distance/jensen_shannon.cuh index 8217ca7a39..62a61b4c9c 100644 --- a/cpp/include/raft/distance/jensen_shannon.cuh +++ b/cpp/include/raft/distance/jensen_shannon.cuh @@ -80,8 +80,8 @@ static void jensenShannonImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, const bool x_zero = x == 0; const bool y_zero = y == 0; - acc += (-x * (logM - (!x_zero) * fastLog(x_zero + x))) + - (-y * (logM - (!y_zero) * fastLog(y_zero + y))); + acc += (-x * (logM - fastLog(x_zero + x))) + + (-y * (logM - fastLog(y_zero + y))); }; // epilogue operation lambda for final value calculation From 99cf19a1504a0f96218d56a21090c00bf97b8450 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 28 Jul 2021 21:39:55 +0530 Subject: [PATCH 04/18] add russell rao distance metric support --- cpp/include/raft/distance/distance.cuh | 18 +++ cpp/include/raft/distance/russell_rao.cuh | 174 ++++++++++++++++++++++ cpp/test/CMakeLists.txt | 1 + cpp/test/distance/dist_russell_rao.cu | 69 +++++++++ cpp/test/distance/distance_base.cuh | 30 ++++ 5 files changed, 292 insertions(+) create mode 100644 cpp/include/raft/distance/russell_rao.cuh create mode 100644 cpp/test/distance/dist_russell_rao.cu diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index 6fe756f4eb..d743390495 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -28,6 +28,7 @@ #include #include #include +#include #include namespace raft { @@ -197,6 +198,18 @@ struct DistanceImpl +struct DistanceImpl { + void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, + Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, + cudaStream_t stream, bool isRowMajor, InType metric_arg) { + raft::distance::russellRaoImpl(m, n, k, x, y, dist, fin_op, stream, isRowMajor); + } +}; + } // anonymous namespace /** @@ -402,6 +415,11 @@ void pairwise_distance(const Type *x, const Type *y, Type *dist, Index_ m, raft::distance::DistanceType::JensenShannon>( x, y, dist, m, n, k, workspace, stream, isRowMajor); break; + case raft::distance::DistanceType::RusselRaoExpanded: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, stream, isRowMajor); + break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); }; diff --git a/cpp/include/raft/distance/russell_rao.cuh b/cpp/include/raft/distance/russell_rao.cuh new file mode 100644 index 0000000000..ab3ed8cb9a --- /dev/null +++ b/cpp/include/raft/distance/russell_rao.cuh @@ -0,0 +1,174 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include + +namespace raft { +namespace distance { + + +/** + * @brief the Russell Rao distance matrix: + * It computes the following equation: + Cij = sqrt(0.5 * sum( -x_i * (log(0.5 * (x_i + y_i)) - log(x_i)) + + (-y_i * (log(0.5 * (x_i + y_i)) - log(y_i))))) + * + * @tparam DataT input data-type (for A and B matrices) + * @tparam AccT accumulation data-type + * @tparam OutT output data-type (for C and D matrices) + * @tparam IdxT index data-type + * @tparam Veclen number of k-elements loaded by each thread + for every LDG call. details in contractions.cuh + * @tparam FinalLambda final lambda called on final distance value + * @tparam isRowMajor true if input/output is row major, + false for column major + * @param[in] x input matrix + * @param[in] y input matrix + * @param[in] m number of rows of A and C/D + * @param[in] n number of rows of B and C/D + * @param[in] k number of cols of A and B + * @param[in] lda leading dimension of A + * @param[in] ldb leading dimension of B + * @param[in] ldd leading dimension of C/D + * @param[output] dOutput output matrix + * @param[in] fin_op the final gemm epilogue lambda + * @param[in] stream cuda stream to launch work + */ +template +static void russellRaoImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, + IdxT k, IdxT lda, IdxT ldb, IdxT ldd, OutT *dOutput, + FinalLambda fin_op, cudaStream_t stream) { + typedef typename raft::linalg::Policy4x4::Policy RowPolicy; + typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; + + typedef + typename std::conditional::type KPolicy; + + dim3 blk(KPolicy::Nthreads); + + // Accumulation operation lambda + auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { + acc += x * y; + }; + + const float one_over_k = 1.0 / k; + // epilogue operation lambda for final value calculation + auto epilog_lambda = [k, one_over_k] __device__( + AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], + DataT * regxn, DataT * regyn, IdxT gridStrideX, + IdxT gridStrideY) { +#pragma unroll + for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { + acc[i][j] = (k - acc[i][j]) * one_over_k; + } + } + }; + + if (isRowMajor) { + constexpr auto russellRaoRowMajor = + pairwiseDistanceMatKernel; + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + russellRaoRowMajor); + + russellRaoRowMajor<<>>( + x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); + } else { + constexpr auto russellRaoColMajor = + pairwiseDistanceMatKernel; + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + russellRaoColMajor); + russellRaoColMajor<<>>( + x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); + } + + CUDA_CHECK(cudaGetLastError()); +} + +template +void russellRao(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, + const DataT *x, const DataT *y, OutT *dOutput, + FinalLambda fin_op, cudaStream_t stream) { + size_t bytesA = sizeof(DataT) * lda; + size_t bytesB = sizeof(DataT) * ldb; + if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { + russellRaoImpl(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, + stream); + } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { + russellRaoImpl(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, + stream); + } else { + russellRaoImpl( + x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); + } +} + +/** + * @brief the Russell Rao distance matrix calculation + * It computes the following equation: + Cij = sqrt(0.5 * sum( -x_i * (log(0.5 * (x_i + y_i)) - log(x_i)) + + (-y_i * (log(0.5 * (x_i + y_i)) - log(y_i))))) + * + * @tparam InType input data-type (for A and B matrices) + * @tparam AccType accumulation data-type + * @tparam OutType output data-type (for C and D matrices) + * @tparam FinalLambda user-defined epilogue lamba + * @tparam Index_ Index type + * @param m number of rows of A and C/D + * @param n number of columns of B and C/D + * @param k number of cols of A and rows of B + * @param pA input matrix + * @param pB input matrix + * @param pD output matrix + * @param fin_op the final element-wise epilogue lambda + * @param stream cuda stream where to launch work + * @param isRowMajor whether the input and output matrices are row major + */ +template +void russellRaoImpl(int m, int n, int k, const InType *pA, const InType *pB, + OutType *pD, FinalLambda fin_op, cudaStream_t stream, + bool isRowMajor) { + typedef std::is_same is_bool; + typedef typename std::conditional::type + russellRaoOutType; + Index_ lda, ldb, ldd; + russellRaoOutType *pDcast = reinterpret_cast(pD); + if (isRowMajor) { + lda = k, ldb = k, ldd = n; + russellRao( + m, n, k, lda, ldb, ldd, pA, pB, pDcast, fin_op, stream); + + } else { + lda = n, ldb = m, ldd = m; + russellRao( + n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); + } +} +} // namespace distance +} // namespace raft diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 67ab8c2bbb..3507f1d301 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -29,6 +29,7 @@ add_executable(test_raft test/distance/dist_jensen_shannon.cu test/distance/dist_l1.cu test/distance/dist_minkowski.cu + test/distance/dist_russell_rao.cu test/distance/fused_l2_nn.cu test/eigen_solvers.cu test/handle.cpp diff --git a/cpp/test/distance/dist_russell_rao.cu b/cpp/test/distance/dist_russell_rao.cu new file mode 100644 index 0000000000..74ccfb0c2e --- /dev/null +++ b/cpp/test/distance/dist_russell_rao.cu @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.h" +#include "distance_base.cuh" + +namespace raft { +namespace distance { + +template +class DistanceRussellRao + : public DistanceTest {}; + +const std::vector> inputsf = { + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceRussellRao DistanceRussellRaoF; +TEST_P(DistanceRussellRaoF, Result) { + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n, + raft::CompareApprox(params.tolerance))); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceRussellRaoF, + ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.001, 1024, 1024, 32, true, 1234ULL}, + {0.001, 1024, 32, 1024, true, 1234ULL}, + {0.001, 32, 1024, 1024, true, 1234ULL}, + {0.003, 1024, 1024, 1024, true, 1234ULL}, + {0.001, 1024, 1024, 32, false, 1234ULL}, + {0.001, 1024, 32, 1024, false, 1234ULL}, + {0.001, 32, 1024, 1024, false, 1234ULL}, + {0.003, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceRussellRao DistanceRussellRaoD; +TEST_P(DistanceRussellRaoD, Result) { + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n, + raft::CompareApprox(params.tolerance))); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceRussellRaoD, + ::testing::ValuesIn(inputsd)); + +} // end namespace distance +} // end namespace raft diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index c731d353d8..008feb0683 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -212,6 +212,26 @@ __global__ void naiveJensenShannonDistanceKernel(DataType *dist, dist[outidx] = acc; } +template +__global__ void naiveRussellRaoDistanceKernel(OutType *dist, const DataType *x, + const DataType *y, int m, int n, + int k, bool isRowMajor) { + int midx = threadIdx.x + blockIdx.x * blockDim.x; + int nidx = threadIdx.y + blockIdx.y * blockDim.y; + if (midx >= m || nidx >= n) return; + OutType acc = OutType(0); + for (int i = 0; i < k; ++i) { + int xidx = isRowMajor ? i + midx * k : i * m + midx; + int yidx = isRowMajor ? i + nidx * k : i * n + nidx; + auto a = x[xidx]; + auto b = y[yidx]; + acc += (a * b); + } + acc = (k - acc) / k; + int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; + dist[outidx] = acc; +} + template void naiveDistance(DataType *dist, const DataType *x, const DataType *y, int m, int n, int k, raft::distance::DistanceType type, @@ -253,6 +273,10 @@ void naiveDistance(DataType *dist, const DataType *x, const DataType *y, int m, naiveJensenShannonDistanceKernel <<>>(dist, x, y, m, n, k, isRowMajor); break; + case raft::distance::DistanceType::RusselRaoExpanded: + naiveRussellRaoDistanceKernel + <<>>(dist, x, y, m, n, k, isRowMajor); + break; default: FAIL() << "should be here\n"; } @@ -312,6 +336,12 @@ class DistanceTest : public ::testing::TestWithParam> { // Hellinger works only on positive numbers r.uniform(x, m * k, DataType(0.0), DataType(1.0), stream); r.uniform(y, n * k, DataType(0.0), DataType(1.0), stream); + } else if (distanceType == raft::distance::DistanceType::RusselRaoExpanded) { + r.uniform(x, m * k, DataType(0.0), DataType(1.0), stream); + r.uniform(y, n * k, DataType(0.0), DataType(1.0), stream); + // Russel rao works on boolean values. + r.bernoulli(x, m * k, 0.5f, stream); + r.bernoulli(y, n * k, 0.5f, stream); } else { r.uniform(x, m * k, DataType(-1.0), DataType(1.0), stream); r.uniform(y, n * k, DataType(-1.0), DataType(1.0), stream); From 90775365cfbe22fdd9612e84c9d39d47c28c70f8 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 28 Jul 2021 23:19:28 +0530 Subject: [PATCH 05/18] remove addition of m_zero, x_zero, y_zero in the log() as it is not needed as we are multiplying with a 0 in case any of the inputs to log() is 0 so there is no resultant inf. this improves perf further now 3.7x compared to naive kernel version --- cpp/include/raft/distance/jensen_shannon.cuh | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/cpp/include/raft/distance/jensen_shannon.cuh b/cpp/include/raft/distance/jensen_shannon.cuh index 62a61b4c9c..cf04397db7 100644 --- a/cpp/include/raft/distance/jensen_shannon.cuh +++ b/cpp/include/raft/distance/jensen_shannon.cuh @@ -74,14 +74,10 @@ static void jensenShannonImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, // Accumulation operation lambda auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { const DataT m = 0.5f * (x + y); - const bool m_zero = (m == 0); - const auto logM = (!m_zero) * fastLog(m + m_zero); - const bool x_zero = x == 0; - const bool y_zero = y == 0; + const auto logM = (!m_zero) * fastLog(m); - acc += (-x * (logM - fastLog(x_zero + x))) + - (-y * (logM - fastLog(y_zero + y))); + acc += (-x * (logM - fastLog(x))) + (-y * (logM - fastLog(y))); }; // epilogue operation lambda for final value calculation From 862db66fa330083c6a76e54ffc514598da47b663 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 29 Jul 2021 13:38:57 +0530 Subject: [PATCH 06/18] correct equation description of russell rao --- cpp/include/raft/distance/russell_rao.cuh | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/cpp/include/raft/distance/russell_rao.cuh b/cpp/include/raft/distance/russell_rao.cuh index ab3ed8cb9a..6c6c60fde9 100644 --- a/cpp/include/raft/distance/russell_rao.cuh +++ b/cpp/include/raft/distance/russell_rao.cuh @@ -24,8 +24,7 @@ namespace distance { /** * @brief the Russell Rao distance matrix: * It computes the following equation: - Cij = sqrt(0.5 * sum( -x_i * (log(0.5 * (x_i + y_i)) - log(x_i)) - + (-y_i * (log(0.5 * (x_i + y_i)) - log(y_i))))) + Cij = (k - sum(x_i * y_i)) / k * * @tparam DataT input data-type (for A and B matrices) * @tparam AccT accumulation data-type @@ -131,8 +130,7 @@ void russellRao(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, /** * @brief the Russell Rao distance matrix calculation * It computes the following equation: - Cij = sqrt(0.5 * sum( -x_i * (log(0.5 * (x_i + y_i)) - log(x_i)) - + (-y_i * (log(0.5 * (x_i + y_i)) - log(y_i))))) + Cij = (k - sum(x_i * y_i)) / k * * @tparam InType input data-type (for A and B matrices) * @tparam AccType accumulation data-type From 44046ee940202f3338c97c1f3dc404479ed9601f Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 29 Jul 2021 17:36:30 +0530 Subject: [PATCH 07/18] add kl-divergence distance metric support --- cpp/include/raft/distance/distance.cuh | 19 +++ cpp/include/raft/distance/kl_divergence.cuh | 178 ++++++++++++++++++++ cpp/test/CMakeLists.txt | 1 + cpp/test/distance/dist_kl_divergence.cu | 69 ++++++++ cpp/test/distance/distance_base.cuh | 30 +++- 5 files changed, 296 insertions(+), 1 deletion(-) create mode 100644 cpp/include/raft/distance/kl_divergence.cuh create mode 100644 cpp/test/distance/dist_kl_divergence.cu diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index d743390495..4a6cfacea7 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -210,6 +211,19 @@ struct DistanceImpl +struct DistanceImpl { + void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, + Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, + cudaStream_t stream, bool isRowMajor, InType metric_arg) { + raft::distance::klDivergenceImpl(m, n, k, x, y, dist, fin_op, stream, isRowMajor); + } +}; + + } // anonymous namespace /** @@ -420,6 +434,11 @@ void pairwise_distance(const Type *x, const Type *y, Type *dist, Index_ m, raft::distance::DistanceType::RusselRaoExpanded>( x, y, dist, m, n, k, workspace, stream, isRowMajor); break; + case raft::distance::DistanceType::KLDivergence: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, stream, isRowMajor); + break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); }; diff --git a/cpp/include/raft/distance/kl_divergence.cuh b/cpp/include/raft/distance/kl_divergence.cuh new file mode 100644 index 0000000000..85a273e456 --- /dev/null +++ b/cpp/include/raft/distance/kl_divergence.cuh @@ -0,0 +1,178 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include + +namespace raft { +namespace distance { + + +/** + * @brief the KL Divergence distance matrix: + * It computes the following equation: + Cij = 0.5 * sum(x * log (x / y)); + * + * @tparam DataT input data-type (for A and B matrices) + * @tparam AccT accumulation data-type + * @tparam OutT output data-type (for C and D matrices) + * @tparam IdxT index data-type + * @tparam Veclen number of k-elements loaded by each thread + for every LDG call. details in contractions.cuh + * @tparam FinalLambda final lambda called on final distance value + * @tparam isRowMajor true if input/output is row major, + false for column major + * @param[in] x input matrix + * @param[in] y input matrix + * @param[in] m number of rows of A and C/D + * @param[in] n number of rows of B and C/D + * @param[in] k number of cols of A and B + * @param[in] lda leading dimension of A + * @param[in] ldb leading dimension of B + * @param[in] ldd leading dimension of C/D + * @param[output] dOutput output matrix + * @param[in] fin_op the final gemm epilogue lambda + * @param[in] stream cuda stream to launch work + */ +template +static void klDivergenceImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, + IdxT k, IdxT lda, IdxT ldb, IdxT ldd, OutT *dOutput, + FinalLambda fin_op, cudaStream_t stream) { + typedef typename raft::linalg::Policy4x4::Policy RowPolicy; + typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; + + typedef + typename std::conditional::type KPolicy; + + dim3 blk(KPolicy::Nthreads); + + // Accumulation operation lambda + auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { + if (isRowMajor) { + const bool y_zero = (y == 0); + acc += x * (fastLog(x) - (!y_zero) * fastLog(y)); + } else { + const bool x_zero = (x == 0); + acc += y * (fastLog(y) - (!x_zero) * fastLog(x)); + } + }; + + // epilogue operation lambda for final value calculation + auto epilog_lambda = [] __device__( + AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], + DataT * regxn, DataT * regyn, IdxT gridStrideX, + IdxT gridStrideY) { +#pragma unroll + for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { + acc[i][j] = (0.5f * acc[i][j]); + } + } + }; + + if (isRowMajor) { + constexpr auto klDivergenceRowMajor = + pairwiseDistanceMatKernel; + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + klDivergenceRowMajor); + + klDivergenceRowMajor<<>>( + x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); + } else { + constexpr auto klDivergenceColMajor = + pairwiseDistanceMatKernel; + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + klDivergenceColMajor); + klDivergenceColMajor<<>>( + x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); + } + + CUDA_CHECK(cudaGetLastError()); +} + +template +void klDivergence(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, + const DataT *x, const DataT *y, OutT *dOutput, + FinalLambda fin_op, cudaStream_t stream) { + size_t bytesA = sizeof(DataT) * lda; + size_t bytesB = sizeof(DataT) * ldb; + if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { + klDivergenceImpl(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, + stream); + } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { + klDivergenceImpl(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, + stream); + } else { + klDivergenceImpl( + x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); + } +} + +/** + * @brief the KL Divergence distance matrix calculation + * It computes the following equation: + Cij = 0.5 * sum(x * log (x / y)); + * + * @tparam InType input data-type (for A and B matrices) + * @tparam AccType accumulation data-type + * @tparam OutType output data-type (for C and D matrices) + * @tparam FinalLambda user-defined epilogue lamba + * @tparam Index_ Index type + * @param m number of rows of A and C/D + * @param n number of columns of B and C/D + * @param k number of cols of A and rows of B + * @param pA input matrix + * @param pB input matrix + * @param pD output matrix + * @param fin_op the final element-wise epilogue lambda + * @param stream cuda stream where to launch work + * @param isRowMajor whether the input and output matrices are row major + */ +template +void klDivergenceImpl(int m, int n, int k, const InType *pA, const InType *pB, + OutType *pD, FinalLambda fin_op, cudaStream_t stream, + bool isRowMajor) { + typedef std::is_same is_bool; + typedef typename std::conditional::type + klDivergenceOutType; + Index_ lda, ldb, ldd; + klDivergenceOutType *pDcast = reinterpret_cast(pD); + if (isRowMajor) { + lda = k, ldb = k, ldd = n; + klDivergence( + m, n, k, lda, ldb, ldd, pA, pB, pDcast, fin_op, stream); + + } else { + lda = n, ldb = m, ldd = m; + klDivergence( + n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); + } +} +} // namespace distance +} // namespace raft diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 3507f1d301..204f40ca22 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -27,6 +27,7 @@ add_executable(test_raft test/distance/dist_hamming.cu test/distance/dist_hellinger.cu test/distance/dist_jensen_shannon.cu + test/distance/dist_kl_divergence.cu test/distance/dist_l1.cu test/distance/dist_minkowski.cu test/distance/dist_russell_rao.cu diff --git a/cpp/test/distance/dist_kl_divergence.cu b/cpp/test/distance/dist_kl_divergence.cu new file mode 100644 index 0000000000..d5182e90ba --- /dev/null +++ b/cpp/test/distance/dist_kl_divergence.cu @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.h" +#include "distance_base.cuh" + +namespace raft { +namespace distance { + +template +class DistanceKLDivergence + : public DistanceTest {}; + +const std::vector> inputsf = { + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceKLDivergence DistanceKLDivergenceF; +TEST_P(DistanceKLDivergenceF, Result) { + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n, + raft::CompareApprox(params.tolerance))); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceKLDivergenceF, + ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.001, 1024, 1024, 32, true, 1234ULL}, + {0.001, 1024, 32, 1024, true, 1234ULL}, + {0.001, 32, 1024, 1024, true, 1234ULL}, + {0.003, 1024, 1024, 1024, true, 1234ULL}, + {0.001, 1024, 1024, 32, false, 1234ULL}, + {0.001, 1024, 32, 1024, false, 1234ULL}, + {0.001, 32, 1024, 1024, false, 1234ULL}, + {0.003, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceKLDivergence DistanceKLDivergenceD; +TEST_P(DistanceKLDivergenceD, Result) { + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n, + raft::CompareApprox(params.tolerance))); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceKLDivergenceD, + ::testing::ValuesIn(inputsd)); + +} // end namespace distance +} // end namespace raft diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index 008feb0683..9d458d9bf2 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -232,6 +232,29 @@ __global__ void naiveRussellRaoDistanceKernel(OutType *dist, const DataType *x, dist[outidx] = acc; } +template +__global__ void naiveKLDivergenceDistanceKernel(OutType *dist, const DataType *x, + const DataType *y, int m, int n, + int k, bool isRowMajor) { + int midx = threadIdx.x + blockIdx.x * blockDim.x; + int nidx = threadIdx.y + blockIdx.y * blockDim.y; + if (midx >= m || nidx >= n) return; + OutType acc = OutType(0); + for (int i = 0; i < k; ++i) { + int xidx = isRowMajor ? i + midx * k : i * m + midx; + int yidx = isRowMajor ? i + nidx * k : i * n + nidx; + auto a = x[xidx]; + auto b = y[yidx]; + bool b_zero = (b == 0); + const auto m = (!b_zero) * (a / b); + const bool m_zero = (m == 0); + acc += (a * (!m_zero) * log(m + m_zero)); + } + acc = 0.5f * acc; + int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; + dist[outidx] = acc; +} + template void naiveDistance(DataType *dist, const DataType *x, const DataType *y, int m, int n, int k, raft::distance::DistanceType type, @@ -277,6 +300,10 @@ void naiveDistance(DataType *dist, const DataType *x, const DataType *y, int m, naiveRussellRaoDistanceKernel <<>>(dist, x, y, m, n, k, isRowMajor); break; + case raft::distance::DistanceType::KLDivergence: + naiveKLDivergenceDistanceKernel + <<>>(dist, x, y, m, n, k, isRowMajor); + break; default: FAIL() << "should be here\n"; } @@ -332,7 +359,8 @@ class DistanceTest : public ::testing::TestWithParam> { raft::allocate(dist, m * n); raft::allocate(dist2, m * n); if (distanceType == raft::distance::DistanceType::HellingerExpanded || - distanceType == raft::distance::DistanceType::JensenShannon) { + distanceType == raft::distance::DistanceType::JensenShannon || + distanceType == raft::distance::DistanceType::KLDivergence) { // Hellinger works only on positive numbers r.uniform(x, m * k, DataType(0.0), DataType(1.0), stream); r.uniform(y, n * k, DataType(0.0), DataType(1.0), stream); From 59207b907cd663b098850a9b1dcedeb152f19590 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 30 Jul 2021 20:53:36 +0530 Subject: [PATCH 08/18] Add correlation distance metric support --- cpp/include/raft/distance/correlation.cuh | 243 ++++++++++++++++++++++ cpp/include/raft/distance/distance.cuh | 29 ++- cpp/test/CMakeLists.txt | 1 + cpp/test/distance/dist_correlation.cu | 69 ++++++ cpp/test/distance/distance_base.cuh | 40 ++++ 5 files changed, 379 insertions(+), 3 deletions(-) create mode 100644 cpp/include/raft/distance/correlation.cuh create mode 100644 cpp/test/distance/dist_correlation.cu diff --git a/cpp/include/raft/distance/correlation.cuh b/cpp/include/raft/distance/correlation.cuh new file mode 100644 index 0000000000..5f993bc894 --- /dev/null +++ b/cpp/include/raft/distance/correlation.cuh @@ -0,0 +1,243 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include +#include + +namespace raft { +namespace distance { + + +/** + * @brief the Correlation distance matrix: + * + * @tparam DataT input data-type (for A and B matrices) + * @tparam AccT accumulation data-type + * @tparam OutT output data-type (for C and D matrices) + * @tparam IdxT index data-type + * @tparam Veclen number of k-elements loaded by each thread + for every LDG call. details in contractions.cuh + * @tparam FinalLambda final lambda called on final distance value + * @tparam isRowMajor true if input/output is row major, + false for column major + * @param[in] x input matrix + * @param[in] y input matrix + * @param[in] m number of rows of A and C/D + * @param[in] n number of rows of B and C/D + * @param[in] k number of cols of A and B + * @param[in] lda leading dimension of A + * @param[in] ldb leading dimension of B + * @param[in] ldd leading dimension of C/D + * @param[output] dOutput output matrix + * @param[in] fin_op the final gemm epilogue lambda + * @param[in] stream cuda stream to launch work + */ +template +static void correlationImpl(const DataT *x, const DataT *y, const DataT *xn, + const DataT *yn, const DataT *x2n, const DataT *y2n, + IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, + OutT *dOutput, FinalLambda fin_op, + cudaStream_t stream) { + typedef typename raft::linalg::Policy4x4::Policy RowPolicy; + typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; + + typedef + typename std::conditional::type KPolicy; + + dim3 blk(KPolicy::Nthreads); + + // Accumulation operation lambda + auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { + acc += x * y; + }; + + // epilogue operation lambda for final value calculation + auto epilog_lambda = [x2n, y2n, m, n, k] __device__( + AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], + DataT * regxn, DataT * regyn, IdxT gridStrideX, + IdxT gridStrideY) { + DataT regx2n[KPolicy::AccRowsPerTh], regy2n[KPolicy::AccColsPerTh]; + + extern __shared__ char smem[]; + DataT* sx2Norm = (DataT*)(&smem[KPolicy::SmemSize + + (KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)]); + DataT* sy2Norm = (&sx2Norm[KPolicy::Mblk]); + + // Load x & y norms required by this threadblock in shmem buffer + if (gridStrideX == blockIdx.x * KPolicy::Nblk) { + for (int i = threadIdx.x; i < KPolicy::Mblk; i += KPolicy::Nthreads) { + auto idx = gridStrideY + i; + sx2Norm[i] = idx < m ? x2n[idx] : 0; + } + } + + for (int i = threadIdx.x; i < KPolicy::Nblk; i += KPolicy::Nthreads) { + auto idx = gridStrideX + i; + sy2Norm[i] = idx < n ? y2n[idx] : 0; + } + __syncthreads(); + +#pragma unroll + for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { + regx2n[i] = sx2Norm[i * KPolicy::AccThRows + (threadIdx.x / KPolicy::AccThCols)]; + } +#pragma unroll + for (int i = 0; i < KPolicy::AccColsPerTh; ++i) { + regy2n[i] = sy2Norm[i * KPolicy::AccThCols + (threadIdx.x % KPolicy::AccThCols)]; + } + +#pragma unroll + for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { + auto numer = k * acc[i][j] - (regxn[i] * regyn[j]); + auto Q_denom = k * regx2n[i] - (regxn[i] * regxn[i]); + auto R_denom = k * regy2n[j] - (regyn[j] * regyn[j]); + + acc[i][j] = 1 - (numer / raft::mySqrt(Q_denom * R_denom)); + + // correct for small instabilities + acc[i][j] = acc[i][j] * (fabs(acc[i][j]) >= 0.0001); + } + } + }; + + constexpr size_t shmemSize = + KPolicy::SmemSize + (2 * (KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)); + if (isRowMajor) { + constexpr auto correlationRowMajor = + pairwiseDistanceMatKernel; + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + correlationRowMajor); + correlationRowMajor<<>>( + x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); + } else { + constexpr auto correlationColMajor = + pairwiseDistanceMatKernel; + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + correlationColMajor); + correlationColMajor<<>>( + x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); + } + + CUDA_CHECK(cudaGetLastError()); +} + +template +void correlation(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, + const DataT *x, const DataT *y, const DataT *xn, const DataT *yn, + const DataT *x2n, const DataT *y2n, + OutT *dOutput, FinalLambda fin_op, cudaStream_t stream) { + size_t bytesA = sizeof(DataT) * lda; + size_t bytesB = sizeof(DataT) * ldb; + if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { + correlationImpl(x, y, xn, yn, x2n, y2n, m, n, k, lda, ldb, ldd, + dOutput, fin_op, stream); + } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { + correlationImpl(x, y, xn, yn, x2n, y2n, m, n, k, lda, ldb, ldd, + dOutput, fin_op, stream); + } else { + correlationImpl( + x, y, xn, yn, x2n, y2n, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); + } +} + +/** + * @brief the Correlation distance matrix calculation + * + * @tparam InType input data-type (for A and B matrices) + * @tparam AccType accumulation data-type + * @tparam OutType output data-type (for C and D matrices) + * @tparam FinalLambda user-defined epilogue lamba + * @tparam Index_ Index type + * @param m number of rows of A and C/D + * @param n number of columns of B and C/D + * @param k number of cols of A and rows of B + * @param pA input matrix + * @param pB input matrix + * @param pD output matrix + * @param fin_op the final element-wise epilogue lambda + * @param stream cuda stream where to launch work + * @param isRowMajor whether the input and output matrices are row major + */ +template +void correlationImpl(int m, int n, int k, const InType *pA, const InType *pB, + OutType *pD, AccType *workspace, size_t &worksize, + FinalLambda fin_op, cudaStream_t stream, bool isRowMajor) { + typedef std::is_same is_bool; + typedef typename std::conditional::type + correlationOutType; + Index_ lda, ldb, ldd; + correlationOutType *pDcast = reinterpret_cast(pD); + + ASSERT(!(((pA != pB) && (worksize < 2 * (m + n) * sizeof(AccType))) || + (worksize < 2 * m * sizeof(AccType))), "workspace size error"); + ASSERT(workspace != nullptr, "workspace is null"); + + AccType *norm_col_vec = workspace; + AccType *norm_row_vec = workspace; + AccType *sq_norm_col_vec = workspace; + AccType *sq_norm_row_vec = workspace; + if (pA != pB) { + norm_row_vec += m; + + raft::linalg::reduce(norm_col_vec, pA, k, m, (AccType)0, isRowMajor, true, + stream, false, raft::Nop(), raft::Sum()); + raft::linalg::reduce(norm_row_vec, pB, k, n, (AccType)0, isRowMajor, true, + stream, false, raft::Nop(), raft::Sum()); + + sq_norm_col_vec += (m + n); + sq_norm_row_vec = sq_norm_col_vec + m; + raft::linalg::rowNorm(sq_norm_col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, + stream); + raft::linalg::rowNorm(sq_norm_row_vec, pB, k, n, raft::linalg::L2Norm, isRowMajor, + stream); + } else { + raft::linalg::reduce(norm_col_vec, pA, k, m, (AccType)0, isRowMajor, true, + stream, false, raft::Nop(), raft::Sum()); + sq_norm_col_vec += m; + sq_norm_row_vec = sq_norm_col_vec; + raft::linalg::rowNorm(sq_norm_col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, + stream); + } + + if (isRowMajor) { + lda = k, ldb = k, ldd = n; + correlation( + m, n, k, lda, ldb, ldd, pA, pB, norm_col_vec, norm_row_vec, + sq_norm_col_vec, sq_norm_row_vec, pDcast, fin_op, stream); + } else { + lda = n, ldb = m, ldd = m; + correlation( + n, m, k, lda, ldb, ldd, pB, pA, norm_row_vec, norm_col_vec, + sq_norm_row_vec, sq_norm_col_vec, pDcast, fin_op, stream); + } +} +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index 4a6cfacea7..46e90f813d 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -21,6 +21,7 @@ #include #include #include +#include #include #include #include @@ -223,6 +224,18 @@ struct DistanceImpl +struct DistanceImpl { + void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, + Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, + cudaStream_t stream, bool isRowMajor, InType metric_arg) { + raft::distance::correlationImpl(m, n, k, x, y, dist, (AccType *)workspace, worksize, + fin_op, stream, isRowMajor); + } +}; } // anonymous namespace @@ -248,11 +261,16 @@ size_t getWorkspaceSize(const InType *x, const InType *y, Index_ m, Index_ n, Index_ k) { size_t worksize = 0; constexpr bool is_allocated = - distanceType <= raft::distance::DistanceType::CosineExpanded; + (distanceType <= raft::distance::DistanceType::CosineExpanded) || + (distanceType == raft::distance::DistanceType::CorrelationExpanded); + constexpr int numOfBuffers = + (distanceType == raft::distance::DistanceType::CorrelationExpanded) ? 2 : 1; + if (is_allocated) { - worksize += m * sizeof(AccType); - if (x != y) worksize += n * sizeof(AccType); + worksize += numOfBuffers * m * sizeof(AccType); + if (x != y) worksize += numOfBuffers * n * sizeof(AccType); } + return worksize; } @@ -439,6 +457,11 @@ void pairwise_distance(const Type *x, const Type *y, Type *dist, Index_ m, raft::distance::DistanceType::KLDivergence>( x, y, dist, m, n, k, workspace, stream, isRowMajor); break; + case raft::distance::DistanceType::CorrelationExpanded: + pairwise_distance_impl( + x, y, dist, m, n, k, workspace, stream, isRowMajor); + break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); }; diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 204f40ca22..0428e09142 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -21,6 +21,7 @@ add_executable(test_raft test/distance/dist_adj.cu test/distance/dist_canberra.cu test/distance/dist_chebyshev.cu + test/distance/dist_correlation.cu test/distance/dist_cos.cu test/distance/dist_euc_exp.cu test/distance/dist_euc_unexp.cu diff --git a/cpp/test/distance/dist_correlation.cu b/cpp/test/distance/dist_correlation.cu new file mode 100644 index 0000000000..5d84f18e52 --- /dev/null +++ b/cpp/test/distance/dist_correlation.cu @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.h" +#include "distance_base.cuh" + +namespace raft { +namespace distance { + +template +class DistanceCorrelation + : public DistanceTest {}; + +const std::vector> inputsf = { + {0.001f, 1024, 1024, 32, true, 1234ULL}, + {0.001f, 1024, 32, 1024, true, 1234ULL}, + {0.001f, 32, 1024, 1024, true, 1234ULL}, + {0.003f, 1024, 1024, 1024, true, 1234ULL}, + {0.001f, 1024, 1024, 32, false, 1234ULL}, + {0.001f, 1024, 32, 1024, false, 1234ULL}, + {0.001f, 32, 1024, 1024, false, 1234ULL}, + {0.003f, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceCorrelation DistanceCorrelationF; +TEST_P(DistanceCorrelationF, Result) { + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n, + raft::CompareApprox(params.tolerance))); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCorrelationF, + ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.001, 1024, 1024, 32, true, 1234ULL}, + {0.001, 1024, 32, 1024, true, 1234ULL}, + {0.001, 32, 1024, 1024, true, 1234ULL}, + {0.003, 1024, 1024, 1024, true, 1234ULL}, + {0.001, 1024, 1024, 32, false, 1234ULL}, + {0.001, 1024, 32, 1024, false, 1234ULL}, + {0.001, 32, 1024, 1024, false, 1234ULL}, + {0.003, 1024, 1024, 1024, false, 1234ULL}, +}; +typedef DistanceCorrelation DistanceCorrelationD; +TEST_P(DistanceCorrelationD, Result) { + int m = params.isRowMajor ? params.m : params.n; + int n = params.isRowMajor ? params.n : params.m; + ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n, + raft::CompareApprox(params.tolerance))); +} +INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCorrelationD, + ::testing::ValuesIn(inputsd)); + +} // end namespace distance +} // end namespace raft diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index 9d458d9bf2..a2754395c0 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -255,6 +255,42 @@ __global__ void naiveKLDivergenceDistanceKernel(OutType *dist, const DataType *x dist[outidx] = acc; } +template +__global__ void naiveCorrelationDistanceKernel(OutType *dist, const DataType *x, + const DataType *y, int m, int n, + int k, bool isRowMajor) { + int midx = threadIdx.x + blockIdx.x * blockDim.x; + int nidx = threadIdx.y + blockIdx.y * blockDim.y; + if (midx >= m || nidx >= n) return; + OutType acc = OutType(0); + auto a_norm = DataType(0); + auto b_norm = DataType(0); + auto a_sq_norm = DataType(0); + auto b_sq_norm = DataType(0); + for (int i = 0; i < k; ++i) { + int xidx = isRowMajor ? i + midx * k : i * m + midx; + int yidx = isRowMajor ? i + nidx * k : i * n + nidx; + auto a = x[xidx]; + auto b = y[yidx]; + a_norm += a; + b_norm += b; + a_sq_norm += (a * a); + b_sq_norm += (b * b); + acc += (a * b); + } + + auto numer = k * acc - (a_norm * b_norm); + auto Q_denom = k * a_sq_norm - (a_norm * a_norm); + auto R_denom = k * b_sq_norm - (b_norm * b_norm); + + acc = 1 - (numer / raft::mySqrt(Q_denom * R_denom)); + acc = acc * (fabs(acc) >= 0.0001); + + int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; + dist[outidx] = acc; +} + + template void naiveDistance(DataType *dist, const DataType *x, const DataType *y, int m, int n, int k, raft::distance::DistanceType type, @@ -304,6 +340,10 @@ void naiveDistance(DataType *dist, const DataType *x, const DataType *y, int m, naiveKLDivergenceDistanceKernel <<>>(dist, x, y, m, n, k, isRowMajor); break; + case raft::distance::DistanceType::CorrelationExpanded: + naiveCorrelationDistanceKernel + <<>>(dist, x, y, m, n, k, isRowMajor); + break; default: FAIL() << "should be here\n"; } From 74f49a17c7400395c3d0f222fb4314e1efa2f82d Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 30 Jul 2021 21:03:20 +0530 Subject: [PATCH 09/18] fix clang format issues --- cpp/include/raft/distance/correlation.cuh | 83 +++++++++++--------- cpp/include/raft/distance/distance.cuh | 51 ++++++------ cpp/include/raft/distance/hamming.cuh | 42 +++++----- cpp/include/raft/distance/jensen_shannon.cuh | 30 +++---- cpp/include/raft/distance/kl_divergence.cuh | 32 ++++---- cpp/include/raft/distance/russell_rao.cuh | 21 +++-- cpp/include/raft/sparse/cusparse_wrappers.h | 2 +- cpp/test/distance/dist_jensen_shannon.cu | 4 +- cpp/test/distance/dist_kl_divergence.cu | 4 +- cpp/test/distance/distance_base.cuh | 29 ++++--- 10 files changed, 159 insertions(+), 139 deletions(-) diff --git a/cpp/include/raft/distance/correlation.cuh b/cpp/include/raft/distance/correlation.cuh index 5f993bc894..03be13c6c8 100644 --- a/cpp/include/raft/distance/correlation.cuh +++ b/cpp/include/raft/distance/correlation.cuh @@ -15,14 +15,13 @@ */ #pragma once +#include #include #include -#include namespace raft { namespace distance { - /** * @brief the Correlation distance matrix: * @@ -50,10 +49,10 @@ namespace distance { template static void correlationImpl(const DataT *x, const DataT *y, const DataT *xn, - const DataT *yn, const DataT *x2n, const DataT *y2n, - IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, - OutT *dOutput, FinalLambda fin_op, - cudaStream_t stream) { + const DataT *yn, const DataT *x2n, const DataT *y2n, + IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, + IdxT ldd, OutT *dOutput, FinalLambda fin_op, + cudaStream_t stream) { typedef typename raft::linalg::Policy4x4::Policy RowPolicy; typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; @@ -75,9 +74,10 @@ static void correlationImpl(const DataT *x, const DataT *y, const DataT *xn, DataT regx2n[KPolicy::AccRowsPerTh], regy2n[KPolicy::AccColsPerTh]; extern __shared__ char smem[]; - DataT* sx2Norm = (DataT*)(&smem[KPolicy::SmemSize + + DataT *sx2Norm = + (DataT *)(&smem[KPolicy::SmemSize + (KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)]); - DataT* sy2Norm = (&sx2Norm[KPolicy::Mblk]); + DataT *sy2Norm = (&sx2Norm[KPolicy::Mblk]); // Load x & y norms required by this threadblock in shmem buffer if (gridStrideX == blockIdx.x * KPolicy::Nblk) { @@ -95,11 +95,13 @@ static void correlationImpl(const DataT *x, const DataT *y, const DataT *xn, #pragma unroll for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { - regx2n[i] = sx2Norm[i * KPolicy::AccThRows + (threadIdx.x / KPolicy::AccThCols)]; + regx2n[i] = + sx2Norm[i * KPolicy::AccThRows + (threadIdx.x / KPolicy::AccThCols)]; } #pragma unroll for (int i = 0; i < KPolicy::AccColsPerTh; ++i) { - regy2n[i] = sy2Norm[i * KPolicy::AccThCols + (threadIdx.x % KPolicy::AccThCols)]; + regy2n[i] = + sy2Norm[i * KPolicy::AccThCols + (threadIdx.x % KPolicy::AccThCols)]; } #pragma unroll @@ -128,8 +130,8 @@ static void correlationImpl(const DataT *x, const DataT *y, const DataT *xn, dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, correlationRowMajor); correlationRowMajor<<>>( - x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, - epilog_lambda, fin_op); + x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, + fin_op); } else { constexpr auto correlationColMajor = pairwiseDistanceMatKernel(m, n, KPolicy::SmemSize, correlationColMajor); correlationColMajor<<>>( - x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, - epilog_lambda, fin_op); + x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, + fin_op); } CUDA_CHECK(cudaGetLastError()); @@ -148,19 +150,19 @@ static void correlationImpl(const DataT *x, const DataT *y, const DataT *xn, template void correlation(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, - const DataT *x, const DataT *y, const DataT *xn, const DataT *yn, - const DataT *x2n, const DataT *y2n, - OutT *dOutput, FinalLambda fin_op, cudaStream_t stream) { + const DataT *x, const DataT *y, const DataT *xn, + const DataT *yn, const DataT *x2n, const DataT *y2n, + OutT *dOutput, FinalLambda fin_op, cudaStream_t stream) { size_t bytesA = sizeof(DataT) * lda; size_t bytesB = sizeof(DataT) * ldb; if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { correlationImpl(x, y, xn, yn, x2n, y2n, m, n, k, lda, ldb, ldd, - dOutput, fin_op, stream); + isRowMajor>(x, y, xn, yn, x2n, y2n, m, n, k, lda, ldb, ldd, + dOutput, fin_op, stream); } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { correlationImpl(x, y, xn, yn, x2n, y2n, m, n, k, lda, ldb, ldd, - dOutput, fin_op, stream); + isRowMajor>(x, y, xn, yn, x2n, y2n, m, n, k, lda, ldb, ldd, + dOutput, fin_op, stream); } else { correlationImpl( x, y, xn, yn, x2n, y2n, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); @@ -188,8 +190,8 @@ void correlation(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, template void correlationImpl(int m, int n, int k, const InType *pA, const InType *pB, - OutType *pD, AccType *workspace, size_t &worksize, - FinalLambda fin_op, cudaStream_t stream, bool isRowMajor) { + OutType *pD, AccType *workspace, size_t &worksize, + FinalLambda fin_op, cudaStream_t stream, bool isRowMajor) { typedef std::is_same is_bool; typedef typename std::conditional::type correlationOutType; @@ -197,7 +199,8 @@ void correlationImpl(int m, int n, int k, const InType *pA, const InType *pB, correlationOutType *pDcast = reinterpret_cast(pD); ASSERT(!(((pA != pB) && (worksize < 2 * (m + n) * sizeof(AccType))) || - (worksize < 2 * m * sizeof(AccType))), "workspace size error"); + (worksize < 2 * m * sizeof(AccType))), + "workspace size error"); ASSERT(workspace != nullptr, "workspace is null"); AccType *norm_col_vec = workspace; @@ -208,35 +211,39 @@ void correlationImpl(int m, int n, int k, const InType *pA, const InType *pB, norm_row_vec += m; raft::linalg::reduce(norm_col_vec, pA, k, m, (AccType)0, isRowMajor, true, - stream, false, raft::Nop(), raft::Sum()); + stream, false, raft::Nop(), + raft::Sum()); raft::linalg::reduce(norm_row_vec, pB, k, n, (AccType)0, isRowMajor, true, - stream, false, raft::Nop(), raft::Sum()); + stream, false, raft::Nop(), + raft::Sum()); sq_norm_col_vec += (m + n); sq_norm_row_vec = sq_norm_col_vec + m; - raft::linalg::rowNorm(sq_norm_col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, - stream); - raft::linalg::rowNorm(sq_norm_row_vec, pB, k, n, raft::linalg::L2Norm, isRowMajor, - stream); + raft::linalg::rowNorm(sq_norm_col_vec, pA, k, m, raft::linalg::L2Norm, + isRowMajor, stream); + raft::linalg::rowNorm(sq_norm_row_vec, pB, k, n, raft::linalg::L2Norm, + isRowMajor, stream); } else { raft::linalg::reduce(norm_col_vec, pA, k, m, (AccType)0, isRowMajor, true, - stream, false, raft::Nop(), raft::Sum()); + stream, false, raft::Nop(), + raft::Sum()); sq_norm_col_vec += m; sq_norm_row_vec = sq_norm_col_vec; - raft::linalg::rowNorm(sq_norm_col_vec, pA, k, m, raft::linalg::L2Norm, isRowMajor, - stream); + raft::linalg::rowNorm(sq_norm_col_vec, pA, k, m, raft::linalg::L2Norm, + isRowMajor, stream); } if (isRowMajor) { lda = k, ldb = k, ldd = n; correlation( - m, n, k, lda, ldb, ldd, pA, pB, norm_col_vec, norm_row_vec, - sq_norm_col_vec, sq_norm_row_vec, pDcast, fin_op, stream); + m, n, k, lda, ldb, ldd, pA, pB, norm_col_vec, norm_row_vec, + sq_norm_col_vec, sq_norm_row_vec, pDcast, fin_op, stream); } else { lda = n, ldb = m, ldd = m; - correlation( - n, m, k, lda, ldb, ldd, pB, pA, norm_row_vec, norm_col_vec, - sq_norm_row_vec, sq_norm_col_vec, pDcast, fin_op, stream); + correlation(n, m, k, lda, ldb, ldd, pB, pA, norm_row_vec, + norm_col_vec, sq_norm_row_vec, sq_norm_col_vec, pDcast, + fin_op, stream); } } } // namespace distance diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index 46e90f813d..074f14064a 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -27,8 +27,8 @@ #include #include #include -#include #include +#include #include #include #include @@ -178,37 +178,40 @@ struct DistanceImpl -struct DistanceImpl { +struct DistanceImpl { void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, cudaStream_t stream, bool isRowMajor, InType metric_arg) { - raft::distance::hammingUnexpandedImpl(m, n, k, x, y, dist, fin_op, stream, isRowMajor); + raft::distance::hammingUnexpandedImpl(m, n, k, x, y, dist, fin_op, + stream, isRowMajor); } }; template -struct DistanceImpl { +struct DistanceImpl { void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, cudaStream_t stream, bool isRowMajor, InType metric_arg) { - raft::distance::jensenShannonImpl(m, n, k, x, y, dist, fin_op, stream, isRowMajor); + raft::distance::jensenShannonImpl(m, n, k, x, y, dist, fin_op, + stream, isRowMajor); } }; template -struct DistanceImpl { +struct DistanceImpl { void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, cudaStream_t stream, bool isRowMajor, InType metric_arg) { - raft::distance::russellRaoImpl(m, n, k, x, y, dist, fin_op, stream, isRowMajor); + raft::distance::russellRaoImpl(m, n, k, x, y, dist, fin_op, stream, + isRowMajor); } }; @@ -219,21 +222,23 @@ struct DistanceImpl(m, n, k, x, y, dist, fin_op, stream, isRowMajor); + raft::distance::klDivergenceImpl(m, n, k, x, y, dist, fin_op, + stream, isRowMajor); } }; template -struct DistanceImpl { +struct DistanceImpl { void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, cudaStream_t stream, bool isRowMajor, InType metric_arg) { - raft::distance::correlationImpl(m, n, k, x, y, dist, (AccType *)workspace, worksize, - fin_op, stream, isRowMajor); + raft::distance::correlationImpl(m, n, k, x, y, dist, + (AccType *)workspace, worksize, + fin_op, stream, isRowMajor); } }; @@ -261,10 +266,10 @@ size_t getWorkspaceSize(const InType *x, const InType *y, Index_ m, Index_ n, Index_ k) { size_t worksize = 0; constexpr bool is_allocated = - (distanceType <= raft::distance::DistanceType::CosineExpanded) || - (distanceType == raft::distance::DistanceType::CorrelationExpanded); + (distanceType <= raft::distance::DistanceType::CosineExpanded) || + (distanceType == raft::distance::DistanceType::CorrelationExpanded); constexpr int numOfBuffers = - (distanceType == raft::distance::DistanceType::CorrelationExpanded) ? 2 : 1; + (distanceType == raft::distance::DistanceType::CorrelationExpanded) ? 2 : 1; if (is_allocated) { worksize += numOfBuffers * m * sizeof(AccType); diff --git a/cpp/include/raft/distance/hamming.cuh b/cpp/include/raft/distance/hamming.cuh index 4b29708db8..08f1020b85 100644 --- a/cpp/include/raft/distance/hamming.cuh +++ b/cpp/include/raft/distance/hamming.cuh @@ -48,9 +48,10 @@ namespace distance { */ template -static void hammingUnexpandedImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, - IdxT k, IdxT lda, IdxT ldb, IdxT ldd, OutT *dOutput, - FinalLambda fin_op, cudaStream_t stream) { +static void hammingUnexpandedImpl(const DataT *x, const DataT *y, IdxT m, + IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, + OutT *dOutput, FinalLambda fin_op, + cudaStream_t stream) { typedef typename raft::linalg::Policy4x4::Policy RowPolicy; typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; @@ -108,18 +109,18 @@ static void hammingUnexpandedImpl(const DataT *x, const DataT *y, IdxT m, IdxT n template void hammingUnexpanded(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, - const DataT *x, const DataT *y, OutT *dOutput, - FinalLambda fin_op, cudaStream_t stream) { + const DataT *x, const DataT *y, OutT *dOutput, + FinalLambda fin_op, cudaStream_t stream) { size_t bytesA = sizeof(DataT) * lda; size_t bytesB = sizeof(DataT) * ldb; if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { - hammingUnexpandedImpl(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, - stream); + hammingUnexpandedImpl(x, y, m, n, k, lda, ldb, ldd, + dOutput, fin_op, stream); } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { - hammingUnexpandedImpl(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, - stream); + hammingUnexpandedImpl(x, y, m, n, k, lda, ldb, ldd, + dOutput, fin_op, stream); } else { hammingUnexpandedImpl( x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); @@ -148,23 +149,26 @@ void hammingUnexpanded(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, */ template -void hammingUnexpandedImpl(int m, int n, int k, const InType *pA, const InType *pB, - OutType *pD, FinalLambda fin_op, cudaStream_t stream, - bool isRowMajor) { +void hammingUnexpandedImpl(int m, int n, int k, const InType *pA, + const InType *pB, OutType *pD, FinalLambda fin_op, + cudaStream_t stream, bool isRowMajor) { typedef std::is_same is_bool; typedef typename std::conditional::type hammingUnexpandedOutType; Index_ lda, ldb, ldd; - hammingUnexpandedOutType *pDcast = reinterpret_cast(pD); + hammingUnexpandedOutType *pDcast = + reinterpret_cast(pD); if (isRowMajor) { lda = k, ldb = k, ldd = n; - hammingUnexpanded( - m, n, k, lda, ldb, ldd, pA, pB, pDcast, fin_op, stream); + hammingUnexpanded(m, n, k, lda, ldb, ldd, pA, pB, pDcast, + fin_op, stream); } else { lda = n, ldb = m, ldd = m; - hammingUnexpanded( - n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); + hammingUnexpanded(n, m, k, lda, ldb, ldd, pB, pA, + pDcast, fin_op, stream); } } } // namespace distance diff --git a/cpp/include/raft/distance/jensen_shannon.cuh b/cpp/include/raft/distance/jensen_shannon.cuh index cf04397db7..6b65a86ca8 100644 --- a/cpp/include/raft/distance/jensen_shannon.cuh +++ b/cpp/include/raft/distance/jensen_shannon.cuh @@ -61,8 +61,9 @@ DI double fastLog(double x) { template static void jensenShannonImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, - IdxT k, IdxT lda, IdxT ldb, IdxT ldd, OutT *dOutput, - FinalLambda fin_op, cudaStream_t stream) { + IdxT k, IdxT lda, IdxT ldb, IdxT ldd, + OutT *dOutput, FinalLambda fin_op, + cudaStream_t stream) { typedef typename raft::linalg::Policy4x4::Policy RowPolicy; typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; @@ -123,18 +124,18 @@ static void jensenShannonImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, template void jensenShannon(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, - const DataT *x, const DataT *y, OutT *dOutput, - FinalLambda fin_op, cudaStream_t stream) { + const DataT *x, const DataT *y, OutT *dOutput, + FinalLambda fin_op, cudaStream_t stream) { size_t bytesA = sizeof(DataT) * lda; size_t bytesB = sizeof(DataT) * ldb; if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { jensenShannonImpl(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, - stream); + isRowMajor>(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, + stream); } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { jensenShannonImpl(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, - stream); + isRowMajor>(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, + stream); } else { jensenShannonImpl( x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); @@ -165,8 +166,8 @@ void jensenShannon(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, template void jensenShannonImpl(int m, int n, int k, const InType *pA, const InType *pB, - OutType *pD, FinalLambda fin_op, cudaStream_t stream, - bool isRowMajor) { + OutType *pD, FinalLambda fin_op, cudaStream_t stream, + bool isRowMajor) { typedef std::is_same is_bool; typedef typename std::conditional::type jensenShannonOutType; @@ -174,13 +175,14 @@ void jensenShannonImpl(int m, int n, int k, const InType *pA, const InType *pB, jensenShannonOutType *pDcast = reinterpret_cast(pD); if (isRowMajor) { lda = k, ldb = k, ldd = n; - jensenShannon( - m, n, k, lda, ldb, ldd, pA, pB, pDcast, fin_op, stream); + jensenShannon(m, n, k, lda, ldb, ldd, pA, pB, pDcast, fin_op, stream); } else { lda = n, ldb = m, ldd = m; - jensenShannon( - n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); + jensenShannon(n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, + stream); } } } // namespace distance diff --git a/cpp/include/raft/distance/kl_divergence.cuh b/cpp/include/raft/distance/kl_divergence.cuh index 85a273e456..c0a0e15825 100644 --- a/cpp/include/raft/distance/kl_divergence.cuh +++ b/cpp/include/raft/distance/kl_divergence.cuh @@ -15,13 +15,12 @@ */ #pragma once -#include #include +#include namespace raft { namespace distance { - /** * @brief the KL Divergence distance matrix: * It computes the following equation: @@ -51,8 +50,9 @@ namespace distance { template static void klDivergenceImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, - IdxT k, IdxT lda, IdxT ldb, IdxT ldd, OutT *dOutput, - FinalLambda fin_op, cudaStream_t stream) { + IdxT k, IdxT lda, IdxT ldb, IdxT ldd, + OutT *dOutput, FinalLambda fin_op, + cudaStream_t stream) { typedef typename raft::linalg::Policy4x4::Policy RowPolicy; typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; @@ -115,18 +115,18 @@ static void klDivergenceImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, template void klDivergence(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, - const DataT *x, const DataT *y, OutT *dOutput, - FinalLambda fin_op, cudaStream_t stream) { + const DataT *x, const DataT *y, OutT *dOutput, + FinalLambda fin_op, cudaStream_t stream) { size_t bytesA = sizeof(DataT) * lda; size_t bytesB = sizeof(DataT) * ldb; if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { klDivergenceImpl(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, - stream); + isRowMajor>(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, + stream); } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { klDivergenceImpl(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, - stream); + isRowMajor>(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, + stream); } else { klDivergenceImpl( x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); @@ -156,8 +156,8 @@ void klDivergence(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, template void klDivergenceImpl(int m, int n, int k, const InType *pA, const InType *pB, - OutType *pD, FinalLambda fin_op, cudaStream_t stream, - bool isRowMajor) { + OutType *pD, FinalLambda fin_op, cudaStream_t stream, + bool isRowMajor) { typedef std::is_same is_bool; typedef typename std::conditional::type klDivergenceOutType; @@ -165,13 +165,13 @@ void klDivergenceImpl(int m, int n, int k, const InType *pA, const InType *pB, klDivergenceOutType *pDcast = reinterpret_cast(pD); if (isRowMajor) { lda = k, ldb = k, ldd = n; - klDivergence( - m, n, k, lda, ldb, ldd, pA, pB, pDcast, fin_op, stream); + klDivergence(m, n, k, lda, ldb, ldd, pA, pB, pDcast, fin_op, stream); } else { lda = n, ldb = m, ldd = m; - klDivergence( - n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); + klDivergence(n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); } } } // namespace distance diff --git a/cpp/include/raft/distance/russell_rao.cuh b/cpp/include/raft/distance/russell_rao.cuh index 6c6c60fde9..417fb73b94 100644 --- a/cpp/include/raft/distance/russell_rao.cuh +++ b/cpp/include/raft/distance/russell_rao.cuh @@ -20,7 +20,6 @@ namespace raft { namespace distance { - /** * @brief the Russell Rao distance matrix: * It computes the following equation: @@ -50,8 +49,8 @@ namespace distance { template static void russellRaoImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, - IdxT k, IdxT lda, IdxT ldb, IdxT ldd, OutT *dOutput, - FinalLambda fin_op, cudaStream_t stream) { + IdxT k, IdxT lda, IdxT ldb, IdxT ldd, OutT *dOutput, + FinalLambda fin_op, cudaStream_t stream) { typedef typename raft::linalg::Policy4x4::Policy RowPolicy; typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; @@ -109,18 +108,18 @@ static void russellRaoImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, template void russellRao(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, - const DataT *x, const DataT *y, OutT *dOutput, - FinalLambda fin_op, cudaStream_t stream) { + const DataT *x, const DataT *y, OutT *dOutput, + FinalLambda fin_op, cudaStream_t stream) { size_t bytesA = sizeof(DataT) * lda; size_t bytesB = sizeof(DataT) * ldb; if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) { russellRaoImpl(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, - stream); + isRowMajor>(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, + stream); } else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) { russellRaoImpl(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, - stream); + isRowMajor>(x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, + stream); } else { russellRaoImpl( x, y, m, n, k, lda, ldb, ldd, dOutput, fin_op, stream); @@ -150,8 +149,8 @@ void russellRao(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, template void russellRaoImpl(int m, int n, int k, const InType *pA, const InType *pB, - OutType *pD, FinalLambda fin_op, cudaStream_t stream, - bool isRowMajor) { + OutType *pD, FinalLambda fin_op, cudaStream_t stream, + bool isRowMajor) { typedef std::is_same is_bool; typedef typename std::conditional::type russellRaoOutType; diff --git a/cpp/include/raft/sparse/cusparse_wrappers.h b/cpp/include/raft/sparse/cusparse_wrappers.h index 360832f557..d072100672 100644 --- a/cpp/include/raft/sparse/cusparse_wrappers.h +++ b/cpp/include/raft/sparse/cusparse_wrappers.h @@ -55,7 +55,7 @@ namespace detail { inline const char* cusparse_error_to_string(cusparseStatus_t err) { #if defined(CUDART_VERSION) && CUDART_VERSION >= 10100 return cusparseGetErrorString(err); -#else // CUDART_VERSION +#else // CUDART_VERSION switch (err) { _CUSPARSE_ERR_TO_STR(CUSPARSE_STATUS_SUCCESS); _CUSPARSE_ERR_TO_STR(CUSPARSE_STATUS_NOT_INITIALIZED); diff --git a/cpp/test/distance/dist_jensen_shannon.cu b/cpp/test/distance/dist_jensen_shannon.cu index a6fa954042..bc0b56f506 100644 --- a/cpp/test/distance/dist_jensen_shannon.cu +++ b/cpp/test/distance/dist_jensen_shannon.cu @@ -22,8 +22,8 @@ namespace distance { template class DistanceJensenShannon - : public DistanceTest {}; + : public DistanceTest { +}; const std::vector> inputsf = { {0.001f, 1024, 1024, 32, true, 1234ULL}, diff --git a/cpp/test/distance/dist_kl_divergence.cu b/cpp/test/distance/dist_kl_divergence.cu index d5182e90ba..884ac4b948 100644 --- a/cpp/test/distance/dist_kl_divergence.cu +++ b/cpp/test/distance/dist_kl_divergence.cu @@ -22,8 +22,8 @@ namespace distance { template class DistanceKLDivergence - : public DistanceTest {}; + : public DistanceTest { +}; const std::vector> inputsf = { {0.001f, 1024, 1024, 32, true, 1234ULL}, diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index a2754395c0..a98fc6e89f 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -182,8 +182,10 @@ __global__ void naiveHammingDistanceKernel(DataType *dist, const DataType *x, template __global__ void naiveJensenShannonDistanceKernel(DataType *dist, - const DataType *x, const DataType *y, - int m, int n, int k, bool isRowMajor) { + const DataType *x, + const DataType *y, int m, + int n, int k, + bool isRowMajor) { int midx = threadIdx.x + blockIdx.x * blockDim.x; int nidx = threadIdx.y + blockIdx.y * blockDim.y; if (midx >= m || nidx >= n) return; @@ -204,8 +206,8 @@ __global__ void naiveJensenShannonDistanceKernel(DataType *dist, bool p_zero = p == 0; bool q_zero = q == 0; - acc += (-a * (!p_zero * log(p + p_zero))) + - (-b * (!q_zero * log(q + q_zero))); + acc += + (-a * (!p_zero * log(p + p_zero))) + (-b * (!q_zero * log(q + q_zero))); } acc = raft::mySqrt(0.5f * acc); int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; @@ -214,8 +216,8 @@ __global__ void naiveJensenShannonDistanceKernel(DataType *dist, template __global__ void naiveRussellRaoDistanceKernel(OutType *dist, const DataType *x, - const DataType *y, int m, int n, - int k, bool isRowMajor) { + const DataType *y, int m, int n, + int k, bool isRowMajor) { int midx = threadIdx.x + blockIdx.x * blockDim.x; int nidx = threadIdx.y + blockIdx.y * blockDim.y; if (midx >= m || nidx >= n) return; @@ -233,9 +235,10 @@ __global__ void naiveRussellRaoDistanceKernel(OutType *dist, const DataType *x, } template -__global__ void naiveKLDivergenceDistanceKernel(OutType *dist, const DataType *x, - const DataType *y, int m, int n, - int k, bool isRowMajor) { +__global__ void naiveKLDivergenceDistanceKernel(OutType *dist, + const DataType *x, + const DataType *y, int m, int n, + int k, bool isRowMajor) { int midx = threadIdx.x + blockIdx.x * blockDim.x; int nidx = threadIdx.y + blockIdx.y * blockDim.y; if (midx >= m || nidx >= n) return; @@ -257,8 +260,8 @@ __global__ void naiveKLDivergenceDistanceKernel(OutType *dist, const DataType *x template __global__ void naiveCorrelationDistanceKernel(OutType *dist, const DataType *x, - const DataType *y, int m, int n, - int k, bool isRowMajor) { + const DataType *y, int m, int n, + int k, bool isRowMajor) { int midx = threadIdx.x + blockIdx.x * blockDim.x; int nidx = threadIdx.y + blockIdx.y * blockDim.y; if (midx >= m || nidx >= n) return; @@ -290,7 +293,6 @@ __global__ void naiveCorrelationDistanceKernel(OutType *dist, const DataType *x, dist[outidx] = acc; } - template void naiveDistance(DataType *dist, const DataType *x, const DataType *y, int m, int n, int k, raft::distance::DistanceType type, @@ -404,7 +406,8 @@ class DistanceTest : public ::testing::TestWithParam> { // Hellinger works only on positive numbers r.uniform(x, m * k, DataType(0.0), DataType(1.0), stream); r.uniform(y, n * k, DataType(0.0), DataType(1.0), stream); - } else if (distanceType == raft::distance::DistanceType::RusselRaoExpanded) { + } else if (distanceType == + raft::distance::DistanceType::RusselRaoExpanded) { r.uniform(x, m * k, DataType(0.0), DataType(1.0), stream); r.uniform(y, n * k, DataType(0.0), DataType(1.0), stream); // Russel rao works on boolean values. From 8595961a17444c0a40370e86b5aa30bea2e27ec9 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 30 Jul 2021 22:35:17 +0530 Subject: [PATCH 10/18] fix clang format issue in cusparse_wrappers.h --- cpp/include/raft/sparse/cusparse_wrappers.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/sparse/cusparse_wrappers.h b/cpp/include/raft/sparse/cusparse_wrappers.h index d072100672..360832f557 100644 --- a/cpp/include/raft/sparse/cusparse_wrappers.h +++ b/cpp/include/raft/sparse/cusparse_wrappers.h @@ -55,7 +55,7 @@ namespace detail { inline const char* cusparse_error_to_string(cusparseStatus_t err) { #if defined(CUDART_VERSION) && CUDART_VERSION >= 10100 return cusparseGetErrorString(err); -#else // CUDART_VERSION +#else // CUDART_VERSION switch (err) { _CUSPARSE_ERR_TO_STR(CUSPARSE_STATUS_SUCCESS); _CUSPARSE_ERR_TO_STR(CUSPARSE_STATUS_NOT_INITIALIZED); From ddd0f3c66b0f1d9ce7c9bd3bc5a67bbaf2f8c334 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Mon, 2 Aug 2021 21:01:39 +0530 Subject: [PATCH 11/18] replace fast math logf() with slow but more accurate log() --- cpp/include/raft/distance/jensen_shannon.cuh | 15 ++------------- cpp/include/raft/distance/kl_divergence.cuh | 5 ++--- 2 files changed, 4 insertions(+), 16 deletions(-) diff --git a/cpp/include/raft/distance/jensen_shannon.cuh b/cpp/include/raft/distance/jensen_shannon.cuh index 6b65a86ca8..cb0ac59226 100644 --- a/cpp/include/raft/distance/jensen_shannon.cuh +++ b/cpp/include/raft/distance/jensen_shannon.cuh @@ -20,17 +20,6 @@ namespace raft { namespace distance { -template -DI T fastLog(T x); -template <> -DI float fastLog(float x) { - return __logf(x); -} -template <> -DI double fastLog(double x) { - return log(x); -} - /** * @brief the Jensen Shannon distance matrix: * It computes the following equation: @@ -76,9 +65,9 @@ static void jensenShannonImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { const DataT m = 0.5f * (x + y); const bool m_zero = (m == 0); - const auto logM = (!m_zero) * fastLog(m); + const auto logM = (!m_zero) * raft::myLog(m); - acc += (-x * (logM - fastLog(x))) + (-y * (logM - fastLog(y))); + acc += (-x * (logM - raft::myLog(x))) + (-y * (logM - raft::myLog(y))); }; // epilogue operation lambda for final value calculation diff --git a/cpp/include/raft/distance/kl_divergence.cuh b/cpp/include/raft/distance/kl_divergence.cuh index c0a0e15825..1cf989110b 100644 --- a/cpp/include/raft/distance/kl_divergence.cuh +++ b/cpp/include/raft/distance/kl_divergence.cuh @@ -15,7 +15,6 @@ */ #pragma once -#include #include namespace raft { @@ -65,10 +64,10 @@ static void klDivergenceImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { if (isRowMajor) { const bool y_zero = (y == 0); - acc += x * (fastLog(x) - (!y_zero) * fastLog(y)); + acc += x * (raft::myLog(x) - (!y_zero) * raft::myLog(y)); } else { const bool x_zero = (x == 0); - acc += y * (fastLog(y) - (!x_zero) * fastLog(x)); + acc += y * (raft::myLog(y) - (!x_zero) * raft::myLog(x)); } }; From af76b0374d90eeeb924e3e2083d8801755785827 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 5 Aug 2021 20:51:21 +0530 Subject: [PATCH 12/18] improve perf of kl-divergence by 2x by computing log(x) of 1 input which only requires log (x), and post processing revert the log(x) back to x --- cpp/include/raft/distance/kl_divergence.cuh | 43 ++++++++++++++++++--- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/cpp/include/raft/distance/kl_divergence.cuh b/cpp/include/raft/distance/kl_divergence.cuh index 1cf989110b..afb616cf70 100644 --- a/cpp/include/raft/distance/kl_divergence.cuh +++ b/cpp/include/raft/distance/kl_divergence.cuh @@ -24,6 +24,10 @@ namespace distance { * @brief the KL Divergence distance matrix: * It computes the following equation: Cij = 0.5 * sum(x * log (x / y)); + * This distance computation modifies A or B by computing a log(x) + * and then performing a `pow(e, log(x))` to convert it back. Because of this, + * it is possible that the values in A or B might differ slightly + * after this is invoked. * * @tparam DataT input data-type (for A and B matrices) * @tparam AccT accumulation data-type @@ -63,14 +67,31 @@ static void klDivergenceImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, // Accumulation operation lambda auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { if (isRowMajor) { - const bool y_zero = (y == 0); - acc += x * (raft::myLog(x) - (!y_zero) * raft::myLog(y)); + acc += x * (raft::myLog(x) - y); } else { - const bool x_zero = (x == 0); - acc += y * (raft::myLog(y) - (!x_zero) * raft::myLog(x)); + acc += y * (raft::myLog(y) - x); } }; + auto unaryOp_lambda = [] __device__(DataT input) { + const bool x_zero = (input == 0); + return (!x_zero) * raft::myLog(input); + }; + + auto unaryOp_lambda_reverse = [] __device__(DataT input) { + // reverse previous log (x) back to x using (e ^ log(x)) + const bool x_zero = (input == 0); + return (!x_zero) * raft::myPow((DataT)M_E, input); + }; + + if (x != y && isRowMajor) { + raft::linalg::unaryOp( + (DataT *)y, y, n * k, unaryOp_lambda, stream); + } else if (x != y && !isRowMajor) { + raft::linalg::unaryOp( + (DataT *)x, x, m * k, unaryOp_lambda, stream); + } + // epilogue operation lambda for final value calculation auto epilog_lambda = [] __device__( AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], @@ -108,6 +129,15 @@ static void klDivergenceImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, epilog_lambda, fin_op); } + // Now reverse previous log (x) back to x using (e ^ log(x)) + if (x != y && isRowMajor) { + raft::linalg::unaryOp( + (DataT *)y, y, n * k, unaryOp_lambda_reverse, stream); + } else if (x != y && !isRowMajor) { + raft::linalg::unaryOp( + (DataT *)x, x, m * k, unaryOp_lambda_reverse, stream); + } + CUDA_CHECK(cudaGetLastError()); } @@ -136,7 +166,10 @@ void klDivergence(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, * @brief the KL Divergence distance matrix calculation * It computes the following equation: Cij = 0.5 * sum(x * log (x / y)); - * + * This distance computation modifies A or B by computing a log(x) + * and then performing a `pow(e, log(x))` to convert it back. Because of this, + * it is possible that the values in A or B might differ slightly + * after this is invoked. * @tparam InType input data-type (for A and B matrices) * @tparam AccType accumulation data-type * @tparam OutType output data-type (for C and D matrices) From efb1e79937a703857e19b7794c763f0597d4bda2 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Mon, 9 Aug 2021 23:49:08 +0530 Subject: [PATCH 13/18] remove the instabilities correction in correlation kernel as it doesn't match with sklearn results otherwise --- cpp/include/raft/distance/correlation.cuh | 3 --- cpp/test/distance/distance_base.cuh | 1 - 2 files changed, 4 deletions(-) diff --git a/cpp/include/raft/distance/correlation.cuh b/cpp/include/raft/distance/correlation.cuh index 03be13c6c8..ed3b7a5464 100644 --- a/cpp/include/raft/distance/correlation.cuh +++ b/cpp/include/raft/distance/correlation.cuh @@ -113,9 +113,6 @@ static void correlationImpl(const DataT *x, const DataT *y, const DataT *xn, auto R_denom = k * regy2n[j] - (regyn[j] * regyn[j]); acc[i][j] = 1 - (numer / raft::mySqrt(Q_denom * R_denom)); - - // correct for small instabilities - acc[i][j] = acc[i][j] * (fabs(acc[i][j]) >= 0.0001); } } }; diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index a98fc6e89f..9e3290593d 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -287,7 +287,6 @@ __global__ void naiveCorrelationDistanceKernel(OutType *dist, const DataType *x, auto R_denom = k * b_sq_norm - (b_norm * b_norm); acc = 1 - (numer / raft::mySqrt(Q_denom * R_denom)); - acc = acc * (fabs(acc) >= 0.0001); int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; dist[outidx] = acc; From 4bb719f0b16fa3b3a660357e10e1ed24f093f355 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Tue, 10 Aug 2021 23:42:44 +0530 Subject: [PATCH 14/18] fix nans in jensenshannon due to log of 0 --- cpp/include/raft/distance/jensen_shannon.cuh | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/distance/jensen_shannon.cuh b/cpp/include/raft/distance/jensen_shannon.cuh index cb0ac59226..4501069e59 100644 --- a/cpp/include/raft/distance/jensen_shannon.cuh +++ b/cpp/include/raft/distance/jensen_shannon.cuh @@ -65,9 +65,11 @@ static void jensenShannonImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { const DataT m = 0.5f * (x + y); const bool m_zero = (m == 0); - const auto logM = (!m_zero) * raft::myLog(m); + const auto logM = (!m_zero) * raft::myLog(m + m_zero); - acc += (-x * (logM - raft::myLog(x))) + (-y * (logM - raft::myLog(y))); + const bool x_zero = (x == 0); + const bool y_zero = (y == 0); + acc += (-x * (logM - raft::myLog(x + x_zero))) + (-y * (logM - raft::myLog(y + y_zero))); }; // epilogue operation lambda for final value calculation From 56a830a3eb3fc191413dba507b328b093c038665 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Tue, 10 Aug 2021 23:44:10 +0530 Subject: [PATCH 15/18] fix clang format issues --- cpp/include/raft/distance/jensen_shannon.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp/include/raft/distance/jensen_shannon.cuh b/cpp/include/raft/distance/jensen_shannon.cuh index 4501069e59..2a94205853 100644 --- a/cpp/include/raft/distance/jensen_shannon.cuh +++ b/cpp/include/raft/distance/jensen_shannon.cuh @@ -69,7 +69,8 @@ static void jensenShannonImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, const bool x_zero = (x == 0); const bool y_zero = (y == 0); - acc += (-x * (logM - raft::myLog(x + x_zero))) + (-y * (logM - raft::myLog(y + y_zero))); + acc += (-x * (logM - raft::myLog(x + x_zero))) + + (-y * (logM - raft::myLog(y + y_zero))); }; // epilogue operation lambda for final value calculation From 09632ecb4a47bc2a21687faa68458a2514f2e464 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Wed, 11 Aug 2021 19:11:13 +0530 Subject: [PATCH 16/18] fix issues of nans in kl-divergence caused by log(0) --- cpp/include/raft/distance/kl_divergence.cuh | 96 ++++++++++++++------- 1 file changed, 64 insertions(+), 32 deletions(-) diff --git a/cpp/include/raft/distance/kl_divergence.cuh b/cpp/include/raft/distance/kl_divergence.cuh index afb616cf70..3197b73d10 100644 --- a/cpp/include/raft/distance/kl_divergence.cuh +++ b/cpp/include/raft/distance/kl_divergence.cuh @@ -67,31 +67,39 @@ static void klDivergenceImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, // Accumulation operation lambda auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { if (isRowMajor) { - acc += x * (raft::myLog(x) - y); + const bool x_zero = (x == 0); + acc += x * (raft::myLog(x + x_zero) - y); } else { - acc += y * (raft::myLog(y) - x); + const bool y_zero = (y == 0); + acc += y * (raft::myLog(y + y_zero) - x); + } + }; + + auto core_lambda_x_equal_y = [] __device__(AccT & acc, DataT & x, DataT & y) { + if (isRowMajor) { + const bool x_zero = (x == 0); + const bool y_zero = (y == 0); + acc += + x * (raft::myLog(x + x_zero) - (!y_zero) * raft::myLog(y + y_zero)); + } else { + const bool y_zero = (y == 0); + const bool x_zero = (x == 0); + acc += + y * (raft::myLog(y + y_zero) - (!x_zero) * raft::myLog(x + x_zero)); } }; auto unaryOp_lambda = [] __device__(DataT input) { const bool x_zero = (input == 0); - return (!x_zero) * raft::myLog(input); + return (!x_zero) * raft::myLog(input + x_zero); }; auto unaryOp_lambda_reverse = [] __device__(DataT input) { // reverse previous log (x) back to x using (e ^ log(x)) const bool x_zero = (input == 0); - return (!x_zero) * raft::myPow((DataT)M_E, input); + return (!x_zero) * raft::myExp(input); }; - if (x != y && isRowMajor) { - raft::linalg::unaryOp( - (DataT *)y, y, n * k, unaryOp_lambda, stream); - } else if (x != y && !isRowMajor) { - raft::linalg::unaryOp( - (DataT *)x, x, m * k, unaryOp_lambda, stream); - } - // epilogue operation lambda for final value calculation auto epilog_lambda = [] __device__( AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], @@ -111,31 +119,55 @@ static void klDivergenceImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, - klDivergenceRowMajor); - - klDivergenceRowMajor<<>>( - x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, - epilog_lambda, fin_op); + constexpr auto klDivergenceRowMajorXequalY = + pairwiseDistanceMatKernel; + if (x != y) { + raft::linalg::unaryOp( + (DataT *)y, y, n * k, unaryOp_lambda, stream); + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + klDivergenceRowMajor); + klDivergenceRowMajor<<>>( + x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); + // Now reverse previous log (x) back to x using (e ^ log(x)) + raft::linalg::unaryOp( + (DataT *)y, y, n * k, unaryOp_lambda_reverse, stream); + } else { + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + klDivergenceRowMajorXequalY); + klDivergenceRowMajorXequalY<<>>( + x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, + core_lambda_x_equal_y, epilog_lambda, fin_op); + } } else { constexpr auto klDivergenceColMajor = pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, - klDivergenceColMajor); - klDivergenceColMajor<<>>( - x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, - epilog_lambda, fin_op); - } - - // Now reverse previous log (x) back to x using (e ^ log(x)) - if (x != y && isRowMajor) { - raft::linalg::unaryOp( - (DataT *)y, y, n * k, unaryOp_lambda_reverse, stream); - } else if (x != y && !isRowMajor) { - raft::linalg::unaryOp( - (DataT *)x, x, m * k, unaryOp_lambda_reverse, stream); + constexpr auto klDivergenceColMajorXequalY = + pairwiseDistanceMatKernel; + if (x != y) { + raft::linalg::unaryOp( + (DataT *)x, x, m * k, unaryOp_lambda, stream); + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + klDivergenceColMajor); + klDivergenceColMajor<<>>( + x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, + epilog_lambda, fin_op); + // Now reverse previous log (x) back to x using (e ^ log(x)) + raft::linalg::unaryOp( + (DataT *)x, x, m * k, unaryOp_lambda_reverse, stream); + } else { + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + klDivergenceColMajorXequalY); + klDivergenceColMajorXequalY<<>>( + x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, + core_lambda_x_equal_y, epilog_lambda, fin_op); + } } CUDA_CHECK(cudaGetLastError()); From 13e2ad5554fe49180449a32c0a6235c028f96747 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 12 Aug 2021 19:00:10 +0530 Subject: [PATCH 17/18] fix build warnings reported as error in cuda 11.4 --- cpp/include/raft/distance/distance.cuh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index cecf712773..c4d6b52acf 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -182,7 +182,7 @@ struct DistanceImpl { void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, - cudaStream_t stream, bool isRowMajor, InType metric_arg) { + cudaStream_t stream, bool isRowMajor, InType) { raft::distance::hammingUnexpandedImpl(m, n, k, x, y, dist, fin_op, stream, isRowMajor); @@ -195,7 +195,7 @@ struct DistanceImpl { void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, - cudaStream_t stream, bool isRowMajor, InType metric_arg) { + cudaStream_t stream, bool isRowMajor, InType) { raft::distance::jensenShannonImpl(m, n, k, x, y, dist, fin_op, stream, isRowMajor); @@ -208,7 +208,7 @@ struct DistanceImpl { void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, - cudaStream_t stream, bool isRowMajor, InType metric_arg) { + cudaStream_t stream, bool isRowMajor, InType) { raft::distance::russellRaoImpl(m, n, k, x, y, dist, fin_op, stream, isRowMajor); @@ -221,7 +221,7 @@ struct DistanceImpl { void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, - cudaStream_t stream, bool isRowMajor, InType metric_arg) { + cudaStream_t stream, bool isRowMajor, InType) { raft::distance::klDivergenceImpl(m, n, k, x, y, dist, fin_op, stream, isRowMajor); @@ -234,7 +234,7 @@ struct DistanceImpl { void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, - cudaStream_t stream, bool isRowMajor, InType metric_arg) { + cudaStream_t stream, bool isRowMajor, InType) { raft::distance::correlationImpl(m, n, k, x, y, dist, (AccType *)workspace, worksize, From 8b3e7ab4dc74622673b95c0a6fcaa34fb99b4341 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Thu, 12 Aug 2021 20:52:38 +0530 Subject: [PATCH 18/18] further fix warnings reported as error in 11.4 --- cpp/include/raft/distance/distance.cuh | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index c4d6b52acf..02d8fb6d03 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -181,8 +181,8 @@ template { void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, - Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, - cudaStream_t stream, bool isRowMajor, InType) { + Index_ k, void *, size_t, FinalLambda fin_op, cudaStream_t stream, + bool isRowMajor, InType) { raft::distance::hammingUnexpandedImpl(m, n, k, x, y, dist, fin_op, stream, isRowMajor); @@ -194,8 +194,8 @@ template { void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, - Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, - cudaStream_t stream, bool isRowMajor, InType) { + Index_ k, void *, size_t, FinalLambda fin_op, cudaStream_t stream, + bool isRowMajor, InType) { raft::distance::jensenShannonImpl(m, n, k, x, y, dist, fin_op, stream, isRowMajor); @@ -207,8 +207,8 @@ template { void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, - Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, - cudaStream_t stream, bool isRowMajor, InType) { + Index_ k, void *, size_t, FinalLambda fin_op, cudaStream_t stream, + bool isRowMajor, InType) { raft::distance::russellRaoImpl(m, n, k, x, y, dist, fin_op, stream, isRowMajor); @@ -220,8 +220,8 @@ template { void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, - Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, - cudaStream_t stream, bool isRowMajor, InType) { + Index_ k, void *, size_t, FinalLambda fin_op, cudaStream_t stream, + bool isRowMajor, InType) { raft::distance::klDivergenceImpl(m, n, k, x, y, dist, fin_op, stream, isRowMajor);