-
Notifications
You must be signed in to change notification settings - Fork 197
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
34 changed files
with
7,612 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
481 changes: 481 additions & 0 deletions
481
cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh
Large diffs are not rendered by default.
Oops, something went wrong.
24 changes: 24 additions & 0 deletions
24
cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel.cuh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
/* | ||
* Copyright (c) 2023, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#pragma once | ||
|
||
#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY | ||
#include "search_multi_cta_kernel-inl.cuh" | ||
#endif | ||
|
||
#ifdef RAFT_COMPILED | ||
#include "search_multi_cta_kernel-ext.cuh" | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
2,401 changes: 2,401 additions & 0 deletions
2,401
cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh
Large diffs are not rendered by default.
Oops, something went wrong.
24 changes: 24 additions & 0 deletions
24
cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel.cuh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
/* | ||
* Copyright (c) 2023, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
#pragma once | ||
|
||
#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY | ||
#include "search_single_cta_kernel-inl.cuh" | ||
#endif | ||
|
||
#ifdef RAFT_COMPILED | ||
#include "search_single_cta_kernel-ext.cuh" | ||
#endif |
113 changes: 113 additions & 0 deletions
113
cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
header = """ | ||
/* | ||
* Copyright (c) 2023, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
/* | ||
* NOTE: this file is generated by search_multi_cta_00_generate.py | ||
* | ||
* Make changes there and run in this directory: | ||
* | ||
* > python search_multi_cta_00_generate.py | ||
* | ||
*/ | ||
#include <raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh> | ||
namespace raft::neighbors::experimental::cagra::detail::multi_cta_search { | ||
#define instantiate_multi_cta_search_kernel(TEAM_SIZE, \\ | ||
BLOCK_SIZE, \\ | ||
BLOCK_COUNT, \\ | ||
MAX_ELEMENTS, \\ | ||
MAX_DATASET_DIM, \\ | ||
DATA_T, \\ | ||
DISTANCE_T, \\ | ||
INDEX_T, \\ | ||
LOAD_T) \\ | ||
template __global__ void search_kernel<TEAM_SIZE, \\ | ||
BLOCK_SIZE, \\ | ||
BLOCK_COUNT, \\ | ||
MAX_ELEMENTS, \\ | ||
MAX_DATASET_DIM, \\ | ||
DATA_T, \\ | ||
DISTANCE_T, \\ | ||
INDEX_T, \\ | ||
LOAD_T>(INDEX_T* const result_indices_ptr, \\ | ||
DISTANCE_T* const result_distances_ptr, \\ | ||
const DATA_T* const dataset_ptr, \\ | ||
const size_t dataset_dim, \\ | ||
const size_t dataset_size, \\ | ||
const DATA_T* const queries_ptr, \\ | ||
const INDEX_T* const knn_graph, \\ | ||
const uint32_t graph_degree, \\ | ||
const unsigned num_distilation, \\ | ||
const uint64_t rand_xor_mask, \\ | ||
const INDEX_T* seed_ptr, \\ | ||
const uint32_t num_seeds, \\ | ||
uint32_t* const visited_hashmap_ptr, \\ | ||
const uint32_t hash_bitlen, \\ | ||
const uint32_t itopk_size, \\ | ||
const uint32_t num_parents, \\ | ||
const uint32_t min_iteration, \\ | ||
const uint32_t max_iteration, \\ | ||
uint32_t* const num_executed_iterations); | ||
""" | ||
|
||
trailer = """ | ||
#undef instantiate_multi_cta_search_kernel | ||
} // namespace raft::neighbors::experimental::cagra::detail::multi_cta_search | ||
""" | ||
|
||
mxdim_team = [(128, 8), (256, 16), (512, 32), (1024, 32)] | ||
block = [(64, 16), (128, 8), (256, 4), (512, 2), (1024, 1)] | ||
mxelem = [64, 128, 256] | ||
load_types = ["uint4", "uint64_t"] | ||
search_types = dict( | ||
float_uint32=("float", "uint32_t", "float"), # data_t, idx_t, distance_t | ||
int8_uint32=("int8_t", "uint32_t", "float"), | ||
uint8_uint32=("uint8_t", "uint32_t", "float"), | ||
) | ||
|
||
# knn | ||
for type_path, (data_t, idx_t, distance_t) in search_types.items(): | ||
for (mxdim, team) in mxdim_team: | ||
path = f"search_multi_cta_{type_path}_dim{mxdim}_t{team}.cu" | ||
with open(path, "w") as f: | ||
f.write(header) | ||
for load_t in load_types: | ||
for b in block: | ||
for elem in mxelem: | ||
f.write( | ||
f"instantiate_multi_cta_search_kernel({team}, {b[0]}, {b[1]}, {elem}, {mxdim},{data_t}, {distance_t}, {idx_t}, {load_t});\n" | ||
) | ||
f.write(trailer) | ||
# For pasting into CMakeLists.txt | ||
print(f"src/neighbors/detail/cagra/{path}") |
100 changes: 100 additions & 0 deletions
100
cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim1024_t32.cu
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
|
||
/* | ||
* Copyright (c) 2023, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
/* | ||
* NOTE: this file is generated by search_multi_cta_00_generate.py | ||
* | ||
* Make changes there and run in this directory: | ||
* | ||
* > python search_multi_cta_00_generate.py | ||
* | ||
*/ | ||
|
||
#include <raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh> | ||
|
||
namespace raft::neighbors::experimental::cagra::detail::multi_cta_search { | ||
#define instantiate_multi_cta_search_kernel(TEAM_SIZE, \ | ||
BLOCK_SIZE, \ | ||
BLOCK_COUNT, \ | ||
MAX_ELEMENTS, \ | ||
MAX_DATASET_DIM, \ | ||
DATA_T, \ | ||
DISTANCE_T, \ | ||
INDEX_T, \ | ||
LOAD_T) \ | ||
template __global__ void search_kernel<TEAM_SIZE, \ | ||
BLOCK_SIZE, \ | ||
BLOCK_COUNT, \ | ||
MAX_ELEMENTS, \ | ||
MAX_DATASET_DIM, \ | ||
DATA_T, \ | ||
DISTANCE_T, \ | ||
INDEX_T, \ | ||
LOAD_T>(INDEX_T* const result_indices_ptr, \ | ||
DISTANCE_T* const result_distances_ptr, \ | ||
const DATA_T* const dataset_ptr, \ | ||
const size_t dataset_dim, \ | ||
const size_t dataset_size, \ | ||
const DATA_T* const queries_ptr, \ | ||
const INDEX_T* const knn_graph, \ | ||
const uint32_t graph_degree, \ | ||
const unsigned num_distilation, \ | ||
const uint64_t rand_xor_mask, \ | ||
const INDEX_T* seed_ptr, \ | ||
const uint32_t num_seeds, \ | ||
uint32_t* const visited_hashmap_ptr, \ | ||
const uint32_t hash_bitlen, \ | ||
const uint32_t itopk_size, \ | ||
const uint32_t num_parents, \ | ||
const uint32_t min_iteration, \ | ||
const uint32_t max_iteration, \ | ||
uint32_t* const num_executed_iterations); | ||
|
||
instantiate_multi_cta_search_kernel(32, 64, 16, 64, 1024, float, float, uint32_t, uint4); | ||
instantiate_multi_cta_search_kernel(32, 64, 16, 128, 1024, float, float, uint32_t, uint4); | ||
instantiate_multi_cta_search_kernel(32, 64, 16, 256, 1024, float, float, uint32_t, uint4); | ||
instantiate_multi_cta_search_kernel(32, 128, 8, 64, 1024, float, float, uint32_t, uint4); | ||
instantiate_multi_cta_search_kernel(32, 128, 8, 128, 1024, float, float, uint32_t, uint4); | ||
instantiate_multi_cta_search_kernel(32, 128, 8, 256, 1024, float, float, uint32_t, uint4); | ||
instantiate_multi_cta_search_kernel(32, 256, 4, 64, 1024, float, float, uint32_t, uint4); | ||
instantiate_multi_cta_search_kernel(32, 256, 4, 128, 1024, float, float, uint32_t, uint4); | ||
instantiate_multi_cta_search_kernel(32, 256, 4, 256, 1024, float, float, uint32_t, uint4); | ||
instantiate_multi_cta_search_kernel(32, 512, 2, 64, 1024, float, float, uint32_t, uint4); | ||
instantiate_multi_cta_search_kernel(32, 512, 2, 128, 1024, float, float, uint32_t, uint4); | ||
instantiate_multi_cta_search_kernel(32, 512, 2, 256, 1024, float, float, uint32_t, uint4); | ||
instantiate_multi_cta_search_kernel(32, 1024, 1, 64, 1024, float, float, uint32_t, uint4); | ||
instantiate_multi_cta_search_kernel(32, 1024, 1, 128, 1024, float, float, uint32_t, uint4); | ||
instantiate_multi_cta_search_kernel(32, 1024, 1, 256, 1024, float, float, uint32_t, uint4); | ||
instantiate_multi_cta_search_kernel(32, 64, 16, 64, 1024, float, float, uint32_t, uint64_t); | ||
instantiate_multi_cta_search_kernel(32, 64, 16, 128, 1024, float, float, uint32_t, uint64_t); | ||
instantiate_multi_cta_search_kernel(32, 64, 16, 256, 1024, float, float, uint32_t, uint64_t); | ||
instantiate_multi_cta_search_kernel(32, 128, 8, 64, 1024, float, float, uint32_t, uint64_t); | ||
instantiate_multi_cta_search_kernel(32, 128, 8, 128, 1024, float, float, uint32_t, uint64_t); | ||
instantiate_multi_cta_search_kernel(32, 128, 8, 256, 1024, float, float, uint32_t, uint64_t); | ||
instantiate_multi_cta_search_kernel(32, 256, 4, 64, 1024, float, float, uint32_t, uint64_t); | ||
instantiate_multi_cta_search_kernel(32, 256, 4, 128, 1024, float, float, uint32_t, uint64_t); | ||
instantiate_multi_cta_search_kernel(32, 256, 4, 256, 1024, float, float, uint32_t, uint64_t); | ||
instantiate_multi_cta_search_kernel(32, 512, 2, 64, 1024, float, float, uint32_t, uint64_t); | ||
instantiate_multi_cta_search_kernel(32, 512, 2, 128, 1024, float, float, uint32_t, uint64_t); | ||
instantiate_multi_cta_search_kernel(32, 512, 2, 256, 1024, float, float, uint32_t, uint64_t); | ||
instantiate_multi_cta_search_kernel(32, 1024, 1, 64, 1024, float, float, uint32_t, uint64_t); | ||
instantiate_multi_cta_search_kernel(32, 1024, 1, 128, 1024, float, float, uint32_t, uint64_t); | ||
instantiate_multi_cta_search_kernel(32, 1024, 1, 256, 1024, float, float, uint32_t, uint64_t); | ||
|
||
#undef instantiate_multi_cta_search_kernel | ||
|
||
} // namespace raft::neighbors::experimental::cagra::detail::multi_cta_search |
Oops, something went wrong.