Skip to content

Commit

Permalink
pairwise_distance_base: Move all logic into run loop
Browse files Browse the repository at this point in the history
This commit removes the epilog function and moves its functionality into
the run loop. The next step might be to see if the ldgNextGridStride()
method has to be called the current location, or if performance is the
same if its called at the start of the loop.
  • Loading branch information
ahendriksen committed Sep 22, 2022
1 parent e118560 commit b5aef79
Showing 1 changed file with 57 additions and 71 deletions.
128 changes: 57 additions & 71 deletions cpp/include/raft/distance/detail/pairwise_distance_base.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ struct PairwiseDistances : public BaseClass {
FinalLambda fin_op;
rowEpilogueLambda rowEpilog_op;


const IdxT grid_stride_m;
const IdxT grid_stride_n;
const IdxT grid_offset_m;
Expand Down Expand Up @@ -140,14 +139,14 @@ struct PairwiseDistances : public BaseClass {
this->switch_write_buffer();

for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) {
this->ldgXY(tile_idx_m, tile_idx_n, kidx);
// Process all data in shared memory (previous k-block) and
// accumulate in registers.
accumulate();
this->stsXY();
__syncthreads();
this->switch_write_buffer();
this->switch_read_buffer();
this->ldgXY(tile_idx_m, tile_idx_n, kidx);
// Process all data in shared memory (previous k-block) and
// accumulate in registers.
accumulate();
this->stsXY();
__syncthreads();
this->switch_write_buffer();
this->switch_read_buffer();
}
accumulate(); // last iteration
// This is needed for making sure next grid stride of
Expand All @@ -156,14 +155,25 @@ struct PairwiseDistances : public BaseClass {
// is complete.
this->switch_read_buffer();

epilog(tile_idx_n, tile_idx_m);
if (useNorms) {
DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh];
load_norms(tile_idx_m, tile_idx_n, regxn, regyn);
// Overlap ldg with epilog computation
ldgNextGridStride(tile_idx_m, tile_idx_n);
epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m);
} else {
// Overlap ldg with epilog computation
ldgNextGridStride(tile_idx_m, tile_idx_n);
epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m);
}
if (writeOut) { store_output(tile_idx_m, tile_idx_n); }
}
rowEpilog_op(tile_idx_m);
}
}

private:
DI void ldgNextGridStride(IdxT tile_idx_n, IdxT tile_idx_m)
DI void ldgNextGridStride(IdxT tile_idx_m, IdxT tile_idx_n)
{
// Fetch next grid stride ldg if within range
const auto next_tile_tile_idx_n = tile_idx_n + grid_stride_n;
Expand All @@ -175,24 +185,8 @@ struct PairwiseDistances : public BaseClass {
}
}

DI void prolog(IdxT tile_idx_n, IdxT tile_idx_m)
DI void reset_accumulator()
{
if (tile_idx_n == blockIdx.x * P::Nblk) { this->ldgXY(0); }

#pragma unroll
for (int i = 0; i < P::AccRowsPerTh; ++i) {
#pragma unroll
for (int j = 0; j < P::AccColsPerTh; ++j) {
acc[i][j] = BaseClass::Zero;
}
}

this->stsXY();
__syncthreads();
this->switch_write_buffer();
}

DI void reset_accumulator() {
// Reset accumulator registers to zero.
#pragma unroll
for (int i = 0; i < P::AccRowsPerTh; ++i) {
Expand Down Expand Up @@ -221,60 +215,52 @@ struct PairwiseDistances : public BaseClass {
}
}

DI void epilog(IdxT tile_idx_n, IdxT tile_idx_m)
DI void load_norms(IdxT tile_idx_m,
IdxT tile_idx_n,
DataT (&regxn)[P::AccRowsPerTh],
DataT (&regyn)[P::AccColsPerTh])
{
if (useNorms) {
DataT* sxNorm = (DataT*)(&smem[P::SmemSize]);
DataT* syNorm = (&sxNorm[P::Mblk]);

// Load x & y norms required by this threadblock in shmem buffer
if (tile_idx_n == blockIdx.x * P::Nblk) {
for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) {
auto idx = tile_idx_m + i;
sxNorm[i] = idx < this->m ? xn[idx] : 0;
}
}

for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) {
auto idx = tile_idx_n + i;
syNorm[i] = idx < this->n ? yn[idx] : 0;
DataT* sxNorm = (DataT*)(&smem[P::SmemSize]);
DataT* syNorm = (&sxNorm[P::Mblk]);

// Load x & y norms required by this threadblock in shmem buffer
if (tile_idx_n == blockIdx.x * P::Nblk) {
for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) {
auto idx = tile_idx_m + i;
sxNorm[i] = idx < this->m ? xn[idx] : 0;
}
}

__syncthreads();
for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) {
auto idx = tile_idx_n + i;
syNorm[i] = idx < this->n ? yn[idx] : 0;
}
__syncthreads();

DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh];
#pragma unroll
for (int i = 0; i < P::AccRowsPerTh; ++i) {
regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)];
}
for (int i = 0; i < P::AccRowsPerTh; ++i) {
regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)];
}
#pragma unroll
for (int i = 0; i < P::AccColsPerTh; ++i) {
regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)];
}

// Overlap ldg with epilog computation
ldgNextGridStride(tile_idx_n, tile_idx_m);
epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m);
} else {
// Overlap ldg with epilog computation
ldgNextGridStride(tile_idx_n, tile_idx_m);
epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m);
for (int i = 0; i < P::AccColsPerTh; ++i) {
regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)];
}
}

if (writeOut) {
IdxT starty = tile_idx_m + this->accrowid;
IdxT startx = tile_idx_n + this->acccolid;
DI void store_output(IdxT tile_idx_m, IdxT tile_idx_n)
{
IdxT starty = tile_idx_m + this->accrowid;
IdxT startx = tile_idx_n + this->acccolid;

#pragma unroll
for (int i = 0; i < P::AccRowsPerTh; ++i) {
auto rowId = starty + i * P::AccThRows;
for (int i = 0; i < P::AccRowsPerTh; ++i) {
auto rowId = starty + i * P::AccThRows;
#pragma unroll
for (int j = 0; j < P::AccColsPerTh; ++j) {
auto colId = startx + j * P::AccThCols;
if (rowId < this->m && colId < this->n) {
// Promote to 64 bit index for final write, as output array can be > 2^31
dOutput[std::size_t(rowId) * this->n + colId] = fin_op(acc[i][j], 0);
}
for (int j = 0; j < P::AccColsPerTh; ++j) {
auto colId = startx + j * P::AccThCols;
if (rowId < this->m && colId < this->n) {
// Promote to 64 bit index for final write, as output array can be > 2^31
dOutput[std::size_t(rowId) * this->n + colId] = fin_op(acc[i][j], 0);
}
}
}
Expand Down

0 comments on commit b5aef79

Please sign in to comment.