diff --git a/cpp/include/raft/matrix/detail/linewise_op.cuh b/cpp/include/raft/matrix/detail/linewise_op.cuh index 81204bfe66..ee703c5138 100644 --- a/cpp/include/raft/matrix/detail/linewise_op.cuh +++ b/cpp/include/raft/matrix/detail/linewise_op.cuh @@ -83,7 +83,7 @@ struct Linewise { Vec v, w; bool update = true; for (; in < in_end; in += AlignWarp::Value, out += AlignWarp::Value, rowMod += warpPad) { - v.val.internal = __ldcv(in); + *v.vectorized_data() = __ldcv(in); while (rowMod >= rowLen) { rowMod -= rowLen; rowDiv++; @@ -105,7 +105,7 @@ struct Linewise { int l = 0; w.val.data[k] = op(v.val.data[k], (std::ignore = vecs, args[l++])...); } - *out = w.val.internal; + *out = *w.vectorized_data(); } } @@ -138,11 +138,11 @@ struct Linewise { Vec v; const IdxType d = BlockSize * gridDim.x; for (IdxType i = threadIdx.x + blockIdx.x * BlockSize; i < len; i += d) { - v.val.internal = __ldcv(in + i); + *v.vectorized_data() = __ldcv(in + i); #pragma unroll VecElems for (int k = 0; k < VecElems; k++) v.val.data[k] = op(v.val.data[k], args.val.data[k]...); - __stwt(out + i, v.val.internal); + __stwt(out + i, *v.vectorized_data()); } } @@ -172,7 +172,7 @@ struct Linewise { __syncthreads(); { Vec out; - out.val.internal = reinterpret_cast(shm)[threadIdx.x]; + *out.vectorized_data() = reinterpret_cast(shm)[threadIdx.x]; return out; } } diff --git a/cpp/include/raft/vectorized.cuh b/cpp/include/raft/vectorized.cuh index 44c6a74162..a1e0308642 100644 --- a/cpp/include/raft/vectorized.cuh +++ b/cpp/include/raft/vectorized.cuh @@ -272,10 +272,12 @@ struct TxN_t { union { /** the vectorized data that is used for subsequent operations */ math_t data[Ratio]; - /** internal data used to ensure vectorized loads/stores */ - io_t internal; } val; + __device__ auto* vectorized_data() { + return reinterpret_cast(val.data); + } + ///@todo: add default constructor /** @@ -311,21 +313,21 @@ struct TxN_t { DI void load(const math_t* ptr, idx_t idx) { const io_t* bptr = reinterpret_cast(&ptr[idx]); - val.internal = __ldg(bptr); + *vectorized_data() = __ldg(bptr); } template DI void load(math_t* ptr, idx_t idx) { io_t* bptr = reinterpret_cast(&ptr[idx]); - val.internal = *bptr; + *vectorized_data() = *bptr; } template DI void store(math_t* ptr, idx_t idx) { io_t* bptr = reinterpret_cast(&ptr[idx]); - *bptr = val.internal; + *bptr = *vectorized_data(); } /** @} */ };