diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 27ef00e385..dff9aceb8d 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,14 +20,6 @@ #include #include #include -#include -#include -#include - -#include - -#include -#include #include #include #include @@ -42,9 +34,19 @@ #include #include #include -#include +#include +#include + +#include +#include +#include + #include +#include +#include +#include + namespace raft::neighbors::detail { using namespace raft::spatial::knn::detail; using namespace raft::spatial::knn; @@ -78,7 +80,7 @@ void tiled_brute_force_knn(const raft::resources& handle, size_t tile_cols = 0; auto stream = resource::get_cuda_stream(handle); auto device_memory = resource::get_workspace_resource(handle); - auto total_mem = device_memory->get_mem_info(stream).second; + auto total_mem = rmm::available_device_memory().second; faiss_select::chooseTileSize(m, n, d, sizeof(ElementType), total_mem, tile_rows, tile_cols); // for unittesting, its convenient to be able to put a max size on the tiles