Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move contractions tiling logic outside of Contractions_NT #837

Merged
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
a3b8587
contractions: Concentrate tile index calculations
ahendriksen Sep 2, 2022
99e65a5
pairwise_distance_base: Remove all ldgXY(0) calls
ahendriksen Sep 2, 2022
e6d5078
pairwise_distance_base: Move all logic into run loop
ahendriksen Sep 2, 2022
995d2ae
pairwise_distance_base: Fix typo
ahendriksen Oct 5, 2022
e6976c5
Implement reviewer feedback
ahendriksen Jan 23, 2023
4947dc8
Merge branch 'branch-23.02' into wip-move-contractions-tiling-logic
cjnolet Jan 25, 2023
ba6491a
Merge branch 'branch-23.02' into wip-move-contractions-tiling-logic
cjnolet Jan 25, 2023
e52b0f9
Forcing sccache reinit.
cjnolet Jan 26, 2023
34eb76a
Merge branch 'branch-23.02' into wip-move-contractions-tiling-logic
cjnolet Jan 26, 2023
85c6294
Breaking specializations for refine into individual files
cjnolet Jan 26, 2023
0fad842
Checking in
cjnolet Jan 26, 2023
f7788af
Including just the refine specialization
cjnolet Jan 26, 2023
e626101
Merge branch 'branch-23.02' into wip-move-contractions-tiling-logic
cjnolet Jan 26, 2023
9e7b729
Proper import of speicalizations
cjnolet Jan 26, 2023
9e4b5f3
Merge branch 'wip-move-contractions-tiling-logic' of github.com:ahend…
cjnolet Jan 26, 2023
060e62c
Remove SCCACHE_RECACHE from build.sh
cjnolet Jan 26, 2023
2370c18
Fixing build errro
cjnolet Jan 26, 2023
d0a5ea4
Fixing remaining compile errors
cjnolet Jan 26, 2023
6291723
Adding specializations to cmakelists
cjnolet Jan 26, 2023
e3ea7ed
Rename distance_inputs to distance_params
ahendriksen Jan 27, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 96 additions & 111 deletions cpp/include/raft/distance/detail/pairwise_distance_base.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -59,6 +59,7 @@ namespace detail {
* @param core_op the core accumulation operation lambda
* @param epilog_op the epilog operation lambda
* @param fin_op the final gemm epilogue lambda
* @param rowEpilog_op epilog lambda that executes when a full row has been processed
*/

template <bool useNorms,
Expand Down Expand Up @@ -87,6 +88,11 @@ struct PairwiseDistances : public BaseClass {
FinalLambda fin_op;
rowEpilogueLambda rowEpilog_op;
ahendriksen marked this conversation as resolved.
Show resolved Hide resolved

const IdxT grid_stride_m;
const IdxT grid_stride_n;
const IdxT grid_offset_m;
const IdxT grid_offset_n;

AccT acc[P::AccRowsPerTh][P::AccColsPerTh];

public:
Expand Down Expand Up @@ -116,96 +122,83 @@ struct PairwiseDistances : public BaseClass {
core_op(_core_op),
epilog_op(_epilog_op),
fin_op(_fin_op),
rowEpilog_op(_rowEpilog_op)
rowEpilog_op(_rowEpilog_op),
grid_stride_m(P::Mblk * gridDim.y),
grid_stride_n(P::Nblk * gridDim.x),
grid_offset_m(P::Mblk * blockIdx.y),
grid_offset_n(P::Nblk * blockIdx.x)
{
}

DI void run()
{
for (auto gridStrideY = blockIdx.y * P::Mblk; gridStrideY < this->m;
gridStrideY += P::Mblk * gridDim.y) {
for (auto gridStrideX = blockIdx.x * P::Nblk; gridStrideX < this->n;
gridStrideX += P::Nblk * gridDim.x) {
prolog(gridStrideX, gridStrideY);
loop();
epilog(gridStrideX, gridStrideY);
for (auto tile_idx_m = grid_offset_m; tile_idx_m < this->m; tile_idx_m += grid_stride_m) {
this->ldgXY(tile_idx_m, grid_offset_n, 0);
for (auto tile_idx_n = grid_offset_n; tile_idx_n < this->n; tile_idx_n += grid_stride_n) {
// Prolog:
reset_accumulator();
ahendriksen marked this conversation as resolved.
Show resolved Hide resolved
this->stsXY();
__syncthreads();
this->switch_write_buffer();

// Main loop:
for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) {
this->ldgXY(tile_idx_m, tile_idx_n, kidx);
// Process all data in shared memory (previous k-block) and
// accumulate in registers.
accumulate();
this->stsXY();
__syncthreads();
this->switch_write_buffer();
this->switch_read_buffer();
}
accumulate(); // last iteration
// The pre-condition for the loop over tile_idx_n is that write_buffer
// and read_buffer point to the same buffer. This flips read_buffer back
// so that it satisfies the pre-condition of this loop.
this->switch_read_buffer();

// Epilog:
if (useNorms) {
DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh];
load_norms(tile_idx_m, tile_idx_n, regxn, regyn);
// Overlap ldg with epilog computation
ldgNextGridStride(tile_idx_m, tile_idx_n);
epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m);
} else {
// Overlap ldg with epilog computation
ldgNextGridStride(tile_idx_m, tile_idx_n);
epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m);
}
if (writeOut) { store_output(tile_idx_m, tile_idx_n); }
}
rowEpilog_op(gridStrideY);
rowEpilog_op(tile_idx_m);
}
}

private:
DI void updateIndicesY()
{
const auto stride = P::Nblk * gridDim.x;
if (isRowMajor) {
this->y += stride * this->ldb;
} else {
this->y += stride;
}
this->yrowid += stride;
}

DI void updateIndicesXY()
{
const auto stride = P::Mblk * gridDim.y;
if (isRowMajor) {
this->x += stride * this->lda;
this->yrowid = IdxT(blockIdx.x) * P::Nblk + this->srowid;
this->y = yBase + this->yrowid * this->ldb;
} else {
this->x += stride;
this->yrowid = IdxT(blockIdx.x) * P::Nblk;
this->y = yBase + this->yrowid + this->srowid * this->ldb;
}
this->xrowid += stride;
}

DI void ldgNextGridStride(IdxT gridStrideX, IdxT gridStrideY)
DI void ldgNextGridStride(IdxT tile_idx_m, IdxT tile_idx_n)
{
// Fetch next grid stride ldg if within range
if ((gridStrideX + gridDim.x * P::Nblk) < this->n) {
updateIndicesY();
this->ldgXY(0);
} else if ((gridStrideY + gridDim.y * P::Mblk) < this->m) {
updateIndicesXY();
this->ldgXY(0);
const auto next_tile_tile_idx_n = tile_idx_n + grid_stride_n;
const auto next_tile_tile_idx_m = tile_idx_m + grid_stride_m;
if ((next_tile_tile_idx_n) < this->n) {
this->ldgXY(tile_idx_m, next_tile_tile_idx_n, 0);
} else if ((next_tile_tile_idx_m) < this->m) {
this->ldgXY(next_tile_tile_idx_m, grid_offset_n, 0);
}
}

DI void prolog(IdxT gridStrideX, IdxT gridStrideY)
DI void reset_accumulator()
{
if (gridStrideX == blockIdx.x * P::Nblk) { this->ldgXY(0); }

// Reset accumulator registers to zero.
#pragma unroll
for (int i = 0; i < P::AccRowsPerTh; ++i) {
#pragma unroll
for (int j = 0; j < P::AccColsPerTh; ++j) {
acc[i][j] = BaseClass::Zero;
}
}

this->stsXY();
__syncthreads();
this->pageWr ^= 1;
}

DI void loop()
{
for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) {
this->ldgXY(kidx);
accumulate(); // on the previous k-block
this->stsXY();
__syncthreads();
this->pageWr ^= 1;
this->pageRd ^= 1;
}
accumulate(); // last iteration
// This is needed for making sure next grid stride of
// non-norm based metrics uses previously accumulated buffer so
// it doesn't make shmem dirty until previous iteration
// is complete.
this->pageRd ^= 1;
}

DI void accumulate()
Expand All @@ -226,60 +219,52 @@ struct PairwiseDistances : public BaseClass {
}
}

DI void epilog(IdxT gridStrideX, IdxT gridStrideY)
DI void load_norms(IdxT tile_idx_m,
IdxT tile_idx_n,
DataT (&regxn)[P::AccRowsPerTh],
DataT (&regyn)[P::AccColsPerTh])
{
if (useNorms) {
DataT* sxNorm = (DataT*)(&smem[P::SmemSize]);
DataT* syNorm = (&sxNorm[P::Mblk]);

// Load x & y norms required by this threadblock in shmem buffer
if (gridStrideX == blockIdx.x * P::Nblk) {
for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) {
auto idx = gridStrideY + i;
sxNorm[i] = idx < this->m ? xn[idx] : 0;
}
}

for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) {
auto idx = gridStrideX + i;
syNorm[i] = idx < this->n ? yn[idx] : 0;
DataT* sxNorm = (DataT*)(&smem[P::SmemSize]);
DataT* syNorm = (&sxNorm[P::Mblk]);

// Load x & y norms required by this threadblock in shmem buffer
if (tile_idx_n == blockIdx.x * P::Nblk) {
for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) {
auto idx = tile_idx_m + i;
sxNorm[i] = idx < this->m ? xn[idx] : 0;
}
}

__syncthreads();
for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) {
auto idx = tile_idx_n + i;
syNorm[i] = idx < this->n ? yn[idx] : 0;
}
__syncthreads();

DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh];
#pragma unroll
for (int i = 0; i < P::AccRowsPerTh; ++i) {
regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)];
}
for (int i = 0; i < P::AccRowsPerTh; ++i) {
regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)];
}
#pragma unroll
for (int i = 0; i < P::AccColsPerTh; ++i) {
regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)];
}

// Overlap ldg with epilog computation
ldgNextGridStride(gridStrideX, gridStrideY);
epilog_op(acc, regxn, regyn, gridStrideX, gridStrideY);
} else {
// Overlap ldg with epilog computation
ldgNextGridStride(gridStrideX, gridStrideY);
epilog_op(acc, nullptr, nullptr, gridStrideX, gridStrideY);
for (int i = 0; i < P::AccColsPerTh; ++i) {
regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)];
}
}

if (writeOut) {
IdxT starty = gridStrideY + this->accrowid;
IdxT startx = gridStrideX + this->acccolid;
DI void store_output(IdxT tile_idx_m, IdxT tile_idx_n)
{
IdxT starty = tile_idx_m + this->accrowid;
IdxT startx = tile_idx_n + this->acccolid;

#pragma unroll
for (int i = 0; i < P::AccRowsPerTh; ++i) {
auto rowId = starty + i * P::AccThRows;
for (int i = 0; i < P::AccRowsPerTh; ++i) {
auto rowId = starty + i * P::AccThRows;
#pragma unroll
for (int j = 0; j < P::AccColsPerTh; ++j) {
auto colId = startx + j * P::AccThCols;
if (rowId < this->m && colId < this->n) {
// Promote to 64 bit index for final write, as output array can be > 2^31
dOutput[std::size_t(rowId) * this->n + colId] = fin_op(acc[i][j], 0);
}
for (int j = 0; j < P::AccColsPerTh; ++j) {
auto colId = startx + j * P::AccThCols;
if (rowId < this->m && colId < this->n) {
// Promote to 64 bit index for final write, as output array can be > 2^31
dOutput[std::size_t(rowId) * this->n + colId] = fin_op(acc[i][j], 0);
}
}
}
Expand Down
49 changes: 22 additions & 27 deletions cpp/include/raft/linalg/detail/contractions.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,10 @@ struct Contractions_NT {
/** leading dimension in Output D */
IdxT ldd;

/** current thread's global mem row id for X data */
IdxT xrowid;
/** current thread's global mem row id for Y data */
IdxT yrowid;
/** global memory pointer to X matrix */
const DataT* x;
const DataT* x_base;
/** global memory pointer to Y matrix */
const DataT* y;
const DataT* y_base;

/** current thread's smem row id */
int srowid;
Expand Down Expand Up @@ -94,10 +90,8 @@ struct Contractions_NT {
k(_k),
lda(_k),
ldb(_k),
xrowid(IdxT(blockIdx.x) * P::Mblk + threadIdx.x / P::LdgThRow),
yrowid(IdxT(blockIdx.y) * P::Nblk + threadIdx.x / P::LdgThRow),
x(_x + xrowid * lda),
y(_y + yrowid * ldb),
x_base(_x),
y_base(_y),
srowid(threadIdx.x / P::LdgThRow),
scolid((threadIdx.x % P::LdgThRow) * P::Veclen),
accrowid(threadIdx.x / P::AccThCols),
Expand Down Expand Up @@ -133,6 +127,8 @@ struct Contractions_NT {
lda(_lda),
ldb(_ldb),
ldd(_ldd),
x_base(_x),
y_base(_y),
srowid(threadIdx.x / P::LdgThRow),
scolid((threadIdx.x % P::LdgThRow) * P::Veclen),
accrowid(threadIdx.x / P::AccThCols),
Expand All @@ -142,28 +138,17 @@ struct Contractions_NT {
pageWr(0),
pageRd(0)
{
if (isRowMajor) {
xrowid = IdxT(blockIdx.y) * P::Mblk + srowid;
yrowid = IdxT(blockIdx.x) * P::Nblk + srowid;
x = _x + xrowid * lda;
y = _y + yrowid * ldb;
} else {
xrowid = IdxT(blockIdx.y) * P::Mblk;
yrowid = IdxT(blockIdx.x) * P::Nblk;
x = _x + xrowid + srowid * lda;
y = _y + yrowid + srowid * ldb;
}
}

protected:
/**
* @brief Load current block of X/Y from global memory to registers
* @param[in] kidx current start index of k to be loaded
*/
DI void ldgXY(IdxT kidx)
DI void ldgXY(IdxT tile_idx_m, IdxT tile_idx_n, IdxT kidx)
{
ldgX(kidx);
ldgY(kidx);
ldgX(tile_idx_m, kidx);
ldgY(tile_idx_n, kidx);
}

/**
Expand All @@ -186,9 +171,16 @@ struct Contractions_NT {
ldsY(kidx, sy + pageRd * P::SmemPage);
}

DI void switch_read_buffer() { this->pageRd ^= 1; }

DI void switch_write_buffer() { this->pageWr ^= 1; }

private:
DI void ldgX(IdxT kidx)
DI void ldgX(IdxT tile_idx_m, IdxT kidx)
{
IdxT xrowid = isRowMajor ? tile_idx_m + srowid : tile_idx_m;
auto x = isRowMajor ? x_base + xrowid * lda : x_base + xrowid + srowid * lda;

if (isRowMajor) {
auto numRows = m;
auto koffset = kidx + scolid;
Expand Down Expand Up @@ -220,8 +212,11 @@ struct Contractions_NT {
}
}

DI void ldgY(IdxT kidx)
DI void ldgY(IdxT tile_idx_n, IdxT kidx)
{
IdxT yrowid = isRowMajor ? tile_idx_n + srowid : tile_idx_n;
auto y = isRowMajor ? y_base + yrowid * ldb : y_base + yrowid + srowid * ldb;

if (isRowMajor) {
auto numRows = n;
auto koffset = kidx + scolid;
Expand Down Expand Up @@ -315,4 +310,4 @@ struct Contractions_NT {

} // namespace detail
} // namespace linalg
} // namespace raft
} // namespace raft
Loading