From 1d0e8d489f15f7e34f5ebba978d44e675e2e35dd Mon Sep 17 00:00:00 2001 From: Gergely Szilvasy Date: Tue, 30 Jan 2024 10:58:13 -0800 Subject: [PATCH] index optimizer (#3154) Summary: Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3154 Using the benchmark to find Pareto optimal indices, in this case on BigANN as an example. Separately optimize the coarse quantizer and the vector codec and use Pareto optimal configurations to construct IVF indices, which are then retested at various scales. See `optimize()` in `optimize.py` as the main function driving the process. The results can be interpreted with `bench_fw_notebook.ipynb`, which allows: * filtering by maximum code size * maximum time * minimum accuracy * space or time Pareto optimal options * and visualize the results and output them as a table. This version is intentionally limited to IVF(Flat|HNSW),PQ|SQ indices... Reviewed By: mdouze Differential Revision: D51781670 fbshipit-source-id: 2c0f800d374ea845255934f519cc28095c00a51f --- benchs/bench_fw/benchmark.py | 122 +++- benchs/bench_fw/benchmark_io.py | 34 +- benchs/bench_fw/descriptors.py | 5 +- benchs/bench_fw/index.py | 203 ++++-- benchs/bench_fw/optimize.py | 333 +++++++++ benchs/bench_fw/utils.py | 169 ++++- benchs/bench_fw_notebook.ipynb | 1144 ++++++++++++++----------------- benchs/bench_fw_optimize.py | 58 ++ 8 files changed, 1318 insertions(+), 750 deletions(-) create mode 100644 benchs/bench_fw/optimize.py create mode 100644 benchs/bench_fw_optimize.py diff --git a/benchs/bench_fw/benchmark.py b/benchs/bench_fw/benchmark.py index ccdbf9c5d6..1053f99388 100644 --- a/benchs/bench_fw/benchmark.py +++ b/benchs/bench_fw/benchmark.py @@ -7,19 +7,20 @@ from copy import copy from dataclasses import dataclass from operator import itemgetter -from statistics import median, mean +from statistics import mean, median from typing import Any, Dict, List, Optional -from .utils import dict_merge -from .index import Index, IndexFromCodec, IndexFromFactory -from .descriptors import DatasetDescriptor, IndexDescriptor - import faiss # @manual=//faiss/python:pyfaiss_gpu import numpy as np from scipy.optimize import curve_fit +from .descriptors import DatasetDescriptor, IndexDescriptor +from .index import Index, IndexFromCodec, IndexFromFactory + +from .utils import dict_merge + logger = logging.getLogger(__name__) @@ -274,8 +275,8 @@ def range_search( search_parameters: Optional[Dict[str, int]], radius: Optional[float] = None, gt_radius: Optional[float] = None, - range_search_metric_function = None, - gt_rsm = None, + range_search_metric_function=None, + gt_rsm=None, ): logger.info("range_search: begin") if radius is None: @@ -328,7 +329,13 @@ def knn_ground_truth(self): logger.info("knn_ground_truth: begin") flat_desc = self.get_index_desc("Flat") self.build_index_wrapper(flat_desc) - self.gt_knn_D, self.gt_knn_I, _, _, requires = flat_desc.index.knn_search( + ( + self.gt_knn_D, + self.gt_knn_I, + _, + _, + requires, + ) = flat_desc.index.knn_search( dry_run=False, search_parameters=None, query_vectors=self.query_vectors, @@ -338,13 +345,13 @@ def knn_ground_truth(self): logger.info("knn_ground_truth: end") def search_benchmark( - self, + self, name, search_func, key_func, cost_metrics, perf_metrics, - results: Dict[str, Any], + results: Dict[str, Any], index: Index, ): index_name = index.get_index_name() @@ -376,11 +383,18 @@ def experiment(parameters, cost_metric, perf_metric): logger.info(f"{name}_benchmark: end") return results, requires - def knn_search_benchmark(self, dry_run, results: Dict[str, Any], index: Index): + def knn_search_benchmark( + self, dry_run, results: Dict[str, Any], index: Index + ): return self.search_benchmark( name="knn_search", search_func=lambda parameters: index.knn_search( - dry_run, parameters, self.query_vectors, self.k, self.gt_knn_I, self.gt_knn_D, + dry_run, + parameters, + self.query_vectors, + self.k, + self.gt_knn_I, + self.gt_knn_D, )[3:], key_func=lambda parameters: index.get_knn_search_name( search_parameters=parameters, @@ -394,11 +408,17 @@ def knn_search_benchmark(self, dry_run, results: Dict[str, Any], index: Index): index=index, ) - def reconstruct_benchmark(self, dry_run, results: Dict[str, Any], index: Index): + def reconstruct_benchmark( + self, dry_run, results: Dict[str, Any], index: Index + ): return self.search_benchmark( name="reconstruct", search_func=lambda parameters: index.reconstruct( - dry_run, parameters, self.query_vectors, self.k, self.gt_knn_I, + dry_run, + parameters, + self.query_vectors, + self.k, + self.gt_knn_I, ), key_func=lambda parameters: index.get_knn_search_name( search_parameters=parameters, @@ -426,19 +446,20 @@ def range_search_benchmark( return self.search_benchmark( name="range_search", search_func=lambda parameters: self.range_search( - dry_run=dry_run, - index=index, - search_parameters=parameters, + dry_run=dry_run, + index=index, + search_parameters=parameters, radius=radius, gt_radius=gt_radius, - range_search_metric_function=range_search_metric_function, + range_search_metric_function=range_search_metric_function, gt_rsm=gt_rsm, )[4:], key_func=lambda parameters: index.get_range_search_name( search_parameters=parameters, query_vectors=self.query_vectors, radius=radius, - ) + metric_key, + ) + + metric_key, cost_metrics=["time"], perf_metrics=["range_score_max_recall"], results=results, @@ -446,11 +467,12 @@ def range_search_benchmark( ) def build_index_wrapper(self, index_desc: IndexDescriptor): - if hasattr(index_desc, 'index'): + if hasattr(index_desc, "index"): return if index_desc.factory is not None: training_vectors = copy(self.training_vectors) - training_vectors.num_vectors = index_desc.training_size + if index_desc.training_size is not None: + training_vectors.num_vectors = index_desc.training_size index = IndexFromFactory( num_threads=self.num_threads, d=self.d, @@ -481,15 +503,24 @@ def clone_one(self, index_desc): training_vectors=self.training_vectors, database_vectors=self.database_vectors, query_vectors=self.query_vectors, - index_descs = [self.get_index_desc("Flat"), index_desc], + index_descs=[self.get_index_desc("Flat"), index_desc], range_ref_index_desc=self.range_ref_index_desc, k=self.k, distance_metric=self.distance_metric, ) - benchmark.set_io(self.io) + benchmark.set_io(self.io.clone()) return benchmark - def benchmark_one(self, dry_run, results: Dict[str, Any], index_desc: IndexDescriptor, train, reconstruct, knn, range): + def benchmark_one( + self, + dry_run, + results: Dict[str, Any], + index_desc: IndexDescriptor, + train, + reconstruct, + knn, + range, + ): faiss.omp_set_num_threads(self.num_threads) if not dry_run: self.knn_ground_truth() @@ -531,9 +562,12 @@ def benchmark_one(self, dry_run, results: Dict[str, Any], index_desc: IndexDescr ) assert requires is None - if self.range_ref_index_desc is None or not index_desc.index.supports_range_search(): + if ( + self.range_ref_index_desc is None + or not index_desc.index.supports_range_search() + ): return results, None - + ref_index_desc = self.get_index_desc(self.range_ref_index_desc) if ref_index_desc is None: raise ValueError( @@ -550,7 +584,9 @@ def benchmark_one(self, dry_run, results: Dict[str, Any], index_desc: IndexDescr coefficients, coefficients_training_data, ) = self.range_search_reference( - ref_index_desc.index, ref_index_desc.search_params, range_metric + ref_index_desc.index, + ref_index_desc.search_params, + range_metric, ) gt_rsm = self.range_ground_truth( gt_radius, range_search_metric_function @@ -583,7 +619,15 @@ def benchmark_one(self, dry_run, results: Dict[str, Any], index_desc: IndexDescr return results, None - def benchmark(self, result_file=None, local=False, train=False, reconstruct=False, knn=False, range=False): + def benchmark( + self, + result_file=None, + local=False, + train=False, + reconstruct=False, + knn=False, + range=False, + ): logger.info("begin evaluate") faiss.omp_set_num_threads(self.num_threads) @@ -656,20 +700,34 @@ def benchmark(self, result_file=None, local=False, train=False, reconstruct=Fals if current_todo: results_one = {"indices": {}, "experiments": {}} - params = [(self.clone_one(index_desc), results_one, index_desc, train, reconstruct, knn, range) for index_desc in current_todo] - for result in self.io.launch_jobs(run_benchmark_one, params, local=local): + params = [ + ( + index_desc, + self.clone_one(index_desc), + results_one, + train, + reconstruct, + knn, + range, + ) + for index_desc in current_todo + ] + for result in self.io.launch_jobs( + run_benchmark_one, params, local=local + ): dict_merge(results, result) - todo = next_todo + todo = next_todo if result_file is not None: self.io.write_json(results, result_file, overwrite=True) logger.info("end evaluate") return results + def run_benchmark_one(params): logger.info(params) - benchmark, results, index_desc, train, reconstruct, knn, range = params + index_desc, benchmark, results, train, reconstruct, knn, range = params results, requires = benchmark.benchmark_one( dry_run=False, results=results, diff --git a/benchs/bench_fw/benchmark_io.py b/benchs/bench_fw/benchmark_io.py index 483acba8c7..b39bb60290 100644 --- a/benchs/bench_fw/benchmark_io.py +++ b/benchs/bench_fw/benchmark_io.py @@ -10,13 +10,13 @@ import os import pickle from dataclasses import dataclass -import submitit from typing import Any, List, Optional from zipfile import ZipFile import faiss # @manual=//faiss/python:pyfaiss_gpu import numpy as np +import submitit from faiss.contrib.datasets import ( # @manual=//faiss/contrib:faiss_contrib_gpu dataset_from_name, ) @@ -47,6 +47,9 @@ def merge_rcq_itq( class BenchmarkIO: path: str + def clone(self): + return BenchmarkIO(path=self.path) + def __post_init__(self): self.cached_ds = {} @@ -119,18 +122,27 @@ def write_file( def get_dataset(self, dataset): if dataset not in self.cached_ds: - if dataset.namespace is not None and dataset.namespace[:4] == "std_": + if ( + dataset.namespace is not None + and dataset.namespace[:4] == "std_" + ): if dataset.tablename not in self.cached_ds: self.cached_ds[dataset.tablename] = dataset_from_name( dataset.tablename, ) p = dataset.namespace[4] if p == "t": - self.cached_ds[dataset] = self.cached_ds[dataset.tablename].get_train(dataset.num_vectors) + self.cached_ds[dataset] = self.cached_ds[ + dataset.tablename + ].get_train(dataset.num_vectors) elif p == "d": - self.cached_ds[dataset] = self.cached_ds[dataset.tablename].get_database() + self.cached_ds[dataset] = self.cached_ds[ + dataset.tablename + ].get_database() elif p == "q": - self.cached_ds[dataset] = self.cached_ds[dataset.tablename].get_queries() + self.cached_ds[dataset] = self.cached_ds[ + dataset.tablename + ].get_queries() else: raise ValueError elif dataset.namespace == "syn": @@ -233,8 +245,8 @@ def launch_jobs(self, func, params, local=True): if local: results = [func(p) for p in params] return results - print(f'launching {len(params)} jobs') - executor = submitit.AutoExecutor(folder='/checkpoint/gsz/jobs') + logger.info(f"launching {len(params)} jobs") + executor = submitit.AutoExecutor(folder="/checkpoint/gsz/jobs") executor.update_parameters( nodes=1, gpus_per_node=8, @@ -248,9 +260,9 @@ def launch_jobs(self, func, params, local=True): slurm_constraint="bldg1", ) jobs = executor.map_array(func, params) - print(f'launched {len(jobs)} jobs') - # for job, param in zip(jobs, params): - # print(f"{job.job_id=} {param=}") + logger.info(f"launched {len(jobs)} jobs") + for job, param in zip(jobs, params): + logger.info(f"{job.job_id=} {param[0]=}") results = [job.result() for job in jobs] - print(f'received {len(results)} results') + print(f"received {len(results)} results") return results diff --git a/benchs/bench_fw/descriptors.py b/benchs/bench_fw/descriptors.py index 113f46b545..f1dd7354c2 100644 --- a/benchs/bench_fw/descriptors.py +++ b/benchs/bench_fw/descriptors.py @@ -9,6 +9,7 @@ import faiss # @manual=//faiss/python:pyfaiss_gpu from .utils import timer + logger = logging.getLogger(__name__) @@ -101,7 +102,9 @@ def k_means(self, io, k, dry_run): tablename=f"{self.get_filename()}kmeans_{k}.npy" ) meta_filename = kmeans_vectors.tablename + ".json" - if not io.file_exist(kmeans_vectors.tablename) or not io.file_exist(meta_filename): + if not io.file_exist(kmeans_vectors.tablename) or not io.file_exist( + meta_filename + ): if dry_run: return None, None, kmeans_vectors.tablename x = io.get_dataset(self) diff --git a/benchs/bench_fw/index.py b/benchs/bench_fw/index.py index 4c536aa753..14f2158e64 100644 --- a/benchs/bench_fw/index.py +++ b/benchs/bench_fw/index.py @@ -4,19 +4,19 @@ # LICENSE file in the root directory of this source tree. -from copy import copy import logging import os from collections import OrderedDict +from copy import copy from dataclasses import dataclass from typing import ClassVar, Dict, List, Optional import faiss # @manual=//faiss/python:pyfaiss_gpu - import numpy as np + from faiss.contrib.evaluation import ( # @manual=//faiss/contrib:faiss_contrib_gpu - OperatingPointsWithRanges, knn_intersection_measure, + OperatingPointsWithRanges, ) from faiss.contrib.factory_tools import ( # @manual=//faiss/contrib:faiss_contrib_gpu reverse_index_factory, @@ -27,7 +27,13 @@ ) from .descriptors import DatasetDescriptor -from .utils import distance_ratio_measure, get_cpu_info, timer, refine_distances_knn, refine_distances_range +from .utils import ( + distance_ratio_measure, + get_cpu_info, + refine_distances_knn, + refine_distances_range, + timer, +) logger = logging.getLogger(__name__) @@ -106,7 +112,9 @@ def set_index_param(index, name, val, assert_same=False): icm_encoder_factory = faiss.GpuIcmEncoderFactory(ngpus) if isinstance(index, faiss.IndexProductLocalSearchQuantizer): for i in range(index.plsq.nsplits): - lsq = faiss.downcast_Quantizer(index.plsq.subquantizer(i)) + lsq = faiss.downcast_Quantizer( + index.plsq.subquantizer(i) + ) if lsq.icm_encoder_factory is None: lsq.icm_encoder_factory = icm_encoder_factory else: @@ -119,29 +127,39 @@ def set_index_param(index, name, val, assert_same=False): obj = faiss.extract_index_ivf(index) elif name in ["use_beam_LUT", "max_beam_size"]: if isinstance(index, faiss.IndexProductResidualQuantizer): - obj = [faiss.downcast_Quantizer(index.prq.subquantizer(i)) for i in range(index.prq.nsplits)] + obj = [ + faiss.downcast_Quantizer(index.prq.subquantizer(i)) + for i in range(index.prq.nsplits) + ] else: obj = index.rq elif name == "encode_ils_iters": if isinstance(index, faiss.IndexProductLocalSearchQuantizer): - obj = [faiss.downcast_Quantizer(index.plsq.subquantizer(i)) for i in range(index.plsq.nsplits)] + obj = [ + faiss.downcast_Quantizer(index.plsq.subquantizer(i)) + for i in range(index.plsq.nsplits) + ] else: obj = index.lsq else: obj = index - + if not isinstance(obj, list): obj = [obj] for o in obj: test = getattr(o, name) - if assert_same and not name == 'use_beam_LUT': + if assert_same and not name == "use_beam_LUT": assert test == val else: setattr(o, name, val) @staticmethod def filter_index_param_dict_list(param_dict_list): - if param_dict_list is not None and param_dict_list[0] is not None and "k_factor" in param_dict_list[0]: + if ( + param_dict_list is not None + and param_dict_list[0] is not None + and "k_factor" in param_dict_list[0] + ): filtered = copy(param_dict_list) del filtered[0]["k_factor"] return filtered @@ -153,6 +171,7 @@ def is_flat(self): return isinstance(model, faiss.IndexFlat) def is_ivf(self): + return False model = self.get_model() return faiss.try_extract_index_ivf(model) is not None @@ -243,7 +262,9 @@ def knn_search_quantizer(self, query_vectors, k): pretransform = None quantizer_query_vectors = query_vectors - quantizer, _, _ = self.get_quantizer(dry_run=False, pretransform=pretransform) + quantizer, _, _ = self.get_quantizer( + dry_run=False, pretransform=pretransform + ) QD, QI, _, QP, _ = quantizer.knn_search( dry_run=False, search_parameters=None, @@ -300,7 +321,9 @@ def knn_search( # Index2Layer doesn't support search xq = self.io.get_dataset(query_vectors) xb = index.reconstruct_n(0, index.ntotal) - (D, I), t, _ = timer("knn_search 2layer", lambda: faiss.knn(xq, xb, k)) + (D, I), t, _ = timer( + "knn_search 2layer", lambda: faiss.knn(xq, xb, k) + ) elif self.is_ivf() and not isinstance(index, faiss.IndexRefine): index_ivf = faiss.extract_index_ivf(index) nprobe = ( @@ -310,7 +333,7 @@ def knn_search( else index_ivf.nprobe ) xqt, QD, QI, QP = self.knn_search_quantizer( - query_vectors=query_vectors, + query_vectors=query_vectors, k=nprobe, ) if index_ivf.parallel_mode != 2: @@ -358,11 +381,19 @@ def knn_search( "construction_params": self.get_construction_params(), "search_params": search_parameters, "knn_intersection": knn_intersection_measure( - I, I_gt, - ) if I_gt is not None else None, + I, + I_gt, + ) + if I_gt is not None + else None, "distance_ratio": distance_ratio_measure( - I, R, D_gt, self.metric_type, - ) if D_gt is not None else None, + I, + R, + D_gt, + self.metric_type, + ) + if D_gt is not None + else None, } logger.info("knn_search: end") return D, I, R, P, None @@ -377,12 +408,14 @@ def reconstruct( ): logger.info("reconstruct: begin") filename = ( - self.get_knn_search_name(parameters, query_vectors, k, reconstruct=True) + self.get_knn_search_name( + parameters, query_vectors, k, reconstruct=True + ) + "zip" ) if self.io.file_exist(filename): logger.info(f"Using cached results for {filename}") - P, = self.io.read_file(filename, ["P"]) + (P,) = self.io.read_file(filename, ["P"]) P["index"] = self.get_index_name() P["codec"] = self.get_codec_name() P["factory"] = self.get_model_name() @@ -395,15 +428,21 @@ def reconstruct( codec_meta = self.fetch_meta() Index.set_index_param_dict(codec, parameters) xb = self.io.get_dataset(self.database_vectors) - xb_encoded, encode_t, _ = timer("sa_encode", lambda: codec.sa_encode(xb)) + xb_encoded, encode_t, _ = timer( + "sa_encode", lambda: codec.sa_encode(xb) + ) xq = self.io.get_dataset(query_vectors) if self.is_decode_supported(): - xb_decoded, decode_t, _ = timer("sa_decode", lambda: codec.sa_decode(xb_encoded)) + xb_decoded, decode_t, _ = timer( + "sa_decode", lambda: codec.sa_decode(xb_encoded) + ) mse = np.square(xb_decoded - xb).sum(axis=1).mean().item() _, I = faiss.knn(xq, xb_decoded, k, metric=self.metric_type) asym_recall = knn_intersection_measure(I, I_gt) xq_decoded = codec.sa_decode(codec.sa_encode(xq)) - _, I = faiss.knn(xq_decoded, xb_decoded, k, metric=self.metric_type) + _, I = faiss.knn( + xq_decoded, xb_decoded, k, metric=self.metric_type + ) else: mse = None asym_recall = None @@ -604,7 +643,7 @@ def fetch_index(self): if self.is_ivf() and not isinstance(index, faiss.IndexRefine): xbt, QD, QI, QP = self.knn_search_quantizer( - query_vectors=self.database_vectors, + query_vectors=self.database_vectors, k=1, ) index_ivf = faiss.extract_index_ivf(index) @@ -638,22 +677,21 @@ def get_index(self): def get_construction_params(self): return self.construction_params - # def get_code_size(self): - # def get_index_code_size(index): - # index = faiss.downcast_index(index) - # if isinstance(index, faiss.IndexPreTransform): - # return get_index_code_size(index.index) - # elif isinstance(index, faiss.IndexHNSWFlat): - # return index.d * 4 # TODO - # elif type(index) in [faiss.IndexRefine, faiss.IndexRefineFlat]: - # return get_index_code_size( - # index.base_index - # ) + get_index_code_size(index.refine_index) - # else: - # return index.code_size - - # codec = self.get_codec() - # return get_index_code_size(codec) + def get_code_size(self, codec=None): + def get_index_code_size(index): + index = faiss.downcast_index(index) + if isinstance(index, faiss.IndexPreTransform): + return get_index_code_size(index.index) + elif type(index) in [faiss.IndexRefine, faiss.IndexRefineFlat]: + return get_index_code_size( + index.base_index + ) + get_index_code_size(index.refine_index) + else: + return index.code_size if hasattr(index, "code_size") else 0 + + if codec is None: + codec = self.get_codec() + return get_index_code_size(codec) def get_sa_code_size(self, codec=None): if codec is None: @@ -680,32 +718,28 @@ def add_range_or_val(name, range): if model_ivf is not None: add_range_or_val( "nprobe", - # [ + [2**i for i in range(12) if 2**i <= model_ivf.nlist * 0.5], + # [1, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28] + [ + # i + # for i in range(32, 64, 8) + # if i <= model_ivf.nlist * 0.1 + # ] + [ + # i + # for i in range(64, 128, 16) + # if i <= model_ivf.nlist * 0.1 + # ] + [ + # i + # for i in range(128, 256, 32) + # if i <= model_ivf.nlist * 0.1 + # ] + [ + # i + # for i in range(256, 512, 64) + # if i <= model_ivf.nlist * 0.1 + # ] + [ # 2**i - # for i in range(12) - # if 2**i <= model_ivf.nlist * 0.5 + # for i in range(9, 12) + # if 2**i <= model_ivf.nlist * 0.1 # ], - [1, 2, 4, 6, 8, 10, 12, 14, 16, 20, 24, 28] + [ - i - for i in range(32, 64, 8) - if i <= model_ivf.nlist * 0.1 - ] + [ - i - for i in range(64, 128, 16) - if i <= model_ivf.nlist * 0.1 - ] + [ - i - for i in range(128, 256, 32) - if i <= model_ivf.nlist * 0.1 - ] + [ - i - for i in range(256, 512, 64) - if i <= model_ivf.nlist * 0.1 - ] + [ - 2**i - for i in range(9, 12) - if 2**i <= model_ivf.nlist * 0.1 - ], ) model = faiss.downcast_index(model) if isinstance(model, faiss.IndexRefine): @@ -718,7 +752,9 @@ def add_range_or_val(name, range): "efSearch", [2**i for i in range(3, 11)], ) - elif isinstance(model, faiss.IndexResidualQuantizer) or isinstance(model, faiss.IndexProductResidualQuantizer): + elif isinstance(model, faiss.IndexResidualQuantizer) or isinstance( + model, faiss.IndexProductResidualQuantizer + ): add_range_or_val( "max_beam_size", [1, 2, 4, 8, 16, 32], @@ -727,7 +763,9 @@ def add_range_or_val(name, range): "use_beam_LUT", [1], ) - elif isinstance(model, faiss.IndexLocalSearchQuantizer) or isinstance(model, faiss.IndexProductLocalSearchQuantizer): + elif isinstance(model, faiss.IndexLocalSearchQuantizer) or isinstance( + model, faiss.IndexProductLocalSearchQuantizer + ): add_range_or_val( "encode_ils_iters", [2, 4, 8, 16], @@ -854,7 +892,9 @@ def fetch_meta(self, dry_run=False): def fetch_codec(self, dry_run=False): codec_filename = self.get_codec_name() + "codec" meta_filename = self.get_codec_name() + "json" - if self.io.file_exist(codec_filename) and self.io.file_exist(meta_filename): + if self.io.file_exist(codec_filename) and self.io.file_exist( + meta_filename + ): codec = self.io.read_index(codec_filename) assert self.d == codec.d assert self.metric_type == codec.metric_type @@ -874,6 +914,7 @@ def fetch_codec(self, dry_run=False): "training_size": self.training_vectors.num_vectors, "codec_size": codec_size, "sa_code_size": self.get_sa_code_size(codec), + "code_size": self.get_code_size(codec), "cpu": get_cpu_info(), } self.io.write_json(meta, meta_filename, overwrite=True) @@ -921,7 +962,9 @@ def get_quantizer(self, dry_run, pretransform=None): training_vectors = self.training_vectors else: training_vectors = pretransform.transform(self.training_vectors) - centroids, t, requires = training_vectors.k_means(self.io, model_ivf.nlist, dry_run) + centroids, t, requires = training_vectors.k_means( + self.io, model_ivf.nlist, dry_run + ) if requires is not None: return None, None, requires quantizer = IndexFromFactory( @@ -944,11 +987,11 @@ def assemble(self, dry_run): model = self.get_model() opaque = True t_aggregate = 0 - try: - reverse_index_factory(model) - opaque = False - except NotImplementedError: - opaque = True + # try: + # reverse_index_factory(model) + # opaque = False + # except NotImplementedError: + # opaque = True if opaque: codec = model else: @@ -958,7 +1001,9 @@ def assemble(self, dry_run): if not isinstance(sub_index, faiss.IndexFlat): # replace the sub-index with Flat and fetch pre-trained pretransform = self.get_pretransform() - codec, meta, report = pretransform.fetch_codec(dry_run=dry_run) + codec, meta, report = pretransform.fetch_codec( + dry_run=dry_run + ) if report is not None: return None, None, report t_aggregate += meta["training_time"] @@ -978,7 +1023,9 @@ def assemble(self, dry_run): training_vectors=transformed_training_vectors, ) wrapper.set_io(self.io) - codec.index, meta, report = wrapper.fetch_codec(dry_run=dry_run) + codec.index, meta, report = wrapper.fetch_codec( + dry_run=dry_run + ) if report is not None: return None, None, report t_aggregate += meta["training_time"] @@ -1008,14 +1055,18 @@ def assemble(self, dry_run): d=model.base_index.d, metric=model.base_index.metric_type, database_vectors=self.database_vectors, - construction_params=IndexBase.filter_index_param_dict_list(self.construction_params), + construction_params=IndexBase.filter_index_param_dict_list( + self.construction_params + ), search_params=None, factory=reverse_index_factory(model.base_index), training_vectors=self.training_vectors, ) wrapper.set_io(self.io) codec = faiss.clone_index(model) - codec.base_index, meta, requires = wrapper.fetch_codec(dry_run=dry_run) + codec.base_index, meta, requires = wrapper.fetch_codec( + dry_run=dry_run + ) if requires is not None: return None, None, requires t_aggregate += meta["training_time"] diff --git a/benchs/bench_fw/optimize.py b/benchs/bench_fw/optimize.py new file mode 100644 index 0000000000..473436ea68 --- /dev/null +++ b/benchs/bench_fw/optimize.py @@ -0,0 +1,333 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from dataclasses import dataclass +from typing import Dict, List, Tuple + +import faiss # @manual=//faiss/python:pyfaiss_gpu + +# from faiss.contrib.evaluation import ( # @manual=//faiss/contrib:faiss_contrib_gpu +# OperatingPoints, +# ) + +from .benchmark import Benchmark +from .descriptors import DatasetDescriptor, IndexDescriptor +from .utils import dict_merge, filter_results, ParetoMetric, ParetoMode + +logger = logging.getLogger(__name__) + + +@dataclass +class Optimizer: + distance_metric: str = "L2" + num_threads: int = 32 + run_local: bool = True + + def __post_init__(self): + self.cached_benchmark = None + if self.distance_metric == "IP": + self.distance_metric_type = faiss.METRIC_INNER_PRODUCT + elif self.distance_metric == "L2": + self.distance_metric_type = faiss.METRIC_L2 + else: + raise ValueError + + def set_io(self, benchmark_io): + self.io = benchmark_io + self.io.distance_metric = self.distance_metric + self.io.distance_metric_type = self.distance_metric_type + + def benchmark_and_filter_candidates( + self, + index_descs, + training_vectors, + database_vectors, + query_vectors, + result_file, + include_flat, + min_accuracy, + pareto_metric, + ): + benchmark = Benchmark( + num_threads=self.num_threads, + training_vectors=training_vectors, + database_vectors=database_vectors, + query_vectors=query_vectors, + index_descs=index_descs, + k=10, + distance_metric=self.distance_metric, + ) + benchmark.set_io(self.io) + results = benchmark.benchmark( + result_file=result_file, local=self.run_local, train=True, knn=True + ) + assert results + filtered = filter_results( + results=results, + evaluation="knn", + accuracy_metric="knn_intersection", + min_accuracy=min_accuracy, + name_filter=None + if include_flat + else (lambda n: not n.startswith("Flat")), + pareto_mode=ParetoMode.GLOBAL, + pareto_metric=pareto_metric, + ) + assert filtered + index_descs = [ + IndexDescriptor( + factory=v["factory"], + construction_params=v["construction_params"], + search_params=v["search_params"], + ) + for _, _, _, _, v in filtered + ] + return index_descs, filtered + + def optimize_quantizer( + self, + training_vectors: DatasetDescriptor, + query_vectors: DatasetDescriptor, + nlists: List[int], + min_accuracy: float, + ): + quantizer_descs = {} + for nlist in nlists: + # cluster + centroids, _, _ = training_vectors.k_means( + self.io, + nlist, + dry_run=False, + ) + + descs = [IndexDescriptor(factory="Flat"),] + [ + IndexDescriptor( + factory="HNSW32", + construction_params=[{"efConstruction": 2**i}], + ) + for i in range(6, 11) + ] + + descs, _ = self.benchmark_and_filter_candidates( + descs, + training_vectors=centroids, + database_vectors=centroids, + query_vectors=query_vectors, + result_file=f"result_{centroids.get_filename()}json", + include_flat=True, + min_accuracy=min_accuracy, + pareto_metric=ParetoMetric.TIME, + ) + quantizer_descs[nlist] = descs + + return quantizer_descs + + def optimize_ivf( + self, + result_file: str, + training_vectors: DatasetDescriptor, + database_vectors: DatasetDescriptor, + query_vectors: DatasetDescriptor, + quantizers: Dict[int, List[IndexDescriptor]], + codecs: List[Tuple[str, str]], + min_accuracy: float, + ): + ivf_descs = [] + for nlist, quantizer_descs in quantizers.items(): + # build IVF index + for quantizer_desc in quantizer_descs: + for pretransform, fine_ivf in codecs: + if pretransform is None: + pretransform = "" + else: + pretransform = pretransform + "," + if quantizer_desc.construction_params is None: + construction_params = [ + None, + quantizer_desc.search_params, + ] + else: + construction_params = [ + None + ] + quantizer_desc.construction_params + if quantizer_desc.search_params is not None: + dict_merge( + construction_params[1], + quantizer_desc.search_params, + ) + ivf_descs.append( + IndexDescriptor( + factory=f"{pretransform}IVF{nlist}({quantizer_desc.factory}),{fine_ivf}", + construction_params=construction_params, + ) + ) + return self.benchmark_and_filter_candidates( + ivf_descs, + training_vectors, + database_vectors, + query_vectors, + result_file, + include_flat=False, + min_accuracy=min_accuracy, + pareto_metric=ParetoMetric.TIME_SPACE, + ) + + # train an IVFFlat index + # find the nprobe required for the given accuracy + def ivf_flat_nprobe_required_for_accuracy( + self, + result_file: str, + training_vectors: DatasetDescriptor, + database_vectors: DatasetDescriptor, + query_vectors: DatasetDescriptor, + nlist, + accuracy, + ): + _, results = self.benchmark_and_filter_candidates( + index_descs=[ + IndexDescriptor(factory=f"IVF{nlist}(Flat),Flat"), + ], + training_vectors=training_vectors, + database_vectors=database_vectors, + query_vectors=query_vectors, + result_file=result_file, + include_flat=False, + min_accuracy=accuracy, + pareto_metric=ParetoMetric.TIME, + ) + nprobe = nlist // 2 + for _, _, _, k, v in results: + if ( + ".knn" in k + and "nprobe" in v["search_params"] + and v["knn_intersection"] >= accuracy + ): + nprobe = min(nprobe, v["search_params"]["nprobe"]) + return nprobe + + # train candidate IVF codecs + # benchmark them at the same nprobe + # keep only the space _and_ time Pareto optimal + def optimize_codec( + self, + result_file: str, + d: int, + training_vectors: DatasetDescriptor, + database_vectors: DatasetDescriptor, + query_vectors: DatasetDescriptor, + nlist: int, + nprobe: int, + min_accuracy: float, + ): + codecs = ( + [ + (None, "Flat"), + (None, "SQfp16"), + (None, "SQ8"), + ] + [ + (f"OPQ{M}_{M * dim}", f"PQ{M}x{b}") + for M in [8, 12, 16, 32, 48, 64, 96, 128, 192, 256] + if d % M == 0 + for dim in range(2, 18, 2) + if M * dim <= d + for b in range(4, 14, 2) + if M * b < d * 8 # smaller than SQ8 + ] + [ + (None, f"PQ{M}x{b}") + for M in [8, 12, 16, 32, 48, 64, 96, 128, 192, 256] + if d % M == 0 + for b in range(8, 14, 2) + if M * b < d * 8 # smaller than SQ8 + ] + ) + factory = {} + for opq, pq in codecs: + factory[ + f"IVF{nlist},{pq}" if opq is None else f"{opq},IVF{nlist},{pq}" + ] = ( + opq, + pq, + ) + + _, filtered = self.benchmark_and_filter_candidates( + index_descs=[ + IndexDescriptor( + factory=f"IVF{nlist},{pq}" + if opq is None + else f"{opq},IVF{nlist},{pq}", + search_params={ + "nprobe": nprobe, + }, + ) + for opq, pq in codecs + ], + training_vectors=training_vectors, + database_vectors=database_vectors, + query_vectors=query_vectors, + result_file=result_file, + include_flat=False, + min_accuracy=min_accuracy, + pareto_metric=ParetoMetric.TIME_SPACE, + ) + results = [ + factory[r] for r in set(v["factory"] for _, _, _, k, v in filtered) + ] + return results + + def optimize( + self, + d: int, + training_vectors: DatasetDescriptor, + database_vectors_list: List[DatasetDescriptor], + query_vectors: DatasetDescriptor, + min_accuracy: float, + ): + # train an IVFFlat index + # find the nprobe required for near perfect accuracy + nlist = 4096 + nprobe_at_95 = self.ivf_flat_nprobe_required_for_accuracy( + result_file=f"result_ivf{nlist}_flat.json", + training_vectors=training_vectors, + database_vectors=database_vectors_list[0], + query_vectors=query_vectors, + nlist=nlist, + accuracy=0.95, + ) + + # train candidate IVF codecs + # benchmark them at the same nprobe + # keep only the space and time Pareto optima + codecs = self.optimize_codec( + result_file=f"result_ivf{nlist}_codec.json", + d=d, + training_vectors=training_vectors, + database_vectors=database_vectors_list[0], + query_vectors=query_vectors, + nlist=nlist, + nprobe=nprobe_at_95, + min_accuracy=min_accuracy, + ) + + # optimize coarse quantizers + quantizers = self.optimize_quantizer( + training_vectors=training_vectors, + query_vectors=query_vectors, + nlists=[4096, 8192, 16384, 32768], + min_accuracy=0.7, + ) + + # combine them with the codecs + # test them at different scales + for database_vectors in database_vectors_list: + self.optimize_ivf( + result_file=f"result_{database_vectors.get_filename()}json", + training_vectors=training_vectors, + database_vectors=database_vectors, + query_vectors=query_vectors, + quantizers=quantizers, + codecs=codecs, + min_accuracy=min_accuracy, + ) diff --git a/benchs/bench_fw/utils.py b/benchs/bench_fw/utils.py index e1e513169b..3151c0c2da 100644 --- a/benchs/bench_fw/utils.py +++ b/benchs/bench_fw/utils.py @@ -3,15 +3,22 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -from time import perf_counter +import functools import logging +from enum import Enum from multiprocessing.pool import ThreadPool -import numpy as np +from time import perf_counter + import faiss # @manual=//faiss/python:pyfaiss_gpu -import functools +import numpy as np + +from faiss.contrib.evaluation import ( # @manual=//faiss/contrib:faiss_contrib_gpu + OperatingPoints, +) logger = logging.getLogger(__name__) + def timer(name, func, once=False) -> float: logger.info(f"Measuring {name}") t1 = perf_counter() @@ -34,28 +41,41 @@ def timer(name, func, once=False) -> float: def refine_distances_knn( - xq: np.ndarray, xb: np.ndarray, I: np.ndarray, metric, + xq: np.ndarray, + xb: np.ndarray, + I: np.ndarray, + metric, ): - """ Recompute distances between xq[i] and xb[I[i, :]] """ + """Recompute distances between xq[i] and xb[I[i, :]]""" nq, k = I.shape - xq = np.ascontiguousarray(xq, dtype='float32') + xq = np.ascontiguousarray(xq, dtype="float32") nq2, d = xq.shape - xb = np.ascontiguousarray(xb, dtype='float32') + xb = np.ascontiguousarray(xb, dtype="float32") nb, d2 = xb.shape - I = np.ascontiguousarray(I, dtype='int64') + I = np.ascontiguousarray(I, dtype="int64") assert nq2 == nq assert d2 == d - D = np.empty(I.shape, dtype='float32') + D = np.empty(I.shape, dtype="float32") D[:] = np.inf if metric == faiss.METRIC_L2: faiss.fvec_L2sqr_by_idx( - faiss.swig_ptr(D), faiss.swig_ptr(xq), faiss.swig_ptr(xb), - faiss.swig_ptr(I), d, nq, k + faiss.swig_ptr(D), + faiss.swig_ptr(xq), + faiss.swig_ptr(xb), + faiss.swig_ptr(I), + d, + nq, + k, ) else: faiss.fvec_inner_products_by_idx( - faiss.swig_ptr(D), faiss.swig_ptr(xq), faiss.swig_ptr(xb), - faiss.swig_ptr(I), d, nq, k + faiss.swig_ptr(D), + faiss.swig_ptr(xq), + faiss.swig_ptr(xb), + faiss.swig_ptr(I), + d, + nq, + k, ) return D @@ -97,7 +117,10 @@ def distance_ratio_measure(I, R, D_GT, metric): @functools.cache def get_cpu_info(): - return [l for l in open("/proc/cpuinfo", "r") if "model name" in l][0][13:].strip() + return [l for l in open("/proc/cpuinfo", "r") if "model name" in l][0][ + 13: + ].strip() + def dict_merge(target, source): for k, v in source.items(): @@ -105,3 +128,121 @@ def dict_merge(target, source): dict_merge(target[k], v) else: target[k] = v + + +class Cost: + def __init__(self, values): + self.values = values + + def __le__(self, other): + return all( + v1 <= v2 for v1, v2 in zip(self.values, other.values, strict=True) + ) + + def __lt__(self, other): + return all( + v1 < v2 for v1, v2 in zip(self.values, other.values, strict=True) + ) + + +class ParetoMode(Enum): + DISABLE = 1 # no Pareto filtering + INDEX = 2 # index-local optima + GLOBAL = 3 # global optima + + +class ParetoMetric(Enum): + TIME = 0 # time vs accuracy + SPACE = 1 # space vs accuracy + TIME_SPACE = 2 # (time, space) vs accuracy + + +def range_search_recall_at_precision(experiment, precision): + return round( + max( + r + for r, p in zip( + experiment["range_search_pr"]["recall"], + experiment["range_search_pr"]["precision"], + ) + if p > precision + ), + 6, + ) + + +def filter_results( + results, + evaluation, + accuracy_metric, # str or func + time_metric=None, # func or None -> use default + space_metric=None, # func or None -> use default + min_accuracy=0, + max_space=0, + max_time=0, + scaling_factor=1.0, + name_filter=None, # func + pareto_mode=ParetoMode.DISABLE, + pareto_metric=ParetoMetric.TIME, +): + if isinstance(accuracy_metric, str): + accuracy_key = accuracy_metric + accuracy_metric = lambda v: v[accuracy_key] + + if time_metric is None: + time_metric = lambda v: v["time"] * scaling_factor + ( + v["quantizer"]["time"] if "quantizer" in v else 0 + ) + + if space_metric is None: + space_metric = lambda v: results["indices"][v["codec"]]["code_size"] + + fe = [] + ops = {} + if pareto_mode == ParetoMode.GLOBAL: + op = OperatingPoints() + ops["global"] = op + for k, v in results["experiments"].items(): + if f".{evaluation}" in k: + accuracy = accuracy_metric(v) + if min_accuracy > 0 and accuracy < min_accuracy: + continue + space = space_metric(v) + if space is None: + space = 0 + if max_space > 0 and space > max_space: + continue + time = time_metric(v) + if max_time > 0 and time > max_time: + continue + idx_name = v["index"] + ( + "snap" + if "search_params" in v and v["search_params"]["snap"] == 1 + else "" + ) + if name_filter is not None and not name_filter(idx_name): + continue + experiment = (accuracy, space, time, k, v) + if pareto_mode == ParetoMode.DISABLE: + fe.append(experiment) + continue + if pareto_mode == ParetoMode.INDEX: + if idx_name not in ops: + ops[idx_name] = OperatingPoints() + op = ops[idx_name] + if pareto_metric == ParetoMetric.TIME: + op.add_operating_point(experiment, accuracy, time) + elif pareto_metric == ParetoMetric.SPACE: + op.add_operating_point(experiment, accuracy, space) + else: + op.add_operating_point( + experiment, accuracy, Cost([time, space]) + ) + + if ops: + for op in ops.values(): + for v, _, _ in op.operating_points: + fe.append(v) + + fe.sort() + return fe diff --git a/benchs/bench_fw_notebook.ipynb b/benchs/bench_fw_notebook.ipynb index c6183a8eb9..5752aaf5fb 100644 --- a/benchs/bench_fw_notebook.ipynb +++ b/benchs/bench_fw_notebook.ipynb @@ -1,617 +1,529 @@ { - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "be081589-e1b2-4569-acb7-44203e273899", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "import matplotlib.pyplot as plt\n", - "import itertools\n", - "from faiss.contrib.evaluation import OperatingPoints\n", - "from enum import Enum\n", - "from bench_fw.benchmark_io import BenchmarkIO as BIO\n", - "from copy import copy\n", - "import numpy as np\n", - "import datetime\n", - "import glob\n", - "import io\n", - "import json\n", - "from zipfile import ZipFile" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a6492e95-24c7-4425-bf0a-27e10e879ca6", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "root = \"/checkpoint/gsz/bench_fw/range/ssnpp\"\n", - "results = BIO(root).read_json(\"result.json\")\n", - "results.keys()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0875d269-aef4-426d-83dd-866970f43777", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "results['experiments']" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a7ff7078-29c7-407c-a079-201877b764ad", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "class Cost:\n", - " def __init__(self, values):\n", - " self.values = values\n", - "\n", - " def __le__(self, other):\n", - " return all(v1 <= v2 for v1, v2 in zip(self.values, other.values, strict=True))\n", - "\n", - " def __lt__(self, other):\n", - " return all(v1 < v2 for v1, v2 in zip(self.values, other.values, strict=True))\n", - "\n", - "class ParetoMode(Enum):\n", - " DISABLE = 1 # no Pareto filtering\n", - " INDEX = 2 # index-local optima\n", - " GLOBAL = 3 # global optima\n", - "\n", - "\n", - "class ParetoMetric(Enum):\n", - " TIME = 0 # time vs accuracy\n", - " SPACE = 1 # space vs accuracy\n", - " TIME_SPACE = 2 # (time, space) vs accuracy\n", - "\n", - "def range_search_recall_at_precision(experiment, precision):\n", - " return round(max(r for r, p in zip(experiment['range_search_pr']['recall'], experiment['range_search_pr']['precision']) if p > precision), 6)\n", - "\n", - "def filter_results(\n", - " results,\n", - " evaluation,\n", - " accuracy_metric, # str or func\n", - " time_metric=None, # func or None -> use default\n", - " space_metric=None, # func or None -> use default\n", - " min_accuracy=0,\n", - " max_space=0,\n", - " max_time=0,\n", - " scaling_factor=1.0,\n", - " \n", - " pareto_mode=ParetoMode.DISABLE,\n", - " pareto_metric=ParetoMetric.TIME,\n", - "):\n", - " if isinstance(accuracy_metric, str):\n", - " accuracy_key = accuracy_metric\n", - " accuracy_metric = lambda v: v[accuracy_key]\n", - "\n", - " if time_metric is None:\n", - " time_metric = lambda v: v['time'] * scaling_factor + (v['quantizer']['time'] if 'quantizer' in v else 0)\n", - "\n", - " if space_metric is None:\n", - " space_metric = lambda v: results['indices'][v['codec']]['sa_code_size']\n", - " \n", - " fe = []\n", - " ops = {}\n", - " if pareto_mode == ParetoMode.GLOBAL:\n", - " op = OperatingPoints()\n", - " ops[\"global\"] = op\n", - " for k, v in results['experiments'].items():\n", - " if f\".{evaluation}\" in k:\n", - " accuracy = accuracy_metric(v)\n", - " if min_accuracy > 0 and accuracy < min_accuracy:\n", - " continue\n", - " space = space_metric(v)\n", - " if space is None:\n", - " space = 0 \n", - " if max_space > 0 and space > max_space:\n", - " continue\n", - " time = time_metric(v)\n", - " if max_time > 0 and time > max_time:\n", - " continue\n", - " idx_name = v['index'] + (\"snap\" if 'search_params' in v and v['search_params'][\"snap\"] == 1 else \"\")\n", - " # if idx_name.startswith(\"HNSW\"):\n", - " # continue\n", - " experiment = (accuracy, space, time, k, v)\n", - " if pareto_mode == ParetoMode.DISABLE:\n", - " fe.append(experiment)\n", - " continue\n", - " if pareto_mode == ParetoMode.INDEX:\n", - " if idx_name not in ops:\n", - " ops[idx_name] = OperatingPoints()\n", - " op = ops[idx_name]\n", - " if pareto_metric == ParetoMetric.TIME:\n", - " op.add_operating_point(experiment, accuracy, time)\n", - " elif pareto_metric == ParetoMetric.SPACE:\n", - " op.add_operating_point(experiment, accuracy, space)\n", - " else:\n", - " op.add_operating_point(experiment, accuracy, Cost([time, space]))\n", - "\n", - " if ops:\n", - " for op in ops.values():\n", - " for v, _, _ in op.operating_points:\n", - " fe.append(v)\n", - "\n", - " fe.sort()\n", - " return fe" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f080a6e2-1565-418b-8732-4adeff03a099", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "def plot_metric(experiments, accuracy_title, cost_title, plot_space=False, plot=None):\n", - " if plot is None:\n", - " plot = plt.subplot()\n", - " x = {}\n", - " y = {}\n", - " for accuracy, space, time, k, v in experiments:\n", - " idx_name = v['index'] + (\"snap\" if 'search_params' in v and v['search_params'][\"snap\"] == 1 else \"\")\n", - " if idx_name not in x:\n", - " x[idx_name] = []\n", - " y[idx_name] = []\n", - " x[idx_name].append(accuracy)\n", - " if plot_space:\n", - " y[idx_name].append(space)\n", - " else:\n", - " y[idx_name].append(time)\n", - "\n", - " #plt.figure(figsize=(10,6))\n", - " #plt.title(accuracy_title)\n", - " plot.set_xlabel(accuracy_title)\n", - " plot.set_ylabel(cost_title)\n", - " marker = itertools.cycle((\"o\", \"v\", \"^\", \"<\", \">\", \"s\", \"p\", \"P\", \"*\", \"h\", \"X\", \"D\")) \n", - " for index in x.keys():\n", - " plot.plot(x[index], y[index], marker=next(marker), label=index, linewidth=0)\n", - " plot.legend(bbox_to_anchor=(1, 1), loc='upper left')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "61007155-5edc-449e-835e-c141a01a2ae5", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# index local optima\n", - "accuracy_metric = \"knn_intersection\"\n", - "fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", - "plot_metric(fr, accuracy_title=\"knn intersection\", cost_title=\"time (seconds, 32 cores)\", plot_space=False)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f9f94dcc-5abe-4cad-9619-f5d1d24fb8c1", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# global optima\n", - "accuracy_metric = \"knn_intersection\"\n", - "fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, min_accuracy=0.5, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", - "plot_metric(fr, accuracy_title=\"knn intersection\", cost_title=\"time (seconds, 32 cores)\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "36e82084-18f6-4546-a717-163eb0224ee8", - "metadata": {}, - "outputs": [], - "source": [ - "# index local optima @ precision 0.8\n", - "precision = 0.8\n", - "accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n", - "fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", - "plot_metric(fr, accuracy_title=f\"range recall @ precision {precision}\", cost_title=\"time (seconds, 16 cores)\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "aff79376-39f7-47c0-8b83-1efe5192bb7e", - "metadata": {}, - "outputs": [], - "source": [ - "# index local optima @ precision 0.2\n", - "precision = 0.2\n", - "accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n", - "fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", - "plot_metric(fr, accuracy_title=f\"range recall @ precision {precision}\", cost_title=\"time (seconds, 16 cores)\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b4834f1f-bbbe-4cae-9aa0-a459b0c842d1", - "metadata": {}, - "outputs": [], - "source": [ - "# global optima @ precision 0.8\n", - "precision = 0.8\n", - "accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n", - "fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", - "plot_metric(fr, accuracy_title=f\"range recall @ precision {precision}\", cost_title=\"time (seconds, 16 cores)\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9aead830-6209-4956-b7ea-4a5e0029d616", - "metadata": {}, - "outputs": [], - "source": [ - "def plot_range_search_pr_curves(experiments):\n", - " x = {}\n", - " y = {}\n", - " show = {\n", - " 'Flat': None,\n", - " }\n", - " for _, _, _, k, v in fr:\n", - " if \".weighted\" in k: # and v['index'] in show:\n", - " x[k] = v['range_search_pr']['recall']\n", - " y[k] = v['range_search_pr']['precision']\n", - " \n", - " plt.title(\"range search recall\")\n", - " plt.xlabel(\"recall\")\n", - " plt.ylabel(\"precision\")\n", - " for index in x.keys():\n", - " plt.plot(x[index], y[index], '.', label=index)\n", - " plt.legend(bbox_to_anchor=(1.0, 1.0), loc='upper left')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "92e45502-7a31-4a15-90df-fa3032d7d350", - "metadata": {}, - "outputs": [], - "source": [ - "precision = 0.8\n", - "accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n", - "fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME_SPACE, scaling_factor=1)\n", - "plot_range_search_pr_curves(fr)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fdf8148a-0da6-4c5e-8d60-f8f85314574c", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "root = \"/checkpoint/gsz/bench_fw/ivf/bigann\"\n", - "scales = [1, 2, 5, 10, 20, 50]\n", - "fig, plots = plt.subplots(len(scales), sharex=True, figsize=(5,25))\n", - "fig.tight_layout()\n", - "for plot, scale in zip(plots, scales, strict=True):\n", - " results = BIO(root).read_json(f\"result{scale}.json\")\n", - " accuracy_metric = \"knn_intersection\"\n", - " fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, min_accuracy=0.9, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", - " plot_metric(fr, accuracy_title=\"knn intersection\", cost_title=\"time (seconds, 64 cores)\", plot=plot)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e503828c-ee61-45f7-814b-cce6461109bc", - "metadata": {}, - "outputs": [], - "source": [ - "x = {}\n", - "y = {}\n", - "accuracy=0.9\n", - "root = \"/checkpoint/gsz/bench_fw/ivf/bigann\"\n", - "scales = [1, 2, 5, 10, 20, 50]\n", - "#fig, plots = plt.subplots(len(scales), sharex=True, figsize=(5,25))\n", - "#fig.tight_layout()\n", - "for scale in scales:\n", - " results = BIO(root).read_json(f\"result{scale}.json\")\n", - " scale *= 1_000_000\n", - " accuracy_metric = \"knn_intersection\"\n", - " fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, min_accuracy=accuracy, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", - " seen = set()\n", - " print(scale)\n", - " for _, _, _, _, exp in fr:\n", - " fact = exp[\"factory\"]\n", - " # \"HNSW\" in fact or \n", - " if fact in seen or fact in [\"Flat\", \"IVF512,Flat\", \"IVF1024,Flat\", \"IVF2048,Flat\"]:\n", - " continue\n", - " seen.add(fact)\n", - " if fact not in x:\n", - " x[fact] = []\n", - " y[fact] = []\n", - " x[fact].append(scale)\n", - " y[fact].append(exp[\"time\"] + exp[\"quantizer\"][\"time\"])\n", - " if (exp[\"knn_intersection\"] > 0.92):\n", - " print(fact)\n", - " print(exp[\"search_params\"])\n", - " print(exp[\"knn_intersection\"])\n", - "\n", - " #plot_metric(fr, accuracy_title=\"knn intersection\", cost_title=\"time (seconds, 64 cores)\", plot=plot)\n", - " \n", - "plt.title(f\"recall @ 1 = {accuracy*100}%\")\n", - "plt.xlabel(\"database size\")\n", - "plt.ylabel(\"time\")\n", - "plt.xscale(\"log\")\n", - "plt.yscale(\"log\")\n", - "\n", - "marker = itertools.cycle((\"o\", \"v\", \"^\", \"<\", \">\", \"s\", \"p\", \"P\", \"*\", \"h\", \"X\", \"D\")) \n", - "for index in x.keys():\n", - " if \"HNSW\" in index:\n", - " plt.plot(x[index], y[index], label=index, linewidth=1, marker=next(marker), linestyle=\"dashed\")\n", - " else:\n", - " plt.plot(x[index], y[index], label=index, linewidth=1, marker=next(marker))\n", - "plt.legend(bbox_to_anchor=(1.0, 1.0), loc='upper left')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "37a99bb2-f998-461b-a345-7cc6e702cb3a", - "metadata": {}, - "outputs": [], - "source": [ - "# global optima\n", - "accuracy_metric = \"sym_recall\"\n", - "fr = filter_results(results, evaluation=\"rec\", accuracy_metric=accuracy_metric, time_metric=lambda e:e['encode_time'], min_accuracy=0.9, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.SPACE, scaling_factor=1)\n", - "plot_metric(fr, accuracy_title=\"knn intersection\", cost_title=\"space\", plot_space=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c973ce4e-3566-4f02-bd93-f113e3e0c791", - "metadata": {}, - "outputs": [], - "source": [ - "def pretty_time(s):\n", - " if s is None:\n", - " return \"None\"\n", - " s = int(s * 1000) / 1000\n", - " m, s = divmod(s, 60)\n", - " h, m = divmod(m, 60)\n", - " d, h = divmod(h, 24)\n", - " r = \"\"\n", - " if d > 0:\n", - " r += f\"{int(d)}d \"\n", - " if h > 0:\n", - " r += f\"{int(h)}h \"\n", - " if m > 0:\n", - " r += f\"{int(m)}m \"\n", - " if s > 0 or len(r) == 0:\n", - " r += f\"{s:.3f}s\"\n", - " return r\n", - "\n", - "def pretty_size(s):\n", - " if s > 1024 * 1024:\n", - " return f\"{s / 1024 / 1024:.1f}\".rstrip('0').rstrip('.') + \"MB\"\n", - " if s > 1024:\n", - " return f\"{s / 1024:.1f}\".rstrip('0').rstrip('.') + \"KB\"\n", - " return f\"{s}\"\n", - "\n", - "def pretty_mse(m):\n", - " if m is None:\n", - " return \"None\"\n", - " else:\n", - " return f\"{m:.6f}\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1ddcf226-fb97-4a59-9fc3-3ed8f7d5e703", - "metadata": {}, - "outputs": [], - "source": [ - "data = {}\n", - "root = \"/checkpoint/gsz/bench_fw/bigann\"\n", - "scales = [1, 2, 5, 10, 20, 50]\n", - "for scale in scales:\n", - " results = BIO(root).read_json(f\"result{scale}.json\")\n", - " accuracy_metric = \"knn_intersection\"\n", - " fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, min_accuracy=0, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", - " d = {}\n", - " data[f\"{scale}M\"] = d\n", - " for _, _, _, _, exp in fr:\n", - " fact = exp[\"factory\"]\n", - " # \"HNSW\" in fact or \n", - " if fact in [\"Flat\", \"IVF512,Flat\", \"IVF1024,Flat\", \"IVF2048,Flat\"]:\n", - " continue\n", - " if fact not in d:\n", - " d[fact] = []\n", - " d[fact].append({\n", - " \"nprobe\": exp[\"search_params\"][\"nprobe\"],\n", - " \"recall\": exp[\"knn_intersection\"],\n", - " \"time\": exp[\"time\"] + exp[\"quantizer\"][\"time\"],\n", - " })\n", - "data\n", - "# with open(\"/checkpoint/gsz/bench_fw/codecs.json\", \"w\") as f:\n", - "# json.dump(data, f)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e54eebb6-0a9f-4a72-84d2-f12c5bd44510", - "metadata": {}, - "outputs": [], - "source": [ - "ds = \"deep1b\"\n", - "data = []\n", - "jss = []\n", - "root = f\"/checkpoint/gsz/bench_fw/codecs/{ds}\"\n", - "results = BIO(root).read_json(f\"result.json\")\n", - "for k, e in results[\"experiments\"].items():\n", - " if \"rec\" in k and e['factory'] != 'Flat': # and e['sym_recall'] > 0.0: # and \"PRQ\" in e['factory'] and e['sym_recall'] > 0.0:\n", - " code_size = results['indices'][e['codec']]['sa_code_size']\n", - " codec_size = results['indices'][e['codec']]['codec_size']\n", - " training_time = results['indices'][e['codec']]['training_time']\n", - " # training_size = results['indices'][e['codec']]['training_size']\n", - " cpu = e['cpu'] if 'cpu' in e else \"\"\n", - " ps = ', '.join([f\"{k}={v}\" for k,v in e['construction_params'][0].items()]) if e['construction_params'] else \" \"\n", - " eps = ', '.join([f\"{k}={v}\" for k,v in e['reconstruct_params'].items() if k != \"snap\"]) if e['reconstruct_params'] else \" \"\n", - " data.append((code_size, f\"|{e['factory']}|{ps}|{eps}|{code_size}|{pretty_size(codec_size)}|{pretty_time(training_time)}|{training_size}|{pretty_mse(e['mse'])}|{e['sym_recall']}|{e['asym_recall']}|{pretty_time(e['encode_time'])}|{pretty_time(e['decode_time'])}|{cpu}|\"))\n", - " jss.append({\n", - " 'factory': e['factory'],\n", - " 'parameters': e['construction_params'][0] if e['construction_params'] else \"\",\n", - " 'evaluation_params': e['reconstruct_params'],\n", - " 'code_size': code_size,\n", - " 'codec_size': codec_size,\n", - " 'training_time': training_time,\n", - " 'training_size': training_size,\n", - " 'mse': e['mse'],\n", - " 'sym_recall': e['sym_recall'],\n", - " 'asym_recall': e['asym_recall'],\n", - " 'encode_time': e['encode_time'],\n", - " 'decode_time': e['decode_time'],\n", - " 'cpu': cpu,\n", - " })\n", - "\n", - "print(\"|factory key|construction parameters|evaluation parameters|code size|codec size|training time|training size|mean squared error|sym recall @ 1|asym recall @ 1|encode time|decode time|cpu|\")\n", - "print(\"|-|-|-|-|-|-|-|-|-|\")\n", - "data.sort()\n", - "for d in data:\n", - " print(d[1])\n", - "\n", - "with open(f\"/checkpoint/gsz/bench_fw/codecs_{ds}_test.json\", \"w\") as f:\n", - " json.dump(jss, f)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d1216733-9670-407c-b3d2-5f87bce0321c", - "metadata": {}, - "outputs": [], - "source": [ - "def read_file(filename: str, keys):\n", - " results = []\n", - " with ZipFile(filename, \"r\") as zip_file:\n", - " for key in keys:\n", - " with zip_file.open(key, \"r\") as f:\n", - " if key in [\"D\", \"I\", \"R\", \"lims\"]:\n", - " results.append(np.load(f))\n", - " elif key in [\"P\"]:\n", - " t = io.TextIOWrapper(f)\n", - " results.append(json.load(t))\n", - " else:\n", - " raise AssertionError()\n", - " return results" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "56de051e-22db-4bef-b242-1ddabc9e0bb9", - "metadata": {}, - "outputs": [], - "source": [ - "ds = \"contriever\"\n", - "data = []\n", - "jss = []\n", - "root = f\"/checkpoint/gsz/bench_fw/codecs/{ds}\"\n", - "for lf in glob.glob(root + '/*rec*.zip'):\n", - " e, = read_file(lf, ['P'])\n", - " if e['factory'] != 'Flat': # and e['sym_recall'] > 0.0: # and \"PRQ\" in e['factory'] and e['sym_recall'] > 0.0:\n", - " code_size = e['codec_meta']['sa_code_size']\n", - " codec_size = e['codec_meta']['codec_size']\n", - " training_time = e['codec_meta']['training_time']\n", - " training_size = None # e['codec_meta']['training_size']\n", - " cpu = e['cpu'] if 'cpu' in e else \"\"\n", - " ps = ', '.join([f\"{k}={v}\" for k,v in e['construction_params'][0].items()]) if e['construction_params'] else \" \"\n", - " eps = ', '.join([f\"{k}={v}\" for k,v in e['reconstruct_params'].items() if k != \"snap\"]) if e['reconstruct_params'] else \" \"\n", - " if eps in ps and eps != \"encode_ils_iters=16\" and eps != \"max_beam_size=32\":\n", - " eps = \" \"\n", - " data.append((code_size, f\"|{e['factory']}|{ps}|{eps}|{code_size}|{pretty_size(codec_size)}|{pretty_time(training_time)}|{pretty_mse(e['mse'])}|{e['sym_recall']}|{e['asym_recall']}|{pretty_time(e['encode_time'])}|{pretty_time(e['decode_time'])}|{cpu}|\"))\n", - " eps = e['reconstruct_params']\n", - " del eps['snap']\n", - " params = copy(e['construction_params'][0]) if e['construction_params'] else {}\n", - " for k, v in e['reconstruct_params'].items():\n", - " params[k] = v\n", - " jss.append({\n", - " 'factory': e['factory'],\n", - " 'params': params,\n", - " 'construction_params': e['construction_params'][0] if e['construction_params'] else {},\n", - " 'evaluation_params': e['reconstruct_params'],\n", - " 'code_size': code_size,\n", - " 'codec_size': codec_size,\n", - " 'training_time': training_time,\n", - " # 'training_size': training_size,\n", - " 'mse': e['mse'],\n", - " 'sym_recall': e['sym_recall'],\n", - " 'asym_recall': e['asym_recall'],\n", - " 'encode_time': e['encode_time'],\n", - " 'decode_time': e['decode_time'],\n", - " 'cpu': cpu,\n", - " })\n", - "\n", - "print(\"|factory key|construction parameters|encode/decode parameters|code size|codec size|training time|mean squared error|sym recall @ 1|asym recall @ 1|encode time|decode time|cpu|\")\n", - "print(\"|-|-|-|-|-|-|-|-|-|\")\n", - "data.sort()\n", - "# for d in data:\n", - "# print(d[1])\n", - "\n", - "print(len(data))\n", - "\n", - "with open(f\"/checkpoint/gsz/bench_fw/codecs_{ds}_5.json\", \"w\") as f:\n", - " json.dump(jss, f)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2fd712bf-f147-4c1b-9dbf-b04428e4c1eb", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python [conda env:.conda-faiss_from_source] *", - "language": "python", - "name": "conda-env-.conda-faiss_from_source-py" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.5" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "be081589-e1b2-4569-acb7-44203e273899", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import itertools\n", + "from faiss.contrib.evaluation import OperatingPoints\n", + "from enum import Enum\n", + "from bench_fw.benchmark_io import BenchmarkIO as BIO\n", + "from bench_fw.utils import filter_results, ParetoMode, ParetoMetric\n", + "from copy import copy\n", + "import numpy as np\n", + "import datetime\n", + "import glob\n", + "import io\n", + "import json\n", + "from zipfile import ZipFile\n", + "import tabulate" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a6492e95-24c7-4425-bf0a-27e10e879ca6", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "root = \"/checkpoint/gsz/bench_fw/optimize/bigann\"\n", + "results = BIO(root).read_json(\"result_std_d_bigann10M.json\")\n", + "results.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0875d269-aef4-426d-83dd-866970f43777", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "results['experiments']" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f080a6e2-1565-418b-8732-4adeff03a099", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def plot_metric(experiments, accuracy_title, cost_title, plot_space=False, plot=None):\n", + " if plot is None:\n", + " plot = plt.subplot()\n", + " x = {}\n", + " y = {}\n", + " for accuracy, space, time, k, v in experiments:\n", + " idx_name = v['index'] + (\"snap\" if 'search_params' in v and v['search_params'][\"snap\"] == 1 else \"\")\n", + " if idx_name not in x:\n", + " x[idx_name] = []\n", + " y[idx_name] = []\n", + " x[idx_name].append(accuracy)\n", + " if plot_space:\n", + " y[idx_name].append(space)\n", + " else:\n", + " y[idx_name].append(time)\n", + "\n", + " #plt.figure(figsize=(10,6))\n", + " #plt.title(accuracy_title)\n", + " plot.set_xlabel(accuracy_title)\n", + " plot.set_ylabel(cost_title)\n", + " plot.set_yscale(\"log\")\n", + " marker = itertools.cycle((\"o\", \"v\", \"^\", \"<\", \">\", \"s\", \"p\", \"P\", \"*\", \"h\", \"X\", \"D\")) \n", + " for index in x.keys():\n", + " plot.plot(x[index], y[index], marker=next(marker), label=index, linewidth=0)\n", + " plot.legend(bbox_to_anchor=(1, 1), loc='upper left')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "61007155-5edc-449e-835e-c141a01a2ae5", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# index local optima\n", + "accuracy_metric = \"knn_intersection\"\n", + "fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1, min_accuracy=0.95)\n", + "plot_metric(fr, accuracy_title=\"knn intersection\", cost_title=\"time (seconds, 32 cores)\", plot_space=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f9f94dcc-5abe-4cad-9619-f5d1d24fb8c1", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# global optima\n", + "accuracy_metric = \"knn_intersection\"\n", + "fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, min_accuracy=0.90, max_space=64, max_time=0, name_filter=lambda n: not n.startswith(\"Flat\"), pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", + "plot_metric(fr, accuracy_title=\"knn intersection\", cost_title=\"time (seconds, 32 cores)\", plot_space=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0c10f587-26ef-49ec-83a9-88f6a2a433e8", + "metadata": {}, + "outputs": [], + "source": [ + "def pretty_params(p):\n", + " p = copy(p)\n", + " if 'snap' in p and p['snap'] == 0:\n", + " del p['snap']\n", + " return p\n", + " \n", + "tabulate.tabulate([(accuracy, space, time, v['factory'], pretty_params(v['construction_params'][1]), pretty_params(v['search_params'])) \n", + " for accuracy, space, time, k, v in fr],\n", + " tablefmt=\"html\",\n", + " headers=[\"accuracy\",\"space\", \"time\", \"factory\", \"quantizer cfg\", \"search cfg\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "36e82084-18f6-4546-a717-163eb0224ee8", + "metadata": {}, + "outputs": [], + "source": [ + "# index local optima @ precision 0.8\n", + "precision = 0.8\n", + "accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n", + "fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", + "plot_metric(fr, accuracy_title=f\"range recall @ precision {precision}\", cost_title=\"time (seconds, 16 cores)\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aff79376-39f7-47c0-8b83-1efe5192bb7e", + "metadata": {}, + "outputs": [], + "source": [ + "# index local optima @ precision 0.2\n", + "precision = 0.2\n", + "accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n", + "fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", + "plot_metric(fr, accuracy_title=f\"range recall @ precision {precision}\", cost_title=\"time (seconds, 16 cores)\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4834f1f-bbbe-4cae-9aa0-a459b0c842d1", + "metadata": {}, + "outputs": [], + "source": [ + "# global optima @ precision 0.8\n", + "precision = 0.8\n", + "accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n", + "fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", + "plot_metric(fr, accuracy_title=f\"range recall @ precision {precision}\", cost_title=\"time (seconds, 16 cores)\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9aead830-6209-4956-b7ea-4a5e0029d616", + "metadata": {}, + "outputs": [], + "source": [ + "def plot_range_search_pr_curves(experiments):\n", + " x = {}\n", + " y = {}\n", + " show = {\n", + " 'Flat': None,\n", + " }\n", + " for _, _, _, k, v in fr:\n", + " if \".weighted\" in k: # and v['index'] in show:\n", + " x[k] = v['range_search_pr']['recall']\n", + " y[k] = v['range_search_pr']['precision']\n", + " \n", + " plt.title(\"range search recall\")\n", + " plt.xlabel(\"recall\")\n", + " plt.ylabel(\"precision\")\n", + " for index in x.keys():\n", + " plt.plot(x[index], y[index], '.', label=index)\n", + " plt.legend(bbox_to_anchor=(1.0, 1.0), loc='upper left')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "92e45502-7a31-4a15-90df-fa3032d7d350", + "metadata": {}, + "outputs": [], + "source": [ + "precision = 0.8\n", + "accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n", + "fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME_SPACE, scaling_factor=1)\n", + "plot_range_search_pr_curves(fr)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fdf8148a-0da6-4c5e-8d60-f8f85314574c", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "root = \"/checkpoint/gsz/bench_fw/ivf/bigann\"\n", + "scales = [1, 2, 5, 10, 20, 50]\n", + "fig, plots = plt.subplots(len(scales), sharex=True, figsize=(5,25))\n", + "fig.tight_layout()\n", + "for plot, scale in zip(plots, scales, strict=True):\n", + " results = BIO(root).read_json(f\"result{scale}.json\")\n", + " accuracy_metric = \"knn_intersection\"\n", + " fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, min_accuracy=0.9, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", + " plot_metric(fr, accuracy_title=\"knn intersection\", cost_title=\"time (seconds, 64 cores)\", plot=plot)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e503828c-ee61-45f7-814b-cce6461109bc", + "metadata": {}, + "outputs": [], + "source": [ + "x = {}\n", + "y = {}\n", + "accuracy=0.9\n", + "root = \"/checkpoint/gsz/bench_fw/ivf/bigann\"\n", + "scales = [1, 2, 5, 10, 20, 50]\n", + "#fig, plots = plt.subplots(len(scales), sharex=True, figsize=(5,25))\n", + "#fig.tight_layout()\n", + "for scale in scales:\n", + " results = BIO(root).read_json(f\"result{scale}.json\")\n", + " scale *= 1_000_000\n", + " accuracy_metric = \"knn_intersection\"\n", + " fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, min_accuracy=accuracy, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", + " seen = set()\n", + " print(scale)\n", + " for _, _, _, _, exp in fr:\n", + " fact = exp[\"factory\"]\n", + " # \"HNSW\" in fact or \n", + " if fact in seen or fact in [\"Flat\", \"IVF512,Flat\", \"IVF1024,Flat\", \"IVF2048,Flat\"]:\n", + " continue\n", + " seen.add(fact)\n", + " if fact not in x:\n", + " x[fact] = []\n", + " y[fact] = []\n", + " x[fact].append(scale)\n", + " y[fact].append(exp[\"time\"] + exp[\"quantizer\"][\"time\"])\n", + " if (exp[\"knn_intersection\"] > 0.92):\n", + " print(fact)\n", + " print(exp[\"search_params\"])\n", + " print(exp[\"knn_intersection\"])\n", + "\n", + " #plot_metric(fr, accuracy_title=\"knn intersection\", cost_title=\"time (seconds, 64 cores)\", plot=plot)\n", + " \n", + "plt.title(f\"recall @ 1 = {accuracy*100}%\")\n", + "plt.xlabel(\"database size\")\n", + "plt.ylabel(\"time\")\n", + "plt.xscale(\"log\")\n", + "plt.yscale(\"log\")\n", + "\n", + "marker = itertools.cycle((\"o\", \"v\", \"^\", \"<\", \">\", \"s\", \"p\", \"P\", \"*\", \"h\", \"X\", \"D\")) \n", + "for index in x.keys():\n", + " if \"HNSW\" in index:\n", + " plt.plot(x[index], y[index], label=index, linewidth=1, marker=next(marker), linestyle=\"dashed\")\n", + " else:\n", + " plt.plot(x[index], y[index], label=index, linewidth=1, marker=next(marker))\n", + "plt.legend(bbox_to_anchor=(1.0, 1.0), loc='upper left')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "37a99bb2-f998-461b-a345-7cc6e702cb3a", + "metadata": {}, + "outputs": [], + "source": [ + "# global optima\n", + "accuracy_metric = \"sym_recall\"\n", + "fr = filter_results(results, evaluation=\"rec\", accuracy_metric=accuracy_metric, time_metric=lambda e:e['encode_time'], min_accuracy=0.9, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.SPACE, scaling_factor=1)\n", + "plot_metric(fr, accuracy_title=\"knn intersection\", cost_title=\"space\", plot_space=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c973ce4e-3566-4f02-bd93-f113e3e0c791", + "metadata": {}, + "outputs": [], + "source": [ + "def pretty_time(s):\n", + " if s is None:\n", + " return \"None\"\n", + " s = int(s * 1000) / 1000\n", + " m, s = divmod(s, 60)\n", + " h, m = divmod(m, 60)\n", + " d, h = divmod(h, 24)\n", + " r = \"\"\n", + " if d > 0:\n", + " r += f\"{int(d)}d \"\n", + " if h > 0:\n", + " r += f\"{int(h)}h \"\n", + " if m > 0:\n", + " r += f\"{int(m)}m \"\n", + " if s > 0 or len(r) == 0:\n", + " r += f\"{s:.3f}s\"\n", + " return r\n", + "\n", + "def pretty_size(s):\n", + " if s > 1024 * 1024:\n", + " return f\"{s / 1024 / 1024:.1f}\".rstrip('0').rstrip('.') + \"MB\"\n", + " if s > 1024:\n", + " return f\"{s / 1024:.1f}\".rstrip('0').rstrip('.') + \"KB\"\n", + " return f\"{s}\"\n", + "\n", + "def pretty_mse(m):\n", + " if m is None:\n", + " return \"None\"\n", + " else:\n", + " return f\"{m:.6f}\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1ddcf226-fb97-4a59-9fc3-3ed8f7d5e703", + "metadata": {}, + "outputs": [], + "source": [ + "data = {}\n", + "root = \"/checkpoint/gsz/bench_fw/bigann\"\n", + "scales = [1, 2, 5, 10, 20, 50]\n", + "for scale in scales:\n", + " results = BIO(root).read_json(f\"result{scale}.json\")\n", + " accuracy_metric = \"knn_intersection\"\n", + " fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, min_accuracy=0, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", + " d = {}\n", + " data[f\"{scale}M\"] = d\n", + " for _, _, _, _, exp in fr:\n", + " fact = exp[\"factory\"]\n", + " # \"HNSW\" in fact or \n", + " if fact in [\"Flat\", \"IVF512,Flat\", \"IVF1024,Flat\", \"IVF2048,Flat\"]:\n", + " continue\n", + " if fact not in d:\n", + " d[fact] = []\n", + " d[fact].append({\n", + " \"nprobe\": exp[\"search_params\"][\"nprobe\"],\n", + " \"recall\": exp[\"knn_intersection\"],\n", + " \"time\": exp[\"time\"] + exp[\"quantizer\"][\"time\"],\n", + " })\n", + "data\n", + "# with open(\"/checkpoint/gsz/bench_fw/codecs.json\", \"w\") as f:\n", + "# json.dump(data, f)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e54eebb6-0a9f-4a72-84d2-f12c5bd44510", + "metadata": {}, + "outputs": [], + "source": [ + "ds = \"deep1b\"\n", + "data = []\n", + "jss = []\n", + "root = f\"/checkpoint/gsz/bench_fw/codecs/{ds}\"\n", + "results = BIO(root).read_json(f\"result.json\")\n", + "for k, e in results[\"experiments\"].items():\n", + " if \"rec\" in k and e['factory'] != 'Flat': # and e['sym_recall'] > 0.0: # and \"PRQ\" in e['factory'] and e['sym_recall'] > 0.0:\n", + " code_size = results['indices'][e['codec']]['sa_code_size']\n", + " codec_size = results['indices'][e['codec']]['codec_size']\n", + " training_time = results['indices'][e['codec']]['training_time']\n", + " # training_size = results['indices'][e['codec']]['training_size']\n", + " cpu = e['cpu'] if 'cpu' in e else \"\"\n", + " ps = ', '.join([f\"{k}={v}\" for k,v in e['construction_params'][0].items()]) if e['construction_params'] else \" \"\n", + " eps = ', '.join([f\"{k}={v}\" for k,v in e['reconstruct_params'].items() if k != \"snap\"]) if e['reconstruct_params'] else \" \"\n", + " data.append((code_size, f\"|{e['factory']}|{ps}|{eps}|{code_size}|{pretty_size(codec_size)}|{pretty_time(training_time)}|{training_size}|{pretty_mse(e['mse'])}|{e['sym_recall']}|{e['asym_recall']}|{pretty_time(e['encode_time'])}|{pretty_time(e['decode_time'])}|{cpu}|\"))\n", + " jss.append({\n", + " 'factory': e['factory'],\n", + " 'parameters': e['construction_params'][0] if e['construction_params'] else \"\",\n", + " 'evaluation_params': e['reconstruct_params'],\n", + " 'code_size': code_size,\n", + " 'codec_size': codec_size,\n", + " 'training_time': training_time,\n", + " 'training_size': training_size,\n", + " 'mse': e['mse'],\n", + " 'sym_recall': e['sym_recall'],\n", + " 'asym_recall': e['asym_recall'],\n", + " 'encode_time': e['encode_time'],\n", + " 'decode_time': e['decode_time'],\n", + " 'cpu': cpu,\n", + " })\n", + "\n", + "print(\"|factory key|construction parameters|evaluation parameters|code size|codec size|training time|training size|mean squared error|sym recall @ 1|asym recall @ 1|encode time|decode time|cpu|\")\n", + "print(\"|-|-|-|-|-|-|-|-|-|\")\n", + "data.sort()\n", + "for d in data:\n", + " print(d[1])\n", + "\n", + "with open(f\"/checkpoint/gsz/bench_fw/codecs_{ds}_test.json\", \"w\") as f:\n", + " json.dump(jss, f)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d1216733-9670-407c-b3d2-5f87bce0321c", + "metadata": {}, + "outputs": [], + "source": [ + "def read_file(filename: str, keys):\n", + " results = []\n", + " with ZipFile(filename, \"r\") as zip_file:\n", + " for key in keys:\n", + " with zip_file.open(key, \"r\") as f:\n", + " if key in [\"D\", \"I\", \"R\", \"lims\"]:\n", + " results.append(np.load(f))\n", + " elif key in [\"P\"]:\n", + " t = io.TextIOWrapper(f)\n", + " results.append(json.load(t))\n", + " else:\n", + " raise AssertionError()\n", + " return results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56de051e-22db-4bef-b242-1ddabc9e0bb9", + "metadata": {}, + "outputs": [], + "source": [ + "ds = \"contriever\"\n", + "data = []\n", + "jss = []\n", + "root = f\"/checkpoint/gsz/bench_fw/codecs/{ds}\"\n", + "for lf in glob.glob(root + '/*rec*.zip'):\n", + " e, = read_file(lf, ['P'])\n", + " if e['factory'] != 'Flat': # and e['sym_recall'] > 0.0: # and \"PRQ\" in e['factory'] and e['sym_recall'] > 0.0:\n", + " code_size = e['codec_meta']['sa_code_size']\n", + " codec_size = e['codec_meta']['codec_size']\n", + " training_time = e['codec_meta']['training_time']\n", + " training_size = None # e['codec_meta']['training_size']\n", + " cpu = e['cpu'] if 'cpu' in e else \"\"\n", + " ps = ', '.join([f\"{k}={v}\" for k,v in e['construction_params'][0].items()]) if e['construction_params'] else \" \"\n", + " eps = ', '.join([f\"{k}={v}\" for k,v in e['reconstruct_params'].items() if k != \"snap\"]) if e['reconstruct_params'] else \" \"\n", + " if eps in ps and eps != \"encode_ils_iters=16\" and eps != \"max_beam_size=32\":\n", + " eps = \" \"\n", + " data.append((code_size, f\"|{e['factory']}|{ps}|{eps}|{code_size}|{pretty_size(codec_size)}|{pretty_time(training_time)}|{pretty_mse(e['mse'])}|{e['sym_recall']}|{e['asym_recall']}|{pretty_time(e['encode_time'])}|{pretty_time(e['decode_time'])}|{cpu}|\"))\n", + " eps = e['reconstruct_params']\n", + " del eps['snap']\n", + " params = copy(e['construction_params'][0]) if e['construction_params'] else {}\n", + " for k, v in e['reconstruct_params'].items():\n", + " params[k] = v\n", + " jss.append({\n", + " 'factory': e['factory'],\n", + " 'params': params,\n", + " 'construction_params': e['construction_params'][0] if e['construction_params'] else {},\n", + " 'evaluation_params': e['reconstruct_params'],\n", + " 'code_size': code_size,\n", + " 'codec_size': codec_size,\n", + " 'training_time': training_time,\n", + " # 'training_size': training_size,\n", + " 'mse': e['mse'],\n", + " 'sym_recall': e['sym_recall'],\n", + " 'asym_recall': e['asym_recall'],\n", + " 'encode_time': e['encode_time'],\n", + " 'decode_time': e['decode_time'],\n", + " 'cpu': cpu,\n", + " })\n", + "\n", + "print(\"|factory key|construction parameters|encode/decode parameters|code size|codec size|training time|mean squared error|sym recall @ 1|asym recall @ 1|encode time|decode time|cpu|\")\n", + "print(\"|-|-|-|-|-|-|-|-|-|\")\n", + "data.sort()\n", + "# for d in data:\n", + "# print(d[1])\n", + "\n", + "print(len(data))\n", + "\n", + "with open(f\"/checkpoint/gsz/bench_fw/codecs_{ds}_5.json\", \"w\") as f:\n", + " json.dump(jss, f)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:.conda-faiss_from_source] *", + "language": "python", + "name": "conda-env-.conda-faiss_from_source-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 + } diff --git a/benchs/bench_fw_optimize.py b/benchs/bench_fw_optimize.py new file mode 100644 index 0000000000..31b56f9f51 --- /dev/null +++ b/benchs/bench_fw_optimize.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +import os + +from bench_fw.benchmark_io import BenchmarkIO +from bench_fw.descriptors import DatasetDescriptor +from bench_fw.optimize import Optimizer + +logging.basicConfig(level=logging.INFO) + + +def bigann(bio): + optimizer = Optimizer( + distance_metric="L2", + num_threads=32, + run_local=False, + ) + optimizer.set_io(bio) + query_vectors = DatasetDescriptor(namespace="std_q", tablename="bigann1M") + xt = bio.get_dataset(query_vectors) + optimizer.optimize( + d=xt.shape[1], + training_vectors=DatasetDescriptor( + namespace="std_t", + tablename="bigann1M", + num_vectors=2_000_000, + ), + database_vectors_list=[ + DatasetDescriptor( + namespace="std_d", + tablename="bigann1M", + ), + DatasetDescriptor(namespace="std_d", tablename="bigann10M"), + ], + query_vectors=query_vectors, + min_accuracy=0.85, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("experiment") + parser.add_argument("path") + args = parser.parse_args() + assert os.path.exists(args.path) + path = os.path.join(args.path, args.experiment) + if not os.path.exists(path): + os.mkdir(path) + bio = BenchmarkIO( + path=path, + ) + if args.experiment == "bigann": + bigann(bio)