Skip to content

Commit

Permalink
Fix instantiations
Browse files Browse the repository at this point in the history
  • Loading branch information
tfeher committed Jul 3, 2023
1 parent 2739c83 commit ec5689b
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 46 deletions.
16 changes: 4 additions & 12 deletions cpp/include/raft/neighbors/detail/cagra/graph_core.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ template <class DATA_T, class IdxT, int numElementsPerThread>
__global__ void kern_sort(const DATA_T* const dataset, // [dataset_chunk_size, dataset_dim]
const IdxT dataset_size,
const uint32_t dataset_dim,
const uint32_t dataset_ld,
IdxT* const knn_graph, // [graph_chunk_size, graph_degree]
const uint32_t graph_size,
const uint32_t graph_degree)
Expand All @@ -91,9 +90,9 @@ __global__ void kern_sort(const DATA_T* const dataset, // [dataset_chunk_size,
float dist = 0.0;
for (int d = lane_id; d < dataset_dim; d += raft::WarpSize) {
float diff = spatial::knn::detail::utils::mapping<float>{}(
dataset[d + static_cast<uint64_t>(dataset_ld) * srcNode]) -
dataset[d + static_cast<uint64_t>(dataset_dim) * srcNode]) -
spatial::knn::detail::utils::mapping<float>{}(
dataset[d + static_cast<uint64_t>(dataset_ld) * dstNode]);
dataset[d + static_cast<uint64_t>(dataset_dim) * dstNode]);
dist += diff * diff;
}
dist += __shfl_xor_sync(0xffffffff, dist, 1);
Expand Down Expand Up @@ -239,7 +238,6 @@ void sort_knn_graph(raft::resources const& res,
"dataset size is expected to have the same number of graph index size");
const uint32_t dataset_size = dataset.extent(0);
const uint32_t dataset_dim = dataset.extent(1);
const uint32_t dataset_ld = dataset.stride(0);
const DataT* dataset_ptr = dataset.data_handle();

const IdxT graph_size = dataset_size;
Expand All @@ -265,13 +263,8 @@ void sort_knn_graph(raft::resources const& res,
graph_size * input_graph_degree,
resource::get_cuda_stream(res));

void (*kernel_sort)(const DataT* const,
const IdxT,
const uint32_t,
const uint32_t,
IdxT* const,
const uint32_t,
const uint32_t);
void (*kernel_sort)(
const DataT* const, const IdxT, const uint32_t, IdxT* const, const uint32_t, const uint32_t);
if (input_graph_degree <= 32) {
constexpr int numElementsPerThread = 1;
kernel_sort = kern_sort<DataT, IdxT, numElementsPerThread>;
Expand Down Expand Up @@ -306,7 +299,6 @@ void sort_knn_graph(raft::resources const& res,
d_dataset.data_handle(),
dataset_size,
dataset_dim,
dataset_ld,
d_input_graph.data_handle(),
graph_size,
input_graph_degree);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include <algorithm>
#include <cassert>
#include <cstdint>
#include <iostream>
#include <memory>
#include <numeric>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@
float_uint32=("float", "uint32_t", "float"), # data_t, idx_t, distance_t
int8_uint32=("int8_t", "uint32_t", "float"),
uint8_uint32=("uint8_t", "uint32_t", "float"),
float_uint64=("float", "uint64_t", "float"),
)

# knn
Expand Down
50 changes: 19 additions & 31 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -316,37 +316,25 @@ if(BUILD_TESTS)
NEIGHBORS_TEST
PATH
test/neighbors/ann_cagra/test_float_uint32_t.cu
# test/neighbors/ann_cagra/test_int8_t_uint32_t.cu
# test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu
# test/neighbors/ann_cagra/test_float_int64_t.cu
# test/neighbors/ann_ivf_flat/test_float_int64_t.cu
# test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu
# test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu
# test/neighbors/ann_ivf_pq/test_float_int64_t.cu
# test/neighbors/ann_ivf_pq/test_float_uint32_t.cu
# test/neighbors/ann_ivf_pq/test_float_int64_t.cu
# test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu
# test/neighbors/ann_ivf_pq/test_uint8_t_int64_t.cu
# test/neighbors/knn.cu
# test/neighbors/fused_l2_knn.cu
# test/neighbors/tiled_knn.cu
# test/neighbors/haversine.cu
# test/neighbors/ball_cover.cu
# test/neighbors/epsilon_neighborhood.cu
# test/neighbors/refine.cu
# test/neighbors/selection.cu
# src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim128_t8.cu
# src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu
# src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu
# src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim1024_t32.cu
# src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim128_t8.cu
# src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim256_t16.cu
# src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim512_t32.cu
# src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim1024_t32.cu
# src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim128_t8.cu
# src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim256_t16.cu
# src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim512_t32.cu
# src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim1024_t32.cu
test/neighbors/ann_cagra/test_int8_t_uint32_t.cu
test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu
test/neighbors/ann_cagra/test_float_int64_t.cu
test/neighbors/ann_ivf_flat/test_float_int64_t.cu
test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu
test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu
test/neighbors/ann_ivf_pq/test_float_int64_t.cu
test/neighbors/ann_ivf_pq/test_float_uint32_t.cu
test/neighbors/ann_ivf_pq/test_float_int64_t.cu
test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu
test/neighbors/ann_ivf_pq/test_uint8_t_int64_t.cu
test/neighbors/knn.cu
test/neighbors/fused_l2_knn.cu
test/neighbors/tiled_knn.cu
test/neighbors/haversine.cu
test/neighbors/ball_cover.cu
test/neighbors/epsilon_neighborhood.cu
test/neighbors/refine.cu
test/neighbors/selection.cu
src/neighbors/detail/cagra/search_single_cta_float_uint32_dim128_t8.cu
src/neighbors/detail/cagra/search_single_cta_float_uint32_dim256_t16.cu
src/neighbors/detail/cagra/search_single_cta_float_uint32_dim512_t32.cu
Expand Down
4 changes: 2 additions & 2 deletions cpp/test/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
*/
#pragma once

#define RAFT_EXPLICIT_INSTANTIATE_ONLY_CAGRA
#define RAFT_COMPILED_CAGRA
// #define RAFT_EXPLICIT_INSTANTIATE_ONLY_CAGRA
// #define RAFT_COMPILED_CAGRA
#include "../test_utils.cuh"
#include "ann_utils.cuh"
#include <raft/core/resource/cuda_stream.hpp>
Expand Down
4 changes: 3 additions & 1 deletion cpp/test/neighbors/ann_cagra/test_float_int64_t.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

#include <gtest/gtest.h>

#undef RAFT_EXPLICIT_INSTANTIATE_ONLY
#undef RAFT_EXPLICIT_INSTANTIATE_ONLY_CAGRA
#undef RAFT_COMPILED_CAGRA

#include "../ann_cagra.cuh"

namespace raft::neighbors::experimental::cagra {
Expand Down
2 changes: 2 additions & 0 deletions cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
* limitations under the License.
*/

#define RAFT_EXPLICIT_INSTANTIATE_ONLY_CAGRA
#define RAFT_COMPILED_CAGRA
#include <gtest/gtest.h>

#include "../ann_cagra.cuh"
Expand Down
2 changes: 2 additions & 0 deletions cpp/test/neighbors/ann_cagra/test_int8_t_uint32_t.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
* limitations under the License.
*/

#define RAFT_EXPLICIT_INSTANTIATE_ONLY_CAGRA
#define RAFT_COMPILED_CAGRA
#include <gtest/gtest.h>

#include "../ann_cagra.cuh"
Expand Down
2 changes: 2 additions & 0 deletions cpp/test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#define RAFT_EXPLICIT_INSTANTIATE_ONLY_CAGRA
#define RAFT_COMPILED_CAGRA

#include <gtest/gtest.h>

Expand Down

0 comments on commit ec5689b

Please sign in to comment.