Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python API for brute-force KNN #1292

Merged
merged 114 commits into from
Mar 23, 2023
Merged
Show file tree
Hide file tree
Changes from 111 commits
Commits
Show all changes
114 commits
Select commit Hold shift + click to select a range
57cfa20
Replace faiss bfKnn
benfred Jan 27, 2023
984c641
Merge branch 'branch-23.02' into bfknn
benfred Jan 27, 2023
805abc7
fix merge
benfred Jan 27, 2023
3e21478
Merge branch 'branch-23.02' into bfknn
cjnolet Jan 27, 2023
74bd44f
Fix bug with col_tiles < K
benfred Jan 27, 2023
c69054d
Merge branch 'bfknn' of github.com:benfred/raft into bfknn
benfred Jan 27, 2023
1d9581b
Include metric_arg in bfknn
benfred Jan 30, 2023
b4cf88c
speedup compile times
benfred Jan 30, 2023
98ffb70
Merge branch 'branch-23.02' into bfknn
benfred Jan 31, 2023
5442d31
Suggestions from code review
benfred Jan 31, 2023
0f5d206
fixes
benfred Jan 31, 2023
cb2b750
Merge branch 'branch-23.04' into bfknn
benfred Feb 3, 2023
e870eb3
use pairwise_distance specialization to speed up compile times
benfred Feb 7, 2023
cd84397
Merge branch 'branch-23.04' into bfknn
benfred Feb 7, 2023
8445aed
Use distance specializations
benfred Feb 7, 2023
e87633a
Merge branch 'branch-23.04' into bfknn
cjnolet Feb 9, 2023
52bf729
Merge branch 'branch-23.04' into bfknn
benfred Feb 11, 2023
d97ddb8
Merge branch 'branch-23.04' into bfknn
cjnolet Feb 11, 2023
5905b2d
use specializations in RBC code
benfred Feb 14, 2023
d905266
Merge branch 'branch-23.04' into bfknn
benfred Feb 14, 2023
8eaba84
use pw specializations in rbc
benfred Feb 14, 2023
fe728e9
use matrix::select_k in bfknn call
benfred Feb 14, 2023
96e05e1
expose bf detail specialization
benfred Feb 15, 2023
59060b2
Revert "use pw specializations in rbc"
benfred Feb 15, 2023
c734bac
Add tests for other metrics
benfred Feb 15, 2023
c65e4bb
Fix parameter order
benfred Feb 16, 2023
3830e53
Fix Lp distance
benfred Feb 17, 2023
3f0b9a7
Revert "use matrix::select_k in bfknn call"
benfred Feb 17, 2023
3900570
re-enable failing tests
benfred Feb 17, 2023
8e71915
fix cosine/innerproduct in bfknn
benfred Feb 17, 2023
f806bf6
Test JensenShannon distance
benfred Feb 17, 2023
3315dca
support k up to 2048 in faiss select
benfred Feb 18, 2023
1b6eda2
Merge remote-tracking branch 'origin/branch-23.04' into bfknn
benfred Feb 18, 2023
a83bef3
cmake format
benfred Feb 18, 2023
3b811a1
support k up to 2048 in faiss select
benfred Feb 18, 2023
9a19456
Merge branch 'branch-23.04' into faiss_largek
cjnolet Feb 18, 2023
c39dc65
style
benfred Feb 18, 2023
c60e17f
Merge branch 'faiss_largek' of github.com:benfred/raft into faiss_largek
benfred Feb 18, 2023
2752294
code review suggestions
benfred Feb 18, 2023
84f7a42
Merge remote-tracking branch 'bf/faiss_largek' into bfknn
benfred Feb 18, 2023
901b898
Merge branch 'branch-23.04' into bfknn
cjnolet Feb 20, 2023
1548a78
Remove ENABLE_NN_DEPENDENCIES option
benfred Feb 21, 2023
3fdc712
Merge branch 'branch-23.04' into bfknn
benfred Feb 21, 2023
642f87d
Adding brute-force knn api to pylibraft
cjnolet Feb 21, 2023
1ac42af
Merge branch 'branch-23.04' into fea-2304-bfknn
cjnolet Feb 21, 2023
31c9cf2
temporarily re-add faiss build targets
benfred Feb 21, 2023
f7fd6a7
couple more files to re-add faiss
benfred Feb 21, 2023
37d66d2
re-add faiss_mr
benfred Feb 22, 2023
a61c92f
explicitly include faiss_mr
benfred Feb 23, 2023
dbd31b2
Allow col_major input to bfknn
benfred Feb 24, 2023
fddecc3
fix faiss queryempty test
benfred Feb 24, 2023
4687144
exclude LP from fused
benfred Feb 27, 2023
bd3ff51
Merge branch 'branch-23.04' into bfknn
benfred Feb 27, 2023
06c8674
use metric processor for cosine/correlation
benfred Feb 27, 2023
b44d15c
exclude cosine
benfred Feb 28, 2023
5a582cd
Merge branch 'branch-23.04' into bfknn
benfred Mar 1, 2023
eb0271a
avoid l2expanded distance
benfred Mar 1, 2023
616455c
Merge branch 'bfknn' of github.com:benfred/raft into bfknn
benfred Mar 1, 2023
4c41c63
Expanded L2 Changes
benfred Mar 6, 2023
cdf1962
correct for small instabilities in l2sqrtexpanded distance
benfred Mar 7, 2023
6a1e2d8
warp divergence
benfred Mar 7, 2023
1e2817c
clamp to 0
benfred Mar 7, 2023
4b56fac
threshold
benfred Mar 7, 2023
4b41e2c
Transpose for fusedl2knn as well
benfred Mar 8, 2023
455c952
fix
benfred Mar 8, 2023
df46b65
Fix stream handling on col-major inputs
benfred Mar 8, 2023
97a3c01
Merge branch 'bfknn' of github.com:benfred/raft into bfknn
benfred Mar 8, 2023
105bc96
Merge branch 'branch-23.04' into bfknn
benfred Mar 9, 2023
2828c3b
Merge branch 'branch-23.04' into bfknn
benfred Mar 9, 2023
6e45267
Merge branch 'branch-23.04' into bfknn
benfred Mar 11, 2023
f426510
Merge branch 'branch-23.04' into fea-2304-bfknn
benfred Mar 11, 2023
754ab9a
Merge remote-tracking branch 'ben/bfknn' into fea-2304-bfknn
cjnolet Mar 11, 2023
a35ec7b
Including specializations in bfknn runtime API
cjnolet Mar 11, 2023
28ebeef
fix build for missing symbols
benfred Mar 13, 2023
9917324
Merge branch 'branch-23.04' into bfknn
cjnolet Mar 13, 2023
65d7725
code review feedback
benfred Mar 14, 2023
e41ff88
matrix::fill and linalg::map_offset
benfred Mar 14, 2023
5eb7d22
build fix
benfred Mar 14, 2023
e36d089
fix
benfred Mar 14, 2023
50e366f
Merge branch 'branch-23.04' into bfknn
benfred Mar 14, 2023
e8f9c55
move faiss_select into raft::neighbors namespace
benfred Mar 14, 2023
9f211a0
move knn_merge parts to its own file
benfred Mar 14, 2023
f1011dc
Merge branch 'branch-23.04' into fea-2304-bfknn
cjnolet Mar 14, 2023
9cbac3c
Merge branch 'branch-23.04' into bfknn
cjnolet Mar 14, 2023
0667526
Merge remote-tracking branch 'origin/branch-23.04' into bfknn
benfred Mar 15, 2023
b48ad31
Merge remote-tracking branch 'ben/bfknn' into fea-2304-bfknn
cjnolet Mar 15, 2023
9593ae1
Use stream pool
benfred Mar 15, 2023
76d2b19
Merge branch 'bfknn' of github.com:benfred/raft into bfknn
benfred Mar 15, 2023
97753b0
use right handle for transpose
benfred Mar 15, 2023
a534538
set blas stream
benfred Mar 15, 2023
237b7e1
error handling
benfred Mar 15, 2023
07290bc
try to isolate stream failure
benfred Mar 16, 2023
2671a0e
Move transpose code out of loop
benfred Mar 16, 2023
7e0bb9b
fix
benfred Mar 16, 2023
b4c3284
try transpose inside streampool again
benfred Mar 16, 2023
92d82db
one more try with cublasSetStream
benfred Mar 17, 2023
c84e560
Merge branch 'branch-23.04' into bfknn
benfred Mar 17, 2023
21a1953
fix
benfred Mar 17, 2023
80fb76c
Merge branch 'bfknn' of github.com:benfred/raft into bfknn
benfred Mar 17, 2023
2d89994
Merge remote-tracking branch 'rapidsai/branch-23.04' into fea-2304-bfknn
cjnolet Mar 17, 2023
1b89ed1
Merge remote-tracking branch 'ben/bfknn' into fea-2304-bfknn
cjnolet Mar 17, 2023
5f743f1
Getting some tests running (no assertions yet)
cjnolet Mar 17, 2023
fa93cf7
Merge branch 'branch-23.04' into fea-2304-bfknn
cjnolet Mar 17, 2023
accfc3b
Adding pytest assertions and baseline for comparison
cjnolet Mar 22, 2023
0d7eba7
Merge branch 'fea-2304-bfknn' of github.com:cjnolet/raft into fea-230…
cjnolet Mar 22, 2023
1c9ba9d
Merge remote-tracking branch 'rapidsai/branch-23.04' into fea-2304-bfknn
cjnolet Mar 23, 2023
1e28f9f
Using correct size of indices
cjnolet Mar 23, 2023
e8879a3
Using correct number of distances too
cjnolet Mar 23, 2023
fb90502
Limiting the correct value
cjnolet Mar 23, 2023
5b407b9
Merge branch 'branch-23.04' into fea-2304-bfknn
cjnolet Mar 23, 2023
7c5c87b
Enabling more distances and fixing assertion
cjnolet Mar 23, 2023
4c9f85b
Tests are passing.
cjnolet Mar 23, 2023
7400402
Review feedback
cjnolet Mar 23, 2023
d9016f2
Using proper import
cjnolet Mar 23, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ if(RAFT_COMPILE_LIBRARY)
# These are somehow missing a kernel definition which is causing a compile error.
# src/distance/specializations/detail/kernels/rbf_kernel_double.cu
# src/distance/specializations/detail/kernels/rbf_kernel_float.cu
src/neighbors/brute_force_knn_int64_t_float.cu
src/distance/specializations/detail/kernels/tanh_kernel_double.cu
src/distance/specializations/detail/kernels/tanh_kernel_float.cu
src/distance/specializations/detail/kl_divergence_float_float_float_int.cu
Expand Down
19 changes: 9 additions & 10 deletions cpp/include/raft_runtime/neighbors/brute_force.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,17 @@

namespace raft::runtime::neighbors::brute_force {

#define RAFT_INST_BFKNN(IDX_T, DATA_T, MATRIX_IDX_T, INDEX_LAYOUT, SEARCH_LAYOUT) \
void knn(raft::device_resources const& handle, \
std::vector<raft::device_matrix_view<const DATA_T, MATRIX_IDX_T, INDEX_LAYOUT>> index, \
raft::device_matrix_view<const DATA_T, MATRIX_IDX_T, SEARCH_LAYOUT> search, \
raft::device_matrix_view<IDX_T, MATRIX_IDX_T, row_major> indices, \
raft::device_matrix_view<DATA_T, MATRIX_IDX_T, row_major> distances, \
int k, \
distance::DistanceType metric = distance::DistanceType::L2Unexpanded, \
std::optional<float> metric_arg = std::make_optional<float>(2.0f), \
#define RAFT_INST_BFKNN(IDX_T, DATA_T, MATRIX_IDX_T, INDEX_LAYOUT, SEARCH_LAYOUT) \
void knn(raft::device_resources const& handle, \
raft::device_matrix_view<const DATA_T, MATRIX_IDX_T, INDEX_LAYOUT> index, \
raft::device_matrix_view<const DATA_T, MATRIX_IDX_T, SEARCH_LAYOUT> search, \
raft::device_matrix_view<IDX_T, MATRIX_IDX_T, row_major> indices, \
raft::device_matrix_view<DATA_T, MATRIX_IDX_T, row_major> distances, \
distance::DistanceType metric = distance::DistanceType::L2Unexpanded, \
std::optional<float> metric_arg = std::make_optional<float>(2.0f), \
std::optional<IDX_T> global_id_offset = std::nullopt);

RAFT_INST_BFKNN(int64_t, float, uint32_t, raft::row_major, raft::row_major);
RAFT_INST_BFKNN(int64_t, float, int64_t, raft::row_major, raft::row_major);

#undef RAFT_INST_BFKNN

Expand Down
46 changes: 24 additions & 22 deletions cpp/src/neighbors/brute_force_knn_int64_t_float.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
* limitations under the License.
*/

#pragma once

#include <raft/core/device_mdspan.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/neighbors/brute_force.cuh>
Expand All @@ -24,30 +22,34 @@

#include <raft_runtime/neighbors/brute_force.hpp>

#include <vector>

namespace raft::runtime::neighbors::brute_force {

#define RAFT_INST_BFKNN(IDX_T, DATA_T, MATRIX_IDX_T, INDEX_LAYOUT, SEARCH_LAYOUT) \
void knn(raft::device_resources const& handle, \
std::vector<raft::device_matrix_view<const DATA_T, MATRIX_IDX_T, INDEX_LAYOUT>> index, \
raft::device_matrix_view<const DATA_T, MATRIX_IDX_T, SEARCH_LAYOUT> search, \
raft::device_matrix_view<IDX_T, MATRIX_IDX_T, row_major> indices, \
raft::device_matrix_view<DATA_T, MATRIX_IDX_T, row_major> distances, \
distance::DistanceType metric = distance::DistanceType::L2Unexpanded, \
std::optional<float> metric_arg = std::make_optional<float>(2.0f), \
std::optional<IDX_T> global_id_offset = std::nullopt) \
{ \
raft::neighbors::brute_force::knn(handle, \
index, \
search, \
indices, \
distances, \
static_cast<int>(indices.extent(1)), \
metric, \
metric_arg, \
global_id_offset); \
#define RAFT_INST_BFKNN(IDX_T, DATA_T, MATRIX_IDX_T, INDEX_LAYOUT, SEARCH_LAYOUT) \
void knn(raft::device_resources const& handle, \
raft::device_matrix_view<const DATA_T, MATRIX_IDX_T, INDEX_LAYOUT> index, \
raft::device_matrix_view<const DATA_T, MATRIX_IDX_T, SEARCH_LAYOUT> search, \
raft::device_matrix_view<IDX_T, MATRIX_IDX_T, row_major> indices, \
raft::device_matrix_view<DATA_T, MATRIX_IDX_T, row_major> distances, \
distance::DistanceType metric, \
std::optional<float> metric_arg, \
std::optional<IDX_T> global_id_offset) \
{ \
std::vector<raft::device_matrix_view<const DATA_T, MATRIX_IDX_T, INDEX_LAYOUT>> vec; \
vec.push_back(index); \
raft::neighbors::brute_force::knn(handle, \
vec, \
search, \
indices, \
distances, \
static_cast<int>(distances.extent(1)), \
metric, \
metric_arg, \
global_id_offset); \
}

RAFT_INST_BFKNN(int64_t, float, uint32_t, raft::row_major, raft::row_major);
RAFT_INST_BFKNN(int64_t, float, int64_t, raft::row_major, raft::row_major);

#undef RAFT_INST_BFKNN

Expand Down
1 change: 0 additions & 1 deletion python/pylibraft/pylibraft/common/mdspan.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ cdef device_matrix_view[float, int64_t, row_major] \
return make_device_matrix_view[float, int64_t, row_major](
<float*><uintptr_t>cai.data, shape[0], shape[1])


cdef device_matrix_view[uint8_t, int64_t, row_major] \
get_dmv_uint8(cai, check_shape) except *:
if cai.dtype != np.uint8:
Expand Down
2 changes: 1 addition & 1 deletion python/pylibraft/pylibraft/neighbors/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# =============================================================================

# Set the list of Cython files to build
set(cython_sources common.pyx refine.pyx)
set(cython_sources common.pyx refine.pyx brute_force.pyx)
set(linked_libraries raft::raft raft::compiled)

# Build all of the Cython targets
Expand Down
182 changes: 182 additions & 0 deletions python/pylibraft/pylibraft/neighbors/brute_force.pyx
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
#
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# cython: profile=False
# distutils: language = c++
# cython: embedsignature = True
# cython: language_level = 3

import numpy as np

from cython.operator cimport dereference as deref
from libcpp cimport bool, nullptr
from libcpp.vector cimport vector

from pylibraft.distance.distance_type cimport DistanceType

from pylibraft.common import (
DeviceResources,
auto_convert_output,
cai_wrapper,
device_ndarray,
)

from libc.stdint cimport int64_t, uintptr_t

from pylibraft.common.cpp.optional cimport optional
from pylibraft.common.handle cimport device_resources
from pylibraft.common.mdspan cimport get_dmv_float, get_dmv_int64

from pylibraft.common.handle import auto_sync_handle
from pylibraft.common.input_validation import is_c_contiguous
from pylibraft.common.interruptible import cuda_interruptible

from pylibraft.distance.distance_type cimport DistanceType

# TODO: Centralize this

from pylibraft.neighbors.ivf_pq.ivf_pq import _get_metric

from pylibraft.common.cpp.mdspan cimport (
device_matrix_view,
host_matrix_view,
make_device_matrix_view,
make_host_matrix_view,
row_major,
)
from pylibraft.neighbors.cpp.brute_force cimport knn as c_knn


def _get_array_params(array_interface, check_dtype=None):
dtype = np.dtype(array_interface["typestr"])
if check_dtype is None and dtype != check_dtype:
raise TypeError("dtype %s not supported" % dtype)
shape = array_interface["shape"]
if len(shape) != 2:
raise ValueError("Expected a 2D array, got %d D" % len(shape))
data = array_interface["data"][0]
return (shape, dtype, data)


@auto_sync_handle
@auto_convert_output
def knn(dataset, queries, k=None, indices=None, distances=None,
metric="sqeuclidean", metric_arg=2.0,
global_id_offset=0, handle=None):
"""
Perform a brute-force nearest neighbors search.

Parameters
----------
dataset : array interface compliant matrix, row-major layout,
shape (n_samples, dim). Supported dtype [float]
queries : array interface compliant matrix, row-major layout,
shape (n_queries, dim) Supported dtype [float]
k : int
Number of neighbors to search (k <= 2048). Optional if indices or
distances arrays are given (in which case their second dimension
is k).
indices : Optional array interface compliant matrix shape
(n_queries, k), dtype int64_t. If supplied, neighbor
indices will be written here in-place. (default None)
Supported dtype uint64
distances : Optional array interface compliant matrix shape
(n_queries, k), dtype float. If supplied, neighbor
indices will be written here in-place. (default None)

{handle_docstring}

Returns
-------
indices: array interface compliant object containing resulting indices
shape (n_queries, k)

distances: array interface compliant object containing resulting distances
shape (n_queries, k)

Examples
--------

>>> import cupy as cp

>>> from pylibraft.common import DeviceResources
>>> from pylibraft.neighbors import ivf_pq, refine
cjnolet marked this conversation as resolved.
Show resolved Hide resolved

>>> n_samples = 50000
>>> n_features = 50
>>> n_queries = 1000

>>> dataset = cp.random.random_sample((n_samples, n_features),
... dtype=cp.float32)
>>> # Search using the built index
>>> queries = cp.random.random_sample((n_queries, n_features),
... dtype=cp.float32)
>>> k = 40
>>> distances, neighbors = knn(dataset, queries, k)
>>> distances = cp.asarray(distances)
>>> neighbors = cp.asarray(neighbors)

>>> # pylibraft functions are often asynchronous so the
>>> # handle needs to be explicitly synchronized
"""

if handle is None:
handle = DeviceResources()

dataset_cai = cai_wrapper(dataset)
queries_cai = cai_wrapper(queries)

if k is None:
if indices is not None:
k = cai_wrapper(indices).shape[1]
elif distances is not None:
k = cai_wrapper(distances).shape[1]
else:
raise ValueError("Argument k must be specified if both indices "
"and distances arg is None")

n_queries = cai_wrapper(queries).shape[0]

if indices is None:
indices = device_ndarray.empty((n_queries, k), dtype='int64')

if distances is None:
distances = device_ndarray.empty((n_queries, k), dtype='float32')

cdef DistanceType c_metric = _get_metric(metric)

distances_cai = cai_wrapper(distances)
indices_cai = cai_wrapper(indices)

cdef optional[float] c_metric_arg = <float>metric_arg
cdef optional[int64_t] c_global_offset = <int64_t>global_id_offset

cdef device_resources* handle_ = \
<device_resources*><size_t>handle.getHandle()

if dataset_cai.dtype == np.float32:
with cuda_interruptible():
c_knn(deref(handle_),
get_dmv_float(dataset_cai, check_shape=True),
get_dmv_float(queries_cai, check_shape=True),
get_dmv_int64(indices_cai, check_shape=True),
get_dmv_float(distances_cai, check_shape=True),
c_metric,
c_metric_arg,
c_global_offset)
else:
raise TypeError("dtype %s not supported" % dataset_cai.dtype)

return (distances, indices)
Empty file.
14 changes: 14 additions & 0 deletions python/pylibraft/pylibraft/neighbors/cpp/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) 2022-2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
55 changes: 55 additions & 0 deletions python/pylibraft/pylibraft/neighbors/cpp/brute_force.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# cython: profile=False
# distutils: language = c++
# cython: embedsignature = True
# cython: language_level = 3

import numpy as np

import pylibraft.common.handle

from cython.operator cimport dereference as deref
from libc.stdint cimport int8_t, int64_t, uint8_t, uint64_t, uintptr_t
from libcpp cimport bool, nullptr
from libcpp.string cimport string
from libcpp.vector cimport vector

from rmm._lib.memory_resource cimport device_memory_resource

from pylibraft.common.cpp.mdspan cimport (
device_matrix_view,
host_matrix_view,
make_device_matrix_view,
make_host_matrix_view,
row_major,
)
from pylibraft.common.cpp.optional cimport optional
from pylibraft.common.handle cimport device_resources
from pylibraft.distance.distance_type cimport DistanceType


cdef extern from "raft_runtime/neighbors/brute_force.hpp" \
namespace "raft::runtime::neighbors::brute_force" nogil:

cdef void knn(const device_resources & handle,
device_matrix_view[float, int64_t, row_major] index,
device_matrix_view[float, int64_t, row_major] search,
device_matrix_view[int64_t, int64_t, row_major] indices,
device_matrix_view[float, int64_t, row_major] distances,
DistanceType metric,
optional[float] metric_arg,
optional[int64_t] global_id_offset) except +
Loading