From 272f856a226590df4ccd1d5e04e524d014364637 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 19 Jan 2023 19:41:08 -0800 Subject: [PATCH] Use ivf_pq and ivf_flat from raft (#5119) We are removing the faiss ANN code in https://github.com/rapidsai/raft/pull/1121, in favour of using the ivf_flat and ivf_pq implementations included with raft. After this change, RAFT can be updated to remove the faiss ANN methods - which is implemented in https://github.com/rapidsai/raft/pull/1121 Note that this removes the `ivf_sq` option , since there is no corresponding implementation in raft. Closes #5131 Authors: - Ben Frederickson (https://github.com/benfred) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/cuml/pull/5119 --- cpp/src_prims/selection/knn.cuh | 11 ---- python/cuml/neighbors/__init__.py | 4 -- python/cuml/neighbors/ann.pxd | 7 --- python/cuml/neighbors/ann.pyx | 35 +------------ .../cuml/neighbors/kneighbors_regressor.pyx | 4 -- python/cuml/neighbors/nearest_neighbors.pyx | 21 +------- python/cuml/tests/test_nearest_neighbors.py | 50 +++++-------------- 7 files changed, 16 insertions(+), 116 deletions(-) diff --git a/cpp/src_prims/selection/knn.cuh b/cpp/src_prims/selection/knn.cuh index c4500b0ed7..52ba52cc9e 100644 --- a/cpp/src_prims/selection/knn.cuh +++ b/cpp/src_prims/selection/knn.cuh @@ -25,17 +25,6 @@ #include #include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - #include #include diff --git a/python/cuml/neighbors/__init__.py b/python/cuml/neighbors/__init__.py index e277b1100c..14e392a34e 100644 --- a/python/cuml/neighbors/__init__.py +++ b/python/cuml/neighbors/__init__.py @@ -47,10 +47,6 @@ "l2", "euclidean", "sqeuclidean", "inner_product", "cosine", "correlation" ]), - "ivfsq": set([ - "l2", "euclidean", "sqeuclidean", - "inner_product", "cosine", "correlation" - ]) } VALID_METRICS_SPARSE = { diff --git a/python/cuml/neighbors/ann.pxd b/python/cuml/neighbors/ann.pxd index 1d667fba84..8819794b8f 100644 --- a/python/cuml/neighbors/ann.pxd +++ b/python/cuml/neighbors/ann.pxd @@ -51,10 +51,6 @@ cdef extern from "raft/spatial/knn/ann_common.h" \ int n_bits bool usePrecomputedTables - cdef cppclass IVFSQParam(IVFParam): - QuantizerType qtype - bool encodeResidual - cdef check_algo_params(algo, params) @@ -65,9 +61,6 @@ cdef build_ivfflat_algo_params(params, automated) cdef build_ivfpq_algo_params(params, automated, additional_info) -cdef build_ivfsq_algo_params(params, automated) - - cdef build_algo_params(algo, params, additional_info) diff --git a/python/cuml/neighbors/ann.pyx b/python/cuml/neighbors/ann.pyx index 7fe0a68a58..4a763b8ae2 100644 --- a/python/cuml/neighbors/ann.pyx +++ b/python/cuml/neighbors/ann.pyx @@ -32,9 +32,6 @@ cdef check_algo_params(algo, params): elif algo == "ivfpq": check_param_list(params, ['nlist', 'nprobe', 'M', 'n_bits', 'usePrecomputedTables']) - elif algo == "ivfsq": - check_param_list(params, ['nlist', 'nprobe', 'qtype', - 'encodeResidual']) cdef build_ivfflat_algo_params(params, automated): @@ -77,7 +74,7 @@ cdef build_ivfpq_algo_params(params, automated, additional_info): # n_bits should be in set {4, 5, 6, 8} since FAISS 1.7 params['n_bits'] = 4 - for n_bits in [5, 6, 8]: + for n_bits in [4, 5, 6, 8]: min_train_points = (2 ** n_bits) * 39 if N >= min_train_points: params['n_bits'] = n_bits @@ -92,33 +89,6 @@ cdef build_ivfpq_algo_params(params, automated, additional_info): return algo_params -cdef build_ivfsq_algo_params(params, automated): - cdef IVFSQParam* algo_params = new IVFSQParam() - if automated: - params = { - 'nlist': 8, - 'nprobe': 2, - 'qtype': 'QT_8bit', - 'encodeResidual': True - } - - quantizer_type = { - 'QT_8bit': QuantizerType.QT_8bit, - 'QT_4bit': QuantizerType.QT_4bit, - 'QT_8bit_uniform': QuantizerType.QT_8bit_uniform, - 'QT_4bit_uniform': QuantizerType.QT_4bit_uniform, - 'QT_fp16': QuantizerType.QT_fp16, - 'QT_8bit_direct': QuantizerType.QT_8bit_direct, - 'QT_6bit': QuantizerType.QT_6bit, - } - - algo_params.nlist = params['nlist'] - algo_params.nprobe = params['nprobe'] - algo_params.qtype = quantizer_type[params['qtype']] - algo_params.encodeResidual = params['encodeResidual'] - return algo_params - - cdef build_algo_params(algo, params, additional_info): automated = params is None or params == 'auto' if not automated: @@ -131,9 +101,6 @@ cdef build_algo_params(algo, params, additional_info): if algo == 'ivfpq': algo_params = \ build_ivfpq_algo_params(params, automated, additional_info) - elif algo == 'ivfsq': - algo_params = \ - build_ivfsq_algo_params(params, automated) return algo_params diff --git a/python/cuml/neighbors/kneighbors_regressor.pyx b/python/cuml/neighbors/kneighbors_regressor.pyx index 950cbda893..57cdb8f896 100644 --- a/python/cuml/neighbors/kneighbors_regressor.pyx +++ b/python/cuml/neighbors/kneighbors_regressor.pyx @@ -99,10 +99,6 @@ class KNeighborsRegressor(RegressorMixin, in n_features/M sub-vectors that will be encoded thanks to intermediary k-means clusterings. This encoding provide partial information allowing faster distances calculations - - ``'ivfsq'``: for inverted file and scalar quantization, - same as inverted list, in addition vectors components - are quantized into reduced binary representation allowing - faster distances calculations metric : string (default='euclidean'). Distance metric to use. weights : string (default='uniform') diff --git a/python/cuml/neighbors/nearest_neighbors.pyx b/python/cuml/neighbors/nearest_neighbors.pyx index 8b0ac5926a..bfee30c9bc 100644 --- a/python/cuml/neighbors/nearest_neighbors.pyx +++ b/python/cuml/neighbors/nearest_neighbors.pyx @@ -192,10 +192,6 @@ class NearestNeighbors(UniversalBase, in n_features/M sub-vectors that will be encoded thanks to intermediary k-means clusterings. This encoding provide partial information allowing faster distances calculations - - ``'ivfsq'``: for inverted file and scalar quantization, - same as inverted list, in addition vectors components - are quantized into reduced binary representation allowing - faster distances calculations metric : string (default='euclidean'). Distance metric to use. Supported distances are ['l1, 'cityblock', @@ -228,15 +224,6 @@ class NearestNeighbors(UniversalBase, - n_bits: (int) bits allocated per subquantizer - usePrecomputedTables : (bool) wether to use precomputed tables - Parameters for algorithm ``'ivfsq'``: - - - nlist: (int) number of cells to partition dataset into - - nprobe: (int) at query time, number of cells used for search - - qtype: (string) quantizer type (among QT_8bit, QT_4bit, - QT_8bit_uniform, QT_4bit_uniform, QT_fp16, QT_8bit_direct, - QT_6bit) - - encodeResidual: (bool) wether to encode residuals - metric_expanded : bool Can increase performance in Minkowski-based (Lp) metrics (for p > 1) by using the expanded form and not computing the n-th roots. @@ -397,7 +384,7 @@ class NearestNeighbors(UniversalBase, cdef handle_t* handle_ = self.handle.getHandle() cdef knnIndexParam* algo_params = 0 - if self._fit_method in ['ivfflat', 'ivfpq', 'ivfsq']: + if self._fit_method in ['ivfflat', 'ivfpq']: warnings.warn("\nWarning: Approximate Nearest Neighbor methods " "might be unstable in this version of cuML. " "This is due to a known issue in the FAISS " @@ -919,7 +906,7 @@ class NearestNeighbors(UniversalBase, kidx = self.__dict__['knn_index'] \ if 'knn_index' in self.__dict__ else None if kidx is not None: - if self._fit_method in ["ivfflat", "ivfpq", "ivfsq"]: + if self._fit_method in ["ivfflat", "ivfpq"]: knn_index = kidx del knn_index else: @@ -979,10 +966,6 @@ def kneighbors_graph(X=None, n_neighbors=5, mode='connectivity', verbose=False, in n_features/M sub-vectors that will be encoded thanks to intermediary k-means clusterings. This encoding provide partial information allowing faster distances calculations - - ``'ivfsq'``: for inverted file and scalar quantization, - same as inverted list, in addition vectors components - are quantized into reduced binary representation allowing - faster distances calculations metric : string (default='euclidean'). Distance metric to use. Supported distances are ['l1, 'cityblock', diff --git a/python/cuml/tests/test_nearest_neighbors.py b/python/cuml/tests/test_nearest_neighbors.py index 8c31ea328c..3a30af925a 100644 --- a/python/cuml/tests/test_nearest_neighbors.py +++ b/python/cuml/tests/test_nearest_neighbors.py @@ -118,7 +118,7 @@ def test_self_neighboring(datatype, metric_p, nrows): metric, p = metric_p if not has_scipy(): - pytest.skip('Skipping test_neighborhood_predictions because ' + + pytest.skip('Skipping test_self_neighboring because ' + 'Scipy is missing') X, y = make_blobs(n_samples=nrows, centers=n_clusters, @@ -166,8 +166,7 @@ def test_self_neighboring(datatype, metric_p, nrows): @pytest.mark.parametrize("algo,datatype", [("brute", "dataframe"), ("ivfflat", "numpy"), - ("ivfpq", "dataframe"), - ("ivfsq", "numpy")]) + ("ivfpq", "dataframe")]) def test_neighborhood_predictions(nrows, ncols, n_neighbors, n_clusters, datatype, algo): if not has_scipy(): @@ -206,7 +205,7 @@ def test_neighborhood_predictions(nrows, ncols, n_neighbors, n_clusters, def test_ivfflat_pred(nrows, ncols, n_neighbors, nlist): algo_params = { 'nlist': nlist, - 'nprobe': nlist * 0.25 + 'nprobe': nlist * 0.5 } X, y = make_blobs(n_samples=nrows, centers=5, @@ -257,39 +256,16 @@ def test_ivfpq_pred(nrows, ncols, n_neighbors, assert array_equal(labels, y) -@pytest.mark.parametrize("qtype,encodeResidual,nrows,ncols,n_neighbors,nlist", - [('QT_4bit', False, 10000, 128, 8, 4), - ('QT_8bit', True, 1000, 512, 7, 4), - ('QT_fp16', False, 3000, 301, 5, 8)]) -def test_ivfsq_pred(qtype, encodeResidual, nrows, ncols, n_neighbors, nlist): - algo_params = { - 'nlist': nlist, - 'nprobe': nlist * 0.25, - 'qtype': qtype, - 'encodeResidual': encodeResidual - } - - X, y = make_blobs(n_samples=nrows, centers=5, - n_features=ncols, random_state=0) - - logger.set_level(logger.level_debug) - knn_cu = cuKNN(algorithm="ivfsq", algo_params=algo_params) - knn_cu.fit(X) - neigh_ind = knn_cu.kneighbors(X, n_neighbors=n_neighbors, - return_distance=False) - del knn_cu - gc.collect() - - labels, probs = predict(neigh_ind, y, n_neighbors) - - assert array_equal(labels, y) - - -@pytest.mark.parametrize("algo", ["brute", "ivfflat", "ivfpq", "ivfsq"]) -@pytest.mark.parametrize("metric", [ - "l2", "euclidean", "sqeuclidean", - "cosine", "correlation" -]) +@pytest.mark.parametrize( + "algo, metric", + [ + (algo, metric) + for algo in ["brute", "ivfflat", "ivfpq"] + for metric in ["l2", "euclidean", "sqeuclidean", "cosine", + "correlation"] + if metric in cuml.neighbors.VALID_METRICS[algo] + ], +) def test_ann_distances_metrics(algo, metric): X, y = make_blobs(n_samples=500, centers=2, n_features=128, random_state=0)