Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bench_fw - fixes & nits for oss #3102

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 42 additions & 30 deletions benchs/bench_fw/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

from contextlib import contextmanager
import json
import logging
import time
from dataclasses import dataclass
from multiprocessing.pool import ThreadPool
from operator import itemgetter
from statistics import median, mean
from time import perf_counter
from typing import Any, List, Optional
from .descriptors import DatasetDescriptor, IndexDescriptor

Expand All @@ -26,6 +27,15 @@
logger = logging.getLogger(__name__)


@contextmanager
def timer(name) -> float:
logger.info(f"Measuring {name}")
t1 = t2 = perf_counter()
yield lambda: t2 - t1
t2 = perf_counter()
logger.info(f"Time for {name}: {t2 - t1:.3f} seconds")


def refine_distances_knn(
D: np.ndarray, I: np.ndarray, xq: np.ndarray, xb: np.ndarray, metric
):
Expand Down Expand Up @@ -77,7 +87,7 @@ def range_search_pr_curve(
tbl = np.vstack(
[dist_ann, metric_score, cum_score, precision, recall, unique_key]
)
group_by_dist_max_cum_score = np.empty(len(dist_ann), np.bool)
group_by_dist_max_cum_score = np.empty(len(dist_ann), bool)
group_by_dist_max_cum_score[-1] = True
group_by_dist_max_cum_score[:-1] = dist_ann[1:] != dist_ann[:-1]
tbl = tbl[:, group_by_dist_max_cum_score]
Expand Down Expand Up @@ -161,11 +171,13 @@ def optimizer(codec, search, cost_metric, perf_metric):
op.add_operating_point(key, perf, cost)


def distance_ratio_measure(R, D_GT, metric):
def distance_ratio_measure(I, R, D_GT, metric):
sum_of_R = np.sum(np.where(I >= 0, R, 0))
sum_of_D_GT = np.sum(np.where(I >= 0, D_GT, 0))
if metric == faiss.METRIC_INNER_PRODUCT:
return (np.sum(R) / np.sum(D_GT)).item()
return (sum_of_R / sum_of_D_GT).item()
elif metric == faiss.METRIC_L2:
return (np.sum(D_GT) / np.sum(R)).item()
return (sum_of_D_GT / sum_of_R).item()
else:
raise RuntimeError(f"unknown metric {metric}")

Expand All @@ -188,7 +200,7 @@ def get_range_search_metric_function(range_metric, D, R):
assert R is not None
assert D.shape == R.shape
if isinstance(range_metric, list):
aradius, ascore = [], []
aradius, ascore, aradius_from, aradius_to = [], [], [], []
radius_to = 0
for rsd in range_metric:
assert isinstance(rsd, list)
Expand All @@ -212,6 +224,8 @@ def get_range_search_metric_function(range_metric, D, R):
)
aradius.append(real_radius)
ascore.append(score)
aradius_from.append(radius_from)
aradius_to.append(radius_to)

def sigmoid(x, a, b, c):
return a / (1 + np.exp(b * x - c))
Expand All @@ -229,6 +243,7 @@ def sigmoid(x, a, b, c):
cutoff,
lambda x: np.where(x < cutoff, sigmoid(x, *popt), 0),
popt.tolist(),
list(zip(aradius, ascore, aradius_from, aradius_to, strict=True))
)
else:
# Assuming that the range_metric is a float,
Expand All @@ -244,7 +259,7 @@ def sigmoid(x, a, b, c):
f"range_search_metric_function {range_metric=} {real_range=}"
)
assert isinstance(real_range, float)
return real_range * 2, lambda x: np.where(x < real_range, 1, 0), []
return real_range * 2, lambda x: np.where(x < real_range, 1, 0), [], []


@dataclass
Expand Down Expand Up @@ -312,9 +327,9 @@ def range_search_reference(self, index_desc, range_metric):
assert len(range_metric) > 0
ri = len(range_metric[0]) - 1
m_radius = (
max(range_metric, key=itemgetter(ri))[ri]
max(rm[ri] for rm in range_metric)
if self.distance_metric_type == faiss.METRIC_L2
else min(range_metric, key=itemgetter(ri))[ri]
else min(rm[ri] for rm in range_metric)
)
else:
m_radius = range_metric
Expand All @@ -329,13 +344,14 @@ def range_search_reference(self, index_desc, range_metric):
gt_radius,
range_search_metric_function,
coefficients,
coefficients_training_data,
) = get_range_search_metric_function(
range_metric,
D if not flat else None,
R if not flat else None,
)
logger.info("range_search_reference: end")
return gt_radius, range_search_metric_function, coefficients
return gt_radius, range_search_metric_function, coefficients, coefficients_training_data

def estimate_range(self, index_desc, parameters, range_scoring_radius):
D, I, R, P = self.knn_search(
Expand Down Expand Up @@ -397,16 +413,12 @@ def range_search(
)
# QD = QD[:, :index.nprobe]
# QI = QI[:, :index.nprobe]
logger.info("Timing range_search_preassigned")
faiss.cvar.indexIVF_stats.reset()
t0 = time.time()
lims, D, I = index.range_search_preassigned(xq, radius, QI, QD)
t = time.time() - t0
with timer("range_search_preassigned") as t:
lims, D, I = index.range_search_preassigned(xq, radius, QI, QD)
else:
logger.info("Timing range_search")
t0 = time.time()
lims, D, I = index.range_search(xq, radius)
t = time.time() - t0
with timer("range_search") as t:
lims, D, I = index.range_search(xq, radius)
if flat:
R = D
else:
Expand All @@ -415,7 +427,7 @@ def range_search(
lims, D, I, xq, xb, self.distance_metric_type
)
P = {
"time": t,
"time": t(),
"radius": radius,
"count": lims[-1].item(),
"parameters": parameters,
Expand Down Expand Up @@ -560,16 +572,12 @@ def knn_search(
)
# QD = QD[:, :index.nprobe]
# QI = QI[:, :index.nprobe]
logger.info("Timing knn search_preassigned")
faiss.cvar.indexIVF_stats.reset()
t0 = time.time()
D, I = index.search_preassigned(xq, k, QI, QD)
t = time.time() - t0
with timer("knn search_preassigned") as t:
D, I = index.search_preassigned(xq, k, QI, QD)
else:
logger.info("Timing knn search")
t0 = time.time()
D, I = index.search(xq, k)
t = time.time() - t0
with timer("knn search") as t:
D, I = index.search(xq, k)
if flat or level > 0:
R = D
else:
Expand All @@ -578,7 +586,7 @@ def knn_search(
D, I, xq, xb, self.distance_metric_type
)
P = {
"time": t,
"time": t(),
"parameters": parameters,
"index": index_desc.factory,
"level": level,
Expand Down Expand Up @@ -646,7 +654,7 @@ def experiment(parameters, cost_metric, perf_metric):
I, self.gt_knn_I
),
"distance_ratio": distance_ratio_measure(
R, self.gt_knn_D, self.distance_metric_type
I, R, self.gt_knn_D, self.distance_metric_type
),
}
results["experiments"][key] = metrics
Expand Down Expand Up @@ -691,8 +699,12 @@ def benchmark(self) -> str:
gt_radius,
range_search_metric_function,
coefficients,
coefficients_training_data,
) = self.range_search_reference(index_desc, range_metric)
results["metrics"][metric_key] = coefficients
results["metrics"][metric_key] = {
"coefficients": coefficients,
"training_data": coefficients_training_data,
}
gt_rsm = self.range_ground_truth(
gt_radius, range_search_metric_function
)
Expand Down
14 changes: 7 additions & 7 deletions benchs/bench_fw/benchmark_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def write_file(
def get_dataset(self, dataset):
if dataset not in self.cached_ds:
self.cached_ds[dataset] = self.read_nparray(
os.path.join(self.path, dataset.name)
os.path.join(self.path, dataset.tablename)
)
return self.cached_ds[dataset]

Expand All @@ -207,9 +207,9 @@ def read_nparray(
filename: str,
):
fn = self.download_file_from_blobstore(filename)
logger.info(f"Loading nparray from {fn}\n")
logger.info(f"Loading nparray from {fn}")
nparray = np.load(fn)
logger.info(f"Loaded nparray {nparray.shape} from {fn}\n")
logger.info(f"Loaded nparray {nparray.shape} from {fn}")
return nparray

def write_nparray(
Expand All @@ -218,7 +218,7 @@ def write_nparray(
filename: str,
):
fn = self.get_local_filename(filename)
logger.info(f"Saving nparray {nparray.shape} to {fn}\n")
logger.info(f"Saving nparray {nparray.shape} to {fn}")
np.save(fn, nparray)
self.upload_file_to_blobstore(filename)

Expand All @@ -227,10 +227,10 @@ def read_json(
filename: str,
):
fn = self.download_file_from_blobstore(filename)
logger.info(f"Loading json {fn}\n")
logger.info(f"Loading json {fn}")
with open(fn, "r") as fp:
json_dict = json.load(fp)
logger.info(f"Loaded json {json_dict} from {fn}\n")
logger.info(f"Loaded json {json_dict} from {fn}")
return json_dict

def write_json(
Expand All @@ -240,7 +240,7 @@ def write_json(
overwrite: bool = False,
):
fn = self.get_local_filename(filename)
logger.info(f"Saving json {json_dict} to {fn}\n")
logger.info(f"Saving json {json_dict} to {fn}")
with open(fn, "w") as fp:
json.dump(json_dict, fp)
self.upload_file_to_blobstore(filename, overwrite=overwrite)
58 changes: 0 additions & 58 deletions build.sh

This file was deleted.