Skip to content

Commit

Permalink
Use ivf_pq and ivf_flat from raft (#5119)
Browse files Browse the repository at this point in the history
We are removing the faiss ANN code in rapidsai/raft#1121, in favour of
using the ivf_flat and ivf_pq implementations included with raft.

After this change, RAFT can be updated to remove the faiss ANN methods  - which is implemented in
rapidsai/raft#1121

Note that this removes the `ivf_sq` option , since there is no corresponding implementation in raft.

Closes #5131

Authors:
  - Ben Frederickson (https://github.com/benfred)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #5119
  • Loading branch information
benfred authored Jan 20, 2023
1 parent bd899e1 commit 272f856
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 116 deletions.
11 changes: 0 additions & 11 deletions cpp/src_prims/selection/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,6 @@
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>

#include <faiss/gpu/GpuDistance.h>
#include <faiss/gpu/GpuIndexFlat.h>
#include <faiss/gpu/GpuIndexIVFFlat.h>
#include <faiss/gpu/GpuIndexIVFPQ.h>
#include <faiss/gpu/GpuIndexIVFScalarQuantizer.h>
#include <faiss/gpu/GpuResources.h>
#include <faiss/gpu/utils/Limits.cuh>
#include <faiss/gpu/utils/Select.cuh>
#include <faiss/gpu/utils/Tensor.cuh>
#include <faiss/utils/Heap.h>

#include <thrust/device_vector.h>
#include <thrust/iterator/transform_iterator.h>

Expand Down
4 changes: 0 additions & 4 deletions python/cuml/neighbors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,6 @@
"l2", "euclidean", "sqeuclidean",
"inner_product", "cosine", "correlation"
]),
"ivfsq": set([
"l2", "euclidean", "sqeuclidean",
"inner_product", "cosine", "correlation"
])
}

VALID_METRICS_SPARSE = {
Expand Down
7 changes: 0 additions & 7 deletions python/cuml/neighbors/ann.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,6 @@ cdef extern from "raft/spatial/knn/ann_common.h" \
int n_bits
bool usePrecomputedTables

cdef cppclass IVFSQParam(IVFParam):
QuantizerType qtype
bool encodeResidual


cdef check_algo_params(algo, params)

Expand All @@ -65,9 +61,6 @@ cdef build_ivfflat_algo_params(params, automated)
cdef build_ivfpq_algo_params(params, automated, additional_info)


cdef build_ivfsq_algo_params(params, automated)


cdef build_algo_params(algo, params, additional_info)


Expand Down
35 changes: 1 addition & 34 deletions python/cuml/neighbors/ann.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ cdef check_algo_params(algo, params):
elif algo == "ivfpq":
check_param_list(params, ['nlist', 'nprobe', 'M', 'n_bits',
'usePrecomputedTables'])
elif algo == "ivfsq":
check_param_list(params, ['nlist', 'nprobe', 'qtype',
'encodeResidual'])


cdef build_ivfflat_algo_params(params, automated):
Expand Down Expand Up @@ -77,7 +74,7 @@ cdef build_ivfpq_algo_params(params, automated, additional_info):

# n_bits should be in set {4, 5, 6, 8} since FAISS 1.7
params['n_bits'] = 4
for n_bits in [5, 6, 8]:
for n_bits in [4, 5, 6, 8]:
min_train_points = (2 ** n_bits) * 39
if N >= min_train_points:
params['n_bits'] = n_bits
Expand All @@ -92,33 +89,6 @@ cdef build_ivfpq_algo_params(params, automated, additional_info):
return <uintptr_t>algo_params


cdef build_ivfsq_algo_params(params, automated):
cdef IVFSQParam* algo_params = new IVFSQParam()
if automated:
params = {
'nlist': 8,
'nprobe': 2,
'qtype': 'QT_8bit',
'encodeResidual': True
}

quantizer_type = {
'QT_8bit': <int> QuantizerType.QT_8bit,
'QT_4bit': <int> QuantizerType.QT_4bit,
'QT_8bit_uniform': <int> QuantizerType.QT_8bit_uniform,
'QT_4bit_uniform': <int> QuantizerType.QT_4bit_uniform,
'QT_fp16': <int> QuantizerType.QT_fp16,
'QT_8bit_direct': <int> QuantizerType.QT_8bit_direct,
'QT_6bit': <int> QuantizerType.QT_6bit,
}

algo_params.nlist = <int> params['nlist']
algo_params.nprobe = <int> params['nprobe']
algo_params.qtype = <QuantizerType> quantizer_type[params['qtype']]
algo_params.encodeResidual = <bool> params['encodeResidual']
return <uintptr_t>algo_params


cdef build_algo_params(algo, params, additional_info):
automated = params is None or params == 'auto'
if not automated:
Expand All @@ -131,9 +101,6 @@ cdef build_algo_params(algo, params, additional_info):
if algo == 'ivfpq':
algo_params = <knnIndexParam*><uintptr_t> \
build_ivfpq_algo_params(params, automated, additional_info)
elif algo == 'ivfsq':
algo_params = <knnIndexParam*><uintptr_t> \
build_ivfsq_algo_params(params, automated)

return <uintptr_t>algo_params

Expand Down
4 changes: 0 additions & 4 deletions python/cuml/neighbors/kneighbors_regressor.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,6 @@ class KNeighborsRegressor(RegressorMixin,
in n_features/M sub-vectors that will be encoded thanks
to intermediary k-means clusterings. This encoding provide
partial information allowing faster distances calculations
- ``'ivfsq'``: for inverted file and scalar quantization,
same as inverted list, in addition vectors components
are quantized into reduced binary representation allowing
faster distances calculations
metric : string (default='euclidean').
Distance metric to use.
weights : string (default='uniform')
Expand Down
21 changes: 2 additions & 19 deletions python/cuml/neighbors/nearest_neighbors.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,6 @@ class NearestNeighbors(UniversalBase,
in n_features/M sub-vectors that will be encoded thanks
to intermediary k-means clusterings. This encoding provide
partial information allowing faster distances calculations
- ``'ivfsq'``: for inverted file and scalar quantization,
same as inverted list, in addition vectors components
are quantized into reduced binary representation allowing
faster distances calculations
metric : string (default='euclidean').
Distance metric to use. Supported distances are ['l1, 'cityblock',
Expand Down Expand Up @@ -228,15 +224,6 @@ class NearestNeighbors(UniversalBase,
- n_bits: (int) bits allocated per subquantizer
- usePrecomputedTables : (bool) wether to use precomputed tables
Parameters for algorithm ``'ivfsq'``:
- nlist: (int) number of cells to partition dataset into
- nprobe: (int) at query time, number of cells used for search
- qtype: (string) quantizer type (among QT_8bit, QT_4bit,
QT_8bit_uniform, QT_4bit_uniform, QT_fp16, QT_8bit_direct,
QT_6bit)
- encodeResidual: (bool) wether to encode residuals
metric_expanded : bool
Can increase performance in Minkowski-based (Lp) metrics (for p > 1)
by using the expanded form and not computing the n-th roots.
Expand Down Expand Up @@ -397,7 +384,7 @@ class NearestNeighbors(UniversalBase,

cdef handle_t* handle_ = <handle_t*><uintptr_t> self.handle.getHandle()
cdef knnIndexParam* algo_params = <knnIndexParam*> 0
if self._fit_method in ['ivfflat', 'ivfpq', 'ivfsq']:
if self._fit_method in ['ivfflat', 'ivfpq']:
warnings.warn("\nWarning: Approximate Nearest Neighbor methods "
"might be unstable in this version of cuML. "
"This is due to a known issue in the FAISS "
Expand Down Expand Up @@ -919,7 +906,7 @@ class NearestNeighbors(UniversalBase,
kidx = self.__dict__['knn_index'] \
if 'knn_index' in self.__dict__ else None
if kidx is not None:
if self._fit_method in ["ivfflat", "ivfpq", "ivfsq"]:
if self._fit_method in ["ivfflat", "ivfpq"]:
knn_index = <knnIndex*><uintptr_t>kidx
del knn_index
else:
Expand Down Expand Up @@ -979,10 +966,6 @@ def kneighbors_graph(X=None, n_neighbors=5, mode='connectivity', verbose=False,
in n_features/M sub-vectors that will be encoded thanks
to intermediary k-means clusterings. This encoding provide
partial information allowing faster distances calculations
- ``'ivfsq'``: for inverted file and scalar quantization,
same as inverted list, in addition vectors components
are quantized into reduced binary representation allowing
faster distances calculations
metric : string (default='euclidean').
Distance metric to use. Supported distances are ['l1, 'cityblock',
Expand Down
50 changes: 13 additions & 37 deletions python/cuml/tests/test_nearest_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_self_neighboring(datatype, metric_p, nrows):
metric, p = metric_p

if not has_scipy():
pytest.skip('Skipping test_neighborhood_predictions because ' +
pytest.skip('Skipping test_self_neighboring because ' +
'Scipy is missing')

X, y = make_blobs(n_samples=nrows, centers=n_clusters,
Expand Down Expand Up @@ -166,8 +166,7 @@ def test_self_neighboring(datatype, metric_p, nrows):
@pytest.mark.parametrize("algo,datatype",
[("brute", "dataframe"),
("ivfflat", "numpy"),
("ivfpq", "dataframe"),
("ivfsq", "numpy")])
("ivfpq", "dataframe")])
def test_neighborhood_predictions(nrows, ncols, n_neighbors, n_clusters,
datatype, algo):
if not has_scipy():
Expand Down Expand Up @@ -206,7 +205,7 @@ def test_neighborhood_predictions(nrows, ncols, n_neighbors, n_clusters,
def test_ivfflat_pred(nrows, ncols, n_neighbors, nlist):
algo_params = {
'nlist': nlist,
'nprobe': nlist * 0.25
'nprobe': nlist * 0.5
}

X, y = make_blobs(n_samples=nrows, centers=5,
Expand Down Expand Up @@ -257,39 +256,16 @@ def test_ivfpq_pred(nrows, ncols, n_neighbors,
assert array_equal(labels, y)


@pytest.mark.parametrize("qtype,encodeResidual,nrows,ncols,n_neighbors,nlist",
[('QT_4bit', False, 10000, 128, 8, 4),
('QT_8bit', True, 1000, 512, 7, 4),
('QT_fp16', False, 3000, 301, 5, 8)])
def test_ivfsq_pred(qtype, encodeResidual, nrows, ncols, n_neighbors, nlist):
algo_params = {
'nlist': nlist,
'nprobe': nlist * 0.25,
'qtype': qtype,
'encodeResidual': encodeResidual
}

X, y = make_blobs(n_samples=nrows, centers=5,
n_features=ncols, random_state=0)

logger.set_level(logger.level_debug)
knn_cu = cuKNN(algorithm="ivfsq", algo_params=algo_params)
knn_cu.fit(X)
neigh_ind = knn_cu.kneighbors(X, n_neighbors=n_neighbors,
return_distance=False)
del knn_cu
gc.collect()

labels, probs = predict(neigh_ind, y, n_neighbors)

assert array_equal(labels, y)


@pytest.mark.parametrize("algo", ["brute", "ivfflat", "ivfpq", "ivfsq"])
@pytest.mark.parametrize("metric", [
"l2", "euclidean", "sqeuclidean",
"cosine", "correlation"
])
@pytest.mark.parametrize(
"algo, metric",
[
(algo, metric)
for algo in ["brute", "ivfflat", "ivfpq"]
for metric in ["l2", "euclidean", "sqeuclidean", "cosine",
"correlation"]
if metric in cuml.neighbors.VALID_METRICS[algo]
],
)
def test_ann_distances_metrics(algo, metric):
X, y = make_blobs(n_samples=500, centers=2,
n_features=128, random_state=0)
Expand Down

0 comments on commit 272f856

Please sign in to comment.