Skip to content

Commit

Permalink
Fix ANN error
Browse files Browse the repository at this point in the history
  • Loading branch information
viclafargue committed Jan 28, 2021
1 parent e0f6a80 commit 7d5e215
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
5 changes: 3 additions & 2 deletions cpp/src/knn/knn.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -54,7 +54,8 @@ void approx_knn_build_index(raft::handle_t &handle, ML::knnIndex *index,
ML::MetricType metric, float metricArg,
float *index_items, int n) {
MLCommon::Selection::approx_knn_build_index(
index, params, D, metric, metricArg, index_items, n, handle.get_stream());
index, params, D, metric, metricArg, index_items, n, handle.get_device(),
handle.get_stream());
}

void approx_knn_search(ML::knnIndex *index, int n, const float *x, int k,
Expand Down
22 changes: 13 additions & 9 deletions cpp/src_prims/selection/knn.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -263,18 +263,15 @@ void approx_knn_ivfsq_build_index(ML::knnIndex *index, ML::IVFSQParam *params,
template <typename IntType = int>
void approx_knn_build_index(ML::knnIndex *index, ML::knnIndexParam *params,
IntType D, ML::MetricType metric, float metricArg,
float *index_items, IntType n,
float *index_items, IntType n, int dev_id,
cudaStream_t userStream) {
int device;
CUDA_CHECK(cudaGetDevice(&device));

faiss::gpu::StandardGpuResources *gpu_res =
new faiss::gpu::StandardGpuResources();
gpu_res->noTempMemory();
gpu_res->setCudaMallocWarning(false);
gpu_res->setDefaultStream(device, userStream);
gpu_res->setTempMemory(3000000000);
gpu_res->setCudaMallocWarning(true);
gpu_res->setDefaultStream(dev_id, userStream);
index->gpu_res = gpu_res;
index->device = device;
index->device = dev_id;
index->index = nullptr;

if (dynamic_cast<ML::IVFFlatParam *>(params)) {
Expand All @@ -285,6 +282,8 @@ void approx_knn_build_index(ML::knnIndex *index, ML::knnIndexParam *params,
userStream);
index->index->train(n, h_index_items.data());
index->index->add(n, h_index_items.data());
CUDA_CHECK(cudaStreamSynchronize(userStream));
CUDA_CHECK(cudaPeekAtLastError());
return;
} else if (dynamic_cast<ML::IVFPQParam *>(params)) {
ML::IVFPQParam *IVFPQ_param = dynamic_cast<ML::IVFPQParam *>(params);
Expand All @@ -298,12 +297,17 @@ void approx_knn_build_index(ML::knnIndex *index, ML::knnIndexParam *params,

index->index->train(n, index_items);
index->index->add(n, index_items);
CUDA_CHECK(cudaStreamSynchronize(userStream));
CUDA_CHECK(cudaPeekAtLastError());
}

template <typename IntType = int>
void approx_knn_search(ML::knnIndex *index, IntType n, const float *x,
IntType k, float *distances, int64_t *labels) {
index->index->search(n, x, k, distances, labels);
cudaStream_t stream = index->gpu_res->getDefaultStream(index->device);
CUDA_CHECK(cudaStreamSynchronize(stream));
CUDA_CHECK(cudaPeekAtLastError());
}

/**
Expand Down

0 comments on commit 7d5e215

Please sign in to comment.