Skip to content

Commit

Permalink
Merge commit c3b9374984208f37484fb7b86c44345729592835 from Faiss mast…
Browse files Browse the repository at this point in the history
…er branch

Commit: c3b9374984208f37484fb7b86c44345729592835
Parents: 0a00d8137a386a0efd7f789e3e0912ab4eb73508
Author: Gergely Szilvasy <[email protected]>
Committer: Facebook GitHub Bot <[email protected]>
Date: Fri Oct 20 2023 10:53:56 GMT-0400 (Eastern Daylight Time)

bench_fw - fixes & nits for oss (#3102)

Summary: Pull Request resolved: facebookresearch/faiss#3102

Reviewed By: pemazare

Differential Revision: D50426528

Pulled By: algoriddle

fbshipit-source-id: 886960b8b522318967fc5ec305666871b496cae8
Signed-off-by: Alexandr Guzhva <[email protected]>
  • Loading branch information
alexanderguzhva committed Nov 27, 2023
1 parent 8c8222e commit 217ce0c
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 37 deletions.
72 changes: 42 additions & 30 deletions thirdparty/faiss/benchs/bench_fw/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

from contextlib import contextmanager
import json
import logging
import time
from dataclasses import dataclass
from multiprocessing.pool import ThreadPool
from operator import itemgetter
from statistics import median, mean
from time import perf_counter
from typing import Any, List, Optional
from .descriptors import DatasetDescriptor, IndexDescriptor

Expand All @@ -26,6 +27,15 @@
logger = logging.getLogger(__name__)


@contextmanager
def timer(name) -> float:
logger.info(f"Measuring {name}")
t1 = t2 = perf_counter()
yield lambda: t2 - t1
t2 = perf_counter()
logger.info(f"Time for {name}: {t2 - t1:.3f} seconds")


def refine_distances_knn(
D: np.ndarray, I: np.ndarray, xq: np.ndarray, xb: np.ndarray, metric
):
Expand Down Expand Up @@ -77,7 +87,7 @@ def range_search_pr_curve(
tbl = np.vstack(
[dist_ann, metric_score, cum_score, precision, recall, unique_key]
)
group_by_dist_max_cum_score = np.empty(len(dist_ann), np.bool)
group_by_dist_max_cum_score = np.empty(len(dist_ann), bool)
group_by_dist_max_cum_score[-1] = True
group_by_dist_max_cum_score[:-1] = dist_ann[1:] != dist_ann[:-1]
tbl = tbl[:, group_by_dist_max_cum_score]
Expand Down Expand Up @@ -161,11 +171,13 @@ def optimizer(codec, search, cost_metric, perf_metric):
op.add_operating_point(key, perf, cost)


def distance_ratio_measure(R, D_GT, metric):
def distance_ratio_measure(I, R, D_GT, metric):
sum_of_R = np.sum(np.where(I >= 0, R, 0))
sum_of_D_GT = np.sum(np.where(I >= 0, D_GT, 0))
if metric == faiss.METRIC_INNER_PRODUCT:
return (np.sum(R) / np.sum(D_GT)).item()
return (sum_of_R / sum_of_D_GT).item()
elif metric == faiss.METRIC_L2:
return (np.sum(D_GT) / np.sum(R)).item()
return (sum_of_D_GT / sum_of_R).item()
else:
raise RuntimeError(f"unknown metric {metric}")

Expand All @@ -188,7 +200,7 @@ def get_range_search_metric_function(range_metric, D, R):
assert R is not None
assert D.shape == R.shape
if isinstance(range_metric, list):
aradius, ascore = [], []
aradius, ascore, aradius_from, aradius_to = [], [], [], []
radius_to = 0
for rsd in range_metric:
assert isinstance(rsd, list)
Expand All @@ -212,6 +224,8 @@ def get_range_search_metric_function(range_metric, D, R):
)
aradius.append(real_radius)
ascore.append(score)
aradius_from.append(radius_from)
aradius_to.append(radius_to)

def sigmoid(x, a, b, c):
return a / (1 + np.exp(b * x - c))
Expand All @@ -229,6 +243,7 @@ def sigmoid(x, a, b, c):
cutoff,
lambda x: np.where(x < cutoff, sigmoid(x, *popt), 0),
popt.tolist(),
list(zip(aradius, ascore, aradius_from, aradius_to, strict=True))
)
else:
# Assuming that the range_metric is a float,
Expand All @@ -244,7 +259,7 @@ def sigmoid(x, a, b, c):
f"range_search_metric_function {range_metric=} {real_range=}"
)
assert isinstance(real_range, float)
return real_range * 2, lambda x: np.where(x < real_range, 1, 0), []
return real_range * 2, lambda x: np.where(x < real_range, 1, 0), [], []


@dataclass
Expand Down Expand Up @@ -312,9 +327,9 @@ def range_search_reference(self, index_desc, range_metric):
assert len(range_metric) > 0
ri = len(range_metric[0]) - 1
m_radius = (
max(range_metric, key=itemgetter(ri))[ri]
max(rm[ri] for rm in range_metric)
if self.distance_metric_type == faiss.METRIC_L2
else min(range_metric, key=itemgetter(ri))[ri]
else min(rm[ri] for rm in range_metric)
)
else:
m_radius = range_metric
Expand All @@ -329,13 +344,14 @@ def range_search_reference(self, index_desc, range_metric):
gt_radius,
range_search_metric_function,
coefficients,
coefficients_training_data,
) = get_range_search_metric_function(
range_metric,
D if not flat else None,
R if not flat else None,
)
logger.info("range_search_reference: end")
return gt_radius, range_search_metric_function, coefficients
return gt_radius, range_search_metric_function, coefficients, coefficients_training_data

def estimate_range(self, index_desc, parameters, range_scoring_radius):
D, I, R, P = self.knn_search(
Expand Down Expand Up @@ -397,16 +413,12 @@ def range_search(
)
# QD = QD[:, :index.nprobe]
# QI = QI[:, :index.nprobe]
logger.info("Timing range_search_preassigned")
faiss.cvar.indexIVF_stats.reset()
t0 = time.time()
lims, D, I = index.range_search_preassigned(xq, radius, QI, QD)
t = time.time() - t0
with timer("range_search_preassigned") as t:
lims, D, I = index.range_search_preassigned(xq, radius, QI, QD)
else:
logger.info("Timing range_search")
t0 = time.time()
lims, D, I = index.range_search(xq, radius)
t = time.time() - t0
with timer("range_search") as t:
lims, D, I = index.range_search(xq, radius)
if flat:
R = D
else:
Expand All @@ -415,7 +427,7 @@ def range_search(
lims, D, I, xq, xb, self.distance_metric_type
)
P = {
"time": t,
"time": t(),
"radius": radius,
"count": lims[-1].item(),
"parameters": parameters,
Expand Down Expand Up @@ -560,16 +572,12 @@ def knn_search(
)
# QD = QD[:, :index.nprobe]
# QI = QI[:, :index.nprobe]
logger.info("Timing knn search_preassigned")
faiss.cvar.indexIVF_stats.reset()
t0 = time.time()
D, I = index.search_preassigned(xq, k, QI, QD)
t = time.time() - t0
with timer("knn search_preassigned") as t:
D, I = index.search_preassigned(xq, k, QI, QD)
else:
logger.info("Timing knn search")
t0 = time.time()
D, I = index.search(xq, k)
t = time.time() - t0
with timer("knn search") as t:
D, I = index.search(xq, k)
if flat or level > 0:
R = D
else:
Expand All @@ -578,7 +586,7 @@ def knn_search(
D, I, xq, xb, self.distance_metric_type
)
P = {
"time": t,
"time": t(),
"parameters": parameters,
"index": index_desc.factory,
"level": level,
Expand Down Expand Up @@ -646,7 +654,7 @@ def experiment(parameters, cost_metric, perf_metric):
I, self.gt_knn_I
),
"distance_ratio": distance_ratio_measure(
R, self.gt_knn_D, self.distance_metric_type
I, R, self.gt_knn_D, self.distance_metric_type
),
}
results["experiments"][key] = metrics
Expand Down Expand Up @@ -691,8 +699,12 @@ def benchmark(self) -> str:
gt_radius,
range_search_metric_function,
coefficients,
coefficients_training_data,
) = self.range_search_reference(index_desc, range_metric)
results["metrics"][metric_key] = coefficients
results["metrics"][metric_key] = {
"coefficients": coefficients,
"training_data": coefficients_training_data,
}
gt_rsm = self.range_ground_truth(
gt_radius, range_search_metric_function
)
Expand Down
14 changes: 7 additions & 7 deletions thirdparty/faiss/benchs/bench_fw/benchmark_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def write_file(
def get_dataset(self, dataset):
if dataset not in self.cached_ds:
self.cached_ds[dataset] = self.read_nparray(
os.path.join(self.path, dataset.name)
os.path.join(self.path, dataset.tablename)
)
return self.cached_ds[dataset]

Expand All @@ -207,9 +207,9 @@ def read_nparray(
filename: str,
):
fn = self.download_file_from_blobstore(filename)
logger.info(f"Loading nparray from {fn}\n")
logger.info(f"Loading nparray from {fn}")
nparray = np.load(fn)
logger.info(f"Loaded nparray {nparray.shape} from {fn}\n")
logger.info(f"Loaded nparray {nparray.shape} from {fn}")
return nparray

def write_nparray(
Expand All @@ -218,7 +218,7 @@ def write_nparray(
filename: str,
):
fn = self.get_local_filename(filename)
logger.info(f"Saving nparray {nparray.shape} to {fn}\n")
logger.info(f"Saving nparray {nparray.shape} to {fn}")
np.save(fn, nparray)
self.upload_file_to_blobstore(filename)

Expand All @@ -227,10 +227,10 @@ def read_json(
filename: str,
):
fn = self.download_file_from_blobstore(filename)
logger.info(f"Loading json {fn}\n")
logger.info(f"Loading json {fn}")
with open(fn, "r") as fp:
json_dict = json.load(fp)
logger.info(f"Loaded json {json_dict} from {fn}\n")
logger.info(f"Loaded json {json_dict} from {fn}")
return json_dict

def write_json(
Expand All @@ -240,7 +240,7 @@ def write_json(
overwrite: bool = False,
):
fn = self.get_local_filename(filename)
logger.info(f"Saving json {json_dict} to {fn}\n")
logger.info(f"Saving json {json_dict} to {fn}")
with open(fn, "w") as fp:
json.dump(json_dict, fp)
self.upload_file_to_blobstore(filename, overwrite=overwrite)

0 comments on commit 217ce0c

Please sign in to comment.