From d4825f7b9876c741c9b2c299c07988c0462b2baf Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 30 Sep 2021 15:36:50 -0400 Subject: [PATCH 01/22] Hiding private API for dense distances primitive --- .../raft/distance/{ => detail}/canberra.cuh | 0 .../raft/distance/{ => detail}/chebyshev.cuh | 0 .../distance/{ => detail}/correlation.cuh | 0 .../raft/distance/{ => detail}/cosine.cuh | 0 .../raft/distance/{ => detail}/euclidean.cuh | 0 .../raft/distance/{ => detail}/hamming.cuh | 0 .../raft/distance/{ => detail}/hellinger.cuh | 0 .../distance/{ => detail}/jensen_shannon.cuh | 0 .../distance/{ => detail}/kl_divergence.cuh | 0 cpp/include/raft/distance/{ => detail}/l1.cuh | 0 .../raft/distance/{ => detail}/minkowski.cuh | 0 .../{ => detail}/pairwise_distance_base.cuh | 0 .../distance/{ => detail}/russell_rao.cuh | 0 cpp/include/raft/distance/distance.cuh | 24 +++++++++---------- cpp/include/raft/distance/fused_l2_nn.cuh | 2 +- 15 files changed, 13 insertions(+), 13 deletions(-) rename cpp/include/raft/distance/{ => detail}/canberra.cuh (100%) rename cpp/include/raft/distance/{ => detail}/chebyshev.cuh (100%) rename cpp/include/raft/distance/{ => detail}/correlation.cuh (100%) rename cpp/include/raft/distance/{ => detail}/cosine.cuh (100%) rename cpp/include/raft/distance/{ => detail}/euclidean.cuh (100%) rename cpp/include/raft/distance/{ => detail}/hamming.cuh (100%) rename cpp/include/raft/distance/{ => detail}/hellinger.cuh (100%) rename cpp/include/raft/distance/{ => detail}/jensen_shannon.cuh (100%) rename cpp/include/raft/distance/{ => detail}/kl_divergence.cuh (100%) rename cpp/include/raft/distance/{ => detail}/l1.cuh (100%) rename cpp/include/raft/distance/{ => detail}/minkowski.cuh (100%) rename cpp/include/raft/distance/{ => detail}/pairwise_distance_base.cuh (100%) rename cpp/include/raft/distance/{ => detail}/russell_rao.cuh (100%) diff --git a/cpp/include/raft/distance/canberra.cuh b/cpp/include/raft/distance/detail/canberra.cuh similarity index 100% rename from cpp/include/raft/distance/canberra.cuh rename to cpp/include/raft/distance/detail/canberra.cuh diff --git a/cpp/include/raft/distance/chebyshev.cuh b/cpp/include/raft/distance/detail/chebyshev.cuh similarity index 100% rename from cpp/include/raft/distance/chebyshev.cuh rename to cpp/include/raft/distance/detail/chebyshev.cuh diff --git a/cpp/include/raft/distance/correlation.cuh b/cpp/include/raft/distance/detail/correlation.cuh similarity index 100% rename from cpp/include/raft/distance/correlation.cuh rename to cpp/include/raft/distance/detail/correlation.cuh diff --git a/cpp/include/raft/distance/cosine.cuh b/cpp/include/raft/distance/detail/cosine.cuh similarity index 100% rename from cpp/include/raft/distance/cosine.cuh rename to cpp/include/raft/distance/detail/cosine.cuh diff --git a/cpp/include/raft/distance/euclidean.cuh b/cpp/include/raft/distance/detail/euclidean.cuh similarity index 100% rename from cpp/include/raft/distance/euclidean.cuh rename to cpp/include/raft/distance/detail/euclidean.cuh diff --git a/cpp/include/raft/distance/hamming.cuh b/cpp/include/raft/distance/detail/hamming.cuh similarity index 100% rename from cpp/include/raft/distance/hamming.cuh rename to cpp/include/raft/distance/detail/hamming.cuh diff --git a/cpp/include/raft/distance/hellinger.cuh b/cpp/include/raft/distance/detail/hellinger.cuh similarity index 100% rename from cpp/include/raft/distance/hellinger.cuh rename to cpp/include/raft/distance/detail/hellinger.cuh diff --git a/cpp/include/raft/distance/jensen_shannon.cuh b/cpp/include/raft/distance/detail/jensen_shannon.cuh similarity index 100% rename from cpp/include/raft/distance/jensen_shannon.cuh rename to cpp/include/raft/distance/detail/jensen_shannon.cuh diff --git a/cpp/include/raft/distance/kl_divergence.cuh b/cpp/include/raft/distance/detail/kl_divergence.cuh similarity index 100% rename from cpp/include/raft/distance/kl_divergence.cuh rename to cpp/include/raft/distance/detail/kl_divergence.cuh diff --git a/cpp/include/raft/distance/l1.cuh b/cpp/include/raft/distance/detail/l1.cuh similarity index 100% rename from cpp/include/raft/distance/l1.cuh rename to cpp/include/raft/distance/detail/l1.cuh diff --git a/cpp/include/raft/distance/minkowski.cuh b/cpp/include/raft/distance/detail/minkowski.cuh similarity index 100% rename from cpp/include/raft/distance/minkowski.cuh rename to cpp/include/raft/distance/detail/minkowski.cuh diff --git a/cpp/include/raft/distance/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh similarity index 100% rename from cpp/include/raft/distance/pairwise_distance_base.cuh rename to cpp/include/raft/distance/detail/pairwise_distance_base.cuh diff --git a/cpp/include/raft/distance/russell_rao.cuh b/cpp/include/raft/distance/detail/russell_rao.cuh similarity index 100% rename from cpp/include/raft/distance/russell_rao.cuh rename to cpp/include/raft/distance/detail/russell_rao.cuh diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index 65b4f3b830..52041a59ad 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -19,18 +19,18 @@ #include #include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include #include namespace raft { diff --git a/cpp/include/raft/distance/fused_l2_nn.cuh b/cpp/include/raft/distance/fused_l2_nn.cuh index b96a536e38..cafb9a91ba 100644 --- a/cpp/include/raft/distance/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/fused_l2_nn.cuh @@ -20,7 +20,7 @@ #include #include #include -#include +#include #include namespace raft { From b9fca66c39206034232e240276fb6d9629f344a3 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 30 Sep 2021 15:39:43 -0400 Subject: [PATCH 02/22] Hiding private sparse distance API --- .../raft/sparse/distance/{ => detail}/bin_distance.cuh | 2 +- .../raft/sparse/distance/{ => detail}/coo_spmv.cuh | 6 +++--- .../{ => detail}/coo_spmv_strategies/base_strategy.cuh | 4 ++-- .../coo_spmv_strategies/coo_mask_row_iterators.cuh | 2 +- .../coo_spmv_strategies/dense_smem_strategy.cuh | 0 .../{ => detail}/coo_spmv_strategies/hash_strategy.cuh | 0 .../raft/sparse/distance/{ => detail}/ip_distance.cuh | 4 ++-- .../raft/sparse/distance/{ => detail}/l2_distance.cuh | 2 +- .../raft/sparse/distance/{ => detail}/lp_distance.cuh | 2 +- .../raft/sparse/distance/{ => detail}/operators.cuh | 0 cpp/include/raft/sparse/distance/{ => detail}/utils.cuh | 0 cpp/include/raft/sparse/distance/distance.cuh | 8 ++++---- 12 files changed, 15 insertions(+), 15 deletions(-) rename cpp/include/raft/sparse/distance/{ => detail}/bin_distance.cuh (99%) rename cpp/include/raft/sparse/distance/{ => detail}/coo_spmv.cuh (99%) rename cpp/include/raft/sparse/distance/{ => detail}/coo_spmv_strategies/base_strategy.cuh (98%) rename cpp/include/raft/sparse/distance/{ => detail}/coo_spmv_strategies/coo_mask_row_iterators.cuh (99%) rename cpp/include/raft/sparse/distance/{ => detail}/coo_spmv_strategies/dense_smem_strategy.cuh (100%) rename cpp/include/raft/sparse/distance/{ => detail}/coo_spmv_strategies/hash_strategy.cuh (100%) rename cpp/include/raft/sparse/distance/{ => detail}/ip_distance.cuh (95%) rename cpp/include/raft/sparse/distance/{ => detail}/l2_distance.cuh (99%) rename cpp/include/raft/sparse/distance/{ => detail}/lp_distance.cuh (99%) rename cpp/include/raft/sparse/distance/{ => detail}/operators.cuh (100%) rename cpp/include/raft/sparse/distance/{ => detail}/utils.cuh (100%) diff --git a/cpp/include/raft/sparse/distance/bin_distance.cuh b/cpp/include/raft/sparse/distance/detail/bin_distance.cuh similarity index 99% rename from cpp/include/raft/sparse/distance/bin_distance.cuh rename to cpp/include/raft/sparse/distance/detail/bin_distance.cuh index 6885c250c0..5b59ce89ed 100644 --- a/cpp/include/raft/sparse/distance/bin_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/bin_distance.cuh @@ -24,7 +24,7 @@ #include #include #include -#include +#include #include #include diff --git a/cpp/include/raft/sparse/distance/coo_spmv.cuh b/cpp/include/raft/sparse/distance/detail/coo_spmv.cuh similarity index 99% rename from cpp/include/raft/sparse/distance/coo_spmv.cuh rename to cpp/include/raft/sparse/distance/detail/coo_spmv.cuh index 24be171900..a9ff55a291 100644 --- a/cpp/include/raft/sparse/distance/coo_spmv.cuh +++ b/cpp/include/raft/sparse/distance/detail/coo_spmv.cuh @@ -24,9 +24,9 @@ #include #include -#include "../csr.cuh" -#include "../utils.h" -#include "common.h" +#include "../../csr.cuh" +#include "../../utils.h" +#include "../common.h" #include diff --git a/cpp/include/raft/sparse/distance/coo_spmv_strategies/base_strategy.cuh b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/base_strategy.cuh similarity index 98% rename from cpp/include/raft/sparse/distance/coo_spmv_strategies/base_strategy.cuh rename to cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/base_strategy.cuh index 3b57225350..22cfda7787 100644 --- a/cpp/include/raft/sparse/distance/coo_spmv_strategies/base_strategy.cuh +++ b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/base_strategy.cuh @@ -16,8 +16,8 @@ #pragma once -#include "../common.h" -#include "../detail/coo_spmv_kernel.cuh" +#include "../../common.h" +#include "../coo_spmv_kernel.cuh" #include "../utils.cuh" #include "coo_mask_row_iterators.cuh" diff --git a/cpp/include/raft/sparse/distance/coo_spmv_strategies/coo_mask_row_iterators.cuh b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/coo_mask_row_iterators.cuh similarity index 99% rename from cpp/include/raft/sparse/distance/coo_spmv_strategies/coo_mask_row_iterators.cuh rename to cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/coo_mask_row_iterators.cuh index 74eb37bc2b..2e774e3a02 100644 --- a/cpp/include/raft/sparse/distance/coo_spmv_strategies/coo_mask_row_iterators.cuh +++ b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/coo_mask_row_iterators.cuh @@ -16,7 +16,7 @@ #pragma once -#include "../common.h" +#include "../../common.h" #include "../utils.cuh" #include diff --git a/cpp/include/raft/sparse/distance/coo_spmv_strategies/dense_smem_strategy.cuh b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/dense_smem_strategy.cuh similarity index 100% rename from cpp/include/raft/sparse/distance/coo_spmv_strategies/dense_smem_strategy.cuh rename to cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/dense_smem_strategy.cuh diff --git a/cpp/include/raft/sparse/distance/coo_spmv_strategies/hash_strategy.cuh b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/hash_strategy.cuh similarity index 100% rename from cpp/include/raft/sparse/distance/coo_spmv_strategies/hash_strategy.cuh rename to cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/hash_strategy.cuh diff --git a/cpp/include/raft/sparse/distance/ip_distance.cuh b/cpp/include/raft/sparse/distance/detail/ip_distance.cuh similarity index 95% rename from cpp/include/raft/sparse/distance/ip_distance.cuh rename to cpp/include/raft/sparse/distance/detail/ip_distance.cuh index b1e2756671..def59d683c 100644 --- a/cpp/include/raft/sparse/distance/ip_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/ip_distance.cuh @@ -27,8 +27,8 @@ #include #include #include -#include -#include +#include +#include #include #include diff --git a/cpp/include/raft/sparse/distance/l2_distance.cuh b/cpp/include/raft/sparse/distance/detail/l2_distance.cuh similarity index 99% rename from cpp/include/raft/sparse/distance/l2_distance.cuh rename to cpp/include/raft/sparse/distance/detail/l2_distance.cuh index 6ccfd4adcb..a32676ce4c 100644 --- a/cpp/include/raft/sparse/distance/l2_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/l2_distance.cuh @@ -26,7 +26,7 @@ #include #include #include -#include +#include #include #include diff --git a/cpp/include/raft/sparse/distance/lp_distance.cuh b/cpp/include/raft/sparse/distance/detail/lp_distance.cuh similarity index 99% rename from cpp/include/raft/sparse/distance/lp_distance.cuh rename to cpp/include/raft/sparse/distance/detail/lp_distance.cuh index 885d55ee50..41a08866a2 100644 --- a/cpp/include/raft/sparse/distance/lp_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/lp_distance.cuh @@ -30,7 +30,7 @@ #include #include -#include +#include #include diff --git a/cpp/include/raft/sparse/distance/operators.cuh b/cpp/include/raft/sparse/distance/detail/operators.cuh similarity index 100% rename from cpp/include/raft/sparse/distance/operators.cuh rename to cpp/include/raft/sparse/distance/detail/operators.cuh diff --git a/cpp/include/raft/sparse/distance/utils.cuh b/cpp/include/raft/sparse/distance/detail/utils.cuh similarity index 100% rename from cpp/include/raft/sparse/distance/utils.cuh rename to cpp/include/raft/sparse/distance/detail/utils.cuh diff --git a/cpp/include/raft/sparse/distance/distance.cuh b/cpp/include/raft/sparse/distance/distance.cuh index 03df396b2e..5b52a0dfbc 100644 --- a/cpp/include/raft/sparse/distance/distance.cuh +++ b/cpp/include/raft/sparse/distance/distance.cuh @@ -31,10 +31,10 @@ #include #include -#include -#include -#include -#include +#include +#include +#include +#include #include From a99e0f7a084e229663a65783d62045b210b8c161 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 30 Sep 2021 15:50:05 -0400 Subject: [PATCH 03/22] Updating pairwise distance base to the proper directory --- cpp/include/raft/distance/detail/canberra.cuh | 2 +- cpp/include/raft/distance/detail/chebyshev.cuh | 2 +- cpp/include/raft/distance/detail/correlation.cuh | 2 +- cpp/include/raft/distance/detail/cosine.cuh | 2 +- cpp/include/raft/distance/detail/euclidean.cuh | 2 +- cpp/include/raft/distance/detail/hamming.cuh | 2 +- cpp/include/raft/distance/detail/hellinger.cuh | 2 +- cpp/include/raft/distance/detail/jensen_shannon.cuh | 2 +- cpp/include/raft/distance/detail/kl_divergence.cuh | 2 +- cpp/include/raft/distance/detail/l1.cuh | 2 +- cpp/include/raft/distance/detail/minkowski.cuh | 2 +- cpp/include/raft/distance/detail/russell_rao.cuh | 2 +- 12 files changed, 12 insertions(+), 12 deletions(-) diff --git a/cpp/include/raft/distance/detail/canberra.cuh b/cpp/include/raft/distance/detail/canberra.cuh index b87c295eb0..de5bb8f584 100644 --- a/cpp/include/raft/distance/detail/canberra.cuh +++ b/cpp/include/raft/distance/detail/canberra.cuh @@ -15,7 +15,7 @@ */ #pragma once -#include +#include namespace raft { namespace distance { diff --git a/cpp/include/raft/distance/detail/chebyshev.cuh b/cpp/include/raft/distance/detail/chebyshev.cuh index 8d53408cf8..4f3ad7c7d3 100644 --- a/cpp/include/raft/distance/detail/chebyshev.cuh +++ b/cpp/include/raft/distance/detail/chebyshev.cuh @@ -15,7 +15,7 @@ */ #pragma once -#include +#include namespace raft { namespace distance { diff --git a/cpp/include/raft/distance/detail/correlation.cuh b/cpp/include/raft/distance/detail/correlation.cuh index ed3b7a5464..af17f4f56b 100644 --- a/cpp/include/raft/distance/detail/correlation.cuh +++ b/cpp/include/raft/distance/detail/correlation.cuh @@ -16,7 +16,7 @@ #pragma once #include -#include +#include #include namespace raft { diff --git a/cpp/include/raft/distance/detail/cosine.cuh b/cpp/include/raft/distance/detail/cosine.cuh index ed9bd28b7f..9ba3bbd92a 100644 --- a/cpp/include/raft/distance/detail/cosine.cuh +++ b/cpp/include/raft/distance/detail/cosine.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include #include namespace raft { diff --git a/cpp/include/raft/distance/detail/euclidean.cuh b/cpp/include/raft/distance/detail/euclidean.cuh index 484da0e5bf..37374aa7f2 100644 --- a/cpp/include/raft/distance/detail/euclidean.cuh +++ b/cpp/include/raft/distance/detail/euclidean.cuh @@ -15,7 +15,7 @@ */ #pragma once -#include +#include #include namespace raft { diff --git a/cpp/include/raft/distance/detail/hamming.cuh b/cpp/include/raft/distance/detail/hamming.cuh index 08f1020b85..adacfa4895 100644 --- a/cpp/include/raft/distance/detail/hamming.cuh +++ b/cpp/include/raft/distance/detail/hamming.cuh @@ -15,7 +15,7 @@ */ #pragma once -#include +#include namespace raft { namespace distance { diff --git a/cpp/include/raft/distance/detail/hellinger.cuh b/cpp/include/raft/distance/detail/hellinger.cuh index f7ad3ed1ba..6baf1ff5f1 100644 --- a/cpp/include/raft/distance/detail/hellinger.cuh +++ b/cpp/include/raft/distance/detail/hellinger.cuh @@ -15,7 +15,7 @@ */ #pragma once -#include +#include #include namespace raft { diff --git a/cpp/include/raft/distance/detail/jensen_shannon.cuh b/cpp/include/raft/distance/detail/jensen_shannon.cuh index 2a94205853..f4fe7f2d38 100644 --- a/cpp/include/raft/distance/detail/jensen_shannon.cuh +++ b/cpp/include/raft/distance/detail/jensen_shannon.cuh @@ -15,7 +15,7 @@ */ #pragma once -#include +#include namespace raft { namespace distance { diff --git a/cpp/include/raft/distance/detail/kl_divergence.cuh b/cpp/include/raft/distance/detail/kl_divergence.cuh index 3197b73d10..b22a3e3aaa 100644 --- a/cpp/include/raft/distance/detail/kl_divergence.cuh +++ b/cpp/include/raft/distance/detail/kl_divergence.cuh @@ -15,7 +15,7 @@ */ #pragma once -#include +#include namespace raft { namespace distance { diff --git a/cpp/include/raft/distance/detail/l1.cuh b/cpp/include/raft/distance/detail/l1.cuh index 6ab084f041..5e5fdf6feb 100644 --- a/cpp/include/raft/distance/detail/l1.cuh +++ b/cpp/include/raft/distance/detail/l1.cuh @@ -15,7 +15,7 @@ */ #pragma once -#include +#include namespace raft { namespace distance { diff --git a/cpp/include/raft/distance/detail/minkowski.cuh b/cpp/include/raft/distance/detail/minkowski.cuh index 803f5fc78a..d45eac6f63 100644 --- a/cpp/include/raft/distance/detail/minkowski.cuh +++ b/cpp/include/raft/distance/detail/minkowski.cuh @@ -15,7 +15,7 @@ */ #pragma once -#include +#include namespace raft { namespace distance { diff --git a/cpp/include/raft/distance/detail/russell_rao.cuh b/cpp/include/raft/distance/detail/russell_rao.cuh index 417fb73b94..663990bf2d 100644 --- a/cpp/include/raft/distance/detail/russell_rao.cuh +++ b/cpp/include/raft/distance/detail/russell_rao.cuh @@ -15,7 +15,7 @@ */ #pragma once -#include +#include namespace raft { namespace distance { From 541117fa297ffd26abccd3b4d9b73d7cfc298f6b Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 30 Sep 2021 16:07:07 -0400 Subject: [PATCH 04/22] Fixing fused_l2_knn path --- cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh index 9d00d9b9f4..d1d17d9389 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh @@ -17,7 +17,7 @@ #include #include #include -#include +#include #include "processing.hpp" namespace raft { From 6359cad6a6ba415dcaa92f8311011e013d25a1b1 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 30 Sep 2021 16:08:21 -0400 Subject: [PATCH 05/22] Removing internal include from knn --- cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh index d1d17d9389..6f2846847e 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh @@ -17,7 +17,6 @@ #include #include #include -#include #include "processing.hpp" namespace raft { From 5e1765267f0dd652c17e31b77d0b7f28204bd4d1 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 1 Oct 2021 10:13:18 -0400 Subject: [PATCH 06/22] Builds and runs --- cpp/include/raft/distance/detail/canberra.cuh | 11 +- .../raft/distance/detail/chebyshev.cuh | 11 +- .../raft/distance/detail/correlation.cuh | 11 +- cpp/include/raft/distance/detail/cosine.cuh | 8 +- cpp/include/raft/distance/detail/distance.cuh | 396 +++++++++++++++ .../raft/distance/detail/euclidean.cuh | 18 +- .../raft/distance/detail/fused_l2_nn.cuh | 270 +++++++++++ cpp/include/raft/distance/detail/hamming.cuh | 11 +- .../raft/distance/detail/hellinger.cuh | 10 +- .../raft/distance/detail/jensen_shannon.cuh | 10 +- .../raft/distance/detail/kl_divergence.cuh | 18 +- cpp/include/raft/distance/detail/l1.cuh | 10 +- .../raft/distance/detail/minkowski.cuh | 11 +- .../detail/pairwise_distance_base.cuh | 37 +- .../raft/distance/detail/russell_rao.cuh | 10 +- cpp/include/raft/distance/distance.cuh | 459 ++++-------------- cpp/include/raft/distance/fused_l2_nn.cuh | 245 +--------- .../sparse/distance/detail/bin_distance.cuh | 3 +- .../raft/sparse/distance/detail/coo_spmv.cuh | 2 + .../distance/detail/coo_spmv_kernel.cuh | 3 +- .../coo_spmv_strategies/base_strategy.cuh | 2 + .../coo_mask_row_iterators.cuh | 2 + .../dense_smem_strategy.cuh | 2 + .../coo_spmv_strategies/hash_strategy.cuh | 2 + .../sparse/distance/detail/ip_distance.cuh | 19 +- .../sparse/distance/detail/l2_distance.cuh | 2 + .../sparse/distance/detail/lp_distance.cuh | 2 + .../raft/sparse/distance/detail/operators.cuh | 2 + .../raft/sparse/distance/detail/utils.cuh | 2 + cpp/include/raft/sparse/distance/distance.cuh | 43 +- .../raft/spatial/knn/detail/fused_l2_knn.cuh | 13 +- cpp/test/distance/fused_l2_nn.cu | 3 +- cpp/test/sparse/dist_coo_spmv.cu | 30 +- 33 files changed, 964 insertions(+), 714 deletions(-) create mode 100644 cpp/include/raft/distance/detail/distance.cuh create mode 100644 cpp/include/raft/distance/detail/fused_l2_nn.cuh diff --git a/cpp/include/raft/distance/detail/canberra.cuh b/cpp/include/raft/distance/detail/canberra.cuh index de5bb8f584..3bef776b2c 100644 --- a/cpp/include/raft/distance/detail/canberra.cuh +++ b/cpp/include/raft/distance/detail/canberra.cuh @@ -19,6 +19,7 @@ namespace raft { namespace distance { +namespace detail { /** * @brief the canberra distance matrix calculation implementer @@ -77,8 +78,8 @@ static void canberraImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, pairwiseDistanceMatKernel; - dim3 grid = - launchConfigGenerator(m, n, KPolicy::SmemSize, canberraRowMajor); + dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, + canberraRowMajor); canberraRowMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, @@ -88,8 +89,8 @@ static void canberraImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, pairwiseDistanceMatKernel; - dim3 grid = - launchConfigGenerator(m, n, KPolicy::SmemSize, canberraColMajor); + dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, + canberraColMajor); canberraColMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); @@ -157,5 +158,7 @@ void canberraImpl(int m, int n, int k, const InType *pA, const InType *pB, n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); } } + +} // namespace detail } // namespace distance } // namespace raft diff --git a/cpp/include/raft/distance/detail/chebyshev.cuh b/cpp/include/raft/distance/detail/chebyshev.cuh index 4f3ad7c7d3..a2c89c5301 100644 --- a/cpp/include/raft/distance/detail/chebyshev.cuh +++ b/cpp/include/raft/distance/detail/chebyshev.cuh @@ -19,7 +19,7 @@ namespace raft { namespace distance { - +namespace detail { /** * @brief the Chebyshev distance matrix calculation implementer * It computes the following equation: cij = max(cij, op(ai-bj)) @@ -74,8 +74,8 @@ static void chebyshevImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, - chebyshevRowMajor); + dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, + chebyshevRowMajor); chebyshevRowMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, @@ -85,8 +85,8 @@ static void chebyshevImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, - chebyshevColMajor); + dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, + chebyshevColMajor); chebyshevColMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); @@ -154,5 +154,6 @@ void chebyshevImpl(int m, int n, int k, const InType *pA, const InType *pB, n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); } } +} // namespace detail } // namespace distance } // namespace raft diff --git a/cpp/include/raft/distance/detail/correlation.cuh b/cpp/include/raft/distance/detail/correlation.cuh index af17f4f56b..0f2e6b0ca4 100644 --- a/cpp/include/raft/distance/detail/correlation.cuh +++ b/cpp/include/raft/distance/detail/correlation.cuh @@ -21,6 +21,7 @@ namespace raft { namespace distance { +namespace detail { /** * @brief the Correlation distance matrix: @@ -124,8 +125,8 @@ static void correlationImpl(const DataT *x, const DataT *y, const DataT *xn, pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, - correlationRowMajor); + dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, + correlationRowMajor); correlationRowMajor<<>>( x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); @@ -134,8 +135,8 @@ static void correlationImpl(const DataT *x, const DataT *y, const DataT *xn, pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, - correlationColMajor); + dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, + correlationColMajor); correlationColMajor<<>>( x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); @@ -243,5 +244,7 @@ void correlationImpl(int m, int n, int k, const InType *pA, const InType *pB, fin_op, stream); } } + +} // namespace detail } // namespace distance } // namespace raft diff --git a/cpp/include/raft/distance/detail/cosine.cuh b/cpp/include/raft/distance/detail/cosine.cuh index 9ba3bbd92a..71bddd12c5 100644 --- a/cpp/include/raft/distance/detail/cosine.cuh +++ b/cpp/include/raft/distance/detail/cosine.cuh @@ -21,6 +21,7 @@ namespace raft { namespace distance { +namespace detail { /** * @brief the cosine distance matrix calculation implementer @@ -89,7 +90,8 @@ void cosineImpl(const DataT *x, const DataT *y, const DataT *xn, pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, shmemSize, cosineRowMajor); + dim3 grid = + detail::launchConfigGenerator(m, n, shmemSize, cosineRowMajor); cosineRowMajor<<>>( x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); @@ -98,7 +100,8 @@ void cosineImpl(const DataT *x, const DataT *y, const DataT *xn, pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, shmemSize, cosineColMajor); + dim3 grid = + detail::launchConfigGenerator(m, n, shmemSize, cosineColMajor); cosineColMajor<<>>( x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); @@ -201,5 +204,6 @@ void cosineAlgo1(Index_ m, Index_ n, Index_ k, const InType *pA, } } +}; // end namespace detail }; // end namespace distance }; // end namespace raft diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh new file mode 100644 index 0000000000..b4d428474e --- /dev/null +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -0,0 +1,396 @@ +/* + * Copyright (c) 2018-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace distance { +namespace detail { + +namespace { +template +struct DistanceImpl { + void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, + Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, + cudaStream_t stream, bool isRowMajor, InType metric_arg = 2.0f) {} +}; + +template +struct DistanceImpl { + void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, + Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, + cudaStream_t stream, bool isRowMajor, InType) { + raft::distance::detail::euclideanAlgo1( + m, n, k, x, y, dist, false, (AccType *)workspace, worksize, fin_op, + stream, isRowMajor); + } +}; + +template +struct DistanceImpl { + void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, + Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, + cudaStream_t stream, bool isRowMajor, InType) { + raft::distance::detail::euclideanAlgo1( + m, n, k, x, y, dist, true, (AccType *)workspace, worksize, fin_op, stream, + isRowMajor); + } +}; + +template +struct DistanceImpl { + void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, + Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, + cudaStream_t stream, bool isRowMajor, InType) { + raft::distance::detail::cosineAlgo1(m, n, k, x, y, dist, + (AccType *)workspace, worksize, + fin_op, stream, isRowMajor); + } +}; + +template +struct DistanceImpl { + void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, + Index_ k, void *, size_t, FinalLambda fin_op, cudaStream_t stream, + bool isRowMajor, InType) { + raft::distance::detail::euclideanAlgo2( + m, n, k, x, y, dist, false, fin_op, stream, isRowMajor); + } +}; + +template +struct DistanceImpl { + void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, + Index_ k, void *, size_t, FinalLambda fin_op, cudaStream_t stream, + bool isRowMajor, InType) { + raft::distance::detail::euclideanAlgo2( + m, n, k, x, y, dist, true, fin_op, stream, isRowMajor); + } +}; + +template +struct DistanceImpl { + void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, + Index_ k, void *, size_t, FinalLambda fin_op, cudaStream_t stream, + bool isRowMajor, InType) { + raft::distance::detail::l1Impl(m, n, k, x, y, dist, fin_op, stream, + isRowMajor); + } +}; + +template +struct DistanceImpl { + void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, + Index_ k, void *, size_t, FinalLambda fin_op, cudaStream_t stream, + bool isRowMajor, InType) { + raft::distance::detail::chebyshevImpl(m, n, k, x, y, dist, fin_op, + stream, isRowMajor); + } +}; + +template +struct DistanceImpl { + void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, + Index_ k, void *, size_t, FinalLambda fin_op, cudaStream_t stream, + bool isRowMajor, InType) { + raft::distance::detail::hellingerImpl(m, n, k, x, y, dist, fin_op, + stream, isRowMajor); + } +}; + +template +struct DistanceImpl { + void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, + Index_ k, void *, size_t, FinalLambda fin_op, cudaStream_t stream, + bool isRowMajor, InType metric_arg) { + raft::distance::detail::minkowskiImpl( + m, n, k, x, y, dist, fin_op, stream, isRowMajor, metric_arg); + } +}; + +template +struct DistanceImpl { + void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, + Index_ k, void *, size_t, FinalLambda fin_op, cudaStream_t stream, + bool isRowMajor, InType) { + raft::distance::detail::canberraImpl(m, n, k, x, y, dist, fin_op, + stream, isRowMajor); + } +}; + +template +struct DistanceImpl { + void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, + Index_ k, void *, size_t, FinalLambda fin_op, cudaStream_t stream, + bool isRowMajor, InType) { + raft::distance::detail::hammingUnexpandedImpl( + m, n, k, x, y, dist, fin_op, stream, isRowMajor); + } +}; + +template +struct DistanceImpl { + void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, + Index_ k, void *, size_t, FinalLambda fin_op, cudaStream_t stream, + bool isRowMajor, InType) { + raft::distance::detail::jensenShannonImpl( + m, n, k, x, y, dist, fin_op, stream, isRowMajor); + } +}; + +template +struct DistanceImpl { + void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, + Index_ k, void *, size_t, FinalLambda fin_op, cudaStream_t stream, + bool isRowMajor, InType) { + raft::distance::detail::russellRaoImpl( + m, n, k, x, y, dist, fin_op, stream, isRowMajor); + } +}; + +template +struct DistanceImpl { + void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, + Index_ k, void *, size_t, FinalLambda fin_op, cudaStream_t stream, + bool isRowMajor, InType) { + raft::distance::detail::klDivergenceImpl( + m, n, k, x, y, dist, fin_op, stream, isRowMajor); + } +}; + +template +struct DistanceImpl { + void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, + Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, + cudaStream_t stream, bool isRowMajor, InType) { + raft::distance::detail::correlationImpl( + m, n, k, x, y, dist, (AccType *)workspace, worksize, fin_op, stream, + isRowMajor); + } +}; + +} // anonymous namespace + +/** + * @brief Evaluate pairwise distances with the user epilogue lamba allowed + * @tparam DistanceType which distance to evaluate + * @tparam InType input argument type + * @tparam AccType accumulation type + * @tparam OutType output type + * @tparam FinalLambda user-defined epilogue lamba + * @tparam Index_ Index type + * @param x first set of points + * @param y second set of points + * @param dist output distance matrix + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * @param workspace temporary workspace needed for computations + * @param worksize number of bytes of the workspace + * @param fin_op the final gemm epilogue lambda + * @param stream cuda stream + * @param isRowMajor whether the matrices are row-major or col-major + * + * @note fin_op: This is a device lambda which is supposed to operate upon the + * input which is AccType and returns the output in OutType. It's signature is + * as follows:
OutType fin_op(AccType in, int g_idx);
. If one needs + * any other parameters, feel free to pass them via closure. + */ +template +void distance(const InType *x, const InType *y, OutType *dist, Index_ m, + Index_ n, Index_ k, void *workspace, size_t worksize, + FinalLambda fin_op, cudaStream_t stream, bool isRowMajor = true, + InType metric_arg = 2.0f) { + DistanceImpl + distImpl; + distImpl.run(x, y, dist, m, n, k, workspace, worksize, fin_op, stream, + isRowMajor, metric_arg); + CUDA_CHECK(cudaPeekAtLastError()); +} + +/** + * @brief Evaluate pairwise distances for the simple use case + * @tparam DistanceType which distance to evaluate + * @tparam InType input argument type + * @tparam AccType accumulation type + * @tparam OutType output type + * @tparam Index_ Index type + * @param x first set of points + * @param y second set of points + * @param dist output distance matrix + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * @param workspace temporary workspace needed for computations + * @param worksize number of bytes of the workspace + * @param stream cuda stream + * @param isRowMajor whether the matrices are row-major or col-major + * + * @note if workspace is passed as nullptr, this will return in + * worksize, the number of bytes of workspace required + */ +template +void distance(const InType *x, const InType *y, OutType *dist, Index_ m, + Index_ n, Index_ k, void *workspace, size_t worksize, + cudaStream_t stream, bool isRowMajor = true, + InType metric_arg = 2.0f) { + auto default_fin_op = [] __device__(AccType d_val, Index_ g_d_idx) { + return d_val; + }; + distance(x, y, dist, m, n, k, workspace, worksize, default_fin_op, + stream, isRowMajor, metric_arg); + CUDA_CHECK(cudaPeekAtLastError()); +} + +/** + * @brief Return the exact workspace size to compute the distance + * @tparam DistanceType which distance to evaluate + * @tparam InType input argument type + * @tparam AccType accumulation type + * @tparam OutType output type + * @tparam Index_ Index type + * @param x first set of points + * @param y second set of points + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * + * @note If the specifed distanceType doesn't need the workspace at all, it + * returns 0. + */ +template +size_t getWorkspaceSize(const InType *x, const InType *y, Index_ m, Index_ n, + Index_ k) { + size_t worksize = 0; + constexpr bool is_allocated = + (distanceType <= raft::distance::DistanceType::CosineExpanded) || + (distanceType == raft::distance::DistanceType::CorrelationExpanded); + constexpr int numOfBuffers = + (distanceType == raft::distance::DistanceType::CorrelationExpanded) ? 2 : 1; + + if (is_allocated) { + worksize += numOfBuffers * m * sizeof(AccType); + if (x != y) worksize += numOfBuffers * n * sizeof(AccType); + } + + return worksize; +} + +/** + * @defgroup pairwise_distance pairwise distance prims + * @{ + * @brief Convenience wrapper around 'distance' prim to convert runtime metric + * into compile time for the purpose of dispatch + * @tparam Type input/accumulation/output data-type + * @tparam Index_ indexing type + * @param x first set of points + * @param y second set of points + * @param dist output distance matrix + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * @param workspace temporary workspace buffer which can get resized as per the + * needed workspace size + * @param metric distance metric + * @param stream cuda stream + * @param isRowMajor whether the matrices are row-major or col-major + */ +template +void pairwise_distance_impl(const Type *x, const Type *y, Type *dist, Index_ m, + Index_ n, Index_ k, + rmm::device_uvector &workspace, + cudaStream_t stream, bool isRowMajor, + Type metric_arg = 2.0f) { + auto worksize = + getWorkspaceSize(x, y, m, n, k); + workspace.resize(worksize, stream); + distance(x, y, dist, m, n, k, + workspace.data(), worksize, + stream, isRowMajor, metric_arg); +} +/** @} */ +}; // namespace detail +}; // namespace distance +}; // namespace raft diff --git a/cpp/include/raft/distance/detail/euclidean.cuh b/cpp/include/raft/distance/detail/euclidean.cuh index 37374aa7f2..8592e96295 100644 --- a/cpp/include/raft/distance/detail/euclidean.cuh +++ b/cpp/include/raft/distance/detail/euclidean.cuh @@ -20,6 +20,7 @@ namespace raft { namespace distance { +namespace detail { /** * @brief the expanded euclidean distance matrix calculation implementer @@ -97,8 +98,8 @@ void euclideanExpImpl(const DataT *x, const DataT *y, const DataT *xn, pairwiseDistanceMatKernel; - dim3 grid = - launchConfigGenerator(m, n, shmemSize, euclideanExpRowMajor); + dim3 grid = detail::launchConfigGenerator(m, n, shmemSize, + euclideanExpRowMajor); euclideanExpRowMajor<<>>( x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, @@ -108,8 +109,8 @@ void euclideanExpImpl(const DataT *x, const DataT *y, const DataT *xn, pairwiseDistanceMatKernel; - dim3 grid = - launchConfigGenerator(m, n, shmemSize, euclideanExpColMajor); + dim3 grid = detail::launchConfigGenerator(m, n, shmemSize, + euclideanExpColMajor); euclideanExpColMajor<<>>( x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); @@ -267,8 +268,8 @@ void euclideanUnExpImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, - euclideanUnExpRowMajor); + dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, + euclideanUnExpRowMajor); euclideanUnExpRowMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, @@ -279,8 +280,8 @@ void euclideanUnExpImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, - euclideanUnExpColMajor); + dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, + euclideanUnExpColMajor); euclideanUnExpColMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, @@ -352,5 +353,6 @@ void euclideanAlgo2(Index_ m, Index_ n, Index_ k, const InType *pA, } } +}; // end namespace detail }; // end namespace distance }; // end namespace raft diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh new file mode 100644 index 0000000000..1e4d921052 --- /dev/null +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -0,0 +1,270 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace distance { +namespace detail { + +#if (ENABLE_MEMCPY_ASYNC == 1) +#include +using namespace nvcuda::experimental; +#endif + +template +struct KVPMinReduceImpl { + typedef cub::KeyValuePair KVP; + + DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { + return b.value < a.value ? b : a; + } + +}; // KVPMinReduce + +template +struct MinAndDistanceReduceOpImpl { + typedef typename cub::KeyValuePair KVP; + DI void operator()(LabelT rid, KVP* out, const KVP& other) { + if (other.value < out->value) { + out->key = other.key; + out->value = other.value; + } + } + + DI void init(KVP* out, DataT maxVal) { + out->key = -1; + out->value = maxVal; + } +}; + +template +struct MinReduceOpImpl { + typedef typename cub::KeyValuePair KVP; + DI void operator()(LabelT rid, DataT* out, const KVP& other) { + if (other.value < *out) { + *out = other.value; + } + } + + DI void init(DataT* out, DataT maxVal) { *out = maxVal; } +}; + +template +__global__ void initKernel(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) { + auto tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x; + if (tid < m) { + redOp.init(min + tid, maxVal); + } +} + +// TODO: specialize this function for MinAndDistanceReduceOp +// with atomicCAS of 64 bit which will eliminate mutex and shfls +template +DI void updateReducedVal(int* mutex, OutT* min, KVPair* val, ReduceOpT red_op, + IdxT m, IdxT gridStrideY) { + const auto lid = threadIdx.x % raft::WarpSize; + const auto accrowid = threadIdx.x / P::AccThCols; + + // for now have first lane from each warp update a unique output row. This + // will resolve hang issues with pre-Volta architectures +#pragma unroll + for (int j = 0; j < (raft::WarpSize / P::AccThCols); j++) { + if (lid == 0) { +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + auto rid = gridStrideY + accrowid + j + i * P::AccThRows; + if (rid < m) { + auto value = val[i]; + while (atomicCAS(mutex + rid, 0, 1) == 1) + ; + __threadfence(); + red_op(rid, min + rid, value); + __threadfence(); + atomicCAS(mutex + rid, 1, 0); + } + } + } + if (j < (raft::WarpSize / P::AccThCols) - 1) { +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + auto tmpkey = raft::shfl(val[i].key, (j + 1) * P::AccThCols); + auto tmpvalue = raft::shfl(val[i].value, (j + 1) * P::AccThCols); + val[i] = {tmpkey, tmpvalue}; + } + } + } +} + +template +__global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel( + OutT* min, const DataT* x, const DataT* y, const DataT* xn, const DataT* yn, + IdxT m, IdxT n, IdxT k, DataT maxVal, int* mutex, ReduceOpT redOp, + KVPReduceOpT pairRedOp, CoreLambda core_op, FinalLambda fin_op) { + extern __shared__ char smem[]; + + typedef cub::KeyValuePair KVPair; + KVPair val[P::AccRowsPerTh]; +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + val[i] = {-1, maxVal}; + } + + // epilogue operation lambda for final value calculation + auto epilog_lambda = [n, pairRedOp, &val, maxVal] __device__( + DataT acc[P::AccRowsPerTh][P::AccColsPerTh], + DataT * regxn, DataT * regyn, IdxT gridStrideX, + IdxT gridStrideY) { + KVPReduceOpT pairRed_op(pairRedOp); + +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; + } + } + if (Sqrt) { +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + acc[i][j] = raft::mySqrt(acc[i][j]); + } + } + } + + // intra thread reduce + const auto acccolid = threadIdx.x % P::AccThCols; + const auto accrowid = threadIdx.x / P::AccThCols; +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + auto tmpkey = acccolid + j * P::AccThCols + gridStrideX; + KVPair tmp = {tmpkey, acc[i][j]}; + if (tmpkey < n) { + val[i] = + pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); + } + } + } + }; + + auto rowEpilog_lambda = [m, mutex, min, pairRedOp, redOp, &val, + maxVal] __device__(IdxT gridStrideY) { + KVPReduceOpT pairRed_op(pairRedOp); + ReduceOpT red_op(redOp); + + const auto accrowid = threadIdx.x / P::AccThCols; + const auto lid = raft::laneId(); + + // reduce +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = P::AccThCols / 2; j > 0; j >>= 1) { + auto tmpkey = raft::shfl(val[i].key, lid + j); + auto tmpvalue = raft::shfl(val[i].value, lid + j); + KVPair tmp = {tmpkey, tmpvalue}; + val[i] = + pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); + } + } + + updateReducedVal(mutex, min, val, red_op, + m, gridStrideY); + + // reset the val array. +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + val[i] = {-1, maxVal}; + } + }; + + IdxT lda = k, ldb = k, ldd = n; + PairwiseDistances + obj(x, y, m, n, k, lda, ldb, ldd, xn, yn, nullptr, smem, core_op, + epilog_lambda, fin_op, rowEpilog_lambda); + obj.run(); +} + +template +void fusedL2NNImpl(OutT* min, const DataT* x, const DataT* y, const DataT* xn, + const DataT* yn, IdxT m, IdxT n, IdxT k, int* workspace, + ReduceOpT redOp, KVPReduceOpT pairRedOp, bool sqrt, + bool initOutBuffer, cudaStream_t stream) { + typedef typename linalg::Policy4x4::Policy P; + + dim3 blk(P::Nthreads); + auto nblks = raft::ceildiv(m, P::Nthreads); + constexpr auto maxVal = std::numeric_limits::max(); + typedef cub::KeyValuePair KVPair; + + // Accumulation operation lambda + auto core_lambda = [] __device__(DataT & acc, DataT & x, DataT & y) { + acc += x * y; + }; + + CUDA_CHECK(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); + if (initOutBuffer) { + initKernel + <<>>(min, m, maxVal, redOp); + CUDA_CHECK(cudaGetLastError()); + } + + auto fin_op = [] __device__(DataT d_val, int g_d_idx) { return d_val; }; + + constexpr size_t shmemSize = + P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); + if (sqrt) { + auto fusedL2NNSqrt = + fusedL2NNkernel; + dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NNSqrt); + + fusedL2NNSqrt<<>>( + min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, + core_lambda, fin_op); + } else { + auto fusedL2NN = + fusedL2NNkernel; + dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NN); + fusedL2NN<<>>(min, x, y, xn, yn, m, n, k, + maxVal, workspace, redOp, + pairRedOp, core_lambda, fin_op); + } + + CUDA_CHECK(cudaGetLastError()); +} + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/detail/hamming.cuh b/cpp/include/raft/distance/detail/hamming.cuh index adacfa4895..2134b8d05b 100644 --- a/cpp/include/raft/distance/detail/hamming.cuh +++ b/cpp/include/raft/distance/detail/hamming.cuh @@ -19,6 +19,7 @@ namespace raft { namespace distance { +namespace detail { /** * @brief the Hamming distance matrix using the unexpanded form: @@ -85,8 +86,8 @@ static void hammingUnexpandedImpl(const DataT *x, const DataT *y, IdxT m, pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, - hammingUnexpandedRowMajor); + dim3 grid = detail::launchConfigGenerator( + m, n, KPolicy::SmemSize, hammingUnexpandedRowMajor); hammingUnexpandedRowMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, @@ -96,8 +97,8 @@ static void hammingUnexpandedImpl(const DataT *x, const DataT *y, IdxT m, pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, - hammingUnexpandedColMajor); + dim3 grid = detail::launchConfigGenerator( + m, n, KPolicy::SmemSize, hammingUnexpandedColMajor); hammingUnexpandedColMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); @@ -171,5 +172,7 @@ void hammingUnexpandedImpl(int m, int n, int k, const InType *pA, pDcast, fin_op, stream); } } + +} // namespace detail } // namespace distance } // namespace raft diff --git a/cpp/include/raft/distance/detail/hellinger.cuh b/cpp/include/raft/distance/detail/hellinger.cuh index 6baf1ff5f1..a99f6db1e9 100644 --- a/cpp/include/raft/distance/detail/hellinger.cuh +++ b/cpp/include/raft/distance/detail/hellinger.cuh @@ -20,6 +20,7 @@ namespace raft { namespace distance { +namespace detail { /** * @brief the Hellinger distance matrix using the expanded form: @@ -105,8 +106,8 @@ static void hellingerImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, - hellingerRowMajor); + dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, + hellingerRowMajor); hellingerRowMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, @@ -116,8 +117,8 @@ static void hellingerImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, - hellingerColMajor); + dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, + hellingerColMajor); hellingerColMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); @@ -200,5 +201,6 @@ void hellingerImpl(int m, int n, int k, const InType *pA, const InType *pB, n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); } } +} // namespace detail } // namespace distance } // namespace raft diff --git a/cpp/include/raft/distance/detail/jensen_shannon.cuh b/cpp/include/raft/distance/detail/jensen_shannon.cuh index f4fe7f2d38..dbaf98562e 100644 --- a/cpp/include/raft/distance/detail/jensen_shannon.cuh +++ b/cpp/include/raft/distance/detail/jensen_shannon.cuh @@ -19,6 +19,7 @@ namespace raft { namespace distance { +namespace detail { /** * @brief the Jensen Shannon distance matrix: @@ -92,8 +93,8 @@ static void jensenShannonImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, - jensenShannonRowMajor); + dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, + jensenShannonRowMajor); jensenShannonRowMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, @@ -103,8 +104,8 @@ static void jensenShannonImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, - jensenShannonColMajor); + dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, + jensenShannonColMajor); jensenShannonColMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); @@ -177,5 +178,6 @@ void jensenShannonImpl(int m, int n, int k, const InType *pA, const InType *pB, stream); } } +} // namespace detail } // namespace distance } // namespace raft diff --git a/cpp/include/raft/distance/detail/kl_divergence.cuh b/cpp/include/raft/distance/detail/kl_divergence.cuh index b22a3e3aaa..da6abbf3f3 100644 --- a/cpp/include/raft/distance/detail/kl_divergence.cuh +++ b/cpp/include/raft/distance/detail/kl_divergence.cuh @@ -19,6 +19,7 @@ namespace raft { namespace distance { +namespace detail { /** * @brief the KL Divergence distance matrix: @@ -126,8 +127,8 @@ static void klDivergenceImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, if (x != y) { raft::linalg::unaryOp( (DataT *)y, y, n * k, unaryOp_lambda, stream); - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, - klDivergenceRowMajor); + dim3 grid = detail::launchConfigGenerator( + m, n, KPolicy::SmemSize, klDivergenceRowMajor); klDivergenceRowMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); @@ -135,8 +136,8 @@ static void klDivergenceImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, raft::linalg::unaryOp( (DataT *)y, y, n * k, unaryOp_lambda_reverse, stream); } else { - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, - klDivergenceRowMajorXequalY); + dim3 grid = detail::launchConfigGenerator( + m, n, KPolicy::SmemSize, klDivergenceRowMajorXequalY); klDivergenceRowMajorXequalY<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda_x_equal_y, epilog_lambda, fin_op); @@ -153,8 +154,8 @@ static void klDivergenceImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, if (x != y) { raft::linalg::unaryOp( (DataT *)x, x, m * k, unaryOp_lambda, stream); - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, - klDivergenceColMajor); + dim3 grid = detail::launchConfigGenerator( + m, n, KPolicy::SmemSize, klDivergenceColMajor); klDivergenceColMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); @@ -162,8 +163,8 @@ static void klDivergenceImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, raft::linalg::unaryOp( (DataT *)x, x, m * k, unaryOp_lambda_reverse, stream); } else { - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, - klDivergenceColMajorXequalY); + dim3 grid = detail::launchConfigGenerator( + m, n, KPolicy::SmemSize, klDivergenceColMajorXequalY); klDivergenceColMajorXequalY<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda_x_equal_y, epilog_lambda, fin_op); @@ -238,5 +239,6 @@ void klDivergenceImpl(int m, int n, int k, const InType *pA, const InType *pB, false>(n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); } } +} // namespace detail } // namespace distance } // namespace raft diff --git a/cpp/include/raft/distance/detail/l1.cuh b/cpp/include/raft/distance/detail/l1.cuh index 5e5fdf6feb..0e1f1bfa71 100644 --- a/cpp/include/raft/distance/detail/l1.cuh +++ b/cpp/include/raft/distance/detail/l1.cuh @@ -19,6 +19,7 @@ namespace raft { namespace distance { +namespace detail { /** * @brief the L1 distance matrix calculation implementer @@ -72,8 +73,8 @@ static void l1Impl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, pairwiseDistanceMatKernel; - dim3 grid = - launchConfigGenerator(m, n, KPolicy::SmemSize, l1RowMajor); + dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, + l1RowMajor); l1RowMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, @@ -83,8 +84,8 @@ static void l1Impl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, pairwiseDistanceMatKernel; - dim3 grid = - launchConfigGenerator(m, n, KPolicy::SmemSize, l1ColMajor); + dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, + l1ColMajor); l1ColMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); @@ -151,5 +152,6 @@ void l1Impl(int m, int n, int k, const InType *pA, const InType *pB, n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); } } +} // namespace detail } // namespace distance } // namespace raft diff --git a/cpp/include/raft/distance/detail/minkowski.cuh b/cpp/include/raft/distance/detail/minkowski.cuh index d45eac6f63..2fab2ae6de 100644 --- a/cpp/include/raft/distance/detail/minkowski.cuh +++ b/cpp/include/raft/distance/detail/minkowski.cuh @@ -19,6 +19,7 @@ namespace raft { namespace distance { +namespace detail { /** * @brief the unexpanded Minkowski distance matrix calculation @@ -83,8 +84,8 @@ void minkowskiUnExpImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, - minkowskiUnExpRowMajor); + dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, + minkowskiUnExpRowMajor); minkowskiUnExpRowMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, @@ -95,8 +96,8 @@ void minkowskiUnExpImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, - minkowskiUnExpColMajor); + dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, + minkowskiUnExpColMajor); minkowskiUnExpColMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, @@ -167,6 +168,6 @@ void minkowskiImpl(Index_ m, Index_ n, Index_ k, const InType *pA, n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream, metric_arg); } } - +}; // end namespace detail }; // end namespace distance }; // end namespace raft diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index e3ff9a7081..a98bda1541 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -24,6 +24,7 @@ namespace raft { namespace distance { +namespace detail { /** * @brief Device class for L1, L2 and cosine distance metrics. @@ -69,11 +70,11 @@ template -__global__ __launch_bounds__( - Policy::Nthreads, - 2) void pairwiseDistanceMatKernel(const DataT* x, const DataT* y, - const DataT* _xn, const DataT* _yn, IdxT m, - IdxT n, IdxT k, IdxT lda, IdxT ldb, - IdxT ldd, OutT* dOutput, CoreLambda core_op, - EpilogueLambda epilog_op, - FinalLambda fin_op) { +__global__ __launch_bounds__(Policy::Nthreads, 2) + + void pairwiseDistanceMatKernel(const DataT *x, const DataT *y, + const DataT *_xn, const DataT *_yn, IdxT m, + IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd, + OutT *dOutput, CoreLambda core_op, + EpilogueLambda epilog_op, FinalLambda fin_op) { extern __shared__ char smem[]; auto rowEpilog = [] __device__(IdxT starty) { return; }; @@ -337,5 +337,6 @@ dim3 launchConfigGenerator(IdxT m, IdxT n, std::size_t sMemSize, T func) { return grid; } +}; // namespace detail }; // namespace distance }; // namespace raft diff --git a/cpp/include/raft/distance/detail/russell_rao.cuh b/cpp/include/raft/distance/detail/russell_rao.cuh index 663990bf2d..cc8d0a855a 100644 --- a/cpp/include/raft/distance/detail/russell_rao.cuh +++ b/cpp/include/raft/distance/detail/russell_rao.cuh @@ -19,6 +19,7 @@ namespace raft { namespace distance { +namespace detail { /** * @brief the Russell Rao distance matrix: @@ -84,8 +85,8 @@ static void russellRaoImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, - russellRaoRowMajor); + dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, + russellRaoRowMajor); russellRaoRowMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, @@ -95,8 +96,8 @@ static void russellRaoImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, - russellRaoColMajor); + dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, + russellRaoColMajor); russellRaoColMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); @@ -167,5 +168,6 @@ void russellRaoImpl(int m, int n, int k, const InType *pA, const InType *pB, n, m, k, lda, ldb, ldd, pB, pA, pDcast, fin_op, stream); } } +} // namespace detail } // namespace distance } // namespace raft diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index 52041a59ad..a0c3544c92 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -16,294 +16,38 @@ #pragma once -#include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include #include namespace raft { namespace distance { -namespace { -template -struct DistanceImpl { - void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, - Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, - cudaStream_t stream, bool isRowMajor, InType metric_arg = 2.0f) {} -}; - -template -struct DistanceImpl { - void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, - Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, - cudaStream_t stream, bool isRowMajor, InType) { - raft::distance::euclideanAlgo1(m, n, k, x, y, dist, false, - (AccType *)workspace, worksize, - fin_op, stream, isRowMajor); - } -}; - -template -struct DistanceImpl { - void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, - Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, - cudaStream_t stream, bool isRowMajor, InType) { - raft::distance::euclideanAlgo1(m, n, k, x, y, dist, true, - (AccType *)workspace, worksize, - fin_op, stream, isRowMajor); - } -}; - -template -struct DistanceImpl { - void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, - Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, - cudaStream_t stream, bool isRowMajor, InType) { - raft::distance::cosineAlgo1( - m, n, k, x, y, dist, (AccType *)workspace, worksize, fin_op, stream, - isRowMajor); - } -}; - -template -struct DistanceImpl { - void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, - Index_ k, void *, size_t, FinalLambda fin_op, cudaStream_t stream, - bool isRowMajor, InType) { - raft::distance::euclideanAlgo2(m, n, k, x, y, dist, false, fin_op, - stream, isRowMajor); - } -}; - -template -struct DistanceImpl { - void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, - Index_ k, void *, size_t, FinalLambda fin_op, cudaStream_t stream, - bool isRowMajor, InType) { - raft::distance::euclideanAlgo2(m, n, k, x, y, dist, true, fin_op, - stream, isRowMajor); - } -}; - -template -struct DistanceImpl { - void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, - Index_ k, void *, size_t, FinalLambda fin_op, cudaStream_t stream, - bool isRowMajor, InType) { - raft::distance::l1Impl( - m, n, k, x, y, dist, fin_op, stream, isRowMajor); - } -}; - -template -struct DistanceImpl { - void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, - Index_ k, void *, size_t, FinalLambda fin_op, cudaStream_t stream, - bool isRowMajor, InType) { - raft::distance::chebyshevImpl(m, n, k, x, y, dist, fin_op, stream, - isRowMajor); - } -}; - -template -struct DistanceImpl { - void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, - Index_ k, void *, size_t, FinalLambda fin_op, cudaStream_t stream, - bool isRowMajor, InType) { - raft::distance::hellingerImpl(m, n, k, x, y, dist, fin_op, stream, - isRowMajor); - } -}; - -template -struct DistanceImpl { - void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, - Index_ k, void *, size_t, FinalLambda fin_op, cudaStream_t stream, - bool isRowMajor, InType metric_arg) { - raft::distance::minkowskiImpl(m, n, k, x, y, dist, fin_op, stream, - isRowMajor, metric_arg); - } -}; - -template -struct DistanceImpl { - void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, - Index_ k, void *, size_t, FinalLambda fin_op, cudaStream_t stream, - bool isRowMajor, InType) { - raft::distance::canberraImpl( - m, n, k, x, y, dist, fin_op, stream, isRowMajor); - } -}; - -template -struct DistanceImpl { - void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, - Index_ k, void *, size_t, FinalLambda fin_op, cudaStream_t stream, - bool isRowMajor, InType) { - raft::distance::hammingUnexpandedImpl(m, n, k, x, y, dist, fin_op, - stream, isRowMajor); - } -}; - -template -struct DistanceImpl { - void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, - Index_ k, void *, size_t, FinalLambda fin_op, cudaStream_t stream, - bool isRowMajor, InType) { - raft::distance::jensenShannonImpl(m, n, k, x, y, dist, fin_op, - stream, isRowMajor); - } -}; - -template -struct DistanceImpl { - void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, - Index_ k, void *, size_t, FinalLambda fin_op, cudaStream_t stream, - bool isRowMajor, InType) { - raft::distance::russellRaoImpl(m, n, k, x, y, dist, fin_op, stream, - isRowMajor); - } -}; - -template -struct DistanceImpl { - void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, - Index_ k, void *, size_t, FinalLambda fin_op, cudaStream_t stream, - bool isRowMajor, InType) { - raft::distance::klDivergenceImpl(m, n, k, x, y, dist, fin_op, - stream, isRowMajor); - } -}; - -template -struct DistanceImpl { - void run(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, - Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, - cudaStream_t stream, bool isRowMajor, InType) { - raft::distance::correlationImpl(m, n, k, x, y, dist, - (AccType *)workspace, worksize, - fin_op, stream, isRowMajor); - } -}; - -} // anonymous namespace - /** - * @brief Return the exact workspace size to compute the distance - * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam Index_ Index type - * @param x first set of points - * @param y second set of points - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * - * @note If the specifed distanceType doesn't need the workspace at all, it - * returns 0. - */ -template -size_t getWorkspaceSize(const InType *x, const InType *y, Index_ m, Index_ n, - Index_ k) { - size_t worksize = 0; - constexpr bool is_allocated = - (distanceType <= raft::distance::DistanceType::CosineExpanded) || - (distanceType == raft::distance::DistanceType::CorrelationExpanded); - constexpr int numOfBuffers = - (distanceType == raft::distance::DistanceType::CorrelationExpanded) ? 2 : 1; - - if (is_allocated) { - worksize += numOfBuffers * m * sizeof(AccType); - if (x != y) worksize += numOfBuffers * n * sizeof(AccType); - } - - return worksize; -} - -/** - * @brief Evaluate pairwise distances with the user epilogue lamba allowed - * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam FinalLambda user-defined epilogue lamba - * @tparam Index_ Index type - * @param x first set of points - * @param y second set of points - * @param dist output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param workspace temporary workspace needed for computations - * @param worksize number of bytes of the workspace - * @param fin_op the final gemm epilogue lambda - * @param stream cuda stream - * @param isRowMajor whether the matrices are row-major or col-major - * - * @note fin_op: This is a device lambda which is supposed to operate upon the - * input which is AccType and returns the output in OutType. It's signature is - * as follows:

OutType fin_op(AccType in, int g_idx);
. If one needs - * any other parameters, feel free to pass them via closure. - */ +* @brief Evaluate pairwise distances with the user epilogue lamba allowed +* @tparam DistanceType which distance to evaluate +* @tparam InType input argument type +* @tparam AccType accumulation type +* @tparam OutType output type +* @tparam FinalLambda user-defined epilogue lamba +* @tparam Index_ Index type +* @param x first set of points +* @param y second set of points +* @param dist output distance matrix +* @param m number of points in x +* @param n number of points in y +* @param k dimensionality +* @param workspace temporary workspace needed for computations +* @param worksize number of bytes of the workspace +* @param fin_op the final gemm epilogue lambda +* @param stream cuda stream +* @param isRowMajor whether the matrices are row-major or col-major +* +* @note fin_op: This is a device lambda which is supposed to operate upon the +* input which is AccType and returns the output in OutType. It's signature is +* as follows:
OutType fin_op(AccType in, int g_idx);
. If one needs +* any other parameters, feel free to pass them via closure. +*/ template @@ -311,47 +55,64 @@ void distance(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, Index_ k, void *workspace, size_t worksize, FinalLambda fin_op, cudaStream_t stream, bool isRowMajor = true, InType metric_arg = 2.0f) { - DistanceImpl - distImpl; - distImpl.run(x, y, dist, m, n, k, workspace, worksize, fin_op, stream, - isRowMajor, metric_arg); - CUDA_CHECK(cudaPeekAtLastError()); + detail::distance( + x, y, dist, m, n, k, workspace, worksize, fin_op, stream, isRowMajor, + metric_arg); } /** - * @brief Evaluate pairwise distances for the simple use case - * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam Index_ Index type - * @param x first set of points - * @param y second set of points - * @param dist output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param workspace temporary workspace needed for computations - * @param worksize number of bytes of the workspace - * @param stream cuda stream - * @param isRowMajor whether the matrices are row-major or col-major - * - * @note if workspace is passed as nullptr, this will return in - * worksize, the number of bytes of workspace required - */ +* @brief Evaluate pairwise distances for the simple use case +* @tparam DistanceType which distance to evaluate +* @tparam InType input argument type +* @tparam AccType accumulation type +* @tparam OutType output type +* @tparam Index_ Index type +* @param x first set of points +* @param y second set of points +* @param dist output distance matrix +* @param m number of points in x +* @param n number of points in y +* @param k dimensionality +* @param workspace temporary workspace needed for computations +* @param worksize number of bytes of the workspace +* @param stream cuda stream +* @param isRowMajor whether the matrices are row-major or col-major +* +* @note if workspace is passed as nullptr, this will return in +* worksize, the number of bytes of workspace required +*/ template void distance(const InType *x, const InType *y, OutType *dist, Index_ m, Index_ n, Index_ k, void *workspace, size_t worksize, cudaStream_t stream, bool isRowMajor = true, InType metric_arg = 2.0f) { - auto default_fin_op = [] __device__(AccType d_val, Index_ g_d_idx) { - return d_val; - }; - distance(x, y, dist, m, n, k, workspace, worksize, default_fin_op, - stream, isRowMajor, metric_arg); - CUDA_CHECK(cudaPeekAtLastError()); + detail::distance( + x, y, dist, m, n, k, workspace, worksize, stream, isRowMajor, metric_arg); +} + +/** +* @brief Return the exact workspace size to compute the distance +* @tparam DistanceType which distance to evaluate +* @tparam InType input argument type +* @tparam AccType accumulation type +* @tparam OutType output type +* @tparam Index_ Index type +* @param x first set of points +* @param y second set of points +* @param m number of points in x +* @param n number of points in y +* @param k dimensionality +* +* @note If the specifed distanceType doesn't need the workspace at all, it +* returns 0. +*/ +template +size_t getWorkspaceSize(const InType *x, const InType *y, Index_ m, Index_ n, + Index_ k) { + return detail::getWorkspaceSize(x, y, m, n, k); } /** @@ -373,20 +134,6 @@ void distance(const InType *x, const InType *y, OutType *dist, Index_ m, * @param stream cuda stream * @param isRowMajor whether the matrices are row-major or col-major */ -template -void pairwise_distance_impl(const Type *x, const Type *y, Type *dist, Index_ m, - Index_ n, Index_ k, - rmm::device_uvector &workspace, - cudaStream_t stream, bool isRowMajor, - Type metric_arg = 2.0f) { - auto worksize = - getWorkspaceSize(x, y, m, n, k); - workspace.resize(worksize, stream); - distance(x, y, dist, m, n, k, - workspace.data(), worksize, - stream, isRowMajor, metric_arg); -} - template void pairwise_distance(const Type *x, const Type *y, Type *dist, Index_ m, Index_ n, Index_ k, rmm::device_uvector &workspace, @@ -394,76 +141,78 @@ void pairwise_distance(const Type *x, const Type *y, Type *dist, Index_ m, bool isRowMajor = true, Type metric_arg = 2.0f) { switch (metric) { case raft::distance::DistanceType::L2Expanded: - pairwise_distance_impl( + detail::pairwise_distance_impl( x, y, dist, m, n, k, workspace, stream, isRowMajor); break; case raft::distance::DistanceType::L2SqrtExpanded: - pairwise_distance_impl( + detail::pairwise_distance_impl< + Type, Index_, raft::distance::DistanceType::L2SqrtExpanded>( x, y, dist, m, n, k, workspace, stream, isRowMajor); break; case raft::distance::DistanceType::CosineExpanded: - pairwise_distance_impl( + detail::pairwise_distance_impl< + Type, Index_, raft::distance::DistanceType::CosineExpanded>( x, y, dist, m, n, k, workspace, stream, isRowMajor); break; case raft::distance::DistanceType::L1: - pairwise_distance_impl( + detail::pairwise_distance_impl( x, y, dist, m, n, k, workspace, stream, isRowMajor); break; case raft::distance::DistanceType::L2Unexpanded: - pairwise_distance_impl( + detail::pairwise_distance_impl< + Type, Index_, raft::distance::DistanceType::L2Unexpanded>( x, y, dist, m, n, k, workspace, stream, isRowMajor); break; case raft::distance::DistanceType::L2SqrtUnexpanded: - pairwise_distance_impl( + detail::pairwise_distance_impl< + Type, Index_, raft::distance::DistanceType::L2SqrtUnexpanded>( x, y, dist, m, n, k, workspace, stream, isRowMajor); break; case raft::distance::DistanceType::Linf: - pairwise_distance_impl( + detail::pairwise_distance_impl( x, y, dist, m, n, k, workspace, stream, isRowMajor); break; case raft::distance::DistanceType::HellingerExpanded: - pairwise_distance_impl( + detail::pairwise_distance_impl< + Type, Index_, raft::distance::DistanceType::HellingerExpanded>( x, y, dist, m, n, k, workspace, stream, isRowMajor); break; case raft::distance::DistanceType::LpUnexpanded: - pairwise_distance_impl( + detail::pairwise_distance_impl< + Type, Index_, raft::distance::DistanceType::LpUnexpanded>( x, y, dist, m, n, k, workspace, stream, isRowMajor, metric_arg); break; case raft::distance::DistanceType::Canberra: - pairwise_distance_impl( + detail::pairwise_distance_impl( x, y, dist, m, n, k, workspace, stream, isRowMajor); break; case raft::distance::DistanceType::HammingUnexpanded: - pairwise_distance_impl( + detail::pairwise_distance_impl< + Type, Index_, raft::distance::DistanceType::HammingUnexpanded>( x, y, dist, m, n, k, workspace, stream, isRowMajor); break; case raft::distance::DistanceType::JensenShannon: - pairwise_distance_impl( + detail::pairwise_distance_impl< + Type, Index_, raft::distance::DistanceType::JensenShannon>( x, y, dist, m, n, k, workspace, stream, isRowMajor); break; case raft::distance::DistanceType::RusselRaoExpanded: - pairwise_distance_impl( + detail::pairwise_distance_impl< + Type, Index_, raft::distance::DistanceType::RusselRaoExpanded>( x, y, dist, m, n, k, workspace, stream, isRowMajor); break; case raft::distance::DistanceType::KLDivergence: - pairwise_distance_impl( + detail::pairwise_distance_impl< + Type, Index_, raft::distance::DistanceType::KLDivergence>( x, y, dist, m, n, k, workspace, stream, isRowMajor); break; case raft::distance::DistanceType::CorrelationExpanded: - pairwise_distance_impl( + detail::pairwise_distance_impl< + Type, Index_, raft::distance::DistanceType::CorrelationExpanded>( x, y, dist, m, n, k, workspace, stream, isRowMajor); break; default: diff --git a/cpp/include/raft/distance/fused_l2_nn.cuh b/cpp/include/raft/distance/fused_l2_nn.cuh index cafb9a91ba..a047233414 100644 --- a/cpp/include/raft/distance/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/fused_l2_nn.cuh @@ -20,249 +20,20 @@ #include #include #include -#include -#include +#include namespace raft { namespace distance { -#if (ENABLE_MEMCPY_ASYNC == 1) -#include -using namespace nvcuda::experimental; -#endif - template -struct KVPMinReduce { - typedef cub::KeyValuePair KVP; - - DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { - return b.value < a.value ? b : a; - } - -}; // KVPMinReduce +using KVPMinReduce = detail::KVPMinReduceImpl; template -struct MinAndDistanceReduceOp { - typedef typename cub::KeyValuePair KVP; - DI void operator()(LabelT rid, KVP* out, const KVP& other) { - if (other.value < out->value) { - out->key = other.key; - out->value = other.value; - } - } - - DI void init(KVP* out, DataT maxVal) { - out->key = -1; - out->value = maxVal; - } -}; +using MinAndDistanceReduceOp = + detail::MinAndDistanceReduceOpImpl; template -struct MinReduceOp { - typedef typename cub::KeyValuePair KVP; - DI void operator()(LabelT rid, DataT* out, const KVP& other) { - if (other.value < *out) { - *out = other.value; - } - } - - DI void init(DataT* out, DataT maxVal) { *out = maxVal; } -}; - -template -__global__ void initKernel(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) { - auto tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x; - if (tid < m) { - redOp.init(min + tid, maxVal); - } -} - -// TODO: specialize this function for MinAndDistanceReduceOp -// with atomicCAS of 64 bit which will eliminate mutex and shfls -template -DI void updateReducedVal(int* mutex, OutT* min, KVPair* val, ReduceOpT red_op, - IdxT m, IdxT gridStrideY) { - const auto lid = threadIdx.x % raft::WarpSize; - const auto accrowid = threadIdx.x / P::AccThCols; - - // for now have first lane from each warp update a unique output row. This - // will resolve hang issues with pre-Volta architectures -#pragma unroll - for (int j = 0; j < (raft::WarpSize / P::AccThCols); j++) { - if (lid == 0) { -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - auto rid = gridStrideY + accrowid + j + i * P::AccThRows; - if (rid < m) { - auto value = val[i]; - while (atomicCAS(mutex + rid, 0, 1) == 1) - ; - __threadfence(); - red_op(rid, min + rid, value); - __threadfence(); - atomicCAS(mutex + rid, 1, 0); - } - } - } - if (j < (raft::WarpSize / P::AccThCols) - 1) { -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - auto tmpkey = raft::shfl(val[i].key, (j + 1) * P::AccThCols); - auto tmpvalue = raft::shfl(val[i].value, (j + 1) * P::AccThCols); - val[i] = {tmpkey, tmpvalue}; - } - } - } -} - -template -__global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel( - OutT* min, const DataT* x, const DataT* y, const DataT* xn, const DataT* yn, - IdxT m, IdxT n, IdxT k, DataT maxVal, int* mutex, ReduceOpT redOp, - KVPReduceOpT pairRedOp, CoreLambda core_op, FinalLambda fin_op) { - extern __shared__ char smem[]; - - typedef cub::KeyValuePair KVPair; - KVPair val[P::AccRowsPerTh]; -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - val[i] = {-1, maxVal}; - } - - // epilogue operation lambda for final value calculation - auto epilog_lambda = [n, pairRedOp, &val, maxVal] __device__( - DataT acc[P::AccRowsPerTh][P::AccColsPerTh], - DataT * regxn, DataT * regyn, IdxT gridStrideX, - IdxT gridStrideY) { - KVPReduceOpT pairRed_op(pairRedOp); - -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; - } - } - if (Sqrt) { -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = raft::mySqrt(acc[i][j]); - } - } - } - - // intra thread reduce - const auto acccolid = threadIdx.x % P::AccThCols; - const auto accrowid = threadIdx.x / P::AccThCols; -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - auto tmpkey = acccolid + j * P::AccThCols + gridStrideX; - KVPair tmp = {tmpkey, acc[i][j]}; - if (tmpkey < n) { - val[i] = - pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); - } - } - } - }; - - auto rowEpilog_lambda = [m, mutex, min, pairRedOp, redOp, &val, - maxVal] __device__(IdxT gridStrideY) { - KVPReduceOpT pairRed_op(pairRedOp); - ReduceOpT red_op(redOp); - - const auto accrowid = threadIdx.x / P::AccThCols; - const auto lid = raft::laneId(); - - // reduce -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = P::AccThCols / 2; j > 0; j >>= 1) { - auto tmpkey = raft::shfl(val[i].key, lid + j); - auto tmpvalue = raft::shfl(val[i].value, lid + j); - KVPair tmp = {tmpkey, tmpvalue}; - val[i] = - pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); - } - } - - updateReducedVal(mutex, min, val, red_op, - m, gridStrideY); - - // reset the val array. -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - val[i] = {-1, maxVal}; - } - }; - - IdxT lda = k, ldb = k, ldd = n; - PairwiseDistances - obj(x, y, m, n, k, lda, ldb, ldd, xn, yn, nullptr, smem, core_op, - epilog_lambda, fin_op, rowEpilog_lambda); - obj.run(); -} - -template -void fusedL2NNImpl(OutT* min, const DataT* x, const DataT* y, const DataT* xn, - const DataT* yn, IdxT m, IdxT n, IdxT k, int* workspace, - ReduceOpT redOp, KVPReduceOpT pairRedOp, bool sqrt, - bool initOutBuffer, cudaStream_t stream) { - typedef typename linalg::Policy4x4::Policy P; - - dim3 blk(P::Nthreads); - auto nblks = raft::ceildiv(m, P::Nthreads); - constexpr auto maxVal = std::numeric_limits::max(); - typedef cub::KeyValuePair KVPair; - - // Accumulation operation lambda - auto core_lambda = [] __device__(DataT & acc, DataT & x, DataT & y) { - acc += x * y; - }; - - CUDA_CHECK(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); - if (initOutBuffer) { - initKernel - <<>>(min, m, maxVal, redOp); - CUDA_CHECK(cudaGetLastError()); - } - - auto fin_op = [] __device__(DataT d_val, int g_d_idx) { return d_val; }; - - constexpr size_t shmemSize = - P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); - if (sqrt) { - auto fusedL2NNSqrt = - fusedL2NNkernel; - dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NNSqrt); - - fusedL2NNSqrt<<>>( - min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, - core_lambda, fin_op); - } else { - auto fusedL2NN = - fusedL2NNkernel; - dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NN); - fusedL2NN<<>>(min, x, y, xn, yn, m, n, k, - maxVal, workspace, redOp, - pairRedOp, core_lambda, fin_op); - } - - CUDA_CHECK(cudaGetLastError()); -} +using MinReduceOp = detail::MinReduceOpImpl; /** * @brief Fused L2 distance and 1-nearest-neighbor computation in a single call. @@ -307,15 +78,15 @@ void fusedL2NN(OutT* min, const DataT* x, const DataT* y, const DataT* xn, bool initOutBuffer, cudaStream_t stream) { size_t bytes = sizeof(DataT) * k; if (16 % sizeof(DataT) == 0 && bytes % 16 == 0) { - fusedL2NNImpl( + detail::fusedL2NNImpl( min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); } else if (8 % sizeof(DataT) == 0 && bytes % 8 == 0) { - fusedL2NNImpl( + detail::fusedL2NNImpl( min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); } else { - fusedL2NNImpl( + detail::fusedL2NNImpl( min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); } diff --git a/cpp/include/raft/sparse/distance/detail/bin_distance.cuh b/cpp/include/raft/sparse/distance/detail/bin_distance.cuh index 5b59ce89ed..e6dd1331ae 100644 --- a/cpp/include/raft/sparse/distance/detail/bin_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/bin_distance.cuh @@ -32,7 +32,7 @@ namespace raft { namespace sparse { namespace distance { - +namespace detail { // @TODO: Move this into sparse prims (coo_norm) template __global__ void compute_binary_row_norm_kernel( @@ -193,6 +193,7 @@ class dice_expanded_distances_t : public distances_t { ip_distances_t ip_dists; }; +} // END namespace detail }; // END namespace distance }; // END namespace sparse }; // END namespace raft diff --git a/cpp/include/raft/sparse/distance/detail/coo_spmv.cuh b/cpp/include/raft/sparse/distance/detail/coo_spmv.cuh index a9ff55a291..83844b8c54 100644 --- a/cpp/include/raft/sparse/distance/detail/coo_spmv.cuh +++ b/cpp/include/raft/sparse/distance/detail/coo_spmv.cuh @@ -37,6 +37,7 @@ namespace raft { namespace sparse { namespace distance { +namespace detail { template class coo_spmv_strategy { @@ -84,6 +85,7 @@ class coo_spmv_strategy { const distances_config_t &config; }; +} // namespace detail } // namespace distance } // namespace sparse } // namespace raft diff --git a/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/coo_mask_row_iterators.cuh b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/coo_mask_row_iterators.cuh index 2e774e3a02..0ab7b65ac2 100644 --- a/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/coo_mask_row_iterators.cuh +++ b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/coo_mask_row_iterators.cuh @@ -24,6 +24,7 @@ namespace raft { namespace sparse { namespace distance { +namespace detail { template class mask_row_it { @@ -186,6 +187,7 @@ class chunked_mask_row_it : public mask_row_it { } }; +} // namespace detail } // namespace distance } // namespace sparse } // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/dense_smem_strategy.cuh b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/dense_smem_strategy.cuh index c463654a3b..79a5f154d0 100644 --- a/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/dense_smem_strategy.cuh +++ b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/dense_smem_strategy.cuh @@ -21,6 +21,7 @@ namespace raft { namespace sparse { namespace distance { +namespace detail { template class dense_smem_strategy : public coo_spmv_strategy { @@ -92,6 +93,7 @@ class dense_smem_strategy : public coo_spmv_strategy { } }; +} // namespace detail } // namespace distance } // namespace sparse } // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/hash_strategy.cuh b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/hash_strategy.cuh index a95c6ff85b..5ba2d5c102 100644 --- a/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/hash_strategy.cuh +++ b/cpp/include/raft/sparse/distance/detail/coo_spmv_strategies/hash_strategy.cuh @@ -33,6 +33,7 @@ CUCO_DECLARE_BITWISE_COMPARABLE(double); namespace raft { namespace sparse { namespace distance { +namespace detail { template class hash_strategy : public coo_spmv_strategy { @@ -217,6 +218,7 @@ class hash_strategy : public coo_spmv_strategy { int map_size; }; +} // namespace detail } // namespace distance } // namespace sparse } // namespace raft diff --git a/cpp/include/raft/sparse/distance/detail/ip_distance.cuh b/cpp/include/raft/sparse/distance/detail/ip_distance.cuh index def59d683c..2cd7b670d8 100644 --- a/cpp/include/raft/sparse/distance/detail/ip_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/ip_distance.cuh @@ -36,14 +36,15 @@ namespace raft { namespace sparse { namespace distance { +namespace detail { template class ip_distances_t : public distances_t { public: /** - * Computes simple sparse inner product distances as sum(x_y * y_k) - * @param[in] config specifies inputs, outputs, and sizes - */ + * Computes simple sparse inner product distances as sum(x_y * y_k) + * @param[in] config specifies inputs, outputs, and sizes + */ ip_distances_t(const distances_config_t &config) : config_(&config), coo_rows_b(config.b_nnz, config.handle.get_stream()) { raft::sparse::convert::csr_to_coo(config_->b_indptr, config_->b_nrows, @@ -52,13 +53,13 @@ class ip_distances_t : public distances_t { } /** - * Performs pairwise distance computation and computes output distances - * @param out_distances dense output matrix (size a_nrows * b_nrows) - */ + * Performs pairwise distance computation and computes output distances + * @param out_distances dense output matrix (size a_nrows * b_nrows) + */ void compute(value_t *out_distances) { /** - * Compute pairwise distances and return dense matrix in row-major format - */ + * Compute pairwise distances and return dense matrix in row-major format + */ balanced_coo_pairwise_generalized_spmv( out_distances, *config_, coo_rows_b.data(), Product(), Sum(), AtomicAdd()); @@ -72,6 +73,8 @@ class ip_distances_t : public distances_t { const distances_config_t *config_; rmm::device_uvector coo_rows_b; }; + +}; // END namespace detail }; // END namespace distance }; // END namespace sparse }; // END namespace raft diff --git a/cpp/include/raft/sparse/distance/detail/l2_distance.cuh b/cpp/include/raft/sparse/distance/detail/l2_distance.cuh index a32676ce4c..e7ac78b80a 100644 --- a/cpp/include/raft/sparse/distance/detail/l2_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/l2_distance.cuh @@ -34,6 +34,7 @@ namespace raft { namespace sparse { namespace distance { +namespace detail { // @TODO: Move this into sparse prims (coo_norm) template @@ -417,6 +418,7 @@ class russelrao_expanded_distances_t : public distances_t { ip_distances_t ip_dists; }; +}; // END namespace detail }; // END namespace distance }; // END namespace sparse }; // END namespace raft diff --git a/cpp/include/raft/sparse/distance/detail/lp_distance.cuh b/cpp/include/raft/sparse/distance/detail/lp_distance.cuh index 41a08866a2..c11369375b 100644 --- a/cpp/include/raft/sparse/distance/detail/lp_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/lp_distance.cuh @@ -37,6 +37,7 @@ namespace raft { namespace sparse { namespace distance { +namespace detail { template @@ -272,6 +273,7 @@ class kl_divergence_unexpanded_distances_t : public distances_t { const distances_config_t *config_; }; +}; // END namespace detail }; // END namespace distance }; // END namespace sparse }; // END namespace raft diff --git a/cpp/include/raft/sparse/distance/detail/operators.cuh b/cpp/include/raft/sparse/distance/detail/operators.cuh index 89acda8b1a..9f206095bf 100644 --- a/cpp/include/raft/sparse/distance/detail/operators.cuh +++ b/cpp/include/raft/sparse/distance/detail/operators.cuh @@ -21,6 +21,7 @@ namespace raft { namespace sparse { namespace distance { +namespace detail { struct Sum { template @@ -90,6 +91,7 @@ struct AbsDiff { return fabs(a - b); } }; +} // namespace detail } // namespace distance } // namespace sparse }; // namespace raft diff --git a/cpp/include/raft/sparse/distance/detail/utils.cuh b/cpp/include/raft/sparse/distance/detail/utils.cuh index 3bee1bc87d..abfb7d24ea 100644 --- a/cpp/include/raft/sparse/distance/detail/utils.cuh +++ b/cpp/include/raft/sparse/distance/detail/utils.cuh @@ -24,6 +24,7 @@ namespace raft { namespace sparse { namespace distance { +namespace detail { /** * Computes the maximum number of columns that can be stored @@ -39,6 +40,7 @@ inline int max_cols_per_block() { sizeof(value_t); } +} // namespace detail } // namespace distance } // namespace sparse } // namespace raft diff --git a/cpp/include/raft/sparse/distance/distance.cuh b/cpp/include/raft/sparse/distance/distance.cuh index 5b52a0dfbc..24b10420f3 100644 --- a/cpp/include/raft/sparse/distance/distance.cuh +++ b/cpp/include/raft/sparse/distance/distance.cuh @@ -78,70 +78,77 @@ void pairwiseDistance(value_t *out, raft::distance::DistanceType metric, float metric_arg) { switch (metric) { case raft::distance::DistanceType::L2Expanded: - l2_expanded_distances_t(input_config).compute(out); + detail::l2_expanded_distances_t(input_config) + .compute(out); break; case raft::distance::DistanceType::L2SqrtExpanded: - l2_sqrt_expanded_distances_t(input_config) + detail::l2_sqrt_expanded_distances_t(input_config) .compute(out); break; case raft::distance::DistanceType::InnerProduct: - ip_distances_t(input_config).compute(out); + detail::ip_distances_t(input_config).compute(out); break; case raft::distance::DistanceType::L2Unexpanded: - l2_unexpanded_distances_t(input_config).compute(out); + detail::l2_unexpanded_distances_t(input_config) + .compute(out); break; case raft::distance::DistanceType::L2SqrtUnexpanded: - l2_sqrt_unexpanded_distances_t(input_config) + detail::l2_sqrt_unexpanded_distances_t(input_config) .compute(out); break; case raft::distance::DistanceType::L1: - l1_unexpanded_distances_t(input_config).compute(out); + detail::l1_unexpanded_distances_t(input_config) + .compute(out); break; case raft::distance::DistanceType::LpUnexpanded: - lp_unexpanded_distances_t(input_config, metric_arg) + detail::lp_unexpanded_distances_t(input_config, + metric_arg) .compute(out); break; case raft::distance::DistanceType::Linf: - linf_unexpanded_distances_t(input_config) + detail::linf_unexpanded_distances_t(input_config) .compute(out); break; case raft::distance::DistanceType::Canberra: - canberra_unexpanded_distances_t(input_config) + detail::canberra_unexpanded_distances_t(input_config) .compute(out); break; case raft::distance::DistanceType::JaccardExpanded: - jaccard_expanded_distances_t(input_config) + detail::jaccard_expanded_distances_t(input_config) .compute(out); break; case raft::distance::DistanceType::CosineExpanded: - cosine_expanded_distances_t(input_config) + detail::cosine_expanded_distances_t(input_config) .compute(out); break; case raft::distance::DistanceType::HellingerExpanded: - hellinger_expanded_distances_t(input_config) + detail::hellinger_expanded_distances_t(input_config) .compute(out); break; case raft::distance::DistanceType::DiceExpanded: - dice_expanded_distances_t(input_config).compute(out); + detail::dice_expanded_distances_t(input_config) + .compute(out); break; case raft::distance::DistanceType::CorrelationExpanded: - correlation_expanded_distances_t(input_config) + detail::correlation_expanded_distances_t(input_config) .compute(out); break; case raft::distance::DistanceType::RusselRaoExpanded: - russelrao_expanded_distances_t(input_config) + detail::russelrao_expanded_distances_t(input_config) .compute(out); break; case raft::distance::DistanceType::HammingUnexpanded: - hamming_unexpanded_distances_t(input_config) + detail::hamming_unexpanded_distances_t(input_config) .compute(out); break; case raft::distance::DistanceType::JensenShannon: - jensen_shannon_unexpanded_distances_t(input_config) + detail::jensen_shannon_unexpanded_distances_t( + input_config) .compute(out); break; case raft::distance::DistanceType::KLDivergence: - kl_divergence_unexpanded_distances_t(input_config) + detail::kl_divergence_unexpanded_distances_t( + input_config) .compute(out); break; diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh index 6f2846847e..7f8523a587 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh @@ -17,6 +17,9 @@ #include #include #include + +// TODO: Need to hide the PairwiseDistance class impl and expose to public API +#include #include "processing.hpp" namespace raft { @@ -458,10 +461,10 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN( } }; - raft::distance::PairwiseDistances + raft::distance::detail::PairwiseDistances< + useNorms, DataT, AccT, OutT, IdxT, Policy, CoreLambda, + decltype(epilog_lambda), FinalLambda, decltype(rowEpilog_lambda), + isRowMajor, false> obj(x, y, m, n, k, lda, ldb, ldd, _xn, _yn, nullptr, smem, core_op, epilog_lambda, fin_op, rowEpilog_lambda); obj.run(); @@ -509,7 +512,7 @@ void fusedL2kNNImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, "fusedL2kNN: num of nearest neighbors must be <= 64"); } - dim3 grid = raft::distance::launchConfigGenerator( + dim3 grid = raft::distance::detail::launchConfigGenerator( m, n, KPolicy::SmemSize, fusedL2kNNRowMajor); if (grid.x > 1) { const auto numMutexes = raft::ceildiv(m, KPolicy::Mblk); diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index cfea4ee2d9..241cab2260 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -81,7 +82,7 @@ void naive(cub::KeyValuePair *min, DataT *x, DataT *y, int m, int n, CUDA_CHECK(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); auto blks = raft::ceildiv(m, 256); MinAndDistanceReduceOp op; - initKernel, int> + detail::initKernel, int> <<>>(min, m, std::numeric_limits::max(), op); CUDA_CHECK(cudaGetLastError()); naiveKernel, 16> diff --git a/cpp/test/sparse/dist_coo_spmv.cu b/cpp/test/sparse/dist_coo_spmv.cu index 563dcf6f15..b6f6269457 100644 --- a/cpp/test/sparse/dist_coo_spmv.cu +++ b/cpp/test/sparse/dist_coo_spmv.cu @@ -25,8 +25,8 @@ #include #include -#include -#include +#include +#include #include "../test_utils.h" @@ -54,15 +54,16 @@ struct InputConfiguration { float metric_arg = 0.0; }; -using dense_smem_strategy_t = dense_smem_strategy; -using hash_strategy_t = hash_strategy; +using dense_smem_strategy_t = detail::dense_smem_strategy; +using hash_strategy_t = detail::hash_strategy; template struct SparseDistanceCOOSPMVInputs { InputConfiguration input_configuration; float capacity_threshold = 0.5; - int map_size = hash_strategy::get_map_size(); + int map_size = + detail::hash_strategy::get_map_size(); }; template @@ -103,7 +104,7 @@ class SparseDistanceCOOSPMVTest dist_config.handle.get_stream()); strategy_t selected_strategy = make_strategy(); - balanced_coo_pairwise_generalized_spmv( + detail::balanced_coo_pairwise_generalized_spmv( out_dists, dist_config, coo_rows.data(), reduce_func, accum_func, write_func, selected_strategy); @@ -112,7 +113,7 @@ class SparseDistanceCOOSPMVTest dist_config.a_indptr, dist_config.a_nrows, coo_rows.data(), dist_config.a_nnz, dist_config.handle.get_stream()); - balanced_coo_pairwise_generalized_spmv_rev( + detail::balanced_coo_pairwise_generalized_spmv_rev( out_dists, dist_config, coo_rows.data(), reduce_func, accum_func, write_func, selected_strategy); } @@ -121,27 +122,28 @@ class SparseDistanceCOOSPMVTest void run_spmv() { switch (params.input_configuration.metric) { case raft::distance::DistanceType::InnerProduct: - compute_dist(Product(), Sum(), AtomicAdd(), true); + compute_dist(detail::Product(), detail::Sum(), detail::AtomicAdd(), + true); break; case raft::distance::DistanceType::L2Unexpanded: - compute_dist(SqDiff(), Sum(), AtomicAdd()); + compute_dist(detail::SqDiff(), detail::Sum(), detail::AtomicAdd()); break; case raft::distance::DistanceType::Canberra: compute_dist( [] __device__(value_t a, value_t b) { return fabsf(a - b) / (fabsf(a) + fabsf(b)); }, - Sum(), AtomicAdd()); + detail::Sum(), detail::AtomicAdd()); break; case raft::distance::DistanceType::L1: - compute_dist(AbsDiff(), Sum(), AtomicAdd()); + compute_dist(detail::AbsDiff(), detail::Sum(), detail::AtomicAdd()); break; case raft::distance::DistanceType::Linf: - compute_dist(AbsDiff(), Max(), AtomicMax()); + compute_dist(detail::AbsDiff(), detail::Max(), detail::AtomicMax()); break; case raft::distance::DistanceType::LpUnexpanded: { - compute_dist(PDiff(params.input_configuration.metric_arg), Sum(), - AtomicAdd()); + compute_dist(detail::PDiff(params.input_configuration.metric_arg), + detail::Sum(), detail::AtomicAdd()); float p = 1.0f / params.input_configuration.metric_arg; raft::linalg::unaryOp( out_dists, out_dists, dist_config.a_nrows * dist_config.b_nrows, From fe892e8f37e4c8ba60a0be9a5f2cb7a4d0292c8c Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 1 Oct 2021 11:01:16 -0400 Subject: [PATCH 07/22] Removing unused include --- cpp/include/raft/spatial/knn/detail/ball_cover.cuh | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index 46a97400e2..909e28708e 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -34,7 +34,6 @@ #include #include #include -#include #include #include From 56e12574c056d6ac789f451d0c0e7576d62d9e87 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Mon, 4 Oct 2021 12:46:24 -0400 Subject: [PATCH 08/22] Adding hpp --- cpp/include/raft/distance/distance.hpp | 150 +++++++++++++++++++++++++ 1 file changed, 150 insertions(+) create mode 100644 cpp/include/raft/distance/distance.hpp diff --git a/cpp/include/raft/distance/distance.hpp b/cpp/include/raft/distance/distance.hpp new file mode 100644 index 0000000000..986f6b4ede --- /dev/null +++ b/cpp/include/raft/distance/distance.hpp @@ -0,0 +1,150 @@ +/* + * Copyright (c) 2018-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft { +namespace distance { + +/** +* @brief Evaluate pairwise distances with the user epilogue lamba allowed +* @tparam DistanceType which distance to evaluate +* @tparam InType input argument type +* @tparam AccType accumulation type +* @tparam OutType output type +* @tparam FinalLambda user-defined epilogue lamba +* @tparam Index_ Index type +* @param x first set of points +* @param y second set of points +* @param dist output distance matrix +* @param m number of points in x +* @param n number of points in y +* @param k dimensionality +* @param workspace temporary workspace needed for computations +* @param worksize number of bytes of the workspace +* @param fin_op the final gemm epilogue lambda +* @param stream cuda stream +* @param isRowMajor whether the matrices are row-major or col-major +* +* @note fin_op: This is a device lambda which is supposed to operate upon the +* input which is AccType and returns the output in OutType. It's signature is +* as follows:

OutType fin_op(AccType in, int g_idx);
. If one needs +* any other parameters, feel free to pass them via closure. +*/ +template +void distance(const InType *x, const InType *y, OutType *dist, Index_ m, + Index_ n, Index_ k, void *workspace, size_t worksize, + FinalLambda fin_op, cudaStream_t stream, bool isRowMajor = true, + InType metric_arg = 2.0f) { + detail::distance( + x, y, dist, m, n, k, workspace, worksize, fin_op, stream, isRowMajor, + metric_arg); +} + +/** +* @brief Evaluate pairwise distances for the simple use case +* @tparam DistanceType which distance to evaluate +* @tparam InType input argument type +* @tparam AccType accumulation type +* @tparam OutType output type +* @tparam Index_ Index type +* @param x first set of points +* @param y second set of points +* @param dist output distance matrix +* @param m number of points in x +* @param n number of points in y +* @param k dimensionality +* @param workspace temporary workspace needed for computations +* @param worksize number of bytes of the workspace +* @param stream cuda stream +* @param isRowMajor whether the matrices are row-major or col-major +* +* @note if workspace is passed as nullptr, this will return in +* worksize, the number of bytes of workspace required +*/ +template +void distance(const InType *x, const InType *y, OutType *dist, Index_ m, + Index_ n, Index_ k, void *workspace, size_t worksize, + cudaStream_t stream, bool isRowMajor = true, + InType metric_arg = 2.0f) { + detail::distance( + x, y, dist, m, n, k, workspace, worksize, stream, isRowMajor, metric_arg); +} + +/** +* @brief Return the exact workspace size to compute the distance +* @tparam DistanceType which distance to evaluate +* @tparam InType input argument type +* @tparam AccType accumulation type +* @tparam OutType output type +* @tparam Index_ Index type +* @param x first set of points +* @param y second set of points +* @param m number of points in x +* @param n number of points in y +* @param k dimensionality +* +* @note If the specifed distanceType doesn't need the workspace at all, it +* returns 0. +*/ +template +size_t getWorkspaceSize(const InType *x, const InType *y, Index_ m, Index_ n, + Index_ k) { + return detail::getWorkspaceSize(x, y, m, n, k); +} + +/** + * @defgroup pairwise_distance pairwise distance prims + * @{ + * @brief Convenience wrapper around 'distance' prim to convert runtime metric + * into compile time for the purpose of dispatch + * @tparam Type input/accumulation/output data-type + * @tparam Index_ indexing type + * @param x first set of points + * @param y second set of points + * @param dist output distance matrix + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * @param workspace temporary workspace buffer which can get resized as per the + * needed workspace size + * @param metric distance metric + * @param stream cuda stream + * @param isRowMajor whether the matrices are row-major or col-major + */ +template +void pairwise_distance(const raft::handle_t &handle, const Type *x, + const Type *y, Type *dist, Index_ m, Index_ n, Index_ k, + raft::distance::DistanceType metric, + bool isRowMajor = true, Type metric_arg = 2.0f) { + raft::device_uvector workspace(0, handle.get_stream()); + raft::distance::pairwise_distance(X, y, dist, m, n, k, workspace, metric, + handle.get_stream(), isRowMajor, + metric_arg); +} +/** @} */ + +}; // namespace distance +}; // namespace raft From a43fb5c0f2d35a5971dfeb0ca2d28d6eee281a19 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Mon, 4 Oct 2021 14:54:15 -0400 Subject: [PATCH 09/22] Removing unnecessary detail:: namespace qualifier from distances --- cpp/include/raft/distance/detail/canberra.cuh | 4 ++-- cpp/include/raft/distance/detail/chebyshev.cuh | 4 ++-- cpp/include/raft/distance/detail/correlation.cuh | 4 ++-- cpp/include/raft/distance/detail/cosine.cuh | 3 +-- cpp/include/raft/distance/detail/euclidean.cuh | 8 ++++---- cpp/include/raft/distance/detail/hamming.cuh | 8 ++++---- cpp/include/raft/distance/detail/hellinger.cuh | 8 ++++---- .../raft/distance/detail/jensen_shannon.cuh | 8 ++++---- .../raft/distance/detail/kl_divergence.cuh | 16 ++++++++-------- cpp/include/raft/distance/detail/l1.cuh | 8 ++++---- cpp/include/raft/distance/detail/minkowski.cuh | 8 ++++---- cpp/include/raft/distance/detail/russell_rao.cuh | 8 ++++---- 12 files changed, 43 insertions(+), 44 deletions(-) diff --git a/cpp/include/raft/distance/detail/canberra.cuh b/cpp/include/raft/distance/detail/canberra.cuh index 3bef776b2c..02e46470f4 100644 --- a/cpp/include/raft/distance/detail/canberra.cuh +++ b/cpp/include/raft/distance/detail/canberra.cuh @@ -89,8 +89,8 @@ static void canberraImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, pairwiseDistanceMatKernel; - dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, - canberraColMajor); + dim3 grid = + launchConfigGenerator(m, n, KPolicy::SmemSize, canberraColMajor); canberraColMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); diff --git a/cpp/include/raft/distance/detail/chebyshev.cuh b/cpp/include/raft/distance/detail/chebyshev.cuh index a2c89c5301..588cf3b596 100644 --- a/cpp/include/raft/distance/detail/chebyshev.cuh +++ b/cpp/include/raft/distance/detail/chebyshev.cuh @@ -74,8 +74,8 @@ static void chebyshevImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, pairwiseDistanceMatKernel; - dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, - chebyshevRowMajor); + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + chebyshevRowMajor); chebyshevRowMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, diff --git a/cpp/include/raft/distance/detail/correlation.cuh b/cpp/include/raft/distance/detail/correlation.cuh index 0f2e6b0ca4..cc579b7b92 100644 --- a/cpp/include/raft/distance/detail/correlation.cuh +++ b/cpp/include/raft/distance/detail/correlation.cuh @@ -125,8 +125,8 @@ static void correlationImpl(const DataT *x, const DataT *y, const DataT *xn, pairwiseDistanceMatKernel; - dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, - correlationRowMajor); + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + correlationRowMajor); correlationRowMajor<<>>( x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); diff --git a/cpp/include/raft/distance/detail/cosine.cuh b/cpp/include/raft/distance/detail/cosine.cuh index 71bddd12c5..bdb4e221fa 100644 --- a/cpp/include/raft/distance/detail/cosine.cuh +++ b/cpp/include/raft/distance/detail/cosine.cuh @@ -90,8 +90,7 @@ void cosineImpl(const DataT *x, const DataT *y, const DataT *xn, pairwiseDistanceMatKernel; - dim3 grid = - detail::launchConfigGenerator(m, n, shmemSize, cosineRowMajor); + dim3 grid = launchConfigGenerator(m, n, shmemSize, cosineRowMajor); cosineRowMajor<<>>( x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); diff --git a/cpp/include/raft/distance/detail/euclidean.cuh b/cpp/include/raft/distance/detail/euclidean.cuh index 8592e96295..af0a006b89 100644 --- a/cpp/include/raft/distance/detail/euclidean.cuh +++ b/cpp/include/raft/distance/detail/euclidean.cuh @@ -98,8 +98,8 @@ void euclideanExpImpl(const DataT *x, const DataT *y, const DataT *xn, pairwiseDistanceMatKernel; - dim3 grid = detail::launchConfigGenerator(m, n, shmemSize, - euclideanExpRowMajor); + dim3 grid = + launchConfigGenerator(m, n, shmemSize, euclideanExpRowMajor); euclideanExpRowMajor<<>>( x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, @@ -109,8 +109,8 @@ void euclideanExpImpl(const DataT *x, const DataT *y, const DataT *xn, pairwiseDistanceMatKernel; - dim3 grid = detail::launchConfigGenerator(m, n, shmemSize, - euclideanExpColMajor); + dim3 grid = + launchConfigGenerator(m, n, shmemSize, euclideanExpColMajor); euclideanExpColMajor<<>>( x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); diff --git a/cpp/include/raft/distance/detail/hamming.cuh b/cpp/include/raft/distance/detail/hamming.cuh index 2134b8d05b..0169ba33a2 100644 --- a/cpp/include/raft/distance/detail/hamming.cuh +++ b/cpp/include/raft/distance/detail/hamming.cuh @@ -86,8 +86,8 @@ static void hammingUnexpandedImpl(const DataT *x, const DataT *y, IdxT m, pairwiseDistanceMatKernel; - dim3 grid = detail::launchConfigGenerator( - m, n, KPolicy::SmemSize, hammingUnexpandedRowMajor); + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + hammingUnexpandedRowMajor); hammingUnexpandedRowMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, @@ -97,8 +97,8 @@ static void hammingUnexpandedImpl(const DataT *x, const DataT *y, IdxT m, pairwiseDistanceMatKernel; - dim3 grid = detail::launchConfigGenerator( - m, n, KPolicy::SmemSize, hammingUnexpandedColMajor); + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + hammingUnexpandedColMajor); hammingUnexpandedColMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); diff --git a/cpp/include/raft/distance/detail/hellinger.cuh b/cpp/include/raft/distance/detail/hellinger.cuh index a99f6db1e9..933d850dbf 100644 --- a/cpp/include/raft/distance/detail/hellinger.cuh +++ b/cpp/include/raft/distance/detail/hellinger.cuh @@ -106,8 +106,8 @@ static void hellingerImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, pairwiseDistanceMatKernel; - dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, - hellingerRowMajor); + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + hellingerRowMajor); hellingerRowMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, @@ -117,8 +117,8 @@ static void hellingerImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, pairwiseDistanceMatKernel; - dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, - hellingerColMajor); + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + hellingerColMajor); hellingerColMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); diff --git a/cpp/include/raft/distance/detail/jensen_shannon.cuh b/cpp/include/raft/distance/detail/jensen_shannon.cuh index dbaf98562e..1e39f39682 100644 --- a/cpp/include/raft/distance/detail/jensen_shannon.cuh +++ b/cpp/include/raft/distance/detail/jensen_shannon.cuh @@ -93,8 +93,8 @@ static void jensenShannonImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, pairwiseDistanceMatKernel; - dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, - jensenShannonRowMajor); + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + jensenShannonRowMajor); jensenShannonRowMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, @@ -104,8 +104,8 @@ static void jensenShannonImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, pairwiseDistanceMatKernel; - dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, - jensenShannonColMajor); + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + jensenShannonColMajor); jensenShannonColMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); diff --git a/cpp/include/raft/distance/detail/kl_divergence.cuh b/cpp/include/raft/distance/detail/kl_divergence.cuh index da6abbf3f3..5a18ba1670 100644 --- a/cpp/include/raft/distance/detail/kl_divergence.cuh +++ b/cpp/include/raft/distance/detail/kl_divergence.cuh @@ -127,8 +127,8 @@ static void klDivergenceImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, if (x != y) { raft::linalg::unaryOp( (DataT *)y, y, n * k, unaryOp_lambda, stream); - dim3 grid = detail::launchConfigGenerator( - m, n, KPolicy::SmemSize, klDivergenceRowMajor); + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + klDivergenceRowMajor); klDivergenceRowMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); @@ -136,8 +136,8 @@ static void klDivergenceImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, raft::linalg::unaryOp( (DataT *)y, y, n * k, unaryOp_lambda_reverse, stream); } else { - dim3 grid = detail::launchConfigGenerator( - m, n, KPolicy::SmemSize, klDivergenceRowMajorXequalY); + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + klDivergenceRowMajorXequalY); klDivergenceRowMajorXequalY<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda_x_equal_y, epilog_lambda, fin_op); @@ -154,8 +154,8 @@ static void klDivergenceImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, if (x != y) { raft::linalg::unaryOp( (DataT *)x, x, m * k, unaryOp_lambda, stream); - dim3 grid = detail::launchConfigGenerator( - m, n, KPolicy::SmemSize, klDivergenceColMajor); + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + klDivergenceColMajor); klDivergenceColMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); @@ -163,8 +163,8 @@ static void klDivergenceImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, raft::linalg::unaryOp( (DataT *)x, x, m * k, unaryOp_lambda_reverse, stream); } else { - dim3 grid = detail::launchConfigGenerator( - m, n, KPolicy::SmemSize, klDivergenceColMajorXequalY); + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + klDivergenceColMajorXequalY); klDivergenceColMajorXequalY<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda_x_equal_y, epilog_lambda, fin_op); diff --git a/cpp/include/raft/distance/detail/l1.cuh b/cpp/include/raft/distance/detail/l1.cuh index 0e1f1bfa71..33e9bae206 100644 --- a/cpp/include/raft/distance/detail/l1.cuh +++ b/cpp/include/raft/distance/detail/l1.cuh @@ -73,8 +73,8 @@ static void l1Impl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, pairwiseDistanceMatKernel; - dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, - l1RowMajor); + dim3 grid = + launchConfigGenerator(m, n, KPolicy::SmemSize, l1RowMajor); l1RowMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, @@ -84,8 +84,8 @@ static void l1Impl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, pairwiseDistanceMatKernel; - dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, - l1ColMajor); + dim3 grid = + launchConfigGenerator(m, n, KPolicy::SmemSize, l1ColMajor); l1ColMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); diff --git a/cpp/include/raft/distance/detail/minkowski.cuh b/cpp/include/raft/distance/detail/minkowski.cuh index 2fab2ae6de..8bd3deb08f 100644 --- a/cpp/include/raft/distance/detail/minkowski.cuh +++ b/cpp/include/raft/distance/detail/minkowski.cuh @@ -84,8 +84,8 @@ void minkowskiUnExpImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, pairwiseDistanceMatKernel; - dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, - minkowskiUnExpRowMajor); + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + minkowskiUnExpRowMajor); minkowskiUnExpRowMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, @@ -96,8 +96,8 @@ void minkowskiUnExpImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, pairwiseDistanceMatKernel; - dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, - minkowskiUnExpColMajor); + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + minkowskiUnExpColMajor); minkowskiUnExpColMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, diff --git a/cpp/include/raft/distance/detail/russell_rao.cuh b/cpp/include/raft/distance/detail/russell_rao.cuh index cc8d0a855a..8e4c4824c3 100644 --- a/cpp/include/raft/distance/detail/russell_rao.cuh +++ b/cpp/include/raft/distance/detail/russell_rao.cuh @@ -85,8 +85,8 @@ static void russellRaoImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, pairwiseDistanceMatKernel; - dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, - russellRaoRowMajor); + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + russellRaoRowMajor); russellRaoRowMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, @@ -96,8 +96,8 @@ static void russellRaoImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, pairwiseDistanceMatKernel; - dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, - russellRaoColMajor); + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, + russellRaoColMajor); russellRaoColMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); From 92f4c2e0954580e813bd7e789be7e8494e1a7b4a Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Mon, 4 Oct 2021 16:22:43 -0400 Subject: [PATCH 10/22] Removing some more unecessry namespaces --- cpp/include/raft/distance/detail/canberra.cuh | 2 +- cpp/include/raft/distance/detail/chebyshev.cuh | 2 +- cpp/include/raft/distance/detail/correlation.cuh | 2 +- cpp/include/raft/distance/detail/cosine.cuh | 2 +- cpp/include/raft/distance/detail/euclidean.cuh | 4 ++-- 5 files changed, 6 insertions(+), 6 deletions(-) diff --git a/cpp/include/raft/distance/detail/canberra.cuh b/cpp/include/raft/distance/detail/canberra.cuh index 02e46470f4..5dab77b97b 100644 --- a/cpp/include/raft/distance/detail/canberra.cuh +++ b/cpp/include/raft/distance/detail/canberra.cuh @@ -78,7 +78,7 @@ static void canberraImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, pairwiseDistanceMatKernel; - dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, canberraRowMajor); canberraRowMajor<<>>( diff --git a/cpp/include/raft/distance/detail/chebyshev.cuh b/cpp/include/raft/distance/detail/chebyshev.cuh index 588cf3b596..ff0049a0a0 100644 --- a/cpp/include/raft/distance/detail/chebyshev.cuh +++ b/cpp/include/raft/distance/detail/chebyshev.cuh @@ -85,7 +85,7 @@ static void chebyshevImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, pairwiseDistanceMatKernel; - dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, chebyshevColMajor); chebyshevColMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, diff --git a/cpp/include/raft/distance/detail/correlation.cuh b/cpp/include/raft/distance/detail/correlation.cuh index cc579b7b92..de43be2864 100644 --- a/cpp/include/raft/distance/detail/correlation.cuh +++ b/cpp/include/raft/distance/detail/correlation.cuh @@ -135,7 +135,7 @@ static void correlationImpl(const DataT *x, const DataT *y, const DataT *xn, pairwiseDistanceMatKernel; - dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, correlationColMajor); correlationColMajor<<>>( x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, diff --git a/cpp/include/raft/distance/detail/cosine.cuh b/cpp/include/raft/distance/detail/cosine.cuh index bdb4e221fa..8a2e31751f 100644 --- a/cpp/include/raft/distance/detail/cosine.cuh +++ b/cpp/include/raft/distance/detail/cosine.cuh @@ -100,7 +100,7 @@ void cosineImpl(const DataT *x, const DataT *y, const DataT *xn, decltype(core_lambda), decltype(epilog_lambda), FinalLambda, false>; dim3 grid = - detail::launchConfigGenerator(m, n, shmemSize, cosineColMajor); + launchConfigGenerator(m, n, shmemSize, cosineColMajor); cosineColMajor<<>>( x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); diff --git a/cpp/include/raft/distance/detail/euclidean.cuh b/cpp/include/raft/distance/detail/euclidean.cuh index af0a006b89..8e16c1c462 100644 --- a/cpp/include/raft/distance/detail/euclidean.cuh +++ b/cpp/include/raft/distance/detail/euclidean.cuh @@ -268,7 +268,7 @@ void euclideanUnExpImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, pairwiseDistanceMatKernel; - dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, euclideanUnExpRowMajor); euclideanUnExpRowMajor<<>>( @@ -280,7 +280,7 @@ void euclideanUnExpImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, pairwiseDistanceMatKernel; - dim3 grid = detail::launchConfigGenerator(m, n, KPolicy::SmemSize, + dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, euclideanUnExpColMajor); euclideanUnExpColMajor<<>>( From 45e0b7035fa2db78c14724670a3cbb872282f723 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Mon, 4 Oct 2021 16:57:58 -0400 Subject: [PATCH 11/22] Style --- cpp/include/raft/distance/detail/canberra.cuh | 4 ++-- cpp/include/raft/distance/detail/chebyshev.cuh | 2 +- cpp/include/raft/distance/detail/correlation.cuh | 2 +- cpp/include/raft/distance/detail/cosine.cuh | 3 +-- cpp/include/raft/distance/detail/euclidean.cuh | 4 ++-- 5 files changed, 7 insertions(+), 8 deletions(-) diff --git a/cpp/include/raft/distance/detail/canberra.cuh b/cpp/include/raft/distance/detail/canberra.cuh index 5dab77b97b..c4c384c45f 100644 --- a/cpp/include/raft/distance/detail/canberra.cuh +++ b/cpp/include/raft/distance/detail/canberra.cuh @@ -78,8 +78,8 @@ static void canberraImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, pairwiseDistanceMatKernel; - dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, - canberraRowMajor); + dim3 grid = + launchConfigGenerator(m, n, KPolicy::SmemSize, canberraRowMajor); canberraRowMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, diff --git a/cpp/include/raft/distance/detail/chebyshev.cuh b/cpp/include/raft/distance/detail/chebyshev.cuh index ff0049a0a0..77fba28310 100644 --- a/cpp/include/raft/distance/detail/chebyshev.cuh +++ b/cpp/include/raft/distance/detail/chebyshev.cuh @@ -86,7 +86,7 @@ static void chebyshevImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, decltype(core_lambda), decltype(epilog_lambda), FinalLambda, false>; dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, - chebyshevColMajor); + chebyshevColMajor); chebyshevColMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); diff --git a/cpp/include/raft/distance/detail/correlation.cuh b/cpp/include/raft/distance/detail/correlation.cuh index de43be2864..cee986997a 100644 --- a/cpp/include/raft/distance/detail/correlation.cuh +++ b/cpp/include/raft/distance/detail/correlation.cuh @@ -136,7 +136,7 @@ static void correlationImpl(const DataT *x, const DataT *y, const DataT *xn, decltype(core_lambda), decltype(epilog_lambda), FinalLambda, false>; dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, - correlationColMajor); + correlationColMajor); correlationColMajor<<>>( x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); diff --git a/cpp/include/raft/distance/detail/cosine.cuh b/cpp/include/raft/distance/detail/cosine.cuh index 8a2e31751f..900e045edc 100644 --- a/cpp/include/raft/distance/detail/cosine.cuh +++ b/cpp/include/raft/distance/detail/cosine.cuh @@ -99,8 +99,7 @@ void cosineImpl(const DataT *x, const DataT *y, const DataT *xn, pairwiseDistanceMatKernel; - dim3 grid = - launchConfigGenerator(m, n, shmemSize, cosineColMajor); + dim3 grid = launchConfigGenerator(m, n, shmemSize, cosineColMajor); cosineColMajor<<>>( x, y, xn, yn, m, n, k, lda, ldb, ldd, dOutput, core_lambda, epilog_lambda, fin_op); diff --git a/cpp/include/raft/distance/detail/euclidean.cuh b/cpp/include/raft/distance/detail/euclidean.cuh index 8e16c1c462..8b8882c244 100644 --- a/cpp/include/raft/distance/detail/euclidean.cuh +++ b/cpp/include/raft/distance/detail/euclidean.cuh @@ -269,7 +269,7 @@ void euclideanUnExpImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, decltype(core_lambda), decltype(epilog_lambda), FinalLambda, true>; dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, - euclideanUnExpRowMajor); + euclideanUnExpRowMajor); euclideanUnExpRowMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, @@ -281,7 +281,7 @@ void euclideanUnExpImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k, decltype(core_lambda), decltype(epilog_lambda), FinalLambda, false>; dim3 grid = launchConfigGenerator(m, n, KPolicy::SmemSize, - euclideanUnExpColMajor); + euclideanUnExpColMajor); euclideanUnExpColMajor<<>>( x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, dOutput, core_lambda, From a016e36a8312e317674c545bf158745503c054f3 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 12 Oct 2021 13:05:21 -0400 Subject: [PATCH 12/22] Updates --- cpp/include/raft/distance/distance.cuh | 225 ------------------ cpp/include/raft/distance/distance.hpp | 127 ++++++++-- .../{fused_l2_nn.cuh => fused_l2_nn.hpp} | 0 .../distance/{distance.cuh => distance.hpp} | 0 .../sparse/selection/connect_components.cuh | 2 +- cpp/test/distance/fused_l2_nn.cu | 2 +- cpp/test/sparse/distance.cu | 2 +- 7 files changed, 104 insertions(+), 254 deletions(-) delete mode 100644 cpp/include/raft/distance/distance.cuh rename cpp/include/raft/distance/{fused_l2_nn.cuh => fused_l2_nn.hpp} (100%) rename cpp/include/raft/sparse/distance/{distance.cuh => distance.hpp} (100%) diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh deleted file mode 100644 index a0c3544c92..0000000000 --- a/cpp/include/raft/distance/distance.cuh +++ /dev/null @@ -1,225 +0,0 @@ -/* - * Copyright (c) 2018-2021, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include - -namespace raft { -namespace distance { - -/** -* @brief Evaluate pairwise distances with the user epilogue lamba allowed -* @tparam DistanceType which distance to evaluate -* @tparam InType input argument type -* @tparam AccType accumulation type -* @tparam OutType output type -* @tparam FinalLambda user-defined epilogue lamba -* @tparam Index_ Index type -* @param x first set of points -* @param y second set of points -* @param dist output distance matrix -* @param m number of points in x -* @param n number of points in y -* @param k dimensionality -* @param workspace temporary workspace needed for computations -* @param worksize number of bytes of the workspace -* @param fin_op the final gemm epilogue lambda -* @param stream cuda stream -* @param isRowMajor whether the matrices are row-major or col-major -* -* @note fin_op: This is a device lambda which is supposed to operate upon the -* input which is AccType and returns the output in OutType. It's signature is -* as follows:
OutType fin_op(AccType in, int g_idx);
. If one needs -* any other parameters, feel free to pass them via closure. -*/ -template -void distance(const InType *x, const InType *y, OutType *dist, Index_ m, - Index_ n, Index_ k, void *workspace, size_t worksize, - FinalLambda fin_op, cudaStream_t stream, bool isRowMajor = true, - InType metric_arg = 2.0f) { - detail::distance( - x, y, dist, m, n, k, workspace, worksize, fin_op, stream, isRowMajor, - metric_arg); -} - -/** -* @brief Evaluate pairwise distances for the simple use case -* @tparam DistanceType which distance to evaluate -* @tparam InType input argument type -* @tparam AccType accumulation type -* @tparam OutType output type -* @tparam Index_ Index type -* @param x first set of points -* @param y second set of points -* @param dist output distance matrix -* @param m number of points in x -* @param n number of points in y -* @param k dimensionality -* @param workspace temporary workspace needed for computations -* @param worksize number of bytes of the workspace -* @param stream cuda stream -* @param isRowMajor whether the matrices are row-major or col-major -* -* @note if workspace is passed as nullptr, this will return in -* worksize, the number of bytes of workspace required -*/ -template -void distance(const InType *x, const InType *y, OutType *dist, Index_ m, - Index_ n, Index_ k, void *workspace, size_t worksize, - cudaStream_t stream, bool isRowMajor = true, - InType metric_arg = 2.0f) { - detail::distance( - x, y, dist, m, n, k, workspace, worksize, stream, isRowMajor, metric_arg); -} - -/** -* @brief Return the exact workspace size to compute the distance -* @tparam DistanceType which distance to evaluate -* @tparam InType input argument type -* @tparam AccType accumulation type -* @tparam OutType output type -* @tparam Index_ Index type -* @param x first set of points -* @param y second set of points -* @param m number of points in x -* @param n number of points in y -* @param k dimensionality -* -* @note If the specifed distanceType doesn't need the workspace at all, it -* returns 0. -*/ -template -size_t getWorkspaceSize(const InType *x, const InType *y, Index_ m, Index_ n, - Index_ k) { - return detail::getWorkspaceSize(x, y, m, n, k); -} - -/** - * @defgroup pairwise_distance pairwise distance prims - * @{ - * @brief Convenience wrapper around 'distance' prim to convert runtime metric - * into compile time for the purpose of dispatch - * @tparam Type input/accumulation/output data-type - * @tparam Index_ indexing type - * @param x first set of points - * @param y second set of points - * @param dist output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param workspace temporary workspace buffer which can get resized as per the - * needed workspace size - * @param metric distance metric - * @param stream cuda stream - * @param isRowMajor whether the matrices are row-major or col-major - */ -template -void pairwise_distance(const Type *x, const Type *y, Type *dist, Index_ m, - Index_ n, Index_ k, rmm::device_uvector &workspace, - raft::distance::DistanceType metric, cudaStream_t stream, - bool isRowMajor = true, Type metric_arg = 2.0f) { - switch (metric) { - case raft::distance::DistanceType::L2Expanded: - detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, stream, isRowMajor); - break; - case raft::distance::DistanceType::L2SqrtExpanded: - detail::pairwise_distance_impl< - Type, Index_, raft::distance::DistanceType::L2SqrtExpanded>( - x, y, dist, m, n, k, workspace, stream, isRowMajor); - break; - case raft::distance::DistanceType::CosineExpanded: - detail::pairwise_distance_impl< - Type, Index_, raft::distance::DistanceType::CosineExpanded>( - x, y, dist, m, n, k, workspace, stream, isRowMajor); - break; - case raft::distance::DistanceType::L1: - detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, stream, isRowMajor); - break; - case raft::distance::DistanceType::L2Unexpanded: - detail::pairwise_distance_impl< - Type, Index_, raft::distance::DistanceType::L2Unexpanded>( - x, y, dist, m, n, k, workspace, stream, isRowMajor); - break; - case raft::distance::DistanceType::L2SqrtUnexpanded: - detail::pairwise_distance_impl< - Type, Index_, raft::distance::DistanceType::L2SqrtUnexpanded>( - x, y, dist, m, n, k, workspace, stream, isRowMajor); - break; - case raft::distance::DistanceType::Linf: - detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, stream, isRowMajor); - break; - case raft::distance::DistanceType::HellingerExpanded: - detail::pairwise_distance_impl< - Type, Index_, raft::distance::DistanceType::HellingerExpanded>( - x, y, dist, m, n, k, workspace, stream, isRowMajor); - break; - case raft::distance::DistanceType::LpUnexpanded: - detail::pairwise_distance_impl< - Type, Index_, raft::distance::DistanceType::LpUnexpanded>( - x, y, dist, m, n, k, workspace, stream, isRowMajor, metric_arg); - break; - case raft::distance::DistanceType::Canberra: - detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, stream, isRowMajor); - break; - case raft::distance::DistanceType::HammingUnexpanded: - detail::pairwise_distance_impl< - Type, Index_, raft::distance::DistanceType::HammingUnexpanded>( - x, y, dist, m, n, k, workspace, stream, isRowMajor); - break; - case raft::distance::DistanceType::JensenShannon: - detail::pairwise_distance_impl< - Type, Index_, raft::distance::DistanceType::JensenShannon>( - x, y, dist, m, n, k, workspace, stream, isRowMajor); - break; - case raft::distance::DistanceType::RusselRaoExpanded: - detail::pairwise_distance_impl< - Type, Index_, raft::distance::DistanceType::RusselRaoExpanded>( - x, y, dist, m, n, k, workspace, stream, isRowMajor); - break; - case raft::distance::DistanceType::KLDivergence: - detail::pairwise_distance_impl< - Type, Index_, raft::distance::DistanceType::KLDivergence>( - x, y, dist, m, n, k, workspace, stream, isRowMajor); - break; - case raft::distance::DistanceType::CorrelationExpanded: - detail::pairwise_distance_impl< - Type, Index_, raft::distance::DistanceType::CorrelationExpanded>( - x, y, dist, m, n, k, workspace, stream, isRowMajor); - break; - default: - THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - }; -} -/** @} */ - -}; // namespace distance -}; // namespace raft diff --git a/cpp/include/raft/distance/distance.hpp b/cpp/include/raft/distance/distance.hpp index 986f6b4ede..a0c3544c92 100644 --- a/cpp/include/raft/distance/distance.hpp +++ b/cpp/include/raft/distance/distance.hpp @@ -17,7 +17,7 @@ #pragma once #include -#include +#include #include namespace raft { @@ -116,33 +116,108 @@ size_t getWorkspaceSize(const InType *x, const InType *y, Index_ m, Index_ n, } /** - * @defgroup pairwise_distance pairwise distance prims - * @{ - * @brief Convenience wrapper around 'distance' prim to convert runtime metric - * into compile time for the purpose of dispatch - * @tparam Type input/accumulation/output data-type - * @tparam Index_ indexing type - * @param x first set of points - * @param y second set of points - * @param dist output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param workspace temporary workspace buffer which can get resized as per the - * needed workspace size - * @param metric distance metric - * @param stream cuda stream - * @param isRowMajor whether the matrices are row-major or col-major - */ + * @defgroup pairwise_distance pairwise distance prims + * @{ + * @brief Convenience wrapper around 'distance' prim to convert runtime metric + * into compile time for the purpose of dispatch + * @tparam Type input/accumulation/output data-type + * @tparam Index_ indexing type + * @param x first set of points + * @param y second set of points + * @param dist output distance matrix + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * @param workspace temporary workspace buffer which can get resized as per the + * needed workspace size + * @param metric distance metric + * @param stream cuda stream + * @param isRowMajor whether the matrices are row-major or col-major + */ template -void pairwise_distance(const raft::handle_t &handle, const Type *x, - const Type *y, Type *dist, Index_ m, Index_ n, Index_ k, - raft::distance::DistanceType metric, +void pairwise_distance(const Type *x, const Type *y, Type *dist, Index_ m, + Index_ n, Index_ k, rmm::device_uvector &workspace, + raft::distance::DistanceType metric, cudaStream_t stream, bool isRowMajor = true, Type metric_arg = 2.0f) { - raft::device_uvector workspace(0, handle.get_stream()); - raft::distance::pairwise_distance(X, y, dist, m, n, k, workspace, metric, - handle.get_stream(), isRowMajor, - metric_arg); + switch (metric) { + case raft::distance::DistanceType::L2Expanded: + detail::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, stream, isRowMajor); + break; + case raft::distance::DistanceType::L2SqrtExpanded: + detail::pairwise_distance_impl< + Type, Index_, raft::distance::DistanceType::L2SqrtExpanded>( + x, y, dist, m, n, k, workspace, stream, isRowMajor); + break; + case raft::distance::DistanceType::CosineExpanded: + detail::pairwise_distance_impl< + Type, Index_, raft::distance::DistanceType::CosineExpanded>( + x, y, dist, m, n, k, workspace, stream, isRowMajor); + break; + case raft::distance::DistanceType::L1: + detail::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, stream, isRowMajor); + break; + case raft::distance::DistanceType::L2Unexpanded: + detail::pairwise_distance_impl< + Type, Index_, raft::distance::DistanceType::L2Unexpanded>( + x, y, dist, m, n, k, workspace, stream, isRowMajor); + break; + case raft::distance::DistanceType::L2SqrtUnexpanded: + detail::pairwise_distance_impl< + Type, Index_, raft::distance::DistanceType::L2SqrtUnexpanded>( + x, y, dist, m, n, k, workspace, stream, isRowMajor); + break; + case raft::distance::DistanceType::Linf: + detail::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, stream, isRowMajor); + break; + case raft::distance::DistanceType::HellingerExpanded: + detail::pairwise_distance_impl< + Type, Index_, raft::distance::DistanceType::HellingerExpanded>( + x, y, dist, m, n, k, workspace, stream, isRowMajor); + break; + case raft::distance::DistanceType::LpUnexpanded: + detail::pairwise_distance_impl< + Type, Index_, raft::distance::DistanceType::LpUnexpanded>( + x, y, dist, m, n, k, workspace, stream, isRowMajor, metric_arg); + break; + case raft::distance::DistanceType::Canberra: + detail::pairwise_distance_impl( + x, y, dist, m, n, k, workspace, stream, isRowMajor); + break; + case raft::distance::DistanceType::HammingUnexpanded: + detail::pairwise_distance_impl< + Type, Index_, raft::distance::DistanceType::HammingUnexpanded>( + x, y, dist, m, n, k, workspace, stream, isRowMajor); + break; + case raft::distance::DistanceType::JensenShannon: + detail::pairwise_distance_impl< + Type, Index_, raft::distance::DistanceType::JensenShannon>( + x, y, dist, m, n, k, workspace, stream, isRowMajor); + break; + case raft::distance::DistanceType::RusselRaoExpanded: + detail::pairwise_distance_impl< + Type, Index_, raft::distance::DistanceType::RusselRaoExpanded>( + x, y, dist, m, n, k, workspace, stream, isRowMajor); + break; + case raft::distance::DistanceType::KLDivergence: + detail::pairwise_distance_impl< + Type, Index_, raft::distance::DistanceType::KLDivergence>( + x, y, dist, m, n, k, workspace, stream, isRowMajor); + break; + case raft::distance::DistanceType::CorrelationExpanded: + detail::pairwise_distance_impl< + Type, Index_, raft::distance::DistanceType::CorrelationExpanded>( + x, y, dist, m, n, k, workspace, stream, isRowMajor); + break; + default: + THROW("Unknown or unsupported distance metric '%d'!", (int)metric); + }; } /** @} */ diff --git a/cpp/include/raft/distance/fused_l2_nn.cuh b/cpp/include/raft/distance/fused_l2_nn.hpp similarity index 100% rename from cpp/include/raft/distance/fused_l2_nn.cuh rename to cpp/include/raft/distance/fused_l2_nn.hpp diff --git a/cpp/include/raft/sparse/distance/distance.cuh b/cpp/include/raft/sparse/distance/distance.hpp similarity index 100% rename from cpp/include/raft/sparse/distance/distance.cuh rename to cpp/include/raft/sparse/distance/distance.hpp diff --git a/cpp/include/raft/sparse/selection/connect_components.cuh b/cpp/include/raft/sparse/selection/connect_components.cuh index 46369ca964..5313b81192 100644 --- a/cpp/include/raft/sparse/selection/connect_components.cuh +++ b/cpp/include/raft/sparse/selection/connect_components.cuh @@ -16,7 +16,7 @@ #include -#include +#include #include #include #include diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index 241cab2260..b494330c57 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -18,7 +18,7 @@ #include #include #include -#include +#include #include #include #include "../test_utils.h" diff --git a/cpp/test/sparse/distance.cu b/cpp/test/sparse/distance.cu index 4b531992f0..25c1356606 100644 --- a/cpp/test/sparse/distance.cu +++ b/cpp/test/sparse/distance.cu @@ -22,7 +22,7 @@ #include #include -#include +#include #include "../test_utils.h" From c5c0eee4c20a2efbadeee31a95a5949f0517d280 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 12 Oct 2021 13:25:10 -0400 Subject: [PATCH 13/22] Updates to distance api --- cpp/include/raft/distance/distance.hpp | 66 +++++++++++++++++++------- 1 file changed, 48 insertions(+), 18 deletions(-) diff --git a/cpp/include/raft/distance/distance.hpp b/cpp/include/raft/distance/distance.hpp index a0c3544c92..e73dc29939 100644 --- a/cpp/include/raft/distance/distance.hpp +++ b/cpp/include/raft/distance/distance.hpp @@ -18,6 +18,7 @@ #include #include +#include #include namespace raft { @@ -135,85 +136,87 @@ size_t getWorkspaceSize(const InType *x, const InType *y, Index_ m, Index_ n, * @param isRowMajor whether the matrices are row-major or col-major */ template -void pairwise_distance(const Type *x, const Type *y, Type *dist, Index_ m, - Index_ n, Index_ k, rmm::device_uvector &workspace, - raft::distance::DistanceType metric, cudaStream_t stream, +void pairwise_distance(const raft::handle_t &handle, const Type *x, + const Type *y, Type *dist, Index_ m, Index_ n, Index_ k, + rmm::device_uvector &workspace, + raft::distance::DistanceType metric, bool isRowMajor = true, Type metric_arg = 2.0f) { switch (metric) { case raft::distance::DistanceType::L2Expanded: detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, stream, isRowMajor); + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); break; case raft::distance::DistanceType::L2SqrtExpanded: detail::pairwise_distance_impl< Type, Index_, raft::distance::DistanceType::L2SqrtExpanded>( - x, y, dist, m, n, k, workspace, stream, isRowMajor); + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); break; case raft::distance::DistanceType::CosineExpanded: detail::pairwise_distance_impl< Type, Index_, raft::distance::DistanceType::CosineExpanded>( - x, y, dist, m, n, k, workspace, stream, isRowMajor); + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); break; case raft::distance::DistanceType::L1: detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, stream, isRowMajor); + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); break; case raft::distance::DistanceType::L2Unexpanded: detail::pairwise_distance_impl< Type, Index_, raft::distance::DistanceType::L2Unexpanded>( - x, y, dist, m, n, k, workspace, stream, isRowMajor); + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); break; case raft::distance::DistanceType::L2SqrtUnexpanded: detail::pairwise_distance_impl< Type, Index_, raft::distance::DistanceType::L2SqrtUnexpanded>( - x, y, dist, m, n, k, workspace, stream, isRowMajor); + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); break; case raft::distance::DistanceType::Linf: detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, stream, isRowMajor); + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); break; case raft::distance::DistanceType::HellingerExpanded: detail::pairwise_distance_impl< Type, Index_, raft::distance::DistanceType::HellingerExpanded>( - x, y, dist, m, n, k, workspace, stream, isRowMajor); + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); break; case raft::distance::DistanceType::LpUnexpanded: detail::pairwise_distance_impl< Type, Index_, raft::distance::DistanceType::LpUnexpanded>( - x, y, dist, m, n, k, workspace, stream, isRowMajor, metric_arg); + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor, + metric_arg); break; case raft::distance::DistanceType::Canberra: detail::pairwise_distance_impl( - x, y, dist, m, n, k, workspace, stream, isRowMajor); + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); break; case raft::distance::DistanceType::HammingUnexpanded: detail::pairwise_distance_impl< Type, Index_, raft::distance::DistanceType::HammingUnexpanded>( - x, y, dist, m, n, k, workspace, stream, isRowMajor); + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); break; case raft::distance::DistanceType::JensenShannon: detail::pairwise_distance_impl< Type, Index_, raft::distance::DistanceType::JensenShannon>( - x, y, dist, m, n, k, workspace, stream, isRowMajor); + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); break; case raft::distance::DistanceType::RusselRaoExpanded: detail::pairwise_distance_impl< Type, Index_, raft::distance::DistanceType::RusselRaoExpanded>( - x, y, dist, m, n, k, workspace, stream, isRowMajor); + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); break; case raft::distance::DistanceType::KLDivergence: detail::pairwise_distance_impl< Type, Index_, raft::distance::DistanceType::KLDivergence>( - x, y, dist, m, n, k, workspace, stream, isRowMajor); + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); break; case raft::distance::DistanceType::CorrelationExpanded: detail::pairwise_distance_impl< Type, Index_, raft::distance::DistanceType::CorrelationExpanded>( - x, y, dist, m, n, k, workspace, stream, isRowMajor); + x, y, dist, m, n, k, workspace, handle.get_stream(), isRowMajor); break; default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); @@ -221,5 +224,32 @@ void pairwise_distance(const Type *x, const Type *y, Type *dist, Index_ m, } /** @} */ +/** + * @defgroup pairwise_distance pairwise distance prims + * @{ + * @brief Convenience wrapper around 'distance' prim to convert runtime metric + * into compile time for the purpose of dispatch + * @tparam Type input/accumulation/output data-type + * @tparam Index_ indexing type + * @param x first set of points + * @param y second set of points + * @param dist output distance matrix + * @param m number of points in x + * @param n number of points in y + * @param k dimensionality + * @param metric distance metric + * @param stream cuda stream + * @param isRowMajor whether the matrices are row-major or col-major + */ +template +void pairwise_distance(const raft::handle_t &handle, const Type *x, + const Type *y, Type *dist, Index_ m, Index_ n, Index_ k, + raft::distance::DistanceType metric, + bool isRowMajor = true, Type metric_arg = 2.0f) { + rmm::device_uvector workspace(0, handle.get_stream()); + pairwise_distance(handle, x, y, dist, m, n, k, workspace, + metric, isRowMajor, metric_arg); +} + }; // namespace distance }; // namespace raft From 7c2ac23471f7096ef803b23d59105e717daa4bc0 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 12 Oct 2021 13:52:21 -0400 Subject: [PATCH 14/22] Update --- cpp/test/distance/distance_base.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index 4798d102f3..2ab6eec2ee 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -17,7 +17,7 @@ #include #include #include -#include +#include #include #include "../test_utils.h" From 563971753caf5579484080bfbbc0a5b44340db2c Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 12 Oct 2021 18:38:08 -0400 Subject: [PATCH 15/22] Updating import --- cpp/test/distance/dist_adj.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/test/distance/dist_adj.cu b/cpp/test/distance/dist_adj.cu index 8d5cd68f13..3d9261162c 100644 --- a/cpp/test/distance/dist_adj.cu +++ b/cpp/test/distance/dist_adj.cu @@ -17,7 +17,7 @@ #include #include #include -#include +#include #include #include "../test_utils.h" From b3083a3e6ccfb71f715f9e590589de1580802ed4 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 12 Oct 2021 19:00:14 -0400 Subject: [PATCH 16/22] update --- cpp/include/raft/sparse/selection/knn.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/sparse/selection/knn.cuh b/cpp/include/raft/sparse/selection/knn.cuh index 49573a679d..5cc1f69a56 100644 --- a/cpp/include/raft/sparse/selection/knn.cuh +++ b/cpp/include/raft/sparse/selection/knn.cuh @@ -30,7 +30,7 @@ #include #include #include -#include +#include #include #include From 902bf6cafbabb98d2cee394c401bd1c29227093c Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 13 Oct 2021 12:40:30 -0400 Subject: [PATCH 17/22] Updsates --- cpp/include/raft/distance/detail/distance.cuh | 50 ++++++++++++++++++- cpp/include/raft/sparse/selection/knn.cuh | 9 ---- 2 files changed, 49 insertions(+), 10 deletions(-) diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index b4d428474e..2d23d650f7 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -37,7 +37,55 @@ namespace raft { namespace distance { namespace detail { -namespace { + /** enum to tell how to compute distance */ + enum DistanceType : unsigned short { + + /** evaluate as dist_ij = sum(x_ik^2) + sum(y_ij)^2 - 2*sum(x_ik * y_jk) */ + L2Expanded = 0, + /** same as above, but inside the epilogue, perform square root operation */ + L2SqrtExpanded = 1, + /** cosine distance */ + CosineExpanded = 2, + /** L1 distance */ + L1 = 3, + /** evaluate as dist_ij += (x_ik - y-jk)^2 */ + L2Unexpanded = 4, + /** same as above, but inside the epilogue, perform square root operation */ + L2SqrtUnexpanded = 5, + /** basic inner product **/ + InnerProduct = 6, + /** Chebyshev (Linf) distance **/ + Linf = 7, + /** Canberra distance **/ + Canberra = 8, + /** Generalized Minkowski distance **/ + LpUnexpanded = 9, + /** Correlation distance **/ + CorrelationExpanded = 10, + /** Jaccard distance **/ + JaccardExpanded = 11, + /** Hellinger distance **/ + HellingerExpanded = 12, + /** Haversine distance **/ + Haversine = 13, + /** Bray-Curtis distance **/ + BrayCurtis = 14, + /** Jensen-Shannon distance**/ + JensenShannon = 15, + /** Hamming distance **/ + HammingUnexpanded = 16, + /** KLDivergence **/ + KLDivergence = 17, + /** RusselRao **/ + RusselRaoExpanded = 18, + /** Dice-Sorensen distance **/ + DiceExpanded = 19, + /** Precomputed (special value) **/ + Precomputed = 100 + }; + + + namespace { template diff --git a/cpp/include/raft/sparse/selection/knn.cuh b/cpp/include/raft/sparse/selection/knn.cuh index 5cc1f69a56..fc1a7c0d8d 100644 --- a/cpp/include/raft/sparse/selection/knn.cuh +++ b/cpp/include/raft/sparse/selection/knn.cuh @@ -20,7 +20,6 @@ #include #include -#include #include #include #include @@ -33,13 +32,6 @@ #include #include -#include -#include -#include -#include - -#include - namespace raft { namespace sparse { namespace selection { @@ -412,7 +404,6 @@ class sparse_knn_t { * @param[out] output_indices dense matrix for output indices (size n_query_rows * k) * @param[out] output_dists dense matrix for output distances (size n_query_rows * k) * @param[in] k the number of neighbors to query - * @param[in] cusparseHandle the initialized cusparseHandle instance to use * @param[in] handle.get_stream() CUDA handle.get_stream() to order operations with respect to * @param[in] batch_size_index maximum number of rows to use from index matrix per batch * @param[in] batch_size_query maximum number of rows to use from query matrix per batch From 9eac45bd471eb5196bd0da0c9b00ea615830688b Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 13 Oct 2021 12:41:04 -0400 Subject: [PATCH 18/22] Updates --- cpp/include/raft/distance/detail/distance.cuh | 93 +++++++++---------- 1 file changed, 46 insertions(+), 47 deletions(-) diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index 2d23d650f7..199dc73fb6 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -37,55 +37,54 @@ namespace raft { namespace distance { namespace detail { - /** enum to tell how to compute distance */ - enum DistanceType : unsigned short { - - /** evaluate as dist_ij = sum(x_ik^2) + sum(y_ij)^2 - 2*sum(x_ik * y_jk) */ - L2Expanded = 0, - /** same as above, but inside the epilogue, perform square root operation */ - L2SqrtExpanded = 1, - /** cosine distance */ - CosineExpanded = 2, - /** L1 distance */ - L1 = 3, - /** evaluate as dist_ij += (x_ik - y-jk)^2 */ - L2Unexpanded = 4, - /** same as above, but inside the epilogue, perform square root operation */ - L2SqrtUnexpanded = 5, - /** basic inner product **/ - InnerProduct = 6, - /** Chebyshev (Linf) distance **/ - Linf = 7, - /** Canberra distance **/ - Canberra = 8, - /** Generalized Minkowski distance **/ - LpUnexpanded = 9, - /** Correlation distance **/ - CorrelationExpanded = 10, - /** Jaccard distance **/ - JaccardExpanded = 11, - /** Hellinger distance **/ - HellingerExpanded = 12, - /** Haversine distance **/ - Haversine = 13, - /** Bray-Curtis distance **/ - BrayCurtis = 14, - /** Jensen-Shannon distance**/ - JensenShannon = 15, - /** Hamming distance **/ - HammingUnexpanded = 16, - /** KLDivergence **/ - KLDivergence = 17, - /** RusselRao **/ - RusselRaoExpanded = 18, - /** Dice-Sorensen distance **/ - DiceExpanded = 19, - /** Precomputed (special value) **/ - Precomputed = 100 - }; +/** enum to tell how to compute distance */ +enum DistanceType : unsigned short { + /** evaluate as dist_ij = sum(x_ik^2) + sum(y_ij)^2 - 2*sum(x_ik * y_jk) */ + L2Expanded = 0, + /** same as above, but inside the epilogue, perform square root operation */ + L2SqrtExpanded = 1, + /** cosine distance */ + CosineExpanded = 2, + /** L1 distance */ + L1 = 3, + /** evaluate as dist_ij += (x_ik - y-jk)^2 */ + L2Unexpanded = 4, + /** same as above, but inside the epilogue, perform square root operation */ + L2SqrtUnexpanded = 5, + /** basic inner product **/ + InnerProduct = 6, + /** Chebyshev (Linf) distance **/ + Linf = 7, + /** Canberra distance **/ + Canberra = 8, + /** Generalized Minkowski distance **/ + LpUnexpanded = 9, + /** Correlation distance **/ + CorrelationExpanded = 10, + /** Jaccard distance **/ + JaccardExpanded = 11, + /** Hellinger distance **/ + HellingerExpanded = 12, + /** Haversine distance **/ + Haversine = 13, + /** Bray-Curtis distance **/ + BrayCurtis = 14, + /** Jensen-Shannon distance**/ + JensenShannon = 15, + /** Hamming distance **/ + HammingUnexpanded = 16, + /** KLDivergence **/ + KLDivergence = 17, + /** RusselRao **/ + RusselRaoExpanded = 18, + /** Dice-Sorensen distance **/ + DiceExpanded = 19, + /** Precomputed (special value) **/ + Precomputed = 100 +}; - namespace { +namespace { template From 5436bc98ec05ab43d3dd81f0ea0ae83007e1704e Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 13 Oct 2021 17:25:18 -0400 Subject: [PATCH 19/22] Updates --- cpp/include/raft/spatial/knn/detail/ann_quantized_faiss.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/include/raft/spatial/knn/detail/ann_quantized_faiss.cuh b/cpp/include/raft/spatial/knn/detail/ann_quantized_faiss.cuh index 0e91b5225d..980001f166 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_quantized_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_quantized_faiss.cuh @@ -27,7 +27,7 @@ #include "processing.hpp" #include