Skip to content

Commit

Permalink
Revert "Move contractions tiling logic outside of Contractions_NT (ra…
Browse files Browse the repository at this point in the history
…pidsai#837)"

This reverts commit c58d00a.
  • Loading branch information
cjnolet committed Feb 2, 2023
1 parent bff0f16 commit 776fc19
Show file tree
Hide file tree
Showing 23 changed files with 230 additions and 605 deletions.
15 changes: 2 additions & 13 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# =============================================================================
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
# Copyright (c) 2020-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
Expand Down Expand Up @@ -293,18 +293,7 @@ if(RAFT_COMPILE_DIST_LIBRARY)
src/distance/cluster/update_centroids_double.cu
src/distance/cluster/cluster_cost_float.cu
src/distance/cluster/cluster_cost_double.cu
src/distance/neighbors/refine_d_uint64_t_float.cu
src/distance/neighbors/refine_d_uint64_t_int8_t.cu
src/distance/neighbors/refine_d_uint64_t_uint8_t.cu
src/distance/neighbors/refine_h_uint64_t_float.cu
src/distance/neighbors/refine_h_uint64_t_int8_t.cu
src/distance/neighbors/refine_h_uint64_t_uint8_t.cu
src/distance/neighbors/specializations/refine_d_uint64_t_float.cu
src/distance/neighbors/specializations/refine_d_uint64_t_int8_t.cu
src/distance/neighbors/specializations/refine_d_uint64_t_uint8_t.cu
src/distance/neighbors/specializations/refine_h_uint64_t_float.cu
src/distance/neighbors/specializations/refine_h_uint64_t_int8_t.cu
src/distance/neighbors/specializations/refine_h_uint64_t_uint8_t.cu
src/distance/neighbors/refine.cu
src/distance/neighbors/ivfpq_search.cu
src/distance/cluster/kmeans_fit_float.cu
src/distance/cluster/kmeans_fit_double.cu
Expand Down
10 changes: 5 additions & 5 deletions cpp/bench/distance/distance_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@

namespace raft::bench::distance {

struct distance_params {
struct distance_inputs {
int m, n, k;
bool isRowMajor;
}; // struct distance_params
}; // struct distance_inputs

template <typename T, raft::distance::DistanceType DType>
struct distance : public fixture {
distance(const distance_params& p)
distance(const distance_inputs& p)
: params(p),
x(p.m * p.k, stream),
y(p.n * p.k, stream),
Expand Down Expand Up @@ -63,13 +63,13 @@ struct distance : public fixture {
}

private:
distance_params params;
distance_inputs params;
rmm::device_uvector<T> x, y, out;
rmm::device_uvector<char> workspace;
size_t worksize;
}; // struct Distance

const std::vector<distance_params> dist_input_vecs{
const std::vector<distance_inputs> dist_input_vecs{
{32, 16384, 16384, true}, {64, 16384, 16384, true}, {128, 16384, 16384, true},
{256, 16384, 16384, true}, {512, 16384, 16384, true}, {1024, 16384, 16384, true},
{16384, 32, 16384, true}, {16384, 64, 16384, true}, {16384, 128, 16384, true},
Expand Down
3 changes: 0 additions & 3 deletions cpp/bench/neighbors/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@
#include <raft/spatial/knn/specializations.cuh>
#if defined RAFT_DISTANCE_COMPILED
#include <raft/cluster/specializations.cuh>
#include <raft/neighbors/specializations.cuh>
#else
#pragma message("NN / Distance specializations are not enabled; expect very long building times.")
#endif
#endif

Expand Down
23 changes: 11 additions & 12 deletions cpp/bench/neighbors/refine.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

#if defined RAFT_DISTANCE_COMPILED
#include <raft/distance/specializations.cuh>
#include <raft/neighbors/specializations.cuh>
#endif

#if defined RAFT_NN_COMPILED
Expand All @@ -53,7 +52,7 @@ inline auto operator<<(std::ostream& os, const RefineInputs<IdxT>& p) -> std::os
return os;
}

RefineInputs<uint64_t> p;
RefineInputs<int64_t> p;

template <typename DataT, typename DistanceT, typename IdxT>
class RefineAnn : public fixture {
Expand Down Expand Up @@ -99,24 +98,24 @@ class RefineAnn : public fixture {
RefineHelper<DataT, DistanceT, IdxT> data;
};

std::vector<RefineInputs<uint64_t>> getInputs()
std::vector<RefineInputs<int64_t>> getInputs()
{
std::vector<RefineInputs<uint64_t>> out;
std::vector<RefineInputs<int64_t>> out;
raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded;
for (bool host_data : {true, false}) {
for (uint64_t n_queries : {1000, 10000}) {
for (uint64_t dim : {128, 512}) {
out.push_back(RefineInputs<uint64_t>{n_queries, 2000000, dim, 32, 128, metric, host_data});
out.push_back(RefineInputs<uint64_t>{n_queries, 2000000, dim, 10, 40, metric, host_data});
for (int64_t n_queries : {1000, 10000}) {
for (int64_t dim : {128, 512}) {
out.push_back(RefineInputs<int64_t>{n_queries, 2000000, dim, 32, 128, metric, host_data});
out.push_back(RefineInputs<int64_t>{n_queries, 2000000, dim, 10, 40, metric, host_data});
}
}
}
return out;
}

using refine_float_uint64 = RefineAnn<float, float, uint64_t>;
RAFT_BENCH_REGISTER(refine_float_uint64, "", getInputs());
using refine_float_int64 = RefineAnn<float, float, int64_t>;
RAFT_BENCH_REGISTER(refine_float_int64, "", getInputs());

using refine_uint8_uint64 = RefineAnn<uint8_t, float, uint64_t>;
RAFT_BENCH_REGISTER(refine_uint8_uint64, "", getInputs());
using refine_uint8_int64 = RefineAnn<uint8_t, float, int64_t>;
RAFT_BENCH_REGISTER(refine_uint8_int64, "", getInputs());
} // namespace raft::bench::neighbors
207 changes: 111 additions & 96 deletions cpp/include/raft/distance/detail/pairwise_distance_base.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -59,7 +59,6 @@ namespace detail {
* @param core_op the core accumulation operation lambda
* @param epilog_op the epilog operation lambda
* @param fin_op the final gemm epilogue lambda
* @param rowEpilog_op epilog lambda that executes when a full row has been processed
*/

template <bool useNorms,
Expand Down Expand Up @@ -88,11 +87,6 @@ struct PairwiseDistances : public BaseClass {
FinalLambda fin_op;
rowEpilogueLambda rowEpilog_op;

const IdxT grid_stride_m;
const IdxT grid_stride_n;
const IdxT grid_offset_m;
const IdxT grid_offset_n;

AccT acc[P::AccRowsPerTh][P::AccColsPerTh];

public:
Expand Down Expand Up @@ -122,83 +116,96 @@ struct PairwiseDistances : public BaseClass {
core_op(_core_op),
epilog_op(_epilog_op),
fin_op(_fin_op),
rowEpilog_op(_rowEpilog_op),
grid_stride_m(P::Mblk * gridDim.y),
grid_stride_n(P::Nblk * gridDim.x),
grid_offset_m(P::Mblk * blockIdx.y),
grid_offset_n(P::Nblk * blockIdx.x)
rowEpilog_op(_rowEpilog_op)
{
}

DI void run()
{
for (auto tile_idx_m = grid_offset_m; tile_idx_m < this->m; tile_idx_m += grid_stride_m) {
this->ldgXY(tile_idx_m, grid_offset_n, 0);
for (auto tile_idx_n = grid_offset_n; tile_idx_n < this->n; tile_idx_n += grid_stride_n) {
// Prolog:
reset_accumulator();
this->stsXY();
__syncthreads();
this->switch_write_buffer();

// Main loop:
for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) {
this->ldgXY(tile_idx_m, tile_idx_n, kidx);
// Process all data in shared memory (previous k-block) and
// accumulate in registers.
accumulate();
this->stsXY();
__syncthreads();
this->switch_write_buffer();
this->switch_read_buffer();
}
accumulate(); // last iteration
// The pre-condition for the loop over tile_idx_n is that write_buffer
// and read_buffer point to the same buffer. This flips read_buffer back
// so that it satisfies the pre-condition of this loop.
this->switch_read_buffer();

// Epilog:
if (useNorms) {
DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh];
load_norms(tile_idx_m, tile_idx_n, regxn, regyn);
// Overlap ldg with epilog computation
ldgNextGridStride(tile_idx_m, tile_idx_n);
epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m);
} else {
// Overlap ldg with epilog computation
ldgNextGridStride(tile_idx_m, tile_idx_n);
epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m);
}
if (writeOut) { store_output(tile_idx_m, tile_idx_n); }
for (auto gridStrideY = blockIdx.y * P::Mblk; gridStrideY < this->m;
gridStrideY += P::Mblk * gridDim.y) {
for (auto gridStrideX = blockIdx.x * P::Nblk; gridStrideX < this->n;
gridStrideX += P::Nblk * gridDim.x) {
prolog(gridStrideX, gridStrideY);
loop();
epilog(gridStrideX, gridStrideY);
}
rowEpilog_op(tile_idx_m);
rowEpilog_op(gridStrideY);
}
}

private:
DI void ldgNextGridStride(IdxT tile_idx_m, IdxT tile_idx_n)
DI void updateIndicesY()
{
const auto stride = P::Nblk * gridDim.x;
if (isRowMajor) {
this->y += stride * this->ldb;
} else {
this->y += stride;
}
this->yrowid += stride;
}

DI void updateIndicesXY()
{
const auto stride = P::Mblk * gridDim.y;
if (isRowMajor) {
this->x += stride * this->lda;
this->yrowid = IdxT(blockIdx.x) * P::Nblk + this->srowid;
this->y = yBase + this->yrowid * this->ldb;
} else {
this->x += stride;
this->yrowid = IdxT(blockIdx.x) * P::Nblk;
this->y = yBase + this->yrowid + this->srowid * this->ldb;
}
this->xrowid += stride;
}

DI void ldgNextGridStride(IdxT gridStrideX, IdxT gridStrideY)
{
// Fetch next grid stride ldg if within range
const auto next_tile_tile_idx_n = tile_idx_n + grid_stride_n;
const auto next_tile_tile_idx_m = tile_idx_m + grid_stride_m;
if ((next_tile_tile_idx_n) < this->n) {
this->ldgXY(tile_idx_m, next_tile_tile_idx_n, 0);
} else if ((next_tile_tile_idx_m) < this->m) {
this->ldgXY(next_tile_tile_idx_m, grid_offset_n, 0);
if ((gridStrideX + gridDim.x * P::Nblk) < this->n) {
updateIndicesY();
this->ldgXY(0);
} else if ((gridStrideY + gridDim.y * P::Mblk) < this->m) {
updateIndicesXY();
this->ldgXY(0);
}
}

DI void reset_accumulator()
DI void prolog(IdxT gridStrideX, IdxT gridStrideY)
{
// Reset accumulator registers to zero.
if (gridStrideX == blockIdx.x * P::Nblk) { this->ldgXY(0); }

#pragma unroll
for (int i = 0; i < P::AccRowsPerTh; ++i) {
#pragma unroll
for (int j = 0; j < P::AccColsPerTh; ++j) {
acc[i][j] = BaseClass::Zero;
}
}

this->stsXY();
__syncthreads();
this->pageWr ^= 1;
}

DI void loop()
{
for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) {
this->ldgXY(kidx);
accumulate(); // on the previous k-block
this->stsXY();
__syncthreads();
this->pageWr ^= 1;
this->pageRd ^= 1;
}
accumulate(); // last iteration
// This is needed for making sure next grid stride of
// non-norm based metrics uses previously accumulated buffer so
// it doesn't make shmem dirty until previous iteration
// is complete.
this->pageRd ^= 1;
}

DI void accumulate()
Expand All @@ -219,52 +226,60 @@ struct PairwiseDistances : public BaseClass {
}
}

DI void load_norms(IdxT tile_idx_m,
IdxT tile_idx_n,
DataT (&regxn)[P::AccRowsPerTh],
DataT (&regyn)[P::AccColsPerTh])
DI void epilog(IdxT gridStrideX, IdxT gridStrideY)
{
DataT* sxNorm = (DataT*)(&smem[P::SmemSize]);
DataT* syNorm = (&sxNorm[P::Mblk]);

// Load x & y norms required by this threadblock in shmem buffer
if (tile_idx_n == blockIdx.x * P::Nblk) {
for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) {
auto idx = tile_idx_m + i;
sxNorm[i] = idx < this->m ? xn[idx] : 0;
if (useNorms) {
DataT* sxNorm = (DataT*)(&smem[P::SmemSize]);
DataT* syNorm = (&sxNorm[P::Mblk]);

// Load x & y norms required by this threadblock in shmem buffer
if (gridStrideX == blockIdx.x * P::Nblk) {
for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) {
auto idx = gridStrideY + i;
sxNorm[i] = idx < this->m ? xn[idx] : 0;
}
}
}

for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) {
auto idx = tile_idx_n + i;
syNorm[i] = idx < this->n ? yn[idx] : 0;
}
__syncthreads();
for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) {
auto idx = gridStrideX + i;
syNorm[i] = idx < this->n ? yn[idx] : 0;
}

__syncthreads();

DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh];
#pragma unroll
for (int i = 0; i < P::AccRowsPerTh; ++i) {
regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)];
}
for (int i = 0; i < P::AccRowsPerTh; ++i) {
regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)];
}
#pragma unroll
for (int i = 0; i < P::AccColsPerTh; ++i) {
regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)];
for (int i = 0; i < P::AccColsPerTh; ++i) {
regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)];
}

// Overlap ldg with epilog computation
ldgNextGridStride(gridStrideX, gridStrideY);
epilog_op(acc, regxn, regyn, gridStrideX, gridStrideY);
} else {
// Overlap ldg with epilog computation
ldgNextGridStride(gridStrideX, gridStrideY);
epilog_op(acc, nullptr, nullptr, gridStrideX, gridStrideY);
}
}

DI void store_output(IdxT tile_idx_m, IdxT tile_idx_n)
{
IdxT starty = tile_idx_m + this->accrowid;
IdxT startx = tile_idx_n + this->acccolid;
if (writeOut) {
IdxT starty = gridStrideY + this->accrowid;
IdxT startx = gridStrideX + this->acccolid;

#pragma unroll
for (int i = 0; i < P::AccRowsPerTh; ++i) {
auto rowId = starty + i * P::AccThRows;
for (int i = 0; i < P::AccRowsPerTh; ++i) {
auto rowId = starty + i * P::AccThRows;
#pragma unroll
for (int j = 0; j < P::AccColsPerTh; ++j) {
auto colId = startx + j * P::AccThCols;
if (rowId < this->m && colId < this->n) {
// Promote to 64 bit index for final write, as output array can be > 2^31
dOutput[std::size_t(rowId) * this->n + colId] = fin_op(acc[i][j], 0);
for (int j = 0; j < P::AccColsPerTh; ++j) {
auto colId = startx + j * P::AccThCols;
if (rowId < this->m && colId < this->n) {
// Promote to 64 bit index for final write, as output array can be > 2^31
dOutput[std::size_t(rowId) * this->n + colId] = fin_op(acc[i][j], 0);
}
}
}
}
Expand Down
Loading

0 comments on commit 776fc19

Please sign in to comment.