Skip to content

Commit

Permalink
Implement reviewer feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
ahendriksen authored and Allard Hendriksen committed Jan 24, 2023
1 parent 995d2ae commit e6976c5
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions cpp/include/raft/distance/detail/pairwise_distance_base.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -59,6 +59,7 @@ namespace detail {
* @param core_op the core accumulation operation lambda
* @param epilog_op the epilog operation lambda
* @param fin_op the final gemm epilogue lambda
* @param rowEpilog_op epilog lambda that executes when a full row has been processed
*/

template <bool useNorms,
Expand Down Expand Up @@ -134,11 +135,13 @@ struct PairwiseDistances : public BaseClass {
for (auto tile_idx_m = grid_offset_m; tile_idx_m < this->m; tile_idx_m += grid_stride_m) {
this->ldgXY(tile_idx_m, grid_offset_n, 0);
for (auto tile_idx_n = grid_offset_n; tile_idx_n < this->n; tile_idx_n += grid_stride_n) {
// Prolog:
reset_accumulator();
this->stsXY();
__syncthreads();
this->switch_write_buffer();

// Main loop:
for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) {
this->ldgXY(tile_idx_m, tile_idx_n, kidx);
// Process all data in shared memory (previous k-block) and
Expand All @@ -150,12 +153,12 @@ struct PairwiseDistances : public BaseClass {
this->switch_read_buffer();
}
accumulate(); // last iteration
// This is needed for making sure next grid stride of
// non-norm based metrics uses previously accumulated buffer so
// it doesn't make shmem dirty until previous iteration
// is complete.
// The pre-condition for the loop over tile_idx_n is that write_buffer
// and read_buffer point to the same buffer. This flips read_buffer back
// so that it satisfies the pre-condition of this loop.
this->switch_read_buffer();

// Epilog:
if (useNorms) {
DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh];
load_norms(tile_idx_m, tile_idx_n, regxn, regyn);
Expand Down

0 comments on commit e6976c5

Please sign in to comment.