Skip to content

Commit

Permalink
Fix integer overflow in distances (#490)
Browse files Browse the repository at this point in the history
Fix for rapidsai/cuml#4552.

Authors:
  - Rory Mitchell (https://github.com/RAMitchell)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #490
  • Loading branch information
RAMitchell authored Feb 8, 2022
1 parent 66a60b9 commit 23e1650
Show file tree
Hide file tree
Showing 15 changed files with 101 additions and 16 deletions.
5 changes: 3 additions & 2 deletions cpp/include/raft/distance/detail/pairwise_distance_base.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -276,7 +276,8 @@ struct PairwiseDistances : public BaseClass {
for (int j = 0; j < P::AccColsPerTh; ++j) {
auto colId = startx + j * P::AccThCols;
if (rowId < this->m && colId < this->n) {
dOutput[rowId * this->n + colId] = fin_op(acc[i][j], 0);
// Promote to 64 bit index for final write, as output array can be > 2^31
dOutput[std::size_t(rowId) * this->n + colId] = fin_op(acc[i][j], 0);
}
}
}
Expand Down
6 changes: 5 additions & 1 deletion cpp/test/distance/dist_canberra.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2021, NVIDIA CORPORATION.
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -64,5 +64,9 @@ TEST_P(DistanceCanberraD, Result)
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCanberraD, ::testing::ValuesIn(inputsd));

class BigMatrixCanberra : public BigMatrixDistanceTest<raft::distance::DistanceType::Canberra> {
};
TEST_F(BigMatrixCanberra, Result) {}

} // end namespace distance
} // end namespace raft
6 changes: 5 additions & 1 deletion cpp/test/distance/dist_chebyshev.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2021, NVIDIA CORPORATION.
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -64,5 +64,9 @@ TEST_P(DistanceLinfD, Result)
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceLinfD, ::testing::ValuesIn(inputsd));

class BigMatrixLinf : public BigMatrixDistanceTest<raft::distance::DistanceType::Linf> {
};
TEST_F(BigMatrixLinf, Result) {}

} // end namespace distance
} // end namespace raft
6 changes: 5 additions & 1 deletion cpp/test/distance/dist_correlation.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -65,5 +65,9 @@ TEST_P(DistanceCorrelationD, Result)
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceCorrelationD, ::testing::ValuesIn(inputsd));

class BigMatrixCorrelation
: public BigMatrixDistanceTest<raft::distance::DistanceType::CorrelationExpanded> {
};
TEST_F(BigMatrixCorrelation, Result) {}
} // end namespace distance
} // end namespace raft
6 changes: 5 additions & 1 deletion cpp/test/distance/dist_cos.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2021, NVIDIA CORPORATION.
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -64,5 +64,9 @@ TEST_P(DistanceExpCosD, Result)
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceExpCosD, ::testing::ValuesIn(inputsd));

class BigMatrixCos : public BigMatrixDistanceTest<raft::distance::DistanceType::CosineExpanded> {
};
TEST_F(BigMatrixCos, Result) {}

} // end namespace distance
} // end namespace raft
5 changes: 4 additions & 1 deletion cpp/test/distance/dist_euc_exp.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2021, NVIDIA CORPORATION.
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -64,5 +64,8 @@ TEST_P(DistanceEucExpTestD, Result)
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucExpTestD, ::testing::ValuesIn(inputsd));

class BigMatrixEucExp : public BigMatrixDistanceTest<raft::distance::DistanceType::L2Expanded> {
};
TEST_F(BigMatrixEucExp, Result) {}
} // end namespace distance
} // end namespace raft
5 changes: 4 additions & 1 deletion cpp/test/distance/dist_euc_unexp.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2021, NVIDIA CORPORATION.
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -65,5 +65,8 @@ TEST_P(DistanceEucUnexpTestD, Result)
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceEucUnexpTestD, ::testing::ValuesIn(inputsd));

class BigMatrixEucUnexp : public BigMatrixDistanceTest<raft::distance::DistanceType::L2Unexpanded> {
};
TEST_F(BigMatrixEucUnexp, Result) {}
} // end namespace distance
} // end namespace raft
6 changes: 5 additions & 1 deletion cpp/test/distance/dist_hamming.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2021, NVIDIA CORPORATION.
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -65,5 +65,9 @@ TEST_P(DistanceHammingD, Result)
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceHammingD, ::testing::ValuesIn(inputsd));

class BigMatrixHamming
: public BigMatrixDistanceTest<raft::distance::DistanceType::HammingUnexpanded> {
};
TEST_F(BigMatrixHamming, Result) {}
} // end namespace distance
} // end namespace raft
6 changes: 5 additions & 1 deletion cpp/test/distance/dist_hellinger.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -65,5 +65,9 @@ TEST_P(DistanceHellingerExpD, Result)
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceHellingerExpD, ::testing::ValuesIn(inputsd));

class BigMatrixHellingerExp
: public BigMatrixDistanceTest<raft::distance::DistanceType::HellingerExpanded> {
};
TEST_F(BigMatrixHellingerExp, Result) {}
} // end namespace distance
} // end namespace raft
6 changes: 5 additions & 1 deletion cpp/test/distance/dist_jensen_shannon.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -65,5 +65,9 @@ TEST_P(DistanceJensenShannonD, Result)
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceJensenShannonD, ::testing::ValuesIn(inputsd));

class BigMatrixJensenShannon
: public BigMatrixDistanceTest<raft::distance::DistanceType::JensenShannon> {
};
TEST_F(BigMatrixJensenShannon, Result) {}
} // end namespace distance
} // end namespace raft
6 changes: 5 additions & 1 deletion cpp/test/distance/dist_kl_divergence.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -65,5 +65,9 @@ TEST_P(DistanceKLDivergenceD, Result)
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceKLDivergenceD, ::testing::ValuesIn(inputsd));

class BigMatrixKLDivergence
: public BigMatrixDistanceTest<raft::distance::DistanceType::KLDivergence> {
};
TEST_F(BigMatrixKLDivergence, Result) {}
} // end namespace distance
} // end namespace raft
6 changes: 5 additions & 1 deletion cpp/test/distance/dist_l1.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2021, NVIDIA CORPORATION.
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -64,5 +64,9 @@ TEST_P(DistanceUnexpL1D, Result)
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceUnexpL1D, ::testing::ValuesIn(inputsd));

class BigMatrixUnexpL1 : public BigMatrixDistanceTest<raft::distance::DistanceType::L1> {
};
TEST_F(BigMatrixUnexpL1, Result) {}

} // end namespace distance
} // end namespace raft
5 changes: 4 additions & 1 deletion cpp/test/distance/dist_minkowski.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2021, NVIDIA CORPORATION.
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -64,5 +64,8 @@ TEST_P(DistanceLpUnexpD, Result)
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceLpUnexpD, ::testing::ValuesIn(inputsd));

class BigMatrixLpUnexp : public BigMatrixDistanceTest<raft::distance::DistanceType::LpUnexpanded> {
};
TEST_F(BigMatrixLpUnexp, Result) {}
} // end namespace distance
} // end namespace raft
6 changes: 5 additions & 1 deletion cpp/test/distance/dist_russell_rao.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
* Copyright (c) 2021-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -65,5 +65,9 @@ TEST_P(DistanceRussellRaoD, Result)
}
INSTANTIATE_TEST_CASE_P(DistanceTests, DistanceRussellRaoD, ::testing::ValuesIn(inputsd));

class BigMatrixRussellRao
: public BigMatrixDistanceTest<raft::distance::DistanceType::RusselRaoExpanded> {
};
TEST_F(BigMatrixRussellRao, Result) {}
} // end namespace distance
} // end namespace raft
37 changes: 36 additions & 1 deletion cpp/test/distance/distance_base.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2021, NVIDIA CORPORATION.
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -476,5 +476,40 @@ class DistanceTest : public ::testing::TestWithParam<DistanceInputs<DataType>> {
rmm::device_uvector<DataType> x, y, dist_ref, dist, dist2;
};

template <raft::distance::DistanceType distanceType>
class BigMatrixDistanceTest : public ::testing::Test {
public:
BigMatrixDistanceTest()
: x(m * k, handle.get_stream()), dist(std::size_t(m) * m, handle.get_stream()){};
void SetUp() override
{
auto testInfo = testing::UnitTest::GetInstance()->current_test_info();
common::nvtx::range fun_scope("test::%s/%s", testInfo->test_suite_name(), testInfo->name());

size_t worksize = raft::distance::getWorkspaceSize<distanceType, float, float, float>(
x.data(), x.data(), m, n, k);
rmm::device_uvector<char> workspace(worksize, handle.get_stream());
raft::distance::distance<distanceType, float, float, float>(x.data(),
x.data(),
dist.data(),
m,
n,
k,
workspace.data(),
worksize,
handle.get_stream(),
true,
0.0f);

RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream()));
}

protected:
int m = 48000;
int n = 48000;
int k = 1;
raft::handle_t handle;
rmm::device_uvector<float> x, dist;
};
} // end namespace distance
} // end namespace raft

0 comments on commit 23e1650

Please sign in to comment.