diff --git a/python/pylibraft/pylibraft/test/test_hnsw.py b/python/pylibraft/pylibraft/test/test_hnsw.py index 487f190e4e..8cdf8c904f 100644 --- a/python/pylibraft/pylibraft/test/test_hnsw.py +++ b/python/pylibraft/pylibraft/test/test_hnsw.py @@ -29,6 +29,7 @@ def run_hnsw_build_search_test( k=10, dtype=np.float32, metric="sqeuclidean", + build_algo="ivf_pq", intermediate_graph_degree=128, graph_degree=64, search_params={}, @@ -36,11 +37,18 @@ def run_hnsw_build_search_test( dataset = generate_data((n_rows, n_cols), dtype) if metric == "inner_product": dataset = normalize(dataset, norm="l2", axis=1) + if dtype in [np.int8, np.uint8]: + pytest.skip( + "inner_product metric is not supported for int8/uint8 data" + ) + if build_algo == "nn_descent": + pytest.skip("inner_product metric is not supported for nn_descent") build_params = cagra.IndexParams( metric=metric, intermediate_graph_degree=intermediate_graph_degree, graph_degree=graph_degree, + build_algo=build_algo, ) index = cagra.build(build_params, dataset) @@ -57,7 +65,14 @@ def run_hnsw_build_search_test( out_dist, out_idx = hnsw.search(search_params, hnsw_index, queries, k) # Calculate reference values with sklearn - nn_skl = NearestNeighbors(n_neighbors=k, algorithm="brute", metric=metric) + skl_metric = { + "sqeuclidean": "sqeuclidean", + "inner_product": "cosine", + "euclidean": "euclidean", + }[metric] + nn_skl = NearestNeighbors( + n_neighbors=k, algorithm="brute", metric=skl_metric + ) nn_skl.fit(dataset) skl_idx = nn_skl.kneighbors(queries, return_distance=False) @@ -69,9 +84,15 @@ def run_hnsw_build_search_test( @pytest.mark.parametrize("k", [10, 20]) @pytest.mark.parametrize("ef", [30, 40]) @pytest.mark.parametrize("num_threads", [2, 4]) -def test_hnsw(dtype, k, ef, num_threads): +@pytest.mark.parametrize("metric", ["sqeuclidean", "inner_product"]) +@pytest.mark.parametrize("build_algo", ["ivf_pq", "nn_descent"]) +def test_hnsw(dtype, k, ef, num_threads, metric, build_algo): # Note that inner_product tests use normalized input which we cannot # represent in int8, therefore we test only sqeuclidean metric here. run_hnsw_build_search_test( - dtype=dtype, k=k, search_params={"ef": ef, "num_threads": num_threads} + dtype=dtype, + k=k, + metric=metric, + build_algo=build_algo, + search_params={"ef": ef, "num_threads": num_threads}, )