Skip to content

Commit

Permalink
use matrix::select_k in brute_force::knn call (#1463)
Browse files Browse the repository at this point in the history
Authors:
  - Ben Frederickson (https://github.com/benfred)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1463
  • Loading branch information
benfred authored May 17, 2023
1 parent 618dc23 commit 6fdb041
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 26 deletions.
9 changes: 3 additions & 6 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ if(RAFT_COMPILE_LIBRARY)
src/matrix/detail/select_k_double_uint32_t.cu
src/matrix/detail/select_k_float_int64_t.cu
src/matrix/detail/select_k_float_uint32_t.cu
src/matrix/detail/select_k_float_int32.cu
src/matrix/detail/select_k_half_int64_t.cu
src/matrix/detail/select_k_half_uint32_t.cu
src/neighbors/ball_cover.cu
Expand Down Expand Up @@ -600,9 +601,7 @@ target_link_libraries(raft::raft INTERFACE
# Use `rapids_export` for 22.04 as it will have COMPONENT support
rapids_export(
INSTALL raft
EXPORT_SET raft-exports
COMPONENTS ${raft_components}
COMPONENTS_EXPORT_SET ${raft_export_sets}
EXPORT_SET raft-exports COMPONENTS ${raft_components} COMPONENTS_EXPORT_SET ${raft_export_sets}
GLOBAL_TARGETS raft compiled distributed
NAMESPACE raft::
DOCUMENTATION doc_string
Expand All @@ -613,9 +612,7 @@ rapids_export(
# * build export -------------------------------------------------------------
rapids_export(
BUILD raft
EXPORT_SET raft-exports
COMPONENTS ${raft_components}
COMPONENTS_EXPORT_SET ${raft_export_sets}
EXPORT_SET raft-exports COMPONENTS ${raft_components} COMPONENTS_EXPORT_SET ${raft_export_sets}
GLOBAL_TARGETS raft compiled distributed
DOCUMENTATION doc_string
NAMESPACE raft::
Expand Down
2 changes: 2 additions & 0 deletions cpp/include/raft/matrix/detail/select_k-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ instantiate_raft_matrix_detail_select_k(__half, uint32_t);
instantiate_raft_matrix_detail_select_k(__half, int64_t);
instantiate_raft_matrix_detail_select_k(float, int64_t);
instantiate_raft_matrix_detail_select_k(float, uint32_t);
// needed for brute force knn
instantiate_raft_matrix_detail_select_k(float, int);
// We did not have these two for double before, but there are tests for them. We
// therefore include them here.
instantiate_raft_matrix_detail_select_k(double, int64_t);
Expand Down
42 changes: 22 additions & 20 deletions cpp/include/raft/neighbors/detail/knn_brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,9 @@
#include <raft/linalg/map.cuh>
#include <raft/linalg/transpose.cuh>
#include <raft/matrix/init.cuh>
#include <raft/matrix/select_k.cuh>
#include <raft/neighbors/detail/faiss_select/DistanceUtils.h>
#include <raft/neighbors/detail/faiss_select/Select.cuh>
#include <raft/neighbors/detail/knn_merge_parts.cuh>
#include <raft/neighbors/detail/selection_faiss.cuh>
#include <raft/spatial/knn/detail/fused_l2_knn.cuh>
#include <raft/spatial/knn/detail/haversine_distance.cuh>
#include <raft/spatial/knn/detail/processing.cuh>
Expand Down Expand Up @@ -230,15 +229,16 @@ void tiled_brute_force_knn(const raft::resources& handle,
}
}

select_k<IndexType, ElementType>(temp_distances.data(),
nullptr,
current_query_size,
current_centroid_size,
distances + i * k,
indices + i * k,
select_min,
current_k,
stream);
matrix::select_k<ElementType, IndexType>(
handle,
raft::make_device_matrix_view<const ElementType, int64_t, row_major>(
temp_distances.data(), current_query_size, current_centroid_size),
std::nullopt,
raft::make_device_matrix_view<ElementType, int64_t, row_major>(
distances + i * k, current_query_size, current_k),
raft::make_device_matrix_view<IndexType, int64_t, row_major>(
indices + i * k, current_query_size, current_k),
select_min);

// if we're tiling over columns, we need to do a couple things to fix up
// the output of select_k
Expand Down Expand Up @@ -270,15 +270,17 @@ void tiled_brute_force_knn(const raft::resources& handle,

if (tile_cols != n) {
// select the actual top-k items here from the temporary output
select_k<IndexType, ElementType>(temp_out_distances.data(),
temp_out_indices.data(),
current_query_size,
temp_out_cols,
distances + i * k,
indices + i * k,
select_min,
k,
stream);
matrix::select_k<ElementType, IndexType>(
handle,
raft::make_device_matrix_view<const ElementType, int64_t, row_major>(
temp_out_distances.data(), current_query_size, temp_out_cols),
raft::make_device_matrix_view<const IndexType, int64_t, row_major>(
temp_out_indices.data(), current_query_size, temp_out_cols),
raft::make_device_matrix_view<ElementType, int64_t, row_major>(
distances + i * k, current_query_size, k),
raft::make_device_matrix_view<IndexType, int64_t, row_major>(
indices + i * k, current_query_size, k),
select_min);
}
}
}
Expand Down
33 changes: 33 additions & 0 deletions cpp/src/matrix/detail/select_k_float_int32.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/matrix/detail/select_k-inl.cuh>

#define instantiate_raft_matrix_detail_select_k(T, IdxT) \
template void raft::matrix::detail::select_k(const T* in_val, \
const IdxT* in_idx, \
size_t batch_size, \
size_t len, \
int k, \
T* out_val, \
IdxT* out_idx, \
bool select_min, \
rmm::cuda_stream_view stream, \
rmm::mr::device_memory_resource* mr)

instantiate_raft_matrix_detail_select_k(float, int);

#undef instantiate_raft_matrix_detail_select_k

0 comments on commit 6fdb041

Please sign in to comment.