Skip to content

Commit

Permalink
fix nan issues in L2 expanded sqrt KNN distances (#411)
Browse files Browse the repository at this point in the history
-  fix nans introduced in l2 expanded sqrt  KNN distances due to very small negative 0s produced by fusedL2knn due to floating point computation uncertainty. 
- re-enable fused l2 knn.

Authors:
  - Mahesh Doijade (https://github.com/mdoijade)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #411
  • Loading branch information
mdoijade authored Dec 9, 2021
1 parent b29ec65 commit 8ddb61c
Showing 1 changed file with 63 additions and 60 deletions.
123 changes: 63 additions & 60 deletions cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -294,67 +294,66 @@ void brute_force_knn_impl(
cudaStream_t stream = raft::select_stream(userStream, internalStreams, n_int_streams, i);

// // TODO: Enable this once we figure out why it's causing pytest failures in cuml.
// if (k <= 64 && rowMajorQuery == rowMajorIndex && rowMajorQuery == true &&
// (metric == raft::distance::DistanceType::L2Unexpanded ||
// metric == raft::distance::DistanceType::L2SqrtUnexpanded //||
// // metric == raft::distance::DistanceType::L2Expanded ||
// // metric == raft::distance::DistanceType::L2SqrtExpanded)
// )) {
// fusedL2Knn(D,
// out_i_ptr,
// out_d_ptr,
// input[i],
// search_items,
// sizes[i],
// n,
// k,
// rowMajorIndex,
// rowMajorQuery,
// stream,
// metric);
// } else {
switch (metric) {
case raft::distance::DistanceType::Haversine:

ASSERT(D == 2,
"Haversine distance requires 2 dimensions "
"(latitude / longitude).");

haversine_knn(out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, stream);
break;
default:
faiss::MetricType m = build_faiss_metric(metric);

faiss::gpu::StandardGpuResources gpu_res;

gpu_res.noTempMemory();
gpu_res.setDefaultStream(device, stream);

faiss::gpu::GpuDistanceParams args;
args.metric = m;
args.metricArg = metricArg;
args.k = k;
args.dims = D;
args.vectors = input[i];
args.vectorsRowMajor = rowMajorIndex;
args.numVectors = sizes[i];
args.queries = search_items;
args.queriesRowMajor = rowMajorQuery;
args.numQueries = n;
args.outDistances = out_d_ptr;
args.outIndices = out_i_ptr;

/**
* @todo: Until FAISS supports pluggable allocation strategies,
* we will not reap the benefits of the pool allocator for
* avoiding device-wide synchronizations from cudaMalloc/cudaFree
*/
bfKnn(&gpu_res, args);
if (k <= 64 && rowMajorQuery == rowMajorIndex && rowMajorQuery == true &&
(metric == raft::distance::DistanceType::L2Unexpanded ||
metric == raft::distance::DistanceType::L2SqrtUnexpanded ||
metric == raft::distance::DistanceType::L2Expanded ||
metric == raft::distance::DistanceType::L2SqrtExpanded)) {
fusedL2Knn(D,
out_i_ptr,
out_d_ptr,
input[i],
search_items,
sizes[i],
n,
k,
rowMajorIndex,
rowMajorQuery,
stream,
metric);
} else {
switch (metric) {
case raft::distance::DistanceType::Haversine:

ASSERT(D == 2,
"Haversine distance requires 2 dimensions "
"(latitude / longitude).");

haversine_knn(out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, stream);
break;
default:
faiss::MetricType m = build_faiss_metric(metric);

faiss::gpu::StandardGpuResources gpu_res;

gpu_res.noTempMemory();
gpu_res.setDefaultStream(device, stream);

faiss::gpu::GpuDistanceParams args;
args.metric = m;
args.metricArg = metricArg;
args.k = k;
args.dims = D;
args.vectors = input[i];
args.vectorsRowMajor = rowMajorIndex;
args.numVectors = sizes[i];
args.queries = search_items;
args.queriesRowMajor = rowMajorQuery;
args.numQueries = n;
args.outDistances = out_d_ptr;
args.outIndices = out_i_ptr;

/**
* @todo: Until FAISS supports pluggable allocation strategies,
* we will not reap the benefits of the pool allocator for
* avoiding device-wide synchronizations from cudaMalloc/cudaFree
*/
bfKnn(&gpu_res, args);
}
}
}

CUDA_CHECK(cudaPeekAtLastError());
// }
CUDA_CHECK(cudaPeekAtLastError());
}

// Sync internal streams if used. We don't need to
// sync the user stream because we'll already have
Expand All @@ -379,7 +378,11 @@ void brute_force_knn_impl(
float p = 0.5; // standard l2
if (metric == raft::distance::DistanceType::LpUnexpanded) p = 1.0 / metricArg;
raft::linalg::unaryOp<float>(
res_D, res_D, n * k, [p] __device__(float input) { return powf(input, p); }, userStream);
res_D,
res_D,
n * k,
[p] __device__(float input) { return powf(fabsf(input), p); },
userStream);
}

query_metric_processor->revert(search_items);
Expand Down

0 comments on commit 8ddb61c

Please sign in to comment.