diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index 07643bc4ea..44f93bff8a 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -255,14 +255,13 @@ __global__ void naiveKLDivergenceDistanceKernel( if (midx >= m || nidx >= n) return; OutType acc = OutType(0); for (int i = 0; i < k; ++i) { - int xidx = isRowMajor ? i + midx * k : i * m + midx; - int yidx = isRowMajor ? i + nidx * k : i * n + nidx; - auto a = x[xidx]; - auto b = y[yidx]; - bool b_zero = (b == 0); - const auto m = (!b_zero) * (a / b); - const bool m_zero = (m == 0); - acc += (a * (!m_zero) * log(m + m_zero)); + int xidx = isRowMajor ? i + midx * k : i * m + midx; + int yidx = isRowMajor ? i + nidx * k : i * n + nidx; + auto a = x[xidx]; + auto b = y[yidx]; + bool b_zero = (b == 0); + bool a_zero = (a == 0); + acc += a * (log(a + a_zero) - log(b + b_zero)); } acc = 0.5f * acc; int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; @@ -450,10 +449,6 @@ class DistanceTest : public ::testing::TestWithParam> { } naiveDistance( dist_ref.data(), x.data(), y.data(), m, n, k, distanceType, isRowMajor, metric_arg, stream); - // size_t worksize = raft::distance::getWorkspaceSize( - // x.data(), y.data(), m, n, k); - // rmm::device_uvector workspace(worksize, stream); DataType threshold = -10000.f;