Skip to content

Commit

Permalink
Add L2SqrtExpanded support to ivf_pq (#1138)
Browse files Browse the repository at this point in the history
Authors:
  - Ben Frederickson (https://github.com/benfred)

Approvers:
  - Artem M. Chirkin (https://github.com/achirkin)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1138
  • Loading branch information
benfred authored Jan 12, 2023
1 parent bbe0755 commit 36b315d
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 19 deletions.
7 changes: 6 additions & 1 deletion cpp/include/raft/spatial/knn/detail/ivf_pq_search.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -171,6 +171,7 @@ void select_clusters(const handle_t& handle,
*/
float norm_factor;
switch (metric) {
case raft::distance::DistanceType::L2SqrtExpanded:
case raft::distance::DistanceType::L2Expanded: norm_factor = 1.0 / -2.0; break;
case raft::distance::DistanceType::InnerProduct: norm_factor = 0.0; break;
default: RAFT_FAIL("Unsupported distance type %d.", int(metric));
Expand All @@ -189,6 +190,7 @@ void select_clusters(const handle_t& handle,
float beta;
uint32_t gemm_k = dim;
switch (metric) {
case raft::distance::DistanceType::L2SqrtExpanded:
case raft::distance::DistanceType::L2Expanded: {
alpha = -2.0;
beta = 0.0;
Expand Down Expand Up @@ -710,6 +712,7 @@ __global__ void ivfpq_compute_similarity_kernel(uint32_t n_rows,
if constexpr (PrecompBaseDiff) {
// Reduce number of memory reads later by pre-computing parts of the score
switch (metric) {
case distance::DistanceType::L2SqrtExpanded:
case distance::DistanceType::L2Expanded: {
for (uint32_t i = threadIdx.x; i < dim; i += blockDim.x) {
base_diff[i] = query[i] - cluster_center[i];
Expand Down Expand Up @@ -743,6 +746,7 @@ __global__ void ivfpq_compute_similarity_kernel(uint32_t n_rows,
float pq_c = *cur_pq_center;
cur_pq_center += PqShift;
switch (metric) {
case distance::DistanceType::L2SqrtExpanded:
case distance::DistanceType::L2Expanded: {
float diff;
if constexpr (PrecompBaseDiff) {
Expand Down Expand Up @@ -809,6 +813,7 @@ __global__ void ivfpq_compute_similarity_kernel(uint32_t n_rows,
switch (metric) {
// If the metric is non-negative, we can use the query_kth approximation as an early stop
// threshold to skip some iterations when computing the score. Add such metrics here.
case distance::DistanceType::L2SqrtExpanded:
case distance::DistanceType::L2Expanded: {
early_stop_limit = query_kth;
} break;
Expand Down
11 changes: 10 additions & 1 deletion cpp/test/neighbors/ann_ivf_pq.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -436,6 +436,15 @@ inline auto enum_variety_ip() -> test_cases_t
});
}

inline auto enum_variety_l2sqrt() -> test_cases_t
{
return map<ivf_pq_inputs>(enum_variety(), [](const ivf_pq_inputs& x) {
ivf_pq_inputs y(x);
y.index_params.metric = distance::DistanceType::L2SqrtExpanded;
return y;
});
}

/**
* Try different number of n_probes, some of which may trigger the non-fused version of the search
* kernel.
Expand Down
5 changes: 3 additions & 2 deletions cpp/test/neighbors/ann_ivf_pq/test_float_int64_t.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -22,6 +22,7 @@ using f32_f32_i64 = ivf_pq_test<float, float, int64_t>;

TEST_BUILD_SEARCH(f32_f32_i64)
TEST_BUILD_EXTEND_SEARCH(f32_f32_i64)
INSTANTIATE(f32_f32_i64, enum_variety_l2() + enum_variety_ip() + big_dims_small_lut());
INSTANTIATE(f32_f32_i64,
enum_variety_l2() + enum_variety_ip() + big_dims_small_lut() + enum_variety_l2sqrt());

} // namespace raft::neighbors::ivf_pq
13 changes: 7 additions & 6 deletions python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,7 @@ from pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq cimport (
def _get_metric(metric):
SUPPORTED_DISTANCES = {
"l2_expanded": DistanceType.L2Expanded,
# TODO(tfeher): fix inconsistency: index building for L2SqrtExpanded is
# only supported by build, not by search.
# "euclidean": DistanceType.L2SqrtExpanded
"euclidean": DistanceType.L2SqrtExpanded,
"inner_product": DistanceType.InnerProduct
}
if metric not in SUPPORTED_DISTANCES:
Expand All @@ -76,7 +74,8 @@ def _get_metric(metric):

cdef _get_metric_string(DistanceType metric):
return {DistanceType.L2Expanded : "l2_expanded",
DistanceType.InnerProduct: "inner_product"}[metric]
DistanceType.InnerProduct: "inner_product",
DistanceType.L2SqrtExpanded: "euclidean"}[metric]


cdef _get_codebook_string(c_ivf_pq.codebook_gen codebook):
Expand Down Expand Up @@ -135,9 +134,11 @@ cdef class IndexParams:
n_list : int, default = 1024
The number of clusters used in the coarse quantizer.
metric : string denoting the metric type, default="l2_expanded"
Valid values for metric: ["l2_expanded", "inner_product"], where
- l2_expanded is the equclidean distance without the square root
Valid values for metric: ["l2_expanded", "inner_product",
"euclidean"], where
- l2_expanded is the euclidean distance without the square root
operation, i.e.: distance(a,b) = \\sum_i (a_i - b_i)^2,
- euclidean is the euclidean distance
- inner product distance is defined as
distance(a, b) = \\sum_i a_i * b_i.
kmeans_n_iters : int, default = 20
Expand Down
19 changes: 10 additions & 9 deletions python/pylibraft/pylibraft/test/test_ivf_pq.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,14 @@ def check_distances(dataset, queries, metric, out_idx, out_dist, eps=None):
X = queries[np.newaxis, i, :]
Y = dataset[out_idx[i, :], :]
if metric == "l2_expanded":
dist[i, :] = pairwise_distances(X, Y, "sqeuclidean")
elif metric == "euclidean":
dist[i, :] = pairwise_distances(X, Y, "euclidean")
elif metric == "inner_product":
dist[i, :] = np.matmul(X, Y.T)
else:
raise ValueError("Invalid metric")

# Note: raft l2 metric does not include the square root operation like
# sklearn's euclidean.
if metric == "l2_expanded":
dist = np.power(dist, 2)

dist_eps = abs(dist)
dist_eps[dist < 1e-3] = 1e-3
diff = abs(out_dist - dist) / dist_eps
Expand Down Expand Up @@ -179,9 +176,11 @@ def run_ivf_pq_build_search_test(
out_dist = out_dist_device.copy_to_host()

# Calculate reference values with sklearn
skl_metric = {"l2_expanded": "euclidean", "inner_product": "cosine"}[
metric
]
skl_metric = {
"l2_expanded": "sqeuclidean",
"inner_product": "cosine",
"euclidean": "euclidean",
}[metric]
nn_skl = NearestNeighbors(
n_neighbors=k, algorithm="brute", metric=skl_metric
)
Expand Down Expand Up @@ -253,7 +252,9 @@ def test_ivf_pq_n(params):
)


@pytest.mark.parametrize("metric", ["l2_expanded", "inner_product"])
@pytest.mark.parametrize(
"metric", ["l2_expanded", "inner_product", "euclidean"]
)
@pytest.mark.parametrize("dtype", [np.float32])
@pytest.mark.parametrize("codebook_kind", ["subspace", "cluster"])
@pytest.mark.parametrize("rotation", [True, False])
Expand Down

0 comments on commit 36b315d

Please sign in to comment.