Skip to content

Commit

Permalink
Fixing select_k specializations (rapidsai#1330)
Browse files Browse the repository at this point in the history
Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Ben Frederickson (https://github.com/benfred)

URL: rapidsai#1330
  • Loading branch information
cjnolet authored and lowener committed Mar 15, 2023
1 parent a53dd52 commit 045dd3c
Show file tree
Hide file tree
Showing 156 changed files with 390 additions and 173 deletions.
2 changes: 2 additions & 0 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ if hasArg tests || (( ${NUMARGS} == 0 )); then
$CMAKE_TARGET == *"DISTANCE_TEST"* || \
$CMAKE_TARGET == *"SPARSE_DIST_TEST" || \
$CMAKE_TARGET == *"SPARSE_NEIGHBORS_TEST"* || \
$CMAKE_TARGET == *"MATRIX_TEST"* || \
$CMAKE_TARGET == *"NEIGHBORS_TEST" || \
$CMAKE_TARGET == *"STATS_TEST"* ]]; then
echo "-- Enabling distance lib for gtests"
Expand All @@ -329,6 +330,7 @@ if hasArg bench || (( ${NUMARGS} == 0 )); then

# Force compile distance library when needed benchmark targets are specified
if [[ $CMAKE_TARGET == *"CLUSTER_BENCH"* || \
$CMAKE_TARGET == *"MATRIX_BENCH"* || \
$CMAKE_TARGET == *"NEIGHBORS_BENCH"* ]]; then
echo "-- Enabling distance lib for benchmarks"
COMPILE_DIST_LIBRARY=ON
Expand Down
4 changes: 2 additions & 2 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -361,9 +361,9 @@ if(RAFT_COMPILE_DIST_LIBRARY)
src/distance/distance/specializations/fused_l2_nn_float_int.cu
src/distance/distance/specializations/fused_l2_nn_float_int64.cu
src/distance/matrix/specializations/detail/select_k_float_uint32_t.cu
src/distance/matrix/specializations/detail/select_k_float_uint64_t.cu
src/distance/matrix/specializations/detail/select_k_float_int64_t.cu
src/distance/matrix/specializations/detail/select_k_half_uint32_t.cu
src/distance/matrix/specializations/detail/select_k_half_uint64_t.cu
src/distance/matrix/specializations/detail/select_k_half_int64_t.cu
src/distance/neighbors/ivfpq_build.cu
src/distance/neighbors/ivfpq_deserialize.cu
src/distance/neighbors/ivfpq_serialize.cu
Expand Down
2 changes: 1 addition & 1 deletion cpp/bench/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ if(BUILD_BENCH)

ConfigureBench(
NAME MATRIX_BENCH PATH bench/matrix/argmin.cu bench/matrix/gather.cu bench/matrix/select_k.cu
bench/main.cpp
bench/main.cpp OPTIONAL DIST
)

ConfigureBench(
Expand Down
4 changes: 2 additions & 2 deletions cpp/bench/cluster/kmeans.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,7 +19,7 @@
#include <raft/cluster/kmeans_types.hpp>

#if defined RAFT_DISTANCE_COMPILED
#include <raft/distance/specializations.cuh>
#include <raft/cluster/specializations.cuh>
#endif

namespace raft::bench::cluster {
Expand Down
4 changes: 2 additions & 2 deletions cpp/bench/cluster/kmeans_balanced.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,7 +19,7 @@
#include <raft/random/rng.cuh>

#if defined RAFT_DISTANCE_COMPILED
#include <raft/distance/specializations.cuh>
#include <raft/cluster/specializations.cuh>
#endif

namespace raft::bench::cluster {
Expand Down
4 changes: 2 additions & 2 deletions cpp/bench/distance/masked_nn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
#include <raft/random/rng.cuh>
#include <raft/util/cudart_utils.hpp>

#if defined RAFT_NN_COMPILED
#include <raft/spatial/knn/specializations.hpp>
#ifdef RAFT_DISTANCE_COMPILED
#include <raft/distance/specializations.cuh>
#endif

namespace raft::bench::distance::masked_nn {
Expand Down
42 changes: 23 additions & 19 deletions cpp/bench/matrix/select_k.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
#include <raft/sparse/detail/utils.h>
#include <raft/util/cudart_utils.hpp>

#if defined RAFT_DISTANCE_COMPILED
#include <raft/matrix/specializations.cuh>
#endif

#include <raft/matrix/detail/select_radix.cuh>
#include <raft/matrix/detail/select_warpsort.cuh>
#include <raft/matrix/select_k.cuh>
Expand Down Expand Up @@ -105,24 +109,24 @@ const std::vector<select::params> kInputs{
RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #A, kInputs); \
}

SELECTION_REGISTER(float, int, kPublicApi); // NOLINT
SELECTION_REGISTER(float, int, kRadix8bits); // NOLINT
SELECTION_REGISTER(float, int, kRadix11bits); // NOLINT
SELECTION_REGISTER(float, int, kWarpAuto); // NOLINT
SELECTION_REGISTER(float, int, kWarpImmediate); // NOLINT
SELECTION_REGISTER(float, int, kWarpFiltered); // NOLINT
SELECTION_REGISTER(float, int, kWarpDistributed); // NOLINT
SELECTION_REGISTER(float, int, kWarpDistributedShm); // NOLINT

SELECTION_REGISTER(double, int, kRadix8bits); // NOLINT
SELECTION_REGISTER(double, int, kRadix11bits); // NOLINT
SELECTION_REGISTER(double, int, kWarpAuto); // NOLINT

SELECTION_REGISTER(double, size_t, kRadix8bits); // NOLINT
SELECTION_REGISTER(double, size_t, kRadix11bits); // NOLINT
SELECTION_REGISTER(double, size_t, kWarpImmediate); // NOLINT
SELECTION_REGISTER(double, size_t, kWarpFiltered); // NOLINT
SELECTION_REGISTER(double, size_t, kWarpDistributed); // NOLINT
SELECTION_REGISTER(double, size_t, kWarpDistributedShm); // NOLINT
SELECTION_REGISTER(float, uint32_t, kPublicApi); // NOLINT
SELECTION_REGISTER(float, uint32_t, kRadix8bits); // NOLINT
SELECTION_REGISTER(float, uint32_t, kRadix11bits); // NOLINT
SELECTION_REGISTER(float, uint32_t, kWarpAuto); // NOLINT
SELECTION_REGISTER(float, uint32_t, kWarpImmediate); // NOLINT
SELECTION_REGISTER(float, uint32_t, kWarpFiltered); // NOLINT
SELECTION_REGISTER(float, uint32_t, kWarpDistributed); // NOLINT
SELECTION_REGISTER(float, uint32_t, kWarpDistributedShm); // NOLINT

SELECTION_REGISTER(double, uint32_t, kRadix8bits); // NOLINT
SELECTION_REGISTER(double, uint32_t, kRadix11bits); // NOLINT
SELECTION_REGISTER(double, uint32_t, kWarpAuto); // NOLINT

SELECTION_REGISTER(double, int64_t, kRadix8bits); // NOLINT
SELECTION_REGISTER(double, int64_t, kRadix11bits); // NOLINT
SELECTION_REGISTER(double, int64_t, kWarpImmediate); // NOLINT
SELECTION_REGISTER(double, int64_t, kWarpFiltered); // NOLINT
SELECTION_REGISTER(double, int64_t, kWarpDistributed); // NOLINT
SELECTION_REGISTER(double, int64_t, kWarpDistributedShm); // NOLINT

} // namespace raft::matrix
6 changes: 4 additions & 2 deletions cpp/bench/neighbors/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
#include <raft/spatial/knn/knn.cuh>

#if defined RAFT_DISTANCE_COMPILED
#include <raft/distance/specializations.cuh>
#include <raft/neighbors/specializations/ivf_pq.cuh>
#include <raft/neighbors/specializations.cuh>

// TODO: Legacy. Remove when FAISS is removed
#include <raft/spatial/knn/specializations.cuh>
#endif

#if defined RAFT_NN_COMPILED
Expand Down
1 change: 0 additions & 1 deletion cpp/bench/neighbors/refine_float_int64_t.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include <common/benchmark.hpp>

#if defined RAFT_DISTANCE_COMPILED
#include <raft/distance/specializations.cuh>
#include <raft/neighbors/specializations/refine.cuh>
#endif

Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/cluster/specializations.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -19,6 +19,6 @@
#pragma once

#include <raft/distance/specializations.cuh>
#include <raft/spatial/knn/specializations.cuh>
#include <raft/neighbors/specializations.cuh>

#endif
19 changes: 19 additions & 0 deletions cpp/include/raft/matrix/specializations.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/matrix/specializations/detail/select_k.cuh>
4 changes: 2 additions & 2 deletions cpp/include/raft/matrix/specializations/detail/select_k.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ namespace raft::matrix::detail {
rmm::mr::device_memory_resource*);

// Commonly used types
RAFT_INST(float, uint64_t);
RAFT_INST(half, uint64_t);
RAFT_INST(float, int64_t);
RAFT_INST(half, int64_t);

// These instances are used in the ivf_pq::search parameterized by the internal_distance_dtype
RAFT_INST(float, uint32_t);
Expand Down
11 changes: 4 additions & 7 deletions cpp/include/raft/neighbors/specializations.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,11 @@
* limitations under the License.
*/

#ifndef __KNN_SPECIALIZATIONS_H
#define __KNN_SPECIALIZATIONS_H

#pragma once

#include <raft/neighbors/specializations/ball_cover.cuh>
#include <raft/neighbors/specializations/fused_l2_knn.cuh>
#include <raft/neighbors/specializations/ivf_pq.cuh>
#include <raft/neighbors/specializations/knn.cuh>
#include <raft/neighbors/specializations/refine.cuh>

#endif
#include <raft/cluster/specializations.cuh>
#include <raft/distance/specializations.cuh>
#include <raft/matrix/specializations.cuh>
1 change: 1 addition & 0 deletions cpp/include/raft/neighbors/specializations/ball_cover.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <raft/neighbors/ball_cover.cuh>
#include <raft/neighbors/ball_cover_types.hpp>
#include <raft/neighbors/specializations/detail/ball_cover_lowdim.hpp>

#include <cstdint>

Expand Down
4 changes: 3 additions & 1 deletion cpp/include/raft/neighbors/specializations/ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

#pragma once

#include <raft/matrix/specializations/detail/select_k.cuh>
#include <raft/cluster/specializations.cuh>
#include <raft/distance/specializations.cuh>
#include <raft/matrix/specializations.cuh>
#include <raft/neighbors/ivf_pq.cuh>
#include <raft/neighbors/specializations/detail/ivf_pq_compute_similarity.cuh>

Expand Down
4 changes: 2 additions & 2 deletions cpp/internal/raft_internal/matrix/select_k.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
#include <raft/matrix/detail/select_warpsort.cuh>
#include <raft/matrix/select_k.cuh>

#if defined RAFT_DISTANCE_COMPILED
#include <raft/matrix/specializations/detail/select_k.cuh>
#ifdef RAFT_DISTANCE_COMPILED
#include <raft/matrix/specializations.cuh>
#endif

#include <raft/core/device_resources.hpp>
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/distance/cluster/cluster_cost_double.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
*/

#include "cluster_cost.cuh"
#include <raft/cluster/specializations.cuh>
#include <raft/core/device_resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/distance/specializations.cuh>

namespace raft::runtime::cluster::kmeans {

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/distance/cluster/cluster_cost_float.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
*/

#include "cluster_cost.cuh"
#include <raft/cluster/specializations.cuh>
#include <raft/core/device_resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/distance/specializations.cuh>

namespace raft::runtime::cluster::kmeans {

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/distance/cluster/kmeans_fit_double.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
*/

#include <raft/cluster/kmeans.cuh>
#include <raft/cluster/specializations.cuh>
#include <raft/core/device_resources.hpp>
#include <raft/distance/specializations.cuh>

namespace raft::runtime::cluster::kmeans {

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/distance/cluster/kmeans_fit_float.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
*/

#include <raft/cluster/kmeans.cuh>
#include <raft/cluster/specializations.cuh>
#include <raft/core/device_resources.hpp>
#include <raft/distance/specializations.cuh>

namespace raft::runtime::cluster::kmeans {

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/distance/cluster/kmeans_init_plus_plus_double.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
*/

#include <raft/cluster/kmeans.cuh>
#include <raft/cluster/specializations.cuh>
#include <raft/core/device_resources.hpp>
#include <raft/distance/specializations.cuh>

namespace raft::runtime::cluster::kmeans {

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/distance/cluster/kmeans_init_plus_plus_float.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
*/

#include <raft/cluster/kmeans.cuh>
#include <raft/cluster/specializations.cuh>
#include <raft/core/device_resources.hpp>
#include <raft/distance/specializations.cuh>

namespace raft::runtime::cluster::kmeans {

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/distance/cluster/update_centroids.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
*/

#include <raft/cluster/kmeans.cuh>
#include <raft/cluster/specializations.cuh>
#include <raft/core/device_resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/distance/specializations.cuh>
#include <raft/linalg/norm.cuh>

namespace raft::runtime::cluster::kmeans {
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/distance/cluster/update_centroids_double.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
*/

#include "update_centroids.cuh"
#include <raft/cluster/specializations.cuh>
#include <raft/core/device_resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/distance/specializations.cuh>

namespace raft::runtime::cluster::kmeans {

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/distance/cluster/update_centroids_float.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
*/

#include "update_centroids.cuh"
#include <raft/cluster/specializations.cuh>
#include <raft/core/device_resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/distance/specializations.cuh>

namespace raft::runtime::cluster::kmeans {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include <raft/distance/detail/distance.cuh>
#include <raft/distance/specializations.cuh>

namespace raft {
namespace distance {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include <raft/distance/detail/distance.cuh>
#include <raft/distance/specializations.cuh>

namespace raft {
namespace distance {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include <raft/distance/detail/distance.cuh>
#include <raft/distance/specializations.cuh>

namespace raft {
namespace distance {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include <raft/distance/detail/distance.cuh>
#include <raft/distance/specializations.cuh>

namespace raft {
namespace distance {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include <raft/distance/detail/distance.cuh>
#include <raft/distance/specializations.cuh>

namespace raft {
namespace distance {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include <raft/distance/detail/distance.cuh>
#include <raft/distance/specializations.cuh>

namespace raft {
namespace distance {
Expand Down
Loading

0 comments on commit 045dd3c

Please sign in to comment.