Skip to content

Commit

Permalink
Merge pull request #4 from divyegala/python-ann-bench-use-gbench
Browse files Browse the repository at this point in the history
Python ann bench use gbench
  • Loading branch information
cjnolet authored Aug 28, 2023
2 parents 1720e11 + be3da1a commit b9e7771
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 52 deletions.
221 changes: 177 additions & 44 deletions bench/ann/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@

mpl.use("Agg") # noqa
import argparse
from collections import OrderedDict
import itertools
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import os


Expand All @@ -42,6 +44,16 @@
}
}

def positive_int(input_str: str) -> int:
try:
i = int(input_str)
if i < 1:
raise ValueError
except ValueError:
raise argparse.ArgumentTypeError(f"{input_str} is not a positive integer")

return i


def generate_n_colors(n):
vs = np.linspace(0.3, 0.9, 7)
Expand Down Expand Up @@ -76,42 +88,35 @@ def get_left_right(metric):
return "right"


def get_plot_label(xm, ym):
template = "%(xlabel)s-%(ylabel)s tradeoff - %(updown)s and" " to the %(leftright)s is better"
return template % {
"xlabel": xm["description"],
"ylabel": ym["description"],
"updown": get_up_down(ym),
"leftright": get_left_right(xm),
}


def create_pointset(data, xn, yn):
xm, ym = (metrics[xn], metrics[yn])
rev_y = -1 if ym["worst"] < 0 else 1
rev_x = -1 if xm["worst"] < 0 else 1
data.sort(key=lambda t: (rev_y * t[-1], rev_x * t[-2]))

axs, ays, als = [], [], []
axs, ays, als, aidxs = [], [], [], []
# Generate Pareto frontier
xs, ys, ls = [], [], []
xs, ys, ls, idxs = [], [], [], []
last_x = xm["worst"]
comparator = (lambda xv, lx: xv > lx) if last_x < 0 else (lambda xv, lx: xv < lx)
for algo_name, xv, yv in data:
for algo_name, index_name, xv, yv in data:
if not xv or not yv:
continue
axs.append(xv)
ays.append(yv)
als.append(algo_name)
aidxs.append(algo_name)
if comparator(xv, last_x):
last_x = xv
xs.append(xv)
ys.append(yv)
ls.append(algo_name)
return xs, ys, ls, axs, ays, als
idxs.append(index_name)
return xs, ys, ls, idxs, axs, ays, als, aidxs


def create_plot(all_data, raw, x_scale, y_scale, fn_out, linestyles):
def create_plot_search(all_data, raw, x_scale, y_scale, fn_out, linestyles,
dataset, k, batch_size):
xn = "k-nn"
yn = "qps"
xm, ym = (metrics[xn], metrics[yn])
Expand All @@ -122,13 +127,13 @@ def create_plot(all_data, raw, x_scale, y_scale, fn_out, linestyles):

# Sorting by mean y-value helps aligning plots with labels
def mean_y(algo):
xs, ys, ls, axs, ays, als = create_pointset(all_data[algo], xn, yn)
xs, ys, ls, idxs, axs, ays, als, aidxs = create_pointset(all_data[algo], xn, yn)
return -np.log(np.array(ys)).mean()

# Find range for logit x-scale
min_x, max_x = 1, 0
for algo in sorted(all_data.keys(), key=mean_y):
xs, ys, ls, axs, ays, als = create_pointset(all_data[algo], xn, yn)
xs, ys, ls, idxs, axs, ays, als, aidxs = create_pointset(all_data[algo], xn, yn)
min_x = min([min_x] + [x for x in xs if x > 0])
max_x = max([max_x] + [x for x in xs if x < 1])
color, faded, linestyle, marker = linestyles[algo]
Expand Down Expand Up @@ -169,7 +174,7 @@ def inv_fun(x):
else:
ax.set_xscale(x_scale)
ax.set_yscale(y_scale)
ax.set_title(get_plot_label(xm, ym))
ax.set_title(f"{dataset} k={k} batch_size={batch_size}")
plt.gca().get_position()
# plt.gca().set_position([box.x0, box.y0, box.width * 0.8, box.height])
ax.legend(handles, labels, loc="center left", bbox_to_anchor=(1, 0.5), prop={"size": 9})
Expand All @@ -188,39 +193,125 @@ def inv_fun(x):
# Workaround for bug https://github.com/matplotlib/matplotlib/issues/6789
ax.spines["bottom"]._adjust_location()

print(f"writing search output to {fn_out}")
plt.savefig(fn_out, bbox_inches="tight")
plt.close()


def load_all_results(dataset_path):
def create_plot_build(build_results, search_results, linestyles, fn_out,
dataset, k, batch_size):
xn = "k-nn"
yn = "qps"

recall_85 = [-1] * len(linestyles)
qps_85 = [-1] * len(linestyles)
bt_85 = [0] * len(linestyles)
i_85 = [-1] * len(linestyles)
recall_90 = [-1] * len(linestyles)
qps_90 = [-1] * len(linestyles)
bt_90 = [0] * len(linestyles)
i_90 = [-1] * len(linestyles)
recall_95 = [-1] * len(linestyles)
qps_95 = [-1] * len(linestyles)
bt_95 = [0] * len(linestyles)
i_95 = [-1] * len(linestyles)
data = OrderedDict()
colors = OrderedDict()

# Sorting by mean y-value helps aligning plots with labels
def mean_y(algo):
xs, ys, ls, idxs, axs, ays, als, aidxs = create_pointset(search_results[algo], xn, yn)
return -np.log(np.array(ys)).mean()

for pos, algo in enumerate(sorted(search_results.keys(), key=mean_y)):
xs, ys, ls, idxs, axs, ays, als, aidxs = create_pointset(search_results[algo], xn, yn)
# x is recall, y is qps, ls is algo_name, idxs is index_name
for i in range(len(xs)):
if xs[i] >= 0.85 and xs[i] < 0.9 and ys[i] > qps_85[pos]:
qps_85[pos] = ys[i]
bt_85[pos] = build_results[(ls[i], idxs[i])][0][2]
i_85[pos] = idxs[i]
elif xs[i] >= 0.9 and xs[i] < 0.95 and ys[i] > qps_90[pos]:
qps_90[pos] = ys[i]
bt_90[pos] = build_results[(ls[i], idxs[i])][0][2]
i_90[pos] = idxs[i]
elif xs[i] >= 0.95 and ys[i] > qps_95[pos]:
qps_95[pos] = ys[i]
bt_95[pos] = build_results[(ls[i], idxs[i])][0][2]
i_95[pos] = idxs[i]
data[algo] = [bt_85[pos], bt_90[pos], bt_95[pos]]
colors[algo] = linestyles[algo][0]

index = ['@85% Recall', '@90% Recall', '@95% Recall']

df = pd.DataFrame(data, index=index)
plt.figure(figsize=(12, 9))
ax = df.plot.bar(rot=0, color=colors)
fig = ax.get_figure()
print(f"writing search output to {fn_out}")
plt.title("Build Time for Highest QPS")
plt.suptitle(f"{dataset} k={k} batch_size={batch_size}")
plt.ylabel("Build Time (s)")
fig.savefig(fn_out)


def load_lines(results_path, result_files, method, index_key):
results = dict()
results_path = os.path.join(dataset_path, "result", "search")
for result_filepath in os.listdir(results_path):
with open(os.path.join(results_path, result_filepath), 'r') as f:

linebreaker = "name,iterations"

for result_filename in result_files:
with open(os.path.join(results_path, result_filename), 'r') as f:
lines = f.readlines()
lines = lines[:-1] if lines[-1] == "\n" else lines
idx = 0
for pos, line in enumerate(lines):
if "QPS" in line:
if linebreaker in line:
idx = pos
break

keys = lines[idx].split(',')
recall_idx = -1
qps_idx = -1
for pos, key in enumerate(keys):
if "Recall" in key:
recall_idx = pos
if "QPS" in key:
qps_idx = pos
if method == "build":
if "hnswlib" in result_filename:
key_idx = [2]
else:
key_idx = [10]
elif method == "search":
if "hnswlib" in result_filename:
key_idx = [10, 6]
else:
key_idx = [12, 10]

for line in lines[idx+1:]:
split_lines = line.split(',')

algo_name = split_lines[0].split('.')[0].strip("\"")
if algo_name not in results:
results[algo_name] = []
results[algo_name].append([algo_name, float(split_lines[recall_idx]),
float(split_lines[qps_idx])])
index_name = split_lines[0].split('/')[0].strip("\"")

if index_key == "algo":
dict_key = algo_name
elif index_key == "index":
dict_key = (algo_name, index_name)
if dict_key not in results:
results[dict_key] = []
to_add = [algo_name, index_name]
for key_i in key_idx:
to_add.append(float(split_lines[key_i]))
results[dict_key].append(to_add)

return results


def load_all_results(dataset_path, algorithms, k, batch_size, method, index_key):
results_path = os.path.join(dataset_path, "result", method)
result_files = os.listdir(results_path)
result_files = [result_filename for result_filename in result_files \
if f"{k}-{batch_size}" in result_filename]
if len(algorithms) > 0:
result_files = [result_filename for result_filename in result_files if \
result_filename.split('-')[0] in algorithms]

results = load_lines(results_path, result_files, method, index_key)

return results


Expand All @@ -232,8 +323,27 @@ def main():
parser.add_argument("--dataset-path", help="path to dataset folder",
default=os.path.join(os.getenv("RAFT_HOME"),
"bench", "ann", "data"))
parser.add_argument("--output-filename",
default="plot.png")
parser.add_argument("--output-filepath",
help="directory for PNG to be saved",
default=os.getcwd())
parser.add_argument("--algorithms",
help="plot only comma separated list of named \
algorithms",
default=None)
parser.add_argument(
"-k", "--count", default=10, type=positive_int, help="the number of nearest neighbors to search for"
)
parser.add_argument(
"-bs", "--batch-size", default=10000, type=positive_int, help="number of query vectors to use in each query trial"
)
parser.add_argument(
"--build",
action="store_true"
)
parser.add_argument(
"--search",
action="store_true"
)
parser.add_argument(
"--x-scale",
help="Scale to use when drawing the X-axis. \
Expand All @@ -249,15 +359,38 @@ def main():
parser.add_argument(
"--raw", help="Show raw results (not just Pareto frontier) in faded colours", action="store_true"
)
args = parser.parse_args()

output_filepath = os.path.join(args.dataset_path, args.dataset, args.output_filename)
print(f"writing output to {output_filepath}")

results = load_all_results(os.path.join(args.dataset_path, args.dataset))
linestyles = create_linestyles(sorted(results.keys()))
args = parser.parse_args()

create_plot(results, args.raw, args.x_scale, args.y_scale, output_filepath, linestyles)
if args.algorithms:
algorithms = args.algorithms.split(',')
else:
algorithms = []
k = args.count
batch_size = args.batch_size
if not args.build and not args.search:
build = True
search = True
else:
build = args.build
search = args.search

search_output_filepath = os.path.join(args.output_filepath, f"search-{args.dataset}-{k}-{batch_size}.png")
build_output_filepath = os.path.join(args.output_filepath, f"build-{args.dataset}-{k}-{batch_size}.png")

search_results = load_all_results(
os.path.join(args.dataset_path, args.dataset),
algorithms, k, batch_size, "search", "algo")
linestyles = create_linestyles(sorted(search_results.keys()))
if search:
create_plot_search(search_results, args.raw, args.x_scale, args.y_scale,
search_output_filepath, linestyles, args.dataset, k, batch_size)
if build:
build_results = load_all_results(
os.path.join(args.dataset_path, args.dataset),
algorithms, k, batch_size, "build", "index")
create_plot_build(build_results, search_results, linestyles, build_output_filepath,
args.dataset, k, batch_size)


if __name__ == "__main__":
Expand Down
12 changes: 6 additions & 6 deletions bench/ann/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ def validate_algorithm(algos_conf, algo):
return algo in algos_conf_keys and not algos_conf[algo]["disabled"]


def find_executable(algos_conf, algo):
def find_executable(algos_conf, algo, k, batch_size):
executable = algos_conf[algo]["executable"]
conda_path = os.path.join(os.getenv("CONDA_PREFIX"), "bin", "ann",
executable)
build_path = os.path.join(os.getenv("RAFT_HOME"), "cpp", "build", executable)
if os.path.exists(conda_path):
return (executable, conda_path, algo)
return (executable, conda_path, f"{algo}-{k}-{batch_size}")
elif os.path.exists(build_path):
return (executable, build_path, algo)
return (executable, build_path, f"{algo}-{k}-{batch_size}")
else:
raise FileNotFoundError(executable)

Expand Down Expand Up @@ -198,7 +198,7 @@ def main():
curr_algo = index["algo"]
if index["name"] in indices and \
validate_algorithm(algos_conf, curr_algo):
executable_path = find_executable(algos_conf, curr_algo)
executable_path = find_executable(algos_conf, curr_algo, k, batch_size)
if executable_path not in executables_to_run:
executables_to_run[executable_path] = {"index": []}
executables_to_run[executable_path]["index"].append(index)
Expand All @@ -212,7 +212,7 @@ def main():
curr_algo = index["algo"]
if curr_algo in algorithms and \
validate_algorithm(algos_conf, curr_algo):
executable_path = find_executable(algos_conf, curr_algo)
executable_path = find_executable(algos_conf, curr_algo, k, batch_size)
if executable_path not in executables_to_run:
executables_to_run[executable_path] = {"index": []}
executables_to_run[executable_path]["index"].append(index)
Expand All @@ -222,7 +222,7 @@ def main():
for index in conf_file["index"]:
curr_algo = index["algo"]
if validate_algorithm(algos_conf, curr_algo):
executable_path = find_executable(algos_conf, curr_algo)
executable_path = find_executable(algos_conf, curr_algo, k, batch_size)
if executable_path not in executables_to_run:
executables_to_run[executable_path] = {"index": []}
executables_to_run[executable_path]["index"].append(index)
Expand Down
5 changes: 3 additions & 2 deletions docs/source/raft_ann_benchmarks.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,15 @@ CSV file in `<dataset-path/<dataset>/search/result/<algo.csv>.

The usage of this script is:
```bash
usage: plot.py [-h] [--dataset DATASET] [--dataset-path DATASET_PATH] [--output-filename OUTPUT_FILENAME] [--x-scale X_SCALE] [--y-scale {linear,log,symlog,logit}] [--raw]
usage: plot.py [-h] [--dataset DATASET] [--dataset-path DATASET_PATH] [--output-filepath OUTPUT_FILEPATH] [--x-scale X_SCALE] [--y-scale {linear,log,symlog,logit}] [--raw]
options:
-h, --help show this help message and exit
--dataset DATASET dataset to download (default: glove-100-inner)
--dataset-path DATASET_PATH
path to dataset folder (default: ${RAFT_HOME}/bench/ann/data)
--output-filename OUTPUT_FILENAME
--output-filepath OUTPUT_FILEPATH
directory for PNG to be saved (default: os.getcwd())
--x-scale X_SCALE Scale to use when drawing the X-axis. Typically linear, logit or a2 (default: linear)
--y-scale {linear,log,symlog,logit}
Scale to use when drawing the Y-axis (default: linear)
Expand Down

0 comments on commit b9e7771

Please sign in to comment.