diff --git a/benchs/bench_fw_notebook.ipynb b/benchs/bench_fw_notebook.ipynb index 5752aaf5fb..c38ed11068 100644 --- a/benchs/bench_fw_notebook.ipynb +++ b/benchs/bench_fw_notebook.ipynb @@ -1,529 +1,532 @@ { - "cells": [ - { - "cell_type": "code", - "execution_count": null, - "id": "be081589-e1b2-4569-acb7-44203e273899", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "import matplotlib.pyplot as plt\n", - "import itertools\n", - "from faiss.contrib.evaluation import OperatingPoints\n", - "from enum import Enum\n", - "from bench_fw.benchmark_io import BenchmarkIO as BIO\n", - "from bench_fw.utils import filter_results, ParetoMode, ParetoMetric\n", - "from copy import copy\n", - "import numpy as np\n", - "import datetime\n", - "import glob\n", - "import io\n", - "import json\n", - "from zipfile import ZipFile\n", - "import tabulate" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a6492e95-24c7-4425-bf0a-27e10e879ca6", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "root = \"/checkpoint/gsz/bench_fw/optimize/bigann\"\n", - "results = BIO(root).read_json(\"result_std_d_bigann10M.json\")\n", - "results.keys()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0875d269-aef4-426d-83dd-866970f43777", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "results['experiments']" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f080a6e2-1565-418b-8732-4adeff03a099", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "def plot_metric(experiments, accuracy_title, cost_title, plot_space=False, plot=None):\n", - " if plot is None:\n", - " plot = plt.subplot()\n", - " x = {}\n", - " y = {}\n", - " for accuracy, space, time, k, v in experiments:\n", - " idx_name = v['index'] + (\"snap\" if 'search_params' in v and v['search_params'][\"snap\"] == 1 else \"\")\n", - " if idx_name not in x:\n", - " x[idx_name] = []\n", - " y[idx_name] = []\n", - " x[idx_name].append(accuracy)\n", - " if plot_space:\n", - " y[idx_name].append(space)\n", - " else:\n", - " y[idx_name].append(time)\n", - "\n", - " #plt.figure(figsize=(10,6))\n", - " #plt.title(accuracy_title)\n", - " plot.set_xlabel(accuracy_title)\n", - " plot.set_ylabel(cost_title)\n", - " plot.set_yscale(\"log\")\n", - " marker = itertools.cycle((\"o\", \"v\", \"^\", \"<\", \">\", \"s\", \"p\", \"P\", \"*\", \"h\", \"X\", \"D\")) \n", - " for index in x.keys():\n", - " plot.plot(x[index], y[index], marker=next(marker), label=index, linewidth=0)\n", - " plot.legend(bbox_to_anchor=(1, 1), loc='upper left')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "61007155-5edc-449e-835e-c141a01a2ae5", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# index local optima\n", - "accuracy_metric = \"knn_intersection\"\n", - "fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1, min_accuracy=0.95)\n", - "plot_metric(fr, accuracy_title=\"knn intersection\", cost_title=\"time (seconds, 32 cores)\", plot_space=False)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f9f94dcc-5abe-4cad-9619-f5d1d24fb8c1", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "# global optima\n", - "accuracy_metric = \"knn_intersection\"\n", - "fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, min_accuracy=0.90, max_space=64, max_time=0, name_filter=lambda n: not n.startswith(\"Flat\"), pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", - "plot_metric(fr, accuracy_title=\"knn intersection\", cost_title=\"time (seconds, 32 cores)\", plot_space=False)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0c10f587-26ef-49ec-83a9-88f6a2a433e8", - "metadata": {}, - "outputs": [], - "source": [ - "def pretty_params(p):\n", - " p = copy(p)\n", - " if 'snap' in p and p['snap'] == 0:\n", - " del p['snap']\n", - " return p\n", - " \n", - "tabulate.tabulate([(accuracy, space, time, v['factory'], pretty_params(v['construction_params'][1]), pretty_params(v['search_params'])) \n", - " for accuracy, space, time, k, v in fr],\n", - " tablefmt=\"html\",\n", - " headers=[\"accuracy\",\"space\", \"time\", \"factory\", \"quantizer cfg\", \"search cfg\"])" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "36e82084-18f6-4546-a717-163eb0224ee8", - "metadata": {}, - "outputs": [], - "source": [ - "# index local optima @ precision 0.8\n", - "precision = 0.8\n", - "accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n", - "fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", - "plot_metric(fr, accuracy_title=f\"range recall @ precision {precision}\", cost_title=\"time (seconds, 16 cores)\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "aff79376-39f7-47c0-8b83-1efe5192bb7e", - "metadata": {}, - "outputs": [], - "source": [ - "# index local optima @ precision 0.2\n", - "precision = 0.2\n", - "accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n", - "fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", - "plot_metric(fr, accuracy_title=f\"range recall @ precision {precision}\", cost_title=\"time (seconds, 16 cores)\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b4834f1f-bbbe-4cae-9aa0-a459b0c842d1", - "metadata": {}, - "outputs": [], - "source": [ - "# global optima @ precision 0.8\n", - "precision = 0.8\n", - "accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n", - "fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", - "plot_metric(fr, accuracy_title=f\"range recall @ precision {precision}\", cost_title=\"time (seconds, 16 cores)\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9aead830-6209-4956-b7ea-4a5e0029d616", - "metadata": {}, - "outputs": [], - "source": [ - "def plot_range_search_pr_curves(experiments):\n", - " x = {}\n", - " y = {}\n", - " show = {\n", - " 'Flat': None,\n", - " }\n", - " for _, _, _, k, v in fr:\n", - " if \".weighted\" in k: # and v['index'] in show:\n", - " x[k] = v['range_search_pr']['recall']\n", - " y[k] = v['range_search_pr']['precision']\n", - " \n", - " plt.title(\"range search recall\")\n", - " plt.xlabel(\"recall\")\n", - " plt.ylabel(\"precision\")\n", - " for index in x.keys():\n", - " plt.plot(x[index], y[index], '.', label=index)\n", - " plt.legend(bbox_to_anchor=(1.0, 1.0), loc='upper left')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "92e45502-7a31-4a15-90df-fa3032d7d350", - "metadata": {}, - "outputs": [], - "source": [ - "precision = 0.8\n", - "accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n", - "fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME_SPACE, scaling_factor=1)\n", - "plot_range_search_pr_curves(fr)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "fdf8148a-0da6-4c5e-8d60-f8f85314574c", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "root = \"/checkpoint/gsz/bench_fw/ivf/bigann\"\n", - "scales = [1, 2, 5, 10, 20, 50]\n", - "fig, plots = plt.subplots(len(scales), sharex=True, figsize=(5,25))\n", - "fig.tight_layout()\n", - "for plot, scale in zip(plots, scales, strict=True):\n", - " results = BIO(root).read_json(f\"result{scale}.json\")\n", - " accuracy_metric = \"knn_intersection\"\n", - " fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, min_accuracy=0.9, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", - " plot_metric(fr, accuracy_title=\"knn intersection\", cost_title=\"time (seconds, 64 cores)\", plot=plot)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e503828c-ee61-45f7-814b-cce6461109bc", - "metadata": {}, - "outputs": [], - "source": [ - "x = {}\n", - "y = {}\n", - "accuracy=0.9\n", - "root = \"/checkpoint/gsz/bench_fw/ivf/bigann\"\n", - "scales = [1, 2, 5, 10, 20, 50]\n", - "#fig, plots = plt.subplots(len(scales), sharex=True, figsize=(5,25))\n", - "#fig.tight_layout()\n", - "for scale in scales:\n", - " results = BIO(root).read_json(f\"result{scale}.json\")\n", - " scale *= 1_000_000\n", - " accuracy_metric = \"knn_intersection\"\n", - " fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, min_accuracy=accuracy, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", - " seen = set()\n", - " print(scale)\n", - " for _, _, _, _, exp in fr:\n", - " fact = exp[\"factory\"]\n", - " # \"HNSW\" in fact or \n", - " if fact in seen or fact in [\"Flat\", \"IVF512,Flat\", \"IVF1024,Flat\", \"IVF2048,Flat\"]:\n", - " continue\n", - " seen.add(fact)\n", - " if fact not in x:\n", - " x[fact] = []\n", - " y[fact] = []\n", - " x[fact].append(scale)\n", - " y[fact].append(exp[\"time\"] + exp[\"quantizer\"][\"time\"])\n", - " if (exp[\"knn_intersection\"] > 0.92):\n", - " print(fact)\n", - " print(exp[\"search_params\"])\n", - " print(exp[\"knn_intersection\"])\n", - "\n", - " #plot_metric(fr, accuracy_title=\"knn intersection\", cost_title=\"time (seconds, 64 cores)\", plot=plot)\n", - " \n", - "plt.title(f\"recall @ 1 = {accuracy*100}%\")\n", - "plt.xlabel(\"database size\")\n", - "plt.ylabel(\"time\")\n", - "plt.xscale(\"log\")\n", - "plt.yscale(\"log\")\n", - "\n", - "marker = itertools.cycle((\"o\", \"v\", \"^\", \"<\", \">\", \"s\", \"p\", \"P\", \"*\", \"h\", \"X\", \"D\")) \n", - "for index in x.keys():\n", - " if \"HNSW\" in index:\n", - " plt.plot(x[index], y[index], label=index, linewidth=1, marker=next(marker), linestyle=\"dashed\")\n", - " else:\n", - " plt.plot(x[index], y[index], label=index, linewidth=1, marker=next(marker))\n", - "plt.legend(bbox_to_anchor=(1.0, 1.0), loc='upper left')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "37a99bb2-f998-461b-a345-7cc6e702cb3a", - "metadata": {}, - "outputs": [], - "source": [ - "# global optima\n", - "accuracy_metric = \"sym_recall\"\n", - "fr = filter_results(results, evaluation=\"rec\", accuracy_metric=accuracy_metric, time_metric=lambda e:e['encode_time'], min_accuracy=0.9, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.SPACE, scaling_factor=1)\n", - "plot_metric(fr, accuracy_title=\"knn intersection\", cost_title=\"space\", plot_space=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c973ce4e-3566-4f02-bd93-f113e3e0c791", - "metadata": {}, - "outputs": [], - "source": [ - "def pretty_time(s):\n", - " if s is None:\n", - " return \"None\"\n", - " s = int(s * 1000) / 1000\n", - " m, s = divmod(s, 60)\n", - " h, m = divmod(m, 60)\n", - " d, h = divmod(h, 24)\n", - " r = \"\"\n", - " if d > 0:\n", - " r += f\"{int(d)}d \"\n", - " if h > 0:\n", - " r += f\"{int(h)}h \"\n", - " if m > 0:\n", - " r += f\"{int(m)}m \"\n", - " if s > 0 or len(r) == 0:\n", - " r += f\"{s:.3f}s\"\n", - " return r\n", - "\n", - "def pretty_size(s):\n", - " if s > 1024 * 1024:\n", - " return f\"{s / 1024 / 1024:.1f}\".rstrip('0').rstrip('.') + \"MB\"\n", - " if s > 1024:\n", - " return f\"{s / 1024:.1f}\".rstrip('0').rstrip('.') + \"KB\"\n", - " return f\"{s}\"\n", - "\n", - "def pretty_mse(m):\n", - " if m is None:\n", - " return \"None\"\n", - " else:\n", - " return f\"{m:.6f}\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1ddcf226-fb97-4a59-9fc3-3ed8f7d5e703", - "metadata": {}, - "outputs": [], - "source": [ - "data = {}\n", - "root = \"/checkpoint/gsz/bench_fw/bigann\"\n", - "scales = [1, 2, 5, 10, 20, 50]\n", - "for scale in scales:\n", - " results = BIO(root).read_json(f\"result{scale}.json\")\n", - " accuracy_metric = \"knn_intersection\"\n", - " fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, min_accuracy=0, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", - " d = {}\n", - " data[f\"{scale}M\"] = d\n", - " for _, _, _, _, exp in fr:\n", - " fact = exp[\"factory\"]\n", - " # \"HNSW\" in fact or \n", - " if fact in [\"Flat\", \"IVF512,Flat\", \"IVF1024,Flat\", \"IVF2048,Flat\"]:\n", - " continue\n", - " if fact not in d:\n", - " d[fact] = []\n", - " d[fact].append({\n", - " \"nprobe\": exp[\"search_params\"][\"nprobe\"],\n", - " \"recall\": exp[\"knn_intersection\"],\n", - " \"time\": exp[\"time\"] + exp[\"quantizer\"][\"time\"],\n", - " })\n", - "data\n", - "# with open(\"/checkpoint/gsz/bench_fw/codecs.json\", \"w\") as f:\n", - "# json.dump(data, f)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e54eebb6-0a9f-4a72-84d2-f12c5bd44510", - "metadata": {}, - "outputs": [], - "source": [ - "ds = \"deep1b\"\n", - "data = []\n", - "jss = []\n", - "root = f\"/checkpoint/gsz/bench_fw/codecs/{ds}\"\n", - "results = BIO(root).read_json(f\"result.json\")\n", - "for k, e in results[\"experiments\"].items():\n", - " if \"rec\" in k and e['factory'] != 'Flat': # and e['sym_recall'] > 0.0: # and \"PRQ\" in e['factory'] and e['sym_recall'] > 0.0:\n", - " code_size = results['indices'][e['codec']]['sa_code_size']\n", - " codec_size = results['indices'][e['codec']]['codec_size']\n", - " training_time = results['indices'][e['codec']]['training_time']\n", - " # training_size = results['indices'][e['codec']]['training_size']\n", - " cpu = e['cpu'] if 'cpu' in e else \"\"\n", - " ps = ', '.join([f\"{k}={v}\" for k,v in e['construction_params'][0].items()]) if e['construction_params'] else \" \"\n", - " eps = ', '.join([f\"{k}={v}\" for k,v in e['reconstruct_params'].items() if k != \"snap\"]) if e['reconstruct_params'] else \" \"\n", - " data.append((code_size, f\"|{e['factory']}|{ps}|{eps}|{code_size}|{pretty_size(codec_size)}|{pretty_time(training_time)}|{training_size}|{pretty_mse(e['mse'])}|{e['sym_recall']}|{e['asym_recall']}|{pretty_time(e['encode_time'])}|{pretty_time(e['decode_time'])}|{cpu}|\"))\n", - " jss.append({\n", - " 'factory': e['factory'],\n", - " 'parameters': e['construction_params'][0] if e['construction_params'] else \"\",\n", - " 'evaluation_params': e['reconstruct_params'],\n", - " 'code_size': code_size,\n", - " 'codec_size': codec_size,\n", - " 'training_time': training_time,\n", - " 'training_size': training_size,\n", - " 'mse': e['mse'],\n", - " 'sym_recall': e['sym_recall'],\n", - " 'asym_recall': e['asym_recall'],\n", - " 'encode_time': e['encode_time'],\n", - " 'decode_time': e['decode_time'],\n", - " 'cpu': cpu,\n", - " })\n", - "\n", - "print(\"|factory key|construction parameters|evaluation parameters|code size|codec size|training time|training size|mean squared error|sym recall @ 1|asym recall @ 1|encode time|decode time|cpu|\")\n", - "print(\"|-|-|-|-|-|-|-|-|-|\")\n", - "data.sort()\n", - "for d in data:\n", - " print(d[1])\n", - "\n", - "with open(f\"/checkpoint/gsz/bench_fw/codecs_{ds}_test.json\", \"w\") as f:\n", - " json.dump(jss, f)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d1216733-9670-407c-b3d2-5f87bce0321c", - "metadata": {}, - "outputs": [], - "source": [ - "def read_file(filename: str, keys):\n", - " results = []\n", - " with ZipFile(filename, \"r\") as zip_file:\n", - " for key in keys:\n", - " with zip_file.open(key, \"r\") as f:\n", - " if key in [\"D\", \"I\", \"R\", \"lims\"]:\n", - " results.append(np.load(f))\n", - " elif key in [\"P\"]:\n", - " t = io.TextIOWrapper(f)\n", - " results.append(json.load(t))\n", - " else:\n", - " raise AssertionError()\n", - " return results" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "56de051e-22db-4bef-b242-1ddabc9e0bb9", - "metadata": {}, - "outputs": [], - "source": [ - "ds = \"contriever\"\n", - "data = []\n", - "jss = []\n", - "root = f\"/checkpoint/gsz/bench_fw/codecs/{ds}\"\n", - "for lf in glob.glob(root + '/*rec*.zip'):\n", - " e, = read_file(lf, ['P'])\n", - " if e['factory'] != 'Flat': # and e['sym_recall'] > 0.0: # and \"PRQ\" in e['factory'] and e['sym_recall'] > 0.0:\n", - " code_size = e['codec_meta']['sa_code_size']\n", - " codec_size = e['codec_meta']['codec_size']\n", - " training_time = e['codec_meta']['training_time']\n", - " training_size = None # e['codec_meta']['training_size']\n", - " cpu = e['cpu'] if 'cpu' in e else \"\"\n", - " ps = ', '.join([f\"{k}={v}\" for k,v in e['construction_params'][0].items()]) if e['construction_params'] else \" \"\n", - " eps = ', '.join([f\"{k}={v}\" for k,v in e['reconstruct_params'].items() if k != \"snap\"]) if e['reconstruct_params'] else \" \"\n", - " if eps in ps and eps != \"encode_ils_iters=16\" and eps != \"max_beam_size=32\":\n", - " eps = \" \"\n", - " data.append((code_size, f\"|{e['factory']}|{ps}|{eps}|{code_size}|{pretty_size(codec_size)}|{pretty_time(training_time)}|{pretty_mse(e['mse'])}|{e['sym_recall']}|{e['asym_recall']}|{pretty_time(e['encode_time'])}|{pretty_time(e['decode_time'])}|{cpu}|\"))\n", - " eps = e['reconstruct_params']\n", - " del eps['snap']\n", - " params = copy(e['construction_params'][0]) if e['construction_params'] else {}\n", - " for k, v in e['reconstruct_params'].items():\n", - " params[k] = v\n", - " jss.append({\n", - " 'factory': e['factory'],\n", - " 'params': params,\n", - " 'construction_params': e['construction_params'][0] if e['construction_params'] else {},\n", - " 'evaluation_params': e['reconstruct_params'],\n", - " 'code_size': code_size,\n", - " 'codec_size': codec_size,\n", - " 'training_time': training_time,\n", - " # 'training_size': training_size,\n", - " 'mse': e['mse'],\n", - " 'sym_recall': e['sym_recall'],\n", - " 'asym_recall': e['asym_recall'],\n", - " 'encode_time': e['encode_time'],\n", - " 'decode_time': e['decode_time'],\n", - " 'cpu': cpu,\n", - " })\n", - "\n", - "print(\"|factory key|construction parameters|encode/decode parameters|code size|codec size|training time|mean squared error|sym recall @ 1|asym recall @ 1|encode time|decode time|cpu|\")\n", - "print(\"|-|-|-|-|-|-|-|-|-|\")\n", - "data.sort()\n", - "# for d in data:\n", - "# print(d[1])\n", - "\n", - "print(len(data))\n", - "\n", - "with open(f\"/checkpoint/gsz/bench_fw/codecs_{ds}_5.json\", \"w\") as f:\n", - " json.dump(jss, f)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python [conda env:.conda-faiss_from_source] *", - "language": "python", - "name": "conda-env-.conda-faiss_from_source-py" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.5" - } - }, - "nbformat": 4, - "nbformat_minor": 5 - } + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "id": "be081589-e1b2-4569-acb7-44203e273899", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import itertools\n", + "from faiss.contrib.evaluation import OperatingPoints\n", + "from enum import Enum\n", + "from faiss.benchs.bench_fw.benchmark_io import BenchmarkIO as BIO\n", + "from faiss.benchs.bench_fw.utils import filter_results, ParetoMode, ParetoMetric\n", + "from copy import copy\n", + "import numpy as np\n", + "import datetime\n", + "import glob\n", + "import io\n", + "import json\n", + "from zipfile import ZipFile\n", + "import tabulate" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a6492e95-24c7-4425-bf0a-27e10e879ca6", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import getpass\n", + "username = getpass.getuser()\n", + "root = f\"/home/{username}/simsearch/data/ivf/results/sift1M\"\n", + "results = BIO(root).read_json(\"result.json\")\n", + "results.keys()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0875d269-aef4-426d-83dd-866970f43777", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "results['experiments']" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "f080a6e2-1565-418b-8732-4adeff03a099", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "def plot_metric(experiments, accuracy_title, cost_title, plot_space=False, plot=None):\n", + " if plot is None:\n", + " plot = plt.subplot()\n", + " x = {}\n", + " y = {}\n", + " for accuracy, space, time, k, v in experiments:\n", + " idx_name = v['index'] + (\"snap\" if 'search_params' in v and v['search_params'][\"snap\"] == 1 else \"\")\n", + " if idx_name not in x:\n", + " x[idx_name] = []\n", + " y[idx_name] = []\n", + " x[idx_name].append(accuracy)\n", + " if plot_space:\n", + " y[idx_name].append(space)\n", + " else:\n", + " y[idx_name].append(time)\n", + "\n", + " #plt.figure(figsize=(10,6))\n", + " #plt.title(accuracy_title)\n", + " plot.set_xlabel(accuracy_title)\n", + " plot.set_ylabel(cost_title)\n", + " plot.set_yscale(\"log\")\n", + " marker = itertools.cycle((\"o\", \"v\", \"^\", \"<\", \">\", \"s\", \"p\", \"P\", \"*\", \"h\", \"X\", \"D\")) \n", + " for index in x.keys():\n", + " plot.plot(x[index], y[index], marker=next(marker), label=index, linewidth=0)\n", + " plot.legend(bbox_to_anchor=(1, 1), loc='upper left')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "61007155-5edc-449e-835e-c141a01a2ae5", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# index local optima\n", + "accuracy_metric = \"knn_intersection\"\n", + "fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1, min_accuracy=0.95)\n", + "plot_metric(fr, accuracy_title=\"knn intersection\", cost_title=\"time (seconds, 32 cores)\", plot_space=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f9f94dcc-5abe-4cad-9619-f5d1d24fb8c1", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# global optima\n", + "accuracy_metric = \"knn_intersection\"\n", + "fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, min_accuracy=0.25, name_filter=lambda n: not n.startswith(\"Flat\"), pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", + "#fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, min_accuracy=0.90, max_space=64, max_time=0, name_filter=lambda n: not n.startswith(\"Flat\"), pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", + "plot_metric(fr, accuracy_title=\"knn intersection\", cost_title=\"time (seconds, 32 cores)\", plot_space=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0c10f587-26ef-49ec-83a9-88f6a2a433e8", + "metadata": {}, + "outputs": [], + "source": [ + "def pretty_params(p):\n", + " p = copy(p)\n", + " if 'snap' in p and p['snap'] == 0:\n", + " del p['snap']\n", + " return p\n", + " \n", + "tabulate.tabulate([(accuracy, space, time, v['factory'], pretty_params(v['construction_params'][1]), pretty_params(v['search_params'])) \n", + " for accuracy, space, time, k, v in fr],\n", + " tablefmt=\"html\",\n", + " headers=[\"accuracy\",\"space\", \"time\", \"factory\", \"quantizer cfg\", \"search cfg\"])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "36e82084-18f6-4546-a717-163eb0224ee8", + "metadata": {}, + "outputs": [], + "source": [ + "# index local optima @ precision 0.8\n", + "precision = 0.8\n", + "accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n", + "fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", + "plot_metric(fr, accuracy_title=f\"range recall @ precision {precision}\", cost_title=\"time (seconds, 16 cores)\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aff79376-39f7-47c0-8b83-1efe5192bb7e", + "metadata": {}, + "outputs": [], + "source": [ + "# index local optima @ precision 0.2\n", + "precision = 0.2\n", + "accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n", + "fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", + "plot_metric(fr, accuracy_title=f\"range recall @ precision {precision}\", cost_title=\"time (seconds, 16 cores)\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4834f1f-bbbe-4cae-9aa0-a459b0c842d1", + "metadata": {}, + "outputs": [], + "source": [ + "# global optima @ precision 0.8\n", + "precision = 0.8\n", + "accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n", + "fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", + "plot_metric(fr, accuracy_title=f\"range recall @ precision {precision}\", cost_title=\"time (seconds, 16 cores)\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9aead830-6209-4956-b7ea-4a5e0029d616", + "metadata": {}, + "outputs": [], + "source": [ + "def plot_range_search_pr_curves(experiments):\n", + " x = {}\n", + " y = {}\n", + " show = {\n", + " 'Flat': None,\n", + " }\n", + " for _, _, _, k, v in fr:\n", + " if \".weighted\" in k: # and v['index'] in show:\n", + " x[k] = v['range_search_pr']['recall']\n", + " y[k] = v['range_search_pr']['precision']\n", + " \n", + " plt.title(\"range search recall\")\n", + " plt.xlabel(\"recall\")\n", + " plt.ylabel(\"precision\")\n", + " for index in x.keys():\n", + " plt.plot(x[index], y[index], '.', label=index)\n", + " plt.legend(bbox_to_anchor=(1.0, 1.0), loc='upper left')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "92e45502-7a31-4a15-90df-fa3032d7d350", + "metadata": {}, + "outputs": [], + "source": [ + "precision = 0.8\n", + "accuracy_metric = lambda exp: range_search_recall_at_precision(exp, precision)\n", + "fr = filter_results(results, evaluation=\"weighted\", accuracy_metric=accuracy_metric, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME_SPACE, scaling_factor=1)\n", + "plot_range_search_pr_curves(fr)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fdf8148a-0da6-4c5e-8d60-f8f85314574c", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "root = \"/checkpoint/gsz/bench_fw/ivf/bigann\"\n", + "scales = [1, 2, 5, 10, 20, 50]\n", + "fig, plots = plt.subplots(len(scales), sharex=True, figsize=(5,25))\n", + "fig.tight_layout()\n", + "for plot, scale in zip(plots, scales, strict=True):\n", + " results = BIO(root).read_json(f\"result{scale}.json\")\n", + " accuracy_metric = \"knn_intersection\"\n", + " fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, min_accuracy=0.9, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", + " plot_metric(fr, accuracy_title=\"knn intersection\", cost_title=\"time (seconds, 64 cores)\", plot=plot)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e503828c-ee61-45f7-814b-cce6461109bc", + "metadata": {}, + "outputs": [], + "source": [ + "x = {}\n", + "y = {}\n", + "accuracy=0.9\n", + "root = \"/checkpoint/gsz/bench_fw/ivf/bigann\"\n", + "scales = [1, 2, 5, 10, 20, 50]\n", + "#fig, plots = plt.subplots(len(scales), sharex=True, figsize=(5,25))\n", + "#fig.tight_layout()\n", + "for scale in scales:\n", + " results = BIO(root).read_json(f\"result{scale}.json\")\n", + " scale *= 1_000_000\n", + " accuracy_metric = \"knn_intersection\"\n", + " fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, min_accuracy=accuracy, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", + " seen = set()\n", + " print(scale)\n", + " for _, _, _, _, exp in fr:\n", + " fact = exp[\"factory\"]\n", + " # \"HNSW\" in fact or \n", + " if fact in seen or fact in [\"Flat\", \"IVF512,Flat\", \"IVF1024,Flat\", \"IVF2048,Flat\"]:\n", + " continue\n", + " seen.add(fact)\n", + " if fact not in x:\n", + " x[fact] = []\n", + " y[fact] = []\n", + " x[fact].append(scale)\n", + " y[fact].append(exp[\"time\"] + exp[\"quantizer\"][\"time\"])\n", + " if (exp[\"knn_intersection\"] > 0.92):\n", + " print(fact)\n", + " print(exp[\"search_params\"])\n", + " print(exp[\"knn_intersection\"])\n", + "\n", + " #plot_metric(fr, accuracy_title=\"knn intersection\", cost_title=\"time (seconds, 64 cores)\", plot=plot)\n", + " \n", + "plt.title(f\"recall @ 1 = {accuracy*100}%\")\n", + "plt.xlabel(\"database size\")\n", + "plt.ylabel(\"time\")\n", + "plt.xscale(\"log\")\n", + "plt.yscale(\"log\")\n", + "\n", + "marker = itertools.cycle((\"o\", \"v\", \"^\", \"<\", \">\", \"s\", \"p\", \"P\", \"*\", \"h\", \"X\", \"D\")) \n", + "for index in x.keys():\n", + " if \"HNSW\" in index:\n", + " plt.plot(x[index], y[index], label=index, linewidth=1, marker=next(marker), linestyle=\"dashed\")\n", + " else:\n", + " plt.plot(x[index], y[index], label=index, linewidth=1, marker=next(marker))\n", + "plt.legend(bbox_to_anchor=(1.0, 1.0), loc='upper left')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "37a99bb2-f998-461b-a345-7cc6e702cb3a", + "metadata": {}, + "outputs": [], + "source": [ + "# global optima\n", + "accuracy_metric = \"sym_recall\"\n", + "fr = filter_results(results, evaluation=\"rec\", accuracy_metric=accuracy_metric, time_metric=lambda e:e['encode_time'], min_accuracy=0.9, pareto_mode=ParetoMode.GLOBAL, pareto_metric=ParetoMetric.SPACE, scaling_factor=1)\n", + "plot_metric(fr, accuracy_title=\"knn intersection\", cost_title=\"space\", plot_space=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c973ce4e-3566-4f02-bd93-f113e3e0c791", + "metadata": {}, + "outputs": [], + "source": [ + "def pretty_time(s):\n", + " if s is None:\n", + " return \"None\"\n", + " s = int(s * 1000) / 1000\n", + " m, s = divmod(s, 60)\n", + " h, m = divmod(m, 60)\n", + " d, h = divmod(h, 24)\n", + " r = \"\"\n", + " if d > 0:\n", + " r += f\"{int(d)}d \"\n", + " if h > 0:\n", + " r += f\"{int(h)}h \"\n", + " if m > 0:\n", + " r += f\"{int(m)}m \"\n", + " if s > 0 or len(r) == 0:\n", + " r += f\"{s:.3f}s\"\n", + " return r\n", + "\n", + "def pretty_size(s):\n", + " if s > 1024 * 1024:\n", + " return f\"{s / 1024 / 1024:.1f}\".rstrip('0').rstrip('.') + \"MB\"\n", + " if s > 1024:\n", + " return f\"{s / 1024:.1f}\".rstrip('0').rstrip('.') + \"KB\"\n", + " return f\"{s}\"\n", + "\n", + "def pretty_mse(m):\n", + " if m is None:\n", + " return \"None\"\n", + " else:\n", + " return f\"{m:.6f}\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1ddcf226-fb97-4a59-9fc3-3ed8f7d5e703", + "metadata": {}, + "outputs": [], + "source": [ + "data = {}\n", + "root = \"/checkpoint/gsz/bench_fw/bigann\"\n", + "scales = [1, 2, 5, 10, 20, 50]\n", + "for scale in scales:\n", + " results = BIO(root).read_json(f\"result{scale}.json\")\n", + " accuracy_metric = \"knn_intersection\"\n", + " fr = filter_results(results, evaluation=\"knn\", accuracy_metric=accuracy_metric, min_accuracy=0, pareto_mode=ParetoMode.INDEX, pareto_metric=ParetoMetric.TIME, scaling_factor=1)\n", + " d = {}\n", + " data[f\"{scale}M\"] = d\n", + " for _, _, _, _, exp in fr:\n", + " fact = exp[\"factory\"]\n", + " # \"HNSW\" in fact or \n", + " if fact in [\"Flat\", \"IVF512,Flat\", \"IVF1024,Flat\", \"IVF2048,Flat\"]:\n", + " continue\n", + " if fact not in d:\n", + " d[fact] = []\n", + " d[fact].append({\n", + " \"nprobe\": exp[\"search_params\"][\"nprobe\"],\n", + " \"recall\": exp[\"knn_intersection\"],\n", + " \"time\": exp[\"time\"] + exp[\"quantizer\"][\"time\"],\n", + " })\n", + "data\n", + "# with open(\"/checkpoint/gsz/bench_fw/codecs.json\", \"w\") as f:\n", + "# json.dump(data, f)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e54eebb6-0a9f-4a72-84d2-f12c5bd44510", + "metadata": {}, + "outputs": [], + "source": [ + "ds = \"deep1b\"\n", + "data = []\n", + "jss = []\n", + "root = f\"/checkpoint/gsz/bench_fw/codecs/{ds}\"\n", + "results = BIO(root).read_json(f\"result.json\")\n", + "for k, e in results[\"experiments\"].items():\n", + " if \"rec\" in k and e['factory'] != 'Flat': # and e['sym_recall'] > 0.0: # and \"PRQ\" in e['factory'] and e['sym_recall'] > 0.0:\n", + " code_size = results['indices'][e['codec']]['sa_code_size']\n", + " codec_size = results['indices'][e['codec']]['codec_size']\n", + " training_time = results['indices'][e['codec']]['training_time']\n", + " # training_size = results['indices'][e['codec']]['training_size']\n", + " cpu = e['cpu'] if 'cpu' in e else \"\"\n", + " ps = ', '.join([f\"{k}={v}\" for k,v in e['construction_params'][0].items()]) if e['construction_params'] else \" \"\n", + " eps = ', '.join([f\"{k}={v}\" for k,v in e['reconstruct_params'].items() if k != \"snap\"]) if e['reconstruct_params'] else \" \"\n", + " data.append((code_size, f\"|{e['factory']}|{ps}|{eps}|{code_size}|{pretty_size(codec_size)}|{pretty_time(training_time)}|{training_size}|{pretty_mse(e['mse'])}|{e['sym_recall']}|{e['asym_recall']}|{pretty_time(e['encode_time'])}|{pretty_time(e['decode_time'])}|{cpu}|\"))\n", + " jss.append({\n", + " 'factory': e['factory'],\n", + " 'parameters': e['construction_params'][0] if e['construction_params'] else \"\",\n", + " 'evaluation_params': e['reconstruct_params'],\n", + " 'code_size': code_size,\n", + " 'codec_size': codec_size,\n", + " 'training_time': training_time,\n", + " 'training_size': training_size,\n", + " 'mse': e['mse'],\n", + " 'sym_recall': e['sym_recall'],\n", + " 'asym_recall': e['asym_recall'],\n", + " 'encode_time': e['encode_time'],\n", + " 'decode_time': e['decode_time'],\n", + " 'cpu': cpu,\n", + " })\n", + "\n", + "print(\"|factory key|construction parameters|evaluation parameters|code size|codec size|training time|training size|mean squared error|sym recall @ 1|asym recall @ 1|encode time|decode time|cpu|\")\n", + "print(\"|-|-|-|-|-|-|-|-|-|\")\n", + "data.sort()\n", + "for d in data:\n", + " print(d[1])\n", + "\n", + "with open(f\"/checkpoint/gsz/bench_fw/codecs_{ds}_test.json\", \"w\") as f:\n", + " json.dump(jss, f)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d1216733-9670-407c-b3d2-5f87bce0321c", + "metadata": {}, + "outputs": [], + "source": [ + "def read_file(filename: str, keys):\n", + " results = []\n", + " with ZipFile(filename, \"r\") as zip_file:\n", + " for key in keys:\n", + " with zip_file.open(key, \"r\") as f:\n", + " if key in [\"D\", \"I\", \"R\", \"lims\"]:\n", + " results.append(np.load(f))\n", + " elif key in [\"P\"]:\n", + " t = io.TextIOWrapper(f)\n", + " results.append(json.load(t))\n", + " else:\n", + " raise AssertionError()\n", + " return results" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "56de051e-22db-4bef-b242-1ddabc9e0bb9", + "metadata": {}, + "outputs": [], + "source": [ + "ds = \"contriever\"\n", + "data = []\n", + "jss = []\n", + "root = f\"/checkpoint/gsz/bench_fw/codecs/{ds}\"\n", + "for lf in glob.glob(root + '/*rec*.zip'):\n", + " e, = read_file(lf, ['P'])\n", + " if e['factory'] != 'Flat': # and e['sym_recall'] > 0.0: # and \"PRQ\" in e['factory'] and e['sym_recall'] > 0.0:\n", + " code_size = e['codec_meta']['sa_code_size']\n", + " codec_size = e['codec_meta']['codec_size']\n", + " training_time = e['codec_meta']['training_time']\n", + " training_size = None # e['codec_meta']['training_size']\n", + " cpu = e['cpu'] if 'cpu' in e else \"\"\n", + " ps = ', '.join([f\"{k}={v}\" for k,v in e['construction_params'][0].items()]) if e['construction_params'] else \" \"\n", + " eps = ', '.join([f\"{k}={v}\" for k,v in e['reconstruct_params'].items() if k != \"snap\"]) if e['reconstruct_params'] else \" \"\n", + " if eps in ps and eps != \"encode_ils_iters=16\" and eps != \"max_beam_size=32\":\n", + " eps = \" \"\n", + " data.append((code_size, f\"|{e['factory']}|{ps}|{eps}|{code_size}|{pretty_size(codec_size)}|{pretty_time(training_time)}|{pretty_mse(e['mse'])}|{e['sym_recall']}|{e['asym_recall']}|{pretty_time(e['encode_time'])}|{pretty_time(e['decode_time'])}|{cpu}|\"))\n", + " eps = e['reconstruct_params']\n", + " del eps['snap']\n", + " params = copy(e['construction_params'][0]) if e['construction_params'] else {}\n", + " for k, v in e['reconstruct_params'].items():\n", + " params[k] = v\n", + " jss.append({\n", + " 'factory': e['factory'],\n", + " 'params': params,\n", + " 'construction_params': e['construction_params'][0] if e['construction_params'] else {},\n", + " 'evaluation_params': e['reconstruct_params'],\n", + " 'code_size': code_size,\n", + " 'codec_size': codec_size,\n", + " 'training_time': training_time,\n", + " # 'training_size': training_size,\n", + " 'mse': e['mse'],\n", + " 'sym_recall': e['sym_recall'],\n", + " 'asym_recall': e['asym_recall'],\n", + " 'encode_time': e['encode_time'],\n", + " 'decode_time': e['decode_time'],\n", + " 'cpu': cpu,\n", + " })\n", + "\n", + "print(\"|factory key|construction parameters|encode/decode parameters|code size|codec size|training time|mean squared error|sym recall @ 1|asym recall @ 1|encode time|decode time|cpu|\")\n", + "print(\"|-|-|-|-|-|-|-|-|-|\")\n", + "data.sort()\n", + "# for d in data:\n", + "# print(d[1])\n", + "\n", + "print(len(data))\n", + "\n", + "with open(f\"/checkpoint/gsz/bench_fw/codecs_{ds}_5.json\", \"w\") as f:\n", + " json.dump(jss, f)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "faiss_binary (local)", + "language": "python", + "name": "faiss_binary_local" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}