From 9f746409e67b4f7b979ed381e3be9d0ac30bb9dc Mon Sep 17 00:00:00 2001 From: Matthijs Douze Date: Thu, 4 May 2023 09:59:06 -0700 Subject: [PATCH] IVF sorting routine (#2846) Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2846 Adds a function to ivf_contrib to sort the inverted lists by size without changing the results. Also moves big_batch_search to its own module. Reviewed By: algoriddle Differential Revision: D45565880 fbshipit-source-id: 091a1c1c074f860d6953bf20d04523292fb55e1a --- contrib/README.md | 5 + contrib/big_batch_search.py | 459 ++++++++++++++++++++++++++++ contrib/ivf_tools.py | 465 ++--------------------------- faiss/IndexFlatCodes.cpp | 11 + faiss/IndexFlatCodes.h | 3 + faiss/IndexHNSW.cpp | 8 + faiss/IndexHNSW.h | 2 + faiss/gpu/test/test_contrib_gpu.py | 6 +- faiss/impl/HNSW.cpp | 33 ++ faiss/impl/HNSW.h | 2 + faiss/invlists/InvertedLists.cpp | 14 + faiss/invlists/InvertedLists.h | 3 + faiss/python/class_wrappers.py | 8 + tests/test_contrib.py | 44 ++- 14 files changed, 613 insertions(+), 450 deletions(-) create mode 100644 contrib/big_batch_search.py diff --git a/contrib/README.md b/contrib/README.md index 4bde6e73ad..f2b7d0f845 100644 --- a/contrib/README.md +++ b/contrib/README.md @@ -69,3 +69,8 @@ Contains: - a Python implementation of kmeans, that can be used for special datatypes (eg. sparse matrices). - a 2-level clustering routine and a function that can apply it to train an IndexIVF + +### big_batch_search.py + +Search IVF indexes with one centroid after another. Useful for large +databases that do not fit in RAM *and* a large number of queries. diff --git a/contrib/big_batch_search.py b/contrib/big_batch_search.py new file mode 100644 index 0000000000..ce769d0f60 --- /dev/null +++ b/contrib/big_batch_search.py @@ -0,0 +1,459 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import time +import pickle +import os +from multiprocessing.pool import ThreadPool +import threading + +import numpy as np +import faiss + +from faiss.contrib.inspect_tools import get_invlist + + +class BigBatchSearcher: + """ + Object that manages all the data related to the computation + except the actual within-bucket matching and the organization of the + computation (parallel or not) + """ + + def __init__( + self, + index, xq, k, + verbose=0, + use_float16=False): + + # verbosity + self.verbose = verbose + self.tictoc = [] + + self.xq = xq + self.index = index + self.use_float16 = use_float16 + keep_max = faiss.is_similarity_metric(index.metric_type) + self.rh = faiss.ResultHeap(len(xq), k, keep_max=keep_max) + self.t_accu = [0] * 5 + self.t_display = self.t0 = time.time() + + def start_t_accu(self): + self.t_accu_t0 = time.time() + + def stop_t_accu(self, n): + self.t_accu[n] += time.time() - self.t_accu_t0 + + def tic(self, name): + self.tictoc = (name, time.time()) + if self.verbose > 0: + print(name, end="\r", flush=True) + + def toc(self): + name, t0 = self.tictoc + dt = time.time() - t0 + if self.verbose > 0: + print(f"{name}: {dt:.3f} s") + return dt + + def report(self, l): + if self.verbose == 1 or ( + l > 1000 and time.time() < self.t_display + 1.0): + return + print( + f"[{time.time()-self.t0:.1f} s] list {l}/{self.index.nlist} " + f"times prep q {self.t_accu[0]:.3f} prep b {self.t_accu[1]:.3f} " + f"comp {self.t_accu[2]:.3f} res {self.t_accu[3]:.3f} " + f"wait {self.t_accu[4]:.3f}", + end="\r", flush=True + ) + self.t_display = time.time() + + def coarse_quantization(self): + self.tic("coarse quantization") + bs = 65536 + nq = len(self.xq) + q_assign = np.empty((nq, self.index.nprobe), dtype='int32') + for i0 in range(0, nq, bs): + i1 = min(nq, i0 + bs) + q_dis_i, q_assign_i = self.index.quantizer.search( + self.xq[i0:i1], self.index.nprobe) + # q_dis[i0:i1] = q_dis_i + q_assign[i0:i1] = q_assign_i + self.toc() + self.q_assign = q_assign + + def reorder_assign(self): + self.tic("bucket sort") + q_assign = self.q_assign + q_assign += 1 # move -1 -> 0 + self.bucket_lims = faiss.matrix_bucket_sort_inplace( + self.q_assign, nbucket=self.index.nlist + 1, nt=16) + self.query_ids = self.q_assign.ravel() + if self.verbose > 0: + print(' number of -1s:', self.bucket_lims[1]) + self.bucket_lims = self.bucket_lims[1:] # shift back to ignore -1s + del self.q_assign # inplace so let's forget about the old version... + self.toc() + + def prepare_bucket(self, l): + """ prepare the queries and database items for bucket l""" + t0 = time.time() + index = self.index + # prepare queries + i0, i1 = self.bucket_lims[l], self.bucket_lims[l + 1] + q_subset = self.query_ids[i0:i1] + xq_l = self.xq[q_subset] + if self.by_residual: + xq_l = xq_l - index.quantizer.reconstruct(l) + t1 = time.time() + # prepare database side + list_ids, xb_l = get_invlist(index.invlists, l) + + if self.decode_func is None: + xb_l = xb_l.ravel() + else: + xb_l = self.decode_func(xb_l) + + if self.use_float16: + xb_l = xb_l.astype('float16') + xq_l = xq_l.astype('float16') + + t2 = time.time() + self.t_accu[0] += t1 - t0 + self.t_accu[1] += t2 - t1 + return q_subset, xq_l, list_ids, xb_l + + def add_results_to_heap(self, q_subset, D, list_ids, I): + """add the bucket results to the heap structure""" + if D is None: + return + t0 = time.time() + if I is None: + I = list_ids + else: + I = list_ids[I] + self.rh.add_result_subset(q_subset, D, I) + self.t_accu[3] += time.time() - t0 + + def sizes_in_checkpoint(self): + return (self.xq.shape, self.index.nprobe, self.index.nlist) + + def write_checkpoint(self, fname, cur_list_no): + # write to temp file then move to final file + tmpname = fname + ".tmp" + pickle.dump( + { + "sizes": self.sizes_in_checkpoint(), + "cur_list_no": cur_list_no, + "rh": (self.rh.D, self.rh.I), + }, open(tmpname, "wb"), -1 + ) + os.replace(tmpname, fname) + + def read_checkpoint(self, fname): + ckp = pickle.load(open(fname, "rb")) + assert ckp["sizes"] == self.sizes_in_checkpoint() + self.rh.D[:] = ckp["rh"][0] + self.rh.I[:] = ckp["rh"][1] + return ckp["cur_list_no"] + + +class BlockComputer: + """ computation within one bucket """ + + def __init__( + self, + index, + method="knn_function", + pairwise_distances=faiss.pairwise_distances, + knn=faiss.knn): + + self.index = index + if index.__class__ == faiss.IndexIVFFlat: + index_help = faiss.IndexFlat(index.d, index.metric_type) + decode_func = lambda x: x.view("float32") + by_residual = False + elif index.__class__ == faiss.IndexIVFPQ: + index_help = faiss.IndexPQ( + index.d, index.pq.M, index.pq.nbits, index.metric_type) + index_help.pq = index.pq + decode_func = index_help.pq.decode + index_help.is_trained = True + by_residual = index.by_residual + elif index.__class__ == faiss.IndexIVFScalarQuantizer: + index_help = faiss.IndexScalarQuantizer( + index.d, index.sq.qtype, index.metric_type) + index_help.sq = index.sq + decode_func = index_help.sq.decode + index_help.is_trained = True + by_residual = index.by_residual + else: + raise RuntimeError(f"index type {index.__class__} not supported") + self.index_help = index_help + self.decode_func = None if method == "index" else decode_func + self.by_residual = by_residual + self.method = method + self.pairwise_distances = pairwise_distances + self.knn = knn + + def block_search(self, xq_l, xb_l, list_ids, k, **extra_args): + metric_type = self.index.metric_type + if xq_l.size == 0 or xb_l.size == 0: + D = I = None + elif self.method == "index": + faiss.copy_array_to_vector(xb_l, self.index_help.codes) + self.index_help.ntotal = len(list_ids) + D, I = self.index_help.search(xq_l, k) + elif self.method == "pairwise_distances": + # TODO implement blockwise to avoid mem blowup + D = self.pairwise_distances(xq_l, xb_l, metric=metric_type) + I = None + elif self.method == "knn_function": + D, I = self.knn(xq_l, xb_l, k, metric=metric_type, **extra_args) + + return D, I + + +def big_batch_search( + index, xq, k, + method="knn_function", + pairwise_distances=faiss.pairwise_distances, + knn=faiss.knn, + verbose=0, + threaded=0, + use_float16=False, + prefetch_threads=8, + computation_threads=0, + q_assign=None, + checkpoint=None, + checkpoint_freq=64, + start_list=0, + end_list=None, + crash_at=-1 + ): + """ + Search queries xq in the IVF index, with a search function that collects + batches of query vectors per inverted list. This can be faster than the + regular search indexes. + Supports IVFFlat, IVFPQ and IVFScalarQuantizer. + + Supports three computation methods: + method = "index": + build a flat index and populate it separately for each index + method = "pairwise_distances": + decompress codes and compute all pairwise distances for the queries + and index and add result to heap + method = "knn_function": + decompress codes and compute knn results for the queries + + threaded=0: sequential execution + threaded=1: prefetch next bucket while computing the current one + threaded>1: prefetch this many buckets at a time. + + compute_threads>1: the knn function will get an additional thread_no that + tells which worker should handle this. + + In threaded mode, the computation is tiled with the bucket perparation and + the writeback of results (useful to maximize GPU utilization). + + use_float16: convert all matrices to float16 (faster for GPU gemm) + + q_assign: override coarse assignment, should be a matrix of size nq * nprobe + + checkpointing (only for threaded > 1): + checkpoint: file where the checkpoints are stored + checkpoint_freq: when to perform checkpoinging. Should be a multiple of threaded + + start_list, end_list: process only a subset of invlists + """ + nprobe = index.nprobe + + assert method in ("index", "pairwise_distances", "knn_function") + + mem_queries = xq.nbytes + mem_assign = len(xq) * nprobe * np.dtype('int32').itemsize + mem_res = len(xq) * k * ( + np.dtype('int64').itemsize + + np.dtype('float32').itemsize + ) + mem_tot = mem_queries + mem_assign + mem_res + if verbose > 0: + print( + f"memory: queries {mem_queries} assign {mem_assign} " + f"result {mem_res} total {mem_tot} = {mem_tot / (1<<30):.3f} GiB" + ) + + bbs = BigBatchSearcher( + index, xq, k, + verbose=verbose, + use_float16=use_float16 + ) + + comp = BlockComputer( + index, + method=method, + pairwise_distances=pairwise_distances, + knn=knn + ) + + bbs.decode_func = comp.decode_func + bbs.by_residual = comp.by_residual + + if q_assign is None: + bbs.coarse_quantization() + else: + bbs.q_assign = q_assign + bbs.reorder_assign() + + if end_list is None: + end_list = index.nlist + + if checkpoint is not None: + assert (start_list, end_list) == (0, index.nlist) + if os.path.exists(checkpoint): + print("recovering checkpoint", checkpoint) + start_list = bbs.read_checkpoint(checkpoint) + print(" start at list", start_list) + else: + print("no checkpoint: starting from scratch") + + if threaded == 0: + # simple sequential version + + for l in range(start_list, end_list): + bbs.report(l) + q_subset, xq_l, list_ids, xb_l = bbs.prepare_bucket(l) + t0i = time.time() + D, I = comp.block_search(xq_l, xb_l, list_ids, k) + bbs.t_accu[2] += time.time() - t0i + bbs.add_results_to_heap(q_subset, D, list_ids, I) + + elif threaded == 1: + + # parallel version with granularity 1 + + def add_results_and_prefetch(to_add, l): + """ perform the addition for the previous bucket and + prefetch the next (if applicable) """ + if to_add is not None: + bbs.add_results_to_heap(*to_add) + if l < index.nlist: + return bbs.prepare_bucket(l) + + prefetched_bucket = bbs.prepare_bucket(start_list) + to_add = None + pool = ThreadPool(1) + + for l in range(start_list, end_list): + bbs.report(l) + prefetched_bucket_a = pool.apply_async( + add_results_and_prefetch, (to_add, l + 1)) + q_subset, xq_l, list_ids, xb_l = prefetched_bucket + bbs.start_t_accu() + D, I = comp.block_search(xq_l, xb_l, list_ids, k) + bbs.stop_t_accu(2) + to_add = q_subset, D, list_ids, I + bbs.start_t_accu() + prefetched_bucket = prefetched_bucket_a.get() + bbs.stop_t_accu(4) + + bbs.add_results_to_heap(*to_add) + pool.close() + else: + # run by batches with parallel prefetch and parallel comp + list_step = threaded + assert start_list % list_step == 0 + + if prefetch_threads == 0: + prefetch_map = map + else: + prefetch_pool = ThreadPool(prefetch_threads) + prefetch_map = prefetch_pool.map + + if computation_threads > 0: + comp_pool = ThreadPool(computation_threads) + + def add_results_and_prefetch_batch(to_add, l): + def add_results(to_add): + for ta in to_add: # this one cannot be run in parallel... + if ta is not None: + bbs.add_results_to_heap(*ta) + if prefetch_threads == 0: + add_results(to_add) + else: + add_a = prefetch_pool.apply_async(add_results, (to_add, )) + next_lists = range(l, min(l + list_step, index.nlist)) + res = list(prefetch_map(bbs.prepare_bucket, next_lists)) + if prefetch_threads > 0: + add_a.get() + return res + + # used only when computation_threads > 1 + thread_id_to_seq_lock = threading.Lock() + thread_id_to_seq = {} + + def do_comp(bucket): + (q_subset, xq_l, list_ids, xb_l) = bucket + try: + tid = thread_id_to_seq[threading.get_ident()] + except KeyError: + with thread_id_to_seq_lock: + tid = len(thread_id_to_seq) + thread_id_to_seq[threading.get_ident()] = tid + D, I = comp.block_search(xq_l, xb_l, list_ids, k, thread_id=tid) + return q_subset, D, list_ids, I + + prefetched_buckets = add_results_and_prefetch_batch([], start_list) + to_add = [] + pool = ThreadPool(1) + prefetched_buckets_a = None + + # loop over inverted lists + for l in range(start_list, end_list, list_step): + bbs.report(l) + buckets = prefetched_buckets + prefetched_buckets_a = pool.apply_async( + add_results_and_prefetch_batch, (to_add, l + list_step)) + + bbs.start_t_accu() + + to_add = [] + if computation_threads == 0: + for q_subset, xq_l, list_ids, xb_l in buckets: + D, I = comp.block_search(xq_l, xb_l, list_ids, k) + to_add.append((q_subset, D, list_ids, I)) + else: + to_add = list(comp_pool.map(do_comp, buckets)) + + bbs.stop_t_accu(2) + + # to test checkpointing + if l == crash_at: + 1 / 0 + + bbs.start_t_accu() + prefetched_buckets = prefetched_buckets_a.get() + bbs.stop_t_accu(4) + + if checkpoint is not None: + if (l // list_step) % checkpoint_freq == 0: + print("writing checkpoint %s" % l) + bbs.write_checkpoint(checkpoint, l) + + # flush add + for ta in to_add: + bbs.add_results_to_heap(*ta) + pool.close() + if prefetch_threads != 0: + prefetch_pool.close() + if computation_threads != 0: + comp_pool.close() + + bbs.tic("finalize heap") + bbs.rh.finalize() + bbs.toc() + + return bbs.rh.D, bbs.rh.I diff --git a/contrib/ivf_tools.py b/contrib/ivf_tools.py index fc1346a739..3ea8827315 100644 --- a/contrib/ivf_tools.py +++ b/contrib/ivf_tools.py @@ -3,16 +3,10 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -import time -import pickle -import os -from multiprocessing.pool import ThreadPool -import threading - import numpy as np import faiss -from faiss.contrib.inspect_tools import get_invlist +from faiss.contrib.inspect_tools import get_invlist_sizes def add_preassigned(index_ivf, x, a, ids=None): @@ -120,445 +114,30 @@ def replace_ivf_quantizer(index_ivf, new_quantizer): return old_quantizer -class BigBatchSearcher: - """ - Object that manages all the data related to the computation - except the actual within-bucket matching and the organization of the - computation (parallel or not) - """ - - def __init__( - self, - index, xq, k, - verbose=0, - use_float16=False): - - # verbosity - self.verbose = verbose - self.tictoc = [] - - self.xq = xq - self.index = index - self.use_float16 = use_float16 - keep_max = faiss.is_similarity_metric(index.metric_type) - self.rh = faiss.ResultHeap(len(xq), k, keep_max=keep_max) - self.t_accu = [0] * 5 - self.t_display = self.t0 = time.time() - - def start_t_accu(self): - self.t_accu_t0 = time.time() - - def stop_t_accu(self, n): - self.t_accu[n] += time.time() - self.t_accu_t0 - - def tic(self, name): - self.tictoc = (name, time.time()) - if self.verbose > 0: - print(name, end="\r", flush=True) - - def toc(self): - name, t0 = self.tictoc - dt = time.time() - t0 - if self.verbose > 0: - print(f"{name}: {dt:.3f} s") - return dt - - def report(self, l): - if self.verbose == 1 or ( - l > 1000 and time.time() < self.t_display + 1.0): - return - print( - f"[{time.time()-self.t0:.1f} s] list {l}/{self.index.nlist} " - f"times prep q {self.t_accu[0]:.3f} prep b {self.t_accu[1]:.3f} " - f"comp {self.t_accu[2]:.3f} res {self.t_accu[3]:.3f} " - f"wait {self.t_accu[4]:.3f}", - end="\r", flush=True - ) - self.t_display = time.time() - - def coarse_quantization(self): - self.tic("coarse quantization") - bs = 65536 - nq = len(self.xq) - q_assign = np.empty((nq, self.index.nprobe), dtype='int32') - for i0 in range(0, nq, bs): - i1 = min(nq, i0 + bs) - q_dis_i, q_assign_i = self.index.quantizer.search( - self.xq[i0:i1], self.index.nprobe) - # q_dis[i0:i1] = q_dis_i - q_assign[i0:i1] = q_assign_i - self.toc() - self.q_assign = q_assign - - def reorder_assign(self): - self.tic("bucket sort") - q_assign = self.q_assign - q_assign += 1 # move -1 -> 0 - self.bucket_lims = faiss.matrix_bucket_sort_inplace( - self.q_assign, nbucket=self.index.nlist + 1, nt=16) - self.query_ids = self.q_assign.ravel() - if self.verbose > 0: - print(' number of -1s:', self.bucket_lims[1]) - self.bucket_lims = self.bucket_lims[1:] # shift back to ignore -1s - del self.q_assign # inplace so let's forget about the old version... - self.toc() - - def prepare_bucket(self, l): - """ prepare the queries and database items for bucket l""" - t0 = time.time() - index = self.index - # prepare queries - i0, i1 = self.bucket_lims[l], self.bucket_lims[l + 1] - q_subset = self.query_ids[i0:i1] - xq_l = self.xq[q_subset] - if self.by_residual: - xq_l = xq_l - index.quantizer.reconstruct(l) - t1 = time.time() - # prepare database side - list_ids, xb_l = get_invlist(index.invlists, l) - - if self.decode_func is None: - xb_l = xb_l.ravel() - else: - xb_l = self.decode_func(xb_l) - - if self.use_float16: - xb_l = xb_l.astype('float16') - xq_l = xq_l.astype('float16') - - t2 = time.time() - self.t_accu[0] += t1 - t0 - self.t_accu[1] += t2 - t1 - return q_subset, xq_l, list_ids, xb_l - - def add_results_to_heap(self, q_subset, D, list_ids, I): - """add the bucket results to the heap structure""" - if D is None: - return - t0 = time.time() - if I is None: - I = list_ids - else: - I = list_ids[I] - self.rh.add_result_subset(q_subset, D, I) - self.t_accu[3] += time.time() - t0 - - def sizes_in_checkpoint(self): - return (self.xq.shape, self.index.nprobe, self.index.nlist) - - def write_checkpoint(self, fname, cur_list_no): - # write to temp file then move to final file - tmpname = fname + ".tmp" - pickle.dump( - { - "sizes": self.sizes_in_checkpoint(), - "cur_list_no": cur_list_no, - "rh": (self.rh.D, self.rh.I), - }, open(tmpname, "wb"), -1 - ) - os.replace(tmpname, fname) - - def read_checkpoint(self, fname): - ckp = pickle.load(open(fname, "rb")) - assert ckp["sizes"] == self.sizes_in_checkpoint() - self.rh.D[:] = ckp["rh"][0] - self.rh.I[:] = ckp["rh"][1] - return ckp["cur_list_no"] - - -class BlockComputer: - """ computation within one bucket """ - - def __init__( - self, - index, - method="knn_function", - pairwise_distances=faiss.pairwise_distances, - knn=faiss.knn): - - self.index = index - if index.__class__ == faiss.IndexIVFFlat: - index_help = faiss.IndexFlat(index.d, index.metric_type) - decode_func = lambda x: x.view("float32") - by_residual = False - elif index.__class__ == faiss.IndexIVFPQ: - index_help = faiss.IndexPQ( - index.d, index.pq.M, index.pq.nbits, index.metric_type) - index_help.pq = index.pq - decode_func = index_help.pq.decode - index_help.is_trained = True - by_residual = index.by_residual - elif index.__class__ == faiss.IndexIVFScalarQuantizer: - index_help = faiss.IndexScalarQuantizer( - index.d, index.sq.qtype, index.metric_type) - index_help.sq = index.sq - decode_func = index_help.sq.decode - index_help.is_trained = True - by_residual = index.by_residual - else: - raise RuntimeError(f"index type {index.__class__} not supported") - self.index_help = index_help - self.decode_func = None if method == "index" else decode_func - self.by_residual = by_residual - self.method = method - self.pairwise_distances = pairwise_distances - self.knn = knn - - def block_search(self, xq_l, xb_l, list_ids, k, **extra_args): - metric_type = self.index.metric_type - if xq_l.size == 0 or xb_l.size == 0: - D = I = None - elif self.method == "index": - faiss.copy_array_to_vector(xb_l, self.index_help.codes) - self.index_help.ntotal = len(list_ids) - D, I = self.index_help.search(xq_l, k) - elif self.method == "pairwise_distances": - # TODO implement blockwise to avoid mem blowup - D = self.pairwise_distances(xq_l, xb_l, metric=metric_type) - I = None - elif self.method == "knn_function": - D, I = self.knn(xq_l, xb_l, k, metric=metric_type, **extra_args) - - return D, I - - -def big_batch_search( - index, xq, k, - method="knn_function", - pairwise_distances=faiss.pairwise_distances, - knn=faiss.knn, - verbose=0, - threaded=0, - use_float16=False, - prefetch_threads=8, - computation_threads=0, - q_assign=None, - checkpoint=None, - checkpoint_freq=64, - start_list=0, - end_list=None, - crash_at=-1 - ): - """ - Search queries xq in the IVF index, with a search function that collects - batches of query vectors per inverted list. This can be faster than the - regular search indexes. - Supports IVFFlat, IVFPQ and IVFScalarQuantizer. - - Supports three computation methods: - method = "index": - build a flat index and populate it separately for each index - method = "pairwise_distances": - decompress codes and compute all pairwise distances for the queries - and index and add result to heap - method = "knn_function": - decompress codes and compute knn results for the queries - - threaded=0: sequential execution - threaded=1: prefetch next bucket while computing the current one - threaded>1: prefetch this many buckets at a time. - - compute_threads>1: the knn function will get an additional thread_no that - tells which worker should handle this. - - In threaded mode, the computation is tiled with the bucket perparation and - the writeback of results (useful to maximize GPU utilization). - - use_float16: convert all matrices to float16 (faster for GPU gemm) - - q_assign: override coarse assignment, should be a matrix of size nq * nprobe - - checkpointing (only for threaded > 1): - checkpoint: file where the checkpoints are stored - checkpoint_freq: when to perform checkpoinging. Should be a multiple of threaded - - start_list, end_list: process only a subset of invlists +def permute_invlists(index_ivf, perm): + """ Apply some permutation to the inverted lists, and modify the quantizer + entries accordingly. + Perm is an array of size nlist, where old_index = perm[new_index] """ - nprobe = index.nprobe - - assert method in ("index", "pairwise_distances", "knn_function") - - mem_queries = xq.nbytes - mem_assign = len(xq) * nprobe * np.dtype('int32').itemsize - mem_res = len(xq) * k * ( - np.dtype('int64').itemsize - + np.dtype('float32').itemsize - ) - mem_tot = mem_queries + mem_assign + mem_res - if verbose > 0: - print( - f"memory: queries {mem_queries} assign {mem_assign} " - f"result {mem_res} total {mem_tot} = {mem_tot / (1<<30):.3f} GiB" - ) - - bbs = BigBatchSearcher( - index, xq, k, - verbose=verbose, - use_float16=use_float16 - ) - - comp = BlockComputer( - index, - method=method, - pairwise_distances=pairwise_distances, - knn=knn - ) - - bbs.decode_func = comp.decode_func - bbs.by_residual = comp.by_residual - - if q_assign is None: - bbs.coarse_quantization() - else: - bbs.q_assign = q_assign - bbs.reorder_assign() - - if end_list is None: - end_list = index.nlist - - if checkpoint is not None: - assert (start_list, end_list) == (0, index.nlist) - if os.path.exists(checkpoint): - print("recovering checkpoint", checkpoint) - start_list = bbs.read_checkpoint(checkpoint) - print(" start at list", start_list) - else: - print("no checkpoint: starting from scratch") - - if threaded == 0: - # simple sequential version - - for l in range(start_list, end_list): - bbs.report(l) - q_subset, xq_l, list_ids, xb_l = bbs.prepare_bucket(l) - t0i = time.time() - D, I = comp.block_search(xq_l, xb_l, list_ids, k) - bbs.t_accu[2] += time.time() - t0i - bbs.add_results_to_heap(q_subset, D, list_ids, I) - - elif threaded == 1: - - # parallel version with granularity 1 - - def add_results_and_prefetch(to_add, l): - """ perform the addition for the previous bucket and - prefetch the next (if applicable) """ - if to_add is not None: - bbs.add_results_to_heap(*to_add) - if l < index.nlist: - return bbs.prepare_bucket(l) - - prefetched_bucket = bbs.prepare_bucket(start_list) - to_add = None - pool = ThreadPool(1) - - for l in range(start_list, end_list): - bbs.report(l) - prefetched_bucket_a = pool.apply_async( - add_results_and_prefetch, (to_add, l + 1)) - q_subset, xq_l, list_ids, xb_l = prefetched_bucket - bbs.start_t_accu() - D, I = comp.block_search(xq_l, xb_l, list_ids, k) - bbs.stop_t_accu(2) - to_add = q_subset, D, list_ids, I - bbs.start_t_accu() - prefetched_bucket = prefetched_bucket_a.get() - bbs.stop_t_accu(4) - - bbs.add_results_to_heap(*to_add) - pool.close() - else: - # run by batches with parallel prefetch and parallel comp - list_step = threaded - assert start_list % list_step == 0 - - if prefetch_threads == 0: - prefetch_map = map - else: - prefetch_pool = ThreadPool(prefetch_threads) - prefetch_map = prefetch_pool.map - - if computation_threads > 0: - comp_pool = ThreadPool(computation_threads) - - def add_results_and_prefetch_batch(to_add, l): - def add_results(to_add): - for ta in to_add: # this one cannot be run in parallel... - if ta is not None: - bbs.add_results_to_heap(*ta) - if prefetch_threads == 0: - add_results(to_add) - else: - add_a = prefetch_pool.apply_async(add_results, (to_add, )) - next_lists = range(l, min(l + list_step, index.nlist)) - res = list(prefetch_map(bbs.prepare_bucket, next_lists)) - if prefetch_threads > 0: - add_a.get() - return res - - # used only when computation_threads > 1 - thread_id_to_seq_lock = threading.Lock() - thread_id_to_seq = {} - - def do_comp(bucket): - (q_subset, xq_l, list_ids, xb_l) = bucket - try: - tid = thread_id_to_seq[threading.get_ident()] - except KeyError: - with thread_id_to_seq_lock: - tid = len(thread_id_to_seq) - thread_id_to_seq[threading.get_ident()] = tid - D, I = comp.block_search(xq_l, xb_l, list_ids, k, thread_id=tid) - return q_subset, D, list_ids, I - - prefetched_buckets = add_results_and_prefetch_batch([], start_list) - to_add = [] - pool = ThreadPool(1) - prefetched_buckets_a = None - - # loop over inverted lists - for l in range(start_list, end_list, list_step): - bbs.report(l) - buckets = prefetched_buckets - prefetched_buckets_a = pool.apply_async( - add_results_and_prefetch_batch, (to_add, l + list_step)) - - bbs.start_t_accu() - - to_add = [] - if computation_threads == 0: - for q_subset, xq_l, list_ids, xb_l in buckets: - D, I = comp.block_search(xq_l, xb_l, list_ids, k) - to_add.append((q_subset, D, list_ids, I)) - else: - to_add = list(comp_pool.map(do_comp, buckets)) - - bbs.stop_t_accu(2) - - # to test checkpointing - if l == crash_at: - 1 / 0 + nlist, = perm.shape + assert index_ivf.nlist == nlist + quantizer = faiss.downcast_index(index_ivf.quantizer) + assert quantizer.ntotal == index_ivf.nlist + perm = np.ascontiguousarray(perm, dtype='int64') - bbs.start_t_accu() - prefetched_buckets = prefetched_buckets_a.get() - bbs.stop_t_accu(4) + # just make sure it's a permutation... + bc = np.bincount(perm, minlength=nlist) + assert np.all(bc == np.ones(nlist, dtype=int)) - if checkpoint is not None: - if (l // list_step) % checkpoint_freq == 0: - print("writing checkpoint %s" % l) - bbs.write_checkpoint(checkpoint, l) + # handle quantizer + quantizer.permute_entries(perm) - # flush add - for ta in to_add: - bbs.add_results_to_heap(*ta) - pool.close() - if prefetch_threads != 0: - prefetch_pool.close() - if computation_threads != 0: - comp_pool.close() + # handle inverted lists + invlists = faiss.downcast_InvertedLists(index_ivf.invlists) + invlists.permute_invlists(faiss.swig_ptr(perm)) - bbs.tic("finalize heap") - bbs.rh.finalize() - bbs.toc() - return bbs.rh.D, bbs.rh.I +def sort_invlists_by_size(index_ivf): + invlist_sizes = get_invlist_sizes(index_ivf.invlists) + perm = np.argsort(invlist_sizes) + permute_invlists(index_ivf, perm) diff --git a/faiss/IndexFlatCodes.cpp b/faiss/IndexFlatCodes.cpp index eb52c76922..caff90ff9c 100644 --- a/faiss/IndexFlatCodes.cpp +++ b/faiss/IndexFlatCodes.cpp @@ -103,4 +103,15 @@ CodePacker* IndexFlatCodes::get_CodePacker() const { return new CodePackerFlat(code_size); } +void IndexFlatCodes::permute_entries(const idx_t* perm) { + std::vector new_codes(codes.size()); + + for (idx_t i = 0; i < ntotal; i++) { + memcpy(new_codes.data() + i * code_size, + codes.data() + perm[i] * code_size, + code_size); + } + std::swap(codes, new_codes); +} + } // namespace faiss diff --git a/faiss/IndexFlatCodes.h b/faiss/IndexFlatCodes.h index 677f3eead4..687558123f 100644 --- a/faiss/IndexFlatCodes.h +++ b/faiss/IndexFlatCodes.h @@ -59,6 +59,9 @@ struct IndexFlatCodes : Index { void check_compatible_for_merge(const Index& otherIndex) const override; virtual void merge_from(Index& otherIndex, idx_t add_id = 0) override; + + // permute_entries. perm of size ntotal maps new to old positions + void permute_entries(const idx_t* perm); }; } // namespace faiss diff --git a/faiss/IndexHNSW.cpp b/faiss/IndexHNSW.cpp index f133baf646..2600b8bf22 100644 --- a/faiss/IndexHNSW.cpp +++ b/faiss/IndexHNSW.cpp @@ -614,6 +614,14 @@ void IndexHNSW::link_singletons() { } } +void IndexHNSW::permute_entries(const idx_t* perm) { + auto flat_storage = dynamic_cast(storage); + FAISS_THROW_IF_NOT_MSG( + flat_storage, "don't know how to permute this index"); + flat_storage->permute_entries(perm); + hnsw.permute_entries(perm); +} + /************************************************************** * ReconstructFromNeighbors implementation **************************************************************/ diff --git a/faiss/IndexHNSW.h b/faiss/IndexHNSW.h index 5878149257..f1ff609e94 100644 --- a/faiss/IndexHNSW.h +++ b/faiss/IndexHNSW.h @@ -134,6 +134,8 @@ struct IndexHNSW : Index { void reorder_links(); void link_singletons(); + + void permute_entries(const idx_t* perm); }; /** Flat index topped with with a HNSW structure to access elements diff --git a/faiss/gpu/test/test_contrib_gpu.py b/faiss/gpu/test/test_contrib_gpu.py index c5d2d8ba0e..061df7da5b 100644 --- a/faiss/gpu/test/test_contrib_gpu.py +++ b/faiss/gpu/test/test_contrib_gpu.py @@ -10,7 +10,7 @@ from common_faiss_tests import get_dataset_2 -from faiss.contrib import datasets, evaluation, ivf_tools +from faiss.contrib import datasets, evaluation, big_batch_search from faiss.contrib.exhaustive_search import knn_ground_truth, \ range_ground_truth @@ -83,7 +83,7 @@ def knn_function(xq, xb, k, metric=faiss.METRIC_L2): return faiss.knn_gpu(res, xq, xb, k, metric=faiss.METRIC_L2) for method in "pairwise_distances", "knn_function": - Dnew, Inew = ivf_tools.big_batch_search( + Dnew, Inew = big_batch_search.big_batch_search( index, ds.get_queries(), k, method=method, pairwise_distances=pairwise_distances, @@ -119,7 +119,7 @@ def knn_function(xq, xb, k, metric=faiss.METRIC_L2, thread_id=None): metric=faiss.METRIC_L2, device=thread_id ) - Dnew, Inew = ivf_tools.big_batch_search( + Dnew, Inew = big_batch_search.big_batch_search( index, ds.get_queries(), k, method="knn_function", knn=knn_function, diff --git a/faiss/impl/HNSW.cpp b/faiss/impl/HNSW.cpp index b356da9673..de70d05b48 100644 --- a/faiss/impl/HNSW.cpp +++ b/faiss/impl/HNSW.cpp @@ -823,6 +823,39 @@ void HNSW::search_level_0( } } +void HNSW::permute_entries(const idx_t* map) { + // remap levels + storage_idx_t ntotal = levels.size(); + std::vector imap(ntotal); // inverse mapping + // map: new index -> old index + // imap: old index -> new index + for (int i = 0; i < ntotal; i++) { + assert(map[i] >= 0 && map[i] < ntotal); + imap[map[i]] = i; + } + if (entry_point != -1) { + entry_point = imap[entry_point]; + } + std::vector new_levels(ntotal); + std::vector new_offsets(ntotal + 1); + std::vector new_neighbors(neighbors.size()); + size_t no = 0; + for (int i = 0; i < ntotal; i++) { + storage_idx_t o = map[i]; // corresponding "old" index + new_levels[i] = levels[o]; + for (size_t j = offsets[o]; j < offsets[o + 1]; j++) { + storage_idx_t neigh = neighbors[j]; + new_neighbors[no++] = neigh >= 0 ? imap[neigh] : neigh; + } + new_offsets[i + 1] = no; + } + assert(new_offsets[ntotal] == offsets[ntotal]); + // swap everyone + std::swap(levels, new_levels); + std::swap(offsets, new_offsets); + std::swap(neighbors, new_neighbors); +} + /************************************************************** * MinimaxHeap **************************************************************/ diff --git a/faiss/impl/HNSW.h b/faiss/impl/HNSW.h index 6e2524ec5c..c923e0a6ae 100644 --- a/faiss/impl/HNSW.h +++ b/faiss/impl/HNSW.h @@ -226,6 +226,8 @@ struct HNSW { std::priority_queue& input, std::vector& output, int max_size); + + void permute_entries(const idx_t* map); }; struct HNSWStats { diff --git a/faiss/invlists/InvertedLists.cpp b/faiss/invlists/InvertedLists.cpp index 26adbf4ed7..46f31e6286 100644 --- a/faiss/invlists/InvertedLists.cpp +++ b/faiss/invlists/InvertedLists.cpp @@ -287,6 +287,20 @@ void ArrayInvertedLists::update_entries( memcpy(&codes[list_no][offset * code_size], codes_in, code_size * n_entry); } +void ArrayInvertedLists::permute_invlists(const idx_t* map) { + std::vector> new_codes(nlist); + std::vector> new_ids(nlist); + + for (size_t i = 0; i < nlist; i++) { + size_t o = map[i]; + FAISS_THROW_IF_NOT(o < nlist); + std::swap(new_codes[i], codes[o]); + std::swap(new_ids[i], ids[o]); + } + std::swap(codes, new_codes); + std::swap(ids, new_ids); +} + ArrayInvertedLists::~ArrayInvertedLists() {} /***************************************************************** diff --git a/faiss/invlists/InvertedLists.h b/faiss/invlists/InvertedLists.h index b93fe665be..c4d681452b 100644 --- a/faiss/invlists/InvertedLists.h +++ b/faiss/invlists/InvertedLists.h @@ -253,6 +253,9 @@ struct ArrayInvertedLists : InvertedLists { void resize(size_t list_no, size_t new_size) override; + /// permute the inverted lists, map maps new_id to old_id + void permute_invlists(const idx_t* map); + ~ArrayInvertedLists() override; }; diff --git a/faiss/python/class_wrappers.py b/faiss/python/class_wrappers.py index 6fccd768a8..54539797f5 100644 --- a/faiss/python/class_wrappers.py +++ b/faiss/python/class_wrappers.py @@ -654,6 +654,12 @@ def replacement_add_sa_codes(self, codes, ids=None): ids = swig_ptr(ids) self.add_sa_codes_c(n, swig_ptr(codes), ids) + def replacement_permute_entries(self, perm): + n, = perm.shape + assert n == self.ntotal + perm = np.ascontiguousarray(perm, dtype='int64') + self.permute_entries_c(faiss.swig_ptr(perm)) + replace_method(the_class, 'add', replacement_add) replace_method(the_class, 'add_with_ids', replacement_add_with_ids) replace_method(the_class, 'assign', replacement_assign) @@ -675,6 +681,8 @@ def replacement_add_sa_codes(self, codes, ids=None): replace_method(the_class, 'sa_decode', replacement_sa_decode) replace_method(the_class, 'add_sa_codes', replacement_add_sa_codes, ignore_missing=True) + replace_method(the_class, 'permute_entries', replacement_permute_entries, + ignore_missing=True) # get/set state for pickle # the data is serialized to std::vector -> numpy array -> python bytes diff --git a/tests/test_contrib.py b/tests/test_contrib.py index d88582dec2..10ac3224c7 100644 --- a/tests/test_contrib.py +++ b/tests/test_contrib.py @@ -15,6 +15,7 @@ from faiss.contrib import evaluation from faiss.contrib import ivf_tools from faiss.contrib import clustering +from faiss.contrib import big_batch_search from common_faiss_tests import get_dataset_2 try: @@ -497,7 +498,7 @@ def do_test(self, factory_string, metric=faiss.METRIC_L2): # faiss.omp_set_num_threads(1) for method in ("pairwise_distances", "knn_function", "index"): for threaded in 0, 1, 3, 8: - Dnew, Inew = ivf_tools.big_batch_search( + Dnew, Inew = big_batch_search.big_batch_search( index, ds.get_queries(), k, method=method, threaded=threaded @@ -526,12 +527,12 @@ def test_checkpoint(self): index.nprobe = 5 Dref, Iref = index.search(ds.get_queries(), k) - r = random.randrange(1<<60) + r = random.randrange(1 << 60) checkpoint = "/tmp/test_big_batch_checkpoint.%d" % r try: # First big batch search try: - Dnew, Inew = ivf_tools.big_batch_search( + Dnew, Inew = big_batch_search.big_batch_search( index, ds.get_queries(), k, method="knn_function", threaded=4, @@ -543,7 +544,7 @@ def test_checkpoint(self): else: self.assertFalse("should have crashed") # Second big batch search - Dnew, Inew = ivf_tools.big_batch_search( + Dnew, Inew = big_batch_search.big_batch_search( index, ds.get_queries(), k, method="knn_function", threaded=4, @@ -554,3 +555,38 @@ def test_checkpoint(self): finally: if os.path.exists(checkpoint): os.unlink(checkpoint) + + +class TestInvlistSort(unittest.TestCase): + + def test_sort(self): + """ make sure that the search results do not change + after sorting the inverted lists """ + ds = datasets.SyntheticDataset(32, 2000, 200, 20) + index = faiss.index_factory(ds.d, "IVF50,SQ8") + index.train(ds.get_train()) + index.add(ds.get_database()) + index.nprobe = 5 + Dref, Iref = index.search(ds.get_queries(), 5) + + ivf_tools.sort_invlists_by_size(index) + list_sizes = ivf_tools.get_invlist_sizes(index.invlists) + assert np.all(list_sizes[1:] >= list_sizes[:-1]) + + Dnew, Inew = index.search(ds.get_queries(), 5) + np.testing.assert_equal(Dnew, Dref) + np.testing.assert_equal(Inew, Iref) + + def test_hnsw_permute(self): + """ make sure HNSW permutation works (useful when used as coarse quantizer) """ + ds = datasets.SyntheticDataset(32, 0, 1000, 50) + index = faiss.index_factory(ds.d, "HNSW32,Flat") + index.add(ds.get_database()) + Dref, Iref = index.search(ds.get_queries(), 5) + rs = np.random.RandomState(1234) + perm = rs.permutation(index.ntotal) + index.permute_entries(perm) + Dnew, Inew = index.search(ds.get_queries(), 5) + np.testing.assert_equal(Dnew, Dref) + Inew_remap = perm[Inew] + np.testing.assert_equal(Inew_remap, Iref)