From 4d1d2010dded199cc46ec81b3157259ac6a3a0af Mon Sep 17 00:00:00 2001 From: "Artem M. Chirkin" <9253178+achirkin@users.noreply.github.com> Date: Wed, 17 Nov 2021 18:55:50 +0100 Subject: [PATCH] Fix coalesced access checks in matrix_vector_op (#372) One of the conditions in [`test_aligned_access`](https://github.com/rapidsai/raft/blob/branch-21.12/cpp/include/raft/linalg/matrix_vector_op.cuh#L106) in `linalg/matrix_vector_op.cuh` was incorrect (`ptr % elem_size` should be zero, not otherwise). Due to that typo, `matrixVectorOp` function was never using vectorized load/store instructions. This PR fixes the problem while also adding a new helper struct to simplify such checks in future. Authors: - Artem M. Chirkin (https://github.com/achirkin) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/372 --- cpp/include/raft/linalg/matrix_vector_op.cuh | 45 +++--- cpp/include/raft/pow2_utils.cuh | 161 +++++++++++++++++++ cpp/test/CMakeLists.txt | 1 + cpp/test/pow2_utils.cu | 109 +++++++++++++ 4 files changed, 295 insertions(+), 21 deletions(-) create mode 100644 cpp/include/raft/pow2_utils.cuh create mode 100644 cpp/test/pow2_utils.cu diff --git a/cpp/include/raft/linalg/matrix_vector_op.cuh b/cpp/include/raft/linalg/matrix_vector_op.cuh index e948c3e673..93f2d746fa 100644 --- a/cpp/include/raft/linalg/matrix_vector_op.cuh +++ b/cpp/include/raft/linalg/matrix_vector_op.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2020, NVIDIA CORPORATION. + * Copyright (c) 2018-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,11 +17,24 @@ #pragma once #include +#include #include namespace raft { namespace linalg { +namespace { +template +struct AlignedAccess { + template + static inline bool test(const T *matrix, size_t strideBytes) { + return Pow2::isAligned(matrix) && + Pow2::isAligned(strideBytes) && + Pow2::isAligned(VecBytes); + } +}; +}; // namespace + template __global__ void matrixVectorOpKernel(Type *out, const Type *matrix, const Type *vector, IdxType D, IdxType N, @@ -101,24 +114,19 @@ void matrixVectorOp(Type *out, const Type *matrix, const Type *vec, IdxType D, IdxType stride = rowMajor ? D : N; size_t stride_bytes = stride * sizeof(Type); - auto test_aligned_access = [stride_bytes, matrix](const int n_bytes) { - return n_bytes / sizeof(Type) && stride_bytes % n_bytes == 0 && - reinterpret_cast(matrix) % sizeof(Type); - }; - - if (test_aligned_access(16)) { + if (AlignedAccess<16>::test(matrix, stride_bytes)) { matrixVectorOpImpl( out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream); - } else if (test_aligned_access(8)) { + } else if (AlignedAccess<8>::test(matrix, stride_bytes)) { matrixVectorOpImpl( out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream); - } else if (test_aligned_access(4)) { + } else if (AlignedAccess<4>::test(matrix, stride_bytes)) { matrixVectorOpImpl( out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream); - } else if (test_aligned_access(2)) { + } else if (AlignedAccess<2>::test(matrix, stride_bytes)) { matrixVectorOpImpl( out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream); - } else if (1 / sizeof(Type)) { + } else if (AlignedAccess<1>::test(matrix, stride_bytes)) { matrixVectorOpImpl( out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream); } else { @@ -209,24 +217,19 @@ void matrixVectorOp(Type *out, const Type *matrix, const Type *vec1, IdxType stride = rowMajor ? D : N; size_t stride_bytes = stride * sizeof(Type); - auto test_aligned_access = [stride_bytes, matrix](const int n_bytes) { - return n_bytes / sizeof(Type) && stride_bytes % n_bytes == 0 && - reinterpret_cast(matrix) % sizeof(Type); - }; - - if (test_aligned_access(16)) { + if (AlignedAccess<16>::test(matrix, stride_bytes)) { matrixVectorOpImpl( out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream); - } else if (test_aligned_access(8)) { + } else if (AlignedAccess<8>::test(matrix, stride_bytes)) { matrixVectorOpImpl( out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream); - } else if (test_aligned_access(4)) { + } else if (AlignedAccess<4>::test(matrix, stride_bytes)) { matrixVectorOpImpl( out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream); - } else if (test_aligned_access(2)) { + } else if (AlignedAccess<2>::test(matrix, stride_bytes)) { matrixVectorOpImpl( out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream); - } else if (1 / sizeof(Type)) { + } else if (AlignedAccess<1>::test(matrix, stride_bytes)) { matrixVectorOpImpl( out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream); } else { diff --git a/cpp/include/raft/pow2_utils.cuh b/cpp/include/raft/pow2_utils.cuh new file mode 100644 index 0000000000..de5fc46452 --- /dev/null +++ b/cpp/include/raft/pow2_utils.cuh @@ -0,0 +1,161 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cuda_utils.cuh" + +namespace raft { + +/** + * @brief Fast arithmetics and alignment checks for power-of-two values known at compile time. + * + * @tparam Value_ a compile-time value representable as a power-of-two. + */ +template +struct Pow2 { + typedef decltype(Value_) Type; + static constexpr Type Value = Value_; + static constexpr Type Log2 = log2(Value); + static constexpr Type Mask = Value - 1; + + static_assert(std::is_integral::value, "Value must be integral."); + static_assert(Value && !(Value & Mask), "Value must be power of two."); + +#define Pow2_IsRepresentableAs(I) \ + (std::is_integral::value && Type(I(Value)) == Value) + + /** + * Integer division by Value truncated toward zero + * (same as `x / Value` in C++). + * + * Invariant: `x = Value * quot(x) + rem(x)` + */ + template + static constexpr HDI std::enable_if_t quot( + I x) noexcept { + if constexpr (std::is_signed::value) + return (x >> I(Log2)) + (x < 0 && (x & I(Mask))); + if constexpr (std::is_unsigned::value) return x >> I(Log2); + } + + /** + * Remainder of integer division by Value truncated toward zero + * (same as `x % Value` in C++). + * + * Invariant: `x = Value * quot(x) + rem(x)`. + */ + template + static constexpr HDI std::enable_if_t rem( + I x) noexcept { + if constexpr (std::is_signed::value) + return x < 0 ? -((-x) & I(Mask)) : (x & I(Mask)); + if constexpr (std::is_unsigned::value) return x & I(Mask); + } + + /** + * Integer division by Value truncated toward negative infinity + * (same as `x // Value` in Python). + * + * Invariant: `x = Value * div(x) + mod(x)`. + * + * Note, `div` and `mod` for negative values are slightly faster + * than `quot` and `rem`, but behave slightly different + * compared to normal C++ operators `/` and `%`. + */ + template + static constexpr HDI std::enable_if_t div( + I x) noexcept { + return x >> I(Log2); + } + + /** + * x modulo Value operation (remainder of the `div(x)`) + * (same as `x % Value` in Python). + * + * Invariant: `mod(x) >= 0` + * Invariant: `x = Value * div(x) + mod(x)`. + * + * Note, `div` and `mod` for negative values are slightly faster + * than `quot` and `rem`, but behave slightly different + * compared to normal C++ operators `/` and `%`. + */ + template + static constexpr HDI std::enable_if_t mod( + I x) noexcept { + return x & I(Mask); + } + +#define Pow2_CHECK_TYPE(T) \ + static_assert(std::is_pointer::value || std::is_integral::value, \ + "Only pointer or integral types make sense here") + + /** + * Tell whether the pointer or integral is Value-aligned. + * NB: for pointers, the alignment is checked in bytes, not in elements. + */ + template + static constexpr HDI bool isAligned(PtrT p) noexcept { + Pow2_CHECK_TYPE(PtrT); + if constexpr (Pow2_IsRepresentableAs(PtrT)) return mod(p) == 0; + if constexpr (!Pow2_IsRepresentableAs(PtrT)) + return mod(reinterpret_cast(p)) == 0; + } + + /** Tell whether two pointers have the same address modulo Value. */ + template + static constexpr HDI bool areSameAlignOffsets(PtrT a, PtrS b) noexcept { + Pow2_CHECK_TYPE(PtrT); + Pow2_CHECK_TYPE(PtrS); + Type x, y; + if constexpr (Pow2_IsRepresentableAs(PtrT)) + x = Type(mod(a)); + else + x = mod(reinterpret_cast(a)); + if constexpr (Pow2_IsRepresentableAs(PtrS)) + y = Type(mod(b)); + else + y = mod(reinterpret_cast(b)); + return x == y; + } + + /** Get this or next Value-aligned address (in bytes) or integral. */ + template + static constexpr HDI PtrT roundUp(PtrT p) noexcept { + Pow2_CHECK_TYPE(PtrT); + if constexpr (Pow2_IsRepresentableAs(PtrT)) + return p + PtrT(Mask) - mod(p + PtrT(Mask)); + if constexpr (!Pow2_IsRepresentableAs(PtrT)) { + auto x = reinterpret_cast(p); + return reinterpret_cast(x + Mask - mod(x + Mask)); + } + } + + /** Get this or previous Value-aligned address (in bytes) or integral. */ + template + static constexpr HDI PtrT roundDown(PtrT p) noexcept { + Pow2_CHECK_TYPE(PtrT); + if constexpr (Pow2_IsRepresentableAs(PtrT)) return p - mod(p); + if constexpr (!Pow2_IsRepresentableAs(PtrT)) { + auto x = reinterpret_cast(p); + return reinterpret_cast(x - mod(x)); + } + } +#undef Pow2_CHECK_TYPE +#undef Pow2_IsRepresentableAs +}; + +}; // namespace raft diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 43e1c65695..4a89fd3273 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -36,6 +36,7 @@ add_executable(test_raft test/eigen_solvers.cu test/handle.cpp test/integer_utils.cpp + test/pow2_utils.cu test/label/label.cu test/label/merge_labels.cu test/lap/lap.cu diff --git a/cpp/test/pow2_utils.cu b/cpp/test/pow2_utils.cu new file mode 100644 index 0000000000..92976e5c61 --- /dev/null +++ b/cpp/test/pow2_utils.cu @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft { + +template +struct Pow2Test : public ::testing::Test { + typedef Pow2 P; + std::vector data; + + void SetUp() override { + std::vector pos = {0, 1, 2, 7, 15, 16, 17, 31, 35, 1024, 1623}; + data.insert(data.end(), pos.begin(), pos.end()); + if constexpr (std::is_signed::value) { + std::vector neg = {-0, -1, -2, -5, -15, -16, -17, -156}; + data.insert(data.end(), neg.begin(), neg.end()); + } + data.push_back(std::numeric_limits::min()); + data.push_back(std::numeric_limits::max()); + } + + void quotRem() { + for (auto x : data) { + ASSERT_EQ(P::quot(x), x / P::Value) << " where x = " << x; + ASSERT_EQ(P::rem(x), x % P::Value) << " where x = " << x; + ASSERT_EQ(x, P::quot(x) * P::Value + P::rem(x)); + } + } + + void divMod() { + for (auto x : data) { + ASSERT_GE(P::mod(x), 0) << " where x = " << x; + ASSERT_EQ(x, P::div(x) * P::Value + P::mod(x)); + } + } + + void round() { + for (auto x : data) { + if (x <= std::numeric_limits::max() - TargetT(P::Value)) + ASSERT_GE(P::roundUp(x), x); + if (x >= std::numeric_limits::min() + TargetT(P::Value)) + ASSERT_LE(P::roundDown(x), x); + ASSERT_EQ(x - P::roundDown(x), P::mod(x)) << " where x = " << x; + ASSERT_EQ(P::mod(P::roundUp(x) + P::mod(x) - x), 0) + << " where x = " << x; + } + } + + void alignment() { + for (auto x : data) { + ASSERT_TRUE(P::areSameAlignOffsets(x, x)); + if (x <= std::numeric_limits::max() - TargetT(P::Value)) { + ASSERT_TRUE(P::areSameAlignOffsets(x, x + TargetT(P::Value))); + int aligned_count = 0; + int same_aligned_count = 0; + for (int i = 0; i < int(P::Value); i++) { + aligned_count += P::isAligned(x + i); + same_aligned_count += P::areSameAlignOffsets(x, x + i); + } + ASSERT_EQ(aligned_count, 1) << " where x = " << x; + ASSERT_EQ(same_aligned_count, 1) << " where x = " << x; + } + } + } +}; + +#define TEST_IT(T) \ + TEST_F(T, quotRem) { divMod(); } \ + TEST_F(T, divMod) { divMod(); } \ + TEST_F(T, round) { round(); } \ + TEST_F(T, alignment) { alignment(); } + +typedef Pow2Test<16, int> Pow2_i32_i32_16; +typedef Pow2Test<1UL, uint64_t> Pow2_u64_u64_1; +typedef Pow2Test<128UL, int> Pow2_u64_i32_128; +typedef Pow2Test<32LL, uint16_t> Pow2_ll_u16_32; +typedef Pow2Test<16, uint64_t> Pow2_i32_u64_16; +TEST_IT(Pow2_i32_i32_16); +TEST_IT(Pow2_u64_u64_1); +TEST_IT(Pow2_u64_i32_128); +TEST_IT(Pow2_ll_u16_32); +TEST_IT(Pow2_i32_u64_16); + +TEST(Pow2, pointers) { + typedef Pow2<32UL> P; + for (ptrdiff_t i = 0; i <= ptrdiff_t(P::Value); i++) { + auto *p = reinterpret_cast(16345 + i); + ASSERT_GE(P::roundUp(p), p); + ASSERT_LE(P::roundDown(p), p); + } +} + +} // namespace raft