Skip to content

Commit

Permalink
InnerProduct testing for CAGRA+HNSW (rapidsai#2297)
Browse files Browse the repository at this point in the history
Authors:
  - Divye Gala (https://github.com/divyegala)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: rapidsai#2297
  • Loading branch information
divyegala authored May 23, 2024
1 parent 0b6f542 commit 64827fc
Showing 1 changed file with 24 additions and 3 deletions.
27 changes: 24 additions & 3 deletions python/pylibraft/pylibraft/test/test_hnsw.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,26 @@ def run_hnsw_build_search_test(
k=10,
dtype=np.float32,
metric="sqeuclidean",
build_algo="ivf_pq",
intermediate_graph_degree=128,
graph_degree=64,
search_params={},
):
dataset = generate_data((n_rows, n_cols), dtype)
if metric == "inner_product":
dataset = normalize(dataset, norm="l2", axis=1)
if dtype in [np.int8, np.uint8]:
pytest.skip(
"inner_product metric is not supported for int8/uint8 data"
)
if build_algo == "nn_descent":
pytest.skip("inner_product metric is not supported for nn_descent")

build_params = cagra.IndexParams(
metric=metric,
intermediate_graph_degree=intermediate_graph_degree,
graph_degree=graph_degree,
build_algo=build_algo,
)

index = cagra.build(build_params, dataset)
Expand All @@ -57,7 +65,14 @@ def run_hnsw_build_search_test(
out_dist, out_idx = hnsw.search(search_params, hnsw_index, queries, k)

# Calculate reference values with sklearn
nn_skl = NearestNeighbors(n_neighbors=k, algorithm="brute", metric=metric)
skl_metric = {
"sqeuclidean": "sqeuclidean",
"inner_product": "cosine",
"euclidean": "euclidean",
}[metric]
nn_skl = NearestNeighbors(
n_neighbors=k, algorithm="brute", metric=skl_metric
)
nn_skl.fit(dataset)
skl_idx = nn_skl.kneighbors(queries, return_distance=False)

Expand All @@ -69,9 +84,15 @@ def run_hnsw_build_search_test(
@pytest.mark.parametrize("k", [10, 20])
@pytest.mark.parametrize("ef", [30, 40])
@pytest.mark.parametrize("num_threads", [2, 4])
def test_hnsw(dtype, k, ef, num_threads):
@pytest.mark.parametrize("metric", ["sqeuclidean", "inner_product"])
@pytest.mark.parametrize("build_algo", ["ivf_pq", "nn_descent"])
def test_hnsw(dtype, k, ef, num_threads, metric, build_algo):
# Note that inner_product tests use normalized input which we cannot
# represent in int8, therefore we test only sqeuclidean metric here.
run_hnsw_build_search_test(
dtype=dtype, k=k, search_params={"ef": ef, "num_threads": num_threads}
dtype=dtype,
k=k,
metric=metric,
build_algo=build_algo,
search_params={"ef": ef, "num_threads": num_threads},
)

0 comments on commit 64827fc

Please sign in to comment.