diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index ae91d75b31..d09e1b329b 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -249,6 +249,7 @@ if(BUILD_RAFT_TESTS) # keep the files in alphabetical order! add_executable(test_raft test/cudart_utils.cpp + test/distance/fused_l2_nn.cu test/handle.cpp test/integer_utils.cpp test/lap/lap.cu diff --git a/cpp/include/raft/common/device_loads_stores.cuh b/cpp/include/raft/common/device_loads_stores.cuh new file mode 100644 index 0000000000..7f9462d9ea --- /dev/null +++ b/cpp/include/raft/common/device_loads_stores.cuh @@ -0,0 +1,120 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft { + +/** + * @defgroup SmemStores Shared memory store operations + * @{ + * @brief Stores to shared memory (both vectorized and non-vectorized forms) + * @param[out] addr shared memory address + * @param[in] x data to be stored at this address + */ +DI void sts(float* addr, const float& x) { *addr = x; } +DI void sts(float* addr, const float (&x)[1]) { *addr = x[0]; } +DI void sts(float* addr, const float (&x)[2]) { + float2 v2 = make_float2(x[0], x[1]); + auto* s2 = reinterpret_cast(addr); + *s2 = v2; +} +DI void sts(float* addr, const float (&x)[4]) { + float4 v4 = make_float4(x[0], x[1], x[2], x[3]); + auto* s4 = reinterpret_cast(addr); + *s4 = v4; +} +DI void sts(double* addr, const double& x) { *addr = x; } +DI void sts(double* addr, const double (&x)[1]) { *addr = x[0]; } +DI void sts(double* addr, const double (&x)[2]) { + double2 v2 = make_double2(x[0], x[1]); + auto* s2 = reinterpret_cast(addr); + *s2 = v2; +} +/** @} */ + +/** + * @defgroup SmemLoads Shared memory load operations + * @{ + * @brief Loads from shared memory (both vectorized and non-vectorized forms) + * @param[out] x the data to be loaded + * @param[in] addr shared memory address from where to load + */ +DI void lds(float& x, float* addr) { x = *addr; } +DI void lds(float (&x)[1], float* addr) { x[0] = *addr; } +DI void lds(float (&x)[2], float* addr) { + auto* s2 = reinterpret_cast(addr); + auto v2 = *s2; + x[0] = v2.x; + x[1] = v2.y; +} +DI void lds(float (&x)[4], float* addr) { + auto* s4 = reinterpret_cast(addr); + auto v4 = *s4; + x[0] = v4.x; + x[1] = v4.y; + x[2] = v4.z; + x[3] = v4.w; +} +DI void lds(double& x, double* addr) { x = *addr; } +DI void lds(double (&x)[1], double* addr) { x[0] = *addr; } +DI void lds(double (&x)[2], double* addr) { + auto* s2 = reinterpret_cast(addr); + auto v2 = *s2; + x[0] = v2.x; + x[1] = v2.y; +} +/** @} */ + +/** + * @defgroup GlobalLoads Global cached load operations + * @{ + * @brief Load from global memory with caching at L1 level + * @param[out] x data to be loaded from global memory + * @param[in] addr address in global memory from where to load + */ +DI void ldg(float& x, const float* addr) { + asm volatile("ld.global.cg.f32 %0, [%1];" : "=f"(x) : "l"(addr)); +} +DI void ldg(float (&x)[1], const float* addr) { + asm volatile("ld.global.cg.f32 %0, [%1];" : "=f"(x[0]) : "l"(addr)); +} +DI void ldg(float (&x)[2], const float* addr) { + asm volatile("ld.global.cg.v2.f32 {%0, %1}, [%2];" + : "=f"(x[0]), "=f"(x[1]) + : "l"(addr)); +} +DI void ldg(float (&x)[4], const float* addr) { + asm volatile("ld.global.cg.v4.f32 {%0, %1, %2, %3}, [%4];" + : "=f"(x[0]), "=f"(x[1]), "=f"(x[2]), "=f"(x[3]) + : "l"(addr)); +} +DI void ldg(double& x, const double* addr) { + asm volatile("ld.global.cg.f64 %0, [%1];" : "=d"(x) : "l"(addr)); +} +DI void ldg(double (&x)[1], const double* addr) { + asm volatile("ld.global.cg.f64 %0, [%1];" : "=d"(x[0]) : "l"(addr)); +} +DI void ldg(double (&x)[2], const double* addr) { + asm volatile("ld.global.cg.v2.f64 {%0, %1}, [%2];" + : "=d"(x[0]), "=d"(x[1]) + : "l"(addr)); +} +/** @} */ + +} // namespace raft diff --git a/cpp/include/raft/distance/fused_l2_nn.cuh b/cpp/include/raft/distance/fused_l2_nn.cuh new file mode 100644 index 0000000000..000d856841 --- /dev/null +++ b/cpp/include/raft/distance/fused_l2_nn.cuh @@ -0,0 +1,423 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace raft { +namespace distance { + +#if (ENABLE_MEMCPY_ASYNC == 1) +#include +using namespace nvcuda::experimental; +#endif + +template +struct KVPMinReduce { + typedef cub::KeyValuePair KVP; + + DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { + return b.value < a.value ? b : a; + } + +}; // KVPMinReduce + +template +struct MinAndDistanceReduceOp { + typedef typename cub::KeyValuePair KVP; + DI void operator()(LabelT rid, KVP* out, const KVP& other) { + if (other.value < out->value) { + out->key = other.key; + out->value = other.value; + } + } + + DI void init(KVP* out, DataT maxVal) { + out->key = -1; + out->value = maxVal; + } +}; + +template +struct MinReduceOp { + typedef typename cub::KeyValuePair KVP; + DI void operator()(LabelT rid, DataT* out, const KVP& other) { + if (other.value < *out) { + *out = other.value; + } + } + + DI void init(DataT* out, DataT maxVal) { *out = maxVal; } +}; + +template > +struct FusedL2NN : public BaseClass { + private: + typedef Policy P; + + const DataT* xn; + const DataT* yn; + OutT* min; + int* mutex; + + DataT *sxNorm, *syNorm; + cub::KeyValuePair* sRed; + + DataT maxVal; + + DataT acc[P::AccRowsPerTh][P::AccColsPerTh]; + + ReduceOpT redOp; + KVPReduceOpT pairRedOp; + +#if (ENABLE_MEMCPY_ASYNC == 1) + DataT zeros[P::Veclen]; + nvcuda::experimental::pipeline pipe; +#endif + + static const DataT Two = (DataT)2.0; + static constexpr size_t SizeAndAlign = P::Veclen * sizeof(DataT); + + public: + DI FusedL2NN(OutT* _min, const DataT* _x, const DataT* _y, const DataT* _xn, + const DataT* _yn, IdxT _m, IdxT _n, IdxT _k, char* _smem, + DataT _mv, int* _mut, ReduceOpT op, KVPReduceOpT pair_op) + : BaseClass(_x, _y, _m, _n, _k, _smem), + xn(_xn), + yn(_yn), + min(_min), + mutex(_mut), + sxNorm((DataT*)_smem), + syNorm(&(sxNorm[P::Mblk])), + sRed((cub::KeyValuePair*)_smem), + maxVal(_mv), + redOp(op), + pairRedOp(pair_op) { +#if (ENABLE_MEMCPY_ASYNC == 1) +#pragma unroll + for (int i = 0; i < P::Veclen; ++i) { + zeros[i] = BaseClass::Zero; + } +#endif + } + + DI void run() { + prolog(); + loop(); + __syncthreads(); // so that we can safely reuse smem + epilog(); + } + + private: + DI void prolog() { + this->ldgXY(0); +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + acc[i][j] = BaseClass::Zero; + } + } + this->stsXY(); + __syncthreads(); + this->pageWr ^= 1; + } + + DI void loop() { + for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { + this->ldgXY(kidx); + accumulate(); // on the previous k-block + this->stsXY(); + __syncthreads(); + this->pageWr ^= 1; + this->pageRd ^= 1; + } + accumulate(); // last iteration + } + + DI void epilog() { + for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { + auto idx = blockIdx.x * P::Mblk + i; + sxNorm[i] = idx < this->m ? xn[idx] : maxVal; + } + for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { + auto idx = blockIdx.y * P::Nblk + i; + syNorm[i] = idx < this->n ? yn[idx] : maxVal; + } + __syncthreads(); + DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + regxn[i] = sxNorm[i * P::AccThRows + this->accrowid]; + } +#pragma unroll + for (int i = 0; i < P::AccColsPerTh; ++i) { + regyn[i] = syNorm[i * P::AccThCols + this->acccolid]; + } +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + acc[i][j] = regxn[i] + regyn[j] - Two * acc[i][j]; + } + } + if (Sqrt) { +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + acc[i][j] = raft::mySqrt(acc[i][j]); + } + } + } + // reduce + cub::KeyValuePair val[P::AccRowsPerTh]; + auto lid = raft::laneId(); +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + val[i] = {-1, maxVal}; +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + auto tmpkey = this->acccolid + j * P::AccThCols + blockIdx.y * P::Nblk; + cub::KeyValuePair tmp = {tmpkey, acc[i][j]}; + if (tmpkey < this->n) + val[i] = + pairRedOp(this->accrowid + i * P::AccThRows + blockIdx.x * P::Mblk, + tmp, val[i]); + } + __syncthreads(); +#pragma unroll + for (int j = P::AccThCols / 2; j > 0; j >>= 1) { + auto tmpkey = raft::shfl(val[i].key, lid + j); + auto tmpvalue = raft::shfl(val[i].value, lid + j); + cub::KeyValuePair tmp = {tmpkey, tmpvalue}; + val[i] = + pairRedOp(this->accrowid + i * P::AccThRows + blockIdx.x * P::Mblk, + tmp, val[i]); + } + } + if (lid % P::AccThCols == 0) { +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + sRed[i * P::AccThCols + this->accrowid] = val[i]; + } + } + __syncthreads(); + updateResults(); + } + + /* + * todo: From Volta onwards see if "coalesced" atomicCAS approach as + * written below helps improve perf + * ``` + * auto tid = threadIdx.x; + * auto rid = IdxT(blockIdx.x) * P::Mblk + tid; + * if (rid < m) { + * auto val = sRed[i]; + * while (atomicCAS(mutex + rid, 0, 1) == 1) + * ; + * __threadfence(); + * redOp(rid, min + rid, val); + * __threadfence(); + * atomicCAS(mutex + rid, 1, 0); + * } + * ``` + */ + DI void updateResults() { + // for now have first lane from each warp update a unique output row. This + // will resolve hang issues with pre-Volta architectures + auto nWarps = blockDim.x / raft::WarpSize; + auto lid = raft::laneId(); + auto ridx = IdxT(blockIdx.x) * P::Mblk; + if (lid == 0) { + for (int i = threadIdx.x / raft::WarpSize; i < P::Mblk; i += nWarps) { + auto rid = ridx + i; + if (rid < this->m) { + auto val = sRed[i]; + while (atomicCAS(mutex + rid, 0, 1) == 1) + ; + __threadfence(); + redOp(rid, min + rid, val); + __threadfence(); + atomicCAS(mutex + rid, 1, 0); + } + } + } + } + + DI void accumulate() { +#pragma unroll + for (int ki = 0; ki < P::Kblk; ki += P::Veclen) { + this->ldsXY(ki); +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { +#pragma unroll + for (int v = 0; v < P::Veclen; ++v) { + acc[i][j] += this->regx[i][v] * this->regy[j][v]; + } + } + } + } + } + +#if (ENABLE_MEMCPY_ASYNC == 1) + DI void ldgXY(IdxT kidx) { + auto koffset = kidx + this->scolid; + auto offset = + this->pageWr * P::SmemPage + this->srowid * P::SmemStride + this->scolid; + auto* saddrx = this->sx + offset; + for (int i = 0; i < P::LdgPerThX; ++i) { + auto* sax = saddrx + i * P::LdgRowsX * P::SmemStride; + auto* gax = this->x + i * P::LdgRowsX * this->k + koffset; + auto inside = + koffset < this->k && (this->xrowid + i * P::LdgRowsX) < this->m; + __pipeline_memcpy_async(sax, inside ? gax : nullptr, SizeAndAlign, + inside ? 0 : SizeAndAlign); + } + auto* saddry = this->sy + offset; + for (int i = 0; i < P::LdgPerThY; ++i) { + auto* say = saddry + i * P::LdgRowsY * P::SmemStride; + auto* gay = this->y + i * P::LdgRowsY * this->k + koffset; + auto inside = + koffset < this->k && (this->yrowid + i * P::LdgRowsY) < this->n; + __pipeline_memcpy_async(say, inside ? gay : nullptr, SizeAndAlign, + inside ? 0 : SizeAndAlign); + } + pipe.commit(); + } + + DI void stsXY() { pipe.wait_prior<0>(); } +#endif // ENABLE_MEMCPY_ASYNC +}; // struct FusedL2NN + +template +__global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2NNkernel( + OutT* min, const DataT* x, const DataT* y, const DataT* xn, const DataT* yn, + IdxT m, IdxT n, IdxT k, DataT maxVal, int* mutex, ReduceOpT redOp, + KVPReduceOpT pairRedOp) { + extern __shared__ char smem[]; + FusedL2NN obj( + min, x, y, xn, yn, m, n, k, smem, maxVal, mutex, redOp, pairRedOp); + obj.run(); +} + +template +__global__ void initKernel(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) { + auto tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x; + if (tid < m) { + redOp.init(min + tid, maxVal); + } +} + +template +void fusedL2NNImpl(OutT* min, const DataT* x, const DataT* y, const DataT* xn, + const DataT* yn, IdxT m, IdxT n, IdxT k, int* workspace, + ReduceOpT redOp, KVPReduceOpT pairRedOp, bool sqrt, + bool initOutBuffer, cudaStream_t stream) { + typedef typename linalg::Policy4x4::Policy Policy; + dim3 grid(raft::ceildiv(m, Policy::Mblk), + raft::ceildiv(n, Policy::Nblk)); + dim3 blk(Policy::Nthreads); + auto nblks = raft::ceildiv(m, Policy::Nthreads); + auto maxVal = std::numeric_limits::max(); + CUDA_CHECK(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); + if (initOutBuffer) { + initKernel + <<>>(min, m, maxVal, redOp); + CUDA_CHECK(cudaGetLastError()); + } + if (sqrt) { + fusedL2NNkernel + <<>>( + min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp); + } else { + fusedL2NNkernel + <<>>( + min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp); + } + CUDA_CHECK(cudaGetLastError()); +} + +/** + * @brief Fused L2 distance and 1-nearest-neighbor computation in a single call. + * + * The benefits of such a call are 2-fold: 1) eliminate the need for an + * intermediate buffer to store the output of gemm 2) reduce the memory read + * traffic on this intermediate buffer, otherwise needed during the reduction + * phase for 1-NN. + * + * @tparam DataT data type + * @tparam OutT output type to either store 1-NN indices and their minimum + * distances or store only the min distances. Accordingly, one + * has to pass an appropriate `ReduceOpT` + * @tparam IdxT indexing arithmetic type + * @tparam ReduceOpT A struct to perform the final needed reduction operation + * and also to initialize the output array elements with the + * appropriate initial value needed for reduction. + * + * @param[out] min will contain the reduced output (Length = `m`) + * (on device) + * @param[in] x first matrix. Row major. Dim = `m x k`. + * (on device). + * @param[in] y second matrix. Row major. Dim = `n x k`. + * (on device). + * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). + * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) + * @param[in] m gemm m + * @param[in] n gemm n + * @param[in] k gemm k + * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) + * @param[in] redOp reduction operator in the epilogue + * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt + * @param[in] initOutBuffer whether to initialize the output buffer before the + * main kernel launch + * @param[in] stream cuda stream + */ +template +void fusedL2NN(OutT* min, const DataT* x, const DataT* y, const DataT* xn, + const DataT* yn, IdxT m, IdxT n, IdxT k, void* workspace, + ReduceOpT redOp, KVPReduceOpT pairRedOp, bool sqrt, + bool initOutBuffer, cudaStream_t stream) { + size_t bytes = sizeof(DataT) * k; + if (16 % sizeof(DataT) == 0 && bytes % 16 == 0) { + fusedL2NNImpl( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, + initOutBuffer, stream); + } else if (8 % sizeof(DataT) == 0 && bytes % 8 == 0) { + fusedL2NNImpl( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, + initOutBuffer, stream); + } else { + fusedL2NNImpl( + min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, + initOutBuffer, stream); + } +} + +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/linalg/contractions.cuh b/cpp/include/raft/linalg/contractions.cuh new file mode 100644 index 0000000000..c82bb761e7 --- /dev/null +++ b/cpp/include/raft/linalg/contractions.cuh @@ -0,0 +1,303 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft { +namespace linalg { + +/** + * @brief This is the central enum that should be used to configure the perf + * landscape of the Contraction kernel. + * + * Main goal of this Policy struct is to provide sufficient knobs to tune the + * perf of Contraction kernel, as and when we see matrices of different shapes. + * + * @tparam DataT the IO and math datatype + * @tparam _veclen number of k-elements loaded by each thread for every LDG call + * it makes. This should be configured based on the input 'k' + * value and the input data type. For eg: if DataT = float and + * k is multiples of 4, then setting this to 4 gives the best + * LDG pattern. Possible values are {1, 2, 4}. + * @tparam _kblk number of k-elements operated upon per main-loop iteration. + * Therefore total number of main-loop iterations will be + * `ceil(k/_kblk)`. This must be multiples of `_veclen`. Do note + * that bigger this value, the greater shared mem requirement. + * @tparam _rpt Defines the number of rows that a given thread accumulates on. + * This directly results in increased register pressure. This + * also is used to compute the number of m-elements worked upon + * by each thread block. + * @tparam _cpt Defines the number of cols that a given thread accumulates on. + * This directly results in increased register pressure. This + * also is used to compute the number of n-elements worked upon + * by each thread block. + * @tparam _tr Number of threads working on the same output column. This is + * used to compute the number of m-elements worked upon by each + * thread block. This also determines the number of threads per + * thread block + * @tparam _tc Number of threads working on the same output row. This is + * used to compute the number of m-elements worked upon by each + * thread block. This also determines the number of threads per + * thread block + */ +template +struct KernelPolicy { + enum { + /** number of elements along K worked upon per main loop iteration */ + Kblk = _kblk, + /** number of elements loaded per LDG */ + Veclen = _veclen, + /** number of rows a thread works on for accumulation */ + AccRowsPerTh = _rpt, + /** number of cols a thread works on for accumulation */ + AccColsPerTh = _cpt, + /** number of threads working the same output col */ + AccThRows = _tr, + /** number of threads working the same output row */ + AccThCols = _tc, + /** total threads per block */ + Nthreads = AccThRows * AccThCols, + /** output tile size along rows */ + Mblk = AccRowsPerTh * AccThRows, + /** output tile size along cols */ + Nblk = AccColsPerTh * AccThCols, + /** number of threads loading a single row */ + LdgThK = Kblk / Veclen, + /** number of LDGs issued by a single thread for X */ + LdgPerThX = Mblk * LdgThK / Nthreads, + /** number of LDGs issued by a single thread for Y */ + LdgPerThY = Nblk * LdgThK / Nthreads, + /** number of rows of X covered per LDG */ + LdgRowsX = Mblk / LdgPerThX, + /** number of rows of Y covered per LDG */ + LdgRowsY = Nblk / LdgPerThY, + /** stride for accessing X/Y data in shared mem */ + SmemStride = Kblk + Veclen, + /** size of one page for storing X data */ + SmemPageX = SmemStride * Mblk, + /** size of one page for storing Y data */ + SmemPageY = SmemStride * Nblk, + /** size of one smem page */ + SmemPage = SmemPageX + SmemPageY, + /** size (in B) for smem needed */ + SmemSize = 2 * SmemPage * sizeof(DataT), + }; // enum +}; // struct KernelPolicy + +/** + * @defgroup Policy4x4 16 elements per thread Policy with k-block = 32 + * @{ + */ +template +struct Policy4x4 {}; + +template +struct Policy4x4 { + typedef KernelPolicy Policy; +}; + +template +struct Policy4x4 { + typedef KernelPolicy Policy; +}; +/** @} */ + +/** + * @brief Base class for gemm-like NT contractions + * + * This class does not provide any arithmetic operations, but only provides the + * memory-related operations of loading the `x` and `y` matrix blocks from the + * global memory into shared memory and then from shared into registers. Thus, + * this class acts as a basic building block for further composing gemm-like NT + * contractions on input matrices which are row-major (and so does the output) + * + * @tparam DataT IO and math data type + * @tparam IdxT indexing type + * @tparam Policy policy used to customize memory access behavior. + * See documentation for `KernelPolicy` to know more. + */ +template +struct Contractions_NT { + protected: + typedef Policy P; + + /** number of rows in X */ + IdxT m; + /** number of rows in Y */ + IdxT n; + /** number of columns in X and Y */ + IdxT k; + /** current thread's global mem row id for X data */ + IdxT xrowid; + /** current thread's global mem row id for Y data */ + IdxT yrowid; + /** global memory pointer to X matrix */ + const DataT* x; + /** global memory pointer to Y matrix */ + const DataT* y; + + /** current thread's smem row id */ + int srowid; + /** current thread's smem column id */ + int scolid; + /** current thread's accumulation row id */ + int accrowid; + /** current thread's accumulation column id */ + int acccolid; + + /** base smem pointer for X data storage */ + DataT* sx; + /** base smem pointer for Y data storage */ + DataT* sy; + /** index pointing the correct smem page for writing after `ldgXY()` */ + int pageWr; + /** index pointing the correct smem page for reading during `ldsXY()` */ + int pageRd; + + /** block of X data loaded from smem after `ldsXY()` */ + DataT regx[P::AccRowsPerTh][P::Veclen]; + /** block of Y data loaded from smem after `ldsXY()` */ + DataT regy[P::AccColsPerTh][P::Veclen]; + /** block of X data loaded from global mem after `ldgXY()` */ + DataT ldgDataX[P::LdgPerThX][P::Veclen]; + /** block of Y data loaded from global mem after `ldgXY()` */ + DataT ldgDataY[P::LdgPerThY][P::Veclen]; + + static const DataT Zero = (DataT)0; + + public: + /** + * @brief Ctor + * @param[in] _x X matrix. [on device] [dim = _m x _k] [row-major] + * @param[in] _y Y matrix. [on device] [dim = _n x _k] [row-major] + * @param[in] _m number of rows of X + * @param[in] _n number of rows of Y + * @param[in] _k number of cols of X and Y + * @param[in] _smem shared memory region used during computations + */ + DI Contractions_NT(const DataT* _x, const DataT* _y, IdxT _m, IdxT _n, + IdxT _k, char* _smem) + : m(_m), + n(_n), + k(_k), + xrowid(IdxT(blockIdx.x) * P::Mblk + threadIdx.x / P::LdgThK), + yrowid(IdxT(blockIdx.y) * P::Nblk + threadIdx.x / P::LdgThK), + x(_x + xrowid * k), + y(_y + yrowid * k), + srowid(threadIdx.x / P::LdgThK), + scolid((threadIdx.x % P::LdgThK) * P::Veclen), + accrowid(threadIdx.x / P::AccThCols), + acccolid(threadIdx.x % P::AccThCols), + sx((DataT*)_smem), + sy(&(sx[P::SmemPageX])), + pageWr(0), + pageRd(0) {} + + protected: + /** + * @brief Load current block of X/Y from global memory to registers + * @param[in] kidx current start index of k to be loaded + */ + DI void ldgXY(IdxT kidx) { + ldgX(kidx); + ldgY(kidx); + } + + /** + * @brief Store current block of X/Y from registers to smem + * @param[in] kidx current start index of k to be loaded + */ + DI void stsXY() { + stsX(sx + pageWr * P::SmemPage); + stsY(sy + pageWr * P::SmemPage); + } + + /** + * @brief Load X and Y block from shared memory to registers + * @param[in] kidx k value from the current k-block to be loaded from smem + */ + DI void ldsXY(int kidx) { + ldsX(kidx, sx + pageRd * P::SmemPage); + ldsY(kidx, sy + pageRd * P::SmemPage); + } + + private: + DI void ldgX(IdxT kidx) { + auto koffset = kidx + scolid; + for (int i = 0; i < P::LdgPerThX; ++i) { + if (koffset < k && (xrowid + i * P::LdgRowsX) < m) { + ldg(ldgDataX[i], x + i * P::LdgRowsX * k + koffset); + } else { +#pragma unroll + for (int j = 0; j < P::Veclen; ++j) { + ldgDataX[i][j] = Zero; + } + } + } + } + + DI void ldgY(IdxT kidx) { + auto koffset = kidx + scolid; + for (int i = 0; i < P::LdgPerThY; ++i) { + if (koffset < k && (yrowid + i * P::LdgRowsY) < n) { + ldg(ldgDataY[i], y + i * P::LdgRowsY * k + koffset); + } else { +#pragma unroll + for (int j = 0; j < P::Veclen; ++j) { + ldgDataY[i][j] = Zero; + } + } + } + } + + DI void stsX(DataT* smem) { + auto* saddr = smem + srowid * P::SmemStride + scolid; +#pragma unroll + for (int i = 0; i < P::LdgPerThX; ++i) { + sts(saddr + i * P::LdgRowsX * P::SmemStride, ldgDataX[i]); + } + } + + DI void stsY(DataT* smem) { + auto* saddr = smem + srowid * P::SmemStride + scolid; +#pragma unroll + for (int i = 0; i < P::LdgPerThY; ++i) { + sts(saddr + i * P::LdgRowsY * P::SmemStride, ldgDataY[i]); + } + } + + DI void ldsX(int kidx, DataT* smem) { + auto* saddr = smem + accrowid * P::SmemStride + kidx; +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + lds(regx[i], saddr + i * P::AccThRows * P::SmemStride); + } + } + + DI void ldsY(int kidx, DataT* smem) { + auto* saddr = smem + acccolid * P::SmemStride + kidx; +#pragma unroll + for (int i = 0; i < P::AccColsPerTh; ++i) { + lds(regy[i], saddr + i * P::AccThCols * P::SmemStride); + } + } +}; // struct Contractions_NT + +} // namespace linalg +} // namespace raft diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu new file mode 100644 index 0000000000..d4e39a0b5e --- /dev/null +++ b/cpp/test/distance/fused_l2_nn.cu @@ -0,0 +1,346 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include "../test_utils.h" + +namespace raft { +namespace distance { + +template +struct CubKVPMinReduce { + typedef cub::KeyValuePair KVP; + + DI KVP operator()(LabelT rit, const KVP &a, const KVP &b) { + return b.value < a.value ? b : a; + } + + DI KVP operator()(const KVP &a, const KVP &b) { + return b.value < a.value ? b : a; + } + +}; // KVPMinReduce + +template +__global__ void naiveKernel(cub::KeyValuePair *min, DataT *x, + DataT *y, int m, int n, int k, int *workspace, + DataT maxVal) { + int midx = threadIdx.y + blockIdx.y * blockDim.y; + int nidx = threadIdx.x + blockIdx.x * blockDim.x; + DataT acc = DataT(0); + for (int i = 0; i < k; ++i) { + int xidx = i + midx * k; + int yidx = i + nidx * k; + auto diff = midx >= m || nidx >= n ? DataT(0) : x[xidx] - y[yidx]; + acc += diff * diff; + } + if (Sqrt) { + acc = raft::mySqrt(acc); + } + ReduceOpT redOp; + typedef cub::WarpReduce> WarpReduce; + __shared__ typename WarpReduce::TempStorage temp[NWARPS]; + int warpId = threadIdx.x / raft::WarpSize; + cub::KeyValuePair tmp; + tmp.key = nidx; + tmp.value = midx >= m || nidx >= n ? maxVal : acc; + tmp = WarpReduce(temp[warpId]).Reduce(tmp, CubKVPMinReduce()); + if (threadIdx.x % raft::WarpSize == 0 && midx < m) { + while (atomicCAS(workspace + midx, 0, 1) == 1) + ; + __threadfence(); + redOp(midx, min + midx, tmp); + __threadfence(); + atomicCAS(workspace + midx, 1, 0); + } +} + +template +void naive(cub::KeyValuePair *min, DataT *x, DataT *y, int m, int n, + int k, int *workspace, cudaStream_t stream) { + static const dim3 TPB(32, 16, 1); + dim3 nblks(raft::ceildiv(n, (int)TPB.x), raft::ceildiv(m, (int)TPB.y), 1); + CUDA_CHECK(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); + auto blks = raft::ceildiv(m, 256); + MinAndDistanceReduceOp op; + initKernel, int> + <<>>(min, m, std::numeric_limits::max(), op); + CUDA_CHECK(cudaGetLastError()); + naiveKernel, 16> + <<>>(min, x, y, m, n, k, workspace, + std::numeric_limits::max()); + CUDA_CHECK(cudaGetLastError()); +} + +template +struct Inputs { + DataT tolerance; + int m, n, k; + unsigned long long int seed; +}; + +template +class FusedL2NNTest : public ::testing::TestWithParam> { + public: + void SetUp() override { + params = ::testing::TestWithParam>::GetParam(); + raft::random::Rng r(params.seed); + int m = params.m; + int n = params.n; + int k = params.k; + CUDA_CHECK(cudaStreamCreate(&stream)); + raft::allocate(x, m * k); + raft::allocate(y, n * k); + raft::allocate(xn, m); + raft::allocate(yn, n); + raft::allocate(workspace, sizeof(int) * m); + raft::allocate(min, m); + raft::allocate(min_ref, m); + r.uniform(x, m * k, DataT(-1.0), DataT(1.0), stream); + r.uniform(y, n * k, DataT(-1.0), DataT(1.0), stream); + generateGoldenResult(); + raft::linalg::rowNorm(xn, x, k, m, raft::linalg::L2Norm, true, stream); + raft::linalg::rowNorm(yn, y, k, n, raft::linalg::L2Norm, true, stream); + } + + void TearDown() override { + CUDA_CHECK(cudaStreamSynchronize(stream)); + CUDA_CHECK(cudaStreamDestroy(stream)); + CUDA_CHECK(cudaFree(x)); + CUDA_CHECK(cudaFree(y)); + CUDA_CHECK(cudaFree(xn)); + CUDA_CHECK(cudaFree(yn)); + CUDA_CHECK(cudaFree(workspace)); + CUDA_CHECK(cudaFree(min_ref)); + CUDA_CHECK(cudaFree(min)); + } + + protected: + Inputs params; + DataT *x, *y, *xn, *yn; + char *workspace; + cub::KeyValuePair *min, *min_ref; + cudaStream_t stream; + + virtual void generateGoldenResult() { + int m = params.m; + int n = params.n; + int k = params.k; + naive(min_ref, x, y, m, n, k, (int *)workspace, stream); + } + + void runTest(cub::KeyValuePair *out) { + int m = params.m; + int n = params.n; + int k = params.k; + MinAndDistanceReduceOp redOp; + fusedL2NN, int>( + out, x, y, xn, yn, m, n, k, (void *)workspace, redOp, + raft::distance::KVPMinReduce(), Sqrt, true, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + } +}; + +template +struct CompareApproxAbsKVP { + typedef typename cub::KeyValuePair KVP; + CompareApproxAbsKVP(T eps_) : eps(eps_) {} + bool operator()(const KVP &a, const KVP &b) const { + if (a.key != b.key) return false; + T diff = raft::abs(raft::abs(a.value) - raft::abs(b.value)); + T m = std::max(raft::abs(a.value), raft::abs(b.value)); + T ratio = m >= eps ? diff / m : diff; + return (ratio <= eps); + } + + private: + T eps; +}; + +template +struct CompareExactKVP { + typedef typename cub::KeyValuePair KVP; + bool operator()(const KVP &a, const KVP &b) const { + if (a.key != b.key) return false; + if (a.value != b.value) return false; + return true; + } +}; + +template +::testing::AssertionResult devArrMatch(const cub::KeyValuePair *expected, + const cub::KeyValuePair *actual, + size_t size, L eq_compare, + cudaStream_t stream = 0) { + typedef typename cub::KeyValuePair KVP; + std::shared_ptr exp_h(new KVP[size]); + std::shared_ptr act_h(new KVP[size]); + raft::update_host(exp_h.get(), expected, size, stream); + raft::update_host(act_h.get(), actual, size, stream); + CUDA_CHECK(cudaStreamSynchronize(stream)); + for (size_t i(0); i < size; ++i) { + auto exp = exp_h.get()[i]; + auto act = act_h.get()[i]; + if (!eq_compare(exp, act)) { + return ::testing::AssertionFailure() + << "actual=" << act.key << "," << act.value + << " != expected=" << exp.key << "," << exp.value << " @" << i; + } + } + return ::testing::AssertionSuccess(); +} + +const std::vector> inputsf = { + {0.001f, 32, 32, 32, 1234ULL}, {0.001f, 32, 64, 32, 1234ULL}, + {0.001f, 64, 32, 32, 1234ULL}, {0.001f, 64, 64, 32, 1234ULL}, + {0.001f, 128, 32, 32, 1234ULL}, {0.001f, 128, 64, 32, 1234ULL}, + {0.001f, 128, 128, 64, 1234ULL}, {0.001f, 64, 128, 128, 1234ULL}, + + {0.001f, 32, 32, 34, 1234ULL}, {0.001f, 32, 64, 34, 1234ULL}, + {0.001f, 64, 32, 34, 1234ULL}, {0.001f, 64, 64, 34, 1234ULL}, + {0.001f, 128, 32, 34, 1234ULL}, {0.001f, 128, 64, 34, 1234ULL}, + {0.001f, 128, 128, 66, 1234ULL}, {0.001f, 64, 128, 130, 1234ULL}, + + {0.001f, 32, 32, 33, 1234ULL}, {0.001f, 32, 64, 33, 1234ULL}, + {0.001f, 64, 32, 33, 1234ULL}, {0.001f, 64, 64, 33, 1234ULL}, + {0.001f, 128, 32, 33, 1234ULL}, {0.001f, 128, 64, 33, 1234ULL}, + {0.001f, 128, 128, 65, 1234ULL}, {0.001f, 64, 128, 129, 1234ULL}, + + {0.006f, 1805, 134, 2, 1234ULL}, +}; +typedef FusedL2NNTest FusedL2NNTestF_Sq; +TEST_P(FusedL2NNTestF_Sq, Result) { + runTest(min); + ASSERT_TRUE(devArrMatch(min_ref, min, params.m, + CompareApproxAbsKVP(params.tolerance))); +} +INSTANTIATE_TEST_CASE_P(FusedL2NNTests, FusedL2NNTestF_Sq, + ::testing::ValuesIn(inputsf)); +typedef FusedL2NNTest FusedL2NNTestF_Sqrt; +TEST_P(FusedL2NNTestF_Sqrt, Result) { + runTest(min); + ASSERT_TRUE(devArrMatch(min_ref, min, params.m, + CompareApproxAbsKVP(params.tolerance))); +} +INSTANTIATE_TEST_CASE_P(FusedL2NNTests, FusedL2NNTestF_Sqrt, + ::testing::ValuesIn(inputsf)); + +const std::vector> inputsd = { + {0.00001, 32, 32, 32, 1234ULL}, {0.00001, 32, 64, 32, 1234ULL}, + {0.00001, 64, 32, 32, 1234ULL}, {0.00001, 64, 64, 32, 1234ULL}, + {0.00001, 128, 32, 32, 1234ULL}, {0.00001, 128, 64, 32, 1234ULL}, + {0.00001, 128, 128, 64, 1234ULL}, {0.00001, 64, 128, 128, 1234ULL}, + + {0.00001, 32, 32, 34, 1234ULL}, {0.00001, 32, 64, 34, 1234ULL}, + {0.00001, 64, 32, 34, 1234ULL}, {0.00001, 64, 64, 34, 1234ULL}, + {0.00001, 128, 32, 34, 1234ULL}, {0.00001, 128, 64, 34, 1234ULL}, + {0.00001, 128, 128, 66, 1234ULL}, {0.00001, 64, 128, 130, 1234ULL}, + + {0.00001, 32, 32, 33, 1234ULL}, {0.00001, 32, 64, 33, 1234ULL}, + {0.00001, 64, 32, 33, 1234ULL}, {0.00001, 64, 64, 33, 1234ULL}, + {0.00001, 128, 32, 33, 1234ULL}, {0.00001, 128, 64, 33, 1234ULL}, + {0.00001, 128, 128, 65, 1234ULL}, {0.00001, 64, 128, 129, 1234ULL}, + + {0.00001, 1805, 134, 2, 1234ULL}, +}; +typedef FusedL2NNTest FusedL2NNTestD_Sq; +TEST_P(FusedL2NNTestD_Sq, Result) { + runTest(min); + ASSERT_TRUE(devArrMatch(min_ref, min, params.m, + CompareApproxAbsKVP(params.tolerance))); +} +INSTANTIATE_TEST_CASE_P(FusedL2NNTests, FusedL2NNTestD_Sq, + ::testing::ValuesIn(inputsd)); +typedef FusedL2NNTest FusedL2NNTestD_Sqrt; +TEST_P(FusedL2NNTestD_Sqrt, Result) { + runTest(min); + ASSERT_TRUE(devArrMatch(min_ref, min, params.m, + CompareApproxAbsKVP(params.tolerance))); +} +INSTANTIATE_TEST_CASE_P(FusedL2NNTests, FusedL2NNTestD_Sqrt, + ::testing::ValuesIn(inputsd)); + +/// This is to test output determinism of the prim +template +class FusedL2NNDetTest : public FusedL2NNTest { + void SetUp() override { + FusedL2NNTest::SetUp(); + int m = this->params.m; + raft::allocate(min1, m); + } + + void TearDown() override { + FusedL2NNTest::TearDown(); + CUDA_CHECK(cudaFree(min1)); + } + + protected: + cub::KeyValuePair *min1; + + static const int NumRepeats = 100; + + void generateGoldenResult() override {} +}; + +typedef FusedL2NNDetTest FusedL2NNDetTestF_Sq; +TEST_P(FusedL2NNDetTestF_Sq, Result) { + runTest(min); // assumed to be golden + for (int i = 0; i < NumRepeats; ++i) { + runTest(min1); + ASSERT_TRUE(devArrMatch(min, min1, params.m, CompareExactKVP())); + } +} +INSTANTIATE_TEST_CASE_P(FusedL2NNDetTests, FusedL2NNDetTestF_Sq, + ::testing::ValuesIn(inputsf)); +typedef FusedL2NNDetTest FusedL2NNDetTestF_Sqrt; +TEST_P(FusedL2NNDetTestF_Sqrt, Result) { + runTest(min); // assumed to be golden + for (int i = 0; i < NumRepeats; ++i) { + runTest(min1); + ASSERT_TRUE(devArrMatch(min, min1, params.m, CompareExactKVP())); + } +} +INSTANTIATE_TEST_CASE_P(FusedL2NNDetTests, FusedL2NNDetTestF_Sqrt, + ::testing::ValuesIn(inputsf)); + +typedef FusedL2NNDetTest FusedL2NNDetTestD_Sq; +TEST_P(FusedL2NNDetTestD_Sq, Result) { + runTest(min); // assumed to be golden + for (int i = 0; i < NumRepeats; ++i) { + runTest(min1); + ASSERT_TRUE(devArrMatch(min, min1, params.m, CompareExactKVP())); + } +} +INSTANTIATE_TEST_CASE_P(FusedL2NNDetTests, FusedL2NNDetTestD_Sq, + ::testing::ValuesIn(inputsd)); +typedef FusedL2NNDetTest FusedL2NNDetTestD_Sqrt; +TEST_P(FusedL2NNDetTestD_Sqrt, Result) { + runTest(min); // assumed to be golden + for (int i = 0; i < NumRepeats; ++i) { + runTest(min1); + ASSERT_TRUE(devArrMatch(min, min1, params.m, CompareExactKVP())); + } +} +INSTANTIATE_TEST_CASE_P(FusedL2NNDetTests, FusedL2NNDetTestD_Sqrt, + ::testing::ValuesIn(inputsd)); + +} // end namespace distance +} // end namespace raft