From 3888f9bb11046f8ab3ddbd9abea22f20ba77f130 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 30 Aug 2023 09:05:59 -0700 Subject: [PATCH] Using expanded distance forms in `RaftFlatIndex.cu` (#3021) Summary: This is a minor bug that comes with a perf impact. The classic FAISS `FlatIndex` always uses expanded form of distance computation even though an argument `exactDistances` is provided. `RaftFlatIndex` was using this argument to determine whether the computation should be exhaustive. This PR includes one additional change to eagerly initialize the `cublas_handle` on the `device_resources` instance when it's created. Pull Request resolved: https://github.com/facebookresearch/faiss/pull/3021 Reviewed By: pemazare Differential Revision: D48739660 Pulled By: mdouze fbshipit-source-id: a361334eb243df86c169c69d24bb10fed8876ee9 --- faiss/gpu/StandardGpuResources.cpp | 3 +++ faiss/gpu/impl/RaftUtils.h | 3 +-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/faiss/gpu/StandardGpuResources.cpp b/faiss/gpu/StandardGpuResources.cpp index 4b35aa4c0a..2d4c675d64 100644 --- a/faiss/gpu/StandardGpuResources.cpp +++ b/faiss/gpu/StandardGpuResources.cpp @@ -432,6 +432,9 @@ raft::device_resources& StandardGpuResourcesImpl::getRaftHandle(int device) { // Make sure we are using the stream the user may have already assigned // to the current GpuResources raftHandles_.emplace(std::make_pair(device, getDefaultStream(device))); + + // Initialize cublas handle + raftHandles_[device].get_cublas_handle(); } // Otherwise, our base default handle diff --git a/faiss/gpu/impl/RaftUtils.h b/faiss/gpu/impl/RaftUtils.h index 6c744051ae..f1ea19ed33 100644 --- a/faiss/gpu/impl/RaftUtils.h +++ b/faiss/gpu/impl/RaftUtils.h @@ -36,8 +36,7 @@ inline raft::distance::DistanceType faiss_to_raft( case MetricType::METRIC_INNER_PRODUCT: return raft::distance::DistanceType::InnerProduct; case MetricType::METRIC_L2: - return exactDistance ? raft::distance::DistanceType::L2Unexpanded - : raft::distance::DistanceType::L2Expanded; + return raft::distance::DistanceType::L2Expanded; case MetricType::METRIC_L1: return raft::distance::DistanceType::L1; case MetricType::METRIC_Linf: