Skip to content

Commit

Permalink
IVF-PQ Python wrappers (#970)
Browse files Browse the repository at this point in the history
This PR adds python wrappers to IVF-PQ.

Authors:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #970
  • Loading branch information
tfeher authored Nov 15, 2022
1 parent a5cce11 commit 355f693
Show file tree
Hide file tree
Showing 16 changed files with 1,710 additions and 1 deletion.
17 changes: 17 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,21 @@ if(RAFT_COMPILE_DIST_LIBRARY)
src/distance/specializations/fused_l2_nn_double_int64.cu
src/distance/specializations/fused_l2_nn_float_int.cu
src/distance/specializations/fused_l2_nn_float_int64.cu
src/nn/specializations/detail/ivfpq_build.cu
src/nn/specializations/detail/ivfpq_compute_similarity_float_fast.cu
src/nn/specializations/detail/ivfpq_compute_similarity_float_no_basediff.cu
src/nn/specializations/detail/ivfpq_compute_similarity_float_no_smem_lut.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_fast.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_no_basediff.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8s_no_smem_lut.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_fast.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_no_basediff.cu
src/nn/specializations/detail/ivfpq_compute_similarity_fp8u_no_smem_lut.cu
src/nn/specializations/detail/ivfpq_compute_similarity_half_fast.cu
src/nn/specializations/detail/ivfpq_compute_similarity_half_no_basediff.cu
src/nn/specializations/detail/ivfpq_compute_similarity_half_no_smem_lut.cu
src/nn/specializations/detail/ivfpq_search.cu
src/nn/specializations/detail/ivfpq_search_float_uint64_t.cu
src/random/specializations/rmat_rectangular_generator_int_double.cu
src/random/specializations/rmat_rectangular_generator_int64_double.cu
src/random/specializations/rmat_rectangular_generator_int_float.cu
Expand Down Expand Up @@ -400,6 +415,8 @@ if(RAFT_COMPILE_NN_LIBRARY)
src/nn/specializations/detail/ivfpq_compute_similarity_half_fast.cu
src/nn/specializations/detail/ivfpq_compute_similarity_half_no_basediff.cu
src/nn/specializations/detail/ivfpq_compute_similarity_half_no_smem_lut.cu
src/nn/specializations/detail/ivfpq_build.cu
src/nn/specializations/detail/ivfpq_search.cu
src/nn/specializations/detail/ivfpq_search_float_int64_t.cu
src/nn/specializations/detail/ivfpq_search_float_uint32_t.cu
src/nn/specializations/detail/ivfpq_search_float_uint64_t.cu
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/neighbors/ivf_pq_types.hpp>

namespace raft::neighbors ::ivf_pq {

#define RAFT_INST_SEARCH(T, IdxT) \
void search(const handle_t&, \
const search_params&, \
const index<IdxT>&, \
const T*, \
uint32_t, \
uint32_t, \
IdxT*, \
float*, \
rmm::mr::device_memory_resource*);

RAFT_INST_SEARCH(float, uint64_t);
RAFT_INST_SEARCH(int8_t, uint64_t);
RAFT_INST_SEARCH(uint8_t, uint64_t);

#undef RAFT_INST_SEARCH

// We define overloads for build and extend with void return type. This is used in the Cython
// wrappers, where exception handling is not compatible with return type that has nontrivial
// constructor.
#define RAFT_INST_BUILD_EXTEND(T, IdxT) \
auto build(const handle_t& handle, \
const index_params& params, \
const T* dataset, \
IdxT n_rows, \
uint32_t dim) \
->index<IdxT>; \
\
auto extend(const handle_t& handle, \
const index<IdxT>& orig_index, \
const T* new_vectors, \
const IdxT* new_indices, \
IdxT n_rows) \
->index<IdxT>; \
\
void build(const handle_t& handle, \
const index_params& params, \
const T* dataset, \
IdxT n_rows, \
uint32_t dim, \
index<IdxT>* idx); \
\
void extend(const handle_t& handle, \
index<IdxT>* idx, \
const T* new_vectors, \
const IdxT* new_indices, \
IdxT n_rows);

RAFT_INST_BUILD_EXTEND(float, uint64_t)
RAFT_INST_BUILD_EXTEND(int8_t, uint64_t)
RAFT_INST_BUILD_EXTEND(uint8_t, uint64_t)

#undef RAFT_INST_BUILD_EXTEND

} // namespace raft::neighbors::ivf_pq
1 change: 0 additions & 1 deletion cpp/include/raft/spatial/knn/detail/haversine_distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

#include <raft/core/handle.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/spatial/knn/faiss_mr.hpp>

namespace raft {
namespace spatial {
Expand Down
66 changes: 66 additions & 0 deletions cpp/src/nn/specializations/detail/ivfpq_build.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/neighbors/ivf_pq.cuh>
#include <raft/neighbors/specializations/ivf_pq_specialization.hpp>

namespace raft::neighbors::ivf_pq {

#define RAFT_INST_BUILD_EXTEND(T, IdxT) \
auto build(const handle_t& handle, \
const index_params& params, \
const T* dataset, \
IdxT n_rows, \
uint32_t dim) \
->index<IdxT> \
{ \
return build<T, IdxT>(handle, params, dataset, n_rows, dim); \
} \
auto extend(const handle_t& handle, \
const index<IdxT>& orig_index, \
const T* new_vectors, \
const IdxT* new_indices, \
IdxT n_rows) \
->index<IdxT> \
{ \
return extend<T, IdxT>(handle, orig_index, new_vectors, new_indices, n_rows); \
} \
\
void build(const handle_t& handle, \
const index_params& params, \
const T* dataset, \
IdxT n_rows, \
uint32_t dim, \
index<IdxT>* idx) \
{ \
*idx = build<T, IdxT>(handle, params, dataset, n_rows, dim); \
} \
void extend(const handle_t& handle, \
index<IdxT>* idx, \
const T* new_vectors, \
const IdxT* new_indices, \
IdxT n_rows) \
{ \
extend<T, IdxT>(handle, idx, new_vectors, new_indices, n_rows); \
}

RAFT_INST_BUILD_EXTEND(float, uint64_t);
RAFT_INST_BUILD_EXTEND(int8_t, uint64_t);
RAFT_INST_BUILD_EXTEND(uint8_t, uint64_t);

#undef RAFT_INST_BUILD_EXTEND

} // namespace raft::neighbors::ivf_pq
43 changes: 43 additions & 0 deletions cpp/src/nn/specializations/detail/ivfpq_search.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/neighbors/ivf_pq.cuh>
#include <raft/neighbors/specializations/detail/ivf_pq_search.cuh>
#include <raft/neighbors/specializations/ivf_pq_specialization.hpp>

namespace raft::neighbors::ivf_pq {

#define RAFT_SEARCH_INST(T, IdxT) \
void search(const handle_t& handle, \
const search_params& params, \
const index<IdxT>& idx, \
const T* queries, \
uint32_t n_queries, \
uint32_t k, \
IdxT* neighbors, \
float* distances, \
rmm::mr::device_memory_resource* mr) \
{ \
search<T, IdxT>(handle, params, idx, queries, n_queries, k, neighbors, distances, mr); \
}

RAFT_SEARCH_INST(float, uint64_t);
RAFT_SEARCH_INST(int8_t, uint64_t);
RAFT_SEARCH_INST(uint8_t, uint64_t);

#undef RAFT_INST_SEARCH

} // namespace raft::neighbors::ivf_pq
1 change: 1 addition & 0 deletions python/pylibraft/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ rapids_cython_init()

add_subdirectory(pylibraft/common)
add_subdirectory(pylibraft/distance)
add_subdirectory(pylibraft/neighbors)
add_subdirectory(pylibraft/random)
add_subdirectory(pylibraft/cluster)

Expand Down
19 changes: 19 additions & 0 deletions python/pylibraft/pylibraft/common/input_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
# cython: embedsignature = True
# cython: language_level = 3

import numpy as np


def do_dtypes_match(*cais):
last_dtype = cais[0].__cuda_array_interface__["typestr"]
Expand Down Expand Up @@ -57,3 +59,20 @@ def do_shapes_match(*cais):
return False
last_shape = shape
return True


def is_c_contiguous(cai):
"""
Checks whether an array is C contiguous.
Parameters
----------
cai : CUDA array interface
"""
dt = np.dtype(cai["typestr"])
return (
"strides" not in cai
or cai["strides"] is None
or cai["strides"][1] == dt.itemsize
)
29 changes: 29 additions & 0 deletions python/pylibraft/pylibraft/neighbors/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# =============================================================================
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under
# the License.
# =============================================================================

# Set the list of Cython files to build
set(linked_libraries raft::raft raft::distance)

# Build all of the Cython targets
rapids_cython_create_modules(
CXX
SOURCE_FILES ""
LINKED_LIBRARIES "${linked_libraries}" MODULE_PREFIX neighbors_
)

foreach(cython_module IN LISTS RAPIDS_CYTHON_CREATED_TARGETS)
set_target_properties(${cython_module} PROPERTIES INSTALL_RPATH "\$ORIGIN;\$ORIGIN/../library")
endforeach()

add_subdirectory(ivf_pq)
14 changes: 14 additions & 0 deletions python/pylibraft/pylibraft/neighbors/__init__.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
14 changes: 14 additions & 0 deletions python/pylibraft/pylibraft/neighbors/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
28 changes: 28 additions & 0 deletions python/pylibraft/pylibraft/neighbors/ivf_pq/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# =============================================================================
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under
# the License.
# =============================================================================

# Set the list of Cython files to build
set(cython_sources ivf_pq.pyx)
set(linked_libraries raft::raft raft::distance)

# Build all of the Cython targets
rapids_cython_create_modules(
CXX
SOURCE_FILES "${cython_sources}"
LINKED_LIBRARIES "${linked_libraries}" MODULE_PREFIX neighbors_ivfpq_
)

foreach(cython_module IN LISTS RAPIDS_CYTHON_CREATED_TARGETS)
set_target_properties(${cython_module} PROPERTIES INSTALL_RPATH "\$ORIGIN;\$ORIGIN/../library")
endforeach()
Empty file.
16 changes: 16 additions & 0 deletions python/pylibraft/pylibraft/neighbors/ivf_pq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) 2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from .ivf_pq import Index, IndexParams, SearchParams, build, extend, search
Loading

0 comments on commit 355f693

Please sign in to comment.