Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Update RAFT test directory #359

Merged
merged 3 commits into from
Oct 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 27 additions & 19 deletions cpp/test/distance/dist_adj.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <raft/cuda_utils.cuh>
#include <raft/distance/distance.cuh>
#include <raft/random/rng.cuh>
#include <rmm/device_uvector.hpp>
#include "../test_utils.h"

namespace raft {
Expand Down Expand Up @@ -70,47 +71,52 @@ template <typename DataType>
class DistanceAdjTest
: public ::testing::TestWithParam<DistanceAdjInputs<DataType>> {
public:
DistanceAdjTest()
: params(::testing::TestWithParam<DistanceAdjInputs<DataType>>::GetParam()),
stream(handle.get_stream()),
dist(params.m * params.n, stream),
dist_ref(params.m * params.n, stream) {}

void SetUp() override {
params = ::testing::TestWithParam<DistanceAdjInputs<DataType>>::GetParam();
raft::random::Rng r(params.seed);
int m = params.m;
int n = params.n;
int k = params.k;
bool isRowMajor = params.isRowMajor;
CUDA_CHECK(cudaStreamCreate(&stream));
raft::allocate(x, m * k, stream);
raft::allocate(y, n * k, stream);
raft::allocate(dist_ref, m * n, stream);
raft::allocate(dist, m * n, stream);
r.uniform(x, m * k, DataType(-1.0), DataType(1.0), stream);
r.uniform(y, n * k, DataType(-1.0), DataType(1.0), stream);

rmm::device_uvector<DataType> x(m * k, stream);
rmm::device_uvector<DataType> y(n * k, stream);

r.uniform(x.data(), m * k, DataType(-1.0), DataType(1.0), stream);
r.uniform(y.data(), n * k, DataType(-1.0), DataType(1.0), stream);

DataType threshold = params.eps;

naiveDistanceAdj(dist_ref, x, y, m, n, k, threshold, isRowMajor);
char *workspace = nullptr;
naiveDistanceAdj(dist_ref.data(), x.data(), y.data(), m, n, k, threshold,
isRowMajor);
size_t worksize =
raft::distance::getWorkspaceSize<raft::distance::DistanceType::L2Expanded,
DataType, DataType, bool>(x, y, m, n, k);
if (worksize != 0) {
raft::allocate(workspace, worksize, stream);
}
DataType, DataType, bool>(
x.data(), y.data(), m, n, k);
rmm::device_uvector<char> workspace(worksize, stream);

auto fin_op = [threshold] __device__(DataType d_val, int g_d_idx) {
return d_val <= threshold;
};
raft::distance::distance<raft::distance::DistanceType::L2Expanded, DataType,
DataType, bool>(
x, y, dist, m, n, k, workspace, worksize, fin_op, stream, isRowMajor);
x.data(), y.data(), dist.data(), m, n, k, workspace.data(),
workspace.size(), fin_op, stream, isRowMajor);
CUDA_CHECK(cudaStreamSynchronize(stream));
}

void TearDown() override {}

protected:
DistanceAdjInputs<DataType> params;
DataType *x, *y;
bool *dist_ref, *dist;
rmm::device_uvector<bool> dist_ref;
rmm::device_uvector<bool> dist;
raft::handle_t handle;
cudaStream_t stream;
};

Expand All @@ -128,7 +134,8 @@ typedef DistanceAdjTest<float> DistanceAdjTestF;
TEST_P(DistanceAdjTestF, Result) {
int m = params.isRowMajor ? params.m : params.n;
int n = params.isRowMajor ? params.n : params.m;
ASSERT_TRUE(devArrMatch(dist_ref, dist, m, n, raft::Compare<bool>()));
ASSERT_TRUE(
devArrMatch(dist_ref.data(), dist.data(), m, n, raft::Compare<bool>()));
}
INSTANTIATE_TEST_CASE_P(DistanceAdjTests, DistanceAdjTestF,
::testing::ValuesIn(inputsf));
Expand All @@ -147,7 +154,8 @@ typedef DistanceAdjTest<double> DistanceAdjTestD;
TEST_P(DistanceAdjTestD, Result) {
int m = params.isRowMajor ? params.m : params.n;
int n = params.isRowMajor ? params.n : params.m;
ASSERT_TRUE(devArrMatch(dist_ref, dist, m, n, raft::Compare<bool>()));
ASSERT_TRUE(
devArrMatch(dist_ref.data(), dist.data(), m, n, raft::Compare<bool>()));
}
INSTANTIATE_TEST_CASE_P(DistanceAdjTests, DistanceAdjTestD,
::testing::ValuesIn(inputsd));
Expand Down
4 changes: 2 additions & 2 deletions cpp/test/distance/dist_canberra.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ typedef DistanceCanberra<float> DistanceCanberraF;
TEST_P(DistanceCanberraF, Result) {
int m = params.isRowMajor ? params.m : params.n;
int n = params.isRowMajor ? params.n : params.m;
ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n,
ASSERT_TRUE(raft::devArrMatch(dist_ref.data(), dist.data(), m, n,
raft::CompareApprox<float>(params.tolerance)));
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCanberraF,
Expand All @@ -58,7 +58,7 @@ typedef DistanceCanberra<double> DistanceCanberraD;
TEST_P(DistanceCanberraD, Result) {
int m = params.isRowMajor ? params.m : params.n;
int n = params.isRowMajor ? params.n : params.m;
ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n,
ASSERT_TRUE(raft::devArrMatch(dist_ref.data(), dist.data(), m, n,
raft::CompareApprox<double>(params.tolerance)));
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCanberraD,
Expand Down
4 changes: 2 additions & 2 deletions cpp/test/distance/dist_chebyshev.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ typedef DistanceLinf<float> DistanceLinfF;
TEST_P(DistanceLinfF, Result) {
int m = params.isRowMajor ? params.m : params.n;
int n = params.isRowMajor ? params.n : params.m;
ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n,
ASSERT_TRUE(raft::devArrMatch(dist_ref.data(), dist.data(), m, n,
raft::CompareApprox<float>(params.tolerance)));
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceLinfF,
Expand All @@ -58,7 +58,7 @@ typedef DistanceLinf<double> DistanceLinfD;
TEST_P(DistanceLinfD, Result) {
int m = params.isRowMajor ? params.m : params.n;
int n = params.isRowMajor ? params.n : params.m;
ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n,
ASSERT_TRUE(raft::devArrMatch(dist_ref.data(), dist.data(), m, n,
raft::CompareApprox<double>(params.tolerance)));
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceLinfD,
Expand Down
4 changes: 2 additions & 2 deletions cpp/test/distance/dist_correlation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ typedef DistanceCorrelation<float> DistanceCorrelationF;
TEST_P(DistanceCorrelationF, Result) {
int m = params.isRowMajor ? params.m : params.n;
int n = params.isRowMajor ? params.n : params.m;
ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n,
ASSERT_TRUE(raft::devArrMatch(dist_ref.data(), dist.data(), m, n,
raft::CompareApprox<float>(params.tolerance)));
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCorrelationF,
Expand All @@ -59,7 +59,7 @@ typedef DistanceCorrelation<double> DistanceCorrelationD;
TEST_P(DistanceCorrelationD, Result) {
int m = params.isRowMajor ? params.m : params.n;
int n = params.isRowMajor ? params.n : params.m;
ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n,
ASSERT_TRUE(raft::devArrMatch(dist_ref.data(), dist.data(), m, n,
raft::CompareApprox<double>(params.tolerance)));
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCorrelationD,
Expand Down
4 changes: 2 additions & 2 deletions cpp/test/distance/dist_cos.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ typedef DistanceExpCos<float> DistanceExpCosF;
TEST_P(DistanceExpCosF, Result) {
int m = params.isRowMajor ? params.m : params.n;
int n = params.isRowMajor ? params.n : params.m;
ASSERT_TRUE(devArrMatch(dist_ref, dist, m, n,
ASSERT_TRUE(devArrMatch(dist_ref.data(), dist.data(), m, n,
raft::CompareApprox<float>(params.tolerance)));
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpCosF,
Expand All @@ -59,7 +59,7 @@ typedef DistanceExpCos<double> DistanceExpCosD;
TEST_P(DistanceExpCosD, Result) {
int m = params.isRowMajor ? params.m : params.n;
int n = params.isRowMajor ? params.n : params.m;
ASSERT_TRUE(devArrMatch(dist_ref, dist, m, n,
ASSERT_TRUE(devArrMatch(dist_ref.data(), dist.data(), m, n,
raft::CompareApprox<double>(params.tolerance)));
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpCosD,
Expand Down
4 changes: 2 additions & 2 deletions cpp/test/distance/dist_euc_exp.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ typedef DistanceEucExpTest<float> DistanceEucExpTestF;
TEST_P(DistanceEucExpTestF, Result) {
int m = params.isRowMajor ? params.m : params.n;
int n = params.isRowMajor ? params.n : params.m;
ASSERT_TRUE(devArrMatch(dist_ref, dist, m, n,
ASSERT_TRUE(devArrMatch(dist_ref.data(), dist.data(), m, n,
raft::CompareApprox<float>(params.tolerance)));
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucExpTestF,
Expand All @@ -58,7 +58,7 @@ typedef DistanceEucExpTest<double> DistanceEucExpTestD;
TEST_P(DistanceEucExpTestD, Result) {
int m = params.isRowMajor ? params.m : params.n;
int n = params.isRowMajor ? params.n : params.m;
ASSERT_TRUE(devArrMatch(dist_ref, dist, m, n,
ASSERT_TRUE(devArrMatch(dist_ref.data(), dist.data(), m, n,
raft::CompareApprox<double>(params.tolerance)));
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucExpTestD,
Expand Down
4 changes: 2 additions & 2 deletions cpp/test/distance/dist_euc_unexp.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ typedef DistanceEucUnexpTest<float> DistanceEucUnexpTestF;
TEST_P(DistanceEucUnexpTestF, Result) {
int m = params.isRowMajor ? params.m : params.n;
int n = params.isRowMajor ? params.n : params.m;
ASSERT_TRUE(devArrMatch(dist_ref, dist, m, n,
ASSERT_TRUE(devArrMatch(dist_ref.data(), dist.data(), m, n,
raft::CompareApprox<float>(params.tolerance)));
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucUnexpTestF,
Expand All @@ -59,7 +59,7 @@ typedef DistanceEucUnexpTest<double> DistanceEucUnexpTestD;
TEST_P(DistanceEucUnexpTestD, Result) {
int m = params.isRowMajor ? params.m : params.n;
int n = params.isRowMajor ? params.n : params.m;
ASSERT_TRUE(devArrMatch(dist_ref, dist, m, n,
ASSERT_TRUE(devArrMatch(dist_ref.data(), dist.data(), m, n,
raft::CompareApprox<double>(params.tolerance)));
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucUnexpTestD,
Expand Down
4 changes: 2 additions & 2 deletions cpp/test/distance/dist_hamming.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ typedef DistanceHamming<float> DistanceHammingF;
TEST_P(DistanceHammingF, Result) {
int m = params.isRowMajor ? params.m : params.n;
int n = params.isRowMajor ? params.n : params.m;
ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n,
ASSERT_TRUE(raft::devArrMatch(dist_ref.data(), dist.data(), m, n,
raft::CompareApprox<float>(params.tolerance)));
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceHammingF,
Expand All @@ -59,7 +59,7 @@ typedef DistanceHamming<double> DistanceHammingD;
TEST_P(DistanceHammingD, Result) {
int m = params.isRowMajor ? params.m : params.n;
int n = params.isRowMajor ? params.n : params.m;
ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n,
ASSERT_TRUE(raft::devArrMatch(dist_ref.data(), dist.data(), m, n,
raft::CompareApprox<double>(params.tolerance)));
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceHammingD,
Expand Down
4 changes: 2 additions & 2 deletions cpp/test/distance/dist_hellinger.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ typedef DistanceHellingerExp<float> DistanceHellingerExpF;
TEST_P(DistanceHellingerExpF, Result) {
int m = params.isRowMajor ? params.m : params.n;
int n = params.isRowMajor ? params.n : params.m;
ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n,
ASSERT_TRUE(raft::devArrMatch(dist_ref.data(), dist.data(), m, n,
raft::CompareApprox<float>(params.tolerance)));
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceHellingerExpF,
Expand All @@ -59,7 +59,7 @@ typedef DistanceHellingerExp<double> DistanceHellingerExpD;
TEST_P(DistanceHellingerExpD, Result) {
int m = params.isRowMajor ? params.m : params.n;
int n = params.isRowMajor ? params.n : params.m;
ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n,
ASSERT_TRUE(raft::devArrMatch(dist_ref.data(), dist.data(), m, n,
raft::CompareApprox<double>(params.tolerance)));
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceHellingerExpD,
Expand Down
4 changes: 2 additions & 2 deletions cpp/test/distance/dist_jensen_shannon.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ typedef DistanceJensenShannon<float> DistanceJensenShannonF;
TEST_P(DistanceJensenShannonF, Result) {
int m = params.isRowMajor ? params.m : params.n;
int n = params.isRowMajor ? params.n : params.m;
ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n,
ASSERT_TRUE(raft::devArrMatch(dist_ref.data(), dist.data(), m, n,
raft::CompareApprox<float>(params.tolerance)));
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceJensenShannonF,
Expand All @@ -59,7 +59,7 @@ typedef DistanceJensenShannon<double> DistanceJensenShannonD;
TEST_P(DistanceJensenShannonD, Result) {
int m = params.isRowMajor ? params.m : params.n;
int n = params.isRowMajor ? params.n : params.m;
ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n,
ASSERT_TRUE(raft::devArrMatch(dist_ref.data(), dist.data(), m, n,
raft::CompareApprox<double>(params.tolerance)));
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceJensenShannonD,
Expand Down
4 changes: 2 additions & 2 deletions cpp/test/distance/dist_kl_divergence.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ typedef DistanceKLDivergence<float> DistanceKLDivergenceF;
TEST_P(DistanceKLDivergenceF, Result) {
int m = params.isRowMajor ? params.m : params.n;
int n = params.isRowMajor ? params.n : params.m;
ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n,
ASSERT_TRUE(raft::devArrMatch(dist_ref.data(), dist.data(), m, n,
raft::CompareApprox<float>(params.tolerance)));
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceKLDivergenceF,
Expand All @@ -59,7 +59,7 @@ typedef DistanceKLDivergence<double> DistanceKLDivergenceD;
TEST_P(DistanceKLDivergenceD, Result) {
int m = params.isRowMajor ? params.m : params.n;
int n = params.isRowMajor ? params.n : params.m;
ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n,
ASSERT_TRUE(raft::devArrMatch(dist_ref.data(), dist.data(), m, n,
raft::CompareApprox<double>(params.tolerance)));
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceKLDivergenceD,
Expand Down
4 changes: 2 additions & 2 deletions cpp/test/distance/dist_l1.cu
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ typedef DistanceUnexpL1<float> DistanceUnexpL1F;
TEST_P(DistanceUnexpL1F, Result) {
int m = params.isRowMajor ? params.m : params.n;
int n = params.isRowMajor ? params.n : params.m;
ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n,
ASSERT_TRUE(raft::devArrMatch(dist_ref.data(), dist.data(), m, n,
raft::CompareApprox<float>(params.tolerance)));
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceUnexpL1F,
Expand All @@ -58,7 +58,7 @@ typedef DistanceUnexpL1<double> DistanceUnexpL1D;
TEST_P(DistanceUnexpL1D, Result) {
int m = params.isRowMajor ? params.m : params.n;
int n = params.isRowMajor ? params.n : params.m;
ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n,
ASSERT_TRUE(raft::devArrMatch(dist_ref.data(), dist.data(), m, n,
raft::CompareApprox<double>(params.tolerance)));
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceUnexpL1D,
Expand Down
4 changes: 2 additions & 2 deletions cpp/test/distance/dist_minkowski.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ typedef DistanceLpUnexp<float> DistanceLpUnexpF;
TEST_P(DistanceLpUnexpF, Result) {
int m = params.isRowMajor ? params.m : params.n;
int n = params.isRowMajor ? params.n : params.m;
ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n,
ASSERT_TRUE(raft::devArrMatch(dist_ref.data(), dist.data(), m, n,
raft::CompareApprox<float>(params.tolerance)));
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceLpUnexpF,
Expand All @@ -59,7 +59,7 @@ typedef DistanceLpUnexp<double> DistanceLpUnexpD;
TEST_P(DistanceLpUnexpD, Result) {
int m = params.isRowMajor ? params.m : params.n;
int n = params.isRowMajor ? params.n : params.m;
ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n,
ASSERT_TRUE(raft::devArrMatch(dist_ref.data(), dist.data(), m, n,
raft::CompareApprox<double>(params.tolerance)));
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceLpUnexpD,
Expand Down
4 changes: 2 additions & 2 deletions cpp/test/distance/dist_russell_rao.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ typedef DistanceRussellRao<float> DistanceRussellRaoF;
TEST_P(DistanceRussellRaoF, Result) {
int m = params.isRowMajor ? params.m : params.n;
int n = params.isRowMajor ? params.n : params.m;
ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n,
ASSERT_TRUE(raft::devArrMatch(dist_ref.data(), dist.data(), m, n,
raft::CompareApprox<float>(params.tolerance)));
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceRussellRaoF,
Expand All @@ -59,7 +59,7 @@ typedef DistanceRussellRao<double> DistanceRussellRaoD;
TEST_P(DistanceRussellRaoD, Result) {
int m = params.isRowMajor ? params.m : params.n;
int n = params.isRowMajor ? params.n : params.m;
ASSERT_TRUE(raft::devArrMatch(dist_ref, dist, m, n,
ASSERT_TRUE(raft::devArrMatch(dist_ref.data(), dist.data(), m, n,
raft::CompareApprox<double>(params.tolerance)));
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceRussellRaoD,
Expand Down
Loading