Skip to content

Commit

Permalink
Limit loop unrolling for expensive distance ops
Browse files Browse the repository at this point in the history
We have four distance operations with expensive inner loops. They are
already limited to veclen == 1. In this commit, we are also limiting the
loop unrolling in accumulate().

In general unroll 1 is best. The only exception is canberra, which can
be sped up with unlimited unrolling.

Results below are for veclen = 1 on H100 with locked 2619Mhz memory and 1980 MHz
SM clock. For comparison, I have added L1 results as well.

| distance_op    | veclen | unroll 1 | unroll 2 | unlimited |
|----------------+--------+----------+----------+-----------|
| lp_unexp       |      1 | 248 G/s  |  248 G/s |   235 G/s |
| canberra       |      1 | 371      |      367 |       447 |
| canberra       |      4 | 377      |      369 |       378 |
| jensen shannon |      1 | 512 G/s  |      512 |       449 |
| kl divergence  |      1 | 659      |      391 |       265 |
|----------------+--------+----------+----------+-----------|
| l1             |      1 | 8.9 T/s  |  8.6 T/s |  9.95 T/s |
| l1             |      4 | 11.1 T/s | 11.5 T/s |  11.7 T/s |

Compile time impact:

pairwise_test                                                                                                               0.5 seconds
build.ninja                                                                                                                 4.0 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/canberra_double_double_double_int.cu.o            6.8 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/canberra_float_float_float_int.cu.o               7.2 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/russel_rao_double_double_double_int.cu.o          7.4 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/russel_rao_float_float_float_int.cu.o             7.7 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu.o       8.2 seconds
akeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu.o    8.8 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l1_double_double_double_int.cu.o                  9.0 seconds
akeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu.o    9.1 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/correlation_double_double_double_int.cu.o         9.2 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l_inf_double_double_double_int.cu.o              10.0 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu.o         10.1 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu.o      10.6 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/jensen_shannon_float_float_float_int.cu.o        10.6 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu.o    12.8 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu.o         13.0 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kl_divergence_float_float_float_int.cu.o         13.3 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l_inf_float_float_float_int.cu.o                 13.6 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l1_float_float_float_int.cu.o                    13.8 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu.o    14.4 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/jensen_shannon_double_double_double_int.cu.o     14.5 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/correlation_float_float_float_int.cu.o           14.6 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kl_divergence_double_double_double_int.cu.o      23.4 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l2_expanded_double_double_double_int.cu.o        24.6 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/cosine_double_double_double_int.cu.o             31.6 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/inner_product_double_double_double_int.cu.o      34.6 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/inner_product_float_float_float_int.cu.o         35.5 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/fused_l2_nn_double_int64.cu.o                           37.7 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/fused_l2_nn_double_int.cu.o                             38.8 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kernels/gram_matrix_base_double.cu.o             40.3 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kernels/tanh_kernel_double.cu.o                  41.3 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/l2_expanded_float_float_float_int.cu.o           42.1 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu.o        42.5 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/fused_l2_nn_float_int.cu.o                              43.7 seconds
CMakeFiles/pairwise_test.dir/test/distance/dist_hellinger.cu.o                                                             44.4 seconds
CMakeFiles/pairwise_test.dir/test/distance/dist_russell_rao.cu.o                                                           44.6 seconds
CMakeFiles/pairwise_test.dir/test/distance/dist_hamming.cu.o                                                               44.7 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/fused_l2_nn_float_int64.cu.o                            45.0 seconds
CMakeFiles/pairwise_test.dir/test/distance/dist_inner_product.cu.o                                                         45.7 seconds
CMakeFiles/pairwise_test.dir/test/distance/dist_cos.cu.o                                                                   45.9 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kernels/gram_matrix_base_float.cu.o              46.2 seconds
CMakeFiles/pairwise_test.dir/test/distance/dist_kl_divergence.cu.o                                                         46.4 seconds
CMakeFiles/pairwise_test.dir/test/distance/dist_l2_unexp.cu.o                                                              46.7 seconds
CMakeFiles/pairwise_test.dir/test/distance/dist_l1.cu.o                                                                    47.1 seconds
CMakeFiles/pairwise_test.dir/test/distance/dist_lp_unexp.cu.o                                                              47.3 seconds
CMakeFiles/pairwise_test.dir/test/distance/dist_l2_exp.cu.o                                                                47.4 seconds
CMakeFiles/pairwise_test.dir/test/distance/dist_canberra.cu.o                                                              47.6 seconds
CMakeFiles/pairwise_test.dir/test/distance/dist_correlation.cu.o                                                           47.6 seconds
CMakeFiles/pairwise_test.dir/test/distance/dist_jensen_shannon.cu.o                                                        48.3 seconds
CMakeFiles/pairwise_test.dir/test/distance/dist_l_inf.cu.o                                                                 48.6 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu.o         48.9 seconds
CMakeFiles/pairwise_test.dir/test/distance/dist_l2_sqrt_exp.cu.o                                                           49.3 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/cosine_float_float_float_int.cu.o                50.6 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/specializations/detail/kernels/tanh_kernel_float.cu.o                   51.9 seconds
CMakeFiles/pairwise_test.dir/src/distance/distance/pairwise_distance.cu.o                                                  54.5 seconds
CMakeFiles/pairwise_test.dir/test/distance/gram.cu.o                                                                       56.8 seconds
CMakeFiles/pairwise_test.dir/test/distance/fused_l2_nn.cu.o                                                                67.5 seconds
CMakeFiles/pairwise_test.dir/test/distance/dist_adj.cu.o                                                                  123.3 seconds
  • Loading branch information
ahendriksen committed Mar 15, 2023
1 parent f54e7a4 commit 35a2ad4
Showing 1 changed file with 24 additions and 7 deletions.
31 changes: 24 additions & 7 deletions cpp/include/raft/distance/detail/pairwise_distance_base.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -205,24 +205,41 @@ struct PairwiseDistances : public BaseClass {
}
}

DI void accumulate()
DI void accumulate_reg_tile(DataT (&reg_x)[P::AccRowsPerTh][P::Veclen],
DataT (&reg_y)[P::AccColsPerTh][P::Veclen])
{
#pragma unroll
for (int ki = 0; ki < P::Kblk; ki += P::Veclen) {
this->ldsXY(ki);
for (int v = 0; v < P::Veclen; ++v) {
#pragma unroll
for (int i = 0; i < P::AccRowsPerTh; ++i) {
#pragma unroll
for (int j = 0; j < P::AccColsPerTh; ++j) {
#pragma unroll
for (int v = 0; v < P::Veclen; ++v) {
distance_op.core(acc[i][j], this->regx[i][v], this->regy[j][v]);
}
distance_op.core(acc[i][j], reg_x[i][v], reg_y[j][v]);
}
}
}
}

DI void accumulate()
{
// We have a separate ldsXY and accumulate_reg_tile outside the loop body,
// so that these separated calls can be interspersed with preceding and
// following instructions, thereby hiding latency.
this->ldsXY(0);

// If expensive inner loop, do not unroll loop.
constexpr int num_iterations = P::Kblk / P::Veclen - 1;
constexpr int unroll_count = decltype(distance_op)::expensive_inner_loop ? 1 : num_iterations;
#pragma unroll unroll_count
for (int ki = P::Veclen; ki < P::Kblk; ki += P::Veclen) {
accumulate_reg_tile(this->regx, this->regy);
this->ldsXY(ki);
}

// Accumulate last loaded tile.
accumulate_reg_tile(this->regx, this->regy);
}

DI void load_norms(IdxT tile_idx_m,
IdxT tile_idx_n,
DataT (&regxn)[P::AccRowsPerTh],
Expand Down

0 comments on commit 35a2ad4

Please sign in to comment.