From 36b315dc2dc3009c578245ac99d3fbcdaf296aa4 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Thu, 12 Jan 2023 09:49:40 -0800 Subject: [PATCH] Add L2SqrtExpanded support to ivf_pq (#1138) Authors: - Ben Frederickson (https://github.com/benfred) Approvers: - Artem M. Chirkin (https://github.com/achirkin) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1138 --- .../raft/spatial/knn/detail/ivf_pq_search.cuh | 7 ++++++- cpp/test/neighbors/ann_ivf_pq.cuh | 11 ++++++++++- .../ann_ivf_pq/test_float_int64_t.cu | 5 +++-- .../pylibraft/neighbors/ivf_pq/ivf_pq.pyx | 13 +++++++------ .../pylibraft/pylibraft/test/test_ivf_pq.py | 19 ++++++++++--------- 5 files changed, 36 insertions(+), 19 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh index 1df5671be2..16a78aec1c 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -171,6 +171,7 @@ void select_clusters(const handle_t& handle, */ float norm_factor; switch (metric) { + case raft::distance::DistanceType::L2SqrtExpanded: case raft::distance::DistanceType::L2Expanded: norm_factor = 1.0 / -2.0; break; case raft::distance::DistanceType::InnerProduct: norm_factor = 0.0; break; default: RAFT_FAIL("Unsupported distance type %d.", int(metric)); @@ -189,6 +190,7 @@ void select_clusters(const handle_t& handle, float beta; uint32_t gemm_k = dim; switch (metric) { + case raft::distance::DistanceType::L2SqrtExpanded: case raft::distance::DistanceType::L2Expanded: { alpha = -2.0; beta = 0.0; @@ -710,6 +712,7 @@ __global__ void ivfpq_compute_similarity_kernel(uint32_t n_rows, if constexpr (PrecompBaseDiff) { // Reduce number of memory reads later by pre-computing parts of the score switch (metric) { + case distance::DistanceType::L2SqrtExpanded: case distance::DistanceType::L2Expanded: { for (uint32_t i = threadIdx.x; i < dim; i += blockDim.x) { base_diff[i] = query[i] - cluster_center[i]; @@ -743,6 +746,7 @@ __global__ void ivfpq_compute_similarity_kernel(uint32_t n_rows, float pq_c = *cur_pq_center; cur_pq_center += PqShift; switch (metric) { + case distance::DistanceType::L2SqrtExpanded: case distance::DistanceType::L2Expanded: { float diff; if constexpr (PrecompBaseDiff) { @@ -809,6 +813,7 @@ __global__ void ivfpq_compute_similarity_kernel(uint32_t n_rows, switch (metric) { // If the metric is non-negative, we can use the query_kth approximation as an early stop // threshold to skip some iterations when computing the score. Add such metrics here. + case distance::DistanceType::L2SqrtExpanded: case distance::DistanceType::L2Expanded: { early_stop_limit = query_kth; } break; diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index 353e8b65e5..94777aedd1 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -436,6 +436,15 @@ inline auto enum_variety_ip() -> test_cases_t }); } +inline auto enum_variety_l2sqrt() -> test_cases_t +{ + return map(enum_variety(), [](const ivf_pq_inputs& x) { + ivf_pq_inputs y(x); + y.index_params.metric = distance::DistanceType::L2SqrtExpanded; + return y; + }); +} + /** * Try different number of n_probes, some of which may trigger the non-fused version of the search * kernel. diff --git a/cpp/test/neighbors/ann_ivf_pq/test_float_int64_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_float_int64_t.cu index ecb2faa6a0..db42b1ee6a 100644 --- a/cpp/test/neighbors/ann_ivf_pq/test_float_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_pq/test_float_int64_t.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ using f32_f32_i64 = ivf_pq_test; TEST_BUILD_SEARCH(f32_f32_i64) TEST_BUILD_EXTEND_SEARCH(f32_f32_i64) -INSTANTIATE(f32_f32_i64, enum_variety_l2() + enum_variety_ip() + big_dims_small_lut()); +INSTANTIATE(f32_f32_i64, + enum_variety_l2() + enum_variety_ip() + big_dims_small_lut() + enum_variety_l2sqrt()); } // namespace raft::neighbors::ivf_pq diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx index 002a097d0f..ee30864193 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx @@ -64,9 +64,7 @@ from pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq cimport ( def _get_metric(metric): SUPPORTED_DISTANCES = { "l2_expanded": DistanceType.L2Expanded, - # TODO(tfeher): fix inconsistency: index building for L2SqrtExpanded is - # only supported by build, not by search. - # "euclidean": DistanceType.L2SqrtExpanded + "euclidean": DistanceType.L2SqrtExpanded, "inner_product": DistanceType.InnerProduct } if metric not in SUPPORTED_DISTANCES: @@ -76,7 +74,8 @@ def _get_metric(metric): cdef _get_metric_string(DistanceType metric): return {DistanceType.L2Expanded : "l2_expanded", - DistanceType.InnerProduct: "inner_product"}[metric] + DistanceType.InnerProduct: "inner_product", + DistanceType.L2SqrtExpanded: "euclidean"}[metric] cdef _get_codebook_string(c_ivf_pq.codebook_gen codebook): @@ -135,9 +134,11 @@ cdef class IndexParams: n_list : int, default = 1024 The number of clusters used in the coarse quantizer. metric : string denoting the metric type, default="l2_expanded" - Valid values for metric: ["l2_expanded", "inner_product"], where - - l2_expanded is the equclidean distance without the square root + Valid values for metric: ["l2_expanded", "inner_product", + "euclidean"], where + - l2_expanded is the euclidean distance without the square root operation, i.e.: distance(a,b) = \\sum_i (a_i - b_i)^2, + - euclidean is the euclidean distance - inner product distance is defined as distance(a, b) = \\sum_i a_i * b_i. kmeans_n_iters : int, default = 20 diff --git a/python/pylibraft/pylibraft/test/test_ivf_pq.py b/python/pylibraft/pylibraft/test/test_ivf_pq.py index 35738cd471..db1389c6cd 100644 --- a/python/pylibraft/pylibraft/test/test_ivf_pq.py +++ b/python/pylibraft/pylibraft/test/test_ivf_pq.py @@ -59,17 +59,14 @@ def check_distances(dataset, queries, metric, out_idx, out_dist, eps=None): X = queries[np.newaxis, i, :] Y = dataset[out_idx[i, :], :] if metric == "l2_expanded": + dist[i, :] = pairwise_distances(X, Y, "sqeuclidean") + elif metric == "euclidean": dist[i, :] = pairwise_distances(X, Y, "euclidean") elif metric == "inner_product": dist[i, :] = np.matmul(X, Y.T) else: raise ValueError("Invalid metric") - # Note: raft l2 metric does not include the square root operation like - # sklearn's euclidean. - if metric == "l2_expanded": - dist = np.power(dist, 2) - dist_eps = abs(dist) dist_eps[dist < 1e-3] = 1e-3 diff = abs(out_dist - dist) / dist_eps @@ -179,9 +176,11 @@ def run_ivf_pq_build_search_test( out_dist = out_dist_device.copy_to_host() # Calculate reference values with sklearn - skl_metric = {"l2_expanded": "euclidean", "inner_product": "cosine"}[ - metric - ] + skl_metric = { + "l2_expanded": "sqeuclidean", + "inner_product": "cosine", + "euclidean": "euclidean", + }[metric] nn_skl = NearestNeighbors( n_neighbors=k, algorithm="brute", metric=skl_metric ) @@ -253,7 +252,9 @@ def test_ivf_pq_n(params): ) -@pytest.mark.parametrize("metric", ["l2_expanded", "inner_product"]) +@pytest.mark.parametrize( + "metric", ["l2_expanded", "inner_product", "euclidean"] +) @pytest.mark.parametrize("dtype", [np.float32]) @pytest.mark.parametrize("codebook_kind", ["subspace", "cluster"]) @pytest.mark.parametrize("rotation", [True, False])