Skip to content

Commit

Permalink
Add cagra template specializations
Browse files Browse the repository at this point in the history
  • Loading branch information
tfeher committed May 14, 2023
1 parent 95428b3 commit 97f7a48
Show file tree
Hide file tree
Showing 34 changed files with 7,612 additions and 4 deletions.
27 changes: 27 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,33 @@ if(RAFT_COMPILE_LIBRARY)
src/neighbors/brute_force_knn_int64_t_float_uint32_t.cu
src/neighbors/brute_force_knn_int_float_int.cu
src/neighbors/brute_force_knn_uint32_t_float_uint32_t.cu
src/neighbors/cagra_build_float_uint32.cu
src/neighbors/cagra_prune_float_uint32.cu
src/neighbors/cagra_search_float_uint32.cu
src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim128_t8.cu
src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu
src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu
src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim1024_t32.cu
src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim128_t8.cu
src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim256_t16.cu
src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim512_t32.cu
src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim1024_t32.cu
src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim128_t8.cu
src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim256_t16.cu
src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim512_t32.cu
src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim1024_t32.cu
src/neighbors/detail/cagra/search_single_cta_float_uint32_dim128_t8.cu
src/neighbors/detail/cagra/search_single_cta_float_uint32_dim256_t16.cu
src/neighbors/detail/cagra/search_single_cta_float_uint32_dim512_t32.cu
src/neighbors/detail/cagra/search_single_cta_float_uint32_dim1024_t32.cu
src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim128_t8.cu
src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim256_t16.cu
src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim512_t32.cu
src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim1024_t32.cu
src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim128_t8.cu
src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim256_t16.cu
src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu
src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu
src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu
src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu
src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu
Expand Down
2 changes: 0 additions & 2 deletions cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@
#include <rmm/cuda_stream_view.hpp>

#include "factory.cuh"
#include "search_multi_cta.cuh"
#include "search_multi_kernel.cuh"
#include "search_plan.cuh"
#include "search_single_cta.cuh"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
#include "compute_distance.hpp"
#include "device_common.hpp"
#include "hashmap.hpp"
#include "search_multi_cta_kernel-inl.cuh"
#include "search_multi_cta_kernel.cuh"
#include "search_plan.cuh"
#include "topk_for_cagra/topk_core.cuh" // TODO replace with raft topk if possible
#include "utils.hpp"
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY
#include "search_multi_cta_kernel-inl.cuh"
#endif

#ifdef RAFT_COMPILED
#include "search_multi_cta_kernel-ext.cuh"
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
#include "device_common.hpp"
#include "hashmap.hpp"
#include "search_plan.cuh"
#include "search_single_cta_kernel-inl.cuh"
#include "search_single_cta_kernel.cuh"
#include "topk_by_radix.cuh"
#include "topk_for_cagra/topk_core.cuh" // TODO replace with raft topk
#include "utils.hpp"
Expand Down
2,401 changes: 2,401 additions & 0 deletions cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY
#include "search_single_cta_kernel-inl.cuh"
#endif

#ifdef RAFT_COMPILED
#include "search_single_cta_kernel-ext.cuh"
#endif
113 changes: 113 additions & 0 deletions cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) 2023, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

header = """
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* NOTE: this file is generated by search_multi_cta_00_generate.py
*
* Make changes there and run in this directory:
*
* > python search_multi_cta_00_generate.py
*
*/
#include <raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh>
namespace raft::neighbors::experimental::cagra::detail::multi_cta_search {
#define instantiate_multi_cta_search_kernel(TEAM_SIZE, \\
BLOCK_SIZE, \\
BLOCK_COUNT, \\
MAX_ELEMENTS, \\
MAX_DATASET_DIM, \\
DATA_T, \\
DISTANCE_T, \\
INDEX_T, \\
LOAD_T) \\
template __global__ void search_kernel<TEAM_SIZE, \\
BLOCK_SIZE, \\
BLOCK_COUNT, \\
MAX_ELEMENTS, \\
MAX_DATASET_DIM, \\
DATA_T, \\
DISTANCE_T, \\
INDEX_T, \\
LOAD_T>(INDEX_T* const result_indices_ptr, \\
DISTANCE_T* const result_distances_ptr, \\
const DATA_T* const dataset_ptr, \\
const size_t dataset_dim, \\
const size_t dataset_size, \\
const DATA_T* const queries_ptr, \\
const INDEX_T* const knn_graph, \\
const uint32_t graph_degree, \\
const unsigned num_distilation, \\
const uint64_t rand_xor_mask, \\
const INDEX_T* seed_ptr, \\
const uint32_t num_seeds, \\
uint32_t* const visited_hashmap_ptr, \\
const uint32_t hash_bitlen, \\
const uint32_t itopk_size, \\
const uint32_t num_parents, \\
const uint32_t min_iteration, \\
const uint32_t max_iteration, \\
uint32_t* const num_executed_iterations);
"""

trailer = """
#undef instantiate_multi_cta_search_kernel
} // namespace raft::neighbors::experimental::cagra::detail::multi_cta_search
"""

mxdim_team = [(128, 8), (256, 16), (512, 32), (1024, 32)]
block = [(64, 16), (128, 8), (256, 4), (512, 2), (1024, 1)]
mxelem = [64, 128, 256]
load_types = ["uint4", "uint64_t"]
search_types = dict(
float_uint32=("float", "uint32_t", "float"), # data_t, idx_t, distance_t
int8_uint32=("int8_t", "uint32_t", "float"),
uint8_uint32=("uint8_t", "uint32_t", "float"),
)

# knn
for type_path, (data_t, idx_t, distance_t) in search_types.items():
for (mxdim, team) in mxdim_team:
path = f"search_multi_cta_{type_path}_dim{mxdim}_t{team}.cu"
with open(path, "w") as f:
f.write(header)
for load_t in load_types:
for b in block:
for elem in mxelem:
f.write(
f"instantiate_multi_cta_search_kernel({team}, {b[0]}, {b[1]}, {elem}, {mxdim},{data_t}, {distance_t}, {idx_t}, {load_t});\n"
)
f.write(trailer)
# For pasting into CMakeLists.txt
print(f"src/neighbors/detail/cagra/{path}")
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@

/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* NOTE: this file is generated by search_multi_cta_00_generate.py
*
* Make changes there and run in this directory:
*
* > python search_multi_cta_00_generate.py
*
*/

#include <raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh>

namespace raft::neighbors::experimental::cagra::detail::multi_cta_search {
#define instantiate_multi_cta_search_kernel(TEAM_SIZE, \
BLOCK_SIZE, \
BLOCK_COUNT, \
MAX_ELEMENTS, \
MAX_DATASET_DIM, \
DATA_T, \
DISTANCE_T, \
INDEX_T, \
LOAD_T) \
template __global__ void search_kernel<TEAM_SIZE, \
BLOCK_SIZE, \
BLOCK_COUNT, \
MAX_ELEMENTS, \
MAX_DATASET_DIM, \
DATA_T, \
DISTANCE_T, \
INDEX_T, \
LOAD_T>(INDEX_T* const result_indices_ptr, \
DISTANCE_T* const result_distances_ptr, \
const DATA_T* const dataset_ptr, \
const size_t dataset_dim, \
const size_t dataset_size, \
const DATA_T* const queries_ptr, \
const INDEX_T* const knn_graph, \
const uint32_t graph_degree, \
const unsigned num_distilation, \
const uint64_t rand_xor_mask, \
const INDEX_T* seed_ptr, \
const uint32_t num_seeds, \
uint32_t* const visited_hashmap_ptr, \
const uint32_t hash_bitlen, \
const uint32_t itopk_size, \
const uint32_t num_parents, \
const uint32_t min_iteration, \
const uint32_t max_iteration, \
uint32_t* const num_executed_iterations);

instantiate_multi_cta_search_kernel(32, 64, 16, 64, 1024, float, float, uint32_t, uint4);
instantiate_multi_cta_search_kernel(32, 64, 16, 128, 1024, float, float, uint32_t, uint4);
instantiate_multi_cta_search_kernel(32, 64, 16, 256, 1024, float, float, uint32_t, uint4);
instantiate_multi_cta_search_kernel(32, 128, 8, 64, 1024, float, float, uint32_t, uint4);
instantiate_multi_cta_search_kernel(32, 128, 8, 128, 1024, float, float, uint32_t, uint4);
instantiate_multi_cta_search_kernel(32, 128, 8, 256, 1024, float, float, uint32_t, uint4);
instantiate_multi_cta_search_kernel(32, 256, 4, 64, 1024, float, float, uint32_t, uint4);
instantiate_multi_cta_search_kernel(32, 256, 4, 128, 1024, float, float, uint32_t, uint4);
instantiate_multi_cta_search_kernel(32, 256, 4, 256, 1024, float, float, uint32_t, uint4);
instantiate_multi_cta_search_kernel(32, 512, 2, 64, 1024, float, float, uint32_t, uint4);
instantiate_multi_cta_search_kernel(32, 512, 2, 128, 1024, float, float, uint32_t, uint4);
instantiate_multi_cta_search_kernel(32, 512, 2, 256, 1024, float, float, uint32_t, uint4);
instantiate_multi_cta_search_kernel(32, 1024, 1, 64, 1024, float, float, uint32_t, uint4);
instantiate_multi_cta_search_kernel(32, 1024, 1, 128, 1024, float, float, uint32_t, uint4);
instantiate_multi_cta_search_kernel(32, 1024, 1, 256, 1024, float, float, uint32_t, uint4);
instantiate_multi_cta_search_kernel(32, 64, 16, 64, 1024, float, float, uint32_t, uint64_t);
instantiate_multi_cta_search_kernel(32, 64, 16, 128, 1024, float, float, uint32_t, uint64_t);
instantiate_multi_cta_search_kernel(32, 64, 16, 256, 1024, float, float, uint32_t, uint64_t);
instantiate_multi_cta_search_kernel(32, 128, 8, 64, 1024, float, float, uint32_t, uint64_t);
instantiate_multi_cta_search_kernel(32, 128, 8, 128, 1024, float, float, uint32_t, uint64_t);
instantiate_multi_cta_search_kernel(32, 128, 8, 256, 1024, float, float, uint32_t, uint64_t);
instantiate_multi_cta_search_kernel(32, 256, 4, 64, 1024, float, float, uint32_t, uint64_t);
instantiate_multi_cta_search_kernel(32, 256, 4, 128, 1024, float, float, uint32_t, uint64_t);
instantiate_multi_cta_search_kernel(32, 256, 4, 256, 1024, float, float, uint32_t, uint64_t);
instantiate_multi_cta_search_kernel(32, 512, 2, 64, 1024, float, float, uint32_t, uint64_t);
instantiate_multi_cta_search_kernel(32, 512, 2, 128, 1024, float, float, uint32_t, uint64_t);
instantiate_multi_cta_search_kernel(32, 512, 2, 256, 1024, float, float, uint32_t, uint64_t);
instantiate_multi_cta_search_kernel(32, 1024, 1, 64, 1024, float, float, uint32_t, uint64_t);
instantiate_multi_cta_search_kernel(32, 1024, 1, 128, 1024, float, float, uint32_t, uint64_t);
instantiate_multi_cta_search_kernel(32, 1024, 1, 256, 1024, float, float, uint32_t, uint64_t);

#undef instantiate_multi_cta_search_kernel

} // namespace raft::neighbors::experimental::cagra::detail::multi_cta_search
Loading

0 comments on commit 97f7a48

Please sign in to comment.