Skip to content

Commit

Permalink
Add specializations to the matrix::detail::select_k
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed Feb 18, 2023
1 parent 02cfacf commit 49dfdfa
Show file tree
Hide file tree
Showing 12 changed files with 210 additions and 1 deletion.
4 changes: 4 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,10 @@ if(RAFT_COMPILE_DIST_LIBRARY)
src/distance/distance/specializations/fused_l2_nn_double_int64.cu
src/distance/distance/specializations/fused_l2_nn_float_int.cu
src/distance/distance/specializations/fused_l2_nn_float_int64.cu
src/distance/matrix/specializations/detail/select_k_float_uint32_t.cu
src/distance/matrix/specializations/detail/select_k_float_uint64_t.cu
src/distance/matrix/specializations/detail/select_k_half_uint32_t.cu
src/distance/matrix/specializations/detail/select_k_half_uint64_t.cu
src/distance/neighbors/ivfpq_build.cu
src/distance/neighbors/ivfpq_deserialize.cu
src/distance/neighbors/ivfpq_serialize.cu
Expand Down
47 changes: 47 additions & 0 deletions cpp/include/raft/matrix/specializations/detail/select_k.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/matrix/detail/select_k.cuh>

#include <cuda_fp16.h>

namespace raft::matrix::detail {

#define RAFT_INST(T, IdxT) \
extern template void select_k<T, IdxT>(const T*, \
const IdxT*, \
size_t, \
size_t, \
int, \
T*, \
IdxT*, \
bool, \
rmm::cuda_stream_view, \
rmm::mr::device_memory_resource*);

// Commonly used types
RAFT_INST(float, uint64_t);
RAFT_INST(half, uint64_t);

// These instances are used in the ivf_pq::search parameterized by the internal_distance_dtype
RAFT_INST(float, uint32_t);
RAFT_INST(half, uint32_t);

#undef RAFT_INST

} // namespace raft::matrix::detail
3 changes: 2 additions & 1 deletion cpp/include/raft/neighbors/detail/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@
#include <raft/linalg/map.cuh>
#include <raft/linalg/norm.cuh>
#include <raft/linalg/unary_op.cuh>
#include <raft/matrix/matrix.cuh>
#include <raft/matrix/gather.cuh>
#include <raft/matrix/linewise_op.cuh>
#include <raft/random/rng.cuh>
#include <raft/stats/histogram.cuh>
#include <raft/util/cuda_utils.cuh>
Expand Down
1 change: 1 addition & 0 deletions cpp/include/raft/neighbors/specializations/ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include <raft/matrix/specializations/detail/select_k.cuh>
#include <raft/neighbors/ivf_pq.cuh>
#include <raft/neighbors/specializations/detail/ivf_pq_compute_similarity.cuh>

Expand Down
6 changes: 6 additions & 0 deletions cpp/internal/raft_internal/matrix/select_k.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,16 @@
* limitations under the License.
*/

#pragma once

#include <raft/matrix/detail/select_radix.cuh>
#include <raft/matrix/detail/select_warpsort.cuh>
#include <raft/matrix/select_k.cuh>

#if defined RAFT_DISTANCE_COMPILED
#include <raft/matrix/specializations/detail/select_k.cuh>
#endif

#include <raft/core/device_resources.hpp>

namespace raft::matrix::select {
Expand Down
4 changes: 4 additions & 0 deletions cpp/internal/raft_internal/neighbors/naive_knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
#include <raft/spatial/knn/detail/ann_utils.cuh>
#include <raft/util/cuda_utils.cuh>

#if defined RAFT_DISTANCE_COMPILED
#include <raft/matrix/specializations/detail/select_k.cuh>
#endif

#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/matrix/specializations/detail/select_k.cuh>

namespace raft::matrix::detail {

#define RAFT_INST(T, IdxT) \
template void select_k<T, IdxT>(const T*, \
const IdxT*, \
size_t, \
size_t, \
int, \
T*, \
IdxT*, \
bool, \
rmm::cuda_stream_view, \
rmm::mr::device_memory_resource*);

RAFT_INST(float, uint32_t);

} // namespace raft::matrix::detail
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/matrix/specializations/detail/select_k.cuh>

namespace raft::matrix::detail {

#define RAFT_INST(T, IdxT) \
template void select_k<T, IdxT>(const T*, \
const IdxT*, \
size_t, \
size_t, \
int, \
T*, \
IdxT*, \
bool, \
rmm::cuda_stream_view, \
rmm::mr::device_memory_resource*);

RAFT_INST(float, uint64_t);

} // namespace raft::matrix::detail
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/matrix/specializations/detail/select_k.cuh>

namespace raft::matrix::detail {

#define RAFT_INST(T, IdxT) \
template void select_k<T, IdxT>(const T*, \
const IdxT*, \
size_t, \
size_t, \
int, \
T*, \
IdxT*, \
bool, \
rmm::cuda_stream_view, \
rmm::mr::device_memory_resource*);

RAFT_INST(half, uint32_t);

} // namespace raft::matrix::detail
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/matrix/specializations/detail/select_k.cuh>

namespace raft::matrix::detail {

#define RAFT_INST(T, IdxT) \
template void select_k<T, IdxT>(const T*, \
const IdxT*, \
size_t, \
size_t, \
int, \
T*, \
IdxT*, \
bool, \
rmm::cuda_stream_view, \
rmm::mr::device_memory_resource*);

RAFT_INST(half, uint64_t);

} // namespace raft::matrix::detail
3 changes: 3 additions & 0 deletions cpp/test/neighbors/ann_ivf_flat.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@

#include <thrust/sequence.h>

#if defined RAFT_DISTANCE_COMPILED
#include <raft/matrix/specializations/detail/select_k.cuh>
#endif
#if defined RAFT_DISTANCE_COMPILED && defined RAFT_NN_COMPILED
#include <raft/cluster/specializations.cuh>
#endif
Expand Down
3 changes: 3 additions & 0 deletions cpp/test/neighbors/selection.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@

#include <raft/sparse/detail/utils.h>
#include <raft/spatial/knn/knn.cuh>
#if defined RAFT_DISTANCE_COMPILED
#include <raft/matrix/specializations/detail/select_k.cuh>
#endif
#if defined RAFT_NN_COMPILED
#include <raft/neighbors/specializations.cuh>
#endif
Expand Down

0 comments on commit 49dfdfa

Please sign in to comment.