Skip to content

Commit

Permalink
Remove type punning from TxN_t
Browse files Browse the repository at this point in the history
  • Loading branch information
wphicks committed Aug 5, 2022
1 parent f974c7b commit 540eeb9
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
10 changes: 5 additions & 5 deletions cpp/include/raft/matrix/detail/linewise_op.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ struct Linewise {
Vec v, w;
bool update = true;
for (; in < in_end; in += AlignWarp::Value, out += AlignWarp::Value, rowMod += warpPad) {
v.val.internal = __ldcv(in);
*v.vectorized_data() = __ldcv(in);
while (rowMod >= rowLen) {
rowMod -= rowLen;
rowDiv++;
Expand All @@ -105,7 +105,7 @@ struct Linewise {
int l = 0;
w.val.data[k] = op(v.val.data[k], (std::ignore = vecs, args[l++])...);
}
*out = w.val.internal;
*out = *w.vectorized_data();
}
}

Expand Down Expand Up @@ -138,11 +138,11 @@ struct Linewise {
Vec v;
const IdxType d = BlockSize * gridDim.x;
for (IdxType i = threadIdx.x + blockIdx.x * BlockSize; i < len; i += d) {
v.val.internal = __ldcv(in + i);
*v.vectorized_data() = __ldcv(in + i);
#pragma unroll VecElems
for (int k = 0; k < VecElems; k++)
v.val.data[k] = op(v.val.data[k], args.val.data[k]...);
__stwt(out + i, v.val.internal);
__stwt(out + i, *v.vectorized_data());
}
}

Expand Down Expand Up @@ -172,7 +172,7 @@ struct Linewise {
__syncthreads();
{
Vec out;
out.val.internal = reinterpret_cast<typename Vec::io_t*>(shm)[threadIdx.x];
*out.vectorized_data() = reinterpret_cast<typename Vec::io_t*>(shm)[threadIdx.x];
return out;
}
}
Expand Down
12 changes: 7 additions & 5 deletions cpp/include/raft/vectorized.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,12 @@ struct TxN_t {
union {
/** the vectorized data that is used for subsequent operations */
math_t data[Ratio];
/** internal data used to ensure vectorized loads/stores */
io_t internal;
} val;

__device__ auto* vectorized_data() {
return reinterpret_cast<io_t*>(val.data);
}

///@todo: add default constructor

/**
Expand Down Expand Up @@ -311,21 +313,21 @@ struct TxN_t {
DI void load(const math_t* ptr, idx_t idx)
{
const io_t* bptr = reinterpret_cast<const io_t*>(&ptr[idx]);
val.internal = __ldg(bptr);
*vectorized_data() = __ldg(bptr);
}

template <typename idx_t = int>
DI void load(math_t* ptr, idx_t idx)
{
io_t* bptr = reinterpret_cast<io_t*>(&ptr[idx]);
val.internal = *bptr;
*vectorized_data() = *bptr;
}

template <typename idx_t = int>
DI void store(math_t* ptr, idx_t idx)
{
io_t* bptr = reinterpret_cast<io_t*>(&ptr[idx]);
*bptr = val.internal;
*bptr = *vectorized_data();
}
/** @} */
};
Expand Down

0 comments on commit 540eeb9

Please sign in to comment.