diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index b28c3a3de4..d849b23999 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -59,6 +59,7 @@ namespace detail { * @param core_op the core accumulation operation lambda * @param epilog_op the epilog operation lambda * @param fin_op the final gemm epilogue lambda + * @param rowEpilog_op epilog lambda that executes when a full row has been processed */ template m; tile_idx_m += grid_stride_m) { this->ldgXY(tile_idx_m, grid_offset_n, 0); for (auto tile_idx_n = grid_offset_n; tile_idx_n < this->n; tile_idx_n += grid_stride_n) { + // Prolog: reset_accumulator(); this->stsXY(); __syncthreads(); this->switch_write_buffer(); + // Main loop: for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { this->ldgXY(tile_idx_m, tile_idx_n, kidx); // Process all data in shared memory (previous k-block) and @@ -150,12 +153,12 @@ struct PairwiseDistances : public BaseClass { this->switch_read_buffer(); } accumulate(); // last iteration - // This is needed for making sure next grid stride of - // non-norm based metrics uses previously accumulated buffer so - // it doesn't make shmem dirty until previous iteration - // is complete. + // The pre-condition for the loop over tile_idx_n is that write_buffer + // and read_buffer point to the same buffer. This flips read_buffer back + // so that it satisfies the pre-condition of this loop. this->switch_read_buffer(); + // Epilog: if (useNorms) { DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; load_norms(tile_idx_m, tile_idx_n, regxn, regyn);