Skip to content

Commit

Permalink
Add links to docs and move helpers to utils.py
Browse files Browse the repository at this point in the history
  • Loading branch information
tfeher committed Sep 21, 2023
1 parent cfd48c5 commit c4b2a05
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 237 deletions.
283 changes: 77 additions & 206 deletions notebooks/ivf_flat_example.ipynb

Large diffs are not rendered by default.

42 changes: 11 additions & 31 deletions notebooks/tutorial_ivf_pq.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
"from pylibraft.common import DeviceResources\n",
"from pylibraft.neighbors import ivf_pq, refine\n",
"from adjustText import adjust_text\n",
"from utils import calc_recall, load_dataset\n",
"\n",
"%matplotlib inline"
]
Expand Down Expand Up @@ -194,15 +195,18 @@
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The index and data will be saved in /tmp/raft_example\n"
]
}
],
"source": [
"DATASET_URL = \"http://ann-benchmarks.com/sift-128-euclidean.hdf5\"\n",
"DATASET_FILENAME = DATASET_URL.split('/')[-1]\n",
"\n",
"## download the dataset\n",
"dataset_path = os.path.join(WORK_FOLDER, DATASET_FILENAME)\n",
"if not os.path.exists(dataset_path):\n",
" urllib.request.urlretrieve(DATASET_URL, dataset_path)"
"f = load_dataset(DATASET_URL)"
]
},
{
Expand All @@ -227,8 +231,6 @@
}
],
"source": [
"f = h5py.File(dataset_path, \"r\")\n",
"\n",
"metric = f.attrs['distance']\n",
"\n",
"dataset = cp.array(f['train'])\n",
Expand Down Expand Up @@ -456,28 +458,6 @@
}
],
"source": [
"## Check the quality of the prediction (recall)\n",
"def calc_recall(found_indices, ground_truth):\n",
" found_indices = cp.asarray(found_indices)\n",
" bs, k = found_indices.shape\n",
" if bs != ground_truth.shape[0]:\n",
" raise RuntimeError(\n",
" \"Batch sizes do not match {} vs {}\".format(\n",
" bs, ground_truth.shape[0])\n",
" )\n",
" if k > ground_truth.shape[1]:\n",
" raise RuntimeError(\n",
" \"Not enough indices in the ground truth ({} > {})\".format(\n",
" k, ground_truth.shape[1])\n",
" )\n",
" n = 0\n",
" # Go over the batch\n",
" for i in range(bs):\n",
" # Note, ivf-pq does not guarantee the ordered input, hence the use of intersect1d\n",
" n += cp.intersect1d(found_indices[i, :k], ground_truth[i, :k]).size\n",
" recall = n / found_indices.size\n",
" return recall\n",
"\n",
"recall_first_try = calc_recall(neighbors, gt_neighbors)\n",
"print(f\"Got recall = {recall_first_try} with the default parameters (k = {k}).\")"
]
Expand Down
103 changes: 103 additions & 0 deletions notebooks/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import cupy as cp
import h5py
import os
import tempfile
import time
import urllib

## Check the quality of the prediction (recall)
def calc_recall(found_indices, ground_truth):
found_indices = cp.asarray(found_indices)
bs, k = found_indices.shape
if bs != ground_truth.shape[0]:
raise RuntimeError(
"Batch sizes do not match {} vs {}".format(
bs, ground_truth.shape[0]
)
)
if k > ground_truth.shape[1]:
raise RuntimeError(
"Not enough indices in the ground truth ({} > {})".format(
k, ground_truth.shape[1]
)
)
n = 0
# Go over the batch
for i in range(bs):
# Note, ivf-pq does not guarantee the ordered input, hence the use of intersect1d
n += cp.intersect1d(found_indices[i, :k], ground_truth[i, :k]).size
recall = n / found_indices.size
return recall


class BenchmarkTimer:
"""Provides a context manager that runs a code block `reps` times
and records results to the instance variable `timings`. Use like:
.. code-block:: python
timer = BenchmarkTimer(rep=5)
for _ in timer.benchmark_runs():
... do something ...
print(np.min(timer.timings))
This class is borrowed from the rapids/cuml benchmark suite
"""

def __init__(self, reps=1, warmup=0):
self.warmup = warmup
self.reps = reps
self.timings = []

def benchmark_runs(self):
for r in range(self.reps + self.warmup):
t0 = time.time()
yield r
t1 = time.time()
self.timings.append(t1 - t0)
if r >= self.warmup:
self.timings.append(t1 - t0)


def load_dataset(dataset_url, work_folder=None):
"""Download dataset from url. It is expeted that the dataset contains a hdf5 file in ann-benchmarks format
Parameters
----------
dataset_url address of hdf5 file
work_folder name of the local folder to store the dataset
"""
dataset_url = "http://ann-benchmarks.com/sift-128-euclidean.hdf5"
dataset_filename = dataset_url.split("/")[-1]

# We'll need to load store some data in this tutorial
if work_folder is None:
work_folder = os.path.join(tempfile.gettempdir(), "raft_example")

if not os.path.exists(work_folder):
os.makedirs(work_folder)
print("The index and data will be saved in", work_folder)

## download the dataset
dataset_path = os.path.join(work_folder, dataset_filename)
if not os.path.exists(dataset_path):
urllib.request.urlretrieve(dataset_url, dataset_path)

f = h5py.File(dataset_path, "r")

return f

0 comments on commit c4b2a05

Please sign in to comment.