diff --git a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh index 2866049188..0b89377630 100644 --- a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh @@ -293,11 +293,13 @@ void brute_force_knn_impl( cudaStream_t stream = raft::select_stream(userStream, internalStreams, n_int_streams, i); + // TODO: Enable this once we figure out why it's causing pytest failures in cuml. if (k <= 64 && rowMajorQuery == rowMajorIndex && rowMajorQuery == true && (metric == raft::distance::DistanceType::L2Unexpanded || - metric == raft::distance::DistanceType::L2SqrtUnexpanded || - metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded)) { + metric == raft::distance::DistanceType::L2SqrtUnexpanded //|| + // metric == raft::distance::DistanceType::L2Expanded || + // metric == raft::distance::DistanceType::L2SqrtExpanded) + )) { fusedL2Knn(D, out_i_ptr, out_d_ptr,