Skip to content

Commit

Permalink
Python API for brute-force KNN (#1292)
Browse files Browse the repository at this point in the history
Closes #1289

Authors:
  - Corey J. Nolet (https://github.com/cjnolet)
  - Ben Frederickson (https://github.com/benfred)

Approvers:
  - Ben Frederickson (https://github.com/benfred)

URL: #1292
  • Loading branch information
cjnolet authored Mar 23, 2023
1 parent 419f0c2 commit 31847af
Show file tree
Hide file tree
Showing 13 changed files with 395 additions and 41 deletions.
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ if(RAFT_COMPILE_LIBRARY)
# These are somehow missing a kernel definition which is causing a compile error.
# src/distance/specializations/detail/kernels/rbf_kernel_double.cu
# src/distance/specializations/detail/kernels/rbf_kernel_float.cu
src/neighbors/brute_force_knn_int64_t_float.cu
src/distance/specializations/detail/kernels/tanh_kernel_double.cu
src/distance/specializations/detail/kernels/tanh_kernel_float.cu
src/distance/specializations/detail/kl_divergence_float_float_float_int.cu
Expand Down
19 changes: 9 additions & 10 deletions cpp/include/raft_runtime/neighbors/brute_force.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,17 @@

namespace raft::runtime::neighbors::brute_force {

#define RAFT_INST_BFKNN(IDX_T, DATA_T, MATRIX_IDX_T, INDEX_LAYOUT, SEARCH_LAYOUT) \
void knn(raft::device_resources const& handle, \
std::vector<raft::device_matrix_view<const DATA_T, MATRIX_IDX_T, INDEX_LAYOUT>> index, \
raft::device_matrix_view<const DATA_T, MATRIX_IDX_T, SEARCH_LAYOUT> search, \
raft::device_matrix_view<IDX_T, MATRIX_IDX_T, row_major> indices, \
raft::device_matrix_view<DATA_T, MATRIX_IDX_T, row_major> distances, \
int k, \
distance::DistanceType metric = distance::DistanceType::L2Unexpanded, \
std::optional<float> metric_arg = std::make_optional<float>(2.0f), \
#define RAFT_INST_BFKNN(IDX_T, DATA_T, MATRIX_IDX_T, INDEX_LAYOUT, SEARCH_LAYOUT) \
void knn(raft::device_resources const& handle, \
raft::device_matrix_view<const DATA_T, MATRIX_IDX_T, INDEX_LAYOUT> index, \
raft::device_matrix_view<const DATA_T, MATRIX_IDX_T, SEARCH_LAYOUT> search, \
raft::device_matrix_view<IDX_T, MATRIX_IDX_T, row_major> indices, \
raft::device_matrix_view<DATA_T, MATRIX_IDX_T, row_major> distances, \
distance::DistanceType metric = distance::DistanceType::L2Unexpanded, \
std::optional<float> metric_arg = std::make_optional<float>(2.0f), \
std::optional<IDX_T> global_id_offset = std::nullopt);

RAFT_INST_BFKNN(int64_t, float, uint32_t, raft::row_major, raft::row_major);
RAFT_INST_BFKNN(int64_t, float, int64_t, raft::row_major, raft::row_major);

#undef RAFT_INST_BFKNN

Expand Down
46 changes: 24 additions & 22 deletions cpp/src/neighbors/brute_force_knn_int64_t_float.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
* limitations under the License.
*/

#pragma once

#include <raft/core/device_mdspan.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/neighbors/brute_force.cuh>
Expand All @@ -24,30 +22,34 @@

#include <raft_runtime/neighbors/brute_force.hpp>

#include <vector>

namespace raft::runtime::neighbors::brute_force {

#define RAFT_INST_BFKNN(IDX_T, DATA_T, MATRIX_IDX_T, INDEX_LAYOUT, SEARCH_LAYOUT) \
void knn(raft::device_resources const& handle, \
std::vector<raft::device_matrix_view<const DATA_T, MATRIX_IDX_T, INDEX_LAYOUT>> index, \
raft::device_matrix_view<const DATA_T, MATRIX_IDX_T, SEARCH_LAYOUT> search, \
raft::device_matrix_view<IDX_T, MATRIX_IDX_T, row_major> indices, \
raft::device_matrix_view<DATA_T, MATRIX_IDX_T, row_major> distances, \
distance::DistanceType metric = distance::DistanceType::L2Unexpanded, \
std::optional<float> metric_arg = std::make_optional<float>(2.0f), \
std::optional<IDX_T> global_id_offset = std::nullopt) \
{ \
raft::neighbors::brute_force::knn(handle, \
index, \
search, \
indices, \
distances, \
static_cast<int>(indices.extent(1)), \
metric, \
metric_arg, \
global_id_offset); \
#define RAFT_INST_BFKNN(IDX_T, DATA_T, MATRIX_IDX_T, INDEX_LAYOUT, SEARCH_LAYOUT) \
void knn(raft::device_resources const& handle, \
raft::device_matrix_view<const DATA_T, MATRIX_IDX_T, INDEX_LAYOUT> index, \
raft::device_matrix_view<const DATA_T, MATRIX_IDX_T, SEARCH_LAYOUT> search, \
raft::device_matrix_view<IDX_T, MATRIX_IDX_T, row_major> indices, \
raft::device_matrix_view<DATA_T, MATRIX_IDX_T, row_major> distances, \
distance::DistanceType metric, \
std::optional<float> metric_arg, \
std::optional<IDX_T> global_id_offset) \
{ \
std::vector<raft::device_matrix_view<const DATA_T, MATRIX_IDX_T, INDEX_LAYOUT>> vec; \
vec.push_back(index); \
raft::neighbors::brute_force::knn(handle, \
vec, \
search, \
indices, \
distances, \
static_cast<int>(distances.extent(1)), \
metric, \
metric_arg, \
global_id_offset); \
}

RAFT_INST_BFKNN(int64_t, float, uint32_t, raft::row_major, raft::row_major);
RAFT_INST_BFKNN(int64_t, float, int64_t, raft::row_major, raft::row_major);

#undef RAFT_INST_BFKNN

Expand Down
1 change: 0 additions & 1 deletion python/pylibraft/pylibraft/common/mdspan.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ cdef device_matrix_view[float, int64_t, row_major] \
return make_device_matrix_view[float, int64_t, row_major](
<float*><uintptr_t>cai.data, shape[0], shape[1])


cdef device_matrix_view[uint8_t, int64_t, row_major] \
get_dmv_uint8(cai, check_shape) except *:
if cai.dtype != np.uint8:
Expand Down
2 changes: 1 addition & 1 deletion python/pylibraft/pylibraft/neighbors/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# =============================================================================

# Set the list of Cython files to build
set(cython_sources common.pyx refine.pyx)
set(cython_sources common.pyx refine.pyx brute_force.pyx)
set(linked_libraries raft::raft raft::compiled)

# Build all of the Cython targets
Expand Down
5 changes: 4 additions & 1 deletion python/pylibraft/pylibraft/neighbors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

from pylibraft.neighbors import brute_force

from .refine import refine

__all__ = ["common", "refine"]
__all__ = ["common", "refine", "brute_force"]
179 changes: 179 additions & 0 deletions python/pylibraft/pylibraft/neighbors/brute_force.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
#
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# cython: profile=False
# distutils: language = c++
# cython: embedsignature = True
# cython: language_level = 3

import numpy as np

from cython.operator cimport dereference as deref
from libcpp cimport bool, nullptr
from libcpp.vector cimport vector

from pylibraft.distance.distance_type cimport DistanceType

from pylibraft.common import (
DeviceResources,
auto_convert_output,
cai_wrapper,
device_ndarray,
)

from libc.stdint cimport int64_t, uintptr_t

from pylibraft.common.cpp.optional cimport optional
from pylibraft.common.handle cimport device_resources
from pylibraft.common.mdspan cimport get_dmv_float, get_dmv_int64

from pylibraft.common.handle import auto_sync_handle
from pylibraft.common.input_validation import is_c_contiguous
from pylibraft.common.interruptible import cuda_interruptible

from pylibraft.distance.distance_type cimport DistanceType

# TODO: Centralize this

from pylibraft.distance.pairwise_distance import DISTANCE_TYPES

from pylibraft.common.cpp.mdspan cimport (
device_matrix_view,
host_matrix_view,
make_device_matrix_view,
make_host_matrix_view,
row_major,
)
from pylibraft.neighbors.cpp.brute_force cimport knn as c_knn


def _get_array_params(array_interface, check_dtype=None):
dtype = np.dtype(array_interface["typestr"])
if check_dtype is None and dtype != check_dtype:
raise TypeError("dtype %s not supported" % dtype)
shape = array_interface["shape"]
if len(shape) != 2:
raise ValueError("Expected a 2D array, got %d D" % len(shape))
data = array_interface["data"][0]
return (shape, dtype, data)


@auto_sync_handle
@auto_convert_output
def knn(dataset, queries, k=None, indices=None, distances=None,
metric="sqeuclidean", metric_arg=2.0,
global_id_offset=0, handle=None):
"""
Perform a brute-force nearest neighbors search.
Parameters
----------
dataset : array interface compliant matrix, row-major layout,
shape (n_samples, dim). Supported dtype [float]
queries : array interface compliant matrix, row-major layout,
shape (n_queries, dim) Supported dtype [float]
k : int
Number of neighbors to search (k <= 2048). Optional if indices or
distances arrays are given (in which case their second dimension
is k).
indices : Optional array interface compliant matrix shape
(n_queries, k), dtype int64_t. If supplied, neighbor
indices will be written here in-place. (default None)
Supported dtype uint64
distances : Optional array interface compliant matrix shape
(n_queries, k), dtype float. If supplied, neighbor
indices will be written here in-place. (default None)
{handle_docstring}
Returns
-------
indices: array interface compliant object containing resulting indices
shape (n_queries, k)
distances: array interface compliant object containing resulting distances
shape (n_queries, k)
Examples
--------
>>> import cupy as cp
>>> from pylibraft.common import DeviceResources
>>> from pylibraft.neighbors.brute_force import knn
>>> n_samples = 50000
>>> n_features = 50
>>> n_queries = 1000
>>> dataset = cp.random.random_sample((n_samples, n_features),
... dtype=cp.float32)
>>> # Search using the built index
>>> queries = cp.random.random_sample((n_queries, n_features),
... dtype=cp.float32)
>>> k = 40
>>> distances, neighbors = knn(dataset, queries, k)
>>> distances = cp.asarray(distances)
>>> neighbors = cp.asarray(neighbors)
"""

if handle is None:
handle = DeviceResources()

dataset_cai = cai_wrapper(dataset)
queries_cai = cai_wrapper(queries)

if k is None:
if indices is not None:
k = cai_wrapper(indices).shape[1]
elif distances is not None:
k = cai_wrapper(distances).shape[1]
else:
raise ValueError("Argument k must be specified if both indices "
"and distances arg is None")

n_queries = cai_wrapper(queries).shape[0]

if indices is None:
indices = device_ndarray.empty((n_queries, k), dtype='int64')

if distances is None:
distances = device_ndarray.empty((n_queries, k), dtype='float32')

cdef DistanceType c_metric = DISTANCE_TYPES[metric]

distances_cai = cai_wrapper(distances)
indices_cai = cai_wrapper(indices)

cdef optional[float] c_metric_arg = <float>metric_arg
cdef optional[int64_t] c_global_offset = <int64_t>global_id_offset

cdef device_resources* handle_ = \
<device_resources*><size_t>handle.getHandle()

if dataset_cai.dtype == np.float32:
with cuda_interruptible():
c_knn(deref(handle_),
get_dmv_float(dataset_cai, check_shape=True),
get_dmv_float(queries_cai, check_shape=True),
get_dmv_int64(indices_cai, check_shape=True),
get_dmv_float(distances_cai, check_shape=True),
c_metric,
c_metric_arg,
c_global_offset)
else:
raise TypeError("dtype %s not supported" % dataset_cai.dtype)

return (distances, indices)
12 changes: 7 additions & 5 deletions python/pylibraft/pylibraft/neighbors/common.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ import warnings

from pylibraft.distance.distance_type cimport DistanceType

SUPPORTED_DISTANCES = {
"sqeuclidean": DistanceType.L2Expanded,
"euclidean": DistanceType.L2SqrtExpanded,
"inner_product": DistanceType.InnerProduct,

}


def _get_metric(metric):
SUPPORTED_DISTANCES = {
"sqeuclidean": DistanceType.L2Expanded,
"euclidean": DistanceType.L2SqrtExpanded,
"inner_product": DistanceType.InnerProduct
}
if metric not in SUPPORTED_DISTANCES:
if metric == "l2_expanded":
warnings.warn("Using l2_expanded as a metric name is deprecated,"
Expand Down
Empty file.
14 changes: 14 additions & 0 deletions python/pylibraft/pylibraft/neighbors/cpp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
55 changes: 55 additions & 0 deletions python/pylibraft/pylibraft/neighbors/cpp/brute_force.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# cython: profile=False
# distutils: language = c++
# cython: embedsignature = True
# cython: language_level = 3

import numpy as np

import pylibraft.common.handle

from cython.operator cimport dereference as deref
from libc.stdint cimport int8_t, int64_t, uint8_t, uint64_t, uintptr_t
from libcpp cimport bool, nullptr
from libcpp.string cimport string
from libcpp.vector cimport vector

from rmm._lib.memory_resource cimport device_memory_resource

from pylibraft.common.cpp.mdspan cimport (
device_matrix_view,
host_matrix_view,
make_device_matrix_view,
make_host_matrix_view,
row_major,
)
from pylibraft.common.cpp.optional cimport optional
from pylibraft.common.handle cimport device_resources
from pylibraft.distance.distance_type cimport DistanceType


cdef extern from "raft_runtime/neighbors/brute_force.hpp" \
namespace "raft::runtime::neighbors::brute_force" nogil:

cdef void knn(const device_resources & handle,
device_matrix_view[float, int64_t, row_major] index,
device_matrix_view[float, int64_t, row_major] search,
device_matrix_view[int64_t, int64_t, row_major] indices,
device_matrix_view[float, int64_t, row_major] distances,
DistanceType metric,
optional[float] metric_arg,
optional[int64_t] global_id_offset) except +
Loading

0 comments on commit 31847af

Please sign in to comment.