Skip to content

Commit

Permalink
Work in progress, build defined for float inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
tfeher committed Nov 1, 2022
1 parent 0b35bf7 commit cf40569
Show file tree
Hide file tree
Showing 8 changed files with 517 additions and 157 deletions.
124 changes: 63 additions & 61 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -245,57 +245,57 @@ if(RAFT_COMPILE_DIST_LIBRARY)
add_library(raft_distance_lib
src/distance/pairwise_distance.cu
src/distance/fused_l2_min_arg.cu
# src/distance/specializations/detail/canberra.cu
# src/distance/specializations/detail/chebyshev.cu
# src/distance/specializations/detail/correlation.cu
# src/distance/specializations/detail/cosine.cu
# src/distance/specializations/detail/cosine.cu
# src/distance/specializations/detail/hamming_unexpanded.cu
# src/distance/specializations/detail/hellinger_expanded.cu
# src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu
# src/distance/specializations/detail/jensen_shannon_float_float_float_uint32.cu
# src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu
# src/distance/specializations/detail/kernels/gram_matrix_base_double.cu
# src/distance/specializations/detail/kernels/gram_matrix_base_float.cu
# src/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu
# src/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu
# # These are somehow missing a kernel definition which is causing a compile error.
# # src/distance/specializations/detail/kernels/rbf_kernel_double.cu
# # src/distance/specializations/detail/kernels/rbf_kernel_float.cu
# src/distance/specializations/detail/kernels/tanh_kernel_double.cu
# src/distance/specializations/detail/kernels/tanh_kernel_float.cu
# src/distance/specializations/detail/kl_divergence_float_float_float_int.cu
# src/distance/specializations/detail/kl_divergence_float_float_float_uint32.cu
# src/distance/specializations/detail/kl_divergence_double_double_double_int.cu
# src/distance/specializations/detail/l1_float_float_float_int.cu
# src/distance/specializations/detail/l1_float_float_float_uint32.cu
# src/distance/specializations/detail/l1_double_double_double_int.cu
src/distance/specializations/detail/canberra.cu
src/distance/specializations/detail/chebyshev.cu
src/distance/specializations/detail/correlation.cu
src/distance/specializations/detail/cosine.cu
src/distance/specializations/detail/cosine.cu
src/distance/specializations/detail/hamming_unexpanded.cu
src/distance/specializations/detail/hellinger_expanded.cu
src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu
src/distance/specializations/detail/jensen_shannon_float_float_float_uint32.cu
src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu
src/distance/specializations/detail/kernels/gram_matrix_base_double.cu
src/distance/specializations/detail/kernels/gram_matrix_base_float.cu
src/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu
src/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu
# These are somehow missing a kernel definition which is causing a compile error.
# src/distance/specializations/detail/kernels/rbf_kernel_double.cu
# src/distance/specializations/detail/kernels/rbf_kernel_float.cu
src/distance/specializations/detail/kernels/tanh_kernel_double.cu
src/distance/specializations/detail/kernels/tanh_kernel_float.cu
src/distance/specializations/detail/kl_divergence_float_float_float_int.cu
src/distance/specializations/detail/kl_divergence_float_float_float_uint32.cu
src/distance/specializations/detail/kl_divergence_double_double_double_int.cu
src/distance/specializations/detail/l1_float_float_float_int.cu
src/distance/specializations/detail/l1_float_float_float_uint32.cu
src/distance/specializations/detail/l1_double_double_double_int.cu
src/distance/specializations/detail/l2_expanded_float_float_float_int.cu
src/distance/specializations/detail/l2_expanded_float_float_float_uint32.cu
src/distance/specializations/detail/l2_expanded_double_double_double_int.cu
# src/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu
# src/distance/specializations/detail/l2_sqrt_expanded_float_float_float_uint32.cu
# src/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu
# src/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu
# src/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_uint32.cu
# src/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu
# src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu
# src/distance/specializations/detail/l2_unexpanded_float_float_float_uint32.cu
# src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu
# src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu
# src/distance/specializations/detail/lp_unexpanded_float_float_float_uint32.cu
# src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu
# src/distance/specializations/detail/russel_rao_double_double_double_int.cu
# src/distance/specializations/detail/russel_rao_float_float_float_uint32.cu
# src/distance/specializations/detail/russel_rao_float_float_float_int.cu
# src/distance/specializations/fused_l2_nn_double_int.cu
# src/distance/specializations/fused_l2_nn_double_int64.cu
# src/distance/specializations/fused_l2_nn_float_int.cu
# src/distance/specializations/fused_l2_nn_float_int64.cu
# src/random/specializations/rmat_rectangular_generator_int_double.cu
# src/random/specializations/rmat_rectangular_generator_int64_double.cu
# src/random/specializations/rmat_rectangular_generator_int_float.cu
# src/random/specializations/rmat_rectangular_generator_int64_float.cu
src/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu
src/distance/specializations/detail/l2_sqrt_expanded_float_float_float_uint32.cu
src/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu
src/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu
src/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_uint32.cu
src/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu
src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu
src/distance/specializations/detail/l2_unexpanded_float_float_float_uint32.cu
src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu
src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu
src/distance/specializations/detail/lp_unexpanded_float_float_float_uint32.cu
src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu
src/distance/specializations/detail/russel_rao_double_double_double_int.cu
src/distance/specializations/detail/russel_rao_float_float_float_uint32.cu
src/distance/specializations/detail/russel_rao_float_float_float_int.cu
src/distance/specializations/fused_l2_nn_double_int.cu
src/distance/specializations/fused_l2_nn_double_int64.cu
src/distance/specializations/fused_l2_nn_float_int.cu
src/distance/specializations/fused_l2_nn_float_int64.cu
src/random/specializations/rmat_rectangular_generator_int_double.cu
src/random/specializations/rmat_rectangular_generator_int64_double.cu
src/random/specializations/rmat_rectangular_generator_int_float.cu
src/random/specializations/rmat_rectangular_generator_int64_float.cu
)
set_target_properties(
raft_distance_lib
Expand Down Expand Up @@ -354,23 +354,25 @@ if(RAFT_COMPILE_NN_LIBRARY)
src/nn/specializations/detail/ivfpq_compute_similarity_float_fast.cu
src/nn/specializations/detail/ivfpq_compute_similarity_float_no_basediff.cu
src/nn/specializations/detail/ivfpq_compute_similarity_float_no_smem_lut.cu
# src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_fast.cu
# src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_no_basediff.cu
# src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_no_smem_lut.cu
# src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_fast.cu
# src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_no_basediff.cu
# src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_no_smem_lut.cu
# src/nn/specializations/detail/ivfpq_compute_similarity_half_fast.cu
# src/nn/specializations/detail/ivfpq_compute_similarity_half_no_basediff.cu
# src/nn/specializations/detail/ivfpq_compute_similarity_half_no_smem_lut.cu
# src/nn/specializations/detail/ivfpq_search_float_int64_t.cu
# src/nn/specializations/detail/ivfpq_search_float_uint32_t.cu
# src/nn/specializations/detail/ivfpq_search_float_uint64_t.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_fast.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_no_basediff.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_no_smem_lut.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_fast.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_no_basediff.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_no_smem_lut.cu
src/nn/specializations/detail/ivfpq_compute_similarity_half_fast.cu
src/nn/specializations/detail/ivfpq_compute_similarity_half_no_basediff.cu
src/nn/specializations/detail/ivfpq_compute_similarity_half_no_smem_lut.cu
src/nn/specializations/detail/ivfpq_build.cu
src/nn/specializations/detail/ivfpq_search.cu
src/nn/specializations/detail/ivfpq_search_float_int64_t.cu
src/nn/specializations/detail/ivfpq_search_float_uint32_t.cu
src/nn/specializations/detail/ivfpq_search_float_uint64_t.cu
src/nn/specializations/fused_l2_knn_long_float_true.cu
src/nn/specializations/fused_l2_knn_long_float_false.cu
src/nn/specializations/fused_l2_knn_int_float_true.cu
src/nn/specializations/fused_l2_knn_int_float_false.cu
src/nn/specializations/knn.cu
# src/nn/specializations/knn.cu
)
set_target_properties(
raft_nn_lib
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

namespace raft::neighbors ::ivf_pq {

#define RAFT_INST(T, IdxT) \
#define RAFT_INST_SEARCH(T, IdxT) \
void search(const handle_t&, \
const search_params&, \
const index<IdxT>&, \
Expand All @@ -31,10 +31,33 @@ namespace raft::neighbors ::ivf_pq {
float*, \
rmm::mr::device_memory_resource*);

RAFT_INST(float, int64_t);
RAFT_INST(float, uint32_t);
RAFT_INST(float, uint64_t);
RAFT_INST_SEARCH(float, int64_t);
RAFT_INST_SEARCH(float, uint32_t);
RAFT_INST_SEARCH(float, uint64_t);

#undef RAFT_INST
#undef RAFT_INST_SEARCH

#define RAFT_INST_BUILD_EXTEND(T, IdxT) \
auto build(const handle_t& handle, \
const index_params& params, \
const T* dataset, \
IdxT n_rows, \
uint32_t dim) \
->index<IdxT>; \
\
auto extend(const handle_t& handle, \
const index<IdxT>& orig_index, \
const T* new_vectors, \
const IdxT* new_indices, \
IdxT n_rows) \
->index<IdxT>;

RAFT_INST_BUILD_EXTEND(float, int64_t)
RAFT_INST_BUILD_EXTEND(int8_t, int64_t)
RAFT_INST_BUILD_EXTEND(uint8_t, int64_t)
// RAFT_INST_BUILD_EXTEND(float, uint32_t);
// RAFT_INST_BUILD_EXTEND(float, uint64_t);

#undef RAFT_INST_BUILD_EXTEND

} // namespace raft::neighbors::ivf_pq
46 changes: 46 additions & 0 deletions cpp/src/nn/specializations/detail/ivfpq_build.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/neighbors/ivf_pq.cuh>
#include <raft/neighbors/specializations/ivf_pq_specialization.hpp>

namespace raft::neighbors::ivf_pq {

#define RAFT_INST_BUILD_EXTEND(T, IdxT) \
template auto build<T, IdxT>(const handle_t& handle, \
const index_params& params, \
const T* dataset, \
IdxT n_rows, \
uint32_t dim) \
->index<IdxT>; \
\
template auto extend<T, IdxT>(const handle_t& handle, \
const index<IdxT>& orig_index, \
const T* new_vectors, \
const IdxT* new_indices, \
IdxT n_rows) \
->index<IdxT>;

RAFT_INST_BUILD_EXTEND(float, int64_t);
RAFT_INST_BUILD_EXTEND(int8_t, int64_t);
RAFT_INST_BUILD_EXTEND(uint8_t, int64_t);

// RAFT_INST_BUILD_EXTEND(float, uint32_t);
// RAFT_INST_BUILD_EXTEND(float, uint64_t);

#undef RAFT_INST_BUILD_EXTEND

} // namespace raft::neighbors::ivf_pq
39 changes: 39 additions & 0 deletions cpp/src/nn/specializations/detail/ivfpq_search.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/neighbors/ivf_pq.cuh>
#include <raft/neighbors/specializations/ivf_pq_specialization.hpp>

namespace raft::neighbors::ivf_pq {

#define RAFT_SEARCH_INST(T, IdxT) \
template void search<T, IdxT>(const handle_t&, \
const search_params&, \
const index<IdxT>&, \
const T*, \
uint32_t, \
uint32_t, \
IdxT*, \
float*, \
rmm::mr::device_memory_resource*);

RAFT_SEARCH_INST(float, int64_t);
RAFT_SEARCH_INST(float, uint32_t);
RAFT_SEARCH_INST(float, uint64_t);

#undef RAFT_INST_SEARCH

} // namespace raft::neighbors::ivf_pq
2 changes: 1 addition & 1 deletion python/pylibraft/pylibraft/neighbors/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# =============================================================================

# Set the list of Cython files to build
set(cython_sources neighbors.pyx)
set(cython_sources ivf_pq.pyx)
set(linked_libraries raft::raft raft::nn)

# Build all of the Cython targets
Expand Down
2 changes: 1 addition & 1 deletion python/pylibraft/pylibraft/neighbors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

from ivf_pq import IvfPq
Loading

0 comments on commit cf40569

Please sign in to comment.