Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HOTFIX] Fix distance metrics L2/cosine/correlation when X & Y are same buffer but with different shape and add unit test for such case. #1571

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions cpp/include/raft/distance/detail/distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,9 @@ void distance_impl(raft::resources const& handle,
bool is_row_major,
DataT) // unused
{
ASSERT(
!(((x != y) && (worksize < 2 * (m + n) * sizeof(AccT))) || (worksize < 2 * m * sizeof(AccT))),
"workspace size error");
ASSERT(!((((x != y) || ((x == y) && (m != n))) && (worksize < 2 * (m + n) * sizeof(AccT))) ||
(worksize < 2 * m * sizeof(AccT))),
"workspace size error");
ASSERT(workspace != nullptr, "workspace is null");

cudaStream_t stream = raft::resource::get_cuda_stream(handle);
Expand All @@ -137,7 +137,7 @@ void distance_impl(raft::resources const& handle,
AccT* y_norm = workspace;
AccT* sq_x_norm = workspace;
AccT* sq_y_norm = workspace;
if (x != y) {
if ((x != y) || ((x == y) && (m != n))) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This conditional is getting a little complicated and I see it in multiple places in this PR. Can we create a helper function that can give this a name and be reused throughout the code? It would make it a lot easier to read and maintain.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have removed this conditional.

y_norm += m;

raft::linalg::reduce(x_norm,
Expand Down Expand Up @@ -210,15 +210,16 @@ void distance_impl(raft::resources const& handle,
"OutT can be uint8_t, float, double,"
"if sizeof(OutT) > 1 then sizeof(AccT) == sizeof(OutT).");

ASSERT(!(((x != y) && (worksize < (m + n) * sizeof(AccT))) || (worksize < m * sizeof(AccT))),
ASSERT(!((((x != y) || ((x == y) && (m != n))) && (worksize < (m + n) * sizeof(AccT))) ||
(worksize < m * sizeof(AccT))),
"workspace size error");
ASSERT(workspace != nullptr, "workspace is null");

cudaStream_t stream = raft::resource::get_cuda_stream(handle);

DataT* x_norm = workspace;
DataT* y_norm = workspace;
if (x != y) {
if ((x != y) || ((x == y) && (m != n))) {
tfeher marked this conversation as resolved.
Show resolved Hide resolved
y_norm += m;
raft::linalg::rowNorm(
x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{});
Expand Down Expand Up @@ -453,13 +454,14 @@ void distance_impl_l2_expanded( // NOTE: different name
"OutT can be uint8_t, float, double,"
"if sizeof(OutT) > 1 then sizeof(AccT) == sizeof(OutT).");

ASSERT(!(((x != y) && (worksize < (m + n) * sizeof(AccT))) || (worksize < m * sizeof(AccT))),
ASSERT(!((((x != y) || ((x == y) && (m != n))) && (worksize < (m + n) * sizeof(AccT))) ||
(worksize < m * sizeof(AccT))),
"workspace size error");
ASSERT(workspace != nullptr, "workspace is null");

DataT* x_norm = workspace;
DataT* y_norm = workspace;
if (x != y) {
if ((x != y) || ((x == y) && (m != n))) {
y_norm += m;
raft::linalg::rowNorm(
x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{});
Expand Down Expand Up @@ -790,7 +792,7 @@ size_t getWorkspaceSize(const InType* x, const InType* y, Index_ m, Index_ n, In

if (is_allocated) {
worksize += numOfBuffers * m * sizeof(AccType);
if (x != y) worksize += numOfBuffers * n * sizeof(AccType);
if ((x != y) || ((x == y) && (m != n))) worksize += numOfBuffers * n * sizeof(AccType);
}

return worksize;
Expand Down
23 changes: 23 additions & 0 deletions cpp/test/distance/dist_correlation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ template <typename DataType>
class DistanceCorrelation
: public DistanceTest<raft::distance::DistanceType::CorrelationExpanded, DataType> {};

template <typename DataType>
class DistanceCorrelationXequalY
: public DistanceTestSameBuffer<raft::distance::DistanceType::CorrelationExpanded, DataType> {};

const std::vector<DistanceInputs<float>> inputsf = {
{0.001f, 1024, 1024, 32, true, 1234ULL},
{0.001f, 1024, 32, 1024, true, 1234ULL},
Expand All @@ -44,6 +48,25 @@ TEST_P(DistanceCorrelationF, Result)
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCorrelationF, ::testing::ValuesIn(inputsf));

typedef DistanceCorrelationXequalY<float> DistanceCorrelationXequalYF;
TEST_P(DistanceCorrelationXequalYF, Result)
{
int m = params.m;
ASSERT_TRUE(raft::devArrMatch(dist_ref[0].data(),
dist[0].data(),
m,
m,
raft::CompareApprox<float>(params.tolerance),
stream));
ASSERT_TRUE(raft::devArrMatch(dist_ref[1].data(),
dist[1].data(),
m / 2,
m,
raft::CompareApprox<float>(params.tolerance),
stream));
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCorrelationXequalYF, ::testing::ValuesIn(inputsf));

const std::vector<DistanceInputs<double>> inputsd = {
{0.001, 1024, 1024, 32, true, 1234ULL},
{0.001, 1024, 32, 1024, true, 1234ULL},
Expand Down
23 changes: 23 additions & 0 deletions cpp/test/distance/dist_cos.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ template <typename DataType>
class DistanceExpCos : public DistanceTest<raft::distance::DistanceType::CosineExpanded, DataType> {
};

template <typename DataType>
class DistanceExpCosXequalY
: public DistanceTestSameBuffer<raft::distance::DistanceType::CosineExpanded, DataType> {};

const std::vector<DistanceInputs<float>> inputsf = {
{0.001f, 1024, 1024, 32, true, 1234ULL},
{0.001f, 1024, 32, 1024, true, 1234ULL},
Expand All @@ -44,6 +48,25 @@ TEST_P(DistanceExpCosF, Result)
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpCosF, ::testing::ValuesIn(inputsf));

typedef DistanceExpCosXequalY<float> DistanceExpCosXequalYF;
TEST_P(DistanceExpCosXequalYF, Result)
{
int m = params.m;
ASSERT_TRUE(raft::devArrMatch(dist_ref[0].data(),
dist[0].data(),
m,
m,
raft::CompareApprox<float>(params.tolerance),
stream));
ASSERT_TRUE(raft::devArrMatch(dist_ref[1].data(),
dist[1].data(),
m / 2,
m,
raft::CompareApprox<float>(params.tolerance),
stream));
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpCosXequalYF, ::testing::ValuesIn(inputsf));

const std::vector<DistanceInputs<double>> inputsd = {
{0.001, 1024, 1024, 32, true, 1234ULL},
{0.001, 1024, 32, 1024, true, 1234ULL},
Expand Down
23 changes: 23 additions & 0 deletions cpp/test/distance/dist_l2_exp.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ template <typename DataType>
class DistanceEucExpTest : public DistanceTest<raft::distance::DistanceType::L2Expanded, DataType> {
};

template <typename DataType>
class DistanceEucExpTestXequalY
: public DistanceTestSameBuffer<raft::distance::DistanceType::L2Expanded, DataType> {};

const std::vector<DistanceInputs<float>> inputsf = {
{0.001f, 2048, 4096, 128, true, 1234ULL},
{0.001f, 1024, 1024, 32, true, 1234ULL},
Expand All @@ -47,6 +51,25 @@ TEST_P(DistanceEucExpTestF, Result)
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucExpTestF, ::testing::ValuesIn(inputsf));

typedef DistanceEucExpTestXequalY<float> DistanceEucExpTestXequalYF;
TEST_P(DistanceEucExpTestXequalYF, Result)
{
int m = params.m;
ASSERT_TRUE(raft::devArrMatch(dist_ref[0].data(),
dist[0].data(),
m,
m,
raft::CompareApprox<float>(params.tolerance),
stream));
ASSERT_TRUE(raft::devArrMatch(dist_ref[1].data(),
dist[1].data(),
m / 2,
m,
raft::CompareApprox<float>(params.tolerance),
stream));
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucExpTestXequalYF, ::testing::ValuesIn(inputsf));

const std::vector<DistanceInputs<double>> inputsd = {
{0.001, 1024, 1024, 32, true, 1234ULL},
{0.001, 1024, 32, 1024, true, 1234ULL},
Expand Down
102 changes: 102 additions & 0 deletions cpp/test/distance/distance_base.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,108 @@ class DistanceTest : public ::testing::TestWithParam<DistanceInputs<DataType>> {
rmm::device_uvector<DataType> x, y, dist_ref, dist, dist2;
};

/*
* This test suite verifies the path when X and Y are same buffer,
* distance metrics which requires norms like L2 expanded/cosine/correlation
* takes a more optimal path in such case to skip norm calculation for Y buffer.
* It may happen that though both X and Y are same buffer but user passes
* different dimensions for them like in case of tiled_brute_force_knn.
*/
template <raft::distance::DistanceType distanceType, typename DataType>
class DistanceTestSameBuffer : public ::testing::TestWithParam<DistanceInputs<DataType>> {
public:
using dev_vector = rmm::device_uvector<DataType>;
DistanceTestSameBuffer()
: params(::testing::TestWithParam<DistanceInputs<DataType>>::GetParam()),
stream(resource::get_cuda_stream(handle)),
x(params.m * params.k, stream),
dist_ref({dev_vector(params.m * params.m, stream), dev_vector(params.m * params.m, stream)}),
dist({dev_vector(params.m * params.m, stream), dev_vector(params.m * params.m, stream)}),
dist2({dev_vector(params.m * params.m, stream), dev_vector(params.m * params.m, stream)})
{
}

void SetUp() override
{
auto testInfo = testing::UnitTest::GetInstance()->current_test_info();
common::nvtx::range fun_scope("test::%s/%s", testInfo->test_suite_name(), testInfo->name());

raft::random::RngState r(params.seed);
int m = params.m;
int n = params.m;
int k = params.k;
DataType metric_arg = params.metric_arg;
bool isRowMajor = params.isRowMajor;
if (distanceType == raft::distance::DistanceType::HellingerExpanded ||
distanceType == raft::distance::DistanceType::JensenShannon ||
distanceType == raft::distance::DistanceType::KLDivergence) {
// Hellinger works only on positive numbers
uniform(handle, r, x.data(), m * k, DataType(0.0), DataType(1.0));
} else if (distanceType == raft::distance::DistanceType::RusselRaoExpanded) {
uniform(handle, r, x.data(), m * k, DataType(0.0), DataType(1.0));
// Russel rao works on boolean values.
bernoulli(handle, r, x.data(), m * k, 0.5f);
} else {
uniform(handle, r, x.data(), m * k, DataType(-1.0), DataType(1.0));
}

for (int i = 0; i < 2; i++) {
// both X and Y are same buffer but when i = 1
// different dimensions for x & y is passed.
m = m / (i + 1);
naiveDistance(dist_ref[i].data(),
x.data(),
x.data(),
m,
n,
k,
distanceType,
isRowMajor,
metric_arg,
stream);

DataType threshold = -10000.f;

if (isRowMajor) {
distanceLauncher<distanceType, DataType, layout_c_contiguous>(handle,
x.data(),
x.data(),
dist[i].data(),
dist2[i].data(),
m,
n,
k,
params,
threshold,
metric_arg);

} else {
distanceLauncher<distanceType, DataType, layout_f_contiguous>(handle,
x.data(),
x.data(),
dist[i].data(),
dist2[i].data(),
m,
n,
k,
params,
threshold,
metric_arg);
}
}
resource::sync_stream(handle, stream);
}

protected:
raft::resources handle;
cudaStream_t stream;

DistanceInputs<DataType> params;
dev_vector x;
static const int N = 2;
std::array<dev_vector, N> dist_ref, dist, dist2;
};

template <raft::distance::DistanceType distanceType>
class BigMatrixDistanceTest : public ::testing::Test {
public:
Expand Down