Skip to content

Commit

Permalink
Changes to NearestNeighbors to call 2d random ball cover (rapidsai#4003)
Browse files Browse the repository at this point in the history
This PR integrates the [random ball cover PoC](rapidsai/raft#213) into cuml's brute-force knn for executing the random ball cover algorithm for haversine distance.

Authors:
  - Corey J. Nolet (https://github.com/cjnolet)
  - Dante Gama Dessavre (https://github.com/dantegd)
  - Paul Taylor (https://github.com/trxcllnt)

Approvers:
  - AJ Schmidt (https://github.com/ajschmidt8)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: rapidsai#4003
  • Loading branch information
cjnolet authored Sep 29, 2021
1 parent ee4588d commit 0416194
Show file tree
Hide file tree
Showing 7 changed files with 170 additions and 14 deletions.
6 changes: 6 additions & 0 deletions ci/local/build.sh
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
#!/bin/bash

# Copyright (c) 2018-2021, NVIDIA CORPORATION.
##############################################
# cuML local build and test script for CI #
##############################################


GIT_DESCRIBE_TAG=`git describe --tags`
MINOR_VERSION=`echo $GIT_DESCRIBE_TAG | grep -o -E '([0-9]+\.[0-9]+)'`

Expand Down
11 changes: 11 additions & 0 deletions cpp/include/cuml/neighbors/knn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <raft/linalg/distance_type.h>
#include <raft/spatial/knn/ann_common.h>
#include <raft/spatial/knn/ball_cover_common.h>

namespace raft {
class handle_t;
Expand Down Expand Up @@ -60,6 +61,16 @@ void brute_force_knn(const raft::handle_t& handle,
raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded,
float metric_arg = 2.0f);

void rbc_build_index(const raft::handle_t& handle,
raft::spatial::knn::BallCoverIndex<int64_t, float, uint32_t>& index);

void rbc_knn_query(const raft::handle_t& handle,
raft::spatial::knn::BallCoverIndex<int64_t, float, uint32_t>& index,
uint32_t k,
const float* search_items,
uint32_t n_search_items,
int64_t* out_inds,
float* out_dists);
/**
* @brief Flat C++ API function to build an approximate nearest neighbors index
* from an index array and a set of parameters.
Expand Down
19 changes: 19 additions & 0 deletions cpp/src/knn/knn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <raft/cuda_utils.cuh>
#include <raft/label/classlabels.cuh>
#include <raft/spatial/knn/ann.hpp>
#include <raft/spatial/knn/ball_cover.hpp>
#include <raft/spatial/knn/knn.hpp>
#include <rmm/device_uvector.hpp>

Expand Down Expand Up @@ -64,6 +65,24 @@ void brute_force_knn(const raft::handle_t& handle,
metric_arg);
}

void rbc_build_index(const raft::handle_t& handle,
raft::spatial::knn::BallCoverIndex<int64_t, float, uint32_t>& index)
{
raft::spatial::knn::rbc_build_index(handle, index);
}

void rbc_knn_query(const raft::handle_t& handle,
raft::spatial::knn::BallCoverIndex<int64_t, float, uint32_t>& index,
uint32_t k,
const float* search_items,
uint32_t n_search_items,
int64_t* out_inds,
float* out_dists)
{
raft::spatial::knn::rbc_knn_query(
handle, index, k, search_items, n_search_items, out_inds, out_dists);
}

void approx_knn_build_index(raft::handle_t& handle,
raft::spatial::knn::knnIndex* index,
raft::spatial::knn::knnIndexParam* params,
Expand Down
1 change: 0 additions & 1 deletion cpp/test/sg/hdbscan_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,6 @@ class HDBSCANTest : public ::testing::TestWithParam<HDBSCANInputs<T, IdxT>> {
protected:
HDBSCANInputs<T, IdxT> params;
IdxT* labels_ref;
int k;

double score;
};
Expand Down
6 changes: 5 additions & 1 deletion python/cuml/neighbors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@
"inner_product", "sqeuclidean",
"haversine"
]),
"rbc": set([
"euclidean", "haversine",
"l2"
]),
"ivfflat": set([
"l2", "euclidean", "sqeuclidean",
"inner_product", "cosine", "correlation"
Expand All @@ -45,7 +49,7 @@
"l2", "euclidean", "sqeuclidean",
"inner_product", "cosine", "correlation"
])
}
}

VALID_METRICS_SPARSE = {
"brute": set(["euclidean", "l2", "inner_product",
Expand Down
101 changes: 89 additions & 12 deletions python/cuml/neighbors/nearest_neighbors.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ from cython.operator cimport dereference as deref
from libcpp cimport bool
from libcpp.memory cimport shared_ptr

from libc.stdint cimport uintptr_t, int64_t
from libc.stdint cimport uintptr_t, int64_t, uint32_t
from libc.stdlib cimport calloc, malloc, free

from libcpp.vector cimport vector
Expand All @@ -63,6 +63,16 @@ cimport cuml.common.cuda
if has_scipy():
import scipy.sparse


cdef extern from "raft/spatial/knn/ball_cover_common.h" \
namespace "raft::spatial::knn":
cdef cppclass BallCoverIndex[int64_t, float, uint32_t]:
BallCoverIndex(const handle_t &handle,
float *X,
uint32_t n_rows,
uint32_t n_cols,
DistanceType metric) except +

cdef extern from "cuml/neighbors/knn.hpp" namespace "ML":
void brute_force_knn(
const handle_t &handle,
Expand All @@ -80,6 +90,21 @@ cdef extern from "cuml/neighbors/knn.hpp" namespace "ML":
float metric_arg
) except +

void rbc_build_index(
const handle_t &handle,
BallCoverIndex[int64_t, float, uint32_t] &index,
) except +

void rbc_knn_query(
const handle_t &handle,
BallCoverIndex[int64_t, float, uint32_t] &index,
uint32_t k,
float *search_items,
uint32_t n_search_items,
int64_t *out_inds,
float *out_dists
) except +

void approx_knn_build_index(
handle_t &handle,
knnIndex* index,
Expand All @@ -101,6 +126,7 @@ cdef extern from "cuml/neighbors/knn.hpp" namespace "ML":
int n
) except +


cdef extern from "cuml/neighbors/knn_sparse.hpp" namespace "ML::Sparse":
void brute_force_knn(handle_t &handle,
const int *idxIndptr,
Expand Down Expand Up @@ -148,6 +174,12 @@ class NearestNeighbors(Base,
algorithm : string (default='brute')
The query algorithm to use. Valid options are:
- ``'auto'``: to automatically select brute-force or
random ball cover based on data shape and metric
- ``'rbc'``: for the random ball algorithm, which partitions
the data space and uses the triangle inequality to lower the
number of potential distances. Currently, this algorithm
supports 2d Euclidean and Haversine.
- ``'brute'``: for brute-force, slow but produces exact results
- ``'ivfflat'``: for inverted file, divide the dataset in partitions
and perform search on relevant partitions only
Expand Down Expand Up @@ -299,7 +331,7 @@ class NearestNeighbors(Base,
n_neighbors=5,
verbose=False,
handle=None,
algorithm="brute",
algorithm="auto",
metric="euclidean",
p=2,
algo_params=None,
Expand All @@ -318,8 +350,10 @@ class NearestNeighbors(Base,
self.algo_params = algo_params
self.p = p
self.algorithm = algorithm
self.working_algorithm_ = self.algorithm
self.selected_algorithm_ = algorithm
self.algo_params = algo_params
self.knn_index = <uintptr_t> 0
self.knn_index = None

@generate_docstring(X='dense_sparse')
def fit(self, X, convert_dtype=True) -> "NearestNeighbors":
Expand All @@ -332,29 +366,44 @@ class NearestNeighbors(Base,

self.n_dims = X.shape[1]

if self.algorithm == "auto":
if self.n_dims == 2 and self.metric in \
cuml.neighbors.VALID_METRICS["rbc"]:
self.working_algorithm_ = "rbc"
else:
self.working_algorithm_ = "brute"

if self.algorithm == "rbc" and self.n_dims > 2:
raise ValueError("rbc algorithm currently only supports 2d data")

if is_sparse(X):
valid_metrics = cuml.neighbors.VALID_METRICS_SPARSE
value_metric_str = "_SPARSE"
self.X_m = SparseCumlArray(X, convert_to_dtype=cp.float32,
convert_format=False)
self.n_rows = self.X_m.shape[0]

else:
valid_metrics = cuml.neighbors.VALID_METRICS
valid_metric_str = ""
self.X_m, self.n_rows, n_cols, dtype = \
input_to_cuml_array(X, order='C', check_dtype=np.float32,
convert_to_dtype=(np.float32
if convert_dtype
else None))

if self.metric not in valid_metrics[self.algorithm]:
if self.metric not in \
valid_metrics[self.working_algorithm_]:
raise ValueError("Metric %s is not valid. "
"Use sorted(cuml.neighbors.VALID_METRICS[%s]) "
"Use sorted(cuml.neighbors.VALID_METRICS%s[%s]) "
"to get valid options." %
(self.metric, self.algorithm))
(valid_metric_str,
self.metric,
self.working_algorithm_))

cdef handle_t* handle_ = <handle_t*><uintptr_t> self.handle.getHandle()
cdef knnIndexParam* algo_params = <knnIndexParam*> 0
if self.algorithm in ['ivfflat', 'ivfpq', 'ivfsq']:
if self.working_algorithm_ in ['ivfflat', 'ivfpq', 'ivfsq']:
warnings.warn("\nWarning: Approximate Nearest Neighbor methods "
"might be unstable in this version of cuML. "
"This is due to a known issue in the FAISS "
Expand All @@ -370,7 +419,7 @@ class NearestNeighbors(Base,
knn_index = new knnIndex()
self.knn_index = <uintptr_t> knn_index
algo_params = <knnIndexParam*><uintptr_t> \
build_algo_params(self.algorithm, self.algo_params,
build_algo_params(self.working_algorithm_, self.algo_params,
additional_info)
metric = self._build_metric_type(self.metric)

Expand All @@ -387,6 +436,16 @@ class NearestNeighbors(Base,
destroy_algo_params(<uintptr_t>algo_params)

del self.X_m
elif self.working_algorithm_ == "rbc":
metric = self._build_metric_type(self.metric)

rbc_index = new BallCoverIndex[int64_t, float, uint32_t](
handle_[0], <float*><uintptr_t>self.X_m.ptr,
<uint32_t>self.n_rows, <uint32_t>n_cols,
<DistanceType>metric)
rbc_build_index(handle_[0],
deref(rbc_index))
self.knn_index = <uintptr_t>rbc_index

self.n_indices = 1
return self
Expand Down Expand Up @@ -654,8 +713,10 @@ class NearestNeighbors(Base,
cdef vector[float*] *inputs = new vector[float*]()
cdef vector[int] *sizes = new vector[int]()
cdef knnIndex* knn_index = <knnIndex*> 0
cdef BallCoverIndex[int64_t, float, uint32_t]* rbc_index = \
<BallCoverIndex[int64_t, float, uint32_t]*> 0

if self.algorithm == 'brute':
if self.working_algorithm_ == 'brute':
inputs.push_back(<float*><uintptr_t>self.X_m.ptr)
sizes.push_back(<int>self.X_m.shape[0])

Expand All @@ -675,6 +736,16 @@ class NearestNeighbors(Base,
# minkowski order is currently the only metric argument.
<float>self.p
)
elif self.working_algorithm_ == "rbc":
rbc_index = <BallCoverIndex[int64_t, float, uint32_t]*>\
<uintptr_t>self.knn_index
rbc_knn_query(handle_[0],
deref(rbc_index),
<uint32_t> n_neighbors,
<float*><uintptr_t>X_m.ptr,
<uint32_t> N,
<int64_t*>I_ptr,
<float*>D_ptr)
else:
knn_index = <knnIndex*><uintptr_t> self.knn_index
approx_knn_search(
Expand Down Expand Up @@ -826,9 +897,15 @@ class NearestNeighbors(Base,
return sparse_csr

def __del__(self):
cdef knnIndex* knn_index = <knnIndex*><uintptr_t>self.knn_index
if knn_index:
del knn_index
cdef knnIndex* knn_index = <knnIndex*>0
cdef BallCoverIndex* rbc_index = <BallCoverIndex*>0
if self.knn_index is not None:
if self.working_algorithm_ in ["ivfflat", "ivfpq", "ivfsq"]:
knn_index = <knnIndex*><uintptr_t>self.knn_index
del knn_index
else:
rbc_index = <BallCoverIndex*><uintptr_t>self.knn_index
del rbc_index


@cuml.internals.api_return_sparse_array()
Expand Down
40 changes: 40 additions & 0 deletions python/cuml/test/test_nearest_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from cuml.datasets import make_blobs

from sklearn.metrics import pairwise_distances
from cuml.metrics import pairwise_distances as cuPW

from cuml.common import logger

Expand Down Expand Up @@ -511,6 +512,45 @@ def test_knn_graph(input_type, mode, output_type, as_instance,
assert isspmatrix_csr(sparse_cu)


@pytest.mark.parametrize('distance', ["euclidean", "haversine"])
@pytest.mark.parametrize('n_neighbors', [2, 12])
@pytest.mark.parametrize('nrows', [unit_param(1000), stress_param(70000)])
def test_nearest_neighbors_rbc(distance, n_neighbors, nrows):
X, y = make_blobs(n_samples=nrows,
n_features=2, random_state=0)

knn_cu = cuKNN(metric=distance, algorithm="rbc")
knn_cu.fit(X)

query_rows = int(nrows/2)

rbc_d, rbc_i = knn_cu.kneighbors(X[:query_rows, :],
n_neighbors=n_neighbors)

if distance == 'euclidean':
# Need to use unexpanded euclidean distance
pw_dists = cuPW(X, metric="l2")
brute_i = cp.argsort(pw_dists, axis=1)[:query_rows, :n_neighbors]
brute_d = cp.sort(pw_dists, axis=1)[:query_rows, :n_neighbors]
else:
knn_cu_brute = cuKNN(metric=distance, algorithm="brute")
knn_cu_brute.fit(X)

brute_d, brute_i = knn_cu_brute.kneighbors(
X[:query_rows, :], n_neighbors=n_neighbors)

cp.testing.assert_allclose(rbc_d, brute_d, atol=5e-2,
rtol=1e-3)
rbc_i = cp.sort(rbc_i, axis=1)
brute_i = cp.sort(brute_i, axis=1)

diff = rbc_i != brute_i

# Using a very small tolerance for subtle differences
# in indices that result from non-determinism
assert diff.ravel().sum() < 5


@pytest.mark.parametrize("metric", valid_metrics_sparse())
@pytest.mark.parametrize(
'nrows,ncols,density,n_neighbors,batch_size_index,batch_size_query',
Expand Down

0 comments on commit 0416194

Please sign in to comment.