Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Custom fusedL2NN kernel for kmeans prediction #2050

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 49 additions & 33 deletions cpp/include/raft/cluster/detail/kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -98,43 +98,59 @@ inline std::enable_if_t<std::is_floating_point_v<MathT>> predict_core(
switch (params.metric) {
case raft::distance::DistanceType::L2Expanded:
case raft::distance::DistanceType::L2SqrtExpanded: {
auto workspace = raft::make_device_mdarray<char, IdxT>(
handle, mr, make_extents<IdxT>((sizeof(int)) * n_rows));

auto minClusterAndDistance = raft::make_device_mdarray<raft::KeyValuePair<IdxT, MathT>, IdxT>(
handle, mr, make_extents<IdxT>(n_rows));
raft::KeyValuePair<IdxT, MathT> initial_value(0, std::numeric_limits<MathT>::max());
thrust::fill(resource::get_thrust_policy(handle),
minClusterAndDistance.data_handle(),
minClusterAndDistance.data_handle() + minClusterAndDistance.size(),
initial_value);

auto centroidsNorm =
raft::make_device_mdarray<MathT, IdxT>(handle, mr, make_extents<IdxT>(n_clusters));
raft::linalg::rowNorm<MathT, IdxT>(
centroidsNorm.data_handle(), centers, dim, n_clusters, raft::linalg::L2Norm, true, stream);

raft::distance::fusedL2NNMinReduce<MathT, raft::KeyValuePair<IdxT, MathT>, IdxT>(
minClusterAndDistance.data_handle(),
dataset,
centers,
dataset_norm,
centroidsNorm.data_handle(),
n_rows,
n_clusters,
dim,
(void*)workspace.data_handle(),
(params.metric == raft::distance::DistanceType::L2Expanded) ? false : true,
false,
stream);

// todo(lsugy): use KVP + iterator in caller.
// Copy keys to output labels
thrust::transform(resource::get_thrust_policy(handle),
minClusterAndDistance.data_handle(),
minClusterAndDistance.data_handle() + n_rows,
labels,
raft::compose_op<raft::cast_op<LabelT>, raft::key_op>());
// Use custom fusedL2NN kernel if both n_cluster and dim are relatively small.
bool use_custom_fusedL2NN_kernel = n_clusters * dim <= 256;
// TODO: unify the output types of fusedL2NNMinReduceCustomKernel and fusedL2NNMinReduce
if (use_custom_fusedL2NN_kernel) {
raft::distance::fusedL2NNMinReduceCustomKernel<MathT, LabelT, IdxT>(
labels,
dataset,
centers,
dataset_norm,
centroidsNorm.data_handle(),
n_rows,
n_clusters,
dim,
(params.metric == raft::distance::DistanceType::L2Expanded) ? false : true,
stream);
} else {
auto workspace = raft::make_device_mdarray<char, IdxT>(
handle, mr, make_extents<IdxT>((sizeof(int)) * n_rows));

auto minClusterAndDistance = raft::make_device_mdarray<raft::KeyValuePair<IdxT, MathT>, IdxT>(
handle, mr, make_extents<IdxT>(n_rows));
raft::KeyValuePair<IdxT, MathT> initial_value(0, std::numeric_limits<MathT>::max());
thrust::fill(resource::get_thrust_policy(handle),
minClusterAndDistance.data_handle(),
minClusterAndDistance.data_handle() + minClusterAndDistance.size(),
initial_value);

raft::distance::fusedL2NNMinReduce<MathT, raft::KeyValuePair<IdxT, MathT>, IdxT>(
minClusterAndDistance.data_handle(),
dataset,
centers,
dataset_norm,
centroidsNorm.data_handle(),
n_rows,
n_clusters,
dim,
(void*)workspace.data_handle(),
(params.metric == raft::distance::DistanceType::L2Expanded) ? false : true,
false,
stream);

// todo(lsugy): use KVP + iterator in caller.
// Copy keys to output labels
thrust::transform(resource::get_thrust_policy(handle),
minClusterAndDistance.data_handle(),
minClusterAndDistance.data_handle() + n_rows,
labels,
raft::compose_op<raft::cast_op<LabelT>, raft::key_op>());
}
break;
}
case raft::distance::DistanceType::InnerProduct: {
Expand Down
87 changes: 87 additions & 0 deletions cpp/include/raft/distance/detail/fused_l2_nn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
#include <raft/core/kvp.hpp> // raft::KeyValuePair
#include <raft/core/operators.hpp> // raft::identity_op
#include <raft/distance/detail/distance_ops/l2_exp.cuh> // ops::l2_exp_distance_op
// raft::distance::detail::ops::get_clamp_precision
#include <raft/distance/detail/fused_distance_nn/cutlass_base.cuh>
#include <raft/distance/detail/pairwise_distance_base.cuh> // PairwiseDistances
#include <raft/linalg/contractions.cuh> // Policy
#include <raft/util/arch.cuh> // raft::util::arch::SM_*
#include <raft/util/cuda_utils.cuh> // raft::ceildiv, raft::shfl
#include <raft/core/math.hpp> // raft::sqrt

namespace raft {
namespace distance {
Expand Down Expand Up @@ -380,6 +382,91 @@ void fusedL2NNImpl(OutT* min,
}
}

template <bool sqrt, typename DataT, typename IdxT, typename LabelT>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: template parameters should start with a capital

Suggested change
template <bool sqrt, typename DataT, typename IdxT, typename LabelT>
template <bool Sqrt, typename DataT, typename IdxT, typename LabelT>

RAFT_KERNEL fusedL2NNKernelSmallInput(const DataT* dataset,
const DataT* centers,
const DataT* dataset_norm,
const DataT* centers_norm,
IdxT n_rows,
IdxT n_clusters,
IdxT dim,
LabelT* labels)
{
extern __shared__ char smem[];
DataT *centers_shared = reinterpret_cast<DataT*>(smem);
DataT *dataset_shared = centers_shared + n_clusters * dim;
int starting_row = blockDim.x * blockIdx.x;
int curr_row = blockDim.x * blockIdx.x + threadIdx.x;

int shmem_loading_idx = threadIdx.x;
while (shmem_loading_idx < n_clusters * dim) {
centers_shared[shmem_loading_idx] = centers[shmem_loading_idx];
shmem_loading_idx += blockDim.x;
}

shmem_loading_idx = threadIdx.x;
while (shmem_loading_idx < blockDim.x * dim) {
if (starting_row * dim + shmem_loading_idx < n_rows * dim)
dataset_shared[shmem_loading_idx] = dataset[starting_row * dim + shmem_loading_idx];
shmem_loading_idx += blockDim.x;
}

__syncthreads();

if (curr_row < n_rows) {
DataT min_distance = std::numeric_limits<DataT>::max();
IdxT location = 0;
#pragma unroll 16
for (int curr_n = 0; curr_n < n_clusters; curr_n++) {
DataT curr_distance = dataset_norm[curr_row] + centers_norm[curr_n];
#pragma unroll 2
for (int curr_k = 0; curr_k < dim; curr_k++) {
curr_distance -= 2 * dataset_shared[threadIdx.x * dim + curr_k] * centers_shared[curr_n * dim + curr_k];
/**
* Self-neighboring points should have (aNorm == bNorm) == accVal and the dot product (accVal)
* can sometimes have round-off errors, which will cause (aNorm == bNorm) ~ accVal instead.
*/
curr_distance = curr_distance * !((curr_distance * curr_distance < raft::distance::detail::ops::get_clamp_precision<DataT>()) * (dataset_norm[curr_row] == centers_norm[curr_n]));
if (sqrt) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change it to constexpr to make it more clear for both the reader and the compiler

Suggested change
if (sqrt) {
if constexpr (Sqrt) {

curr_distance = raft::sqrt(curr_distance * (curr_distance > 0));
}
}
if (curr_distance < min_distance) {
min_distance = curr_distance;
location = curr_n;
}
}
labels[curr_row] = location;
}
}

template <typename DataT, typename LabelT, typename IdxT>
void fusedL2NNMinReduceCustomKernelImpl(LabelT* label,
const DataT* x,
const DataT* y,
const DataT* xn,
const DataT* yn,
IdxT m,
IdxT n,
IdxT k,
bool sqrt,
cudaStream_t stream)
{
constexpr int block_size = 256;
dim3 threads_per_block(block_size, 1, 1);
dim3 num_blocks(ceil(1.0 * m / block_size), 1, 1);
int shmem_size = sizeof(DataT) * k * (n + block_size);
if (sqrt) {
fusedL2NNKernelSmallInput<true><<<num_blocks, threads_per_block, shmem_size, stream>>>(
x, y, xn, yn, m, n, k, label
);
} else {
fusedL2NNKernelSmallInput<false><<<num_blocks, threads_per_block, shmem_size, stream>>>(
x, y, xn, yn, m, n, k, label
);
}
}

} // namespace detail
} // namespace distance
} // namespace raft
31 changes: 31 additions & 0 deletions cpp/include/raft/distance/fused_l2_nn-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,18 @@ void fusedL2NNMinReduce(OutT* min,
bool initOutBuffer,
cudaStream_t stream) RAFT_EXPLICIT;

template <typename DataT, typename LabelT, typename IdxT>
void fusedL2NNMinReduceCustomKernel(LabelT* label,
const DataT* x,
const DataT* y,
const DataT* xn,
const DataT* yn,
IdxT m,
IdxT n,
IdxT k,
bool sqrt,
cudaStream_t stream) RAFT_EXPLICIT;

} // namespace distance
} // namespace raft

Expand Down Expand Up @@ -80,3 +92,22 @@ instantiate_raft_distance_fusedL2NNMinReduce(float,
#undef COMMA

#undef instantiate_raft_distance_fusedL2NNMinReduce

#define instantiate_raft_distance_fusedL2NNMinReduceCustomKernel(DataT, LabelT, IdxT) \
extern template void raft::distance::fusedL2NNMinReduceCustomKernel<DataT, LabelT, IdxT>(LabelT* label, \
const DataT* x, \
const DataT* y, \
const DataT* xn, \
const DataT* yn, \
IdxT m, \
IdxT n, \
IdxT k, \
bool sqrt, \
cudaStream_t stream)

instantiate_raft_distance_fusedL2NNMinReduceCustomKernel(float, uint32_t, int);
instantiate_raft_distance_fusedL2NNMinReduceCustomKernel(float, uint32_t, int64_t);
instantiate_raft_distance_fusedL2NNMinReduceCustomKernel(double, uint32_t, int);
instantiate_raft_distance_fusedL2NNMinReduceCustomKernel(double, uint32_t, int64_t);

#undef instantiate_raft_distance_fusedL2NNMinReduceCustomKernel
38 changes: 38 additions & 0 deletions cpp/include/raft/distance/fused_l2_nn-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,44 @@ void fusedL2NNMinReduce(OutT* min,
min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream);
}

/**
* @brief Wrapper around custom fusedL2NN kernel with minimum reduction operators.
* This custom kernel is designed to optimize cases with small inputs.
*
* Like fusedL2NN this wrapper covers the most common case (minimum).
*
* @tparam DataT data type
* @tparam LabelT label type to store 1-NN indices
* @tparam IdxT indexing arithmetic type
* @param[out] label will contain the output label (Length = `m`)
* (on device)
* @param[in] x first matrix. Row major. Dim = `m x k`.
* (on device).
* @param[in] y second matrix. Row major. Dim = `n x k`.
* (on device).
* @param[in] xn L2 squared norm of `x`. Length = `m`. (on device).
* @param[in] yn L2 squared norm of `y`. Length = `n`. (on device)
* @param[in] m gemm m
* @param[in] n gemm n
* @param[in] k gemm k
* @param[in] sqrt Whether the output `minDist` should contain L2-sqrt
* @param[in] stream cuda stream
*/
template <typename DataT, typename LabelT, typename IdxT>
void fusedL2NNMinReduceCustomKernel(LabelT* label,
const DataT* x,
const DataT* y,
const DataT* xn,
const DataT* yn,
IdxT m,
IdxT n,
IdxT k,
bool sqrt,
cudaStream_t stream)
{
detail::fusedL2NNMinReduceCustomKernelImpl(label, x, y, xn, yn, m, n, k, sqrt, stream);
}

/** @} */

} // namespace distance
Expand Down
20 changes: 20 additions & 0 deletions cpp/src/distance/fused_l2_nn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,23 @@ instantiate_raft_distance_fusedL2NNMinReduce(float,
#undef COMMA

#undef instantiate_raft_distance_fusedL2NNMinReduce


#define instantiate_raft_distance_fusedL2NNMinReduceCustomKernel(DataT, LabelT, IdxT) \
template void raft::distance::fusedL2NNMinReduceCustomKernel<DataT, LabelT, IdxT>(LabelT* label, \
const DataT* x, \
const DataT* y, \
const DataT* xn, \
const DataT* yn, \
IdxT m, \
IdxT n, \
IdxT k, \
bool sqrt, \
cudaStream_t stream)

instantiate_raft_distance_fusedL2NNMinReduceCustomKernel(float, uint32_t, int);
instantiate_raft_distance_fusedL2NNMinReduceCustomKernel(float, uint32_t, int64_t);
instantiate_raft_distance_fusedL2NNMinReduceCustomKernel(double, uint32_t, int);
instantiate_raft_distance_fusedL2NNMinReduceCustomKernel(double, uint32_t, int64_t);

#undef instantiate_raft_distance_fusedL2NNMinReduceCustomKernel