From 316d23dfccbe7ffb5174e9fa5bb12958b9f97d4b Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 24 Jun 2022 20:15:05 +0530 Subject: [PATCH 1/2] fix nans in naive kl divergence introduced by div by 0, use similar formula as one used by main kernel --- cpp/test/distance/distance_base.cuh | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index 07643bc4ea..3c2864291c 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -260,9 +260,8 @@ __global__ void naiveKLDivergenceDistanceKernel( auto a = x[xidx]; auto b = y[yidx]; bool b_zero = (b == 0); - const auto m = (!b_zero) * (a / b); - const bool m_zero = (m == 0); - acc += (a * (!m_zero) * log(m + m_zero)); + bool a_zero = (a == 0); + acc += a * (log(a + a_zero) - log(b + b_zero)); } acc = 0.5f * acc; int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; @@ -450,10 +449,6 @@ class DistanceTest : public ::testing::TestWithParam> { } naiveDistance( dist_ref.data(), x.data(), y.data(), m, n, k, distanceType, isRowMajor, metric_arg, stream); - // size_t worksize = raft::distance::getWorkspaceSize( - // x.data(), y.data(), m, n, k); - // rmm::device_uvector workspace(worksize, stream); DataType threshold = -10000.f; From d65012174e9e661581098b3ac8f8b3ae4ab93de0 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 24 Jun 2022 20:27:08 +0530 Subject: [PATCH 2/2] fix clang format issues --- cpp/test/distance/distance_base.cuh | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index 3c2864291c..44f93bff8a 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -255,12 +255,12 @@ __global__ void naiveKLDivergenceDistanceKernel( if (midx >= m || nidx >= n) return; OutType acc = OutType(0); for (int i = 0; i < k; ++i) { - int xidx = isRowMajor ? i + midx * k : i * m + midx; - int yidx = isRowMajor ? i + nidx * k : i * n + nidx; - auto a = x[xidx]; - auto b = y[yidx]; - bool b_zero = (b == 0); - bool a_zero = (a == 0); + int xidx = isRowMajor ? i + midx * k : i * m + midx; + int yidx = isRowMajor ? i + nidx * k : i * n + nidx; + auto a = x[xidx]; + auto b = y[yidx]; + bool b_zero = (b == 0); + bool a_zero = (a == 0); acc += a * (log(a + a_zero) - log(b + b_zero)); } acc = 0.5f * acc;