diff --git a/cpp/include/raft/spatial/knn/detail/ann_quantized_faiss.cuh b/cpp/include/raft/spatial/knn/detail/ann_quantized_faiss.cuh index bb37fd03db..6e4c99b646 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_quantized_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_quantized_faiss.cuh @@ -18,6 +18,9 @@ #include "../ann_common.h" +#include "common_faiss.h" +#include "processing.hpp" + #include #include diff --git a/cpp/include/raft/spatial/knn/detail/common_faiss.h b/cpp/include/raft/spatial/knn/detail/common_faiss.h new file mode 100644 index 0000000000..0c0398a336 --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/common_faiss.h @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include +#include + +namespace raft { +namespace spatial { +namespace knn { +namespace detail { + +inline faiss::MetricType build_faiss_metric( + raft::distance::DistanceType metric) { + switch (metric) { + case raft::distance::DistanceType::CosineExpanded: + return faiss::MetricType::METRIC_INNER_PRODUCT; + case raft::distance::DistanceType::CorrelationExpanded: + return faiss::MetricType::METRIC_INNER_PRODUCT; + case raft::distance::DistanceType::L2Expanded: + return faiss::MetricType::METRIC_L2; + case raft::distance::DistanceType::L2Unexpanded: + return faiss::MetricType::METRIC_L2; + case raft::distance::DistanceType::L2SqrtExpanded: + return faiss::MetricType::METRIC_L2; + case raft::distance::DistanceType::L2SqrtUnexpanded: + return faiss::MetricType::METRIC_L2; + case raft::distance::DistanceType::L1: + return faiss::MetricType::METRIC_L1; + case raft::distance::DistanceType::InnerProduct: + return faiss::MetricType::METRIC_INNER_PRODUCT; + case raft::distance::DistanceType::LpUnexpanded: + return faiss::MetricType::METRIC_Lp; + case raft::distance::DistanceType::Linf: + return faiss::MetricType::METRIC_Linf; + case raft::distance::DistanceType::Canberra: + return faiss::MetricType::METRIC_Canberra; + case raft::distance::DistanceType::BrayCurtis: + return faiss::MetricType::METRIC_BrayCurtis; + case raft::distance::DistanceType::JensenShannon: + return faiss::MetricType::METRIC_JensenShannon; + default: + THROW("MetricType not supported: %d", metric); + } +} + +} // namespace detail +} // namespace knn +} // namespace spatial +} // namespace raft diff --git a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh index 75226299e6..09494e9eb1 100644 --- a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh @@ -36,6 +36,8 @@ #include "haversine_distance.cuh" #include "processing.hpp" +#include "common_faiss.h" + namespace raft { namespace spatial { namespace knn { @@ -167,40 +169,6 @@ inline void knn_merge_parts(value_t *inK, value_idx *inV, value_t *outK, inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); } -inline faiss::MetricType build_faiss_metric( - raft::distance::DistanceType metric) { - switch (metric) { - case raft::distance::DistanceType::CosineExpanded: - return faiss::MetricType::METRIC_INNER_PRODUCT; - case raft::distance::DistanceType::CorrelationExpanded: - return faiss::MetricType::METRIC_INNER_PRODUCT; - case raft::distance::DistanceType::L2Expanded: - return faiss::MetricType::METRIC_L2; - case raft::distance::DistanceType::L2Unexpanded: - return faiss::MetricType::METRIC_L2; - case raft::distance::DistanceType::L2SqrtExpanded: - return faiss::MetricType::METRIC_L2; - case raft::distance::DistanceType::L2SqrtUnexpanded: - return faiss::MetricType::METRIC_L2; - case raft::distance::DistanceType::L1: - return faiss::MetricType::METRIC_L1; - case raft::distance::DistanceType::InnerProduct: - return faiss::MetricType::METRIC_INNER_PRODUCT; - case raft::distance::DistanceType::LpUnexpanded: - return faiss::MetricType::METRIC_Lp; - case raft::distance::DistanceType::Linf: - return faiss::MetricType::METRIC_Linf; - case raft::distance::DistanceType::Canberra: - return faiss::MetricType::METRIC_Canberra; - case raft::distance::DistanceType::BrayCurtis: - return faiss::MetricType::METRIC_BrayCurtis; - case raft::distance::DistanceType::JensenShannon: - return faiss::MetricType::METRIC_JensenShannon; - default: - THROW("MetricType not supported: %d", metric); - } -} - /** * Search the kNN for the k-nearest neighbors of a set of query vectors * @param[in] input vector of device device memory array pointers to search