diff --git a/cpp/src/knn/knn.cu b/cpp/src/knn/knn.cu index 4055d61a82..27d07955a6 100644 --- a/cpp/src/knn/knn.cu +++ b/cpp/src/knn/knn.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -54,7 +54,8 @@ void approx_knn_build_index(raft::handle_t &handle, ML::knnIndex *index, ML::MetricType metric, float metricArg, float *index_items, int n) { MLCommon::Selection::approx_knn_build_index( - index, params, D, metric, metricArg, index_items, n, handle.get_stream()); + index, params, D, metric, metricArg, index_items, n, handle.get_device(), + handle.get_stream()); } void approx_knn_search(ML::knnIndex *index, int n, const float *x, int k, diff --git a/cpp/src_prims/selection/knn.cuh b/cpp/src_prims/selection/knn.cuh index 4d3a0f5dc7..51996834be 100644 --- a/cpp/src_prims/selection/knn.cuh +++ b/cpp/src_prims/selection/knn.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -263,18 +263,15 @@ void approx_knn_ivfsq_build_index(ML::knnIndex *index, ML::IVFSQParam *params, template void approx_knn_build_index(ML::knnIndex *index, ML::knnIndexParam *params, IntType D, ML::MetricType metric, float metricArg, - float *index_items, IntType n, + float *index_items, IntType n, int dev_id, cudaStream_t userStream) { - int device; - CUDA_CHECK(cudaGetDevice(&device)); - faiss::gpu::StandardGpuResources *gpu_res = new faiss::gpu::StandardGpuResources(); - gpu_res->noTempMemory(); - gpu_res->setCudaMallocWarning(false); - gpu_res->setDefaultStream(device, userStream); + gpu_res->setTempMemory(3000000000); + gpu_res->setCudaMallocWarning(true); + gpu_res->setDefaultStream(dev_id, userStream); index->gpu_res = gpu_res; - index->device = device; + index->device = dev_id; index->index = nullptr; if (dynamic_cast(params)) { @@ -285,6 +282,8 @@ void approx_knn_build_index(ML::knnIndex *index, ML::knnIndexParam *params, userStream); index->index->train(n, h_index_items.data()); index->index->add(n, h_index_items.data()); + CUDA_CHECK(cudaStreamSynchronize(userStream)); + CUDA_CHECK(cudaPeekAtLastError()); return; } else if (dynamic_cast(params)) { ML::IVFPQParam *IVFPQ_param = dynamic_cast(params); @@ -298,12 +297,17 @@ void approx_knn_build_index(ML::knnIndex *index, ML::knnIndexParam *params, index->index->train(n, index_items); index->index->add(n, index_items); + CUDA_CHECK(cudaStreamSynchronize(userStream)); + CUDA_CHECK(cudaPeekAtLastError()); } template void approx_knn_search(ML::knnIndex *index, IntType n, const float *x, IntType k, float *distances, int64_t *labels) { index->index->search(n, x, k, distances, labels); + cudaStream_t stream = index->gpu_res->getDefaultStream(index->device); + CUDA_CHECK(cudaStreamSynchronize(stream)); + CUDA_CHECK(cudaPeekAtLastError()); } /**