diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index c6850b290f..a6341f6dda 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -284,7 +284,18 @@ if(RAFT_COMPILE_DIST_LIBRARY) src/distance/cluster/update_centroids_double.cu src/distance/cluster/cluster_cost_float.cu src/distance/cluster/cluster_cost_double.cu - src/distance/neighbors/refine.cu + src/distance/neighbors/refine_d_uint64_t_float.cu + src/distance/neighbors/refine_d_uint64_t_int8_t.cu + src/distance/neighbors/refine_d_uint64_t_uint8_t.cu + src/distance/neighbors/refine_h_uint64_t_float.cu + src/distance/neighbors/refine_h_uint64_t_int8_t.cu + src/distance/neighbors/refine_h_uint64_t_uint8_t.cu + src/distance/neighbors/specializations/refine_d_uint64_t_float.cu + src/distance/neighbors/specializations/refine_d_uint64_t_int8_t.cu + src/distance/neighbors/specializations/refine_d_uint64_t_uint8_t.cu + src/distance/neighbors/specializations/refine_h_uint64_t_float.cu + src/distance/neighbors/specializations/refine_h_uint64_t_int8_t.cu + src/distance/neighbors/specializations/refine_h_uint64_t_uint8_t.cu src/distance/neighbors/ivfpq_search.cu src/distance/cluster/kmeans_fit_float.cu src/distance/cluster/kmeans_fit_double.cu diff --git a/cpp/bench/distance/distance_common.cuh b/cpp/bench/distance/distance_common.cuh index 73faacce37..1be00ec0c7 100644 --- a/cpp/bench/distance/distance_common.cuh +++ b/cpp/bench/distance/distance_common.cuh @@ -24,14 +24,14 @@ namespace raft::bench::distance { -struct distance_inputs { +struct distance_params { int m, n, k; bool isRowMajor; -}; // struct distance_inputs +}; // struct distance_params template struct distance : public fixture { - distance(const distance_inputs& p) + distance(const distance_params& p) : params(p), x(p.m * p.k, stream), y(p.n * p.k, stream), @@ -63,13 +63,13 @@ struct distance : public fixture { } private: - distance_inputs params; + distance_params params; rmm::device_uvector x, y, out; rmm::device_uvector workspace; size_t worksize; }; // struct Distance -const std::vector dist_input_vecs{ +const std::vector dist_input_vecs{ {32, 16384, 16384, true}, {64, 16384, 16384, true}, {128, 16384, 16384, true}, {256, 16384, 16384, true}, {512, 16384, 16384, true}, {1024, 16384, 16384, true}, {16384, 32, 16384, true}, {16384, 64, 16384, true}, {16384, 128, 16384, true}, diff --git a/cpp/bench/neighbors/knn.cuh b/cpp/bench/neighbors/knn.cuh index 60eb8c257d..eec1cba99e 100644 --- a/cpp/bench/neighbors/knn.cuh +++ b/cpp/bench/neighbors/knn.cuh @@ -32,6 +32,9 @@ #include #if defined RAFT_DISTANCE_COMPILED #include +#include +#else +#pragma message("NN / Distance specializations are not enabled; expect very long building times.") #endif #endif diff --git a/cpp/bench/neighbors/refine.cu b/cpp/bench/neighbors/refine.cu index 3349b8b6ae..f32af3a57e 100644 --- a/cpp/bench/neighbors/refine.cu +++ b/cpp/bench/neighbors/refine.cu @@ -27,6 +27,7 @@ #if defined RAFT_DISTANCE_COMPILED #include +#include #endif #if defined RAFT_NN_COMPILED @@ -52,7 +53,7 @@ inline auto operator<<(std::ostream& os, const RefineInputs& p) -> std::os return os; } -RefineInputs p; +RefineInputs p; template class RefineAnn : public fixture { @@ -98,24 +99,24 @@ class RefineAnn : public fixture { RefineHelper data; }; -std::vector> getInputs() +std::vector> getInputs() { - std::vector> out; + std::vector> out; raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded; for (bool host_data : {true, false}) { - for (int64_t n_queries : {1000, 10000}) { - for (int64_t dim : {128, 512}) { - out.push_back(RefineInputs{n_queries, 2000000, dim, 32, 128, metric, host_data}); - out.push_back(RefineInputs{n_queries, 2000000, dim, 10, 40, metric, host_data}); + for (uint64_t n_queries : {1000, 10000}) { + for (uint64_t dim : {128, 512}) { + out.push_back(RefineInputs{n_queries, 2000000, dim, 32, 128, metric, host_data}); + out.push_back(RefineInputs{n_queries, 2000000, dim, 10, 40, metric, host_data}); } } } return out; } -using refine_float_int64 = RefineAnn; -RAFT_BENCH_REGISTER(refine_float_int64, "", getInputs()); +using refine_float_uint64 = RefineAnn; +RAFT_BENCH_REGISTER(refine_float_uint64, "", getInputs()); -using refine_uint8_int64 = RefineAnn; -RAFT_BENCH_REGISTER(refine_uint8_int64, "", getInputs()); +using refine_uint8_uint64 = RefineAnn; +RAFT_BENCH_REGISTER(refine_uint8_uint64, "", getInputs()); } // namespace raft::bench::neighbors diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index 69bb83d29a..d849b23999 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -59,6 +59,7 @@ namespace detail { * @param core_op the core accumulation operation lambda * @param epilog_op the epilog operation lambda * @param fin_op the final gemm epilogue lambda + * @param rowEpilog_op epilog lambda that executes when a full row has been processed */ template m; - gridStrideY += P::Mblk * gridDim.y) { - for (auto gridStrideX = blockIdx.x * P::Nblk; gridStrideX < this->n; - gridStrideX += P::Nblk * gridDim.x) { - prolog(gridStrideX, gridStrideY); - loop(); - epilog(gridStrideX, gridStrideY); + for (auto tile_idx_m = grid_offset_m; tile_idx_m < this->m; tile_idx_m += grid_stride_m) { + this->ldgXY(tile_idx_m, grid_offset_n, 0); + for (auto tile_idx_n = grid_offset_n; tile_idx_n < this->n; tile_idx_n += grid_stride_n) { + // Prolog: + reset_accumulator(); + this->stsXY(); + __syncthreads(); + this->switch_write_buffer(); + + // Main loop: + for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { + this->ldgXY(tile_idx_m, tile_idx_n, kidx); + // Process all data in shared memory (previous k-block) and + // accumulate in registers. + accumulate(); + this->stsXY(); + __syncthreads(); + this->switch_write_buffer(); + this->switch_read_buffer(); + } + accumulate(); // last iteration + // The pre-condition for the loop over tile_idx_n is that write_buffer + // and read_buffer point to the same buffer. This flips read_buffer back + // so that it satisfies the pre-condition of this loop. + this->switch_read_buffer(); + + // Epilog: + if (useNorms) { + DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; + load_norms(tile_idx_m, tile_idx_n, regxn, regyn); + // Overlap ldg with epilog computation + ldgNextGridStride(tile_idx_m, tile_idx_n); + epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m); + } else { + // Overlap ldg with epilog computation + ldgNextGridStride(tile_idx_m, tile_idx_n); + epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); + } + if (writeOut) { store_output(tile_idx_m, tile_idx_n); } } - rowEpilog_op(gridStrideY); + rowEpilog_op(tile_idx_m); } } private: - DI void updateIndicesY() - { - const auto stride = P::Nblk * gridDim.x; - if (isRowMajor) { - this->y += stride * this->ldb; - } else { - this->y += stride; - } - this->yrowid += stride; - } - - DI void updateIndicesXY() - { - const auto stride = P::Mblk * gridDim.y; - if (isRowMajor) { - this->x += stride * this->lda; - this->yrowid = IdxT(blockIdx.x) * P::Nblk + this->srowid; - this->y = yBase + this->yrowid * this->ldb; - } else { - this->x += stride; - this->yrowid = IdxT(blockIdx.x) * P::Nblk; - this->y = yBase + this->yrowid + this->srowid * this->ldb; - } - this->xrowid += stride; - } - - DI void ldgNextGridStride(IdxT gridStrideX, IdxT gridStrideY) + DI void ldgNextGridStride(IdxT tile_idx_m, IdxT tile_idx_n) { // Fetch next grid stride ldg if within range - if ((gridStrideX + gridDim.x * P::Nblk) < this->n) { - updateIndicesY(); - this->ldgXY(0); - } else if ((gridStrideY + gridDim.y * P::Mblk) < this->m) { - updateIndicesXY(); - this->ldgXY(0); + const auto next_tile_tile_idx_n = tile_idx_n + grid_stride_n; + const auto next_tile_tile_idx_m = tile_idx_m + grid_stride_m; + if ((next_tile_tile_idx_n) < this->n) { + this->ldgXY(tile_idx_m, next_tile_tile_idx_n, 0); + } else if ((next_tile_tile_idx_m) < this->m) { + this->ldgXY(next_tile_tile_idx_m, grid_offset_n, 0); } } - DI void prolog(IdxT gridStrideX, IdxT gridStrideY) + DI void reset_accumulator() { - if (gridStrideX == blockIdx.x * P::Nblk) { this->ldgXY(0); } - + // Reset accumulator registers to zero. #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll @@ -184,28 +199,6 @@ struct PairwiseDistances : public BaseClass { acc[i][j] = BaseClass::Zero; } } - - this->stsXY(); - __syncthreads(); - this->pageWr ^= 1; - } - - DI void loop() - { - for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { - this->ldgXY(kidx); - accumulate(); // on the previous k-block - this->stsXY(); - __syncthreads(); - this->pageWr ^= 1; - this->pageRd ^= 1; - } - accumulate(); // last iteration - // This is needed for making sure next grid stride of - // non-norm based metrics uses previously accumulated buffer so - // it doesn't make shmem dirty until previous iteration - // is complete. - this->pageRd ^= 1; } DI void accumulate() @@ -226,60 +219,52 @@ struct PairwiseDistances : public BaseClass { } } - DI void epilog(IdxT gridStrideX, IdxT gridStrideY) + DI void load_norms(IdxT tile_idx_m, + IdxT tile_idx_n, + DataT (®xn)[P::AccRowsPerTh], + DataT (®yn)[P::AccColsPerTh]) { - if (useNorms) { - DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); - DataT* syNorm = (&sxNorm[P::Mblk]); - - // Load x & y norms required by this threadblock in shmem buffer - if (gridStrideX == blockIdx.x * P::Nblk) { - for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { - auto idx = gridStrideY + i; - sxNorm[i] = idx < this->m ? xn[idx] : 0; - } - } - - for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { - auto idx = gridStrideX + i; - syNorm[i] = idx < this->n ? yn[idx] : 0; + DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); + DataT* syNorm = (&sxNorm[P::Mblk]); + + // Load x & y norms required by this threadblock in shmem buffer + if (tile_idx_n == blockIdx.x * P::Nblk) { + for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { + auto idx = tile_idx_m + i; + sxNorm[i] = idx < this->m ? xn[idx] : 0; } + } - __syncthreads(); + for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { + auto idx = tile_idx_n + i; + syNorm[i] = idx < this->n ? yn[idx] : 0; + } + __syncthreads(); - DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; #pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)]; - } + for (int i = 0; i < P::AccRowsPerTh; ++i) { + regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)]; + } #pragma unroll - for (int i = 0; i < P::AccColsPerTh; ++i) { - regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; - } - - // Overlap ldg with epilog computation - ldgNextGridStride(gridStrideX, gridStrideY); - epilog_op(acc, regxn, regyn, gridStrideX, gridStrideY); - } else { - // Overlap ldg with epilog computation - ldgNextGridStride(gridStrideX, gridStrideY); - epilog_op(acc, nullptr, nullptr, gridStrideX, gridStrideY); + for (int i = 0; i < P::AccColsPerTh; ++i) { + regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; } + } - if (writeOut) { - IdxT starty = gridStrideY + this->accrowid; - IdxT startx = gridStrideX + this->acccolid; + DI void store_output(IdxT tile_idx_m, IdxT tile_idx_n) + { + IdxT starty = tile_idx_m + this->accrowid; + IdxT startx = tile_idx_n + this->acccolid; #pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - auto rowId = starty + i * P::AccThRows; + for (int i = 0; i < P::AccRowsPerTh; ++i) { + auto rowId = starty + i * P::AccThRows; #pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - auto colId = startx + j * P::AccThCols; - if (rowId < this->m && colId < this->n) { - // Promote to 64 bit index for final write, as output array can be > 2^31 - dOutput[std::size_t(rowId) * this->n + colId] = fin_op(acc[i][j], 0); - } + for (int j = 0; j < P::AccColsPerTh; ++j) { + auto colId = startx + j * P::AccThCols; + if (rowId < this->m && colId < this->n) { + // Promote to 64 bit index for final write, as output array can be > 2^31 + dOutput[std::size_t(rowId) * this->n + colId] = fin_op(acc[i][j], 0); } } } diff --git a/cpp/include/raft/linalg/detail/contractions.cuh b/cpp/include/raft/linalg/detail/contractions.cuh index 5d83f88e71..f2d71117f7 100644 --- a/cpp/include/raft/linalg/detail/contractions.cuh +++ b/cpp/include/raft/linalg/detail/contractions.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -40,14 +40,10 @@ struct Contractions_NT { /** leading dimension in Output D */ IdxT ldd; - /** current thread's global mem row id for X data */ - IdxT xrowid; - /** current thread's global mem row id for Y data */ - IdxT yrowid; /** global memory pointer to X matrix */ - const DataT* x; + const DataT* x_base; /** global memory pointer to Y matrix */ - const DataT* y; + const DataT* y_base; /** current thread's smem row id */ int srowid; @@ -94,10 +90,8 @@ struct Contractions_NT { k(_k), lda(_k), ldb(_k), - xrowid(IdxT(blockIdx.x) * P::Mblk + threadIdx.x / P::LdgThRow), - yrowid(IdxT(blockIdx.y) * P::Nblk + threadIdx.x / P::LdgThRow), - x(_x + xrowid * lda), - y(_y + yrowid * ldb), + x_base(_x), + y_base(_y), srowid(threadIdx.x / P::LdgThRow), scolid((threadIdx.x % P::LdgThRow) * P::Veclen), accrowid(threadIdx.x / P::AccThCols), @@ -133,6 +127,8 @@ struct Contractions_NT { lda(_lda), ldb(_ldb), ldd(_ldd), + x_base(_x), + y_base(_y), srowid(threadIdx.x / P::LdgThRow), scolid((threadIdx.x % P::LdgThRow) * P::Veclen), accrowid(threadIdx.x / P::AccThCols), @@ -142,17 +138,6 @@ struct Contractions_NT { pageWr(0), pageRd(0) { - if (isRowMajor) { - xrowid = IdxT(blockIdx.y) * P::Mblk + srowid; - yrowid = IdxT(blockIdx.x) * P::Nblk + srowid; - x = _x + xrowid * lda; - y = _y + yrowid * ldb; - } else { - xrowid = IdxT(blockIdx.y) * P::Mblk; - yrowid = IdxT(blockIdx.x) * P::Nblk; - x = _x + xrowid + srowid * lda; - y = _y + yrowid + srowid * ldb; - } } protected: @@ -160,10 +145,10 @@ struct Contractions_NT { * @brief Load current block of X/Y from global memory to registers * @param[in] kidx current start index of k to be loaded */ - DI void ldgXY(IdxT kidx) + DI void ldgXY(IdxT tile_idx_m, IdxT tile_idx_n, IdxT kidx) { - ldgX(kidx); - ldgY(kidx); + ldgX(tile_idx_m, kidx); + ldgY(tile_idx_n, kidx); } /** @@ -186,9 +171,16 @@ struct Contractions_NT { ldsY(kidx, sy + pageRd * P::SmemPage); } + DI void switch_read_buffer() { this->pageRd ^= 1; } + + DI void switch_write_buffer() { this->pageWr ^= 1; } + private: - DI void ldgX(IdxT kidx) + DI void ldgX(IdxT tile_idx_m, IdxT kidx) { + IdxT xrowid = isRowMajor ? tile_idx_m + srowid : tile_idx_m; + auto x = isRowMajor ? x_base + xrowid * lda : x_base + xrowid + srowid * lda; + if (isRowMajor) { auto numRows = m; auto koffset = kidx + scolid; @@ -220,8 +212,11 @@ struct Contractions_NT { } } - DI void ldgY(IdxT kidx) + DI void ldgY(IdxT tile_idx_n, IdxT kidx) { + IdxT yrowid = isRowMajor ? tile_idx_n + srowid : tile_idx_n; + auto y = isRowMajor ? y_base + yrowid * ldb : y_base + yrowid + srowid * ldb; + if (isRowMajor) { auto numRows = n; auto koffset = kidx + scolid; @@ -315,4 +310,4 @@ struct Contractions_NT { } // namespace detail } // namespace linalg -} // namespace raft \ No newline at end of file +} // namespace raft diff --git a/cpp/include/raft/neighbors/specializations.cuh b/cpp/include/raft/neighbors/specializations.cuh index 0511bbbf6c..d17467c8a7 100644 --- a/cpp/include/raft/neighbors/specializations.cuh +++ b/cpp/include/raft/neighbors/specializations.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,9 +20,8 @@ #pragma once #include +#include #include #include - -#include - +#include #endif diff --git a/cpp/include/raft/neighbors/specializations/refine.cuh b/cpp/include/raft/neighbors/specializations/refine.cuh new file mode 100644 index 0000000000..71e83a26f3 --- /dev/null +++ b/cpp/include/raft/neighbors/specializations/refine.cuh @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft::neighbors { + +#ifdef RAFT_INST +#undef RAFT_INST +#endif + +#define RAFT_INST(T, IdxT) \ + extern template void refine( \ + raft::device_resources const& handle, \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbor_candidates, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + distance::DistanceType metric); \ + \ + extern template void refine( \ + raft::device_resources const& handle, \ + raft::host_matrix_view dataset, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbor_candidates, \ + raft::host_matrix_view indices, \ + raft::host_matrix_view distances, \ + distance::DistanceType metric); + +RAFT_INST(float, uint64_t); +RAFT_INST(uint8_t, uint64_t); +RAFT_INST(int8_t, uint64_t); + +#undef RAFT_INST +} // namespace raft::neighbors diff --git a/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh b/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh index 19862d743d..7616083796 100644 --- a/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh +++ b/cpp/include/raft/spatial/knn/detail/epsilon_neighborhood.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -64,7 +64,7 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass { private: DI void prolog() { - this->ldgXY(0); + this->ldgXY(IdxT(blockIdx.x) * P::Mblk, IdxT(blockIdx.y) * P::Nblk, 0); #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll @@ -74,18 +74,18 @@ struct EpsUnexpL2SqNeighborhood : public BaseClass { } this->stsXY(); __syncthreads(); - this->pageWr ^= 1; + this->switch_write_buffer(); } DI void loop() { for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { - this->ldgXY(kidx); + this->ldgXY(IdxT(blockIdx.x) * P::Mblk, IdxT(blockIdx.y) * P::Nblk, kidx); accumulate(); // on the previous k-block this->stsXY(); __syncthreads(); - this->pageWr ^= 1; - this->pageRd ^= 1; + this->switch_write_buffer(); + this->switch_read_buffer(); } accumulate(); // last iteration } diff --git a/cpp/src/distance/neighbors/refine.cu b/cpp/src/distance/neighbors/refine.cu deleted file mode 100644 index 83e3383cba..0000000000 --- a/cpp/src/distance/neighbors/refine.cu +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace raft::runtime::neighbors { - -#define RAFT_INST_REFINE(IDX_T, DATA_T) \ - void refine(raft::device_resources const& handle, \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbor_candidates, \ - raft::device_matrix_view indices, \ - raft::device_matrix_view distances, \ - distance::DistanceType metric) \ - { \ - raft::neighbors::detail::refine_device( \ - handle, dataset, queries, neighbor_candidates, indices, distances, metric); \ - } \ - \ - void refine(raft::device_resources const& handle, \ - raft::host_matrix_view dataset, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbor_candidates, \ - raft::host_matrix_view indices, \ - raft::host_matrix_view distances, \ - distance::DistanceType metric) \ - { \ - raft::neighbors::detail::refine_host( \ - dataset, queries, neighbor_candidates, indices, distances, metric); \ - } - -RAFT_INST_REFINE(uint64_t, float); -RAFT_INST_REFINE(uint64_t, uint8_t); -RAFT_INST_REFINE(uint64_t, int8_t); - -#undef RAFT_INST_REFINE - -} // namespace raft::runtime::neighbors diff --git a/cpp/src/distance/neighbors/refine_d_uint64_t_float.cu b/cpp/src/distance/neighbors/refine_d_uint64_t_float.cu new file mode 100644 index 0000000000..d7b460180a --- /dev/null +++ b/cpp/src/distance/neighbors/refine_d_uint64_t_float.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft::runtime::neighbors { + +void refine(raft::device_resources const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric) +{ + raft::neighbors::refine( + handle, dataset, queries, neighbor_candidates, indices, distances, metric); +} + +} // namespace raft::runtime::neighbors diff --git a/cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu b/cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu new file mode 100644 index 0000000000..3db07f0cdb --- /dev/null +++ b/cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft::runtime::neighbors { + +void refine(raft::device_resources const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric) +{ + raft::neighbors::refine( + handle, dataset, queries, neighbor_candidates, indices, distances, metric); +} + +} // namespace raft::runtime::neighbors diff --git a/cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu b/cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu new file mode 100644 index 0000000000..2ce43d5800 --- /dev/null +++ b/cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft::runtime::neighbors { + +void refine(raft::device_resources const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric) +{ + raft::neighbors::refine( + handle, dataset, queries, neighbor_candidates, indices, distances, metric); +} + +} // namespace raft::runtime::neighbors diff --git a/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu b/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu new file mode 100644 index 0000000000..2a2dcff3bf --- /dev/null +++ b/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft::runtime::neighbors { + +void refine(raft::device_resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric) +{ + raft::neighbors::refine( + handle, dataset, queries, neighbor_candidates, indices, distances, metric); +} + +} // namespace raft::runtime::neighbors diff --git a/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu b/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu new file mode 100644 index 0000000000..d7c60b62a5 --- /dev/null +++ b/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft::runtime::neighbors { +void refine(raft::device_resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric) +{ + raft::neighbors::refine( + handle, dataset, queries, neighbor_candidates, indices, distances, metric); +} + +} // namespace raft::runtime::neighbors diff --git a/cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu b/cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu new file mode 100644 index 0000000000..e9c4345e97 --- /dev/null +++ b/cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft::runtime::neighbors { + +void refine(raft::device_resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric) +{ + raft::neighbors::refine( + handle, dataset, queries, neighbor_candidates, indices, distances, metric); +} + +} // namespace raft::runtime::neighbors diff --git a/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_float.cu b/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_float.cu new file mode 100644 index 0000000000..6bb1985d94 --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_float.cu @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors { + +template void refine( + raft::device_resources const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric); + +} // namespace raft::neighbors diff --git a/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_int8_t.cu b/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_int8_t.cu new file mode 100644 index 0000000000..7e70ee5e29 --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_int8_t.cu @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors { + +template void refine( + raft::device_resources const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric); + +} // namespace raft::neighbors diff --git a/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_uint8_t.cu b/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_uint8_t.cu new file mode 100644 index 0000000000..53de106ef9 --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_uint8_t.cu @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors { + +template void refine( + raft::device_resources const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric); + +} // namespace raft::neighbors diff --git a/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_float.cu b/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_float.cu new file mode 100644 index 0000000000..b473924741 --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_float.cu @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors { + +template void refine( + raft::device_resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric); + +} // namespace raft::neighbors diff --git a/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_int8_t.cu b/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_int8_t.cu new file mode 100644 index 0000000000..c8b0e4c1c2 --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_int8_t.cu @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors { +template void refine( + raft::device_resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric); + +} // namespace raft::neighbors diff --git a/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_uint8_t.cu b/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_uint8_t.cu new file mode 100644 index 0000000000..b9e0f58ef6 --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_uint8_t.cu @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors { + +template void refine( + raft::device_resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric); + +} // namespace raft::neighbors diff --git a/cpp/test/neighbors/refine.cu b/cpp/test/neighbors/refine.cu index 98933046b9..a78f5cfe5c 100644 --- a/cpp/test/neighbors/refine.cu +++ b/cpp/test/neighbors/refine.cu @@ -31,7 +31,7 @@ #include -#if defined RAFT_NN_COMPILED +#if defined RAFT_DISTANCE_COMPILED #include #endif @@ -107,26 +107,26 @@ class RefineTest : public ::testing::TestWithParam> { RefineHelper data; }; -const std::vector> inputs = - raft::util::itertools::product>( - {137}, - {1000}, - {16}, - {1, 10, 33}, - {33}, +const std::vector> inputs = + raft::util::itertools::product>( + {static_cast(137)}, + {static_cast(1000)}, + {static_cast(16)}, + {static_cast(1), static_cast(10), static_cast(33)}, + {static_cast(33)}, {raft::distance::DistanceType::L2Expanded, raft::distance::DistanceType::InnerProduct}, {false, true}); -typedef RefineTest RefineTestF; +typedef RefineTest RefineTestF; TEST_P(RefineTestF, AnnRefine) { this->testRefine(); } INSTANTIATE_TEST_CASE_P(RefineTest, RefineTestF, ::testing::ValuesIn(inputs)); -typedef RefineTest RefineTestF_uint8; +typedef RefineTest RefineTestF_uint8; TEST_P(RefineTestF_uint8, AnnRefine) { this->testRefine(); } INSTANTIATE_TEST_CASE_P(RefineTest, RefineTestF_uint8, ::testing::ValuesIn(inputs)); -typedef RefineTest RefineTestF_int8; +typedef RefineTest RefineTestF_int8; TEST_P(RefineTestF_int8, AnnRefine) { this->testRefine(); } INSTANTIATE_TEST_CASE_P(RefineTest, RefineTestF_int8, ::testing::ValuesIn(inputs)); } // namespace raft::neighbors