Skip to content

Commit

Permalink
Fix coalesced access checks in matrix_vector_op (#372)
Browse files Browse the repository at this point in the history
One of the conditions in [`test_aligned_access`](https://github.com/rapidsai/raft/blob/branch-21.12/cpp/include/raft/linalg/matrix_vector_op.cuh#L106) in `linalg/matrix_vector_op.cuh` was incorrect (`ptr % elem_size` should be zero, not otherwise). Due to that typo, `matrixVectorOp` function was never using vectorized load/store instructions.
This PR fixes the problem while also adding a new helper struct to simplify such checks in future.

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #372
  • Loading branch information
achirkin authored Nov 17, 2021
1 parent 4a8fa9f commit 4d1d201
Show file tree
Hide file tree
Showing 4 changed files with 295 additions and 21 deletions.
45 changes: 24 additions & 21 deletions cpp/include/raft/linalg/matrix_vector_op.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2020, NVIDIA CORPORATION.
* Copyright (c) 2018-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,11 +17,24 @@
#pragma once

#include <raft/cuda_utils.cuh>
#include <raft/pow2_utils.cuh>
#include <raft/vectorized.cuh>

namespace raft {
namespace linalg {

namespace {
template <size_t VecBytes>
struct AlignedAccess {
template <typename T>
static inline bool test(const T *matrix, size_t strideBytes) {
return Pow2<VecBytes>::isAligned(matrix) &&
Pow2<VecBytes>::isAligned(strideBytes) &&
Pow2<sizeof(T)>::isAligned(VecBytes);
}
};
}; // namespace

template <typename Type, int veclen_, typename Lambda, typename IdxType>
__global__ void matrixVectorOpKernel(Type *out, const Type *matrix,
const Type *vector, IdxType D, IdxType N,
Expand Down Expand Up @@ -101,24 +114,19 @@ void matrixVectorOp(Type *out, const Type *matrix, const Type *vec, IdxType D,
IdxType stride = rowMajor ? D : N;
size_t stride_bytes = stride * sizeof(Type);

auto test_aligned_access = [stride_bytes, matrix](const int n_bytes) {
return n_bytes / sizeof(Type) && stride_bytes % n_bytes == 0 &&
reinterpret_cast<uintptr_t>(matrix) % sizeof(Type);
};

if (test_aligned_access(16)) {
if (AlignedAccess<16>::test(matrix, stride_bytes)) {
matrixVectorOpImpl<Type, 16 / sizeof(Type), Lambda, IdxType, TPB>(
out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream);
} else if (test_aligned_access(8)) {
} else if (AlignedAccess<8>::test(matrix, stride_bytes)) {
matrixVectorOpImpl<Type, 8 / sizeof(Type), Lambda, IdxType, TPB>(
out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream);
} else if (test_aligned_access(4)) {
} else if (AlignedAccess<4>::test(matrix, stride_bytes)) {
matrixVectorOpImpl<Type, 4 / sizeof(Type), Lambda, IdxType, TPB>(
out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream);
} else if (test_aligned_access(2)) {
} else if (AlignedAccess<2>::test(matrix, stride_bytes)) {
matrixVectorOpImpl<Type, 2 / sizeof(Type), Lambda, IdxType, TPB>(
out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream);
} else if (1 / sizeof(Type)) {
} else if (AlignedAccess<1>::test(matrix, stride_bytes)) {
matrixVectorOpImpl<Type, 1 / sizeof(Type), Lambda, IdxType, TPB>(
out, matrix, vec, D, N, rowMajor, bcastAlongRows, op, stream);
} else {
Expand Down Expand Up @@ -209,24 +217,19 @@ void matrixVectorOp(Type *out, const Type *matrix, const Type *vec1,
IdxType stride = rowMajor ? D : N;
size_t stride_bytes = stride * sizeof(Type);

auto test_aligned_access = [stride_bytes, matrix](const int n_bytes) {
return n_bytes / sizeof(Type) && stride_bytes % n_bytes == 0 &&
reinterpret_cast<uintptr_t>(matrix) % sizeof(Type);
};

if (test_aligned_access(16)) {
if (AlignedAccess<16>::test(matrix, stride_bytes)) {
matrixVectorOpImpl<Type, 16 / sizeof(Type), Lambda, IdxType, TPB>(
out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream);
} else if (test_aligned_access(8)) {
} else if (AlignedAccess<8>::test(matrix, stride_bytes)) {
matrixVectorOpImpl<Type, 8 / sizeof(Type), Lambda, IdxType, TPB>(
out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream);
} else if (test_aligned_access(4)) {
} else if (AlignedAccess<4>::test(matrix, stride_bytes)) {
matrixVectorOpImpl<Type, 4 / sizeof(Type), Lambda, IdxType, TPB>(
out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream);
} else if (test_aligned_access(2)) {
} else if (AlignedAccess<2>::test(matrix, stride_bytes)) {
matrixVectorOpImpl<Type, 2 / sizeof(Type), Lambda, IdxType, TPB>(
out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream);
} else if (1 / sizeof(Type)) {
} else if (AlignedAccess<1>::test(matrix, stride_bytes)) {
matrixVectorOpImpl<Type, 1 / sizeof(Type), Lambda, IdxType, TPB>(
out, matrix, vec1, vec2, D, N, rowMajor, bcastAlongRows, op, stream);
} else {
Expand Down
161 changes: 161 additions & 0 deletions cpp/include/raft/pow2_utils.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include "cuda_utils.cuh"

namespace raft {

/**
* @brief Fast arithmetics and alignment checks for power-of-two values known at compile time.
*
* @tparam Value_ a compile-time value representable as a power-of-two.
*/
template <auto Value_>
struct Pow2 {
typedef decltype(Value_) Type;
static constexpr Type Value = Value_;
static constexpr Type Log2 = log2(Value);
static constexpr Type Mask = Value - 1;

static_assert(std::is_integral<Type>::value, "Value must be integral.");
static_assert(Value && !(Value & Mask), "Value must be power of two.");

#define Pow2_IsRepresentableAs(I) \
(std::is_integral<I>::value && Type(I(Value)) == Value)

/**
* Integer division by Value truncated toward zero
* (same as `x / Value` in C++).
*
* Invariant: `x = Value * quot(x) + rem(x)`
*/
template <typename I>
static constexpr HDI std::enable_if_t<Pow2_IsRepresentableAs(I), I> quot(
I x) noexcept {
if constexpr (std::is_signed<I>::value)
return (x >> I(Log2)) + (x < 0 && (x & I(Mask)));
if constexpr (std::is_unsigned<I>::value) return x >> I(Log2);
}

/**
* Remainder of integer division by Value truncated toward zero
* (same as `x % Value` in C++).
*
* Invariant: `x = Value * quot(x) + rem(x)`.
*/
template <typename I>
static constexpr HDI std::enable_if_t<Pow2_IsRepresentableAs(I), I> rem(
I x) noexcept {
if constexpr (std::is_signed<I>::value)
return x < 0 ? -((-x) & I(Mask)) : (x & I(Mask));
if constexpr (std::is_unsigned<I>::value) return x & I(Mask);
}

/**
* Integer division by Value truncated toward negative infinity
* (same as `x // Value` in Python).
*
* Invariant: `x = Value * div(x) + mod(x)`.
*
* Note, `div` and `mod` for negative values are slightly faster
* than `quot` and `rem`, but behave slightly different
* compared to normal C++ operators `/` and `%`.
*/
template <typename I>
static constexpr HDI std::enable_if_t<Pow2_IsRepresentableAs(I), I> div(
I x) noexcept {
return x >> I(Log2);
}

/**
* x modulo Value operation (remainder of the `div(x)`)
* (same as `x % Value` in Python).
*
* Invariant: `mod(x) >= 0`
* Invariant: `x = Value * div(x) + mod(x)`.
*
* Note, `div` and `mod` for negative values are slightly faster
* than `quot` and `rem`, but behave slightly different
* compared to normal C++ operators `/` and `%`.
*/
template <typename I>
static constexpr HDI std::enable_if_t<Pow2_IsRepresentableAs(I), I> mod(
I x) noexcept {
return x & I(Mask);
}

#define Pow2_CHECK_TYPE(T) \
static_assert(std::is_pointer<T>::value || std::is_integral<T>::value, \
"Only pointer or integral types make sense here")

/**
* Tell whether the pointer or integral is Value-aligned.
* NB: for pointers, the alignment is checked in bytes, not in elements.
*/
template <typename PtrT>
static constexpr HDI bool isAligned(PtrT p) noexcept {
Pow2_CHECK_TYPE(PtrT);
if constexpr (Pow2_IsRepresentableAs(PtrT)) return mod(p) == 0;
if constexpr (!Pow2_IsRepresentableAs(PtrT))
return mod(reinterpret_cast<Type>(p)) == 0;
}

/** Tell whether two pointers have the same address modulo Value. */
template <typename PtrT, typename PtrS>
static constexpr HDI bool areSameAlignOffsets(PtrT a, PtrS b) noexcept {
Pow2_CHECK_TYPE(PtrT);
Pow2_CHECK_TYPE(PtrS);
Type x, y;
if constexpr (Pow2_IsRepresentableAs(PtrT))
x = Type(mod(a));
else
x = mod(reinterpret_cast<Type>(a));
if constexpr (Pow2_IsRepresentableAs(PtrS))
y = Type(mod(b));
else
y = mod(reinterpret_cast<Type>(b));
return x == y;
}

/** Get this or next Value-aligned address (in bytes) or integral. */
template <typename PtrT>
static constexpr HDI PtrT roundUp(PtrT p) noexcept {
Pow2_CHECK_TYPE(PtrT);
if constexpr (Pow2_IsRepresentableAs(PtrT))
return p + PtrT(Mask) - mod(p + PtrT(Mask));
if constexpr (!Pow2_IsRepresentableAs(PtrT)) {
auto x = reinterpret_cast<Type>(p);
return reinterpret_cast<PtrT>(x + Mask - mod(x + Mask));
}
}

/** Get this or previous Value-aligned address (in bytes) or integral. */
template <typename PtrT>
static constexpr HDI PtrT roundDown(PtrT p) noexcept {
Pow2_CHECK_TYPE(PtrT);
if constexpr (Pow2_IsRepresentableAs(PtrT)) return p - mod(p);
if constexpr (!Pow2_IsRepresentableAs(PtrT)) {
auto x = reinterpret_cast<Type>(p);
return reinterpret_cast<PtrT>(x - mod(x));
}
}
#undef Pow2_CHECK_TYPE
#undef Pow2_IsRepresentableAs
};

}; // namespace raft
1 change: 1 addition & 0 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ add_executable(test_raft
test/eigen_solvers.cu
test/handle.cpp
test/integer_utils.cpp
test/pow2_utils.cu
test/label/label.cu
test/label/merge_labels.cu
test/lap/lap.cu
Expand Down
109 changes: 109 additions & 0 deletions cpp/test/pow2_utils.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Copyright (c) 2020-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
#include <raft/pow2_utils.cuh>

namespace raft {

template <auto Val, typename TargetT>
struct Pow2Test : public ::testing::Test {
typedef Pow2<Val> P;
std::vector<TargetT> data;

void SetUp() override {
std::vector<TargetT> pos = {0, 1, 2, 7, 15, 16, 17, 31, 35, 1024, 1623};
data.insert(data.end(), pos.begin(), pos.end());
if constexpr (std::is_signed<TargetT>::value) {
std::vector<TargetT> neg = {-0, -1, -2, -5, -15, -16, -17, -156};
data.insert(data.end(), neg.begin(), neg.end());
}
data.push_back(std::numeric_limits<TargetT>::min());
data.push_back(std::numeric_limits<TargetT>::max());
}

void quotRem() {
for (auto x : data) {
ASSERT_EQ(P::quot(x), x / P::Value) << " where x = " << x;
ASSERT_EQ(P::rem(x), x % P::Value) << " where x = " << x;
ASSERT_EQ(x, P::quot(x) * P::Value + P::rem(x));
}
}

void divMod() {
for (auto x : data) {
ASSERT_GE(P::mod(x), 0) << " where x = " << x;
ASSERT_EQ(x, P::div(x) * P::Value + P::mod(x));
}
}

void round() {
for (auto x : data) {
if (x <= std::numeric_limits<TargetT>::max() - TargetT(P::Value))
ASSERT_GE(P::roundUp(x), x);
if (x >= std::numeric_limits<TargetT>::min() + TargetT(P::Value))
ASSERT_LE(P::roundDown(x), x);
ASSERT_EQ(x - P::roundDown(x), P::mod(x)) << " where x = " << x;
ASSERT_EQ(P::mod(P::roundUp(x) + P::mod(x) - x), 0)
<< " where x = " << x;
}
}

void alignment() {
for (auto x : data) {
ASSERT_TRUE(P::areSameAlignOffsets(x, x));
if (x <= std::numeric_limits<TargetT>::max() - TargetT(P::Value)) {
ASSERT_TRUE(P::areSameAlignOffsets(x, x + TargetT(P::Value)));
int aligned_count = 0;
int same_aligned_count = 0;
for (int i = 0; i < int(P::Value); i++) {
aligned_count += P::isAligned(x + i);
same_aligned_count += P::areSameAlignOffsets(x, x + i);
}
ASSERT_EQ(aligned_count, 1) << " where x = " << x;
ASSERT_EQ(same_aligned_count, 1) << " where x = " << x;
}
}
}
};

#define TEST_IT(T) \
TEST_F(T, quotRem) { divMod(); } \
TEST_F(T, divMod) { divMod(); } \
TEST_F(T, round) { round(); } \
TEST_F(T, alignment) { alignment(); }

typedef Pow2Test<16, int> Pow2_i32_i32_16;
typedef Pow2Test<1UL, uint64_t> Pow2_u64_u64_1;
typedef Pow2Test<128UL, int> Pow2_u64_i32_128;
typedef Pow2Test<32LL, uint16_t> Pow2_ll_u16_32;
typedef Pow2Test<16, uint64_t> Pow2_i32_u64_16;
TEST_IT(Pow2_i32_i32_16);
TEST_IT(Pow2_u64_u64_1);
TEST_IT(Pow2_u64_i32_128);
TEST_IT(Pow2_ll_u16_32);
TEST_IT(Pow2_i32_u64_16);

TEST(Pow2, pointers) {
typedef Pow2<32UL> P;
for (ptrdiff_t i = 0; i <= ptrdiff_t(P::Value); i++) {
auto *p = reinterpret_cast<float *>(16345 + i);
ASSERT_GE(P::roundUp(p), p);
ASSERT_LE(P::roundDown(p), p);
}
}

} // namespace raft

0 comments on commit 4d1d201

Please sign in to comment.