Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove type punning from TxN_t #781

Merged
merged 11 commits into from
Aug 23, 2022
10 changes: 5 additions & 5 deletions cpp/include/raft/matrix/detail/linewise_op.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ struct Linewise {
Vec v, w;
bool update = true;
for (; in < in_end; in += AlignWarp::Value, out += AlignWarp::Value, rowMod += warpPad) {
v.val.internal = __ldcv(in);
*v.vectorized_data() = __ldcv(in);
while (rowMod >= rowLen) {
rowMod -= rowLen;
rowDiv++;
Expand All @@ -105,7 +105,7 @@ struct Linewise {
int l = 0;
w.val.data[k] = op(v.val.data[k], (std::ignore = vecs, args[l++])...);
}
*out = w.val.internal;
*out = *w.vectorized_data();
}
}

Expand Down Expand Up @@ -138,11 +138,11 @@ struct Linewise {
Vec v;
const IdxType d = BlockSize * gridDim.x;
for (IdxType i = threadIdx.x + blockIdx.x * BlockSize; i < len; i += d) {
v.val.internal = __ldcv(in + i);
*v.vectorized_data() = __ldcv(in + i);
#pragma unroll VecElems
for (int k = 0; k < VecElems; k++)
v.val.data[k] = op(v.val.data[k], args.val.data[k]...);
__stwt(out + i, v.val.internal);
__stwt(out + i, *v.vectorized_data());
}
}

Expand Down Expand Up @@ -172,7 +172,7 @@ struct Linewise {
__syncthreads();
{
Vec out;
out.val.internal = reinterpret_cast<typename Vec::io_t*>(shm)[threadIdx.x];
*out.vectorized_data() = reinterpret_cast<typename Vec::io_t*>(shm)[threadIdx.x];
return out;
}
}
Expand Down
16 changes: 8 additions & 8 deletions cpp/include/raft/vectorized.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2021, NVIDIA CORPORATION.
* Copyright (c) 2018-2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -272,10 +272,10 @@ struct TxN_t {
union {
/** the vectorized data that is used for subsequent operations */
math_t data[Ratio];
/** internal data used to ensure vectorized loads/stores */
io_t internal;
} val;

__device__ auto* vectorized_data() { return reinterpret_cast<io_t*>(val.data); }

///@todo: add default constructor

/**
Expand Down Expand Up @@ -310,22 +310,22 @@ struct TxN_t {
template <typename idx_t = int>
DI void load(const math_t* ptr, idx_t idx)
{
const io_t* bptr = reinterpret_cast<const io_t*>(&ptr[idx]);
val.internal = __ldg(bptr);
const io_t* bptr = reinterpret_cast<const io_t*>(&ptr[idx]);
wphicks marked this conversation as resolved.
Show resolved Hide resolved
*vectorized_data() = __ldg(bptr);
}

template <typename idx_t = int>
DI void load(math_t* ptr, idx_t idx)
{
io_t* bptr = reinterpret_cast<io_t*>(&ptr[idx]);
val.internal = *bptr;
io_t* bptr = reinterpret_cast<io_t*>(&ptr[idx]);
*vectorized_data() = *bptr;
}

template <typename idx_t = int>
DI void store(math_t* ptr, idx_t idx)
{
io_t* bptr = reinterpret_cast<io_t*>(&ptr[idx]);
*bptr = val.internal;
*bptr = *vectorized_data();
}
/** @} */
};
Expand Down