diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index 5fcaf07539..ae8230984a 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -16,16 +16,23 @@ #include "../test_utils.cuh" #include -#include -#include -#include -#include -#include -#include +#include // common::nvtx::range + +#include // raft::device_resources +#include // raft::sqrt +#include // raft::distance::DistanceType +#include +#include // rmm::device_uvector + +// When the distance library is precompiled, include only the raft_runtime +// headers. This way, a small change in one of the kernel internals does not +// trigger a rebuild of the test files (it of course still triggers a rebuild of +// the raft specializations) #if defined RAFT_DISTANCE_COMPILED -#include +#include +#else +#include #endif -#include namespace raft { namespace distance { @@ -409,6 +416,25 @@ template return os; } +// TODO: Remove when mdspan-based raft::runtime::distance::pairwise_distance is +// implemented. +// +// Context: +// https://github.com/rapidsai/raft/issues/1338 +template +constexpr bool layout_to_row_major(); + +template <> +constexpr bool layout_to_row_major() +{ + return true; +} +template <> +constexpr bool layout_to_row_major() +{ + return false; +} + template void distanceLauncher(raft::device_resources const& handle, DataType* x, @@ -422,12 +448,23 @@ void distanceLauncher(raft::device_resources const& handle, DataType threshold, DataType metric_arg = 2.0f) { +#if defined RAFT_DISTANCE_COMPILED + // TODO: Implement and use mdspan-based + // raft::runtime::distance::pairwise_distance here. + // + // Context: + // https://github.com/rapidsai/raft/issues/1338 + bool row_major = layout_to_row_major(); + raft::runtime::distance::pairwise_distance( + handle, x, y, dist, m, n, k, distanceType, row_major, metric_arg); +#else auto x_v = make_device_matrix_view(x, m, k); auto y_v = make_device_matrix_view(y, n, k); auto dist_v = make_device_matrix_view(dist, m, n); raft::distance::distance( handle, x_v, y_v, dist_v, metric_arg); +#endif } template @@ -523,9 +560,25 @@ class BigMatrixDistanceTest : public ::testing::Test { auto testInfo = testing::UnitTest::GetInstance()->current_test_info(); common::nvtx::range fun_scope("test::%s/%s", testInfo->test_suite_name(), testInfo->name()); + void pairwise_distance(raft::device_resources const& handle, + float* x, + float* y, + float* dists, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg); + bool row_major = true; + float metric_arg = 0.0f; +#if defined RAFT_DISTANCE_COMPILED + raft::runtime::distance::pairwise_distance( + handle, x.data(), x.data(), dist.data(), m, n, k, distanceType, row_major, metric_arg); +#else raft::distance::distance( - handle, x.data(), x.data(), dist.data(), m, n, k, true, 0.0f); - + handle, x.data(), x.data(), dist.data(), m, n, k, row_major, metric_arg); +#endif RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); }