Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename kernel arch finding function for dispatch #1536

Merged
merged 82 commits into from
May 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
82 commits
Select commit Hold shift + click to select a range
fb3e03c
add cutlass source files for fusedL2NN with initial set of changes
mdoijade Nov 18, 2022
7039a50
temp commit to store current progress
mdoijade Nov 30, 2022
356c2e1
working code with atomicCAS based reduction from each thread
mdoijade Dec 21, 2022
6663d4c
improve perf by warp reduce and reduce register by adjusting gemm blo…
mdoijade Dec 27, 2022
ca88ad1
improve the perf of fusedL2NN cutlass kernel by reducing atomic locks…
mdoijade Dec 30, 2022
4e4e6ff
add the custom gemm fused epilogue header required for passing params…
mdoijade Dec 30, 2022
8af3ae7
rename final_op as cg_reduce_op, cleanup
mdoijade Dec 30, 2022
d2c3833
add whole block multi-row single lock impl which performs 5-7% faster…
mdoijade Jan 3, 2023
ea9e62a
merge branch-23.02
mdoijade Jan 4, 2023
b38414f
fix connected components reduction functor for working with cutlass
mdoijade Jan 4, 2023
68b75c8
merge branch-23.02
mdoijade Jan 4, 2023
c89f1c1
fix clang format and copyright year
mdoijade Jan 4, 2023
207a964
add comments to cutlass headers which are customized
mdoijade Jan 4, 2023
ff22154
fix clang format issues
mdoijade Jan 4, 2023
cc3e669
merge branch-23.04
mdoijade Feb 3, 2023
2866e35
merge branch-23.04
mdoijade Feb 3, 2023
bf8f271
fix style checks
mdoijade Feb 3, 2023
c9421a9
fix the style checks for header include
mdoijade Feb 3, 2023
9215abe
fix style check in fused_l2_knn
mdoijade Feb 3, 2023
08745d1
fix style issues
mdoijade Feb 3, 2023
8f7c56e
merge branch-23.04
mdoijade Feb 6, 2023
383b805
Merge branch 'branch-23.04' into fusedL2NN_cutlass
mdoijade Feb 7, 2023
4e541f7
Merge branch 'branch-23.04' into fusedL2NN_cutlass
cjnolet Feb 9, 2023
6deab56
Merge branch 'branch-23.04' into fusedL2NN_cutlass
cjnolet Feb 16, 2023
e7470e8
use new gemm shape 32,128,16, make use of shared mem for reduction in…
mdoijade Feb 23, 2023
20f7fb6
fix formatting issues, move cutlass_check to a separate header follow…
mdoijade Feb 23, 2023
ff09aca
merge remote branch
mdoijade Feb 23, 2023
46f1e83
merge branch-23.04
mdoijade Feb 23, 2023
4941064
add missed out cutlass_utils.cuh
mdoijade Feb 23, 2023
29dd569
Merge branch 'branch-23.04' into fusedL2NN_cutlass
cjnolet Mar 9, 2023
1b062d5
remove usage of shared mem gmem ptr storage, use block offset to stor…
mdoijade Mar 9, 2023
72df128
merge branch-23.04 from upstream changes
mdoijade Mar 9, 2023
e4025c2
Merge branch 'branch-23.04' into fusedL2NN_cutlass
mdoijade Mar 9, 2023
c2b0f70
Merge branch 'branch-23.04' into fusedL2NN_cutlass
cjnolet Mar 11, 2023
03432a3
Merge branch 'branch-23.04' into fusedL2NN_cutlass
cjnolet Mar 12, 2023
8a42b77
use tile32 whenever group size is 32 as it uses optimal reduce, elimi…
mdoijade Mar 16, 2023
1239c20
Merge remote-tracking branch 'remotes/mdoijade_fork/fusedL2NN_cutlass…
mdoijade Mar 16, 2023
a8e0607
Merge branch 'branch-23.04' into fusedL2NN_cutlass
mdoijade Mar 16, 2023
96e6e1e
fix formatting issue
mdoijade Mar 16, 2023
a4e45be
use ops::l2_exp_cutlass_op from updated changes
mdoijade Mar 17, 2023
49ec9ec
Merge branch 'branch-23.04' into fusedL2NN_cutlass
cjnolet Mar 20, 2023
7fa45e1
persistent cutlass version based on grouped gemm though only using it…
mdoijade Apr 20, 2023
087dbf0
add support for vectorized epilogue, move the sources to fused_distan…
mdoijade Apr 27, 2023
edc0cc7
merge upstream changes
mdoijade Apr 27, 2023
4ce7bce
merge branch-23.06
mdoijade Apr 27, 2023
cb020ef
remove the data parallel fusedL2NN cutlass source, fix the connect_co…
mdoijade Apr 28, 2023
42b3890
fix formatting issues
mdoijade Apr 28, 2023
ff537fd
merge branch-23.06
mdoijade Apr 28, 2023
c116d23
fix copyright and formatting issues in couple cutlass source files
mdoijade Apr 28, 2023
527c89d
restrict cutlass kernel to sm 80+ using _cuda_arch_
mdoijade Apr 28, 2023
90c2c39
add ignore -Wtautological-compare to not report warnings as error in …
mdoijade Apr 28, 2023
f5a493b
add ignore warning pragma -Wtautological-compare to pairwise_distance…
mdoijade Apr 28, 2023
3be1e1c
Merge branch 'branch-23.06' into fusedL2NN_cutlass
cjnolet May 4, 2023
ab219fc
fix the failure in cluster_test due to incorrect row_id passed to red…
mdoijade May 4, 2023
4b556a1
Merge remote-tracking branch 'mdoijade_fork/fusedL2NN_cutlass' into f…
mdoijade May 4, 2023
06af196
Merge branch 'branch-23.06' into fusedL2NN_cutlass
mdoijade May 4, 2023
5ab9f59
remove redundant code in custom_epilogue_with_broadcast.h, and add co…
mdoijade May 4, 2023
294b306
Merge remote-tracking branch 'mdoijade_fork/fusedL2NN_cutlass' into f…
mdoijade May 4, 2023
65d9c8e
Merge branch 'branch-23.06' into fusedL2NN_cutlass
mdoijade May 5, 2023
1fce36a
remove the redundant header inclusion in fusedl2knn tests which was p…
mdoijade May 5, 2023
b072d80
remove commented code and fix formating
mdoijade May 5, 2023
a7c7303
add larger input test cases to test fusedL2nn all code paths
mdoijade May 5, 2023
d741c4a
fix launch config for small input sizes, fix atomicCAS optimal path s…
mdoijade May 8, 2023
7e4b298
fix formatting issues
mdoijade May 10, 2023
c3ab218
Merge remote-tracking branch 'mdoijade_fork/fusedL2NN_cutlass' into f…
mdoijade May 10, 2023
bc1bfad
Merge branch 'branch-23.06' into fusedL2NN_cutlass
mdoijade May 10, 2023
7f1d30d
move raft copyright below cutlass's and fix start year to be 2023
mdoijade May 10, 2023
514bd1e
add specialization for veclen=1 with 32x128x16 having no reg spills, …
mdoijade May 11, 2023
cba6bbc
move smem init code to their respective tile iterators instead of hav…
mdoijade May 11, 2023
71b91e9
combine the optimal path and non-optimal path gmem writes
mdoijade May 12, 2023
ad2ce75
use the new dispatch mechanism to select the appropriate kernel at ru…
mdoijade May 12, 2023
6b87cc9
fix comments and small cleanup
mdoijade May 12, 2023
383f5bd
add comment about persistent_gemm.h mapping to its cutlass version
mdoijade May 12, 2023
3f4ceff
fix all formatting issues
mdoijade May 12, 2023
5dbf438
Merge remote-tracking branch 'mdoijade_fork/fusedL2NN_cutlass' into f…
mdoijade May 12, 2023
268c218
Merge branch 'branch-23.06' into fusedL2NN_cutlass
mdoijade May 12, 2023
e21a8da
Merge branch 'branch-23.06' into fusedL2NN_cutlass
cjnolet May 16, 2023
5f8f33b
remove dead/commented code from epilogue broadcast header
mdoijade May 16, 2023
d603f81
Merge remote-tracking branch 'mdoijade_fork/fusedL2NN_cutlass' into f…
mdoijade May 16, 2023
4140734
fix formatting
mdoijade May 16, 2023
61a86d3
rename kernel_runtime_arch as kernel_virtual_arch and update comments…
mdoijade May 19, 2023
52ce5d8
merge branch-23.06
mdoijade May 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions cpp/include/raft/distance/detail/fused_l2_nn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -325,12 +325,13 @@ void fusedL2NNImpl(OutT* min,
decltype(distance_op),
decltype(fin_op)>;

// Get pointer to fp32 SIMT kernel to determine the runtime architecture of the
// current system. Other methods to determine the architecture (that do not
// Get pointer to fp32 SIMT kernel to determine the best compute architecture
// out of all for which the kernel was compiled for that matches closely
// to the current device. Other methods to determine the architecture (that do not
// require a pointer) can be error prone. See:
// https://github.com/NVIDIA/cub/issues/545
void* kernel_ptr = reinterpret_cast<void*>(kernel);
auto runtime_arch = arch::kernel_runtime_arch(kernel_ptr);
auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr);
auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future());

if (cutlass_range.contains(runtime_arch)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,14 @@ void pairwise_matrix_dispatch(OpT distance_op,
auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future());
auto legacy_range = arch::SM_range(arch::SM_min(), arch::SM_80());

// Get pointer to SM60 kernel to determine the runtime architecture of the
// current system. Other methods to determine the architecture (that do not
// Get pointer to SM60 kernel to determine the best compute architecture
// out of all for which the kernel was compiled for that matches closely
// to the current device. Other methods to determine the architecture (that do not
// require a pointer) can be error prone. See:
// https://github.com/NVIDIA/cub/issues/545
auto sm60_wrapper = pairwise_matrix_sm60_get_wrapper(distance_op, params, legacy_range);
void* kernel_ptr = reinterpret_cast<void*>(sm60_wrapper.kernel_ptr);
auto runtime_arch = arch::kernel_runtime_arch(kernel_ptr);
auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr);

if (cutlass_range.contains(runtime_arch)) {
// If device is SM_80 or later, use CUTLASS-based kernel.
Expand Down
15 changes: 8 additions & 7 deletions cpp/include/raft/util/arch.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ namespace raft::util::arch {
* compute architecture that a kernel is compiled with. It can only be used
* inside kernels with a template argument.
*
* - raft::util::arch::kernel_runtime_arch : a function that computes at *run-time*
* which version of a kernel will launch (i.e., it will return the compute
* architecture of the version of the kernel that will be launched by the
* driver).
* - raft::util::arch::kernel_virtual_arch : a function that computes at *run-time*
* which version of a kernel will launch (i.e., it will return the virtual compute
* architecture of the version of the kernel that it was compiled for which
* will be launched by the driver).
*
* - raft::util::arch::SM_range : a compile-time value to represent an open interval
* of compute architectures. This can be used to check if the current
Expand Down Expand Up @@ -97,7 +97,7 @@ struct SM_compute_arch {
// compute architecture of the version of the kernel that the driver picks when
// the kernel runs.
struct SM_runtime {
friend SM_runtime kernel_runtime_arch(void*);
friend SM_runtime kernel_virtual_arch(void*);

private:
const int _version;
Expand All @@ -107,15 +107,16 @@ struct SM_runtime {
__host__ __device__ int value() const { return _version; }
};

// Computes which compute architecture of a kernel will run
// Computes which virtual compute architecture the given kernel was compiled for,
// driver picks the version of the kernel that closely matches the current hardware.
//
// Semantics are described above in the documentation of SM_runtime.
//
// This function requires a pointer to the kernel that will run. Other methods
// to determine the architecture (that do not require a pointer) can be error
// prone. See:
// https://github.com/NVIDIA/cub/issues/545
inline SM_runtime kernel_runtime_arch(void* kernel)
inline SM_runtime kernel_virtual_arch(void* kernel)
{
cudaFuncAttributes attributes;
RAFT_CUDA_TRY(cudaFuncGetAttributes(&attributes, kernel));
Expand Down