From f2bc24dabd6fad9e4ef1f62183466073ce3fb176 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 19 Jan 2023 16:16:59 -0800 Subject: [PATCH 1/9] Remove faiss ANN code from knnIndex (#1121) Authors: - Ben Frederickson (https://github.com/benfred) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1121 --- cpp/include/raft/spatial/knn/ann_common.h | 26 +--- .../raft/spatial/knn/detail/ann_quantized.cuh | 130 ++++-------------- cpp/test/neighbors/ann_ivf_flat.cu | 2 - 3 files changed, 31 insertions(+), 127 deletions(-) diff --git a/cpp/include/raft/spatial/knn/ann_common.h b/cpp/include/raft/spatial/knn/ann_common.h index a0d79a1b77..0e9e323b84 100644 --- a/cpp/include/raft/spatial/knn/ann_common.h +++ b/cpp/include/raft/spatial/knn/ann_common.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,12 +22,10 @@ #include "detail/processing.hpp" #include "ivf_flat_types.hpp" +#include #include -#include -#include - namespace raft { namespace spatial { namespace knn { @@ -36,13 +34,14 @@ struct knnIndex { raft::distance::DistanceType metric; float metricArg; int nprobe; - std::unique_ptr index; std::unique_ptr> metric_processor; + std::unique_ptr> ivf_flat_float_; std::unique_ptr> ivf_flat_uint8_t_; std::unique_ptr> ivf_flat_int8_t_; - std::unique_ptr gpu_res; + std::unique_ptr> ivf_pq; + int device; template @@ -70,16 +69,6 @@ inline auto knnIndex::ivf_flat() return ivf_flat_int8_t_; } -enum QuantizerType : unsigned int { - QT_8bit, - QT_4bit, - QT_8bit_uniform, - QT_4bit_uniform, - QT_fp16, - QT_8bit_direct, - QT_6bit -}; - struct knnIndexParam { virtual ~knnIndexParam() {} }; @@ -98,11 +87,6 @@ struct IVFPQParam : IVFParam { bool usePrecomputedTables; }; -struct IVFSQParam : IVFParam { - QuantizerType qtype; - bool encodeResidual; -}; - inline auto from_legacy_index_params(const IVFFlatParam& legacy, raft::distance::DistanceType metric, float metric_arg) diff --git a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh index 975f1a0f89..f651e943e3 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh @@ -18,9 +18,7 @@ #include "../ann_common.h" #include "../ivf_flat.cuh" -#include "knn_brute_force_faiss.cuh" -#include "common_faiss.h" #include "processing.cuh" #include #include @@ -29,83 +27,14 @@ #include #include #include -#include +#include #include -#include -#include -#include -#include - #include namespace raft::spatial::knn::detail { -inline faiss::ScalarQuantizer::QuantizerType build_faiss_qtype(QuantizerType qtype) -{ - switch (qtype) { - case QuantizerType::QT_8bit: return faiss::ScalarQuantizer::QuantizerType::QT_8bit; - case QuantizerType::QT_8bit_uniform: - return faiss::ScalarQuantizer::QuantizerType::QT_8bit_uniform; - case QuantizerType::QT_4bit_uniform: - return faiss::ScalarQuantizer::QuantizerType::QT_4bit_uniform; - case QuantizerType::QT_fp16: return faiss::ScalarQuantizer::QuantizerType::QT_fp16; - case QuantizerType::QT_8bit_direct: - return faiss::ScalarQuantizer::QuantizerType::QT_8bit_direct; - case QuantizerType::QT_6bit: return faiss::ScalarQuantizer::QuantizerType::QT_6bit; - default: return (faiss::ScalarQuantizer::QuantizerType)qtype; - } -} - -template -void approx_knn_ivfflat_build_index(knnIndex* index, - const IVFFlatParam& params, - IntType n, - IntType D) -{ - faiss::gpu::GpuIndexIVFFlatConfig config; - config.device = index->device; - faiss::MetricType faiss_metric = build_faiss_metric(index->metric); - index->index.reset( - new faiss::gpu::GpuIndexIVFFlat(index->gpu_res.get(), D, params.nlist, faiss_metric, config)); -} - -template -void approx_knn_ivfpq_build_index(knnIndex* index, const IVFPQParam& params, IntType n, IntType D) -{ - faiss::gpu::GpuIndexIVFPQConfig config; - config.device = index->device; - config.usePrecomputedTables = params.usePrecomputedTables; - config.interleavedLayout = params.n_bits != 8; - faiss::MetricType faiss_metric = build_faiss_metric(index->metric); - index->index.reset(new faiss::gpu::GpuIndexIVFPQ( - index->gpu_res.get(), D, params.nlist, params.M, params.n_bits, faiss_metric, config)); -} - -template -void approx_knn_ivfsq_build_index(knnIndex* index, const IVFSQParam& params, IntType n, IntType D) -{ - faiss::gpu::GpuIndexIVFScalarQuantizerConfig config; - config.device = index->device; - faiss::MetricType faiss_metric = build_faiss_metric(index->metric); - faiss::ScalarQuantizer::QuantizerType faiss_qtype = build_faiss_qtype(params.qtype); - index->index.reset(new faiss::gpu::GpuIndexIVFScalarQuantizer( - index->gpu_res.get(), D, params.nlist, faiss_qtype, faiss_metric, params.encodeResidual)); -} - -inline bool ivf_flat_supported_metric(raft::distance::DistanceType metric) -{ - switch (metric) { - case raft::distance::DistanceType::L2Unexpanded: - case raft::distance::DistanceType::L2Expanded: - case raft::distance::DistanceType::L2SqrtExpanded: - case raft::distance::DistanceType::L2SqrtUnexpanded: - case raft::distance::DistanceType::InnerProduct: return true; - default: return false; - } -} - template void approx_knn_build_index(const handle_t& handle, knnIndex* index, @@ -117,7 +46,6 @@ void approx_knn_build_index(const handle_t& handle, IntType D) { auto stream = handle.get_stream(); - index->index = nullptr; index->metric = metric; index->metricArg = metricArg; if (dynamic_cast(params)) { @@ -125,37 +53,35 @@ void approx_knn_build_index(const handle_t& handle, } auto ivf_ft_pams = dynamic_cast(params); auto ivf_pq_pams = dynamic_cast(params); - auto ivf_sq_pams = dynamic_cast(params); if constexpr (std::is_same_v) { index->metric_processor = create_processor(metric, n, D, 0, false, stream); + // For cosine/correlation distance, the metric processor translates distance + // to inner product via pre/post processing - pass the translated metric to + // ANN index + if (metric == raft::distance::DistanceType::CosineExpanded || + metric == raft::distance::DistanceType::CorrelationExpanded) { + metric = index->metric = raft::distance::DistanceType::InnerProduct; + } } if constexpr (std::is_same_v) { index->metric_processor->preprocess(index_array); } - if (ivf_ft_pams && ivf_flat_supported_metric(metric)) { + if (ivf_ft_pams) { auto new_params = from_legacy_index_params(*ivf_ft_pams, metric, metricArg); index->ivf_flat() = std::make_unique>( ivf_flat::build(handle, new_params, index_array, int64_t(n), D)); + } else if (ivf_pq_pams) { + neighbors::ivf_pq::index_params params; + params.metric = metric; + params.metric_arg = metricArg; + params.n_lists = ivf_pq_pams->nlist; + params.pq_bits = ivf_pq_pams->n_bits; + params.pq_dim = ivf_pq_pams->M; + // TODO: handle ivf_pq_pams.usePrecomputedTables ? + index->ivf_pq = std::make_unique>( + neighbors::ivf_pq::build(handle, params, index_array, int64_t(n), D)); } else { - RAFT_CUDA_TRY(cudaGetDevice(&(index->device))); - index->gpu_res.reset(new raft::spatial::knn::RmmGpuResources()); - index->gpu_res->noTempMemory(); - index->gpu_res->setDefaultStream(index->device, stream); - if (ivf_ft_pams) { - approx_knn_ivfflat_build_index(index, *ivf_ft_pams, n, D); - } else if (ivf_pq_pams) { - approx_knn_ivfpq_build_index(index, *ivf_pq_pams, n, D); - } else if (ivf_sq_pams) { - approx_knn_ivfsq_build_index(index, *ivf_sq_pams, n, D); - } else { - RAFT_FAIL("Unrecognized index type."); - } - if constexpr (std::is_same_v) { - index->index->train(n, index_array); - index->index->add(n, index_array); - } else { - RAFT_FAIL("FAISS-based index supports only float data."); - } + RAFT_FAIL("Unrecognized index type."); } if constexpr (std::is_same_v) { index->metric_processor->revert(index_array); } @@ -170,26 +96,22 @@ void approx_knn_search(const handle_t& handle, T* query_array, IntType n) { - auto faiss_ivf = dynamic_cast(index->index.get()); - if (faiss_ivf) { faiss_ivf->setNumProbes(index->nprobe); } - if constexpr (std::is_same_v) { index->metric_processor->preprocess(query_array); index->metric_processor->set_num_queries(k); } // search - if (faiss_ivf) { - if constexpr (std::is_same_v) { - faiss_ivf->search(n, query_array, k, distances, indices); - } else { - RAFT_FAIL("FAISS-based index supports only float data."); - } - } else if (index->ivf_flat()) { + if (index->ivf_flat()) { ivf_flat::search_params params; params.n_probes = index->nprobe; ivf_flat::search( handle, params, *(index->ivf_flat()), query_array, n, k, indices, distances); + } else if (index->ivf_pq) { + neighbors::ivf_pq::search_params params; + params.n_probes = index->nprobe; + neighbors::ivf_pq::search( + handle, params, *index->ivf_pq, query_array, n, k, indices, distances); } else { RAFT_FAIL("The model is not trained"); } diff --git a/cpp/test/neighbors/ann_ivf_flat.cu b/cpp/test/neighbors/ann_ivf_flat.cu index 86a62bb487..080e7551fa 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cu +++ b/cpp/test/neighbors/ann_ivf_flat.cu @@ -107,8 +107,6 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { ivfParams.nprobe = ps.nprobe; ivfParams.nlist = ps.nlist; raft::spatial::knn::knnIndex index; - index.index = nullptr; - index.gpu_res = nullptr; approx_knn_build_index(handle_, &index, From d233a2cba9108b37727440e88d0ad6e406d28d5f Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 19 Jan 2023 19:34:52 -0800 Subject: [PATCH 2/9] Use squeuclidean for metric name in ivf_pq python bindings (#1160) Use sqeuclidean instead of l2_expanded for the distance name in the ivf_pq python bindings. This matches both sklearn, and the RAFT pairwise_distance api - and should be less confusing for our users Authors: - Ben Frederickson (https://github.com/benfred) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1160 --- .../pylibraft/neighbors/ivf_pq/ivf_pq.pyx | 23 ++++++++++------ .../pylibraft/pylibraft/neighbors/refine.pyx | 6 ++--- .../pylibraft/pylibraft/test/test_ivf_pq.py | 26 +++++++++---------- .../pylibraft/pylibraft/test/test_refine.py | 8 +++--- 4 files changed, 35 insertions(+), 28 deletions(-) diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx index ee30864193..8f8a49fb63 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx @@ -18,6 +18,8 @@ # cython: embedsignature = True # cython: language_level = 3 +import warnings + import numpy as np from cython.operator cimport dereference as deref @@ -63,17 +65,22 @@ from pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq cimport ( def _get_metric(metric): SUPPORTED_DISTANCES = { - "l2_expanded": DistanceType.L2Expanded, + "sqeuclidean": DistanceType.L2Expanded, "euclidean": DistanceType.L2SqrtExpanded, "inner_product": DistanceType.InnerProduct } if metric not in SUPPORTED_DISTANCES: + if metric == "l2_expanded": + warnings.warn("Using l2_expanded as a metric name is deprecated," + " use sqeuclidean instead", FutureWarning) + return DistanceType.L2Expanded + raise ValueError("metric %s is not supported" % metric) return SUPPORTED_DISTANCES[metric] cdef _get_metric_string(DistanceType metric): - return {DistanceType.L2Expanded : "l2_expanded", + return {DistanceType.L2Expanded : "sqeuclidean", DistanceType.InnerProduct: "inner_product", DistanceType.L2SqrtExpanded: "euclidean"}[metric] @@ -118,7 +125,7 @@ cdef class IndexParams: def __init__(self, *, n_lists=1024, - metric="l2_expanded", + metric="sqeuclidean", kmeans_n_iters=20, kmeans_trainset_fraction=0.5, pq_bits=8, @@ -133,10 +140,10 @@ cdef class IndexParams: ---------- n_list : int, default = 1024 The number of clusters used in the coarse quantizer. - metric : string denoting the metric type, default="l2_expanded" - Valid values for metric: ["l2_expanded", "inner_product", + metric : string denoting the metric type, default="sqeuclidean" + Valid values for metric: ["sqeuclidean", "inner_product", "euclidean"], where - - l2_expanded is the euclidean distance without the square root + - sqeuclidean is the euclidean distance without the square root operation, i.e.: distance(a,b) = \\sum_i (a_i - b_i)^2, - euclidean is the euclidean distance - inner product distance is defined as @@ -251,7 +258,7 @@ cdef class Index: # We create a placeholder object. The actual parameter values do # not matter, it will be replaced with a built index object later. self.index = new c_ivf_pq.index[uint64_t]( - deref(handle_), _get_metric("l2_expanded"), + deref(handle_), _get_metric("sqeuclidean"), c_ivf_pq.codebook_gen.PER_SUBSPACE, 1, 4, @@ -347,7 +354,7 @@ def build(IndexParams index_params, dataset, handle=None): >>> handle = Handle() >>> index_params = ivf_pq.IndexParams( ... n_lists=1024, - ... metric="l2_expanded", + ... metric="sqeuclidean", ... pq_dim=10) >>> index = ivf_pq.build(index_params, dataset, handle=handle) diff --git a/python/pylibraft/pylibraft/neighbors/refine.pyx b/python/pylibraft/pylibraft/neighbors/refine.pyx index 37ef69e7b5..b8f1bd0caa 100644 --- a/python/pylibraft/pylibraft/neighbors/refine.pyx +++ b/python/pylibraft/pylibraft/neighbors/refine.pyx @@ -1,5 +1,5 @@ # -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -215,7 +215,7 @@ cdef host_matrix_view[int8_t, uint64_t, row_major] \ @auto_sync_handle @auto_convert_output def refine(dataset, queries, candidates, k=None, indices=None, distances=None, - metric="l2_expanded", handle=None): + metric="sqeuclidean", handle=None): """ Refine nearest neighbor search. @@ -271,7 +271,7 @@ def refine(dataset, queries, candidates, k=None, indices=None, distances=None, >>> dataset = cp.random.random_sample((n_samples, n_features), ... dtype=cp.float32) >>> handle = Handle() - >>> index_params = ivf_pq.IndexParams(n_lists=1024, metric="l2_expanded", + >>> index_params = ivf_pq.IndexParams(n_lists=1024, metric="sqeuclidean", ... pq_dim=10) >>> index = ivf_pq.build(index_params, dataset, handle=handle) diff --git a/python/pylibraft/pylibraft/test/test_ivf_pq.py b/python/pylibraft/pylibraft/test/test_ivf_pq.py index db1389c6cd..6952408c02 100644 --- a/python/pylibraft/pylibraft/test/test_ivf_pq.py +++ b/python/pylibraft/pylibraft/test/test_ivf_pq.py @@ -58,7 +58,7 @@ def check_distances(dataset, queries, metric, out_idx, out_dist, eps=None): for i in range(queries.shape[0]): X = queries[np.newaxis, i, :] Y = dataset[out_idx[i, :], :] - if metric == "l2_expanded": + if metric == "sqeuclidean": dist[i, :] = pairwise_distances(X, Y, "sqeuclidean") elif metric == "euclidean": dist[i, :] = pairwise_distances(X, Y, "euclidean") @@ -177,7 +177,7 @@ def run_ivf_pq_build_search_test( # Calculate reference values with sklearn skl_metric = { - "l2_expanded": "sqeuclidean", + "sqeuclidean": "sqeuclidean", "inner_product": "cosine", "euclidean": "euclidean", }[metric] @@ -204,14 +204,14 @@ def test_ivf_pq_dtypes( n_rows, n_cols, n_queries, n_lists, dtype, inplace, array_type ): # Note that inner_product tests use normalized input which we cannot - # represent in int8, therefore we test only l2_expanded metric here. + # represent in int8, therefore we test only sqeuclidean metric here. run_ivf_pq_build_search_test( n_rows=n_rows, n_cols=n_cols, n_queries=n_queries, k=10, n_lists=n_lists, - metric="l2_expanded", + metric="sqeuclidean", dtype=dtype, inplace=inplace, array_type=array_type, @@ -246,14 +246,14 @@ def test_ivf_pq_n(params): n_queries=params["n_queries"], k=params["k"], n_lists=params["n_lists"], - metric="l2_expanded", + metric="sqeuclidean", dtype=np.float32, compare=False, ) @pytest.mark.parametrize( - "metric", ["l2_expanded", "inner_product", "euclidean"] + "metric", ["sqeuclidean", "inner_product", "euclidean"] ) @pytest.mark.parametrize("dtype", [np.float32]) @pytest.mark.parametrize("codebook_kind", ["subspace", "cluster"]) @@ -298,7 +298,7 @@ def test_ivf_pq_params(params): n_queries=1000, k=10, n_lists=params["n_lists"], - metric="l2_expanded", + metric="sqeuclidean", dtype=np.float32, pq_bits=params["pq_bits"], pq_dim=params["pq_dims"], @@ -344,7 +344,7 @@ def test_ivf_pq_search_params(params): k=params["k"], n_lists=100, n_probes=params["n_probes"], - metric="l2_expanded", + metric="sqeuclidean", dtype=np.float32, lut_dtype=params["lut"], internal_distance_dtype=params["idd"], @@ -360,7 +360,7 @@ def test_extend(dtype, array_type): n_queries=100, k=10, n_lists=100, - metric="l2_expanded", + metric="sqeuclidean", dtype=dtype, add_data_on_build=False, array_type=array_type, @@ -375,7 +375,7 @@ def test_build_assertions(): n_queries=100, k=10, n_lists=100, - metric="l2_expanded", + metric="sqeuclidean", dtype=np.float64, ) @@ -388,7 +388,7 @@ def test_build_assertions(): index_params = ivf_pq.IndexParams( n_lists=50, - metric="l2_expanded", + metric="sqeuclidean", kmeans_n_iters=20, kmeans_trainset_fraction=1, add_data_on_build=False, @@ -482,7 +482,7 @@ def test_search_inputs(params): out_dist_device = device_ndarray(out_dist) index_params = ivf_pq.IndexParams( - n_lists=50, metric="l2_expanded", add_data_on_build=True + n_lists=50, metric="sqeuclidean", add_data_on_build=True ) dataset = generate_data((n_rows, n_cols), dtype) @@ -511,7 +511,7 @@ def test_save_load(): dataset = generate_data((n_rows, n_cols), dtype) dataset_device = device_ndarray(dataset) - build_params = ivf_pq.IndexParams(n_lists=100, metric="l2_expanded") + build_params = ivf_pq.IndexParams(n_lists=100, metric="sqeuclidean") index = ivf_pq.build(build_params, dataset_device) assert index.trained diff --git a/python/pylibraft/pylibraft/test/test_refine.py b/python/pylibraft/pylibraft/test/test_refine.py index 2f3bef2e0c..8502d0575c 100644 --- a/python/pylibraft/pylibraft/test/test_refine.py +++ b/python/pylibraft/pylibraft/test/test_refine.py @@ -27,7 +27,7 @@ def run_refine( n_rows=500, n_cols=50, n_queries=100, - metric="l2_expanded", + metric="sqeuclidean", k0=40, k=10, inplace=False, @@ -49,7 +49,7 @@ def run_refine( queries_device = device_ndarray(queries) # Calculate reference values with sklearn - skl_metric = {"l2_expanded": "euclidean", "inner_product": "cosine"}[ + skl_metric = {"sqeuclidean": "euclidean", "inner_product": "cosine"}[ metric ] nn_skl = NearestNeighbors( @@ -106,7 +106,7 @@ def run_refine( if recall <= 0.999: # We did not find the same neighbor indices. # We could have found other neighbor with same distance. - if metric == "l2_expanded": + if metric == "sqeuclidean": skl_dist = np.power(skl_dist[:, :k], 2) elif metric == "inner_product": skl_dist = 1 - skl_dist[:, :k] @@ -120,7 +120,7 @@ def run_refine( @pytest.mark.parametrize("n_queries", [100, 1024, 37]) @pytest.mark.parametrize("inplace", [True, False]) -@pytest.mark.parametrize("metric", ["l2_expanded", "inner_product"]) +@pytest.mark.parametrize("metric", ["sqeuclidean", "inner_product"]) @pytest.mark.parametrize("dtype", [np.float32, np.int8, np.uint8]) @pytest.mark.parametrize("memory_type", ["device", "host"]) def test_refine_dtypes(n_queries, dtype, inplace, metric, memory_type): From 102a4f036a27bff30749f67b1133611575ee0603 Mon Sep 17 00:00:00 2001 From: Sevag H Date: Fri, 20 Jan 2023 09:00:21 -0500 Subject: [PATCH 3/9] Make cutlass use static ctk (#1155) Cutlass links to the CTK in ways that cause problems for downstream pip wheel builds (especially in cugraph). This might help. Authors: - Sevag H (https://github.com/sevagh) Approvers: - Robert Maynard (https://github.com/robertmaynard) URL: https://github.com/rapidsai/raft/pull/1155 --- cpp/cmake/thirdparty/get_cutlass.cmake | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/cpp/cmake/thirdparty/get_cutlass.cmake b/cpp/cmake/thirdparty/get_cutlass.cmake index 811a5466c3..3e02ce064e 100644 --- a/cpp/cmake/thirdparty/get_cutlass.cmake +++ b/cpp/cmake/thirdparty/get_cutlass.cmake @@ -30,6 +30,10 @@ function(find_and_configure_cutlass) CACHE BOOL "Disable CUTLASS to build with cuBLAS library." ) + if (CUDA_STATIC_RUNTIME) + set(CUDART_LIBRARY "${CUDA_cudart_static_LIBRARY}" CACHE FILEPATH "fixing cutlass cmake code" FORCE) + endif() + rapids_cpm_find( NvidiaCutlass ${PKG_VERSION} GLOBAL_TARGETS nvidia::cutlass::cutlass From b70519e631cedee4fd652215fb71a1a6c0545c85 Mon Sep 17 00:00:00 2001 From: "Artem M. Chirkin" <9253178+achirkin@users.noreply.github.com> Date: Sat, 21 Jan 2023 02:09:15 +0100 Subject: [PATCH 4/9] Protect balanced k-means out-of-memory in some cases (#1161) There's no guarantee that our balanced k-means implementation always produces balanced clusters. In the first stage, when mesoclusters are trained, the biggest cluster can grow larger than half of all input data. This becomes a problem at the second stage, when in `build_fine_clusters`, the mesocluster data is copied in a temporary buffer. If size is too big, there may be not enough memory on the device. A quick workaround: 1. Expand the error reporting (RAFT_LOG_WARN) 2. Artificially limit the mesocluster size in the event of highly unbalanced clustering Authors: - Artem M. Chirkin (https://github.com/achirkin) Approvers: - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/raft/pull/1161 --- .../knn/detail/ann_kmeans_balanced.cuh | 35 +++++++++++-------- 1 file changed, 21 insertions(+), 14 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh index 72df13d760..c6a3aea0cf 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh @@ -839,6 +839,10 @@ inline auto arrange_fine_clusters(uint32_t n_clusters, * As a result, the fine clusters are what is returned by `build_hierarchical`; * this function returns the total number of fine clusters, which can be checked to be * the same as the requested number of clusters. + * + * Note: this function uses at most `fine_clusters_nums_max` points per mesocluster for training; + * if one of the clusters is larger than that (as given by `mesocluster_sizes`), the extra data + * is ignored and a warning is reported. */ template auto build_fine_clusters(const handle_t& handle, @@ -880,8 +884,8 @@ auto build_fine_clusters(const handle_t& handle, uint32_t n_clusters_done = 0; for (uint32_t i = 0; i < n_mesoclusters; i++) { uint32_t k = 0; - for (IdxT j = 0; j < n_rows; j++) { - if (labels_mptr[j] == (LabelT)i) { mc_trainset_ids[k++] = j; } + for (IdxT j = 0; j < n_rows && k < mesocluster_size_max; j++) { + if (labels_mptr[j] == LabelT(i)) { mc_trainset_ids[k++] = j; } } if (k != mesocluster_sizes[i]) RAFT_LOG_WARN("Incorrect mesocluster size at %d. %d vs %d", i, k, mesocluster_sizes[i]); @@ -896,19 +900,13 @@ auto build_fine_clusters(const handle_t& handle, "Number of fine clusters must be non-zero for a non-empty mesocluster"); } - utils::copy_selected((IdxT)mesocluster_sizes[i], - (IdxT)dim, - dataset_mptr, - mc_trainset_ids, - (IdxT)dim, - mc_trainset, - (IdxT)dim, - stream); + utils::copy_selected( + (IdxT)k, (IdxT)dim, dataset_mptr, mc_trainset_ids, (IdxT)dim, mc_trainset, (IdxT)dim, stream); if (metric == raft::distance::DistanceType::L2Expanded || metric == raft::distance::DistanceType::L2SqrtExpanded) { thrust::gather(handle.get_thrust_policy(), mc_trainset_ids, - mc_trainset_ids + mesocluster_sizes[i], + mc_trainset_ids + k, dataset_norm_mptr, mc_trainset_norm); } @@ -917,7 +915,7 @@ auto build_fine_clusters(const handle_t& handle, n_iters, dim, mc_trainset, - mesocluster_sizes[i], + k, fine_clusters_nums[i], mc_trainset_ccenters.data(), mc_trainset_labels.data(), @@ -1036,10 +1034,19 @@ void build_hierarchical(const handle_t& handle, auto [mesocluster_size_max, fine_clusters_nums_max, fine_clusters_nums, fine_clusters_csum] = arrange_fine_clusters(n_clusters, n_mesoclusters, n_rows, mesocluster_sizes); - if (mesocluster_size_max * n_mesoclusters > 2 * n_rows) { - RAFT_LOG_WARN("build_hierarchical: built unbalanced mesoclusters"); + const auto mesocluster_size_max_balanced = uint32_t(div_rounding_up_safe( + 2lu * size_t(n_rows), std::max(size_t(n_mesoclusters), 1lu))); + if (mesocluster_size_max > mesocluster_size_max_balanced) { + RAFT_LOG_WARN( + "build_hierarchical: built unbalanced mesoclusters (max_mesocluster_size == %u > %u). " + "At most %u points will be used for training within each mesocluster. " + "Consider increasing the number of training iterations `n_iters`.", + mesocluster_size_max, + mesocluster_size_max_balanced, + mesocluster_size_max_balanced); RAFT_LOG_TRACE_VEC(mesocluster_sizes, n_mesoclusters); RAFT_LOG_TRACE_VEC(fine_clusters_nums.data(), n_mesoclusters); + mesocluster_size_max = mesocluster_size_max_balanced; } auto n_clusters_done = build_fine_clusters(handle, From a9e1adc6a55f03fb98199ab8c7f4bc82e9849a73 Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Sat, 21 Jan 2023 21:40:14 +0100 Subject: [PATCH 5/9] Improvement of the math API wrappers (#1146) Solves #1025 Provides a centralized collection of host- and device-friendly wrappers around common math operations, with generalizations when useful. Deprecates former `myXxx` wrappers. Those wrappers are mostly intended to future-proof the API as well as simplify the definition of host-device functions. Authors: - Louis Sugy (https://github.com/Nyrio) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1146 --- cpp/include/raft/core/math.hpp | 320 ++++++++++++++++ cpp/include/raft/core/operators.hpp | 27 +- cpp/include/raft/distance/detail/canberra.cuh | 4 +- .../raft/distance/detail/chebyshev.cuh | 4 +- .../raft/distance/detail/correlation.cuh | 4 +- .../raft/distance/detail/euclidean.cuh | 8 +- .../raft/distance/detail/fused_l2_nn.cuh | 4 +- .../raft/distance/detail/hellinger.cuh | 4 +- .../raft/distance/detail/jensen_shannon.cuh | 8 +- .../raft/distance/detail/kl_divergence.cuh | 14 +- cpp/include/raft/distance/detail/l1.cuh | 2 +- .../raft/distance/detail/minkowski.cuh | 8 +- cpp/include/raft/linalg/detail/lstsq.cuh | 4 +- cpp/include/raft/matrix/detail/math.cuh | 10 +- .../raft/random/detail/make_regression.cuh | 6 +- cpp/include/raft/random/detail/rng_device.cuh | 24 +- .../sparse/distance/detail/l2_distance.cuh | 12 +- .../sparse/distance/detail/lp_distance.cuh | 6 +- .../spatial/knn/detail/ball_cover/common.cuh | 4 +- .../spatial/knn/detail/haversine_distance.cuh | 8 +- .../raft/spectral/detail/spectral_util.cuh | 4 +- cpp/include/raft/stats/detail/stddev.cuh | 6 +- cpp/include/raft/util/cuda_utils.cuh | 90 ++--- cpp/test/CMakeLists.txt | 2 + cpp/test/core/math_device.cu | 352 ++++++++++++++++++ cpp/test/core/math_host.cpp | 195 ++++++++++ cpp/test/distance/distance_base.cuh | 22 +- cpp/test/distance/fused_l2_nn.cu | 2 +- cpp/test/linalg/matrix_vector.cu | 4 +- cpp/test/linalg/norm.cu | 10 +- cpp/test/linalg/power.cu | 6 +- cpp/test/linalg/sqrt.cu | 4 +- cpp/test/matrix/math.cu | 4 +- cpp/test/neighbors/ann_utils.cuh | 2 +- cpp/test/random/rng.cu | 14 +- 35 files changed, 1034 insertions(+), 164 deletions(-) create mode 100644 cpp/include/raft/core/math.hpp create mode 100644 cpp/test/core/math_device.cu create mode 100644 cpp/test/core/math_host.cpp diff --git a/cpp/include/raft/core/math.hpp b/cpp/include/raft/core/math.hpp new file mode 100644 index 0000000000..c5f08b84b7 --- /dev/null +++ b/cpp/include/raft/core/math.hpp @@ -0,0 +1,320 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include + +namespace raft { + +/** + * @defgroup Absolute Absolute value + * @{ + */ +template +RAFT_INLINE_FUNCTION auto abs(T x) + -> std::enable_if_t || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + T> +{ +#ifdef __CUDA_ARCH__ + return ::abs(x); +#else + return std::abs(x); +#endif +} +template +constexpr RAFT_INLINE_FUNCTION auto abs(T x) + -> std::enable_if_t && !std::is_same_v && + !std::is_same_v && !std::is_same_v && + !std::is_same_v, + T> +{ + return x < T{0} ? -x : x; +} +/** @} */ + +/** + * @defgroup Trigonometry Trigonometry functions + * @{ + */ +/** Inverse cosine */ +template +RAFT_INLINE_FUNCTION auto acos(T x) +{ +#ifdef __CUDA_ARCH__ + return ::acos(x); +#else + return std::acos(x); +#endif +} + +/** Inverse sine */ +template +RAFT_INLINE_FUNCTION auto asin(T x) +{ +#ifdef __CUDA_ARCH__ + return ::asin(x); +#else + return std::asin(x); +#endif +} + +/** Inverse hyperbolic tangent */ +template +RAFT_INLINE_FUNCTION auto atanh(T x) +{ +#ifdef __CUDA_ARCH__ + return ::atanh(x); +#else + return std::atanh(x); +#endif +} + +/** Cosine */ +template +RAFT_INLINE_FUNCTION auto cos(T x) +{ +#ifdef __CUDA_ARCH__ + return ::cos(x); +#else + return std::cos(x); +#endif +} + +/** Sine */ +template +RAFT_INLINE_FUNCTION auto sin(T x) +{ +#ifdef __CUDA_ARCH__ + return ::sin(x); +#else + return std::sin(x); +#endif +} + +/** Sine and cosine */ +template +RAFT_INLINE_FUNCTION std::enable_if_t || std::is_same_v> sincos( + const T& x, T* s, T* c) +{ +#ifdef __CUDA_ARCH__ + ::sincos(x, s, c); +#else + *s = std::sin(x); + *c = std::cos(x); +#endif +} + +/** Hyperbolic tangent */ +template +RAFT_INLINE_FUNCTION auto tanh(T x) +{ +#ifdef __CUDA_ARCH__ + return ::tanh(x); +#else + return std::tanh(x); +#endif +} +/** @} */ + +/** + * @defgroup Exponential Exponential and logarithm + * @{ + */ +/** Exponential function */ +template +RAFT_INLINE_FUNCTION auto exp(T x) +{ +#ifdef __CUDA_ARCH__ + return ::exp(x); +#else + return std::exp(x); +#endif +} + +/** Natural logarithm */ +template +RAFT_INLINE_FUNCTION auto log(T x) +{ +#ifdef __CUDA_ARCH__ + return ::log(x); +#else + return std::log(x); +#endif +} +/** @} */ + +/** + * @defgroup Maximum Maximum of two or more values. + * + * The CUDA Math API has overloads for all combinations of float/double. We provide similar + * functionality while wrapping around std::max, which only supports arguments of the same type. + * However, though the CUDA Math API supports combinations of unsigned and signed integers, this is + * very error-prone so we do not support that and require the user to cast instead. (e.g the max of + * -1 and 1u is 4294967295u...) + * + * When no overload matches, we provide a generic implementation but require that both types be the + * same (and that the less-than operator be defined). + * @{ + */ +template +RAFT_INLINE_FUNCTION auto max(const T1& x, const T2& y) +{ +#ifdef __CUDA_ARCH__ + // Combinations of types supported by the CUDA Math API + if constexpr ((std::is_integral_v && std::is_integral_v && std::is_same_v) || + ((std::is_same_v || std::is_same_v)&&( + std::is_same_v || std::is_same_v))) { + return ::max(x, y); + } + // Else, check that the types are the same and provide a generic implementation + else { + static_assert( + std::is_same_v, + "No native max overload for these types. Both argument types must be the same to use " + "the generic max. Please cast appropriately."); + return (x < y) ? y : x; + } +#else + if constexpr (std::is_same_v && std::is_same_v) { + return std::max(static_cast(x), y); + } else if constexpr (std::is_same_v && std::is_same_v) { + return std::max(x, static_cast(y)); + } else { + static_assert( + std::is_same_v, + "std::max requires that both argument types be the same. Please cast appropriately."); + return std::max(x, y); + } +#endif +} + +/** Many-argument overload to avoid verbose nested calls or use with variadic arguments */ +template +RAFT_INLINE_FUNCTION auto max(const T1& x, const T2& y, Args&&... args) +{ + return raft::max(x, raft::max(y, std::forward(args)...)); +} + +/** One-argument overload for convenience when using with variadic arguments */ +template +constexpr RAFT_INLINE_FUNCTION auto max(const T& x) +{ + return x; +} +/** @} */ + +/** + * @defgroup Minimum Minimum of two or more values. + * + * The CUDA Math API has overloads for all combinations of float/double. We provide similar + * functionality while wrapping around std::min, which only supports arguments of the same type. + * However, though the CUDA Math API supports combinations of unsigned and signed integers, this is + * very error-prone so we do not support that and require the user to cast instead. (e.g the min of + * -1 and 1u is 1u...) + * + * When no overload matches, we provide a generic implementation but require that both types be the + * same (and that the less-than operator be defined). + * @{ + */ +template +RAFT_INLINE_FUNCTION auto min(const T1& x, const T2& y) +{ +#ifdef __CUDA_ARCH__ + // Combinations of types supported by the CUDA Math API + if constexpr ((std::is_integral_v && std::is_integral_v && std::is_same_v) || + ((std::is_same_v || std::is_same_v)&&( + std::is_same_v || std::is_same_v))) { + return ::min(x, y); + } + // Else, check that the types are the same and provide a generic implementation + else { + static_assert( + std::is_same_v, + "No native min overload for these types. Both argument types must be the same to use " + "the generic min. Please cast appropriately."); + return (y < x) ? y : x; + } +#else + if constexpr (std::is_same_v && std::is_same_v) { + return std::min(static_cast(x), y); + } else if constexpr (std::is_same_v && std::is_same_v) { + return std::min(x, static_cast(y)); + } else { + static_assert( + std::is_same_v, + "std::min requires that both argument types be the same. Please cast appropriately."); + return std::min(x, y); + } +#endif +} + +/** Many-argument overload to avoid verbose nested calls or use with variadic arguments */ +template +RAFT_INLINE_FUNCTION auto min(const T1& x, const T2& y, Args&&... args) +{ + return raft::min(x, raft::min(y, std::forward(args)...)); +} + +/** One-argument overload for convenience when using with variadic arguments */ +template +constexpr RAFT_INLINE_FUNCTION auto min(const T& x) +{ + return x; +} +/** @} */ + +/** + * @defgroup Power Power and root functions + * @{ + */ +/** Power */ +template +RAFT_INLINE_FUNCTION auto pow(T1 x, T2 y) +{ +#ifdef __CUDA_ARCH__ + return ::pow(x, y); +#else + return std::pow(x, y); +#endif +} + +/** Square root */ +template +RAFT_INLINE_FUNCTION auto sqrt(T x) +{ +#ifdef __CUDA_ARCH__ + return ::sqrt(x); +#else + return std::sqrt(x); +#endif +} +/** @} */ + +/** Sign */ +template +RAFT_INLINE_FUNCTION auto sgn(T val) -> int +{ + return (T(0) < val) - (val < T(0)); +} + +} // namespace raft diff --git a/cpp/include/raft/core/operators.hpp b/cpp/include/raft/core/operators.hpp index 398354df46..de27c2b271 100644 --- a/cpp/include/raft/core/operators.hpp +++ b/cpp/include/raft/core/operators.hpp @@ -23,6 +23,7 @@ #include #include +#include namespace raft { @@ -75,9 +76,9 @@ struct value_op { struct sqrt_op { template - constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& in, UnusedArgs...) const + RAFT_INLINE_FUNCTION auto operator()(const Type& in, UnusedArgs...) const { - return std::sqrt(in); + return raft::sqrt(in); } }; @@ -91,9 +92,9 @@ struct nz_op { struct abs_op { template - constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& in, UnusedArgs...) const + RAFT_INLINE_FUNCTION auto operator()(const Type& in, UnusedArgs...) const { - return std::abs(in); + return raft::abs(in); } }; @@ -148,27 +149,25 @@ struct div_checkzero_op { struct pow_op { template - constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const + RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const { - return std::pow(a, b); + return raft::pow(a, b); } }; struct min_op { - template - constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const + template + RAFT_INLINE_FUNCTION auto operator()(Args&&... args) const { - if (a > b) { return b; } - return a; + return raft::min(std::forward(args)...); } }; struct max_op { - template - constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const + template + RAFT_INLINE_FUNCTION auto operator()(Args&&... args) const { - if (b > a) { return b; } - return a; + return raft::max(std::forward(args)...); } }; diff --git a/cpp/include/raft/distance/detail/canberra.cuh b/cpp/include/raft/distance/detail/canberra.cuh index 43a904edba..f17a26dc4b 100644 --- a/cpp/include/raft/distance/detail/canberra.cuh +++ b/cpp/include/raft/distance/detail/canberra.cuh @@ -73,8 +73,8 @@ static void canberraImpl(const DataT* x, // Accumulation operation lambda auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = raft::myAbs(x - y); - const auto add = raft::myAbs(x) + raft::myAbs(y); + const auto diff = raft::abs(x - y); + const auto add = raft::abs(x) + raft::abs(y); // deal with potential for 0 in denominator by // forcing 1/0 instead acc += ((add != 0) * diff / (add + (add == 0))); diff --git a/cpp/include/raft/distance/detail/chebyshev.cuh b/cpp/include/raft/distance/detail/chebyshev.cuh index 52573bd170..43b36e7921 100644 --- a/cpp/include/raft/distance/detail/chebyshev.cuh +++ b/cpp/include/raft/distance/detail/chebyshev.cuh @@ -73,8 +73,8 @@ static void chebyshevImpl(const DataT* x, // Accumulation operation lambda auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = raft::myAbs(x - y); - acc = raft::myMax(acc, diff); + const auto diff = raft::abs(x - y); + acc = raft::max(acc, diff); }; // epilogue operation lambda for final value calculation diff --git a/cpp/include/raft/distance/detail/correlation.cuh b/cpp/include/raft/distance/detail/correlation.cuh index 9bdbbf112c..f7fe3678e6 100644 --- a/cpp/include/raft/distance/detail/correlation.cuh +++ b/cpp/include/raft/distance/detail/correlation.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -125,7 +125,7 @@ static void correlationImpl(const DataT* x, auto Q_denom = k * regx2n[i] - (regxn[i] * regxn[i]); auto R_denom = k * regy2n[j] - (regyn[j] * regyn[j]); - acc[i][j] = 1 - (numer / raft::mySqrt(Q_denom * R_denom)); + acc[i][j] = 1 - (numer / raft::sqrt(Q_denom * R_denom)); } } }; diff --git a/cpp/include/raft/distance/detail/euclidean.cuh b/cpp/include/raft/distance/detail/euclidean.cuh index 4184810fff..1a2db63f5c 100644 --- a/cpp/include/raft/distance/detail/euclidean.cuh +++ b/cpp/include/raft/distance/detail/euclidean.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,7 +34,7 @@ struct L2ExpandedOp { __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept { AccT outVal = aNorm + bNorm - DataT(2.0) * accVal; - return sqrt ? raft::mySqrt(outVal) : outVal; + return sqrt ? raft::sqrt(outVal) : outVal; } __device__ AccT operator()(DataT aData) const noexcept { return aData; } @@ -130,7 +130,7 @@ void euclideanExpImpl(const DataT* x, for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] = raft::mySqrt(acc[i][j]); + acc[i][j] = raft::sqrt(acc[i][j]); } } } @@ -350,7 +350,7 @@ void euclideanUnExpImpl(const DataT* x, for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] = raft::mySqrt(acc[i][j]); + acc[i][j] = raft::sqrt(acc[i][j]); } } } diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index c9750df8ad..447359ffe6 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -175,7 +175,7 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { auto acc_ij = acc[i][j]; - acc[i][j] = acc_ij > DataT{0} ? raft::mySqrt(acc_ij) : DataT{0}; + acc[i][j] = acc_ij > DataT{0} ? raft::sqrt(acc_ij) : DataT{0}; } } } diff --git a/cpp/include/raft/distance/detail/hellinger.cuh b/cpp/include/raft/distance/detail/hellinger.cuh index 51f462ab36..13507fe84f 100644 --- a/cpp/include/raft/distance/detail/hellinger.cuh +++ b/cpp/include/raft/distance/detail/hellinger.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -105,7 +105,7 @@ static void hellingerImpl(const DataT* x, // Adjust to replace NaN in sqrt with 0 if input to sqrt is negative const auto finalVal = (1 - acc[i][j]); const auto rectifier = (!signbit(finalVal)); - acc[i][j] = raft::mySqrt(rectifier * finalVal); + acc[i][j] = raft::sqrt(rectifier * finalVal); } } }; diff --git a/cpp/include/raft/distance/detail/jensen_shannon.cuh b/cpp/include/raft/distance/detail/jensen_shannon.cuh index 92ee071cf5..f96da01b87 100644 --- a/cpp/include/raft/distance/detail/jensen_shannon.cuh +++ b/cpp/include/raft/distance/detail/jensen_shannon.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -78,11 +78,11 @@ static void jensenShannonImpl(const DataT* x, auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { const DataT m = 0.5f * (x + y); const bool m_zero = (m == 0); - const auto logM = (!m_zero) * raft::myLog(m + m_zero); + const auto logM = (!m_zero) * raft::log(m + m_zero); const bool x_zero = (x == 0); const bool y_zero = (y == 0); - acc += (-x * (logM - raft::myLog(x + x_zero))) + (-y * (logM - raft::myLog(y + y_zero))); + acc += (-x * (logM - raft::log(x + x_zero))) + (-y * (logM - raft::log(y + y_zero))); }; // epilogue operation lambda for final value calculation @@ -95,7 +95,7 @@ static void jensenShannonImpl(const DataT* x, for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] = raft::mySqrt(0.5 * acc[i][j]); + acc[i][j] = raft::sqrt(0.5 * acc[i][j]); } } }; diff --git a/cpp/include/raft/distance/detail/kl_divergence.cuh b/cpp/include/raft/distance/detail/kl_divergence.cuh index 4c0c4b6ace..7ebeaf4de9 100644 --- a/cpp/include/raft/distance/detail/kl_divergence.cuh +++ b/cpp/include/raft/distance/detail/kl_divergence.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -81,10 +81,10 @@ static void klDivergenceImpl(const DataT* x, auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { if (isRowMajor) { const bool x_zero = (x == 0); - acc += x * (raft::myLog(x + x_zero) - y); + acc += x * (raft::log(x + x_zero) - y); } else { const bool y_zero = (y == 0); - acc += y * (raft::myLog(y + y_zero) - x); + acc += y * (raft::log(y + y_zero) - x); } }; @@ -92,23 +92,23 @@ static void klDivergenceImpl(const DataT* x, if (isRowMajor) { const bool x_zero = (x == 0); const bool y_zero = (y == 0); - acc += x * (raft::myLog(x + x_zero) - (!y_zero) * raft::myLog(y + y_zero)); + acc += x * (raft::log(x + x_zero) - (!y_zero) * raft::log(y + y_zero)); } else { const bool y_zero = (y == 0); const bool x_zero = (x == 0); - acc += y * (raft::myLog(y + y_zero) - (!x_zero) * raft::myLog(x + x_zero)); + acc += y * (raft::log(y + y_zero) - (!x_zero) * raft::log(x + x_zero)); } }; auto unaryOp_lambda = [] __device__(DataT input) { const bool x_zero = (input == 0); - return (!x_zero) * raft::myLog(input + x_zero); + return (!x_zero) * raft::log(input + x_zero); }; auto unaryOp_lambda_reverse = [] __device__(DataT input) { // reverse previous log (x) back to x using (e ^ log(x)) const bool x_zero = (input == 0); - return (!x_zero) * raft::myExp(input); + return (!x_zero) * raft::exp(input); }; // epilogue operation lambda for final value calculation diff --git a/cpp/include/raft/distance/detail/l1.cuh b/cpp/include/raft/distance/detail/l1.cuh index 87893bab7c..bf10651b60 100644 --- a/cpp/include/raft/distance/detail/l1.cuh +++ b/cpp/include/raft/distance/detail/l1.cuh @@ -71,7 +71,7 @@ static void l1Impl(const DataT* x, // Accumulation operation lambda auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = raft::myAbs(x - y); + const auto diff = raft::abs(x - y); acc += diff; }; diff --git a/cpp/include/raft/distance/detail/minkowski.cuh b/cpp/include/raft/distance/detail/minkowski.cuh index bda83babf1..42af8cd281 100644 --- a/cpp/include/raft/distance/detail/minkowski.cuh +++ b/cpp/include/raft/distance/detail/minkowski.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -74,8 +74,8 @@ void minkowskiUnExpImpl(const DataT* x, // Accumulation operation lambda auto core_lambda = [p] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = raft::myAbs(x - y); - acc += raft::myPow(diff, p); + const auto diff = raft::abs(x - y); + acc += raft::pow(diff, p); }; // epilogue operation lambda for final value calculation @@ -89,7 +89,7 @@ void minkowskiUnExpImpl(const DataT* x, for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] = raft::myPow(acc[i][j], one_over_p); + acc[i][j] = raft::pow(acc[i][j], one_over_p); } } }; diff --git a/cpp/include/raft/linalg/detail/lstsq.cuh b/cpp/include/raft/linalg/detail/lstsq.cuh index 1273956b21..f0cf300e2f 100644 --- a/cpp/include/raft/linalg/detail/lstsq.cuh +++ b/cpp/include/raft/linalg/detail/lstsq.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -104,7 +104,7 @@ struct DivideByNonZero { operator()(const math_t a, const math_t b) const { - return raft::myAbs(b) >= eps ? a / b : a; + return raft::abs(b) >= eps ? a / b : a; } }; diff --git a/cpp/include/raft/matrix/detail/math.cuh b/cpp/include/raft/matrix/detail/math.cuh index c559da3942..f5c33d1cf6 100644 --- a/cpp/include/raft/matrix/detail/math.cuh +++ b/cpp/include/raft/matrix/detail/math.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -87,10 +87,10 @@ void seqRoot(math_t* in, if (a < math_t(0)) { return math_t(0); } else { - return sqrt(a * scalar); + return raft::sqrt(a * scalar); } } else { - return sqrt(a * scalar); + return raft::sqrt(a * scalar); } }, stream); @@ -278,7 +278,7 @@ void matrixVectorBinaryDivSkipZero(Type* data, rowMajor, bcastAlongRows, [] __device__(Type a, Type b) { - if (raft::myAbs(b) < Type(1e-10)) + if (raft::abs(b) < Type(1e-10)) return Type(0); else return a / b; @@ -294,7 +294,7 @@ void matrixVectorBinaryDivSkipZero(Type* data, rowMajor, bcastAlongRows, [] __device__(Type a, Type b) { - if (raft::myAbs(b) < Type(1e-10)) + if (raft::abs(b) < Type(1e-10)) return a; else return a / b; diff --git a/cpp/include/raft/random/detail/make_regression.cuh b/cpp/include/raft/random/detail/make_regression.cuh index cb0949c458..057196cd74 100644 --- a/cpp/include/raft/random/detail/make_regression.cuh +++ b/cpp/include/raft/random/detail/make_regression.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -44,8 +44,8 @@ static __global__ void _singular_profile_kernel(DataT* out, IdxT n, DataT tail_s IdxT tid = threadIdx.x + blockIdx.x * blockDim.x; if (tid < n) { DataT sval = static_cast(tid) / rank; - DataT low_rank = ((DataT)1.0 - tail_strength) * raft::myExp(-sval * sval); - DataT tail = tail_strength * raft::myExp((DataT)-0.1 * sval); + DataT low_rank = ((DataT)1.0 - tail_strength) * raft::exp(-sval * sval); + DataT tail = tail_strength * raft::exp((DataT)-0.1 * sval); out[tid] = low_rank + tail; } } diff --git a/cpp/include/raft/random/detail/rng_device.cuh b/cpp/include/raft/random/detail/rng_device.cuh index 6c75a4fa78..7f994fb07f 100644 --- a/cpp/include/raft/random/detail/rng_device.cuh +++ b/cpp/include/raft/random/detail/rng_device.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -143,10 +143,10 @@ DI void box_muller_transform(Type& val1, Type& val2, Type sigma1, Type mu1, Type { constexpr Type twoPi = Type(2.0) * Type(3.141592654); constexpr Type minus2 = -Type(2.0); - Type R = raft::mySqrt(minus2 * raft::myLog(val1)); + Type R = raft::sqrt(minus2 * raft::log(val1)); Type theta = twoPi * val2; Type s, c; - raft::mySinCos(theta, s, c); + raft::sincos(theta, &s, &c); val1 = R * c * sigma1 + mu1; val2 = R * s * sigma2 + mu2; } @@ -323,7 +323,7 @@ DI void custom_next( gen.next(res); } while (res == OutType(0.0)); - *val = params.mu - params.beta * raft::myLog(-raft::myLog(res)); + *val = params.mu - params.beta * raft::log(-raft::log(res)); } template @@ -340,8 +340,8 @@ DI void custom_next(GenType& gen, gen.next(res2); box_muller_transform(res1, res2, params.sigma, params.mu); - *val = raft::myExp(res1); - *(val + 1) = raft::myExp(res2); + *val = raft::exp(res1); + *(val + 1) = raft::exp(res2); } template @@ -358,7 +358,7 @@ DI void custom_next(GenType& gen, } while (res == OutType(0.0)); constexpr OutType one = (OutType)1.0; - *val = params.mu - params.scale * raft::myLog(one / res - one); + *val = params.mu - params.scale * raft::log(one / res - one); } template @@ -371,7 +371,7 @@ DI void custom_next(GenType& gen, OutType res; gen.next(res); constexpr OutType one = (OutType)1.0; - *val = -raft::myLog(one - res) / params.lambda; + *val = -raft::log(one - res) / params.lambda; } template @@ -386,7 +386,7 @@ DI void custom_next(GenType& gen, constexpr OutType one = (OutType)1.0; constexpr OutType two = (OutType)2.0; - *val = raft::mySqrt(-two * raft::myLog(one - res)) * params.sigma; + *val = raft::sqrt(-two * raft::log(one - res)) * params.sigma; } template @@ -409,9 +409,9 @@ DI void custom_next(GenType& gen, // The <= comparison here means, number of samples going in `if` branch are more by 1 than `else` // branch. However it does not matter as for 0.5 both branches evaluate to same result. if (res <= oneHalf) { - out = params.mu + params.scale * raft::myLog(two * res); + out = params.mu + params.scale * raft::log(two * res); } else { - out = params.mu - params.scale * raft::myLog(two * (one - res)); + out = params.mu - params.scale * raft::log(two * (one - res)); } *val = out; } @@ -424,7 +424,7 @@ DI void custom_next( gen.next(res); params.inIdxPtr[idx] = idx; constexpr OutType one = (OutType)1.0; - auto exp = -raft::myLog(one - res); + auto exp = -raft::log(one - res); if (params.wts != nullptr) { *val = exp / params.wts[idx]; } else { diff --git a/cpp/include/raft/sparse/distance/detail/l2_distance.cuh b/cpp/include/raft/sparse/distance/detail/l2_distance.cuh index 3c852235df..2f165b3ff2 100644 --- a/cpp/include/raft/sparse/distance/detail/l2_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/l2_distance.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -112,7 +112,7 @@ __global__ void compute_correlation_warp_kernel(value_t* __restrict__ C, value_t Q_denom = n * Q_l2 - (Q_l1 * Q_l1); value_t R_denom = n * R_l2 - (R_l1 * R_l1); - value_t val = 1 - (numer / sqrt(Q_denom * R_denom)); + value_t val = 1 - (numer / raft::sqrt(Q_denom * R_denom)); // correct for small instabilities C[(size_t)i * n_cols + j] = val * (fabs(val) >= 0.0001); @@ -292,7 +292,7 @@ class l2_sqrt_expanded_distances_t : public l2_expanded_distances_tconfig_->a_nrows * this->config_->b_nrows, [] __device__(value_t input) { int neg = input < 0 ? -1 : 1; - return sqrt(abs(input) * neg); + return raft::sqrt(abs(input) * neg); }, this->config_->handle.get_stream()); } @@ -379,7 +379,7 @@ class cosine_expanded_distances_t : public distances_t { config_->b_nrows, config_->handle.get_stream(), [] __device__ __host__(value_t dot, value_t q_norm, value_t r_norm) { - value_t norms = sqrt(q_norm) * sqrt(r_norm); + value_t norms = raft::sqrt(q_norm) * raft::sqrt(r_norm); // deal with potential for 0 in denominator by forcing 0/1 instead value_t cos = ((norms != 0) * dot) / ((norms == 0) + norms); @@ -429,7 +429,7 @@ class hellinger_expanded_distances_t : public distances_t { out_dists, *config_, coo_rows.data(), - [] __device__(value_t a, value_t b) { return sqrt(a) * sqrt(b); }, + [] __device__(value_t a, value_t b) { return raft::sqrt(a) * raft::sqrt(b); }, raft::add_op(), raft::atomic_add_op()); @@ -440,7 +440,7 @@ class hellinger_expanded_distances_t : public distances_t { [=] __device__(value_t input) { // Adjust to replace NaN in sqrt with 0 if input to sqrt is negative bool rectifier = (1 - input) > 0; - return sqrt(rectifier * (1 - input)); + return raft::sqrt(rectifier * (1 - input)); }, config_->handle.get_stream()); } diff --git a/cpp/include/raft/sparse/distance/detail/lp_distance.cuh b/cpp/include/raft/sparse/distance/detail/lp_distance.cuh index a973aebbab..f67109afbc 100644 --- a/cpp/include/raft/sparse/distance/detail/lp_distance.cuh +++ b/cpp/include/raft/sparse/distance/detail/lp_distance.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -132,7 +132,7 @@ class l2_sqrt_unexpanded_distances_t : public l2_unexpanded_distances_tconfig_->a_nrows * this->config_->b_nrows, [] __device__(value_t input) { int neg = input < 0 ? -1 : 1; - return sqrt(abs(input) * neg); + return raft::sqrt(abs(input) * neg); }, this->config_->handle.get_stream()); } @@ -274,7 +274,7 @@ class jensen_shannon_unexpanded_distances_t : public distances_t { out_dists, out_dists, config_->a_nrows * config_->b_nrows, - [=] __device__(value_t input) { return sqrt(0.5 * input); }, + [=] __device__(value_t input) { return raft::sqrt(0.5 * input); }, config_->handle.get_stream()); } diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/common.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/common.cuh index b09cf0da10..0a6718f5a5 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/common.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/common.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -71,7 +71,7 @@ struct EuclideanFunc : public DistFunc { sum_sq += diff * diff; } - return sqrt(sum_sq); + return raft::sqrt(sum_sq); } }; diff --git a/cpp/include/raft/spatial/knn/detail/haversine_distance.cuh b/cpp/include/raft/spatial/knn/detail/haversine_distance.cuh index e073841dd3..9cecc0adf4 100644 --- a/cpp/include/raft/spatial/knn/detail/haversine_distance.cuh +++ b/cpp/include/raft/spatial/knn/detail/haversine_distance.cuh @@ -32,11 +32,11 @@ namespace detail { template DI value_t compute_haversine(value_t x1, value_t y1, value_t x2, value_t y2) { - value_t sin_0 = sin(0.5 * (x1 - y1)); - value_t sin_1 = sin(0.5 * (x2 - y2)); - value_t rdist = sin_0 * sin_0 + cos(x1) * cos(y1) * sin_1 * sin_1; + value_t sin_0 = raft::sin(0.5 * (x1 - y1)); + value_t sin_1 = raft::sin(0.5 * (x2 - y2)); + value_t rdist = sin_0 * sin_0 + raft::cos(x1) * raft::cos(y1) * sin_1 * sin_1; - return 2 * asin(sqrt(rdist)); + return 2 * raft::asin(raft::sqrt(rdist)); } /** diff --git a/cpp/include/raft/spectral/detail/spectral_util.cuh b/cpp/include/raft/spectral/detail/spectral_util.cuh index 3a0ad1f96f..5991e71ec6 100644 --- a/cpp/include/raft/spectral/detail/spectral_util.cuh +++ b/cpp/include/raft/spectral/detail/spectral_util.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -72,7 +72,7 @@ static __global__ void scale_obs_kernel(index_type_t m, index_type_t n, value_ty // scale by alpha alpha = __shfl_sync(warp_full_mask(), alpha, blockDim.x - 1, blockDim.x); - alpha = std::sqrt(alpha); + alpha = raft::sqrt(alpha); for (j = threadIdx.y + blockIdx.y * blockDim.y; j < n; j += blockDim.y * gridDim.y) { for (i = threadIdx.x; i < m; i += blockDim.x) { // blockDim.x=32 index = i + j * m; diff --git a/cpp/include/raft/stats/detail/stddev.cuh b/cpp/include/raft/stats/detail/stddev.cuh index ccea2ea5da..2f7e22ca8a 100644 --- a/cpp/include/raft/stats/detail/stddev.cuh +++ b/cpp/include/raft/stats/detail/stddev.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -63,7 +63,7 @@ __global__ void stddevKernelColMajor( thread_data += diff * diff; } Type acc = BlockReduce(temp_storage).Sum(thread_data); - if (threadIdx.x == 0) { std[blockIdx.x] = raft::mySqrt(acc / N); } + if (threadIdx.x == 0) { std[blockIdx.x] = raft::sqrt(acc / N); } } template @@ -126,7 +126,7 @@ void stddev(Type* std, std, mu, D, - [ratio] __device__(Type a, Type b) { return raft::mySqrt(a * ratio - b * b); }, + [ratio] __device__(Type a, Type b) { return raft::sqrt(a * ratio - b * b); }, stream); } else { stddevKernelColMajor<<>>(std, data, mu, D, N); diff --git a/cpp/include/raft/util/cuda_utils.cuh b/cpp/include/raft/util/cuda_utils.cuh index 61dd6e0ad8..5be9dc999a 100644 --- a/cpp/include/raft/util/cuda_utils.cuh +++ b/cpp/include/raft/util/cuda_utils.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ #include #include +#include #include #ifndef ENABLE_MEMCPY_ASYNC @@ -259,12 +260,14 @@ DI double myAtomicMax(double* address, double val) template HDI T myMax(T x, T y); template <> -HDI float myMax(float x, float y) +[[deprecated("use raft::max from raft/core/math.hpp instead")]] HDI float myMax(float x, + float y) { return fmaxf(x, y); } template <> -HDI double myMax(double x, double y) +[[deprecated("use raft::max from raft/core/math.hpp instead")]] HDI double myMax(double x, + double y) { return fmax(x, y); } @@ -277,12 +280,14 @@ HDI double myMax(double x, double y) template HDI T myMin(T x, T y); template <> -HDI float myMin(float x, float y) +[[deprecated("use raft::min from raft/core/math.hpp instead")]] HDI float myMin(float x, + float y) { return fminf(x, y); } template <> -HDI double myMin(double x, double y) +[[deprecated("use raft::min from raft/core/math.hpp instead")]] HDI double myMin(double x, + double y) { return fmin(x, y); } @@ -298,7 +303,7 @@ HDI double myMin(double x, double y) template DI T myAtomicMin(T* address, T val) { - myAtomicReduce(address, val, myMin); + myAtomicReduce(address, val, raft::min_op{}); return *address; } @@ -312,19 +317,10 @@ DI T myAtomicMin(T* address, T val) template DI T myAtomicMax(T* address, T val) { - myAtomicReduce(address, val, myMax); + myAtomicReduce(address, val, raft::max_op{}); return *address; } -/** - * Sign function - */ -template -HDI int sgn(const T val) -{ - return (T(0) < val) - (val < T(0)); -} - /** * @defgroup Exp Exponential function * @{ @@ -332,14 +328,14 @@ HDI int sgn(const T val) template HDI T myExp(T x); template <> -HDI float myExp(float x) +[[deprecated("use raft::exp from raft/core/math.hpp instead")]] HDI float myExp(float x) { return expf(x); } template <> -HDI double myExp(double x) +[[deprecated("use raft::exp from raft/core/math.hpp instead")]] HDI double myExp(double x) { - return exp(x); + return ::exp(x); } /** @} */ @@ -368,14 +364,14 @@ inline __device__ double myInf() template HDI T myLog(T x); template <> -HDI float myLog(float x) +[[deprecated("use raft::log from raft/core/math.hpp instead")]] HDI float myLog(float x) { return logf(x); } template <> -HDI double myLog(double x) +[[deprecated("use raft::log from raft/core/math.hpp instead")]] HDI double myLog(double x) { - return log(x); + return ::log(x); } /** @} */ @@ -386,14 +382,14 @@ HDI double myLog(double x) template HDI T mySqrt(T x); template <> -HDI float mySqrt(float x) +[[deprecated("use raft::sqrt from raft/core/math.hpp instead")]] HDI float mySqrt(float x) { return sqrtf(x); } template <> -HDI double mySqrt(double x) +[[deprecated("use raft::sqrt from raft/core/math.hpp instead")]] HDI double mySqrt(double x) { - return sqrt(x); + return ::sqrt(x); } /** @} */ @@ -404,14 +400,18 @@ HDI double mySqrt(double x) template DI void mySinCos(T x, T& s, T& c); template <> -DI void mySinCos(float x, float& s, float& c) +[[deprecated("use raft::sincos from raft/core/math.hpp instead")]] DI void mySinCos(float x, + float& s, + float& c) { sincosf(x, &s, &c); } template <> -DI void mySinCos(double x, double& s, double& c) +[[deprecated("use raft::sincos from raft/core/math.hpp instead")]] DI void mySinCos(double x, + double& s, + double& c) { - sincos(x, &s, &c); + ::sincos(x, &s, &c); } /** @} */ @@ -422,14 +422,14 @@ DI void mySinCos(double x, double& s, double& c) template DI T mySin(T x); template <> -DI float mySin(float x) +[[deprecated("use raft::sin from raft/core/math.hpp instead")]] DI float mySin(float x) { return sinf(x); } template <> -DI double mySin(double x) +[[deprecated("use raft::sin from raft/core/math.hpp instead")]] DI double mySin(double x) { - return sin(x); + return ::sin(x); } /** @} */ @@ -443,12 +443,12 @@ DI T myAbs(T x) return x < 0 ? -x : x; } template <> -DI float myAbs(float x) +[[deprecated("use raft::abs from raft/core/math.hpp instead")]] DI float myAbs(float x) { return fabsf(x); } template <> -DI double myAbs(double x) +[[deprecated("use raft::abs from raft/core/math.hpp instead")]] DI double myAbs(double x) { return fabs(x); } @@ -461,14 +461,16 @@ DI double myAbs(double x) template HDI T myPow(T x, T power); template <> -HDI float myPow(float x, float power) +[[deprecated("use raft::pow from raft/core/math.hpp instead")]] HDI float myPow(float x, + float power) { return powf(x, power); } template <> -HDI double myPow(double x, double power) +[[deprecated("use raft::pow from raft/core/math.hpp instead")]] HDI double myPow(double x, + double power) { - return pow(x, power); + return ::pow(x, power); } /** @} */ @@ -479,14 +481,14 @@ HDI double myPow(double x, double power) template HDI T myTanh(T x); template <> -HDI float myTanh(float x) +[[deprecated("use raft::tanh from raft/core/math.hpp instead")]] HDI float myTanh(float x) { return tanhf(x); } template <> -HDI double myTanh(double x) +[[deprecated("use raft::tanh from raft/core/math.hpp instead")]] HDI double myTanh(double x) { - return tanh(x); + return ::tanh(x); } /** @} */ @@ -497,14 +499,14 @@ HDI double myTanh(double x) template HDI T myATanh(T x); template <> -HDI float myATanh(float x) +[[deprecated("use raft::atanh from raft/core/math.hpp instead")]] HDI float myATanh(float x) { return atanhf(x); } template <> -HDI double myATanh(double x) +[[deprecated("use raft::atanh from raft/core/math.hpp instead")]] HDI double myATanh(double x) { - return atanh(x); + return ::atanh(x); } /** @} */ @@ -526,7 +528,7 @@ struct SqrtOp { [[deprecated("SqrtOp is deprecated. Use sqrt_op instead.")]] HDI Type operator()(Type in, IdxType i = 0) const { - return mySqrt(in); + return raft::sqrt(in); } }; @@ -544,7 +546,7 @@ struct L1Op { [[deprecated("L1Op is deprecated. Use abs_op instead.")]] HDI Type operator()(Type in, IdxType i = 0) const { - return myAbs(in); + return raft::abs(in); } }; diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 8ca30a5c82..a4b3758faa 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -86,6 +86,8 @@ if(BUILD_TESTS) CORE_TEST PATH test/core/logger.cpp + test/core/math_device.cu + test/core/math_host.cpp test/core/operators_device.cu test/core/operators_host.cpp test/core/handle.cpp diff --git a/cpp/test/core/math_device.cu b/cpp/test/core/math_device.cu new file mode 100644 index 0000000000..ff4b343d9e --- /dev/null +++ b/cpp/test/core/math_device.cu @@ -0,0 +1,352 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../test_utils.h" +#include +#include +#include + +template +__global__ void math_eval_kernel(OutT* out, OpT op, Args... args) +{ + out[0] = op(std::forward(args)...); +} + +template +auto math_eval(OpT op, Args&&... args) +{ + typedef decltype(op(args...)) OutT; + auto stream = rmm::cuda_stream_default; + rmm::device_scalar result(stream); + math_eval_kernel<<<1, 1, 0, stream>>>(result.data(), op, std::forward(args)...); + return result.value(stream); +} + +struct abs_test_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& in) const + { + return raft::abs(in); + } +}; + +TEST(MathDevice, Abs) +{ + // Integer abs + ASSERT_TRUE( + raft::match(int8_t{123}, math_eval(abs_test_op{}, int8_t{-123}), raft::Compare())); + ASSERT_TRUE(raft::match(12345, math_eval(abs_test_op{}, -12345), raft::Compare())); + ASSERT_TRUE(raft::match(12345l, math_eval(abs_test_op{}, -12345l), raft::Compare())); + ASSERT_TRUE(raft::match(123451234512345ll, + math_eval(abs_test_op{}, -123451234512345ll), + raft::Compare())); + // Floating-point abs + ASSERT_TRUE( + raft::match(12.34f, math_eval(abs_test_op{}, -12.34f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE( + raft::match(12.34, math_eval(abs_test_op{}, -12.34), raft::CompareApprox(0.000001))); +} + +struct acos_test_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& in) const + { + return raft::acos(in); + } +}; + +TEST(MathDevice, Acos) +{ + ASSERT_TRUE(raft::match( + std::acos(0.123f), math_eval(acos_test_op{}, 0.123f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE(raft::match( + std::acos(0.123), math_eval(acos_test_op{}, 0.123), raft::CompareApprox(0.000001))); +} + +struct asin_test_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& in) const + { + return raft::asin(in); + } +}; + +TEST(MathDevice, Asin) +{ + ASSERT_TRUE(raft::match( + std::asin(0.123f), math_eval(asin_test_op{}, 0.123f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE(raft::match( + std::asin(0.123), math_eval(asin_test_op{}, 0.123), raft::CompareApprox(0.000001))); +} + +struct atanh_test_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& in) const + { + return raft::atanh(in); + } +}; + +TEST(MathDevice, Atanh) +{ + ASSERT_TRUE(raft::match( + std::atanh(0.123f), math_eval(atanh_test_op{}, 0.123f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE(raft::match( + std::atanh(0.123), math_eval(atanh_test_op{}, 0.123), raft::CompareApprox(0.000001))); +} + +struct cos_test_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& in) const + { + return raft::cos(in); + } +}; + +TEST(MathDevice, Cos) +{ + ASSERT_TRUE(raft::match( + std::cos(12.34f), math_eval(cos_test_op{}, 12.34f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE(raft::match( + std::cos(12.34), math_eval(cos_test_op{}, 12.34), raft::CompareApprox(0.000001))); +} + +struct exp_test_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& in) const + { + return raft::exp(in); + } +}; + +TEST(MathDevice, Exp) +{ + ASSERT_TRUE(raft::match( + std::exp(12.34f), math_eval(exp_test_op{}, 12.34f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE(raft::match( + std::exp(12.34), math_eval(exp_test_op{}, 12.34), raft::CompareApprox(0.000001))); +} + +struct log_test_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& in) const + { + return raft::log(in); + } +}; + +TEST(MathDevice, Log) +{ + ASSERT_TRUE(raft::match( + std::log(12.34f), math_eval(log_test_op{}, 12.34f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE(raft::match( + std::log(12.34), math_eval(log_test_op{}, 12.34), raft::CompareApprox(0.000001))); +} + +struct max_test_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(Args&&... args) const + { + return raft::max(std::forward(args)...); + } +}; + +TEST(MathDevice, Max2) +{ + ASSERT_TRUE(raft::match(1234, math_eval(max_test_op{}, -1234, 1234), raft::Compare())); + ASSERT_TRUE( + raft::match(1234u, math_eval(max_test_op{}, 1234u, 123u), raft::Compare())); + ASSERT_TRUE( + raft::match(1234ll, math_eval(max_test_op{}, -1234ll, 1234ll), raft::Compare())); + ASSERT_TRUE(raft::match( + 1234ull, math_eval(max_test_op{}, 1234ull, 123ull), raft::Compare())); + + ASSERT_TRUE( + raft::match(12.34f, math_eval(max_test_op{}, -12.34f, 12.34f), raft::Compare())); + ASSERT_TRUE(raft::match(12.34, math_eval(max_test_op{}, -12.34, 12.34), raft::Compare())); + ASSERT_TRUE(raft::match( + 12.34, math_eval(max_test_op{}, -12.34f, 12.34), raft::CompareApprox(0.000001))); + ASSERT_TRUE(raft::match( + 12.34, math_eval(max_test_op{}, -12.34, 12.34f), raft::CompareApprox(0.000001))); +} + +TEST(MathDevice, Max3) +{ + ASSERT_TRUE(raft::match(1234, math_eval(max_test_op{}, 1234, 0, -1234), raft::Compare())); + ASSERT_TRUE(raft::match(1234, math_eval(max_test_op{}, -1234, 1234, 0), raft::Compare())); + ASSERT_TRUE(raft::match(1234, math_eval(max_test_op{}, 0, -1234, 1234), raft::Compare())); + + ASSERT_TRUE(raft::match( + 12.34, math_eval(max_test_op{}, 12.34f, 0., -12.34), raft::CompareApprox(0.000001))); + ASSERT_TRUE(raft::match( + 12.34, math_eval(max_test_op{}, -12.34, 12.34f, 0.), raft::CompareApprox(0.000001))); + ASSERT_TRUE(raft::match( + 12.34, math_eval(max_test_op{}, 0., -12.34, 12.34f), raft::CompareApprox(0.000001))); +} + +struct min_test_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(Args&&... args) const + { + return raft::min(std::forward(args)...); + } +}; + +TEST(MathDevice, Min2) +{ + ASSERT_TRUE(raft::match(-1234, math_eval(min_test_op{}, -1234, 1234), raft::Compare())); + ASSERT_TRUE( + raft::match(123u, math_eval(min_test_op{}, 1234u, 123u), raft::Compare())); + ASSERT_TRUE(raft::match( + -1234ll, math_eval(min_test_op{}, -1234ll, 1234ll), raft::Compare())); + ASSERT_TRUE(raft::match( + 123ull, math_eval(min_test_op{}, 1234ull, 123ull), raft::Compare())); + + ASSERT_TRUE( + raft::match(-12.34f, math_eval(min_test_op{}, -12.34f, 12.34f), raft::Compare())); + ASSERT_TRUE( + raft::match(-12.34, math_eval(min_test_op{}, -12.34, 12.34), raft::Compare())); + ASSERT_TRUE(raft::match( + -12.34, math_eval(min_test_op{}, -12.34f, 12.34), raft::CompareApprox(0.000001))); + ASSERT_TRUE(raft::match( + -12.34, math_eval(min_test_op{}, -12.34, 12.34f), raft::CompareApprox(0.000001))); +} + +TEST(MathDevice, Min3) +{ + ASSERT_TRUE(raft::match(-1234, math_eval(min_test_op{}, 1234, 0, -1234), raft::Compare())); + ASSERT_TRUE(raft::match(-1234, math_eval(min_test_op{}, -1234, 1234, 0), raft::Compare())); + ASSERT_TRUE(raft::match(-1234, math_eval(min_test_op{}, 0, -1234, 1234), raft::Compare())); + + ASSERT_TRUE(raft::match( + -12.34, math_eval(min_test_op{}, 12.34f, 0., -12.34), raft::CompareApprox(0.000001))); + ASSERT_TRUE(raft::match( + -12.34, math_eval(min_test_op{}, -12.34, 12.34f, 0.), raft::CompareApprox(0.000001))); + ASSERT_TRUE(raft::match( + -12.34, math_eval(min_test_op{}, 0., -12.34, 12.34f), raft::CompareApprox(0.000001))); +} + +struct pow_test_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& x, const Type& y) const + { + return raft::pow(x, y); + } +}; + +TEST(MathDevice, Pow) +{ + ASSERT_TRUE(raft::match(std::pow(12.34f, 2.f), + math_eval(pow_test_op{}, 12.34f, 2.f), + raft::CompareApprox(0.0001f))); + ASSERT_TRUE(raft::match(std::pow(12.34, 2.), + math_eval(pow_test_op{}, 12.34, 2.), + raft::CompareApprox(0.000001))); +} + +struct sgn_test_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& in) const + { + return raft::sgn(in); + } +}; + +TEST(MathDevice, Sgn) +{ + ASSERT_TRUE(raft::match(-1, math_eval(sgn_test_op{}, -1234), raft::Compare())); + ASSERT_TRUE(raft::match(0, math_eval(sgn_test_op{}, 0), raft::Compare())); + ASSERT_TRUE(raft::match(1, math_eval(sgn_test_op{}, 1234), raft::Compare())); + ASSERT_TRUE(raft::match(-1, math_eval(sgn_test_op{}, -12.34f), raft::Compare())); + ASSERT_TRUE(raft::match(0, math_eval(sgn_test_op{}, 0.f), raft::Compare())); + ASSERT_TRUE(raft::match(1, math_eval(sgn_test_op{}, 12.34f), raft::Compare())); +} + +struct sin_test_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& in) const + { + return raft::sin(in); + } +}; + +TEST(MathDevice, Sin) +{ + ASSERT_TRUE(raft::match( + std::sin(12.34f), math_eval(sin_test_op{}, 12.34f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE(raft::match( + std::sin(12.34), math_eval(sin_test_op{}, 12.34), raft::CompareApprox(0.000001))); +} + +struct sincos_test_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& x, Type* s, Type* c) const + { + raft::sincos(x, s, c); + return x; // unused, just to avoid creating another helper + } +}; + +TEST(MathDevice, SinCos) +{ + auto stream = rmm::cuda_stream_default; + float xf = 12.34f; + rmm::device_scalar sf(stream); + rmm::device_scalar cf(stream); + math_eval(sincos_test_op{}, xf, sf.data(), cf.data()); + ASSERT_TRUE(raft::match(std::sin(12.34f), sf.value(stream), raft::CompareApprox(0.0001f))); + ASSERT_TRUE(raft::match(std::cos(12.34f), cf.value(stream), raft::CompareApprox(0.0001f))); + double xd = 12.34f; + rmm::device_scalar sd(stream); + rmm::device_scalar cd(stream); + math_eval(sincos_test_op{}, xd, sd.data(), cd.data()); + ASSERT_TRUE(raft::match(std::sin(12.34), sd.value(stream), raft::CompareApprox(0.0001f))); + ASSERT_TRUE(raft::match(std::cos(12.34), cd.value(stream), raft::CompareApprox(0.0001f))); +} + +struct sqrt_test_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& in) const + { + return raft::sqrt(in); + } +}; + +TEST(MathDevice, Sqrt) +{ + ASSERT_TRUE(raft::match( + std::sqrt(12.34f), math_eval(sqrt_test_op{}, 12.34f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE(raft::match( + std::sqrt(12.34), math_eval(sqrt_test_op{}, 12.34), raft::CompareApprox(0.000001))); +} + +struct tanh_test_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& in) const + { + return raft::tanh(in); + } +}; + +TEST(MathDevice, Tanh) +{ + ASSERT_TRUE(raft::match( + std::tanh(12.34f), math_eval(tanh_test_op{}, 12.34f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE(raft::match( + std::tanh(12.34), math_eval(tanh_test_op{}, 12.34), raft::CompareApprox(0.000001))); +} diff --git a/cpp/test/core/math_host.cpp b/cpp/test/core/math_host.cpp new file mode 100644 index 0000000000..5808905713 --- /dev/null +++ b/cpp/test/core/math_host.cpp @@ -0,0 +1,195 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "../test_utils.h" +#include + +TEST(MathHost, Abs) +{ + // Integer abs + ASSERT_TRUE(raft::match(int8_t{123}, raft::abs(int8_t{-123}), raft::Compare())); + ASSERT_TRUE(raft::match(12345, raft::abs(-12345), raft::Compare())); + ASSERT_TRUE(raft::match(12345l, raft::abs(-12345l), raft::Compare())); + ASSERT_TRUE( + raft::match(123451234512345ll, raft::abs(-123451234512345ll), raft::Compare())); + // Floating-point abs + ASSERT_TRUE(raft::match(12.34f, raft::abs(-12.34f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE(raft::match(12.34, raft::abs(-12.34), raft::CompareApprox(0.000001))); +} + +TEST(MathHost, Acos) +{ + ASSERT_TRUE( + raft::match(std::acos(0.123f), raft::acos(0.123f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE( + raft::match(std::acos(0.123), raft::acos(0.123), raft::CompareApprox(0.000001))); +} + +TEST(MathHost, Asin) +{ + ASSERT_TRUE( + raft::match(std::asin(0.123f), raft::asin(0.123f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE( + raft::match(std::asin(0.123), raft::asin(0.123), raft::CompareApprox(0.000001))); +} + +TEST(MathHost, Atanh) +{ + ASSERT_TRUE( + raft::match(std::atanh(0.123f), raft::atanh(0.123f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE( + raft::match(std::atanh(0.123), raft::atanh(0.123), raft::CompareApprox(0.000001))); +} + +TEST(MathHost, Cos) +{ + ASSERT_TRUE( + raft::match(std::cos(12.34f), raft::cos(12.34f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE( + raft::match(std::cos(12.34), raft::cos(12.34), raft::CompareApprox(0.000001))); +} + +TEST(MathHost, Exp) +{ + ASSERT_TRUE( + raft::match(std::exp(12.34f), raft::exp(12.34f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE( + raft::match(std::exp(12.34), raft::exp(12.34), raft::CompareApprox(0.000001))); +} + +TEST(MathHost, Log) +{ + ASSERT_TRUE( + raft::match(std::log(12.34f), raft::log(12.34f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE( + raft::match(std::log(12.34), raft::log(12.34), raft::CompareApprox(0.000001))); +} + +TEST(MathHost, Max2) +{ + ASSERT_TRUE(raft::match(1234, raft::max(-1234, 1234), raft::Compare())); + ASSERT_TRUE(raft::match(1234u, raft::max(1234u, 123u), raft::Compare())); + ASSERT_TRUE(raft::match(1234ll, raft::max(-1234ll, 1234ll), raft::Compare())); + ASSERT_TRUE( + raft::match(1234ull, raft::max(1234ull, 123ull), raft::Compare())); + + ASSERT_TRUE(raft::match(12.34f, raft::max(-12.34f, 12.34f), raft::Compare())); + ASSERT_TRUE(raft::match(12.34, raft::max(-12.34, 12.34), raft::Compare())); + ASSERT_TRUE(raft::match(12.34, raft::max(-12.34f, 12.34), raft::CompareApprox(0.000001))); + ASSERT_TRUE(raft::match(12.34, raft::max(-12.34, 12.34f), raft::CompareApprox(0.000001))); +} + +TEST(MathHost, Max3) +{ + ASSERT_TRUE(raft::match(1234, raft::max(1234, 0, -1234), raft::Compare())); + ASSERT_TRUE(raft::match(1234, raft::max(-1234, 1234, 0), raft::Compare())); + ASSERT_TRUE(raft::match(1234, raft::max(0, -1234, 1234), raft::Compare())); + + ASSERT_TRUE( + raft::match(12.34, raft::max(12.34f, 0., -12.34), raft::CompareApprox(0.000001))); + ASSERT_TRUE( + raft::match(12.34, raft::max(-12.34, 12.34f, 0.), raft::CompareApprox(0.000001))); + ASSERT_TRUE( + raft::match(12.34, raft::max(0., -12.34, 12.34f), raft::CompareApprox(0.000001))); +} + +TEST(MathHost, Min2) +{ + ASSERT_TRUE(raft::match(-1234, raft::min(-1234, 1234), raft::Compare())); + ASSERT_TRUE(raft::match(123u, raft::min(1234u, 123u), raft::Compare())); + ASSERT_TRUE(raft::match(-1234ll, raft::min(-1234ll, 1234ll), raft::Compare())); + ASSERT_TRUE( + raft::match(123ull, raft::min(1234ull, 123ull), raft::Compare())); + + ASSERT_TRUE(raft::match(-12.34f, raft::min(-12.34f, 12.34f), raft::Compare())); + ASSERT_TRUE(raft::match(-12.34, raft::min(-12.34, 12.34), raft::Compare())); + ASSERT_TRUE( + raft::match(-12.34, raft::min(-12.34f, 12.34), raft::CompareApprox(0.000001))); + ASSERT_TRUE( + raft::match(-12.34, raft::min(-12.34, 12.34f), raft::CompareApprox(0.000001))); +} + +TEST(MathHost, Min3) +{ + ASSERT_TRUE(raft::match(-1234, raft::min(1234, 0, -1234), raft::Compare())); + ASSERT_TRUE(raft::match(-1234, raft::min(-1234, 1234, 0), raft::Compare())); + ASSERT_TRUE(raft::match(-1234, raft::min(0, -1234, 1234), raft::Compare())); + + ASSERT_TRUE( + raft::match(-12.34, raft::min(12.34f, 0., -12.34), raft::CompareApprox(0.000001))); + ASSERT_TRUE( + raft::match(-12.34, raft::min(-12.34, 12.34f, 0.), raft::CompareApprox(0.000001))); + ASSERT_TRUE( + raft::match(-12.34, raft::min(0., -12.34, 12.34f), raft::CompareApprox(0.000001))); +} + +TEST(MathHost, Pow) +{ + ASSERT_TRUE(raft::match( + std::pow(12.34f, 2.f), raft::pow(12.34f, 2.f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE( + raft::match(std::pow(12.34, 2.), raft::pow(12.34, 2.), raft::CompareApprox(0.000001))); +} + +TEST(MathHost, Sgn) +{ + ASSERT_TRUE(raft::match(-1, raft::sgn(-1234), raft::Compare())); + ASSERT_TRUE(raft::match(0, raft::sgn(0), raft::Compare())); + ASSERT_TRUE(raft::match(1, raft::sgn(1234), raft::Compare())); + ASSERT_TRUE(raft::match(-1, raft::sgn(-12.34f), raft::Compare())); + ASSERT_TRUE(raft::match(0, raft::sgn(0.f), raft::Compare())); + ASSERT_TRUE(raft::match(1, raft::sgn(12.34f), raft::Compare())); +} + +TEST(MathHost, Sin) +{ + ASSERT_TRUE( + raft::match(std::sin(12.34f), raft::sin(12.34f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE( + raft::match(std::sin(12.34), raft::sin(12.34), raft::CompareApprox(0.000001))); +} + +TEST(MathHost, SinCos) +{ + float xf = 12.34f; + float sf, cf; + raft::sincos(xf, &sf, &cf); + ASSERT_TRUE(raft::match(std::sin(12.34f), sf, raft::CompareApprox(0.0001f))); + ASSERT_TRUE(raft::match(std::cos(12.34f), cf, raft::CompareApprox(0.0001f))); + double xd = 12.34f; + double sd, cd; + raft::sincos(xd, &sd, &cd); + ASSERT_TRUE(raft::match(std::sin(12.34), sd, raft::CompareApprox(0.000001))); + ASSERT_TRUE(raft::match(std::cos(12.34), cd, raft::CompareApprox(0.000001))); +} + +TEST(MathHost, Sqrt) +{ + ASSERT_TRUE( + raft::match(std::sqrt(12.34f), raft::sqrt(12.34f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE( + raft::match(std::sqrt(12.34), raft::sqrt(12.34), raft::CompareApprox(0.000001))); +} + +TEST(MathHost, Tanh) +{ + ASSERT_TRUE( + raft::match(std::tanh(12.34f), raft::tanh(12.34f), raft::CompareApprox(0.0001f))); + ASSERT_TRUE( + raft::match(std::tanh(12.34), raft::tanh(12.34), raft::CompareApprox(0.000001))); +} diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index cbfd97ebc6..fedbee919d 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -52,7 +52,7 @@ __global__ void naiveDistanceKernel(DataType* dist, } if (type == raft::distance::DistanceType::L2SqrtExpanded || type == raft::distance::DistanceType::L2SqrtUnexpanded) - acc = raft::mySqrt(acc); + acc = raft::sqrt(acc); int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; dist[outidx] = acc; } @@ -79,9 +79,9 @@ __global__ void naiveL1_Linf_CanberraDistanceKernel(DataType* dist, auto b = y[yidx]; auto diff = (a > b) ? (a - b) : (b - a); if (type == raft::distance::DistanceType::Linf) { - acc = raft::myMax(acc, diff); + acc = raft::max(acc, diff); } else if (type == raft::distance::DistanceType::Canberra) { - const auto add = raft::myAbs(a) + raft::myAbs(b); + const auto add = raft::abs(a) + raft::abs(b); // deal with potential for 0 in denominator by // forcing 1/0 instead acc += ((add != 0) * diff / (add + (add == 0))); @@ -119,7 +119,7 @@ __global__ void naiveCosineDistanceKernel( int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; // Use 1.0 - (cosine similarity) to calc the distance - dist[outidx] = (DataType)1.0 - acc_ab / (raft::mySqrt(acc_a) * raft::mySqrt(acc_b)); + dist[outidx] = (DataType)1.0 - acc_ab / (raft::sqrt(acc_a) * raft::sqrt(acc_b)); } template @@ -137,7 +137,7 @@ __global__ void naiveHellingerDistanceKernel( int yidx = isRowMajor ? i + nidx * k : i * n + nidx; auto a = x[xidx]; auto b = y[yidx]; - acc_ab += raft::mySqrt(a) * raft::mySqrt(b); + acc_ab += raft::sqrt(a) * raft::sqrt(b); } int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; @@ -145,7 +145,7 @@ __global__ void naiveHellingerDistanceKernel( // Adjust to replace NaN in sqrt with 0 if input to sqrt is negative acc_ab = 1 - acc_ab; auto rectifier = (!signbit(acc_ab)); - dist[outidx] = raft::mySqrt(rectifier * acc_ab); + dist[outidx] = raft::sqrt(rectifier * acc_ab); } template @@ -167,11 +167,11 @@ __global__ void naiveLpUnexpDistanceKernel(DataType* dist, int yidx = isRowMajor ? i + nidx * k : i * n + nidx; auto a = x[xidx]; auto b = y[yidx]; - auto diff = raft::myAbs(a - b); - acc += raft::myPow(diff, p); + auto diff = raft::abs(a - b); + acc += raft::pow(diff, p); } auto one_over_p = 1 / p; - acc = raft::myPow(acc, one_over_p); + acc = raft::pow(acc, one_over_p); int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; dist[outidx] = acc; } @@ -222,7 +222,7 @@ __global__ void naiveJensenShannonDistanceKernel( acc += (-a * (!p_zero * log(p + p_zero))) + (-b * (!q_zero * log(q + q_zero))); } - acc = raft::mySqrt(0.5f * acc); + acc = raft::sqrt(0.5f * acc); int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; dist[outidx] = acc; } @@ -297,7 +297,7 @@ __global__ void naiveCorrelationDistanceKernel( auto Q_denom = k * a_sq_norm - (a_norm * a_norm); auto R_denom = k * b_sq_norm - (b_norm * b_norm); - acc = 1 - (numer / raft::mySqrt(Q_denom * R_denom)); + acc = 1 - (numer / raft::sqrt(Q_denom * R_denom)); int outidx = isRowMajor ? midx * n + nidx : midx + m * nidx; dist[outidx] = acc; diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index e746a2382d..54de12307a 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -60,7 +60,7 @@ __global__ void naiveKernel(raft::KeyValuePair* min, auto diff = midx >= m || nidx >= n ? DataT(0) : x[xidx] - y[yidx]; acc += diff * diff; } - if (Sqrt) { acc = raft::mySqrt(acc); } + if (Sqrt) { acc = raft::sqrt(acc); } ReduceOpT redOp; typedef cub::WarpReduce> WarpReduce; __shared__ typename WarpReduce::TempStorage temp[NWARPS]; diff --git a/cpp/test/linalg/matrix_vector.cu b/cpp/test/linalg/matrix_vector.cu index 7018e1da96..fb1e2235f9 100644 --- a/cpp/test/linalg/matrix_vector.cu +++ b/cpp/test/linalg/matrix_vector.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -116,7 +116,7 @@ void naive_matrix_vector_op_launch(const raft::handle_t& handle, } }; auto operation_bin_div_skip_zero = [] __device__(T mat_element, T vec_element) { - if (raft::myAbs(vec_element) < T(1e-10)) + if (raft::abs(vec_element) < T(1e-10)) return T(0); else return mat_element / vec_element; diff --git a/cpp/test/linalg/norm.cu b/cpp/test/linalg/norm.cu index 94540b9ff6..90cfbd8f89 100644 --- a/cpp/test/linalg/norm.cu +++ b/cpp/test/linalg/norm.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -56,10 +56,10 @@ __global__ void naiveRowNormKernel( if (type == L2Norm) { acc += data[rowStart * D + i] * data[rowStart * D + i]; } else { - acc += raft::myAbs(data[rowStart * D + i]); + acc += raft::abs(data[rowStart * D + i]); } } - dots[rowStart] = do_sqrt ? raft::mySqrt(acc) : acc; + dots[rowStart] = do_sqrt ? raft::sqrt(acc) : acc; } } @@ -131,10 +131,10 @@ __global__ void naiveColNormKernel( Type acc = 0; for (IdxT i = 0; i < N; i++) { Type v = data[colID + i * D]; - acc += type == L2Norm ? v * v : raft::myAbs(v); + acc += type == L2Norm ? v * v : raft::abs(v); } - dots[colID] = do_sqrt ? raft::mySqrt(acc) : acc; + dots[colID] = do_sqrt ? raft::sqrt(acc) : acc; } template diff --git a/cpp/test/linalg/power.cu b/cpp/test/linalg/power.cu index 54c2e2a7aa..5cb63a5697 100644 --- a/cpp/test/linalg/power.cu +++ b/cpp/test/linalg/power.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,7 +27,7 @@ template __global__ void naivePowerElemKernel(Type* out, const Type* in1, const Type* in2, int len) { int idx = threadIdx.x + blockIdx.x * blockDim.x; - if (idx < len) { out[idx] = raft::myPow(in1[idx], in2[idx]); } + if (idx < len) { out[idx] = raft::pow(in1[idx], in2[idx]); } } template @@ -43,7 +43,7 @@ template __global__ void naivePowerScalarKernel(Type* out, const Type* in1, const Type in2, int len) { int idx = threadIdx.x + blockIdx.x * blockDim.x; - if (idx < len) { out[idx] = raft::myPow(in1[idx], in2); } + if (idx < len) { out[idx] = raft::pow(in1[idx], in2); } } template diff --git a/cpp/test/linalg/sqrt.cu b/cpp/test/linalg/sqrt.cu index 9008313b58..93150ca77d 100644 --- a/cpp/test/linalg/sqrt.cu +++ b/cpp/test/linalg/sqrt.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -27,7 +27,7 @@ template __global__ void naiveSqrtElemKernel(Type* out, const Type* in1, int len) { int idx = threadIdx.x + blockIdx.x * blockDim.x; - if (idx < len) { out[idx] = raft::mySqrt(in1[idx]); } + if (idx < len) { out[idx] = raft::sqrt(in1[idx]); } } template diff --git a/cpp/test/matrix/math.cu b/cpp/test/matrix/math.cu index f2c1a6249c..9dcbfc8899 100644 --- a/cpp/test/matrix/math.cu +++ b/cpp/test/matrix/math.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -51,7 +51,7 @@ template __global__ void naiveSqrtKernel(Type* in, Type* out, int len) { int idx = threadIdx.x + blockIdx.x * blockDim.x; - if (idx < len) { out[idx] = std::sqrt(in[idx]); } + if (idx < len) { out[idx] = raft::sqrt(in[idx]); } } template diff --git a/cpp/test/neighbors/ann_utils.cuh b/cpp/test/neighbors/ann_utils.cuh index b88b6abd9e..bb2f334db4 100644 --- a/cpp/test/neighbors/ann_utils.cuh +++ b/cpp/test/neighbors/ann_utils.cuh @@ -131,7 +131,7 @@ __global__ void naive_distance_kernel(EvalT* dist, } if (type == raft::distance::DistanceType::L2SqrtExpanded || type == raft::distance::DistanceType::L2SqrtUnexpanded) - acc = raft::mySqrt(acc); + acc = raft::sqrt(acc); dist[midx * n + nidx] = acc; } } diff --git a/cpp/test/random/rng.cu b/cpp/test/random/rng.cu index bdce79b76e..0bf494b624 100644 --- a/cpp/test/random/rng.cu +++ b/cpp/test/random/rng.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -145,8 +145,8 @@ class RngTest : public ::testing::TestWithParam> { case RNG_LogNormal: { auto var = params.end * params.end; auto mu = params.start; - meanvar[0] = raft::myExp(mu + var * T(0.5)); - meanvar[1] = (raft::myExp(var) - T(1.0)) * raft::myExp(T(2.0) * mu + var); + meanvar[0] = raft::exp(mu + var * T(0.5)); + meanvar[1] = (raft::exp(var) - T(1.0)) * raft::exp(T(2.0) * mu + var); break; } case RNG_Uniform: @@ -169,7 +169,7 @@ class RngTest : public ::testing::TestWithParam> { meanvar[1] = meanvar[0] * meanvar[0]; break; case RNG_Rayleigh: - meanvar[0] = params.start * raft::mySqrt(T(3.1415 / 2.0)); + meanvar[0] = params.start * raft::sqrt(T(3.1415 / 2.0)); meanvar[1] = ((T(4.0) - T(3.1415)) / T(2.0)) * params.start * params.start; break; case RNG_Laplace: @@ -239,8 +239,8 @@ class RngMdspanTest : public ::testing::TestWithParam> { case RNG_LogNormal: { auto var = params.end * params.end; auto mu = params.start; - meanvar[0] = raft::myExp(mu + var * T(0.5)); - meanvar[1] = (raft::myExp(var) - T(1.0)) * raft::myExp(T(2.0) * mu + var); + meanvar[0] = raft::exp(mu + var * T(0.5)); + meanvar[1] = (raft::exp(var) - T(1.0)) * raft::exp(T(2.0) * mu + var); break; } case RNG_Uniform: @@ -263,7 +263,7 @@ class RngMdspanTest : public ::testing::TestWithParam> { meanvar[1] = meanvar[0] * meanvar[0]; break; case RNG_Rayleigh: - meanvar[0] = params.start * raft::mySqrt(T(3.1415 / 2.0)); + meanvar[0] = params.start * raft::sqrt(T(3.1415 / 2.0)); meanvar[1] = ((T(4.0) - T(3.1415)) / T(2.0)) * params.start * params.start; break; case RNG_Laplace: From 0e96662f9b4fc77cd4ac6e528fe6103c81715287 Mon Sep 17 00:00:00 2001 From: Louis Sugy Date: Sat, 21 Jan 2023 21:55:43 +0100 Subject: [PATCH 6/9] Improvements in `matrix::gather`: test coverage, compilation errors, performance (#1126) In order to deprecate `copy_selected` from `ann_utils.cuh`, I wanted to make sure that the performance of `matrix::gather` was on par. But in the process I discovered that: - Map transforms and conditional copy were not tested at all. - In fact, most of the API in `gather.cuh` wasn't covered in tests and some of the functions didn't even compile. - The same type `MatrixIteratorT` was used for the input and output iterators, which made it impossible to take advantage of custom iterators, as is needed in `kmeans_balanced` to convert the dataset from `T` to `float` and gather in a single step. - The performance was really poor when `D` is small because the kernel assigns one block per row (so a block could be working on only 2 or 3 elements...) This PR addresses all the aforementioned issues. Authors: - Louis Sugy (https://github.com/Nyrio) Approvers: - Tamas Bela Feher (https://github.com/tfeher) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1126 --- cpp/bench/CMakeLists.txt | 2 +- cpp/bench/matrix/argmin.cu | 17 +- cpp/bench/matrix/gather.cu | 101 +++++ .../raft/cluster/detail/kmeans_common.cuh | 2 +- cpp/include/raft/core/operators.hpp | 51 ++- cpp/include/raft/matrix/detail/gather.cuh | 236 ++++++----- cpp/include/raft/matrix/gather.cuh | 371 ++++++++---------- cpp/test/matrix/gather.cu | 208 ++++++---- 8 files changed, 578 insertions(+), 410 deletions(-) create mode 100644 cpp/bench/matrix/gather.cu diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index 813483adc5..8dcdb325e9 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -103,7 +103,7 @@ if(BUILD_BENCH) bench/main.cpp ) - ConfigureBench(NAME MATRIX_BENCH PATH bench/matrix/argmin.cu bench/main.cpp) + ConfigureBench(NAME MATRIX_BENCH PATH bench/matrix/argmin.cu bench/matrix/gather.cu bench/main.cpp) ConfigureBench( NAME RANDOM_BENCH PATH bench/random/make_blobs.cu bench/random/permute.cu bench/random/rng.cu diff --git a/cpp/bench/matrix/argmin.cu b/cpp/bench/matrix/argmin.cu index 0d0dea0fdb..52f5aab7f3 100644 --- a/cpp/bench/matrix/argmin.cu +++ b/cpp/bench/matrix/argmin.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,10 +17,11 @@ #include #include #include +#include #include -namespace raft::bench::linalg { +namespace raft::bench::matrix { template struct ArgminParams { @@ -57,15 +58,11 @@ struct Argmin : public fixture { raft::device_vector indices; }; // struct Argmin -const std::vector> argmin_inputs_i64{ - {1000, 64}, {1000, 128}, {1000, 256}, {1000, 512}, {1000, 1024}, - {10000, 64}, {10000, 128}, {10000, 256}, {10000, 512}, {10000, 1024}, - {100000, 64}, {100000, 128}, {100000, 256}, {100000, 512}, {100000, 1024}, - {1000000, 64}, {1000000, 128}, {1000000, 256}, {1000000, 512}, {1000000, 1024}, - {10000000, 64}, {10000000, 128}, {10000000, 256}, {10000000, 512}, {10000000, 1024}, -}; +const std::vector> argmin_inputs_i64 = + raft::util::itertools::product>({1000, 10000, 100000, 1000000, 10000000}, + {64, 128, 256, 512, 1024}); RAFT_BENCH_REGISTER((Argmin), "", argmin_inputs_i64); RAFT_BENCH_REGISTER((Argmin), "", argmin_inputs_i64); -} // namespace raft::bench::linalg +} // namespace raft::bench::matrix diff --git a/cpp/bench/matrix/gather.cu b/cpp/bench/matrix/gather.cu new file mode 100644 index 0000000000..97812c20a1 --- /dev/null +++ b/cpp/bench/matrix/gather.cu @@ -0,0 +1,101 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include + +namespace raft::bench::matrix { + +template +struct GatherParams { + IdxT rows, cols, map_length; +}; + +template +inline auto operator<<(std::ostream& os, const GatherParams& p) -> std::ostream& +{ + os << p.rows << "#" << p.cols << "#" << p.map_length; + return os; +} + +template +struct Gather : public fixture { + Gather(const GatherParams& p) : params(p) {} + + void allocate_data(const ::benchmark::State& state) override + { + matrix = raft::make_device_matrix(handle, params.rows, params.cols); + map = raft::make_device_vector(handle, params.map_length); + out = raft::make_device_matrix(handle, params.map_length, params.cols); + stencil = raft::make_device_vector(handle, Conditional ? params.map_length : IdxT(0)); + + raft::random::RngState rng{1234}; + raft::random::uniform( + rng, matrix.data_handle(), params.rows * params.cols, T(-1), T(1), stream); + raft::random::uniformInt( + handle, rng, map.data_handle(), params.map_length, (MapT)0, (MapT)params.rows); + if constexpr (Conditional) { + raft::random::uniform(rng, stencil.data_handle(), params.map_length, T(-1), T(1), stream); + } + handle.sync_stream(stream); + } + + void run_benchmark(::benchmark::State& state) override + { + std::ostringstream label_stream; + label_stream << params; + state.SetLabel(label_stream.str()); + + loop_on_state(state, [this]() { + auto matrix_const_view = raft::make_device_matrix_view( + matrix.data_handle(), matrix.extent(0), matrix.extent(1)); + auto map_const_view = + raft::make_device_vector_view(map.data_handle(), map.extent(0)); + if constexpr (Conditional) { + auto stencil_const_view = + raft::make_device_vector_view(stencil.data_handle(), stencil.extent(0)); + auto pred_op = raft::plug_const_op(T(0.0), raft::greater_op()); + raft::matrix::gather_if( + handle, matrix_const_view, out.view(), map_const_view, stencil_const_view, pred_op); + } else { + raft::matrix::gather(handle, matrix_const_view, map_const_view, out.view()); + } + }); + } + + private: + GatherParams params; + raft::device_matrix matrix, out; + raft::device_vector stencil; + raft::device_vector map; +}; // struct Gather + +template +using GatherIf = Gather; + +const std::vector> gather_inputs_i64 = + raft::util::itertools::product>( + {1000000}, {10, 20, 50, 100, 200, 500}, {1000, 10000, 100000, 1000000}); + +RAFT_BENCH_REGISTER((Gather), "", gather_inputs_i64); +RAFT_BENCH_REGISTER((Gather), "", gather_inputs_i64); +RAFT_BENCH_REGISTER((GatherIf), "", gather_inputs_i64); +RAFT_BENCH_REGISTER((GatherIf), "", gather_inputs_i64); +} // namespace raft::bench::matrix diff --git a/cpp/include/raft/cluster/detail/kmeans_common.cuh b/cpp/include/raft/cluster/detail/kmeans_common.cuh index 2fd33ac759..559793442f 100644 --- a/cpp/include/raft/cluster/detail/kmeans_common.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_common.cuh @@ -335,7 +335,7 @@ void shuffleAndGather(const raft::handle_t& handle, in.extent(1), in.extent(0), indices.data_handle(), - n_samples_to_gather, + static_cast(n_samples_to_gather), out.data_handle(), stream); } diff --git a/cpp/include/raft/core/operators.hpp b/cpp/include/raft/core/operators.hpp index de27c2b271..edb437c880 100644 --- a/cpp/include/raft/core/operators.hpp +++ b/cpp/include/raft/core/operators.hpp @@ -147,6 +147,14 @@ struct div_checkzero_op { } }; +struct modulo_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a % b; + } +}; + struct pow_op { template RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const @@ -189,17 +197,49 @@ struct argmax_op { } }; +struct greater_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a > b; + } +}; + +struct less_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a < b; + } +}; + +struct greater_or_equal_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a >= b; + } +}; + +struct less_or_equal_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a <= b; + } +}; + struct equal_op { - template - constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const { return a == b; } }; struct notequal_op { - template - constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const { return a != b; } @@ -267,6 +307,9 @@ using div_const_op = plug_const_op; template using div_checkzero_const_op = plug_const_op; +template +using modulo_const_op = plug_const_op; + template using pow_const_op = plug_const_op; diff --git a/cpp/include/raft/matrix/detail/gather.cuh b/cpp/include/raft/matrix/detail/gather.cuh index c006f69e47..a8efc2d0d0 100644 --- a/cpp/include/raft/matrix/detail/gather.cuh +++ b/cpp/include/raft/matrix/detail/gather.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,41 +17,63 @@ #pragma once #include +#include namespace raft { namespace matrix { namespace detail { -// gatherKernel conditionally copies rows from the source matrix 'in' into the destination matrix -// 'out' according to a map (or a transformed map) -template +struct gather_policy { + static constexpr int n_threads = tpb; + static constexpr int work_per_thread = wpt; + static constexpr int stride = tpb * wpt; +}; + +/** Conditionally copies rows from the source matrix 'in' into the destination matrix + * 'out' according to a map (or a transformed map) */ +template -__global__ void gatherKernel(const MatrixIteratorT in, - IndexT D, - IndexT N, - MapIteratorT map, - StencilIteratorT stencil, - MatrixIteratorT out, - PredicateOp pred_op, - MapTransformOp transform_op) + typename OutputIteratorT, + typename IndexT> +__global__ void gather_kernel(const InputIteratorT in, + IndexT D, + IndexT len, + const MapIteratorT map, + StencilIteratorT stencil, + OutputIteratorT out, + PredicateOp pred_op, + MapTransformOp transform_op) { typedef typename std::iterator_traits::value_type MapValueT; typedef typename std::iterator_traits::value_type StencilValueT; - IndexT outRowStart = blockIdx.x * D; - MapValueT map_val = map[blockIdx.x]; - StencilValueT stencil_val = stencil[blockIdx.x]; +#pragma unroll + for (IndexT wid = 0; wid < Policy::work_per_thread; wid++) { + IndexT tid = threadIdx.x + (Policy::work_per_thread * static_cast(blockIdx.x) + wid) * + Policy::n_threads; + if (tid < len) { + IndexT i_dst = tid / D; + IndexT j = tid % D; + + MapValueT map_val = map[i_dst]; + StencilValueT stencil_val = stencil[i_dst]; - bool predicate = pred_op(stencil_val); - if (predicate) { - IndexT inRowStart = transform_op(map_val) * D; - for (int i = threadIdx.x; i < D; i += TPB) { - out[outRowStart + i] = in[inRowStart + i]; + bool predicate = pred_op(stencil_val); + if (predicate) { + IndexT i_src = transform_op(map_val); + out[tid] = in[i_src * D + j]; + } } } } @@ -60,7 +82,7 @@ __global__ void gatherKernel(const MatrixIteratorT in, * @brief gather conditionally copies rows from a source matrix into a destination matrix according * to a transformed map. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * @tparam InputIteratorT Random-access iterator type, for reading input matrix (may be a * simple pointer type). * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple * pointer type). @@ -69,7 +91,10 @@ __global__ void gatherKernel(const MatrixIteratorT in, * @tparam UnaryPredicateOp Unary lambda expression or operator type, UnaryPredicateOp's result * type must be convertible to bool type. * @tparam MapTransformOp Unary lambda expression or operator type, MapTransformOp's result - * type must be convertible to IndexT (= int) type. + * type must be convertible to IndexT. + * @tparam OutputIteratorT Random-access iterator type, for writing output matrix (may be a + * simple pointer type). + * @tparam IndexT Index type. * * @param in Pointer to the input matrix (assumed to be row-major) * @param D Leading dimension of the input matrix 'in', which in-case of row-major @@ -83,18 +108,20 @@ __global__ void gatherKernel(const MatrixIteratorT in, * @param transform_op The transformation operation, transforms the map values to IndexT * @param stream CUDA stream to launch kernels within */ -template -void gatherImpl(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, + typename MapTransformOp, + typename OutputIteratorT, + typename IndexT> +void gatherImpl(const InputIteratorT in, + IndexT D, + IndexT N, + const MapIteratorT map, StencilIteratorT stencil, - int map_length, - MatrixIteratorT out, + IndexT map_length, + OutputIteratorT out, UnaryPredicateOp pred_op, MapTransformOp transform_op, cudaStream_t stream) @@ -102,9 +129,6 @@ void gatherImpl(const MatrixIteratorT in, // skip in case of 0 length input if (map_length <= 0 || N <= 0 || D <= 0) return; - // signed integer type for indexing or global offsets - typedef int IndexT; - // map value type typedef typename std::iterator_traits::value_type MapValueT; @@ -121,38 +145,26 @@ void gatherImpl(const MatrixIteratorT in, static_assert((std::is_convertible::value), "UnaryPredicateOp's result type must be convertible to bool type"); - if (D <= 32) { - gatherKernel - <<>>(in, D, N, map, stencil, out, pred_op, transform_op); - } else if (D <= 64) { - gatherKernel - <<>>(in, D, N, map, stencil, out, pred_op, transform_op); - } else if (D <= 128) { - gatherKernel - <<>>(in, D, N, map, stencil, out, pred_op, transform_op); + IndexT len = map_length * D; + constexpr int TPB = 128; + const int n_sm = raft::getMultiProcessorCount(); + // The following empirical heuristics enforce that we keep a good balance between having enough + // blocks and enough work per thread. + if (len < 32 * TPB * n_sm) { + using Policy = gather_policy; + IndexT n_blocks = raft::ceildiv(map_length * D, static_cast(Policy::stride)); + gather_kernel<<>>( + in, D, len, map, stencil, out, pred_op, transform_op); + } else if (len < 32 * 4 * TPB * n_sm) { + using Policy = gather_policy; + IndexT n_blocks = raft::ceildiv(map_length * D, static_cast(Policy::stride)); + gather_kernel<<>>( + in, D, len, map, stencil, out, pred_op, transform_op); } else { - gatherKernel - <<>>(in, D, N, map, stencil, out, pred_op, transform_op); + using Policy = gather_policy; + IndexT n_blocks = raft::ceildiv(map_length * D, static_cast(Policy::stride)); + gather_kernel<<>>( + in, D, len, map, stencil, out, pred_op, transform_op); } RAFT_CUDA_TRY(cudaPeekAtLastError()); } @@ -160,10 +172,13 @@ void gatherImpl(const MatrixIteratorT in, /** * @brief gather copies rows from a source matrix into a destination matrix according to a map. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * @tparam InputIteratorT Random-access iterator type, for reading input matrix (may be a * simple pointer type). * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple * pointer type). + * @tparam OutputIteratorT Random-access iterator type, for writing output matrix (may be a + * simple pointer type). + * @tparam IndexT Index type. * * @param in Pointer to the input matrix (assumed to be row-major) * @param D Leading dimension of the input matrix 'in', which in-case of row-major @@ -174,13 +189,13 @@ void gatherImpl(const MatrixIteratorT in, * @param out Pointer to the output matrix (assumed to be row-major) * @param stream CUDA stream to launch kernels within */ -template -void gather(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, - int map_length, - MatrixIteratorT out, +template +void gather(const InputIteratorT in, + IndexT D, + IndexT N, + const MapIteratorT map, + IndexT map_length, + OutputIteratorT out, cudaStream_t stream) { typedef typename std::iterator_traits::value_type MapValueT; @@ -192,12 +207,15 @@ void gather(const MatrixIteratorT in, * @brief gather copies rows from a source matrix into a destination matrix according to a * transformed map. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * @tparam InputIteratorT Random-access iterator type, for reading input matrix (may be a * simple pointer type). * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple * pointer type). * @tparam MapTransformOp Unary lambda expression or operator type, MapTransformOp's result - * type must be convertible to IndexT (= int) type. + * type must be convertible to IndexT. + * @tparam OutputIteratorT Random-access iterator type, for writing output matrix (may be a + * simple pointer type). + * @tparam IndexT Index type. * * @param in Pointer to the input matrix (assumed to be row-major) * @param D Leading dimension of the input matrix 'in', which in-case of row-major @@ -209,13 +227,17 @@ void gather(const MatrixIteratorT in, * @param transform_op The transformation operation, transforms the map values to IndexT * @param stream CUDA stream to launch kernels within */ -template -void gather(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, - int map_length, - MatrixIteratorT out, +template +void gather(const InputIteratorT in, + IndexT D, + IndexT N, + const MapIteratorT map, + IndexT map_length, + OutputIteratorT out, MapTransformOp transform_op, cudaStream_t stream) { @@ -227,7 +249,7 @@ void gather(const MatrixIteratorT in, * @brief gather_if conditionally copies rows from a source matrix into a destination matrix * according to a map. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * @tparam InputIteratorT Random-access iterator type, for reading input matrix (may be a * simple pointer type). * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple * pointer type). @@ -235,6 +257,9 @@ void gather(const MatrixIteratorT in, * simple pointer type). * @tparam UnaryPredicateOp Unary lambda expression or operator type, UnaryPredicateOp's result * type must be convertible to bool type. + * @tparam OutputIteratorT Random-access iterator type, for writing output matrix (may be a + * simple pointer type). + * @tparam IndexT Index type. * * @param in Pointer to the input matrix (assumed to be row-major) * @param D Leading dimension of the input matrix 'in', which in-case of row-major @@ -247,17 +272,19 @@ void gather(const MatrixIteratorT in, * @param pred_op Predicate to apply to the stencil values * @param stream CUDA stream to launch kernels within */ -template -void gather_if(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, + typename UnaryPredicateOp, + typename OutputIteratorT, + typename IndexT> +void gather_if(const InputIteratorT in, + IndexT D, + IndexT N, + const MapIteratorT map, StencilIteratorT stencil, - int map_length, - MatrixIteratorT out, + IndexT map_length, + OutputIteratorT out, UnaryPredicateOp pred_op, cudaStream_t stream) { @@ -269,7 +296,7 @@ void gather_if(const MatrixIteratorT in, * @brief gather_if conditionally copies rows from a source matrix into a destination matrix * according to a transformed map. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * @tparam InputIteratorT Random-access iterator type, for reading input matrix (may be a * simple pointer type). * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple * pointer type). @@ -278,7 +305,10 @@ void gather_if(const MatrixIteratorT in, * @tparam UnaryPredicateOp Unary lambda expression or operator type, UnaryPredicateOp's result * type must be convertible to bool type. * @tparam MapTransformOp Unary lambda expression or operator type, MapTransformOp's result - * type must be convertible to IndexT (= int) type. + * type must be convertible to IndexT type. + * @tparam OutputIteratorT Random-access iterator type, for writing output matrix (may be a + * simple pointer type). + * @tparam IndexT Index type. * * @param in Pointer to the input matrix (assumed to be row-major) * @param D Leading dimension of the input matrix 'in', which in-case of row-major @@ -292,18 +322,20 @@ void gather_if(const MatrixIteratorT in, * @param transform_op The transformation operation, transforms the map values to IndexT * @param stream CUDA stream to launch kernels within */ -template -void gather_if(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, + typename MapTransformOp, + typename OutputIteratorT, + typename IndexT> +void gather_if(const InputIteratorT in, + IndexT D, + IndexT N, + const MapIteratorT map, StencilIteratorT stencil, - int map_length, - MatrixIteratorT out, + IndexT map_length, + OutputIteratorT out, UnaryPredicateOp pred_op, MapTransformOp transform_op, cudaStream_t stream) diff --git a/cpp/include/raft/matrix/gather.cuh b/cpp/include/raft/matrix/gather.cuh index 6a923fb0cc..9487da35b5 100644 --- a/cpp/include/raft/matrix/gather.cuh +++ b/cpp/include/raft/matrix/gather.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ #include #include #include +#include namespace raft::matrix { @@ -28,62 +29,68 @@ namespace raft::matrix { */ /** - * @brief gather copies rows from a source matrix into a destination matrix according to a map. + * @brief Copies rows from a source matrix into a destination matrix according to a map. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a - * simple pointer type). - * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple - * pointer type). + * For each output row, read the index in the input matrix from the map and copy the row. * - * @param in Pointer to the input matrix (assumed to be row-major) - * @param D Leading dimension of the input matrix 'in', which in-case of row-major - * storage is the number of columns - * @param N Second dimension - * @param map Pointer to the input sequence of gather locations - * @param map_length The length of 'map' and 'stencil' - * @param out Pointer to the output matrix (assumed to be row-major) + * @tparam InputIteratorT Input iterator type, for the input matrix (may be a pointer type). + * @tparam MapIteratorT Input iterator type, for the map (may be a pointer type). + * @tparam OutputIteratorT Output iterator type, for the output matrix (may be a pointer type). + * @tparam IndexT Index type. + * + * @param in Input matrix, dim = [N, D] (row-major) + * @param D Number of columns of the input/output matrices + * @param N Number of rows of the input matrix + * @param map Map of row indices to gather, dim = [map_length] + * @param map_length The length of 'map', number of rows of the output matrix + * @param out Output matrix, dim = [map_length, D] (row-major) * @param stream CUDA stream to launch kernels within */ -template -void gather(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, - int map_length, - MatrixIteratorT out, +template +void gather(const InputIteratorT in, + IndexT D, + IndexT N, + const MapIteratorT map, + IndexT map_length, + OutputIteratorT out, cudaStream_t stream) { detail::gather(in, D, N, map, map_length, out, stream); } /** - * @brief gather copies rows from a source matrix into a destination matrix according to a - * transformed map. + * @brief Copies rows from a source matrix into a destination matrix according to a transformed map. + * + * For each output row, read the index in the input matrix from the map, apply a transformation to + * this input index and copy the row. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a - * simple pointer type). - * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple - * pointer type). - * @tparam MapTransformOp Unary lambda expression or operator type, MapTransformOp's result - * type must be convertible to IndexT (= int) type. + * @tparam InputIteratorT Input iterator type, for the input matrix (may be a pointer type). + * @tparam MapIteratorT Input iterator type, for the map (may be a pointer type). + * @tparam MapTransformOp Unary lambda expression or operator type. MapTransformOp's result type + * must be convertible to IndexT. + * @tparam OutputIteratorT Output iterator type, for the output matrix (may be a pointer type). + * @tparam IndexT Index type. * - * @param in Pointer to the input matrix (assumed to be row-major) - * @param D Leading dimension of the input matrix 'in', which in-case of row-major - * storage is the number of columns - * @param N Second dimension - * @param map Pointer to the input sequence of gather locations - * @param map_length The length of 'map' and 'stencil' - * @param out Pointer to the output matrix (assumed to be row-major) - * @param transform_op The transformation operation, transforms the map values to IndexT + * @param in Input matrix, dim = [N, D] (row-major) + * @param D Number of columns of the input/output matrices + * @param N Number of rows of the input matrix + * @param map Map of row indices to gather, dim = [map_length] + * @param map_length The length of 'map', number of rows of the output matrix + * @param out Output matrix, dim = [map_length, D] (row-major) + * @param transform_op Transformation to apply to map values * @param stream CUDA stream to launch kernels within */ -template -void gather(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, - int map_length, - MatrixIteratorT out, +template +void gather(const InputIteratorT in, + IndexT D, + IndexT N, + const MapIteratorT map, + IndexT map_length, + OutputIteratorT out, MapTransformOp transform_op, cudaStream_t stream) { @@ -91,40 +98,42 @@ void gather(const MatrixIteratorT in, } /** - * @brief gather_if conditionally copies rows from a source matrix into a destination matrix - * according to a map. + * @brief Conditionally copies rows from a source matrix into a destination matrix. + * + * For each output row, read the index in the input matrix from the map, read a stencil value, apply + * a predicate to the stencil value, and if true, copy the row. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a - * simple pointer type). - * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple - * pointer type). - * @tparam StencilIteratorT Random-access iterator type, for reading input stencil (may be a - * simple pointer type). - * @tparam UnaryPredicateOp Unary lambda expression or operator type, UnaryPredicateOp's result - * type must be convertible to bool type. + * @tparam InputIteratorT Input iterator type, for the input matrix (may be a pointer type). + * @tparam MapIteratorT Input iterator type, for the map (may be a pointer type). + * @tparam StencilIteratorT Input iterator type, for the stencil (may be a pointer type). + * @tparam UnaryPredicateOp Unary lambda expression or operator type. UnaryPredicateOp's result type + * must be convertible to bool type. + * @tparam OutputIteratorT Output iterator type, for the output matrix (may be a pointer type). + * @tparam IndexT Index type. * - * @param in Pointer to the input matrix (assumed to be row-major) - * @param D Leading dimension of the input matrix 'in', which in-case of row-major - * storage is the number of columns - * @param N Second dimension - * @param map Pointer to the input sequence of gather locations - * @param stencil Pointer to the input sequence of stencil or predicate values - * @param map_length The length of 'map' and 'stencil' - * @param out Pointer to the output matrix (assumed to be row-major) + * @param in Input matrix, dim = [N, D] (row-major) + * @param D Number of columns of the input/output matrices + * @param N Number of rows of the input matrix + * @param map Map of row indices to gather, dim = [map_length] + * @param stencil Sequence of stencil values, dim = [map_length] + * @param map_length The length of 'map' and 'stencil', number of rows of the output matrix + * @param out Output matrix, dim = [map_length, D] (row-major) * @param pred_op Predicate to apply to the stencil values * @param stream CUDA stream to launch kernels within */ -template -void gather_if(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, + typename UnaryPredicateOp, + typename OutputIteratorT, + typename IndexT> +void gather_if(const InputIteratorT in, + IndexT D, + IndexT N, + const MapIteratorT map, StencilIteratorT stencil, - int map_length, - MatrixIteratorT out, + IndexT map_length, + OutputIteratorT out, UnaryPredicateOp pred_op, cudaStream_t stream) { @@ -132,44 +141,47 @@ void gather_if(const MatrixIteratorT in, } /** - * @brief gather_if conditionally copies rows from a source matrix into a destination matrix - * according to a transformed map. + * @brief Conditionally copies rows according to a transformed map. + * + * For each output row, read the index in the input matrix from the map, read a stencil value, + * apply a predicate to the stencil value, and if true, apply a transformation to the input index + * and copy the row. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a - * simple pointer type). - * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple - * pointer type). - * @tparam StencilIteratorT Random-access iterator type, for reading input stencil (may be a - * simple pointer type). - * @tparam UnaryPredicateOp Unary lambda expression or operator type, UnaryPredicateOp's result - * type must be convertible to bool type. - * @tparam MapTransformOp Unary lambda expression or operator type, MapTransformOp's result - * type must be convertible to IndexT (= int) type. + * @tparam InputIteratorT Input iterator type, for the input matrix (may be a pointer type). + * @tparam MapIteratorT Input iterator type, for the map (may be a pointer type). + * @tparam MapTransformOp Unary lambda expression or operator type. MapTransformOp's result type + * must be convertible to IndexT. + * @tparam StencilIteratorT Input iterator type, for the stencil (may be a pointer type). + * @tparam UnaryPredicateOp Unary lambda expression or operator type. UnaryPredicateOp's result type + * must be convertible to bool type. + * @tparam OutputIteratorT Output iterator type, for the output matrix (may be a pointer type). + * @tparam IndexT Index type. * - * @param in Pointer to the input matrix (assumed to be row-major) - * @param D Leading dimension of the input matrix 'in', which in-case of row-major - * storage is the number of columns - * @param N Second dimension - * @param map Pointer to the input sequence of gather locations - * @param stencil Pointer to the input sequence of stencil or predicate values - * @param map_length The length of 'map' and 'stencil' - * @param out Pointer to the output matrix (assumed to be row-major) + * @param in Input matrix, dim = [N, D] (row-major) + * @param D Number of columns of the input/output matrices + * @param N Number of rows of the input matrix + * @param map Map of row indices to gather, dim = [map_length] + * @param stencil Sequence of stencil values, dim = [map_length] + * @param map_length The length of 'map' and 'stencil', number of rows of the output matrix + * @param out Output matrix, dim = [map_length, D] (row-major) * @param pred_op Predicate to apply to the stencil values - * @param transform_op The transformation operation, transforms the map values to IndexT + * @param transform_op Transformation to apply to map values * @param stream CUDA stream to launch kernels within */ -template -void gather_if(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, + typename MapTransformOp, + typename OutputIteratorT, + typename IndexT> +void gather_if(const InputIteratorT in, + IndexT D, + IndexT N, + const MapIteratorT map, StencilIteratorT stencil, - int map_length, - MatrixIteratorT out, + IndexT map_length, + OutputIteratorT out, UnaryPredicateOp pred_op, MapTransformOp transform_op, cudaStream_t stream) @@ -178,58 +190,31 @@ void gather_if(const MatrixIteratorT in, } /** - * @brief gather copies rows from a source matrix into a destination matrix according to a map. + * @brief Copies rows from a source matrix into a destination matrix according to a transformed map. * - * @tparam matrix_t Matrix element type - * @tparam map_t Map vector type - * @tparam idx_t integer type used for indexing - * @param[in] handle raft handle for managing resources - * @param[in] in Input matrix (assumed to be row-major) - * @param[in] map Vector of gather locations - * @param[out] out Output matrix (assumed to be row-major) - */ -template -void gather(const raft::handle_t& handle, - raft::device_matrix_view in, - raft::device_vector_view map, - raft::device_matrix_view out) -{ - RAFT_EXPECTS(out.extent(0) == map.extent(0), - "Number of rows in output matrix must equal the size of the map vector"); - RAFT_EXPECTS(out.extent(1) == in.extent(1), - "Number of columns in input and output matrices must be equal."); - - raft::matrix::detail::gather( - const_cast(in.data_handle()), // TODO: There's a better way to handle this - static_cast(in.extent(1)), - static_cast(in.extent(0)), - map.data_handle(), - static_cast(map.extent(0)), - out.data_handle(), - handle.get_stream()); -} - -/** - * @brief gather copies rows from a source matrix into a destination matrix according to a - * transformed map. + * For each output row, read the index in the input matrix from the map, apply a transformation to + * this input index if specified, and copy the row. * - * @tparam matrix_t Matrix type - * @tparam map_t Map vector type - * @tparam map_xform_t Unary lambda expression or operator type, MapTransformOp's result - * type must be convertible to idx_t (= int) type. - * @tparam idx_t integer type for indexing - * @param[in] handle raft handle for managing resources - * @param[in] in Input matrix (assumed to be row-major) - * @param[in] map Input vector of gather locations - * @param[out] out Output matrix (assumed to be row-major) - * @param[in] transform_op The transformation operation, transforms the map values to idx_t + * @tparam matrix_t Matrix element type + * @tparam map_t Integer type of map elements + * @tparam idx_t Integer type used for indexing + * @tparam map_xform_t Unary lambda expression or operator type. MapTransformOp's result type must + * be convertible to idx_t. + * @param[in] handle raft handle for managing resources + * @param[in] in Input matrix, dim = [N, D] (row-major) + * @param[in] map Map of row indices to gather, dim = [map_length] + * @param[out] out Output matrix, dim = [map_length, D] (row-major) + * @param[in] transform_op (optional) Transformation to apply to map values */ -template +template void gather(const raft::handle_t& handle, raft::device_matrix_view in, raft::device_vector_view map, - raft::device_matrix_view out, - map_xform_t transform_op) + raft::device_matrix_view out, + map_xform_t transform_op = raft::identity_op()) { RAFT_EXPECTS(out.extent(0) == map.extent(0), "Number of rows in output matrix must equal the size of the map vector"); @@ -238,95 +223,51 @@ void gather(const raft::handle_t& handle, detail::gather( const_cast(in.data_handle()), // TODO: There's a better way to handle this - static_cast(in.extent(1)), - static_cast(in.extent(0)), - map, - static_cast(map.extent(0)), + in.extent(1), + in.extent(0), + map.data_handle(), + map.extent(0), out.data_handle(), transform_op, handle.get_stream()); } /** - * @brief gather_if conditionally copies rows from a source matrix into a destination matrix - * according to a map. + * @brief Conditionally copies rows according to a transformed map. + * + * For each output row, read the index in the input matrix from the map, read a stencil value, + * apply a predicate to the stencil value, and if true, apply a transformation if specified to the + * input index, and copy the row. * - * @tparam matrix_t Matrix value type - * @tparam map_t Map vector type - * @tparam stencil_t Stencil vector type - * @tparam unary_pred_t Unary lambda expression or operator type, unary_pred_t's result - * type must be convertible to bool type. - * @tparam idx_t integer type for indexing - * @param[in] handle raft handle for managing resources - * @param[in] in Input matrix (assumed to be row-major) - * @param[in] map Input vector of gather locations - * @param[in] stencil Input vector of stencil or predicate values - * @param[out] out Output matrix (assumed to be row-major) - * @param[in] pred_op Predicate to apply to the stencil values + * @tparam matrix_t Matrix element type + * @tparam map_t Integer type of map elements + * @tparam stencil_t Value type for stencil (input type for the pred_op) + * @tparam unary_pred_t Unary lambda expression or operator type. unary_pred_t's result + * type must be convertible to bool type. + * @tparam map_xform_t Unary lambda expression or operator type. MapTransformOp's result type must + * be convertible to idx_t. + * @tparam idx_t Integer type used for indexing + * @param[in] handle raft handle for managing resources + * @param[in] in Input matrix, dim = [N, D] (row-major) + * @param[in] map Map of row indices to gather, dim = [map_length] + * @param[in] stencil Vector of stencil values, dim = [map_length] + * @param[out] out Output matrix, dim = [map_length, D] (row-major) + * @param[in] pred_op Predicate to apply to the stencil values + * @param[in] transform_op (optional) Transformation to apply to map values */ template + typename idx_t, + typename map_xform_t = raft::identity_op> void gather_if(const raft::handle_t& handle, raft::device_matrix_view in, raft::device_matrix_view out, raft::device_vector_view map, raft::device_vector_view stencil, - unary_pred_t pred_op) -{ - RAFT_EXPECTS(out.extent(0) == map.extent(0), - "Number of rows in output matrix must equal the size of the map vector"); - RAFT_EXPECTS(out.extent(1) == in.extent(1), - "Number of columns in input and output matrices must be equal."); - RAFT_EXPECTS(map.extent(0) == stencil.extent(0), - "Number of elements in stencil must equal number of elements in map"); - - detail::gather_if(const_cast(in.data_handle()), - out.extent(1), - out.extent(0), - map.data_handle(), - stencil.data_handle(), - map.extent(0), - out.data_handle(), - pred_op, - handle.get_stream()); -} - -/** - * @brief gather_if conditionally copies rows from a source matrix into a destination matrix - * according to a transformed map. - * - * @tparam matrix_t Matrix value type, for reading input matrix - * @tparam map_t Vector value type for map - * @tparam stencil_t Vector value type for stencil - * @tparam unary_pred_t Unary lambda expression or operator type, unary_pred_t's result - * type must be convertible to bool type. - * @tparam map_xform_t Unary lambda expression or operator type, map_xform_t's result - * type must be convertible to idx_t (= int) type. - * @tparam idx_t integer type for indexing - * @param[in] handle raft handle for managing resources - * @param[in] in Input matrix (assumed to be row-major) - * @param[in] map Vector of gather locations - * @param[in] stencil Vector of stencil or predicate values - * @param[out] out Output matrix (assumed to be row-major) - * @param[in] pred_op Predicate to apply to the stencil values - * @param[in] transform_op The transformation operation, transforms the map values to idx_t - */ -template -void gather_if(const raft::handle_t& handle, - raft::device_matrix_view in, - raft::device_matrix_view out, - raft::device_vector_view map, - raft::device_vector_view stencil, unary_pred_t pred_op, - map_xform_t transform_op) + map_xform_t transform_op = raft::identity_op()) { RAFT_EXPECTS(out.extent(0) == map.extent(0), "Number of rows in output matrix must equal the size of the map vector"); diff --git a/cpp/test/matrix/gather.cu b/cpp/test/matrix/gather.cu index 0bea62e9cf..3659265e84 100644 --- a/cpp/test/matrix/gather.cu +++ b/cpp/test/matrix/gather.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,50 +18,72 @@ #include #include #include +#include #include #include #include +#include #include namespace raft { -template -void naiveGatherImpl( - MatrixIteratorT in, int D, int N, MapIteratorT map, int map_length, MatrixIteratorT out) +template +void naiveGather(InputIteratorT in, + IdxT D, + IdxT N, + MapIteratorT map, + StencilIteratorT stencil, + IdxT map_length, + OutputIteratorT out, + UnaryPredicateOp pred_op, + MapTransformOp transform_op) { - for (int outRow = 0; outRow < map_length; ++outRow) { + for (IdxT outRow = 0; outRow < map_length; ++outRow) { + if constexpr (Conditional) { + auto stencil_val = stencil[outRow]; + if (!pred_op(stencil_val)) continue; + } typename std::iterator_traits::value_type map_val = map[outRow]; - int inRowStart = map_val * D; - int outRowStart = outRow * D; - for (int i = 0; i < D; ++i) { + IdxT transformed_val; + if constexpr (MapTransform) { + transformed_val = transform_op(map_val); + } else { + transformed_val = map_val; + } + IdxT inRowStart = transformed_val * D; + IdxT outRowStart = outRow * D; + for (IdxT i = 0; i < D; ++i) { out[outRowStart + i] = in[inRowStart + i]; } } } -template -void naiveGather( - MatrixIteratorT in, int D, int N, MapIteratorT map, int map_length, MatrixIteratorT out) -{ - naiveGatherImpl(in, D, N, map, map_length, out); -} - +template struct GatherInputs { - uint32_t nrows; - uint32_t ncols; - uint32_t map_length; + IdxT nrows; + IdxT ncols; + IdxT map_length; unsigned long long int seed; }; -template -class GatherTest : public ::testing::TestWithParam { +template +class GatherTest : public ::testing::TestWithParam> { protected: GatherTest() : stream(handle.get_stream()), - params(::testing::TestWithParam::GetParam()), + params(::testing::TestWithParam>::GetParam()), d_in(0, stream), d_out_exp(0, stream), d_out_act(0, stream), + d_stencil(0, stream), d_map(0, stream) { } @@ -71,44 +93,71 @@ class GatherTest : public ::testing::TestWithParam { raft::random::RngState r(params.seed); raft::random::RngState r_int(params.seed); - uint32_t nrows = params.nrows; - uint32_t ncols = params.ncols; - uint32_t map_length = params.map_length; - uint32_t len = nrows * ncols; + IdxT map_length = params.map_length; + IdxT len = params.nrows * params.ncols; // input matrix setup - d_in.resize(nrows * ncols, stream); - h_in.resize(nrows * ncols); + d_in.resize(params.nrows * params.ncols, stream); + h_in.resize(params.nrows * params.ncols); raft::random::uniform(handle, r, d_in.data(), len, MatrixT(-1.0), MatrixT(1.0)); raft::update_host(h_in.data(), d_in.data(), len, stream); // map setup d_map.resize(map_length, stream); h_map.resize(map_length); - raft::random::uniformInt(handle, r_int, d_map.data(), map_length, (MapT)0, nrows); + raft::random::uniformInt(handle, r_int, d_map.data(), map_length, (MapT)0, (MapT)params.nrows); raft::update_host(h_map.data(), d_map.data(), map_length, stream); - // expected and actual output matrix setup - h_out.resize(map_length * ncols); - d_out_exp.resize(map_length * ncols, stream); - d_out_act.resize(map_length * ncols, stream); + // stencil setup + if (Conditional) { + d_stencil.resize(map_length, stream); + h_stencil.resize(map_length); + raft::random::uniform(handle, r, d_stencil.data(), map_length, MatrixT(-1.0), MatrixT(1.0)); + raft::update_host(h_stencil.data(), d_stencil.data(), map_length, stream); + } - // launch gather on the host and copy the results to device - naiveGather(h_in.data(), ncols, nrows, h_map.data(), map_length, h_out.data()); - raft::update_device(d_out_exp.data(), h_out.data(), map_length * ncols, stream); + // unary predicate op (used only when Conditional is true) + auto pred_op = raft::plug_const_op(MatrixT(0.0), raft::greater_op()); - auto in_view = raft::make_device_matrix_view( - d_in.data(), nrows, ncols); - auto out_view = - raft::make_device_matrix_view(d_out_act.data(), map_length, ncols); - auto map_view = - raft::make_device_vector_view(d_map.data(), map_length); + // map transform op (used only when MapTransform is true) + auto transform_op = + raft::compose_op(raft::modulo_const_op(params.nrows), raft::add_const_op(10)); - raft::matrix::gather(handle, in_view, map_view, out_view); + // expected and actual output matrix setup + h_out.resize(map_length * params.ncols); + d_out_exp.resize(map_length * params.ncols, stream); + d_out_act.resize(map_length * params.ncols, stream); - // // launch device version of the kernel - // gatherLaunch( - // handle, d_in.data(), ncols, nrows, d_map.data(), map_length, d_out_act.data(), stream); + // launch gather on the host and copy the results to device + naiveGather(h_in.data(), + params.ncols, + params.nrows, + h_map.data(), + h_stencil.data(), + map_length, + h_out.data(), + pred_op, + transform_op); + raft::update_device(d_out_exp.data(), h_out.data(), map_length * params.ncols, stream); + + auto in_view = raft::make_device_matrix_view( + d_in.data(), params.nrows, params.ncols); + auto out_view = raft::make_device_matrix_view( + d_out_act.data(), map_length, params.ncols); + auto map_view = raft::make_device_vector_view(d_map.data(), map_length); + auto stencil_view = + raft::make_device_vector_view(d_stencil.data(), map_length); + + if (Conditional && MapTransform) { + raft::matrix::gather_if( + handle, in_view, out_view, map_view, stencil_view, pred_op, transform_op); + } else if (Conditional) { + raft::matrix::gather_if(handle, in_view, out_view, map_view, stencil_view, pred_op); + } else if (MapTransform) { + raft::matrix::gather(handle, in_view, map_view, out_view, transform_op); + } else { + raft::matrix::gather(handle, in_view, map_view, out_view); + } handle.sync_stream(stream); } @@ -116,41 +165,46 @@ class GatherTest : public ::testing::TestWithParam { protected: raft::handle_t handle; cudaStream_t stream = 0; - GatherInputs params; - std::vector h_in, h_out; + GatherInputs params; + std::vector h_in, h_out, h_stencil; std::vector h_map; - rmm::device_uvector d_in, d_out_exp, d_out_act; + rmm::device_uvector d_in, d_out_exp, d_out_act, d_stencil; rmm::device_uvector d_map; }; -const std::vector inputs = {{1024, 32, 128, 1234ULL}, - {1024, 32, 256, 1234ULL}, - {1024, 32, 512, 1234ULL}, - {1024, 32, 1024, 1234ULL}, - {1024, 64, 128, 1234ULL}, - {1024, 64, 256, 1234ULL}, - {1024, 64, 512, 1234ULL}, - {1024, 64, 1024, 1234ULL}, - {1024, 128, 128, 1234ULL}, - {1024, 128, 256, 1234ULL}, - {1024, 128, 512, 1234ULL}, - {1024, 128, 1024, 1234ULL}}; - -typedef GatherTest GatherTestF; -TEST_P(GatherTestF, Result) -{ - ASSERT_TRUE(devArrMatch( - d_out_exp.data(), d_out_act.data(), params.map_length * params.ncols, raft::Compare())); -} - -typedef GatherTest GatherTestD; -TEST_P(GatherTestD, Result) -{ - ASSERT_TRUE(devArrMatch( - d_out_exp.data(), d_out_act.data(), params.map_length * params.ncols, raft::Compare())); -} - -INSTANTIATE_TEST_CASE_P(GatherTests, GatherTestF, ::testing::ValuesIn(inputs)); -INSTANTIATE_TEST_CASE_P(GatherTests, GatherTestD, ::testing::ValuesIn(inputs)); +#define GATHER_TEST(test_type, test_name, test_inputs) \ + typedef RAFT_DEPAREN(test_type) test_name; \ + TEST_P(test_name, Result) \ + { \ + ASSERT_TRUE(devArrMatch(d_out_exp.data(), \ + d_out_act.data(), \ + params.map_length* params.ncols, \ + raft::Compare())); \ + } \ + INSTANTIATE_TEST_CASE_P(GatherTests, test_name, ::testing::ValuesIn(test_inputs)) + +const std::vector> inputs_i32 = + raft::util::itertools::product>({25, 2000}, {6, 31, 129}, {11, 999}, {1234ULL}); +const std::vector> inputs_i64 = + raft::util::itertools::product>( + {25, 2000}, {6, 31, 129}, {11, 999}, {1234ULL}); + +GATHER_TEST((GatherTest), GatherTestFU32I32, inputs_i32); +GATHER_TEST((GatherTest), + GatherTransformTestFU32I32, + inputs_i32); +GATHER_TEST((GatherTest), GatherIfTestFU32I32, inputs_i32); +GATHER_TEST((GatherTest), + GatherIfTransformTestFU32I32, + inputs_i32); +GATHER_TEST((GatherTest), + GatherIfTransformTestDU32I32, + inputs_i32); +GATHER_TEST((GatherTest), + GatherIfTransformTestFU32I64, + inputs_i64); +GATHER_TEST((GatherTest), + GatherIfTransformTestFI64I64, + inputs_i64); } // end namespace raft \ No newline at end of file From 5a6cb097fcdb9e781a21d3adddcf6d4443ce6650 Mon Sep 17 00:00:00 2001 From: "Artem M. Chirkin" <9253178+achirkin@users.noreply.github.com> Date: Mon, 23 Jan 2023 16:56:50 +0100 Subject: [PATCH 7/9] ANN tests: make the min_recall check strict (#1156) In #1135, we adjusted the min_recall values to report if any regressions happen in ivf-pq. However, `eval_neighbours` function, which is used in several ANN test suites, doesn't fail unless the regression is really large (it prints a warning if the calculated recall is "slightly" smaller than the expected recall). In this PR, I make `eval_neighbours` always fail if the calculated recall is smaller than the expected recall. Slightly adjust the tests and do a small refactoring along the way. Authors: - Artem M. Chirkin (https://github.com/achirkin) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/raft/pull/1156 --- cpp/test/neighbors/ann_ivf_pq.cuh | 13 +++++--- cpp/test/neighbors/ann_utils.cuh | 52 +++++++++++++++++-------------- 2 files changed, 36 insertions(+), 29 deletions(-) diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index b5671b74b0..719f429f13 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -139,8 +139,8 @@ class ivf_pq_test : public ::testing::TestWithParam { protected: void gen_data() { - database.resize(ps.num_db_vecs * ps.dim, stream_); - search_queries.resize(ps.num_queries * ps.dim, stream_); + database.resize(size_t{ps.num_db_vecs} * size_t{ps.dim}, stream_); + search_queries.resize(size_t{ps.num_queries} * size_t{ps.dim}, stream_); raft::random::Rng r(1234ULL); if constexpr (std::is_same{}) { @@ -155,7 +155,7 @@ class ivf_pq_test : public ::testing::TestWithParam { void calc_ref() { - size_t queries_size = ps.num_queries * ps.k; + size_t queries_size = size_t{ps.num_queries} * size_t{ps.k}; rmm::device_uvector distances_naive_dev(queries_size, stream_); rmm::device_uvector indices_naive_dev(queries_size, stream_); naiveBfKnn(distances_naive_dev.data(), @@ -463,7 +463,7 @@ inline auto enum_variety() -> test_cases_t }); ADD_CASE({ x.search_params.lut_dtype = CUDA_R_8U; - x.min_recall = 0.85; + x.min_recall = 0.84; }); ADD_CASE({ @@ -496,7 +496,10 @@ inline auto enum_variety_ip() -> test_cases_t // InnerProduct score is signed, // thus we're forced to used signed 8-bit representation, // thus we have one bit less precision - y.min_recall = y.min_recall.value() * 0.95; + y.min_recall = y.min_recall.value() * 0.90; + } else { + // In other cases it seems to perform a little bit better, still worse than L2 + y.min_recall = y.min_recall.value() * 0.94; } } y.index_params.metric = distance::DistanceType::InnerProduct; diff --git a/cpp/test/neighbors/ann_utils.cuh b/cpp/test/neighbors/ann_utils.cuh index bb2f334db4..551ebd767f 100644 --- a/cpp/test/neighbors/ann_utils.cuh +++ b/cpp/test/neighbors/ann_utils.cuh @@ -110,28 +110,39 @@ __global__ void naive_distance_kernel(EvalT* dist, IdxT m, IdxT n, IdxT k, - raft::distance::DistanceType type) + raft::distance::DistanceType metric) { - IdxT midx = threadIdx.x + blockIdx.x * blockDim.x; + IdxT midx = IdxT(threadIdx.x) + IdxT(blockIdx.x) * IdxT(blockDim.x); if (midx >= m) return; - for (IdxT nidx = threadIdx.y + blockIdx.y * blockDim.y; nidx < n; - nidx += blockDim.y * gridDim.y) { + IdxT grid_size = IdxT(blockDim.y) * IdxT(gridDim.y); + for (IdxT nidx = threadIdx.y + blockIdx.y * blockDim.y; nidx < n; nidx += grid_size) { EvalT acc = EvalT(0); for (IdxT i = 0; i < k; ++i) { IdxT xidx = i + midx * k; IdxT yidx = i + nidx * k; - EvalT xv = (EvalT)x[xidx]; - EvalT yv = (EvalT)y[yidx]; - if (type == raft::distance::DistanceType::InnerProduct) { - acc += xv * yv; - } else { - EvalT diff = xv - yv; - acc += diff * diff; + auto xv = EvalT(x[xidx]); + auto yv = EvalT(y[yidx]); + switch (metric) { + case raft::distance::DistanceType::InnerProduct: { + acc += xv * yv; + } break; + case raft::distance::DistanceType::L2SqrtExpanded: + case raft::distance::DistanceType::L2SqrtUnexpanded: + case raft::distance::DistanceType::L2Expanded: + case raft::distance::DistanceType::L2Unexpanded: { + auto diff = xv - yv; + acc += diff * diff; + } break; + default: break; } } - if (type == raft::distance::DistanceType::L2SqrtExpanded || - type == raft::distance::DistanceType::L2SqrtUnexpanded) - acc = raft::sqrt(acc); + switch (metric) { + case raft::distance::DistanceType::L2SqrtExpanded: + case raft::distance::DistanceType::L2SqrtUnexpanded: { + acc = raft::sqrt(acc); + } break; + default: break; + } dist[midx * n + nidx] = acc; } } @@ -241,16 +252,9 @@ auto eval_neighbours(const std::vector& expected_idx, error_margin < 0 ? "above" : "below", eps); if (actual_recall < min_recall - eps) { - if (actual_recall < min_recall * min_recall - eps) { - RAFT_LOG_ERROR("Recall is much lower than the minimum (%f < %f)", actual_recall, min_recall); - } else { - RAFT_LOG_WARN("Recall is suspiciously too low (%f < %f)", actual_recall, min_recall); - } - if (match_count == 0 || actual_recall < min_recall * std::min(min_recall, 0.5) - eps) { - return testing::AssertionFailure() - << "actual recall (" << actual_recall - << ") is much smaller than the minimum expected recall (" << min_recall << ")."; - } + return testing::AssertionFailure() + << "actual recall (" << actual_recall << ") is lower than the minimum expected recall (" + << min_recall << "); eps = " << eps << ". "; } return testing::AssertionSuccess(); } From 0076101e69c03bab03c9cb022d2e4c519bce60af Mon Sep 17 00:00:00 2001 From: "Artem M. Chirkin" <9253178+achirkin@users.noreply.github.com> Date: Mon, 23 Jan 2023 21:17:02 +0100 Subject: [PATCH 8/9] matrix::select_k: move selection and warp-sort primitives (#1085) Refactor and move a set of implementations for batch-selecting top K largest/smallest values: - Move device warp-wide primitives `bitonic_sort.cuh` to the public `raft::util` namespace, add tests. - Create a new public `matrix::select_k` interface. - Deprecate the legacy public `raft::spatial::knn::select_k` interface. - Copy/adapt `select_k` tests. - Move/adapt `select_k` benchmarks. - Rework the internals of `select_warpsort.cuh` to enable more implementations. Closes https://github.com/rapidsai/raft/issues/853 Authors: - Artem M. Chirkin (https://github.com/achirkin) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1085 --- cpp/bench/CMakeLists.txt | 6 +- cpp/bench/matrix/select_k.cu | 133 +++++ cpp/bench/neighbors/selection.cu | 123 ----- .../topk.cuh => matrix/detail/select_k.cuh} | 58 +-- .../detail/select_radix.cuh} | 113 +++-- .../detail/select_warpsort.cuh} | 415 +++++++++++----- cpp/include/raft/matrix/select_k.cuh | 110 +++++ cpp/include/raft/neighbors/detail/refine.cuh | 4 +- .../spatial/knn/detail/ivf_flat_search.cuh | 75 +-- .../raft/spatial/knn/detail/ivf_pq_search.cuh | 79 ++- cpp/include/raft/spatial/knn/knn.cuh | 38 +- .../knn/detail/topk => util}/bitonic_sort.cuh | 83 ++-- cpp/include/raft/util/integer_utils.hpp | 34 +- cpp/test/CMakeLists.txt | 5 +- cpp/test/matrix/select_k.cu | 459 ++++++++++++++++++ cpp/test/matrix/select_k.cuh | 127 +++++ cpp/test/neighbors/ann_ivf_flat.cu | 8 +- cpp/test/neighbors/ann_utils.cuh | 23 +- cpp/test/neighbors/selection.cu | 2 +- cpp/test/util/bitonic_sort.cu | 200 ++++++++ docs/source/cpp_api/matrix_ordering.rst | 12 + 21 files changed, 1631 insertions(+), 476 deletions(-) create mode 100644 cpp/bench/matrix/select_k.cu delete mode 100644 cpp/bench/neighbors/selection.cu rename cpp/include/raft/{spatial/knn/detail/topk.cuh => matrix/detail/select_k.cuh} (59%) rename cpp/include/raft/{spatial/knn/detail/topk/radix_topk.cuh => matrix/detail/select_radix.cuh} (87%) rename cpp/include/raft/{spatial/knn/detail/topk/warpsort_topk.cuh => matrix/detail/select_warpsort.cuh} (71%) create mode 100644 cpp/include/raft/matrix/select_k.cuh rename cpp/include/raft/{spatial/knn/detail/topk => util}/bitonic_sort.cuh (68%) create mode 100644 cpp/test/matrix/select_k.cu create mode 100644 cpp/test/matrix/select_k.cuh create mode 100644 cpp/test/util/bitonic_sort.cu diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index 8dcdb325e9..6b985acfc3 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -103,7 +103,10 @@ if(BUILD_BENCH) bench/main.cpp ) - ConfigureBench(NAME MATRIX_BENCH PATH bench/matrix/argmin.cu bench/matrix/gather.cu bench/main.cpp) + ConfigureBench( + NAME MATRIX_BENCH PATH bench/matrix/argmin.cu bench/matrix/gather.cu bench/matrix/select_k.cu + bench/main.cpp + ) ConfigureBench( NAME RANDOM_BENCH PATH bench/random/make_blobs.cu bench/random/permute.cu bench/random/rng.cu @@ -127,7 +130,6 @@ if(BUILD_BENCH) bench/neighbors/knn/ivf_pq_int8_t_int64_t.cu bench/neighbors/knn/ivf_pq_uint8_t_uint32_t.cu bench/neighbors/refine.cu - bench/neighbors/selection.cu bench/main.cpp OPTIONAL DIST diff --git a/cpp/bench/matrix/select_k.cu b/cpp/bench/matrix/select_k.cu new file mode 100644 index 0000000000..452a50ba50 --- /dev/null +++ b/cpp/bench/matrix/select_k.cu @@ -0,0 +1,133 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * TODO: reconsider how to organize shared test+bench files better + * Related Issue: https://github.com/rapidsai/raft/issues/1153 + * (although this header does not depend on any gtest headers) + */ +#include "../../test/matrix/select_k.cuh" + +#include + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include + +namespace raft::matrix { + +using namespace raft::bench; // NOLINT + +template +struct selection : public fixture { + explicit selection(const select::params& p) + : params_(p), + in_dists_(p.batch_size * p.len, stream), + in_ids_(p.batch_size * p.len, stream), + out_dists_(p.batch_size * p.k, stream), + out_ids_(p.batch_size * p.k, stream) + { + raft::sparse::iota_fill(in_ids_.data(), IdxT(p.batch_size), IdxT(p.len), stream); + raft::random::RngState state{42}; + raft::random::uniform(handle, state, in_dists_.data(), in_dists_.size(), KeyT(-1.0), KeyT(1.0)); + } + + void run_benchmark(::benchmark::State& state) override // NOLINT + { + handle_t handle{stream}; + using_pool_memory_res res; + try { + std::ostringstream label_stream; + label_stream << params_.batch_size << "#" << params_.len << "#" << params_.k; + state.SetLabel(label_stream.str()); + loop_on_state(state, [this, &handle]() { + select::select_k_impl(handle, + Algo, + in_dists_.data(), + in_ids_.data(), + params_.batch_size, + params_.len, + params_.k, + out_dists_.data(), + out_ids_.data(), + params_.select_min); + }); + } catch (raft::exception& e) { + state.SkipWithError(e.what()); + } + } + + private: + const select::params params_; + rmm::device_uvector in_dists_, out_dists_; + rmm::device_uvector in_ids_, out_ids_; +}; + +const std::vector kInputs{ + {20000, 500, 1, true}, {20000, 500, 2, true}, {20000, 500, 4, true}, + {20000, 500, 8, true}, {20000, 500, 16, true}, {20000, 500, 32, true}, + {20000, 500, 64, true}, {20000, 500, 128, true}, {20000, 500, 256, true}, + + {1000, 10000, 1, true}, {1000, 10000, 2, true}, {1000, 10000, 4, true}, + {1000, 10000, 8, true}, {1000, 10000, 16, true}, {1000, 10000, 32, true}, + {1000, 10000, 64, true}, {1000, 10000, 128, true}, {1000, 10000, 256, true}, + + {100, 100000, 1, true}, {100, 100000, 2, true}, {100, 100000, 4, true}, + {100, 100000, 8, true}, {100, 100000, 16, true}, {100, 100000, 32, true}, + {100, 100000, 64, true}, {100, 100000, 128, true}, {100, 100000, 256, true}, + + {10, 1000000, 1, true}, {10, 1000000, 2, true}, {10, 1000000, 4, true}, + {10, 1000000, 8, true}, {10, 1000000, 16, true}, {10, 1000000, 32, true}, + {10, 1000000, 64, true}, {10, 1000000, 128, true}, {10, 1000000, 256, true}, +}; + +#define SELECTION_REGISTER(KeyT, IdxT, A) \ + namespace BENCHMARK_PRIVATE_NAME(selection) \ + { \ + using SelectK = selection; \ + RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #A, kInputs); \ + } + +SELECTION_REGISTER(float, int, kPublicApi); // NOLINT +SELECTION_REGISTER(float, int, kRadix8bits); // NOLINT +SELECTION_REGISTER(float, int, kRadix11bits); // NOLINT +SELECTION_REGISTER(float, int, kWarpAuto); // NOLINT +SELECTION_REGISTER(float, int, kWarpImmediate); // NOLINT +SELECTION_REGISTER(float, int, kWarpFiltered); // NOLINT +SELECTION_REGISTER(float, int, kWarpDistributed); // NOLINT +SELECTION_REGISTER(float, int, kWarpDistributedShm); // NOLINT + +SELECTION_REGISTER(double, int, kRadix8bits); // NOLINT +SELECTION_REGISTER(double, int, kRadix11bits); // NOLINT +SELECTION_REGISTER(double, int, kWarpAuto); // NOLINT + +SELECTION_REGISTER(double, size_t, kRadix8bits); // NOLINT +SELECTION_REGISTER(double, size_t, kRadix11bits); // NOLINT +SELECTION_REGISTER(double, size_t, kWarpImmediate); // NOLINT +SELECTION_REGISTER(double, size_t, kWarpFiltered); // NOLINT +SELECTION_REGISTER(double, size_t, kWarpDistributed); // NOLINT +SELECTION_REGISTER(double, size_t, kWarpDistributedShm); // NOLINT + +} // namespace raft::matrix diff --git a/cpp/bench/neighbors/selection.cu b/cpp/bench/neighbors/selection.cu deleted file mode 100644 index 1f116c199f..0000000000 --- a/cpp/bench/neighbors/selection.cu +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#if defined RAFT_NN_COMPILED -#include -#endif - -#include -#include - -#include -#include - -namespace raft::bench::spatial { - -struct params { - int n_inputs; - int input_len; - int k; - int select_min; -}; - -template -struct selection : public fixture { - explicit selection(const params& p) - : params_(p), - in_dists_(p.n_inputs * p.input_len, stream), - in_ids_(p.n_inputs * p.input_len, stream), - out_dists_(p.n_inputs * p.k, stream), - out_ids_(p.n_inputs * p.k, stream) - { - raft::sparse::iota_fill(in_ids_.data(), IdxT(p.n_inputs), IdxT(p.input_len), stream); - raft::random::RngState state{42}; - raft::random::uniform(handle, state, in_dists_.data(), in_dists_.size(), KeyT(-1.0), KeyT(1.0)); - } - - void run_benchmark(::benchmark::State& state) override - { - using_pool_memory_res res; - try { - std::ostringstream label_stream; - label_stream << params_.n_inputs << "#" << params_.input_len << "#" << params_.k; - state.SetLabel(label_stream.str()); - loop_on_state(state, [this]() { - raft::spatial::knn::select_k(in_dists_.data(), - in_ids_.data(), - params_.n_inputs, - params_.input_len, - out_dists_.data(), - out_ids_.data(), - params_.select_min, - params_.k, - stream, - Algo); - }); - } catch (raft::exception& e) { - state.SkipWithError(e.what()); - } - } - - private: - const params params_; - rmm::device_uvector in_dists_, out_dists_; - rmm::device_uvector in_ids_, out_ids_; -}; - -const std::vector kInputs{ - {20000, 500, 1, true}, {20000, 500, 2, true}, {20000, 500, 4, true}, - {20000, 500, 8, true}, {20000, 500, 16, true}, {20000, 500, 32, true}, - {20000, 500, 64, true}, {20000, 500, 128, true}, {20000, 500, 256, true}, - - {1000, 10000, 1, true}, {1000, 10000, 2, true}, {1000, 10000, 4, true}, - {1000, 10000, 8, true}, {1000, 10000, 16, true}, {1000, 10000, 32, true}, - {1000, 10000, 64, true}, {1000, 10000, 128, true}, {1000, 10000, 256, true}, - - {100, 100000, 1, true}, {100, 100000, 2, true}, {100, 100000, 4, true}, - {100, 100000, 8, true}, {100, 100000, 16, true}, {100, 100000, 32, true}, - {100, 100000, 64, true}, {100, 100000, 128, true}, {100, 100000, 256, true}, - - {10, 1000000, 1, true}, {10, 1000000, 2, true}, {10, 1000000, 4, true}, - {10, 1000000, 8, true}, {10, 1000000, 16, true}, {10, 1000000, 32, true}, - {10, 1000000, 64, true}, {10, 1000000, 128, true}, {10, 1000000, 256, true}, -}; - -#define SELECTION_REGISTER(KeyT, IdxT, Algo) \ - namespace BENCHMARK_PRIVATE_NAME(selection) \ - { \ - using SelectK = selection; \ - RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #Algo, kInputs); \ - } - -SELECTION_REGISTER(float, int, FAISS); -SELECTION_REGISTER(float, int, RADIX_8_BITS); -SELECTION_REGISTER(float, int, RADIX_11_BITS); -SELECTION_REGISTER(float, int, WARP_SORT); - -SELECTION_REGISTER(double, int, FAISS); -SELECTION_REGISTER(double, int, RADIX_8_BITS); -SELECTION_REGISTER(double, int, RADIX_11_BITS); -SELECTION_REGISTER(double, int, WARP_SORT); - -SELECTION_REGISTER(double, size_t, FAISS); -SELECTION_REGISTER(double, size_t, RADIX_8_BITS); -SELECTION_REGISTER(double, size_t, RADIX_11_BITS); -SELECTION_REGISTER(double, size_t, WARP_SORT); - -} // namespace raft::bench::spatial diff --git a/cpp/include/raft/spatial/knn/detail/topk.cuh b/cpp/include/raft/matrix/detail/select_k.cuh similarity index 59% rename from cpp/include/raft/spatial/knn/detail/topk.cuh rename to cpp/include/raft/matrix/detail/select_k.cuh index f4dcb53088..ac1ba3dfa3 100644 --- a/cpp/include/raft/spatial/knn/detail/topk.cuh +++ b/cpp/include/raft/matrix/detail/select_k.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,34 +16,34 @@ #pragma once -#include "topk/radix_topk.cuh" -#include "topk/warpsort_topk.cuh" +#include "select_radix.cuh" +#include "select_warpsort.cuh" #include #include #include -namespace raft::spatial::knn::detail { +namespace raft::matrix::detail { /** * Select k smallest or largest key/values from each row in the input data. * - * If you think of the input data `in_keys` as a row-major matrix with len columns and - * batch_size rows, then this function selects k smallest/largest values in each row and fills - * in the row-major matrix `out` of size (batch_size, k). + * If you think of the input data `in_val` as a row-major matrix with `len` columns and + * `batch_size` rows, then this function selects `k` smallest/largest values in each row and fills + * in the row-major matrix `out_val` of size (batch_size, k). * * @tparam T * the type of the keys (what is being compared). * @tparam IdxT * the index type (what is being selected together with the keys). * - * @param[in] in + * @param[in] in_val * contiguous device array of inputs of size (len * batch_size); * these are compared and selected. * @param[in] in_idx * contiguous device array of inputs of size (len * batch_size); - * typically, these are indices of the corresponding in_keys. + * typically, these are indices of the corresponding in_val. * @param batch_size * number of input rows, i.e. the batch size. * @param len @@ -51,12 +51,12 @@ namespace raft::spatial::knn::detail { * Invariant: len >= k. * @param k * the number of outputs to select in each input row. - * @param[out] out + * @param[out] out_val * contiguous device array of outputs of size (k * batch_size); - * the k smallest/largest values from each row of the `in_keys`. + * the k smallest/largest values from each row of the `in_val`. * @param[out] out_idx * contiguous device array of outputs of size (k * batch_size); - * the payload selected together with `out`. + * the payload selected together with `out_val`. * @param select_min * whether to select k smallest (true) or largest (false) keys. * @param stream @@ -64,28 +64,28 @@ namespace raft::spatial::knn::detail { * memory pool here to avoid memory allocations within the call). */ template -void select_topk(const T* in, - const IdxT* in_idx, - size_t batch_size, - size_t len, - int k, - T* out, - IdxT* out_idx, - bool select_min, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = nullptr) +void select_k(const T* in_val, + const IdxT* in_idx, + size_t batch_size, + size_t len, + int k, + T* out_val, + IdxT* out_idx, + bool select_min, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = nullptr) { common::nvtx::range fun_scope( - "matrix::select_topk(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k); + "matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k); // TODO (achirkin): investigate the trade-off for a wider variety of inputs. const bool radix_faster = batch_size >= 64 && len >= 102400 && k >= 128; - if (k <= raft::spatial::knn::detail::topk::kMaxCapacity && !radix_faster) { - topk::warp_sort_topk( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream, mr); + if (k <= select::warpsort::kMaxCapacity && !radix_faster) { + select::warpsort::select_k( + in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); } else { - topk::radix_topk= 4 ? 11 : 8), 512>( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream, mr); + select::radix::select_k= 4 ? 11 : 8), 512>( + in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); } } -} // namespace raft::spatial::knn::detail +} // namespace raft::matrix::detail diff --git a/cpp/include/raft/spatial/knn/detail/topk/radix_topk.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh similarity index 87% rename from cpp/include/raft/spatial/knn/detail/topk/radix_topk.cuh rename to cpp/include/raft/matrix/detail/select_radix.cuh index 9c0f20b706..de19e63a4c 100644 --- a/cpp/include/raft/spatial/knn/detail/topk/radix_topk.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include #include @@ -27,29 +28,29 @@ #include #include -#include +#include #include -namespace raft::spatial::knn::detail::topk { +namespace raft::matrix::detail::select::radix { constexpr int ITEM_PER_THREAD = 32; constexpr int VECTORIZED_READ_SIZE = 16; template -__host__ __device__ constexpr int calc_num_buckets() +_RAFT_HOST_DEVICE constexpr int calc_num_buckets() { return 1 << BitsPerPass; } template -__host__ __device__ constexpr int calc_num_passes() +_RAFT_HOST_DEVICE constexpr int calc_num_passes() { return ceildiv(sizeof(T) * 8, BitsPerPass); } // Minimum reasonable block size for the given radix size. template -__host__ __device__ constexpr int calc_min_block_size() +_RAFT_HOST_DEVICE constexpr int calc_min_block_size() { return 1 << std::max(BitsPerPass - 4, Pow2::Log2 + 1); } @@ -62,7 +63,7 @@ __host__ __device__ constexpr int calc_min_block_size() * NB: Use pass=-1 for calc_mask(). */ template -__device__ constexpr int calc_start_bit(int pass) +_RAFT_DEVICE constexpr int calc_start_bit(int pass) { int start_bit = static_cast(sizeof(T) * 8) - (pass + 1) * BitsPerPass; if (start_bit < 0) { start_bit = 0; } @@ -70,7 +71,7 @@ __device__ constexpr int calc_start_bit(int pass) } template -__device__ constexpr unsigned calc_mask(int pass) +_RAFT_DEVICE constexpr unsigned calc_mask(int pass) { static_assert(BitsPerPass <= 31); int num_bits = calc_start_bit(pass - 1) - calc_start_bit(pass); @@ -82,7 +83,7 @@ __device__ constexpr unsigned calc_mask(int pass) * as of integers. */ template -__device__ typename cub::Traits::UnsignedBits twiddle_in(T key, bool greater) +_RAFT_DEVICE typename cub::Traits::UnsignedBits twiddle_in(T key, bool greater) { auto bits = reinterpret_cast::UnsignedBits&>(key); bits = cub::Traits::TwiddleIn(bits); @@ -91,7 +92,7 @@ __device__ typename cub::Traits::UnsignedBits twiddle_in(T key, bool greater) } template -__device__ int calc_bucket(T x, int start_bit, unsigned mask, bool greater) +_RAFT_DEVICE int calc_bucket(T x, int start_bit, unsigned mask, bool greater) { static_assert(BitsPerPass <= sizeof(int) * 8 - 1); // so return type can be int return (twiddle_in(x, greater) >> start_bit) & mask; @@ -112,7 +113,7 @@ __device__ int calc_bucket(T x, int start_bit, unsigned mask, bool greater) * @param f the lambda taking two arguments (T x, IdxT idx) */ template -__device__ void vectorized_process(const T* in, IdxT len, Func f) +_RAFT_DEVICE void vectorized_process(const T* in, IdxT len, Func f) { const IdxT stride = blockDim.x * gridDim.x; const int tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -167,18 +168,18 @@ struct Counter { * (see steps 4-1 in `radix_kernel` description). */ template -__device__ void filter_and_histogram(const T* in_buf, - const IdxT* in_idx_buf, - T* out_buf, - IdxT* out_idx_buf, - T* out, - IdxT* out_idx, - IdxT len, - Counter* counter, - IdxT* histogram, - bool greater, - int pass, - int k) +_RAFT_DEVICE void filter_and_histogram(const T* in_buf, + const IdxT* in_idx_buf, + T* out_buf, + IdxT* out_idx_buf, + T* out, + IdxT* out_idx, + IdxT len, + Counter* counter, + IdxT* histogram, + bool greater, + int pass, + int k) { constexpr int num_buckets = calc_num_buckets(); __shared__ IdxT histogram_smem[num_buckets]; @@ -260,10 +261,10 @@ __device__ void filter_and_histogram(const T* in_buf, * (step 2 in `radix_kernel` description) */ template -__device__ void scan(volatile IdxT* histogram, - const int start, - const int num_buckets, - const IdxT current) +_RAFT_DEVICE void scan(volatile IdxT* histogram, + const int start, + const int num_buckets, + const IdxT current) { typedef cub::BlockScan BlockScan; __shared__ typename BlockScan::TempStorage temp_storage; @@ -284,7 +285,7 @@ __device__ void scan(volatile IdxT* histogram, * (steps 2-3 in `radix_kernel` description) */ template -__device__ void choose_bucket(Counter* counter, IdxT* histogram, const IdxT k) +_RAFT_DEVICE void choose_bucket(Counter* counter, IdxT* histogram, const IdxT k) { constexpr int num_buckets = calc_num_buckets(); int index = threadIdx.x; @@ -547,21 +548,21 @@ inline dim3 get_optimal_grid_size(size_t req_batch_size, size_t len) * memory pool here to avoid memory allocations within the call). */ template -void radix_topk(const T* in, - const IdxT* in_idx, - size_t batch_size, - size_t len, - int k, - T* out, - IdxT* out_idx, - bool select_min, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = nullptr) +void select_k(const T* in, + const IdxT* in_idx, + size_t batch_size, + size_t len, + int k, + T* out, + IdxT* out_idx, + bool select_min, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = nullptr) { // reduce the block size if the input length is too small. if constexpr (BlockSize > calc_min_block_size()) { if (BlockSize * ITEM_PER_THREAD > len) { - return radix_topk( + return select_k( in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); } } @@ -573,23 +574,33 @@ void radix_topk(const T* in, dim3 blocks = get_optimal_grid_size(batch_size, len); size_t max_chunk_size = blocks.y; - auto pool_guard = raft::get_pool_memory_resource( - mr, - max_chunk_size * (sizeof(Counter) // counters - + sizeof(IdxT) * (num_buckets + 2) // histograms and IdxT bufs - + sizeof(T) * 2 // T bufs - )); + size_t req_aux = max_chunk_size * (sizeof(Counter) + num_buckets * sizeof(IdxT)); + size_t req_buf = max_chunk_size * len * 2 * (sizeof(T) + sizeof(IdxT)); + size_t mem_req = req_aux + req_buf; + size_t mem_free, mem_total; + RAFT_CUDA_TRY(cudaMemGetInfo(&mem_free, &mem_total)); + std::optional managed_memory; + rmm::mr::device_memory_resource* mr_buf = nullptr; + if (mem_req > mem_free) { + // if there's not enough memory for buffers on the device, resort to the managed memory. + mem_req = req_aux; + managed_memory.emplace(); + mr_buf = &managed_memory.value(); + } + + auto pool_guard = raft::get_pool_memory_resource(mr, mem_req); if (pool_guard) { - RAFT_LOG_DEBUG("radix_topk: using pool memory resource with initial size %zu bytes", + RAFT_LOG_DEBUG("radix::select_k: using pool memory resource with initial size %zu bytes", pool_guard->pool_size()); } + if (mr_buf == nullptr) { mr_buf = mr; } rmm::device_uvector> counters(max_chunk_size, stream, mr); - rmm::device_uvector histograms(num_buckets * max_chunk_size, stream, mr); - rmm::device_uvector buf1(len * max_chunk_size, stream, mr); - rmm::device_uvector idx_buf1(len * max_chunk_size, stream, mr); - rmm::device_uvector buf2(len * max_chunk_size, stream, mr); - rmm::device_uvector idx_buf2(len * max_chunk_size, stream, mr); + rmm::device_uvector histograms(max_chunk_size * num_buckets, stream, mr); + rmm::device_uvector buf1(max_chunk_size * len, stream, mr_buf); + rmm::device_uvector idx_buf1(max_chunk_size * len, stream, mr_buf); + rmm::device_uvector buf2(max_chunk_size * len, stream, mr_buf); + rmm::device_uvector idx_buf2(max_chunk_size * len, stream, mr_buf); for (size_t offset = 0; offset < batch_size; offset += max_chunk_size) { blocks.y = std::min(max_chunk_size, batch_size - offset); @@ -646,4 +657,4 @@ void radix_topk(const T* in, } } -} // namespace raft::spatial::knn::detail::topk +} // namespace raft::matrix::detail::select::radix diff --git a/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh b/cpp/include/raft/matrix/detail/select_warpsort.cuh similarity index 71% rename from cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh rename to cpp/include/raft/matrix/detail/select_warpsort.cuh index c06aa04aea..d362b73792 100644 --- a/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh +++ b/cpp/include/raft/matrix/detail/select_warpsort.cuh @@ -16,10 +16,11 @@ #pragma once -#include "bitonic_sort.cuh" - +#include #include +#include #include +#include #include #include @@ -31,12 +32,12 @@ /* Three APIs of different scopes are provided: - 1. host function: warp_sort_topk() + 1. host function: select_k() 2. block-wide API: class block_sort 3. warp-wide API: several implementations of warp_sort_* - 1. warp_sort_topk() + 1. select_k() (see the docstring) 2. class block_sort @@ -74,7 +75,7 @@ These two classes can be regarded as fixed size priority queue for a warp. Usage is similar to class block_sort. No shared memory is needed. - The host function (warp_sort_topk) uses a heuristic to choose between these two classes for + The host function (select_k) uses a heuristic to choose between these two classes for sorting, warp_sort_immediate being chosen when the number of inputs per warp is somewhat small (see the usage of LaunchThreshold::len_factor_for_choosing). @@ -94,7 +95,7 @@ } */ -namespace raft::spatial::knn::detail::topk { +namespace raft::matrix::detail::select::warpsort { static constexpr int kMaxCapacity = 256; @@ -102,18 +103,12 @@ namespace { /** Whether 'left` should indeed be on the left w.r.t. `right`. */ template -__device__ __forceinline__ auto is_ordered(T left, T right) -> bool +_RAFT_DEVICE _RAFT_FORCEINLINE auto is_ordered(T left, T right) -> bool { if constexpr (Ascending) { return left < right; } if constexpr (!Ascending) { return left > right; } } -constexpr auto calc_capacity(int k) -> int -{ - int capacity = isPo2(k) ? k : (1 << (log2(k) + 1)); - return capacity; -} - } // namespace /** @@ -134,7 +129,7 @@ constexpr auto calc_capacity(int k) -> int */ template class warp_sort { - static_assert(isPo2(Capacity)); + static_assert(is_a_power_of_two(Capacity)); static_assert(std::is_default_constructible_v); public: @@ -148,13 +143,16 @@ class warp_sort { /** The number of elements to select. */ const int k; + /** Extra memory required per-block for keeping the state (shared or global). */ + constexpr static auto mem_required(uint32_t block_size) -> size_t { return 0; } + /** * Construct the warp_sort empty queue. * * @param k * number of elements to select. */ - __device__ warp_sort(int k) : k(k) + _RAFT_DEVICE warp_sort(int k) : k(k) { #pragma unroll for (int i = 0; i < kMaxArrLen; i++) { @@ -182,7 +180,7 @@ class warp_sort { * It serves as a conditional; when `false` the function does nothing. * We need it to ensure threads within a full warp don't diverge calling `bitonic::merge()`. */ - __device__ void load_sorted(const T* in, const IdxT* in_idx, bool do_merge = true) + _RAFT_DEVICE void load_sorted(const T* in, const IdxT* in_idx, bool do_merge = true) { if (do_merge) { int idx = Pow2::mod(laneId()) ^ Pow2::Mask; @@ -198,7 +196,7 @@ class warp_sort { } } if (kWarpWidth < WarpSize || do_merge) { - topk::bitonic(Ascending, kWarpWidth).merge(val_arr_, idx_arr_); + util::bitonic(Ascending, kWarpWidth).merge(val_arr_, idx_arr_); } } @@ -211,15 +209,23 @@ class warp_sort { * @param[out] out_idx * device pointer to a contiguous array, unique per-subwarp of size `kWarpWidth` * (length: k <= kWarpWidth * kMaxArrLen). + * @param valF (optional) postprocess values (T -> OutT) + * @param idxF (optional) postprocess indices (IdxT -> OutIdxT) */ - template - __device__ void store(T* out, IdxT* out_idx, Lambda post_process = raft::identity_op()) const + template + _RAFT_DEVICE void store(OutT* out, + OutIdxT* out_idx, + ValF valF = raft::identity_op{}, + IdxF idxF = raft::identity_op{}) const { int idx = Pow2::mod(laneId()); #pragma unroll kMaxArrLen for (int i = 0; i < kMaxArrLen && idx < k; i++, idx += kWarpWidth) { - out[idx] = post_process(val_arr_[i]); - out_idx[idx] = idx_arr_[i]; + out[idx] = valF(val_arr_[i]); + out_idx[idx] = idxF(idx_arr_[i]); } } @@ -246,8 +252,8 @@ class warp_sort { * the associated indices of the elements in the same format as `keys_in`. */ template - __device__ __forceinline__ void merge_in(const T* __restrict__ keys_in, - const IdxT* __restrict__ ids_in) + _RAFT_DEVICE _RAFT_FORCEINLINE void merge_in(const T* __restrict__ keys_in, + const IdxT* __restrict__ ids_in) { #pragma unroll for (int i = std::min(kMaxArrLen, PerThreadSizeIn); i > 0; i--) { @@ -258,7 +264,7 @@ class warp_sort { idx_arr_[kMaxArrLen - i] = ids_in[PerThreadSizeIn - i]; } } - topk::bitonic(Ascending, kWarpWidth).merge(val_arr_, idx_arr_); + util::bitonic(Ascending, kWarpWidth).merge(val_arr_, idx_arr_); } }; @@ -276,8 +282,9 @@ class warp_sort_filtered : public warp_sort { using warp_sort::kDummy; using warp_sort::kWarpWidth; using warp_sort::k; + using warp_sort::mem_required; - __device__ warp_sort_filtered(int k, T limit) + explicit _RAFT_DEVICE warp_sort_filtered(int k, T limit = kDummy) : warp_sort(k), buf_len_(0), k_th_(limit) { #pragma unroll @@ -287,12 +294,14 @@ class warp_sort_filtered : public warp_sort { } } - __device__ __forceinline__ explicit warp_sort_filtered(int k) - : warp_sort_filtered(k, kDummy) + _RAFT_DEVICE _RAFT_FORCEINLINE static auto init_blockwide(int k, + uint8_t* = nullptr, + T limit = kDummy) { + return warp_sort_filtered{k, limit}; } - __device__ void add(T val, IdxT idx) + _RAFT_DEVICE void add(T val, IdxT idx) { // comparing for k_th should reduce the total amount of updates: // `false` means the input value is surely not in the top-k values. @@ -310,22 +319,22 @@ class warp_sort_filtered : public warp_sort { if (do_add) { add_to_buf_(val, idx); } } - __device__ void done() + _RAFT_DEVICE void done() { if (any(buf_len_ != 0)) { merge_buf_(); } } private: - __device__ __forceinline__ void set_k_th_() + _RAFT_DEVICE _RAFT_FORCEINLINE void set_k_th_() { // NB on using srcLane: it's ok if it is outside the warp size / width; // the modulo op will be done inside the __shfl_sync. k_th_ = shfl(val_arr_[kMaxArrLen - 1], k - 1, kWarpWidth); } - __device__ __forceinline__ void merge_buf_() + _RAFT_DEVICE _RAFT_FORCEINLINE void merge_buf_() { - topk::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); + util::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); this->merge_in(val_buf_, idx_buf_); buf_len_ = 0; set_k_th_(); // contains warp sync @@ -335,7 +344,7 @@ class warp_sort_filtered : public warp_sort { } } - __device__ __forceinline__ void add_to_buf_(T val, IdxT idx) + _RAFT_DEVICE _RAFT_FORCEINLINE void add_to_buf_(T val, IdxT idx) { // NB: the loop is used here to ensure the constant indexing, // to not force the buffers spill into the local memory. @@ -374,8 +383,9 @@ class warp_sort_distributed : public warp_sort { using warp_sort::kDummy; using warp_sort::kWarpWidth; using warp_sort::k; + using warp_sort::mem_required; - __device__ warp_sort_distributed(int k, T limit) + explicit _RAFT_DEVICE warp_sort_distributed(int k, T limit = kDummy) : warp_sort(k), buf_val_(kDummy), buf_idx_(IdxT{}), @@ -384,12 +394,14 @@ class warp_sort_distributed : public warp_sort { { } - __device__ __forceinline__ explicit warp_sort_distributed(int k) - : warp_sort_distributed(k, kDummy) + _RAFT_DEVICE _RAFT_FORCEINLINE static auto init_blockwide(int k, + uint8_t* = nullptr, + T limit = kDummy) { + return warp_sort_distributed{k, limit}; } - __device__ void add(T val, IdxT idx) + _RAFT_DEVICE void add(T val, IdxT idx) { // mask tells which lanes in the warp have valid items to be added uint32_t mask = ballot(is_ordered(val, k_th_)); @@ -429,7 +441,7 @@ class warp_sort_distributed : public warp_sort { } } - __device__ void done() + _RAFT_DEVICE void done() { if (buf_len_ != 0) { merge_buf_(); @@ -438,16 +450,16 @@ class warp_sort_distributed : public warp_sort { } private: - __device__ __forceinline__ void set_k_th_() + _RAFT_DEVICE _RAFT_FORCEINLINE void set_k_th_() { // NB on using srcLane: it's ok if it is outside the warp size / width; // the modulo op will be done inside the __shfl_sync. k_th_ = shfl(val_arr_[kMaxArrLen - 1], k - 1, kWarpWidth); } - __device__ __forceinline__ void merge_buf_() + _RAFT_DEVICE _RAFT_FORCEINLINE void merge_buf_() { - topk::bitonic<1>(!Ascending, kWarpWidth).sort(buf_val_, buf_idx_); + util::bitonic<1>(!Ascending, kWarpWidth).sort(buf_val_, buf_idx_); this->merge_in<1>(&buf_val_, &buf_idx_); set_k_th_(); // contains warp sync buf_val_ = kDummy; @@ -464,6 +476,117 @@ class warp_sort_distributed : public warp_sort { T k_th_; }; +/** + * The same as `warp_sort_distributed`, but keeps the temporary value and index buffers + * in the given external pointers (normally, a shared memory pointer should be passed in). + */ +template +class warp_sort_distributed_ext : public warp_sort { + public: + using warp_sort::kDummy; + using warp_sort::kWarpWidth; + using warp_sort::k; + + constexpr static auto mem_required(uint32_t block_size) -> size_t + { + return (sizeof(T) + sizeof(IdxT)) * block_size; + } + + _RAFT_DEVICE warp_sort_distributed_ext(int k, T* val_buf, IdxT* idx_buf, T limit = kDummy) + : warp_sort(k), + val_buf_(val_buf), + idx_buf_(idx_buf), + buf_len_(0), + k_th_(limit) + { + val_buf_[laneId()] = kDummy; + } + + _RAFT_DEVICE static auto init_blockwide(int k, uint8_t* shmem, T limit = kDummy) + { + T* val_buf = nullptr; + IdxT* idx_buf = nullptr; + if constexpr (alignof(T) >= alignof(IdxT)) { + val_buf = reinterpret_cast(shmem); + idx_buf = reinterpret_cast(val_buf + blockDim.x); + } else { + idx_buf = reinterpret_cast(shmem); + val_buf = reinterpret_cast(idx_buf + blockDim.x); + } + auto warp_offset = Pow2::roundDown(threadIdx.x); + val_buf += warp_offset; + idx_buf += warp_offset; + return warp_sort_distributed_ext{k, val_buf, idx_buf, limit}; + } + + _RAFT_DEVICE void add(T val, IdxT idx) + { + bool do_add = is_ordered(val, k_th_); + // mask tells which lanes in the warp have valid items to be added + uint32_t mask = ballot(do_add); + if (mask == 0) { return; } + // where to put the element in the tmp buffer + int dst_ix = buf_len_ + __popc(mask & ((1u << laneId()) - 1u)); + // put all elements, which fit into the current tmp buffer + if (do_add && dst_ix < WarpSize) { + val_buf_[dst_ix] = val; + idx_buf_[dst_ix] = idx; + do_add = false; + } + // Total number of elements to be added + buf_len_ += __popc(mask); + // If the buffer is still not full, we can return + if (buf_len_ < WarpSize) { return; } + // Otherwise, merge the warp tmp buffer into the queue + merge_buf_(); // implies warp sync + buf_len_ -= WarpSize; + // save the inputs that couldn't fit before the merge + if (do_add) { + dst_ix -= WarpSize; + val_buf_[dst_ix] = val; + idx_buf_[dst_ix] = idx; + } + } + + _RAFT_DEVICE void done() + { + if (buf_len_ != 0) { + merge_buf_(); + buf_len_ = 0; + } + __syncthreads(); + } + + private: + _RAFT_DEVICE _RAFT_FORCEINLINE void set_k_th_() + { + // NB on using srcLane: it's ok if it is outside the warp size / width; + // the modulo op will be done inside the __shfl_sync. + k_th_ = shfl(val_arr_[kMaxArrLen - 1], k - 1, kWarpWidth); + } + + _RAFT_DEVICE _RAFT_FORCEINLINE void merge_buf_() + { + __syncwarp(); // make sure the threads are aware of the data written by others + T buf_val = val_buf_[laneId()]; + IdxT buf_idx = idx_buf_[laneId()]; + val_buf_[laneId()] = kDummy; + util::bitonic<1>(!Ascending, kWarpWidth).sort(buf_val, buf_idx); + this->merge_in<1>(&buf_val, &buf_idx); + set_k_th_(); // contains warp sync + } + + using warp_sort::kMaxArrLen; + using warp_sort::val_arr_; + using warp_sort::idx_arr_; + + T* val_buf_; + IdxT* idx_buf_; + uint32_t buf_len_; // 0 <= buf_len_ < WarpSize + + T k_th_; +}; + /** * This version of warp_sort adds every input element into the intermediate sorting * buffer, and thus does the sorting step every `Capacity` input elements. @@ -476,8 +599,10 @@ class warp_sort_immediate : public warp_sort { using warp_sort::kDummy; using warp_sort::kWarpWidth; using warp_sort::k; + using warp_sort::mem_required; - __device__ warp_sort_immediate(int k) : warp_sort(k), buf_len_(0) + explicit _RAFT_DEVICE warp_sort_immediate(int k) + : warp_sort(k), buf_len_(0) { #pragma unroll for (int i = 0; i < kMaxArrLen; i++) { @@ -486,7 +611,12 @@ class warp_sort_immediate : public warp_sort { } } - __device__ void add(T val, IdxT idx) + _RAFT_DEVICE _RAFT_FORCEINLINE static auto init_blockwide(int k, uint8_t* = nullptr) + { + return warp_sort_immediate{k}; + } + + _RAFT_DEVICE void add(T val, IdxT idx) { // NB: the loop is used here to ensure the constant indexing, // to not force the buffers spill into the local memory. @@ -500,7 +630,7 @@ class warp_sort_immediate : public warp_sort { ++buf_len_; if (buf_len_ == kMaxArrLen) { - topk::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); + util::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); this->merge_in(val_buf_, idx_buf_); #pragma unroll for (int i = 0; i < kMaxArrLen; i++) { @@ -510,10 +640,10 @@ class warp_sort_immediate : public warp_sort { } } - __device__ void done() + _RAFT_DEVICE void done() { if (buf_len_ != 0) { - topk::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); + util::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); this->merge_in(val_buf_, idx_buf_); } } @@ -545,15 +675,11 @@ class block_sort { using queue_t = WarpSortWarpWide; template - __device__ block_sort(int k, uint8_t* smem_buf, Args... args) : queue_(k, args...) + _RAFT_DEVICE block_sort(int k, Args... args) : queue_(queue_t::init_blockwide(k, args...)) { - val_smem_ = reinterpret_cast(smem_buf); - const int num_of_warp = subwarp_align::div(blockDim.x); - idx_smem_ = reinterpret_cast( - smem_buf + Pow2<256>::roundUp(ceildiv(num_of_warp, 2) * sizeof(T) * k)); } - __device__ void add(T val, IdxT idx) { queue_.add(val, idx); } + _RAFT_DEVICE void add(T val, IdxT idx) { queue_.add(val, idx); } /** * At the point of calling this function, the warp-level queues consumed all input @@ -561,22 +687,26 @@ class block_sort { * * Here we tree-merge the results using the shared memory and block sync. */ - __device__ void done() + _RAFT_DEVICE void done(uint8_t* smem_buf) { queue_.done(); + int nwarps = subwarp_align::div(blockDim.x); + auto val_smem = reinterpret_cast(smem_buf); + auto idx_smem = reinterpret_cast( + smem_buf + Pow2<256>::roundUp(ceildiv(nwarps, 2) * sizeof(T) * queue_.k)); + const int warp_id = subwarp_align::div(threadIdx.x); // NB: there is no need for the second __synchthreads between .load_sorted and .store: // we shift the pointers every iteration, such that individual warps either access the same // locations or do not overlap with any of the other warps. The access patterns within warps // are different for the two functions, but .load_sorted implies warp sync at the end, so // there is no need for __syncwarp either. - for (int shift_mask = ~0, nwarps = subwarp_align::div(blockDim.x), split = (nwarps + 1) >> 1; - nwarps > 1; + for (int shift_mask = ~0, split = (nwarps + 1) >> 1; nwarps > 1; nwarps = split, split = (nwarps + 1) >> 1) { if (warp_id < nwarps && warp_id >= split) { int dst_warp_shift = (warp_id - (split & shift_mask)) * queue_.k; - queue_.store(val_smem_ + dst_warp_shift, idx_smem_ + dst_warp_shift); + queue_.store(val_smem + dst_warp_shift, idx_smem + dst_warp_shift); } __syncthreads(); @@ -586,23 +716,27 @@ class block_sort { // The last argument serves as a condition for loading // -- to make sure threads within a full warp do not diverge on `bitonic::merge()` queue_.load_sorted( - val_smem_ + src_warp_shift, idx_smem_ + src_warp_shift, warp_id < nwarps - split); + val_smem + src_warp_shift, idx_smem + src_warp_shift, warp_id < nwarps - split); } } } /** Save the content by the pointer location. */ - template - __device__ void store(T* out, IdxT* out_idx, Lambda post_process = raft::identity_op()) const + template + _RAFT_DEVICE void store(OutT* out, + OutIdxT* out_idx, + ValF valF = raft::identity_op{}, + IdxF idxF = raft::identity_op{}) const { - if (threadIdx.x < subwarp_align::Value) { queue_.store(out, out_idx, post_process); } + if (threadIdx.x < subwarp_align::Value) { queue_.store(out, out_idx, valF, idxF); } } private: using subwarp_align = Pow2; queue_t queue_; - T* val_smem_; - IdxT* idx_smem_; }; /** @@ -620,7 +754,10 @@ __launch_bounds__(256) __global__ void block_kernel(const T* in, const IdxT* in_idx, IdxT len, int k, T* out, IdxT* out_idx) { extern __shared__ __align__(256) uint8_t smem_buf_bytes[]; - block_sort queue(k, smem_buf_bytes); + using bq_t = block_sort; + uint8_t* warp_smem = bq_t::queue_t::mem_required(blockDim.x) > 0 ? smem_buf_bytes : nullptr; + bq_t queue(k, warp_smem); + in += blockIdx.y * len; if (in_idx != nullptr) { in_idx += blockIdx.y * len; } @@ -631,7 +768,7 @@ __launch_bounds__(256) __global__ (i < len && in_idx != nullptr) ? __ldcs(in_idx + i) : i); } - queue.done(); + queue.done(smem_buf_bytes); const int block_id = blockIdx.x + gridDim.x * blockIdx.y; queue.store(out + block_id * k, out_idx + block_id * k); } @@ -658,7 +795,7 @@ struct launch_setup { int* min_grid_size, int block_size_limit = 0) { - const int capacity = calc_capacity(k); + const int capacity = bound_by_power_of_two(k); if constexpr (Capacity > 1) { if (capacity < Capacity) { return launch_setup::calc_optimal_params( @@ -691,7 +828,7 @@ struct launch_setup { IdxT* out_idx, rmm::cuda_stream_view stream) { - const int capacity = calc_capacity(k); + const int capacity = bound_by_power_of_two(k); if constexpr (Capacity > 1) { if (capacity < Capacity) { return launch_setup::kernel(k, @@ -742,6 +879,18 @@ struct LaunchThreshold { static constexpr int len_factor_for_single_block = 32; }; +template <> +struct LaunchThreshold { + static constexpr int len_factor_for_multi_block = 2; + static constexpr int len_factor_for_single_block = 32; +}; + +template <> +struct LaunchThreshold { + static constexpr int len_factor_for_multi_block = 2; + static constexpr int len_factor_for_single_block = 32; +}; + template <> struct LaunchThreshold { static constexpr int len_factor_for_choosing = 4; @@ -753,7 +902,7 @@ template