From 538440897bae7c967c842da37476daaf534993f6 Mon Sep 17 00:00:00 2001 From: Mahesh Doijade Date: Fri, 16 Feb 2024 07:15:54 -0800 Subject: [PATCH] remove double datatype API, code cleanup and other review comments --- .../fused_distance_nn/persistent_gemm.h | 1 - .../predicated_tile_iterator_reduced_vec.h | 16 --------------- .../raft/distance/fused_distance_nn-ext.cuh | 10 +--------- .../raft/distance/fused_distance_nn-inl.cuh | 2 +- .../raft/distance/fused_distance_nn.cuh | 2 +- .../distance/fused_distance_nn_helpers.cuh | 1 - .../distance/fused_distance_nn.hpp | 12 ----------- cpp/src/distance/fused_distance_nn.cu | 8 -------- .../distance/fused_distance_min_arg.cu | 20 ------------------- cpp/test/distance/fused_cosine_nn.cu | 2 ++ 10 files changed, 5 insertions(+), 69 deletions(-) diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h index a04fe36b79..223af7eb58 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h @@ -205,7 +205,6 @@ struct FusedDistanceNNPersistent { void const* ptr_B, void const* ptr_C, void* ptr_Vector, - // volatile void* ptr_Tensor, void* ptr_Tensor, typename LayoutA::Stride::Index lda, typename LayoutB::Stride::Index ldb, diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h index 4591fa7855..81e7819223 100644 --- a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h +++ b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h @@ -321,11 +321,8 @@ class PredicatedTileIteratorReducedVec { /// Parameters structure containing reference and precomputed state. Params params_; - /// Byte-level pointer - // uint8_t* byte_pointer_; /// Byte-level pointer first tile offset of this threadblock. volatile uint8_t* first_tile_byte_pointer_; - // uint8_t* first_tile_byte_pointer_; /// Array of boolean values to contain steady-state predicates Mask mask_; @@ -421,8 +418,6 @@ class PredicatedTileIteratorReducedVec { first_tile_byte_pointer_ = reinterpret_cast(pointer) + LongIndex(block_offset.row()) * LongIndex(params_.stride); - // first_tile_byte_pointer_ = reinterpret_cast(pointer) + - // LongIndex(block_offset.row()) * LongIndex(params_.stride); // Initialize internal state counter state_[0] = state_[1] = state_[2] = 0; } @@ -468,13 +463,6 @@ class PredicatedTileIteratorReducedVec { CUTLASS_DEVICE ~PredicatedTileIteratorReducedVec() {} - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) - { - // byte_pointer_ += pointer_offset * sizeof_bits::value / 8; - } - /// Performs reduction and Stores a reduced output to memory CUTLASS_DEVICE void store_with_byte_offset(Fragment& frag, int64_t byte_offset) const @@ -574,14 +562,12 @@ class PredicatedTileIteratorReducedVec { PredicatedTileIteratorReducedVec& operator++() { ++state_[0]; - // if (!ScatterD) { byte_pointer_ += params_.advance_row; } thread_start_row_ += ThreadMap::Shape::kRow; if (state_[0] == ThreadMap::Count::kRow) { state_[0] = 0; ++state_[1]; - // byte_pointer_ += params_.advance_group; thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; @@ -589,14 +575,12 @@ class PredicatedTileIteratorReducedVec { if (state_[1] == ThreadMap::Count::kGroup) { state_[1] = 0; ++state_[2]; - // byte_pointer_ += params_.advance_cluster; thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * ThreadMap::Shape::kRow; if (state_[2] == ThreadMap::Count::kCluster) { state_[2] = 0; - // byte_pointer_ += params_.advance_tile; } } } diff --git a/cpp/include/raft/distance/fused_distance_nn-ext.cuh b/cpp/include/raft/distance/fused_distance_nn-ext.cuh index 9dd236a3bd..0b9096423a 100644 --- a/cpp/include/raft/distance/fused_distance_nn-ext.cuh +++ b/cpp/include/raft/distance/fused_distance_nn-ext.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -67,20 +67,12 @@ void fusedDistanceNNMinReduce(OutT* min, float metric_arg, \ cudaStream_t stream) -instantiate_raft_distance_fusedDistanceNNMinReduce(double, double, int); -instantiate_raft_distance_fusedDistanceNNMinReduce(double, double, int64_t); instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int); instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int64_t); // We can't have comma's in the macro expansion, so we use the COMMA macro: #define COMMA , -instantiate_raft_distance_fusedDistanceNNMinReduce(double, - raft::KeyValuePair, - int); -instantiate_raft_distance_fusedDistanceNNMinReduce(double, - raft::KeyValuePair, - int64_t); instantiate_raft_distance_fusedDistanceNNMinReduce(float, raft::KeyValuePair, int); instantiate_raft_distance_fusedDistanceNNMinReduce(float, raft::KeyValuePair, diff --git a/cpp/include/raft/distance/fused_distance_nn-inl.cuh b/cpp/include/raft/distance/fused_distance_nn-inl.cuh index 342bde828d..5ec4b8c5cf 100644 --- a/cpp/include/raft/distance/fused_distance_nn-inl.cuh +++ b/cpp/include/raft/distance/fused_distance_nn-inl.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/distance/fused_distance_nn.cuh b/cpp/include/raft/distance/fused_distance_nn.cuh index 0c22df72f1..04c42e49a1 100755 --- a/cpp/include/raft/distance/fused_distance_nn.cuh +++ b/cpp/include/raft/distance/fused_distance_nn.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. + * Copyright (c) 2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/distance/fused_distance_nn_helpers.cuh b/cpp/include/raft/distance/fused_distance_nn_helpers.cuh index e70d098d09..3a570c681c 100644 --- a/cpp/include/raft/distance/fused_distance_nn_helpers.cuh +++ b/cpp/include/raft/distance/fused_distance_nn_helpers.cuh @@ -17,7 +17,6 @@ #pragma once #include -// #include #include namespace raft::distance { diff --git a/cpp/include/raft_runtime/distance/fused_distance_nn.hpp b/cpp/include/raft_runtime/distance/fused_distance_nn.hpp index 6580cfa639..09d8d401e4 100644 --- a/cpp/include/raft_runtime/distance/fused_distance_nn.hpp +++ b/cpp/include/raft_runtime/distance/fused_distance_nn.hpp @@ -57,18 +57,6 @@ void fused_distance_nn_min_arg(raft::resources const& handle, bool isRowMajor, float metric_arg); -void fused_distance_nn_min_arg(raft::resources const& handle, - int* min, - const double* x, - const double* y, - int m, - int n, - int k, - bool sqrt, - raft::distance::DistanceType metric, - bool isRowMajor, - float metric_arg); - /** @} */ // end group fused_distance_nn_min_arg_runtime } // end namespace raft::runtime::distance diff --git a/cpp/src/distance/fused_distance_nn.cu b/cpp/src/distance/fused_distance_nn.cu index c3d1301e29..fc8a6cb26d 100644 --- a/cpp/src/distance/fused_distance_nn.cu +++ b/cpp/src/distance/fused_distance_nn.cu @@ -36,20 +36,12 @@ float metric_arg, \ cudaStream_t stream) -instantiate_raft_distance_fusedDistanceNNMinReduce(double, double, int); -instantiate_raft_distance_fusedDistanceNNMinReduce(double, double, int64_t); instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int); instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int64_t); // We can't have comma's in the macro expansion, so we use the COMMA macro: #define COMMA , -instantiate_raft_distance_fusedDistanceNNMinReduce(double, - raft::KeyValuePair, - int); -instantiate_raft_distance_fusedDistanceNNMinReduce(double, - raft::KeyValuePair, - int64_t); instantiate_raft_distance_fusedDistanceNNMinReduce(float, raft::KeyValuePair, int); instantiate_raft_distance_fusedDistanceNNMinReduce(float, raft::KeyValuePair, diff --git a/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu b/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu index 90d00d9f6b..1899b1616f 100644 --- a/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu +++ b/cpp/src/raft_runtime/distance/fused_distance_min_arg.cu @@ -114,24 +114,4 @@ void fused_distance_nn_min_arg(raft::resources const& handle, } } -void fused_distance_nn_min_arg(raft::resources const& handle, - int* min, - const double* x, - const double* y, - int m, - int n, - int k, - bool sqrt, - raft::distance::DistanceType metric, - bool isRowMajor, - float metric_arg) -{ - switch (metric) { - case raft::distance::DistanceType::CosineExpanded: - compute_fused_cosine_nn_min_arg(handle, min, x, y, m, n, k, sqrt); - break; - default: assert("only cosine metric is supported with fusedDistanceNN\n"); break; - } -} - } // end namespace raft::runtime::distance diff --git a/cpp/test/distance/fused_cosine_nn.cu b/cpp/test/distance/fused_cosine_nn.cu index 5a89e71608..e87692094f 100644 --- a/cpp/test/distance/fused_cosine_nn.cu +++ b/cpp/test/distance/fused_cosine_nn.cu @@ -14,6 +14,8 @@ * limitations under the License. */ +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Search with filter instantiation + #include "../test_utils.cuh" #include #include