Skip to content

Commit

Permalink
Using expanded distance forms in RaftFlatIndex.cu (facebookresearch…
Browse files Browse the repository at this point in the history
…#3021)

Summary:
This is a minor bug that comes with a perf impact. The classic FAISS `FlatIndex` always uses expanded form of distance computation even though an argument `exactDistances` is provided. `RaftFlatIndex` was using this argument to determine whether the computation should be exhaustive.

This PR includes one additional change to eagerly initialize the `cublas_handle` on the `device_resources` instance when it's created.

Pull Request resolved: facebookresearch#3021

Reviewed By: pemazare

Differential Revision: D48739660

Pulled By: mdouze

fbshipit-source-id: a361334eb243df86c169c69d24bb10fed8876ee9
  • Loading branch information
cjnolet authored and abhinavdangeti committed Jul 12, 2024
1 parent 471da78 commit 88a9bb3
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
3 changes: 3 additions & 0 deletions faiss/gpu/StandardGpuResources.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,9 @@ raft::device_resources& StandardGpuResourcesImpl::getRaftHandle(int device) {
// Make sure we are using the stream the user may have already assigned
// to the current GpuResources
raftHandles_.emplace(std::make_pair(device, getDefaultStream(device)));

// Initialize cublas handle
raftHandles_[device].get_cublas_handle();
}

// Otherwise, our base default handle
Expand Down
3 changes: 1 addition & 2 deletions faiss/gpu/impl/RaftUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ inline raft::distance::DistanceType faiss_to_raft(
case MetricType::METRIC_INNER_PRODUCT:
return raft::distance::DistanceType::InnerProduct;
case MetricType::METRIC_L2:
return exactDistance ? raft::distance::DistanceType::L2Unexpanded
: raft::distance::DistanceType::L2Expanded;
return raft::distance::DistanceType::L2Expanded;
case MetricType::METRIC_L1:
return raft::distance::DistanceType::L1;
case MetricType::METRIC_Linf:
Expand Down

0 comments on commit 88a9bb3

Please sign in to comment.