diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml index 69629645ab..4572ae1b98 100644 --- a/.github/ISSUE_TEMPLATE/config.yml +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -2,4 +2,4 @@ blank_issues_enabled: true contact_links: - name: CUTLASS Discord url: https://discord.gg/nvidiadeveloper - about: Come chat about using and contributing to CUTLASS! + about: Come chat about using and contributing to CUTLASS! diff --git a/CHANGELOG.md b/CHANGELOG.md index 7bb701ed3e..80f712397a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,20 @@ # NVIDIA CUTLASS Changelog +## [3.3](https://github.com/NVIDIA/cutlass/releases/tag/v3.3) (2023-10-31) +* [Mixed Precision Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input operand types. +* [Mixed Precision Ampere GEMMs](https://github.com/NVIDIA/cutlass/commit/7d8317a63e0a978a8dbb3c1fb7af4dbe4f286616) with support for canonical layouts (TN) and {fp16, bf16} x {s8/u8}. +* [Copy Async based Hopper GEMMs](/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_cooperative.cu) - which support lower than 16B aligned input tensors. +* Kernel schedules and Builder support for mixed precision and Copy Async GEMMs with < 16B aligned input tensors. +* Profiler support for lower-aligned Hopper GEMMs. +* Performance Improvements to [Scatter-Gather Hopper Example](/examples/52_hopper_gather_scatter_fusion) +* Sub-Byte type fixes and improvements +* EVT Support for RELU with Aux bitmap tensor store (used in dRELU). See [SM90 EVT fusions](/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp) for details. +* Fusion support for backprop fusions including drelu, dgelu, and dbias. +* Support for void-C kernels and SM80 mixed-precision GEMMs in the CUTLASS Python interface + +## [3.2.2](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1) (2023-10-25) +* Minor patch for issue/1138 + ## [3.2.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1) (2023-09-22) * Python support SM90 Epilogue Visitor Tree (EVT) on top of the C++ support released in 3.2.0. * SM80 EVT support in C++ and Python. diff --git a/CMakeLists.txt b/CMakeLists.txt index b880de0a52..ec5dc5f2b5 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,7 +40,7 @@ endif() message(STATUS "CMake Version: ${CMAKE_VERSION}") set(IMPLICIT_CMAKE_CXX_STANDARD OFF CACHE BOOL "Do not explicitly specify -std=c++11 if set") -project(CUTLASS VERSION 3.2.1 LANGUAGES CXX) +project(CUTLASS VERSION 3.3.0 LANGUAGES CXX) include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake) if (CUDA_VERSION VERSION_LESS 11.3) @@ -87,19 +87,6 @@ set(CUTLASS_TEST_LEVEL "0" CACHE STRING "Level of tests to compile.") find_package(Python3 3.5 COMPONENTS Interpreter REQUIRED) -# Install cutlass_library Python package -execute_process( - WORKING_DIRECTORY ${CUTLASS_DIR}/python - COMMAND ${Python3_EXECUTABLE} ${CUTLASS_DIR}/python/setup_library.py develop --user - RESULT_VARIABLE cutlass_lib_GENERATOR_INSTALL_RESULT - OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/cutlass_library_installation.log - ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/cutlass_library_installation.log -) - -if(NOT cutlass_lib_GENERATOR_INSTALL_RESULT EQUAL 0) - message(FATAL_ERROR "Error installing cutlass_library package. See ${CMAKE_CURRENT_BINARY_DIR}/cutlass_library_installation.log") -endif() - ################################################################################ set(CUTLASS_ENABLE_HEADERS_ONLY OFF CACHE BOOL "Enable only the header library") @@ -829,6 +816,8 @@ function(cutlass_add_executable_tests NAME TARGET) set(TEST_GEN_DIR ${CMAKE_CURRENT_BINARY_DIR}/ctest/${TEST_NAME}) file(MAKE_DIRECTORY ${TEST_GEN_DIR}) + set(TEST_SETS_SUPPORTED default) + set(TEST_EXE_PATH $) set(TEST_USE_EXTENDED_FORMAT ON) configure_file("${CUTLASS_CTEST_TEMPLATE_FILE}" "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.cmake" @ONLY) @@ -883,7 +872,12 @@ if (CUTLASS_INSTALL_TESTS) file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/ctest") - file(WRITE "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "# Generated File\n") + file(WRITE "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "# Generated File\n\n") + + file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "if (NOT DEFINED ENV{CUTLASS_TEST_SET})\n") + file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" " set(ENV{CUTLASS_TEST_SET} \"default\")\n") + file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "endif()\n\n") + foreach(GENERATED_FILE ${CUTLASS_CTEST_GENERATED_FILES}) file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "include(${GENERATED_FILE})\n") endforeach() diff --git a/README.md b/README.md index 2d09925798..f1ea3b4364 100644 --- a/README.md +++ b/README.md @@ -1,8 +1,8 @@ ![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") -# CUTLASS 3.2 +# CUTLASS 3.3 -_CUTLASS 3.2 - August 2023_ +_CUTLASS 3.3 - October 2023_ CUTLASS is a collection of CUDA C++ template abstractions for implementing high-performance matrix-matrix multiplication (GEMM) and related computations at all levels @@ -41,26 +41,17 @@ and improves code composability and readability. More documentation specific to In addition to GEMMs, CUTLASS implements high-performance convolution via the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline. This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components. -# What's New in CUTLASS 3.2 - -CUTLASS 3.2.0 is an update to CUTLASS adding: -- New warp-specialized persistent FP8 GEMM kernel [kernel schedules](/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) and [mainloops](/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters. An example showcasing [Hopper warp-specialized FP8 GEMMs](/examples/54_hopper_fp8_warp_specialized_gemm). -- New [Epilogue Visitor Tree (EVT)](/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu) support for Hopper TMA epilogues. EVTs allows for user-defined customized epilogue fusion patterns without having to write a new epilogue. -- [Stream-K](/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp) feature for Hopper. Note that this is only a functional implementation of stream-K, and should not be used for performance comparison. Optimizations are expected in a future release. -- Improved CTA rasterization and support for CTA swizzling for Hopper kernels using the [Tile Scheduler](/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp). -- Improved performance for [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA. -- [Hopper GEMM+Permute](/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu), an example of fusing tensor reordering (permutation) with GEMM mainloop or epilogue. -- New CUTLASS 2D Convolution Python interface. New [example](/examples/python/03_basic_conv2d.ipynb) here. -- Support for Windows (MSVC) builds. - -CUTLASS 3.2.1 is an update to CUTLASS adding: -- Python support SM90 Epilogue Visitor Tree (EVT) on top of the C++ support released in 3.2.0. -- SM80 EVT support in C++ and Python. -- Splitting CUTLASS library into smaller units based on operation, arch and datatypes. See [1105](https://github.com/NVIDIA/cutlass/discussions/1105) for details. -- Making `tools/library/scripts` packageable - `tools/library/scripts` is now moving to `python/cutlass_library`. See the Python [README](/python/README.md) for details. -- SM90 TF32 kernel improvements for all layouts. -- SM90 rasterization direction support in the CUTLASS profiler. -- Improvement for CUTLASS profiler build times. +# What's New in CUTLASS 3.3 + +CUTLASS 3.3.0 is an update to CUTLASS adding: + +- New [Mixed Precision Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input types with optimal performance. +- New [Mixed Precision Ampere GEMMs](https://github.com/NVIDIA/cutlass/commit/7d8317a63e0a978a8dbb3c1fb7af4dbe4f286616) with support for canonical layouts (TN) and {fp16, bf16} x {s8/u8}. They also include fast numeric conversion recipes and warp level shuffles to achieve optimal performance. +- New [Copy Async based Hopper GEMMs](/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_cooperative.cu) - which support lower than 16B aligned input tensors (across s8/fp8/fp16/bf16/tf32 types) with optimal performance. As a part of this, new kernel schedules, and Copy Ops [SM80\_CP\_ASYNC\_CACHE\_\*](/include/cute/arch/copy_sm80.hpp) were also added. +- EVT Support for RELU with Aux bitmap tensor store (used in dRELU). See [SM90 EVT fusions](/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp) for details. +- Various subbyte enhancements like tagged device ptrs, support for vectorized copy, various operators to treat subbyte iterators as pointers, and full-fledged CuTe Tensor support. +- Support for Clang as a host compiler +- Support for void-C kernels and SM80 mixed-precision GEMMs in the CUTLASS Python interface Minimum requirements: @@ -103,7 +94,7 @@ as shown in the above figure. Tensor Core operations are implemented using CUDA # Compatibility CUTLASS requires a C++17 host compiler and -performs best when built with the [**CUDA 12.2 Toolkit**](https://developer.nvidia.com/cuda-toolkit). +performs best when built with the [**CUDA 12.2.2 Toolkit**](https://developer.nvidia.com/cuda-toolkit-archive). It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, CUDA 12.0 and CUDA 12.1. ## Operating Systems @@ -114,9 +105,10 @@ We have tested the following environments. | Ubuntu 18.04 | GCC 7.5.0 | | Ubuntu 20.04 | GCC 10.3.0 | | Ubuntu 22.04 | GCC 11.2.0 | +| Ubuntu 22.04 | Clang 10.0.0 | +| Ubuntu 22.04 | Clang 14.0.6 | | Windows 10.0 | Visual Studio 2019 v16.11.27 | -Note: We plan to add Clang compiler support soon. Note: GCC 8.5.0 has known regressions regarding fold expressions and overloaded operators. Using GCC 7.5.0 or (preferred) GCC >= 9 is recommended. ## Hardware diff --git a/bin2hex.cmake b/bin2hex.cmake index b0773dd659..44935f2d24 100644 --- a/bin2hex.cmake +++ b/bin2hex.cmake @@ -6,7 +6,7 @@ function(FILE_TO_C_STRING FILENAME VARIABLE_NAME OUTPUT_STRING ZERO_TERMINATED) endif() string(REGEX REPLACE "(....)" "\\1\n" HEX_OUTPUT ${HEX_INPUT}) - string(REGEX REPLACE "([0-9a-f][0-9a-f])" "0x\\1," HEX_OUTPUT ${HEX_OUTPUT}) + string(REGEX REPLACE "([0-9a-f][0-9a-f])" "char(0x\\1)," HEX_OUTPUT ${HEX_OUTPUT}) set(HEX_OUTPUT "static char const ${VARIABLE_NAME}[] = {\n ${HEX_OUTPUT}\n};\n") diff --git a/cmake/CTestTestfile.configure.cmake b/cmake/CTestTestfile.configure.cmake index 3fc3994647..524ba1f82a 100644 --- a/cmake/CTestTestfile.configure.cmake +++ b/cmake/CTestTestfile.configure.cmake @@ -1,5 +1,12 @@ # Generated file +set(TEST_SETS_SUPPORTED @TEST_SETS_SUPPORTED@) + +#? if (DEFINED ENV{CUTLASS_TEST_SET} AND NOT ENV{CUTLASS_TEST_SET} IN_LIST TEST_SETS_SUPPORTED) +#? message(STATUS "Skipping tests for @TEST_EXE_PATH@ as $ENV{CUTLASS_TEST_SET} is not in the set of ${TEST_SETS_SUPPORTED}.") +#? return() +#? endif() + set(TEST_EXE_PATH @TEST_EXE_PATH@) set(TEST_EXE_WORKING_DIRECTORY @TEST_EXE_WORKING_DIRECTORY@) set(CUTLASS_USE_EXTENDED_ADD_TEST_FORMAT @TEST_USE_EXTENDED_FORMAT@) @@ -11,4 +18,3 @@ else() endif() @_INLINE_PER_TEST_CODE@ - diff --git a/examples/47_ampere_gemm_universal_streamk/CMakeLists.txt b/examples/47_ampere_gemm_universal_streamk/CMakeLists.txt index 783cbf8448..07928e1c5c 100644 --- a/examples/47_ampere_gemm_universal_streamk/CMakeLists.txt +++ b/examples/47_ampere_gemm_universal_streamk/CMakeLists.txt @@ -33,7 +33,13 @@ cutlass_example_add_executable( ampere_gemm_universal_streamk.cu ) +# Deliberately test non-square sizes to ensure that internal transpose is +# not triggered when using SM80 EVT +set(TEST_COMMAND_00 --m=512 --n=768 --k=1152) + cutlass_example_add_executable( 47_ampere_gemm_universal_streamk_broadcast ampere_gemm_universal_streamk_broadcast.cu + TEST_COMMAND_OPTIONS + TEST_COMMAND_00 ) diff --git a/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu b/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu index 5a98c7ae64..0ba72abb78 100644 --- a/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu +++ b/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu @@ -107,16 +107,6 @@ using ClusterShape = Shape<_1,_2,_1>; // S using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default setting in the Collective Builder -using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< - ArchTag, OperatorClass, - ElementA, LayoutA, AlignmentA, - ElementB, LayoutB, AlignmentB, - ElementAccumulator, - TileShape, ClusterShape, - cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::collective::KernelScheduleAuto - >::CollectiveOp; - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, @@ -127,6 +117,17 @@ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBui cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< Shape, // Indicates ProblemShape CollectiveMainloop, @@ -397,6 +398,7 @@ int run(Options &options) GpuTimer timer; timer.start(); for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); CUTLASS_CHECK(gemm.run()); } timer.stop(); diff --git a/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu b/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu index 25f637ac49..d0e479397b 100644 --- a/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu +++ b/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu @@ -86,7 +86,7 @@ CUTLASS builders make an attempt to pick the best schedule when `Auto` is provided such that the assembled collectives have the best performance, but this is not a guarantee. A user relying on `Auto` may get a free performance upgrade with newer CUTLASS releases in case we can provide more optimized - implementations that the builder can transparently assemble for `Auto`. But a user should not rely on + implementations that the builder can transparently assemble for `Auto`. But a user should not rely on `Auto` if they require a specific scheduling policy and/or stage count to be used. If a user decides to let the builders pick the collective specialization via `Auto` schedules, @@ -289,7 +289,7 @@ struct ExampleRunner { // EVTs can be constructed by composing the fundamental load/store/compute visitor operations defined in include/cutlass/epilogue/fusion // For more complex examples of EVT construction please refer to include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp using CustomEVT = // alpha * acc + beta * C - cutlass::epilogue::fusion::Sm90EVT, // beta * C + (alpha * acc) + cutlass::epilogue::fusion::Sm90EVT, // beta * C + (alpha * acc) cutlass::epilogue::fusion::Sm90ScalarBroadcast, // beta cutlass::epilogue::fusion::Sm90SrcFetch, // C cutlass::epilogue::fusion::Sm90EVT, // alpha * acc @@ -322,7 +322,7 @@ struct ExampleRunner { ElementAccumulator, Shape<_128,_128,_64>, Shape<_2,_1,_1>, cute::conditional_t, - cutlass::gemm::collective::StageCountAutoCarveout<(int)sizeof(typename CollectiveEpilogue::SharedStorage)>, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, StageCountType>, MainloopScheduleType >::CollectiveOp; diff --git a/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu b/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu index c99afc05e6..b962e3dcfc 100644 --- a/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu +++ b/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu @@ -93,7 +93,7 @@ struct Options { int mode = 1; // N-mode gather/scatter by default float alpha = 1.0f; - float beta = 1.0f; + float beta = 0.0f; bool reference_check = true; int iterations = 20; @@ -179,19 +179,27 @@ struct ExampleRunner { // Useful aliases + using ProblemShape = Shape; + + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideB = cutlass::gemm::TagToStrideB_t; + using StrideC = cutlass::gemm::TagToStrideC_t; + using StrideD = cutlass::gemm::TagToStrideC_t; + // Alias to for the epilogue type that supports gather/scatter - using Epilogue = cutlass::epilogue::collective::EpilogueGatherScatter< - cutlass::gemm::TagToStrideC_t, - cutlass::gemm::TagToStrideC_t, - cutlass::epilogue::thread::LinearCombination< - ElementD, 1, - ElementAccumulator, ElementComputeEpilogue, - cutlass::epilogue::thread::ScaleType::Default, - cutlass::FloatRoundStyle::round_to_nearest, ElementC - >, - cutlass::gemm::EpilogueDefault, - GatherC, - ScatterD + using Epilogue = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::EpilogueGatherScatter< + StrideC, StrideD, + cutlass::epilogue::thread::LinearCombination< + ElementD, 1, + ElementAccumulator, ElementComputeEpilogue, + cutlass::epilogue::thread::ScaleType::Default, + cutlass::FloatRoundStyle::round_to_nearest, ElementC + >, + cutlass::gemm::EpilogueDefault, + GatherC, + ScatterD + > >; // Alias to for the mainloop type @@ -202,27 +210,21 @@ struct ExampleRunner ElementAccumulator, Shape<_128,_128,_64>, Shape<_1,_1,_1>, - cutlass::gemm::collective::StageCount<5>, - cutlass::gemm::KernelMultistage + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecialized >::CollectiveOp; - using ProblemShape = Shape; - using Kernel = cutlass::gemm::kernel::GemmGather< ProblemShape, Mainloop, Epilogue, + void, GatherA, GatherB >; using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - using StrideA = typename Kernel::StrideA; - using StrideB = typename Kernel::StrideB; - using StrideC = typename Kernel::StrideC; - using StrideD = typename Kernel::StrideD; - static constexpr bool DoGatherA = not cutlass::platform::is_same::value; static constexpr bool DoGatherB = not cutlass::platform::is_same::value; static constexpr bool DoGatherC = not cutlass::platform::is_same::value; @@ -250,16 +252,19 @@ struct ExampleRunner using MainloopRef = Mainloop; - using EpilogueRef = typename cutlass::epilogue::collective::DefaultEpilogue< - StrideC, StrideD, - typename Epilogue::ThreadEpilogueOp, - typename Epilogue::EpilogueSchedule + using EpilogueRef = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::DefaultEpilogue< + StrideC, StrideD, + typename Epilogue::ThreadEpilogueOp, + typename Epilogue::EpilogueSchedule + > >; using KernelRef = cutlass::gemm::kernel::GemmUniversal< - ProblemShape, - MainloopRef, - EpilogueRef + ProblemShape, + MainloopRef, + EpilogueRef, + void >; using GemmRef = cutlass::gemm::device::GemmUniversalAdapter; @@ -289,9 +294,10 @@ struct ExampleRunner >::CollectiveOp; using KernelOpt = cutlass::gemm::kernel::GemmUniversal< - ProblemShape, - MainloopOpt, - EpilogueOpt + ProblemShape, + MainloopOpt, + EpilogueOpt, + void >; using GemmOpt = cutlass::gemm::device::GemmUniversalAdapter; @@ -404,6 +410,7 @@ struct ExampleRunner typename Epilogue::ScatterD{gather_indices.get()} }, hw_info, + {}, typename Kernel::GatherA{gather_indices.get()}, typename Kernel::GatherB{gather_indices.get()} }, diff --git a/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp b/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp index 579122210a..07de1639d6 100644 --- a/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp +++ b/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp @@ -47,9 +47,9 @@ template < class ProblemShape_, class CollectiveMainloop_, class CollectiveEpilogue_, + class TileScheduler_, class GatherA_, - class GatherB_, - class TileScheduler_ = void + class GatherB_ > class GemmGather { @@ -58,8 +58,6 @@ class GemmGather // Type Aliases // using ProblemShape = ProblemShape_; - using TileSchedulerTag = TileScheduler_; - using TileScheduler = TileScheduler_; static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, "ProblemShape{} should be or "); @@ -74,8 +72,10 @@ class GemmGather using StrideB = typename CollectiveMainloop::StrideB; using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; using MainloopArguments = typename CollectiveMainloop::Arguments; using MainloopParams = typename CollectiveMainloop::Params; + static_assert(ArchTag::kMinComputeCapability >= 90); // Epilogue derived types using CollectiveEpilogue = CollectiveEpilogue_; @@ -85,17 +85,48 @@ class GemmGather using StrideD = typename CollectiveEpilogue::StrideD; using EpilogueArguments = typename CollectiveEpilogue::Arguments; using EpilogueParams = typename CollectiveEpilogue::Params; - static_assert(std::is_same_v, - "Mainloop and epilogue do not agree on accumulator value type."); + + static_assert(cute::is_void_v or cute::is_same_v, + "Non-persistent warp-specialized kernel does not support specializing the tile scheduler."); + using TileSchedulerTag = TileScheduler_; + using TileScheduler = typename detail::TileSchedulerSelector< + TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; using GatherA = GatherA_; using GatherB = GatherB_; - static constexpr int SharedStorageSize = static_cast(cute::max( - sizeof(typename CollectiveMainloop::SharedStorage), - sizeof(typename CollectiveEpilogue::SharedStorage))); + // Kernel level shared memory storage + struct SharedStorage { + union TensorStorage { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + struct PipelineStorage : cute::aligned_struct<16> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + } pipelines; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + using GmemTiledCopyA = typename CollectiveMainloop::GmemTiledCopyA; + using GmemTiledCopyB = typename CollectiveMainloop::GmemTiledCopyB; + static_assert(cute::size(GmemTiledCopyA{}) == cute::size(GmemTiledCopyB{}), "Number of threads in A/B tiled copies must be the same."); - static constexpr uint32_t MaxThreadsPerBlock = cute::size(TiledMma{}); + static constexpr uint32_t NumLoadWarpGroups = cute::size(GmemTiledCopyA{}) / NumThreadsPerWarpGroup; + static constexpr uint32_t NumMmaWarpGroups = cute::size(TiledMma{}) / NumThreadsPerWarpGroup; + static constexpr uint32_t NumWarpGroups = NumLoadWarpGroups + NumMmaWarpGroups; + static_assert(NumWarpGroups == 2 || NumWarpGroups == 3, "Number of warp groups must be 2 or 3 for good performance."); + + static constexpr uint32_t MaxThreadsPerBlock = NumWarpGroups * NumThreadsPerWarpGroup; static constexpr uint32_t MinBlocksPerMultiprocessor = 1; // Device side arguments @@ -105,6 +136,7 @@ class GemmGather MainloopArguments mainloop{}; EpilogueArguments epilogue{}; KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; GatherA gather_A{}; GatherB gather_B{}; }; @@ -123,14 +155,20 @@ class GemmGather // Methods // - // Convert to underlying arguments. + // Convert to underlying arguments. In this case, a simple copy for the aliased type. static Params to_underlying_arguments(Arguments const& args, void* workspace) { (void) workspace; + auto problem_shape = args.problem_shape; + if constexpr (detail::IF_SWAP_AB::value) { + // swap M/N + get<0>(problem_shape) = get<1>(args.problem_shape); + get<1>(problem_shape) = get<0>(args.problem_shape); + } return { args.mode, - args.problem_shape, + problem_shape, CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace), args.gather_A, @@ -138,17 +176,18 @@ class GemmGather }; } - static - Status - initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { - return Status::kSuccess; - } - - static + CUTLASS_HOST_DEVICE static bool can_implement(Arguments const& args) { - return args.mode == GemmUniversalMode::kGemm or - (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + bool implementable = (args.mode == GemmUniversalMode::kGemm) or + (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + return implementable; } static @@ -157,23 +196,23 @@ class GemmGather return 0; } - static constexpr - dim3 - get_grid_shape(Params const& params) { - int batch_count = 1; - if constexpr (rank(ProblemShape{}) == 4) { - batch_count = cute::size<3>(params.problem_shape); - } + static + cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + return Status::kSuccess; + } - return dim3( - cute::size(cute::ceil_div(cute::shape<0>(params.problem_shape), cute::shape<0>(TileShape{}))), - cute::size(cute::ceil_div(cute::shape<1>(params.problem_shape), cute::shape<1>(TileShape{}))), - batch_count - ); + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + auto cluster_shape = Shape<_1,_1,_1>{}; + auto tile_shape = TileShape{}; + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + return TileScheduler::get_tiled_cta_shape_mnl( + problem_shape_MNKL, tile_shape, cluster_shape); } - static constexpr - dim3 + static dim3 get_block_shape() { return dim3(MaxThreadsPerBlock, 1, 1); } @@ -184,8 +223,75 @@ class GemmGather using namespace cute; using X = Underscore; + // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. + #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) + if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { + printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); + return; + } + #endif + + enum class WarpGroupRole { + Producer = 0, + Consumer = 1, + }; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + int thread_idx = int(threadIdx.x); + int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; + int warp_group_idx = canonical_warp_group_idx(); + CUTLASS_ASSERT(warp_group_idx < NumWarpGroups); + WarpGroupRole warp_group_role = warp_group_idx < NumLoadWarpGroups ? WarpGroupRole::Producer : WarpGroupRole::Consumer; + + // Mainloop Load pipeline + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + typename MainloopPipeline::Params mainloop_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.producer_arv_count = NumLoadWarpGroups * NumThreadsPerWarpGroup; + mainloop_pipeline_params.consumer_arv_count = NumMmaWarpGroups * NumThreadsPerWarpGroup; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params); + + // Epilogue Load pipeline + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.producer_arv_count = NumLoadWarpGroups * NumThreadsPerWarpGroup; + epi_load_pipeline_params.consumer_arv_count = NumMmaWarpGroups * NumThreadsPerWarpGroup; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + // Initialize starting pipeline states for the collectives + typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; + + // For the DMA Load (producer) we start with an opposite phase + // i.e., we skip all waits since we know that the buffer is indeed empty + PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + // Preconditions - CUTE_STATIC_ASSERT(is_static::value); + static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); // Separate out problem shape for convenience // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) @@ -195,69 +301,113 @@ class GemmGather auto K = get<2>(problem_shape_MNKL); auto L = get<3>(problem_shape_MNKL); - // Preconditions - static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - - // Get the appropriate blocks for this thread block -- potential for thread block locality - int thread_idx = int(threadIdx.x); - auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) - auto [m_coord, n_coord, l_coord] = blockIdx; - auto blk_coord_mnkl = make_coord(m_coord, n_coord, _, l_coord); // (m,n,k,l) - // Represent the full tensors Tensor mA_mkl = make_gather_tensor(make_gmem_ptr(params.mainloop.ptr_A), make_shape(M,K,L), params.mainloop.dA, params.gather_A); //(m,k,l) Tensor mB_nkl = make_gather_tensor(make_gmem_ptr(params.mainloop.ptr_B), make_shape(N,K,L), params.mainloop.dB, params.gather_B); //(n,k,l) - // Get batch slice - Tensor mA_mk = mA_mkl(_,_,l_coord); // (m,k) - Tensor mB_nk = mB_nkl(_,_,l_coord); // (n,k) + // Get the appropriate blocks for this thread block -- potential for thread block locality + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + TiledMma tiled_mma; - // Slice to get the tiles this thread block is responsible for - Tensor gA = local_tile(mA_mk, blk_shape, take<0,3>(blk_coord_mnkl), Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) - Tensor gB = local_tile(mB_nk, blk_shape, take<0,3>(blk_coord_mnkl), Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, blk_shape, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, blk_shape, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) - // Compute tile residues for predication - auto m_max_coord = M - size<0>(gA) * get<0>(blk_coord_mnkl); // M - BLK_M * m_coord - auto n_max_coord = N - size<0>(gB) * get<1>(blk_coord_mnkl); // N - BLK_N * n_coord - auto k_residue = K - size<1>(gA) * size<2>(gA); // K - BLK_K * k_coord_max - auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue); + // Compute m_coord, n_coord, and l_coord with their post-tiled shapes + auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mkl)); + auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nkl)); + auto l_coord = idx2crd(int(blockIdx.z), shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - // Allocate the tiled_mma and the accumulators for the (M,N) blk_shape - TiledMma tiled_mma; - Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) - clear(accumulators); + // Slice with m_coord and n_coord + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + // Get pipeline iterators and increments from tensor shapes auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); - int k_tile_count = size<2>(gA); - - // Perform the collective scoped MMA - CollectiveMainloop collective_mma; - collective_mma( - accumulators, - gA, - gB, - accumulators, - k_tile_iter, k_tile_count, - residue_mnk, - thread_idx, - smem_buf - ); - - // Epilogue and write to gD - CollectiveEpilogue epilogue{params.epilogue}; - epilogue( - problem_shape_MNKL, - blk_shape, - blk_coord_mnkl, - accumulators, - tiled_mma, - residue_mnk, - thread_idx, - smem_buf - ); + auto k_tile_count = size<2>(gA); + auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape); + auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape); + + // Wait for all threads in the thread block + __syncthreads(); + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue{params.epilogue, shared_storage.tensors.epilogue}; + + if (warp_group_role == WarpGroupRole::Producer) { + // Compute tile residues for predication + auto m_max_coord = M - size<0>(gA) * get<0>(blk_coord); // M - BLK_M * m_coord + auto n_max_coord = N - size<0>(gB) * get<1>(blk_coord); // N - BLK_N * n_coord + auto k_residue = K - size<1>(gA) * size<2>(gA); // K - BLK_K * k_coord_max + auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue); + + collective_mainloop.load( + mainloop_pipeline, + mainloop_pipe_producer_state, + gA, + gB, + k_tile_iter, k_tile_count, + residue_mnk, + thread_idx, + shared_storage.tensors.mainloop + ); + // Update starting mainloop pipeline state for the pipeline drain + mainloop_pipe_producer_state.advance(k_tile_count); + // Make sure mainloop consumer has been waited upon before issuing epilogue load + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + + if (collective_epilogue.is_producer_load_needed()) { + epi_load_pipe_producer_state = + collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + tiled_mma, + thread_idx, + shared_storage.tensors.epilogue + ); + collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + } + } + else if (warp_group_role == WarpGroupRole::Consumer) { + Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) + + collective_mainloop.mma( + mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators, + k_tile_count, + warp_group_thread_idx, + shared_storage.tensors.mainloop, + params.mainloop + ); + + // Make sure the math instructions are done and free buffers before entering the epilogue + collective_mainloop.mma_tail( + mainloop_pipeline, + mainloop_pipe_consumer_state, + k_tile_count + ); + + // Epilogue and write to gD + collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + accumulators, + tiled_mma, + warp_group_thread_idx, + shared_storage.tensors.epilogue + ); + } } }; diff --git a/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu b/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu new file mode 100644 index 0000000000..92c2207cd8 --- /dev/null +++ b/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu @@ -0,0 +1,558 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Hopper GEMM example with different data types using CUTLASS 3.0 APIs for NVIDIA Hopper architecture + + This example shows how to perform GEMM where the input tensors A and B have different element types. CUTLASS currently supports upcasting + from a narrower (fewer bits) to a wider (more bits) type and utilizing the tensor core instruction for the wider type. For instance, when doing + INT8 x FP16, CUTLASS will convert INT8 -> FP16 and do math using FP16 tensor cores. Similarly, for INT4 x INT8, it will upcast to INT8 and issue math + using INT8 tensor cores. + + The narrower type always passes through the register file. Therefore, in cases where the narrower type is operand B, the collective will implicitly swap + A and B in the main loop. Consequently, it is essential to consider this when constructing the epilogue, as illustrated in this example. + + Limitations: + 1) Only supported combinations are 16-bit x {8-bit, 4-bit, 2-bit} and {8-bit} x {4-bit, 2-bit}. + 2) The narrow type must always be in K-major format. + 3) When dealing with 8-bit x {4-bit, 2-bit}, both inputs must be in K-major format. + 4) Currently, TMA epilogues cannot be used when the narrow type is the B operand. This limitation arises because the implementation always swaps the + operands to ensure the narrow type passes through the register file, and TMA epilogues do not currently support swap + transpose operations. + We plan to address this limitation in the future. + + Examples: + + $ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm --m=2048 --n=2048 --k=2048 --l=2 +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gett.hpp" + +#include "helper.h" +#include "unfused_weight_dequantize.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// +using MmaType = cutlass::half_t; +using QuantType = int8_t; + +// A matrix configuration +using ElementA = MmaType; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = QuantType; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = cutlass::half_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// D matrix configuration +using ElementD = ElementC; +using LayoutD = LayoutC; + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for epilogue computation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_128,_256,_64>; // Threadblock-level tile size +using ClusterShape = Shape<_2,_1,_1>; // Shape of the threadblocks in a cluster +using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size +using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default setting in the Collective Builder + + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + // Lie here about layout of C and D since we do swap and transpose trick + ElementC, typename cutlass::layout::LayoutTranspose::type, AlignmentC, + ElementC, typename cutlass::layout::LayoutTranspose::type, AlignmentC, + cutlass::epilogue::NoSmemWarpSpecialized // This is the only epi supporting the required swap + transpose. + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue +>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +using StrideA = typename Gemm::GemmKernel::StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +// +// Data members +// + +/// Initialization +StrideA stride_A; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; +uint64_t seed; + +// Initialization functions don't handle sub-byte types so we use uint8 to initialize and a separate +// kernel to pack the data if it is necessary. +using InitializationType = cute::conditional_t < 8, uint8_t, QuantType>; + +cutlass::HostTensor tensor_A; +cutlass::HostTensor tensor_B_init; +cutlass::HostTensor tensor_B; +cutlass::HostTensor tensor_B_dq; +cutlass::HostTensor tensor_C; +cutlass::HostTensor tensor_D; +cutlass::HostTensor tensor_ref_D; + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help = false; + + float alpha = 1.0f; + float beta = 0.0f; + int iterations = 1000; + int m = 5120, n = 4096, k = 4096; + int l = 1; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "55_hopper_warp_specialized_gemm\n\n" + << " Hopper FP32 GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= The number of independent gemm problems with mnk shape\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "55_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 --l=10 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k * l; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms = 0.0; + double gflops = 0.0; + cutlass::Status status = cutlass::Status::kSuccess; + cudaError_t error = cudaSuccess; + bool passed = false; + +}; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_tensor( + cutlass::TensorView view, + uint64_t seed=2023) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } + else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + + return true; +} + +template +bool initialize_quant_tensor( + cutlass::TensorView view, + uint64_t seed=2023) { + + Element scope_max, scope_min; + constexpr int bits_input = cute::sizeof_bits_v; + static_assert(bits_input <= 8, "Quantization type can be at most 8 bits"); + + if constexpr (bits_input == 8) { + // Directly init 1-byte types + static_assert(cute::is_same_v, "Init type should equal quant type for 1 byte types"); + scope_max = std::numeric_limits::max(); + scope_min = std::numeric_limits::min(); + } else { + static_assert(cute::is_same_v, "Init type should be uint8_t for sub-byte types"); + scope_max = (1 << bits_input); + scope_min = 0; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + + return true; +} + +template +bool initialize_with_one( + cutlass::TensorView view) { + cutlass::reference::host::TensorFill(view, Element(1.0f)); + return true; +} + +template +void prepare_packed_data(cutlass::HostTensor view_dst_data, + cutlass::HostTensor view_src_data, + const L& cute_layout) { + if constexpr (cute::is_same_v) { + view_dst_data.copy_in_device_to_device(view_src_data.device_data()); + } + else { + pack_data(view_dst_data.device_data(), view_src_data.device_data(), cute_layout); + } +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + + auto shape_b = cute::make_shape(options.n, options.k, options.l); + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_b); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.n, options.m, options.l)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.n, options.m, options.l)); + + auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); + auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); + auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); + + tensor_A.resize(a_coord); + tensor_B_init.resize(b_coord); + tensor_B.resize(b_coord); + tensor_B_dq.resize(b_coord); + tensor_C.resize(c_coord); + tensor_D.resize(c_coord); + tensor_ref_D.resize(c_coord); + + // We need scales since the "dequantize" kernels expects them. We just set them to 1 so the values get converted + // to the mma type. + cutlass::HostTensor tensor_scale; + tensor_scale.resize({1 * options.l, options.n}); + + initialize_tensor(tensor_A.host_view(), seed + 2022); + initialize_quant_tensor(tensor_B_init.host_view(), seed + 2021); + initialize_tensor(tensor_C.host_view(), seed + 2020); + initialize_with_one(tensor_scale.host_view()); + + tensor_A.sync_device(); + tensor_B_init.sync_device(); + tensor_C.sync_device(); + tensor_scale.sync_device(); + + auto layout_B = make_layout(shape_b, stride_B); + prepare_packed_data(tensor_B, tensor_B_init, layout_B); + + auto shape_scale = cute::make_shape(options.n, 1, options.l); + auto layout_scale = make_layout(shape_scale); + dequantize_weight(tensor_B_dq.device_data(), tensor_B.device_data(), layout_B, tensor_scale.device_data(), layout_scale); + + tensor_B.sync_host(); + tensor_B_dq.sync_host(); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, options.l}, + {tensor_A.device_data(), stride_A, tensor_B.device_data(), stride_B}, + {{options.alpha, options.beta}, tensor_C.device_data(), stride_C, tensor_D.device_data(), stride_D} + }; + + return arguments; +} + +bool verify(const Options &options) { + // + // Compute reference output + // + + // Create instantiation for device reference gemm kernel + auto A = cute::make_tensor(tensor_A.host_data(), + cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A)); + auto B = cute::make_tensor(tensor_B_dq.host_data(), + cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B)); + auto C = cute::make_tensor(tensor_C.host_data(), + cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C)); + auto D = cute::make_tensor(tensor_ref_D.host_data(), + cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D)); + + using unused_t = decltype(D); + + cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; + + cutlass::reference::host::GettEpilogueParams< + typename Gemm::EpilogueOutputOp::ElementScalar, + typename Gemm::EpilogueOutputOp::ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D), + unused_t, // bias + unused_t, // aux + unused_t, // valpha + unused_t // vbeta + > epilogue_params; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = options.alpha; + epilogue_params.beta = options.beta; + + // get reference result + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + // compare_reference + tensor_D.sync_host(); + bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view()); + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least 90. + if (__CUDACC_VER_MAJOR__ < 12) { + std::cerr << "This example requires CUDA 12 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major < 9) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater).\n"; + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + run(options); +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/python/cutlass/profiler/__init__.py b/examples/55_hopper_mixed_dtype_gemm/CMakeLists.txt similarity index 84% rename from python/cutlass/profiler/__init__.py rename to examples/55_hopper_mixed_dtype_gemm/CMakeLists.txt index d9e9cdc854..8b8ac5ba15 100644 --- a/python/cutlass/profiler/__init__.py +++ b/examples/55_hopper_mixed_dtype_gemm/CMakeLists.txt @@ -1,5 +1,4 @@ -################################################################################################# -# + # Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # @@ -27,11 +26,10 @@ # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# -""" -Profilers for Python Interface -""" -from cutlass.profiler.event_profiler import CUDAEventProfiler + +cutlass_example_add_executable( + 55_hopper_mixed_dtype_gemm + 55_hopper_mixed_dtype_gemm.cu + ) diff --git a/examples/55_hopper_mixed_dtype_gemm/README.md b/examples/55_hopper_mixed_dtype_gemm/README.md new file mode 100644 index 0000000000..2c0236512c --- /dev/null +++ b/examples/55_hopper_mixed_dtype_gemm/README.md @@ -0,0 +1,36 @@ +This example shows how to do mixed types GEMMs in CUTLASS. + +## High level overview +This example shows how to perform GEMMs on Hopper when A and B have different types. This implementation always passes the type with fewer bits through the register file and upcasts to the type with the higher bit count. + +When relying on `KernelScheduleAuto`, the main loop supporting different A and B types will be selected whenever the bit count of A is not equal to the bit count of B. Users can manually select the mixed type main loop and explicitly choose the scheduling policy by specifying one of the following schedules to the `CollectiveBuilder`: `KernelTmaWarpSpecializedMixedInput`, `KernelTmaWarpSpecializedPingpongMixedInput` or `KernelTmaWarpSpecializedCooperativeMixedInput`. + +This first version only supports mixed type GEMMs using TMA. + +## Performance + +While the example offers a harness for straightforward benchmarking, this initial implementation isn't optimized for performance in the majority of scenarios. We expect this implementation to be performant for `{fp16, bf16} x int8` for problems that are compute bound. + +We are currently optimizing the following cases: +1. Memory bound cases for all types +1. Compute bound cases for `{16-bit, 8-bit} x {4-bit, 2-bit}` + +As a result, we do not suggest using this example as a benchmarking reference until all of our optimizations are complete (this will be clearly stated in this README in a future release). + +## Limitations + +* The type that needs to be converted must go through the register file. This means that the collective will swap and transpose whenever the type with fewer bits is the B operand. The user must be aware of when these swaps happen to control the layout of the epilogue as shown in the example. Note that TMA epilogues currently do not support swap + transpose, so non-tma epilogues must be used in this case. We plan to relax this limitation in a future release. + +* The layout of the narrow type must be K-major. This means the following: + * Narrow type is the A operand: Must be Row-Major + * Narrow type is the B operand: Must be Column-Major + +* For 8-bit x 4-bit or 2-bit, both inputs must be K-major. + +* TMA requires an alignment of 128 bits. As a result, for a type with `B` bits, `B x TILE_K` must be a multiple of 128 bits. + +## Upcoming features + +* Support for applying scales after conversion, but before issuing tensor core math (input scale fusion) is planned for v3.4. + +* Many optimizations for SOL performance. diff --git a/examples/55_hopper_mixed_dtype_gemm/unfused_weight_dequantize.h b/examples/55_hopper_mixed_dtype_gemm/unfused_weight_dequantize.h new file mode 100644 index 0000000000..e19a01dc69 --- /dev/null +++ b/examples/55_hopper_mixed_dtype_gemm/unfused_weight_dequantize.h @@ -0,0 +1,178 @@ +#pragma once + +#include "cute/tensor.hpp" + +#include +#include "helper.h" + +template +__global__ void dequantize_weight_kernel(DequantizedElement* dq_buffer, + const QuantizedElement* q_buffer, + const OperandLayout operand_layout, + const ElementScale* scale_buffer, + const ScaleBroadCastLayout broadcasted_scale_layout, + ThrLayout thr_layout) { + using namespace cute; + + // Represent the full tensors to gmem elements. + // These are expected to have shape [MN, K, L] + Tensor gmem_op_dq = make_tensor(make_gmem_ptr(dq_buffer), operand_layout); + auto init_quantized_iterator = [&]() { + if constexpr (cute::sizeof_bits_v >= 8) { + return make_gmem_ptr(q_buffer); + } else { + return subbyte_iterator(q_buffer); + } + }; + Tensor gmem_op_q = make_tensor(init_quantized_iterator(), operand_layout); + // While the scales are expected to have shape [MN, G, L] but with a stride to allow broadcasting + // It is expected that K % G == 0 + Tensor gmem_scale_broadcasted = make_tensor(make_gmem_ptr(scale_buffer), broadcasted_scale_layout); + + // Assign 1 thread per element in the thread block + auto blk_shape = make_shape(size<0>(thr_layout), _1{}, _1{}); // + auto blk_coord = make_coord(_, blockIdx.x, blockIdx.y); // (MN, K, L) + + // Tile across the block + auto gOp_dq = local_tile(gmem_op_dq, blk_shape, blk_coord); + auto gScale = local_tile(gmem_scale_broadcasted, blk_shape, blk_coord); + auto gOp_q = local_tile(gmem_op_q, blk_shape, blk_coord); + + auto tOpDq_gOpDq = local_partition(gOp_dq, thr_layout, threadIdx.x); + auto tScale_gScale = local_partition(gScale, thr_layout, threadIdx.x); + auto tOpQ_gOpQ = local_partition(gOp_q, thr_layout, threadIdx.x); + + // Make a fragment of registers to hold gmem loads + Tensor rmem_op_q = make_fragment_like(tOpQ_gOpQ(_, _, _, 0)); + Tensor rmem_scale = make_fragment_like(tScale_gScale(_, _, _, 0)); + Tensor rmem_op_dq = make_fragment_like(tOpDq_gOpDq(_, _, _, 0)); + + Tensor pred_id = make_identity_tensor(shape(operand_layout)); + auto pred_blk_tile = local_tile(pred_id, blk_shape, blk_coord); + auto pred_thr_partition = local_partition(pred_blk_tile, thr_layout, threadIdx.x); + + const auto num_iters = size<3>(tOpDq_gOpDq); + + for (int ii = 0; ii < num_iters; ++ii) { + const auto thread_offset = get<0>(pred_thr_partition(0, 0, 0, ii)); + if (thread_offset < size<0>(operand_layout)) { + copy(tOpQ_gOpQ(_, _, _, ii), rmem_op_q); + copy(tScale_gScale(_, _, _, ii), rmem_scale); + transform(rmem_op_q, rmem_op_dq, [] (const QuantizedElement& elt) { return DequantizedElement(elt); } ); + transform(rmem_op_dq, rmem_scale, rmem_op_dq, multiplies{}); + copy(rmem_op_dq, tOpDq_gOpDq(_, _, _, ii)); + } + } +} + +template +void dequantize_weight(DequantizedElement* dq_buffer, + const QuantizedElement* q_buffer, + const OperandLayout operand_layout, + const ElementScale* scale_buffer, + const ScaleLayout scale_layout) { + + using namespace cute; + + constexpr int tpb = 128; + auto thr_layout = make_layout(make_shape(Int{})); + + const auto num_rows = get<0>(shape(operand_layout)); + const auto num_cols = get<1>(shape(operand_layout)); // [MN, K, L] + const auto batches = get<2>(shape(operand_layout)); // [MN, K, L] + const auto num_cols_scale = get<1>(shape(scale_layout)); // [MN, G, L] + + if (num_rows != size<0>(scale_layout)) { + std::cerr << "Invalid first dimension for scales. Must match first dim for weights." + << " But got shapes " << shape(operand_layout) << " " << shape(scale_layout) + << std::endl; + exit(-1); + } + + if (num_cols % num_cols_scale != 0) { + std::cerr << "Invalid shape for weight / scales. Weight cols must be a multiple of scale cols." + << " But got shapes " << shape(operand_layout) << " " << shape(scale_layout) + << std::endl; + exit(-1); + } + + const auto scale_stride0 = get<0>(stride(scale_layout)); + const auto scale_stride1 = get<1>(stride(scale_layout)); + const auto scale_stride2 = get<2>(stride(scale_layout)); + + auto scale_shape_bcast = make_shape(num_rows, make_shape(num_cols / num_cols_scale, num_cols_scale), batches); + auto scale_stride_bcast = make_stride(scale_stride0, make_stride(0, scale_stride1), scale_stride2); + auto scale_layout_bcast = make_layout(scale_shape_bcast, scale_stride_bcast); + + const auto blocks_x = num_cols; + const auto blocks_y = batches; + + dim3 blocks(blocks_x, blocks_y, 1); + dequantize_weight_kernel<<>>(dq_buffer, q_buffer, operand_layout, scale_buffer, scale_layout_bcast, thr_layout); + CUDA_CHECK(cudaDeviceSynchronize()); +} + + +template +__global__ void pack_data_kernel(SubbyteType* packed_data_ptr, + const uint8_t* unpacked_data_ptr, + const size_t max_elts) { + using namespace cute; + + const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + uint8_t data[ELTS_PER_THREAD]; + if (tid < max_elts) { + const uint8_t* read_ptr = unpacked_data_ptr + tid * ELTS_PER_THREAD; + for (int ii = 0; ii < ELTS_PER_THREAD; ++ii) { + data[ii] = read_ptr[ii]; + } + + + using WriteType = cute::array_subbyte; + WriteType* write_ptr = reinterpret_cast(packed_data_ptr); + + WriteType packed_data; + for (int ii = 0; ii < ELTS_PER_THREAD; ++ii) { + SubbyteType elt(data[ii]); + packed_data[ii] = elt; + } + write_ptr[tid] = packed_data; + } + +} + +template +void pack_data(SubbyteType* packed_data, const uint8_t* unpacked_data, const OperandLayout operand_layout) { + static_assert(cute::sizeof_bits_v < 8, "First operand must be a sub-byte type"); + constexpr int packed_elements = 8 / cute::sizeof_bits_v; + + if (cute::stride<0>(operand_layout) == 1 && (cute::shape<0>(operand_layout) % packed_elements)) { + std::cerr << "Invalid shape / stride for dimension 0. Contiguous dimension must be a multiple of " + << packed_elements << std::endl; + exit(-1); + } + + if (cute::stride<1>(operand_layout) == 1 && (cute::shape<1>(operand_layout) % packed_elements)) { + std::cerr << "Invalid shape / stride for dimension 1. Contiguous dimension must be a multiple of " + << packed_elements << std::endl; + exit(-1); + } + + const int64_t total_threads = cute::size(operand_layout) / packed_elements; + + const int threads_per_block = 256; + const int64_t num_blocks = (total_threads + threads_per_block - 1) / threads_per_block; + pack_data_kernel<<>>(packed_data, unpacked_data, total_threads); + CUDA_CHECK(cudaDeviceSynchronize()); +} diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index cf604d861f..47445ca247 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -135,6 +135,7 @@ foreach(EXAMPLE 52_hopper_gather_scatter_fusion 53_hopper_gemm_permute 54_hopper_fp8_warp_specialized_gemm + 55_hopper_mixed_dtype_gemm ) add_subdirectory(${EXAMPLE}) diff --git a/examples/python/00_basic_gemm.ipynb b/examples/python/00_basic_gemm.ipynb index 6c8222e0de..e7a130b6f6 100644 --- a/examples/python/00_basic_gemm.ipynb +++ b/examples/python/00_basic_gemm.ipynb @@ -269,6 +269,7 @@ "metadata": {}, "outputs": [], "source": [ + "tiles = [td for td in tiles if td.threadblock_shape[0] >= 128]\n", "idx = random.randint(0, len(tiles)-1)\n", "td = tiles[idx]\n", "print('Tile description {} is: {}'.format(idx, td))\n", diff --git a/examples/python/04_epilogue_visitor.ipynb b/examples/python/04_epilogue_visitor.ipynb index 72547d1999..3f47afa019 100644 --- a/examples/python/04_epilogue_visitor.ipynb +++ b/examples/python/04_epilogue_visitor.ipynb @@ -29,7 +29,7 @@ "import cutlass\n", "from cutlass.epilogue import relu\n", "from cutlass import Tensor as FakeTensor\n", - "from cutlass.profiler import CUDAEventProfiler\n", + "from cutlass.utils.profiler import CUDAEventProfiler\n", "\n", "# This controls whether ther C++ GEMM declaration will be printed at each step. Set to `false` to\n", "# omit this information.\n", @@ -160,10 +160,6 @@ " return example_epilogue(accum, alpha, C, beta, aux, bias)\n", "\n", "torch_reference = TorchReference()\n", - "if hasattr(torch, \"compile\"):\n", - " # If the torch.compile feature is available\n", - " torch_reference = torch.compile(torch_reference)\n", - "\n", "tensor_D_ref, tensor_F_ref = torch_reference(tensor_A, tensor_B, alpha, tensor_C, beta, aux, bias)\n", "\n", "assert torch.equal(tensor_D, tensor_D_ref)\n", diff --git a/include/cute/algorithm/copy.hpp b/include/cute/algorithm/copy.hpp index 65f80af8f9..9506db7919 100644 --- a/include/cute/algorithm/copy.hpp +++ b/include/cute/algorithm/copy.hpp @@ -130,6 +130,17 @@ copy_if(PrdTensor const& pred, // copy_if -- Predicated CopyAtom // +namespace detail { + +// Trait that detects if atom's traits has a member function with(bool) +template +constexpr bool has_with_bool = false; + +template +constexpr bool has_with_bool().with(declval()))>> = true; + +} // end namespace detail + template const& copy_atom, auto dst_v = group_modes<1,R>(dst); CUTE_UNROLL for (int i = 0; i < size<1>(src_v); ++i) { - if (pred(i)) { - copy_atom.call(src_v(_,i), dst_v(_,i)); + // If copy traits can be transformed with a predicate value, do it, otherwise branch here + if constexpr (detail::has_with_bool>) { + copy_atom.with(pred(i)).call(src_v(_,i), dst_v(_,i)); + } + else { + if (pred(i)) { + copy_atom.call(src_v(_,i), dst_v(_,i)); + } } } } @@ -169,15 +186,17 @@ void copy_vec(Tensor const& src, Tensor & dst) { - using SrcType = typename SrcEngine::value_type; - using DstType = typename DstEngine::value_type; + using SrcType = typename SrcEngine::element_type; + using DstType = typename DstEngine::element_type; if constexpr (sizeof(SrcType) == sizeof(DstType) && sizeof(VecType) > sizeof(DstType)) { /* @pre is_aligned(src.data()) && * is_aligned(dst.data()) */ - auto src_v = recast(src); - auto dst_v = recast(dst); + using SrcVecType = conditional_t, VecType const volatile, VecType const>; + using DstVecType = conditional_t, VecType volatile, VecType >; + auto src_v = recast(src); + auto dst_v = recast(dst); #if 0 if (thread0()) { diff --git a/include/cute/algorithm/functional.hpp b/include/cute/algorithm/functional.hpp index ea17ecb907..1fee742460 100644 --- a/include/cute/algorithm/functional.hpp +++ b/include/cute/algorithm/functional.hpp @@ -170,6 +170,76 @@ CUTE_NAMED_BINARY_OP(min_fn, cute::min); #undef CUTE_BINARY_OP #undef CUTE_NAMED_BINARY_OP +/**********/ +/** Fold **/ +/**********/ + +#define CUTE_FOLD_OP(NAME,OP) \ + struct NAME##_unary_rfold { \ + template \ + CUTE_HOST_DEVICE constexpr \ + auto operator()(T&&... t) const { \ + return (t OP ...); \ + } \ + }; \ + struct NAME##_unary_lfold { \ + template \ + CUTE_HOST_DEVICE constexpr \ + auto operator()(T&&... t) const { \ + return (... OP t); \ + } \ + }; \ + struct NAME##_binary_rfold { \ + template \ + CUTE_HOST_DEVICE constexpr \ + auto operator()(U&& u, T&&... t) const { \ + return (t OP ... OP u); \ + } \ + }; \ + struct NAME##_binary_lfold { \ + template \ + CUTE_HOST_DEVICE constexpr \ + auto operator()(U&& u, T&&... t) const { \ + return (u OP ... OP t); \ + } \ + } + +CUTE_FOLD_OP(plus, +); +CUTE_FOLD_OP(minus, -); +CUTE_FOLD_OP(multiplies, *); +CUTE_FOLD_OP(divides, /); +CUTE_FOLD_OP(modulus, %); + +CUTE_FOLD_OP(plus_assign, +=); +CUTE_FOLD_OP(minus_assign, -=); +CUTE_FOLD_OP(multiplies_assign, *=); +CUTE_FOLD_OP(divides_assign, /=); +CUTE_FOLD_OP(modulus_assign, %=); + +CUTE_FOLD_OP(bit_and, &); +CUTE_FOLD_OP(bit_or, |); +CUTE_FOLD_OP(bit_xor, ^); +CUTE_FOLD_OP(left_shift, <<); +CUTE_FOLD_OP(right_shift, >>); + +CUTE_FOLD_OP(bit_and_assign, &=); +CUTE_FOLD_OP(bit_or_assign, |=); +CUTE_FOLD_OP(bit_xor_assign, ^=); +CUTE_FOLD_OP(left_shift_assign, <<=); +CUTE_FOLD_OP(right_shift_assign, >>=); + +CUTE_FOLD_OP(logical_and, &&); +CUTE_FOLD_OP(logical_or, ||); + +CUTE_FOLD_OP(equal_to, ==); +CUTE_FOLD_OP(not_equal_to, !=); +CUTE_FOLD_OP(greater, >); +CUTE_FOLD_OP(less, <); +CUTE_FOLD_OP(greater_equal, >=); +CUTE_FOLD_OP(less_equal, <=); + +#undef CUTE_FOLD_OP + /**********/ /** Meta **/ /**********/ diff --git a/include/cute/arch/copy.hpp b/include/cute/arch/copy.hpp index aa7bb333ed..8c2552ecb3 100644 --- a/include/cute/arch/copy.hpp +++ b/include/cute/arch/copy.hpp @@ -48,11 +48,21 @@ struct UniversalCopy using SRegisters = S[1]; using DRegisters = D[1]; + template CUTE_HOST_DEVICE static constexpr void - copy(S const& src, - D & dst) + copy(S_ const& src, + D_ & dst) { - dst = src; + dst = static_cast(static_cast(src)); + } + + // Accept mutable temporaries + template + CUTE_HOST_DEVICE static constexpr void + copy(S_ const& src, + D_ && dst) + { + copy(src, dst); } }; diff --git a/include/cute/arch/copy_sm80.hpp b/include/cute/arch/copy_sm80.hpp index c6c44121bd..7dd12de2d3 100644 --- a/include/cute/arch/copy_sm80.hpp +++ b/include/cute/arch/copy_sm80.hpp @@ -96,6 +96,66 @@ struct SM80_CP_ASYNC_CACHEGLOBAL } }; +/// Copy via cp.async with caching at all levels +template +struct SM80_CP_ASYNC_CACHEALWAYS_ZFILL +{ + using SRegisters = TS[1]; + using DRegisters = TD[1]; + + static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)"); + static_assert(sizeof(TS) == 4 || sizeof(TS) == 8 || sizeof(TS) == 16, "cp.async sizeof(TS) is not supported"); + + CUTE_HOST_DEVICE static void + copy(TS const& gmem_src, + TD & smem_dst, + bool pred) + { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + TS const* gmem_ptr = &gmem_src; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + int src_size = pred ? sizeof(TS) : 0; + asm volatile("cp.async.ca.shared.global [%0], [%1], %2, %3;\n" + :: "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(sizeof(TS)), + "r"(src_size)); +#else + CUTE_RUNTIME_ASSERT("Support for cp.async instructions has not been enabled"); +#endif + } +}; + +/// Copy via cp.async with caching at global level +template +struct SM80_CP_ASYNC_CACHEGLOBAL_ZFILL +{ + using SRegisters = TS[1]; + using DRegisters = TD[1]; + + static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)"); + static_assert(sizeof(TS) == 4 || sizeof(TS) == 8 || sizeof(TS) == 16, "cp.async sizeof(TS) is not supported"); + + CUTE_HOST_DEVICE static void + copy(TS const& gmem_src, + TD & smem_dst, + bool pred) + { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + TS const* gmem_ptr = &gmem_src; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + int src_size = pred ? sizeof(TS) : 0; + asm volatile("cp.async.cg.shared.global [%0], [%1], %2, %3;\n" + :: "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(sizeof(TS)), + "r"(src_size)); +#else + CUTE_RUNTIME_ASSERT("Support for cp.async instructions has not been enabled"); +#endif + } +}; + //////////////////////////////////////////////////////////////////////////////////////////////////// /// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block. diff --git a/include/cute/arch/copy_sm90_tma.hpp b/include/cute/arch/copy_sm90_tma.hpp index 46cace385d..fcb9189f71 100644 --- a/include/cute/arch/copy_sm90_tma.hpp +++ b/include/cute/arch/copy_sm90_tma.hpp @@ -785,7 +785,7 @@ tma_store_arrive() { #endif } -// Wait on prior N (Count) TMA_STORE instructions to complete +// Wait until at most Count committed TMA_STOREs are pending and all prior commits are complete template CUTE_HOST_DEVICE static void tma_store_wait() { diff --git a/include/cute/atom/copy_traits.hpp b/include/cute/atom/copy_traits.hpp index 9c4821d90d..53548f3aaf 100644 --- a/include/cute/atom/copy_traits.hpp +++ b/include/cute/atom/copy_traits.hpp @@ -92,6 +92,21 @@ struct Copy_Traits using RefLayout = SrcLayout; }; +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +void +copy_explode(PtrS&& s, int_sequence, + PtrD&& d, int_sequence) +{ + return Operation::copy(s[Is]..., d[Id]...); +} + +} // end namespace detail + // // Generic copy_unpack for any Copy_Traits // @@ -123,9 +138,8 @@ copy_unpack(Copy_Traits const&, CUTE_STATIC_ASSERT_V(size(rD) == Int{}, "In CopyAtom, dst layout doesn't vectorize into registers. This dst layout is incompatible with this tiled copy."); - detail::explode(Operation::copy, - rS, make_int_sequence{}, - rD, make_int_sequence{}); + detail::copy_explode(rS, make_int_sequence{}, + rD, make_int_sequence{}); } // diff --git a/include/cute/atom/copy_traits_sm80.hpp b/include/cute/atom/copy_traits_sm80.hpp index 089d19347f..4e311be38d 100644 --- a/include/cute/atom/copy_traits_sm80.hpp +++ b/include/cute/atom/copy_traits_sm80.hpp @@ -51,6 +51,13 @@ struct Copy_Traits> // Reference map from (thr,val) to bit using RefLayout = SrcLayout; + + // Construct a zfill variant with a given predicate value + CUTE_HOST_DEVICE constexpr + Copy_Traits> + with(bool pred) const { + return {pred}; + } }; template @@ -66,6 +73,95 @@ struct Copy_Traits> // Reference map from (thr,val) to bit using RefLayout = SrcLayout; + + // Construct a zfill variant with a given predicate value + CUTE_HOST_DEVICE constexpr + Copy_Traits> + with(bool pred) const { + return {pred}; + } +}; + +template +struct Copy_Traits> +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout::value>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout::value>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // Predicate value that determines whether to load or zfill + bool pred = false; + + // Overload copy_unpack for zfill variant to pass the predicate into the op + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_gmem::value, "Expected gmem source for cp.async."); + static_assert(is_smem::value, "Expected smem destination for cp.async."); + + Tensor rS = recast(src); + Tensor rD = recast(dst); + + CUTE_STATIC_ASSERT_V(size(rS) == Int<1>{}, + "In CopyAtom, src layout doesn't vectorize into registers. This src layout is incompatible with this tiled copy."); + CUTE_STATIC_ASSERT_V(size(rD) == Int<1>{}, + "In CopyAtom, dst layout doesn't vectorize into registers. This dst layout is incompatible with this tiled copy."); + + SM80_CP_ASYNC_CACHEALWAYS_ZFILL::copy(rS[0], rD[0], traits.pred); + } +}; + +template +struct Copy_Traits> +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout::value>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout::value>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // Predicate value that determines whether to load or zfill + bool pred = false; + + // Overload copy_unpack for zfill variant to pass the predicate into the op + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_gmem::value, "Expected gmem source for cp.async."); + static_assert(is_smem::value, "Expected smem destination for cp.async."); + + Tensor rS = recast(src); + Tensor rD = recast(dst); + + CUTE_STATIC_ASSERT_V(size(rS) == Int<1>{}, + "In CopyAtom, src layout doesn't vectorize into registers. This src layout is incompatible with this tiled copy."); + CUTE_STATIC_ASSERT_V(size(rD) == Int<1>{}, + "In CopyAtom, dst layout doesn't vectorize into registers. This dst layout is incompatible with this tiled copy."); + + SM80_CP_ASYNC_CACHEGLOBAL_ZFILL::copy(rS[0], rD[0], traits.pred); + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/atom/copy_traits_sm90_tma.hpp b/include/cute/atom/copy_traits_sm90_tma.hpp index d2617abdd8..132ba5201a 100644 --- a/include/cute/atom/copy_traits_sm90_tma.hpp +++ b/include/cute/atom/copy_traits_sm90_tma.hpp @@ -79,6 +79,14 @@ struct Copy_Traits copy_unpack_(void const* const dst_ptr, Coord const& src_coord, seq) const { +#if 0 + auto [c0,c1,c2,c3,c4] = append<5>(src_coord, 0); + printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", + threadIdx.x, threadIdx.y, threadIdx.z, + blockIdx.x, blockIdx.y, blockIdx.z, + int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), dst_ptr); +#endif + SM90_TMA_LOAD::copy(&tma_desc_, tma_load_mbar_, dst_ptr, get(src_coord)...); } @@ -185,6 +193,14 @@ struct Copy_Traits copy_unpack_(void const* const dst_ptr, Coord const& src_coord, seq) const { +#if 0 + auto [c0,c1,c2,c3,c4] = append<5>(src_coord, 0); + printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", + threadIdx.x, threadIdx.y, threadIdx.z, + blockIdx.x, blockIdx.y, blockIdx.z, + int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), dst_ptr); +#endif + SM90_TMA_LOAD_MULTICAST::copy(&tma_desc_, tma_load_mbar_, multicast_mask_, dst_ptr, get(src_coord)...); } @@ -298,6 +314,14 @@ struct Copy_Traits copy_unpack_(void const* const src_ptr, Coord const& dst_coord, seq) const { +#if 0 + auto [c0,c1,c2,c3,c4] = append<5>(dst_coord, 0); + printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", + threadIdx.x, threadIdx.y, threadIdx.z, + blockIdx.x, blockIdx.y, blockIdx.z, + int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), src_ptr); +#endif + SM90_TMA_STORE::copy(&tma_desc_, src_ptr, get(dst_coord)...); } @@ -354,8 +378,8 @@ struct Copy_Traits "Extra arguments not set. Set .with() before use."); static_assert(is_gmem::value, "Expected gmem src for SM90_BULK_COPY_G2S"); static_assert(is_smem::value, "Expected smem dst for SM90_BULK_COPY_G2S"); - SM90_BULK_COPY_G2S::copy(src.data().get(), *get<0>(traits.bulk_load_mbar_), - dst.data().get(), int32_t(NumBitsPerTMA::value / 8)); + SM90_BULK_COPY_G2S::copy(raw_pointer_cast(src.data()), *get<0>(traits.bulk_load_mbar_), + raw_pointer_cast(dst.data()), int32_t(NumBitsPerTMA::value / 8)); } // Record the memory barrier for the instruction @@ -390,7 +414,7 @@ struct Copy_Traits { static_assert(is_smem::value, "Expected smem src for SM90_BULK_COPY_S2G"); static_assert(is_gmem::value, "Expected gmem dst for SM90_BULK_COPY_S2G"); - SM90_BULK_COPY_S2G::copy(src.data().get(), dst.data().get(), int32_t(NumBitsPerTMA::value / 8)); + SM90_BULK_COPY_S2G::copy(raw_pointer_cast(src.data()), raw_pointer_cast(dst.data()), int32_t(NumBitsPerTMA::value / 8)); } }; @@ -497,7 +521,7 @@ coalesce_256(Tensor const& tensor) // and construct a TMA Descriptor for the resulting instruction // At the same time, construct the Tma Tensor's Stride to generate // the TMA coordinates that the instruction consumes. -// +// template const& gtensor, // The original GM // Perform the tiling to the gmem vector again, but with indirections to the gtensor modes auto gbasis = make_identity_layout(shape(gtensor)); auto tile_gbasis_tmp = gbasis.compose(smem_inv_h); - + // Instead of the recast (gbasis doesn't have type info), replace the shape with the already-recasted shape // tma_box_shape:gmem_mode auto tile_gbasis = make_layout(shape(tile_gstride), stride(tile_gbasis_tmp)); @@ -530,8 +554,8 @@ make_tma_copy_desc(Tensor const& gtensor, // The original GM // NOTE This is essentially ArithmeticTuple complement... // NOTE in pursuit of implementing an ArithmeticTuple logical_divide for smem_inv_h auto tile_gbasis_remaining_stride = filter_tuple(flatten(shape (gtensor_T)), flatten(stride(gtensor_T)), - flatten(stride(gbasis)), - [&](auto s, auto d, auto e) + flatten(stride(gbasis)), + [&](auto s, auto d, auto e) { if constexpr (is_constant<1, decltype(s)>::value || is_constant<0, decltype(d)>::value) { return cute::tuple<>{}; // If size-1 or stride-0, then don't append @@ -551,7 +575,7 @@ make_tma_copy_desc(Tensor const& gtensor, // The original GM auto tma_gbasis_tile = tile_gbasis.compose(make_layout(wrap(shape(tma_gstride)))); // Append the remaining basis modes that contribute to the TMA with size-1 - auto tma_gbasis_full = make_layout(tuple_cat(wrap( shape(tma_gbasis_tile)), wrap(repeat(Int<1>{}))), + auto tma_gbasis_full = make_layout(tuple_cat(wrap( shape(tma_gbasis_tile)), wrap(repeat(Int<1>{}))), tuple_cat(wrap(stride(tma_gbasis_tile)), wrap(tile_gbasis_remaining_stride))); // Group the trailing modes to make this max rank-5 -- TMA rank limitation @@ -570,7 +594,7 @@ make_tma_copy_desc(Tensor const& gtensor, // The original GM // // TMA desc creation - // + // constexpr int tma_dim = decltype(rank(tma_gbasis))::value; @@ -579,7 +603,7 @@ make_tma_copy_desc(Tensor const& gtensor, // The original GM // void* gmem_address = (void*) raw_pointer_cast(gtensor_T.data()); - auto gmem_layout = gtensor_T.layout(); + auto gmem_layout = gtensor_T.layout(); cute::array gmem_prob_shape = {1,1,1,1,1}; cute::array gmem_prob_stride = {0,0,0,0,0}; @@ -665,20 +689,20 @@ make_tma_copy_desc(Tensor const& gtensor, // The original GM // // Construct the descriptor // - + TmaDescriptor tma_desc = {0}; - + // // TMA general info // - + #if (__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__) - + CUtensorMapDataType tma_format = TMA::to_CUtensorMapDataType(); CUtensorMapInterleave tma_interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; CUtensorMapL2promotion tma_l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; CUtensorMapFloatOOBfill tma_oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; - + // TMA smem swizzle type CUtensorMapSwizzle smem_swizzle = TMA::to_CUtensorMapSwizzle(get_tma_swizzle_bits(swizzle)); CUresult result = cuTensorMapEncodeTiled( @@ -694,7 +718,7 @@ make_tma_copy_desc(Tensor const& gtensor, // The original GM smem_swizzle, tma_l2Promotion, tma_oobFill); - + if (result != CUDA_SUCCESS) { std::cerr << "TMA Desc Addr: " << &tma_desc << "\nformat " << tma_format @@ -711,8 +735,11 @@ make_tma_copy_desc(Tensor const& gtensor, // The original GM std::cerr << "Error: Failed to initialize the TMA descriptor " << result << std::endl; assert(false); } - + #endif // (__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__) + auto recast_ratio = cute::ratio(Int::value>{}, + Int::value>{}); + // Finally, get the inverse permutation of the E bases for the mocked gmem stride // NOTE This is essentially ArithmeticTuple inverse... auto gmem_stride_bases = transform_leaf(stride(gbasis), [&](auto ei) { @@ -727,9 +754,9 @@ make_tma_copy_desc(Tensor const& gtensor, // The original GM [[maybe_unused]] auto j = find_if(tma_gbasis_stride, [&](auto tma_stride_j) { return any_of(tma_stride_j, [&](auto dj) { return dj == EI{}; }); }); if constexpr (decltype(j == rank(tma_gbasis_stride))::value) { return Int<0>{}; // If not-found, return arithmetic identity -- no contribution to the TMA - } else + } else if constexpr (decltype(j == Int<0>{})::value) { - auto scale = ratio(size(tma_gstride), size(smem_inv_h)) * basis_get(ei, stride(gtensor)); + auto scale = recast_ratio * basis_get(ei, stride(gtensor)); return E{} * scale; // Return TMA Coord basis -- with a recast scale factor } else if constexpr (decltype(rank(tma_gbasis_stride) == Int<1>{})::value) { @@ -959,21 +986,23 @@ template CUTE_HOST_RTC auto make_tma_copy(CopyOp const& copy_op, Tensor const& gtensor, SLayout const& slayout, - CTA_Tile const& cta_tile, + CTA_Tiler const& cta_tiler, Cluster_Size const& cluster_size) { + auto cta_v_tile = make_identity_layout(shape(gtensor)).compose(cta_tiler); + auto cta_t_tile = make_layout(cluster_size); return detail::make_tma_copy_tiled(copy_op, gtensor, slayout, - make_layout(cluster_size), - make_identity_layout(cta_tile)); + cta_t_tile, + cta_v_tile); } // Explicit defaulting diff --git a/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp b/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp index 6d391b2173..7ed1061b18 100644 --- a/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp +++ b/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp @@ -37,12 +37,13 @@ #include #endif -#include "cute/arch/copy_sm90_desc.hpp" -#include "cute/swizzle_layout.hpp" +#include +#include namespace cute::detail { template +CUTE_HOST_DEVICE constexpr TMA::SmemSwizzleBits get_tma_swizzle_bits(Swizzle) { diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp index 68bd290e6d..94e8a8752e 100644 --- a/include/cute/atom/mma_atom.hpp +++ b/include/cute/atom/mma_atom.hpp @@ -155,7 +155,8 @@ struct MMA_Atom> if constexpr (has_dereference::value) { // If the intended FrgTypeA is a view (of the current tensor), forward the whole - static_assert(is_same, typename remove_cvref_t::value_type>::value, "Expecting ValTypeA type"); + static_assert(is_same::value_type>::value + , "Expecting ValTypeA type"); return make_tensor(std::forward(atensor)); } else { // Else, the intended FrgTypeA is a value type, construct a new tensor with a fragment layout @@ -176,7 +177,8 @@ struct MMA_Atom> if constexpr (has_dereference::value) { // If the intended FrgTypeB is a view (of the current tensor), forward the whole - static_assert(is_same::value_type>::value, "Expecting ValTypeB type"); + static_assert(is_same::value_type>::value + , "Expecting ValTypeB type"); return make_tensor(std::forward(btensor)); } else { // Else, the intended FrgTypeB is a value type, construct a new tensor with a fragment layout @@ -224,6 +226,11 @@ struct TiledMMA : MMA_Atom // thr_idx -> (ThrV,ThrM,ThrN,ThrK) using TidLayout = decltype(right_inverse(ThrLayoutVMNK{})); + CUTE_HOST_DEVICE constexpr auto + get_thr_layout_vmnk() const { + return ThrLayoutVMNK{}; + } + // Tile a tensor or a layout from shape // (M,N,...) // to shape @@ -295,8 +302,8 @@ struct TiledMMA : MMA_Atom thrfrg_A(ATensor&& atensor) { CUTE_STATIC_ASSERT_V(rank(atensor) >= Int<2>{}); - CUTE_STATIC_ASSERT_V(size<0>(atensor) % size<0>(TiledShape_MNK{}) == Int<0>{}); - CUTE_STATIC_ASSERT_V(size<1>(atensor) % size<2>(TiledShape_MNK{}) == Int<0>{}); + //CUTE_STATIC_ASSERT_V(size<0>(atensor) % size<0>(TiledShape_MNK{}) == Int<0>{}); + //UTE_STATIC_ASSERT_V(size<1>(atensor) % size<2>(TiledShape_MNK{}) == Int<0>{}); // Reorder the tensor for the TiledAtom auto t_tile = make_tile(left_inverse(get<0>(PermutationsMNK{})), @@ -353,8 +360,8 @@ struct TiledMMA : MMA_Atom thrfrg_B(BTensor&& btensor) { CUTE_STATIC_ASSERT_V(rank(btensor) >= Int<2>{}); - CUTE_STATIC_ASSERT_V(size<0>(btensor) % size<1>(TiledShape_MNK{}) == Int<0>{}); - CUTE_STATIC_ASSERT_V(size<1>(btensor) % size<2>(TiledShape_MNK{}) == Int<0>{}); + //CUTE_STATIC_ASSERT_V(size<0>(btensor) % size<1>(TiledShape_MNK{}) == Int<0>{}); + //CUTE_STATIC_ASSERT_V(size<1>(btensor) % size<2>(TiledShape_MNK{}) == Int<0>{}); // Reorder the tensor for the TiledAtom auto t_tile = make_tile(left_inverse(get<1>(PermutationsMNK{})), diff --git a/include/cute/atom/mma_traits_sm90_gmma.hpp b/include/cute/atom/mma_traits_sm90_gmma.hpp index 993205c413..27d40e3d0d 100644 --- a/include/cute/atom/mma_traits_sm90_gmma.hpp +++ b/include/cute/atom/mma_traits_sm90_gmma.hpp @@ -117,16 +117,22 @@ using Layout_SW128_Atom = typename conditional>::type; // -// Tensor to LayoutType utility +// Tensor (position-dependent swizzle) to LayoutType utility // -// smem_ptr_swizzle LayoutType -template +template CUTE_HOST_DEVICE constexpr LayoutType -layout_type(Tensor>>, - Layout> const&) +layout_type(Tensor> const&) { + static_assert(is_same::value, + "Expected uint128_t type in LayoutType conversion."); + + using Swizzle = get_swizzle_t; + constexpr int B = Swizzle::num_bits; + constexpr int M = Swizzle::num_base; + constexpr int S = Swizzle::num_shft; + static_assert(M == 4, "Unsupported layout swizzle"); static_assert(0 <= B && B <= 3, "Unsupported layout swizzle"); static_assert(S == 3, "Unsupported layout swizzle"); @@ -140,16 +146,6 @@ layout_type(Tensor>> return LayoutType::INTERLEAVE; // ERROR } -// smem_ptr non-swizzled LayoutType -template -CUTE_HOST_DEVICE constexpr -LayoutType -layout_type(Tensor>, - Layout> const&) -{ - return LayoutType::INTERLEAVE; -} - /////////////////////////////////////////////////////////////////////////////// // Construction method for GMMA Descriptors /////////////////////////////////////////////////////////////////////////////// @@ -211,7 +207,7 @@ make_gmma_desc(Tensor const& tensor) desc.bitfield.layout_type_ = uint8_t(LAYOUT_TYPE); // Start address (4LSB not included) - uint32_t start_address = cast_smem_ptr_to_uint(u128_tensor.data().get()); + uint32_t start_address = cast_smem_ptr_to_uint(raw_pointer_cast(u128_tensor.data())); desc.bitfield.start_address_ = start_address >> 4; constexpr uint8_t base_offset = 0; @@ -314,57 +310,67 @@ make_gmma_desc(Tensor const& tensor) struct DescriptorIterator { + using reference = GmmaDescriptor; + using element_type = GmmaDescriptor; + using value_type = GmmaDescriptor; + GmmaDescriptor desc_; // Dereference returns the GmmaDescriptor CUTE_HOST_DEVICE constexpr - GmmaDescriptor const& operator*() const { return desc_; } + reference operator*() const { return desc_; } // Advance and return a new GmmaDescriptor template CUTE_HOST_DEVICE constexpr - GmmaDescriptor operator[](Index const& i) const { return *(*this + i); } + reference operator[](Index const& i) const { return *(*this + i); } // Return an advanced iterator template CUTE_HOST_DEVICE constexpr DescriptorIterator operator+(Index const& offset) const { - return { GmmaDescriptor {desc_ + uint64_t(offset)} }; + return { GmmaDescriptor{desc_ + uint64_t(offset)} }; } CUTE_HOST_DEVICE friend void - print(DescriptorIterator const&) { printf("GMMA::DescriptorIterator"); } + print(DescriptorIterator) { printf("GMMA::DescriptorIterator"); } }; +template +CUTE_HOST_DEVICE constexpr +GmmaDescriptor +raw_pointer_cast(DescriptorIterator const& ptr) { + return ptr.desc_; +} + +// Recast a DescriptorIterator Tensor to uint64_t, it's RegType in mma_unpack +template +CUTE_HOST_DEVICE constexpr +DescriptorIterator +recast_ptr(DescriptorIterator const& iter) { + static_assert(is_same::value, "Can only cast GmmaDescriptorIterator to uint64_t."); + return iter; // Do nothing, it will still dereference to GmmaDescriptor and decay to uint64_t +} + // The GMMA Traits below have custom fragment type flags for their smem desc tensors. // These flags specialize a MakeTensor customization point to correctly make the fragment that is desired. template struct smem_desc : DescriptorIterator {}; -// Recast a DescriptorIterator Tensor to uint64_t, it's RegType -template -CUTE_HOST_DEVICE constexpr -auto -recast(Tensor,TLayout> const& tensor, type_list) -{ - static_assert(is_same::value, "Can only cast descriptors to uint64_t."); - return make_tensor(tensor.data(), Layout<_1,_0>{}); -} - } // end namespace GMMA // Customization point for creating a GMMA::smem_desc Tensor template struct MakeTensor> { - template + template CUTE_HOST_DEVICE constexpr auto - operator()(Tensor const& smem_tensor) + operator()(Tensor const& smem_tensor) { - static_assert(is_smem::value, "Expected SMEM Tensor to construct a GMMA Desc Tensor"); + static_assert(is_smem::value, "Expected SMEM Tensor to construct a GMMA Desc Tensor"); return make_tensor(GMMA::DescriptorIterator{GMMA::make_gmma_desc(tensor<0>(smem_tensor))}, - recast(smem_tensor).layout()); + replace<0>(recast(smem_tensor).layout(), Layout<_1,_0>{})); } }; diff --git a/include/cute/container/array.hpp b/include/cute/container/array.hpp index 3b0831657a..dcf01ba1c4 100644 --- a/include/cute/container/array.hpp +++ b/include/cute/container/array.hpp @@ -41,13 +41,14 @@ namespace cute template struct array { - using value_type = T; + using element_type = T; + using value_type = remove_cv_t; using size_type = size_t; using difference_type = ptrdiff_t; - using reference = value_type&; - using const_reference = const value_type&; - using pointer = value_type*; - using const_pointer = const value_type*; + using reference = element_type&; + using const_reference = const element_type&; + using pointer = element_type*; + using const_pointer = const element_type*; using iterator = pointer; using const_iterator = const_pointer; @@ -190,20 +191,21 @@ struct array } } - value_type __elems_[N > 0 ? N : 1]; + element_type __elems_[N]; }; template struct array { - using value_type = T; + using element_type = T; + using value_type = remove_cv_t; using size_type = size_t; using difference_type = ptrdiff_t; - using reference = value_type&; - using const_reference = const value_type&; - using pointer = value_type*; - using const_pointer = const value_type*; + using reference = element_type&; + using const_reference = const element_type&; + using pointer = element_type*; + using const_pointer = const element_type*; using const_iterator = const_pointer; using iterator = const_iterator; diff --git a/include/cute/container/array_subbyte.hpp b/include/cute/container/array_subbyte.hpp index f39bcd66bf..88e7abf6ec 100644 --- a/include/cute/container/array_subbyte.hpp +++ b/include/cute/container/array_subbyte.hpp @@ -39,11 +39,18 @@ #include // sizeof_bits #include -#include // dummy_type namespace cute { +template +struct is_subbyte { + static constexpr bool value = sizeof_bits_v < 8; +}; + +template +constexpr bool is_subbyte_v = is_subbyte::value; + // // Underlying subbyte storage type // @@ -53,43 +60,44 @@ using subbyte_storage_type_t = conditional_t<(sizeof_bits_v <= 8), uint8_t, conditional_t<(sizeof_bits_v <= 32), uint32_t, conditional_t<(sizeof_bits_v <= 64), uint64_t, conditional_t<(sizeof_bits_v <= 128), uint128_t, - dummy_type>>>>>; + T>>>>>; -template -struct subbyte_iterator; +template struct subbyte_iterator; +template struct swizzle_ptr; // // subbyte_reference // Proxy object for sub-byte element references // template -struct subbyte_reference +struct subbyte_reference { // Iterator Element type (const or non-const) using element_type = T; - // Iterator Value type without type qulifier. + // Iterator Value type without type qualifier. using value_type = remove_cv_t; // Storage type (const or non-const) using storage_type = conditional_t<(is_const_v), subbyte_storage_type_t const, subbyte_storage_type_t>; - static_assert(!is_same_v, "Storage type is not supported"); + static_assert(sizeof_bits_v % 8 == 0, "Storage type is not supported"); static_assert(sizeof_bits_v <= sizeof_bits_v, "Size of Element must not be greater than Storage."); - // Number of logical elements per stored object - static constexpr uint8_t ElementsPerStoredItem = sizeof_bits_v / sizeof_bits_v; - // Bitmask for covering one item - static constexpr storage_type BitMask = storage_type((storage_type(1) << sizeof_bits_v) - 1); - private: + // Bitmask for covering one item + static constexpr storage_type BitMask = storage_type(storage_type(-1) >> (sizeof_bits_v - sizeof_bits_v)); + // Flag for fast branching on straddled elements + static constexpr bool is_storage_unaligned = ((sizeof_bits_v % sizeof_bits_v) != 0); + friend class subbyte_iterator; - + // Pointer to storage element storage_type* ptr_ = nullptr; - // Index into elements packed into storage_type element. RI: 0 <= idx_ < ElementsPerStoredItem + // Bit index of value_type starting position within storage_type element. + // RI: 0 <= idx_ < sizeof_bit uint8_t idx_ = 0; // Ctor @@ -100,38 +108,73 @@ struct subbyte_reference public: // Copy Ctor - CUTE_HOST_DEVICE constexpr + CUTE_HOST_DEVICE constexpr subbyte_reference(subbyte_reference const& other) { *this = element_type(other); } // Copy Assignment - CUTE_HOST_DEVICE constexpr + CUTE_HOST_DEVICE constexpr subbyte_reference& operator=(subbyte_reference const& other) { return *this = element_type(other); } - // Dtor - ~subbyte_reference() = default; - // Assignment - template + template CUTE_HOST_DEVICE constexpr - enable_if_t, subbyte_reference&> operator=(element_type x) { + enable_if_t, subbyte_reference&> operator=(element_type x) + { static_assert(is_same_v, "Do not specify template arguments!"); - storage_type item = (reinterpret_cast(x) & BitMask); - storage_type kUpdateMask = storage_type(~(BitMask << (idx_ * sizeof_bits_v))); - *ptr_ = storage_type((*ptr_ & kUpdateMask) | (item << (idx_ * sizeof_bits_v))); + storage_type item = (reinterpret_cast(x) & BitMask); + + // Update the current storage element + storage_type bit_mask_0 = storage_type(BitMask << idx_); + ptr_[0] = storage_type((ptr_[0] & ~bit_mask_0) | (item << idx_)); + + // If value_type is unaligned with storage_type (static) and this is a straddled value (dynamic) + if (is_storage_unaligned && idx_ + sizeof_bits_v > sizeof_bits_v) { + uint8_t straddle_bits = uint8_t(sizeof_bits_v - idx_); + storage_type bit_mask_1 = storage_type(BitMask >> straddle_bits); + // Update the next storage element + ptr_[1] = storage_type((ptr_[1] & ~bit_mask_1) | (item >> straddle_bits)); + } + return *this; } + // Comparison of referenced values + CUTE_HOST_DEVICE constexpr friend + bool operator==(subbyte_reference const& x, subbyte_reference const& y) { return x.get() == y.get(); } + CUTE_HOST_DEVICE constexpr friend + bool operator!=(subbyte_reference const& x, subbyte_reference const& y) { return x.get() != y.get(); } + CUTE_HOST_DEVICE constexpr friend + bool operator< (subbyte_reference const& x, subbyte_reference const& y) { return x.get() < y.get(); } + CUTE_HOST_DEVICE constexpr friend + bool operator> (subbyte_reference const& x, subbyte_reference const& y) { return x.get() > y.get(); } + CUTE_HOST_DEVICE constexpr friend + bool operator<=(subbyte_reference const& x, subbyte_reference const& y) { return x.get() <= y.get(); } + CUTE_HOST_DEVICE constexpr friend + bool operator>=(subbyte_reference const& x, subbyte_reference const& y) { return x.get() >= y.get(); } + + // Value CUTE_HOST_DEVICE - element_type get() const { + element_type get() const + { if constexpr (is_same_v) { // Extract to bool -- potentially faster impl - return bool((*ptr_) & (BitMask << (idx_ * sizeof_bits_v))); + return bool((*ptr_) & (BitMask << idx_)); } else { // Extract to element_type - storage_type item = storage_type((*ptr_ >> (idx_ * sizeof_bits_v)) & BitMask); - return reinterpret_cast(item); + // Extract from the current storage element + auto item = storage_type((ptr_[0] >> idx_) & BitMask); + + // If value_type is unaligned with storage_type (static) and this is a straddled value (dynamic) + if (is_storage_unaligned && idx_ + sizeof_bits_v > sizeof_bits_v) { + uint8_t straddle_bits = uint8_t(sizeof_bits_v - idx_); + storage_type bit_mask_1 = storage_type(BitMask >> straddle_bits); + // Extract from the next storage element + item |= storage_type((ptr_[1] & bit_mask_1) << straddle_bits); + } + + return reinterpret_cast(item); } } @@ -142,112 +185,145 @@ struct subbyte_reference } }; - // // subbyte_iterator // Random-access iterator over subbyte references // template -struct subbyte_iterator +struct subbyte_iterator { // Iterator Element type (const or non-const) using element_type = T; - // Iterator Value type without type qulifier. + // Iterator Value type without type qualifier. using value_type = remove_cv_t; // Storage type (const or non-const) using storage_type = conditional_t<(is_const_v), subbyte_storage_type_t const, subbyte_storage_type_t>; // Reference proxy type using reference = subbyte_reference; - static_assert(!is_same_v, "Storage type is not supported"); + static_assert(sizeof_bits_v % 8 == 0, "Storage type is not supported"); static_assert(sizeof_bits_v <= sizeof_bits_v, "Size of Element must not be greater than Storage."); - // Number of logical elements per stored object - static constexpr uint8_t ElementsPerStoredItem = sizeof_bits_v / sizeof_bits_v; - private: + template friend class swizzle_ptr; + // Pointer to storage element storage_type* ptr_ = nullptr; - // Index into elements packed into storage_type element. RI: 0 <= idx_ < ElementsPerStoredItem + // Bit index of value_type starting position within storage_type element. + // RI: 0 <= idx_ < sizeof_bit uint8_t idx_ = 0; public: + // Ctor + subbyte_iterator() = default; + + // Ctor template CUTE_HOST_DEVICE constexpr - subbyte_iterator(PointerType* ptr, uint8_t idx = 0): ptr_(reinterpret_cast(ptr)), idx_(idx) { } + subbyte_iterator(PointerType* ptr, uint8_t idx = 0) : ptr_(reinterpret_cast(ptr)), idx_(idx) { } - subbyte_iterator() = default; CUTE_HOST_DEVICE constexpr - subbyte_iterator& operator++() { - ++idx_; - if (idx_ == ElementsPerStoredItem) { - ++ptr_; - idx_ = 0; - } - return *this; + reference operator*() const { + return reference(ptr_, idx_); } CUTE_HOST_DEVICE constexpr - subbyte_iterator& operator--() { - if (idx_) { - --idx_; - } else { - --ptr_; - idx_ = ElementsPerStoredItem - 1; - } + subbyte_iterator& operator+=(uint64_t k) { + k = sizeof_bits_v * k + idx_; + ptr_ += k / sizeof_bits_v; + idx_ = k % sizeof_bits_v; return *this; } CUTE_HOST_DEVICE constexpr - subbyte_iterator operator++(int) { - subbyte_iterator ret(*this); - ++(*this); - return ret; + subbyte_iterator operator+(uint64_t k) const { + return subbyte_iterator(ptr_, idx_) += k; } CUTE_HOST_DEVICE constexpr - subbyte_iterator operator--(int) { - subbyte_iterator ret(*this); - --(*this); - return ret; + reference operator[](uint64_t k) const { + return *(*this + k); } CUTE_HOST_DEVICE constexpr - subbyte_iterator& operator+=(uint64_t k) { - k += idx_; - ptr_ += k / ElementsPerStoredItem; - idx_ = k % ElementsPerStoredItem; + subbyte_iterator& operator++() { + idx_ += sizeof_bits_v; + if (idx_ >= sizeof_bits_v) { + ++ptr_; + idx_ -= sizeof_bits_v; + } return *this; } CUTE_HOST_DEVICE constexpr - subbyte_iterator operator+(uint64_t k) const { - return subbyte_iterator(ptr_,idx_) += k; + subbyte_iterator operator++(int) { + subbyte_iterator ret(*this); + ++(*this); + return ret; } CUTE_HOST_DEVICE constexpr - reference operator*() const { - return reference(ptr_, idx_); + subbyte_iterator& operator--() { + if (idx_ >= sizeof_bits_v) { + idx_ -= sizeof_bits_v; + } else { + --ptr_; + idx_ += sizeof_bits_v - sizeof_bits_v; + } + return *this; } CUTE_HOST_DEVICE constexpr - reference operator[](uint64_t k) const { - return *(*this + k); + subbyte_iterator operator--(int) { + subbyte_iterator ret(*this); + --(*this); + return ret; } - CUTE_HOST_DEVICE constexpr - friend bool operator==(subbyte_iterator const& x, subbyte_iterator const& y) { + CUTE_HOST_DEVICE constexpr friend + bool operator==(subbyte_iterator const& x, subbyte_iterator const& y) { return x.ptr_ == y.ptr_ && x.idx_ == y.idx_; } + CUTE_HOST_DEVICE constexpr friend + bool operator< (subbyte_iterator const& x, subbyte_iterator const& y) { + return x.ptr_ < y.ptr_ || (x.ptr_ == y.ptr_ && x.idx_ < y.idx_); + } + CUTE_HOST_DEVICE constexpr friend + bool operator!=(subbyte_iterator const& x, subbyte_iterator const& y) { return !(x == y); } + CUTE_HOST_DEVICE constexpr friend + bool operator<=(subbyte_iterator const& x, subbyte_iterator const& y) { return !(y < x); } + CUTE_HOST_DEVICE constexpr friend + bool operator> (subbyte_iterator const& x, subbyte_iterator const& y) { return (y < x); } + CUTE_HOST_DEVICE constexpr friend + bool operator>=(subbyte_iterator const& x, subbyte_iterator const& y) { return !(x < y); } + + // Conversion to raw pointer with loss of subbyte index + CUTE_HOST_DEVICE constexpr friend + T* raw_pointer_cast(subbyte_iterator const& x) { + assert(x.idx_ == 0); + return reinterpret_cast(x.ptr_); + } + + // Conversion to NewT_ with possible loss of subbyte index + template + CUTE_HOST_DEVICE constexpr friend + auto recast_ptr(subbyte_iterator const& x) { + using NewT = conditional_t<(is_const_v), NewT_ const, NewT_>; + if constexpr (is_subbyte::value) { // Making subbyte_iter, preserve the subbyte idx + return subbyte_iterator(x.ptr_, x.idx_); + } else { // Not subbyte, assume/assert subbyte idx 0 + return reinterpret_cast(raw_pointer_cast(x)); + } + CUTE_GCC_UNREACHABLE; + } - CUTE_HOST_DEVICE constexpr - friend bool operator!=(subbyte_iterator const& x, subbyte_iterator const& y) { - return !(x == y); + CUTE_HOST_DEVICE friend void print(subbyte_iterator x) { + printf("subptr[%db](%p.%u)", int(sizeof_bits::value), x.ptr_, x.idx_); } }; @@ -281,26 +357,20 @@ struct array_subbyte // Storage type (const or non-const) using storage_type = conditional_t<(is_const_v), subbyte_storage_type_t const, subbyte_storage_type_t>; - static_assert(!is_same_v, "Storage type is not supported"); - - // Number of logical elements per stored object - static constexpr uint8_t ElementsPerStoredItem = sizeof_bits_v / sizeof_bits_v; - - // Bitmask for covering one item - static constexpr storage_type BitMask = ((storage_type(1) << sizeof_bits::value) - 1); - - // Number of storage elements - static constexpr size_type StorageElements = (N + ElementsPerStoredItem - 1) / ElementsPerStoredItem; + static_assert(sizeof_bits_v % 8 == 0, "Storage type is not supported"); private: + // Number of storage elements, ceil_div + static constexpr size_type StorageElements = (N * sizeof_bits_v + sizeof_bits_v - 1) / sizeof_bits_v; + // Internal storage storage_type storage[StorageElements]; public: CUTE_HOST_DEVICE constexpr - array_subbyte() { } + array_subbyte() {} CUTE_HOST_DEVICE constexpr array_subbyte(array_subbyte const& x) { @@ -334,20 +404,11 @@ struct array_subbyte } } - // Efficient fill method CUTE_HOST_DEVICE constexpr void fill(T const& value) { - storage_type item = (reinterpret_cast(value) & BitMask); - - // Reproduce the value over the bits of the storage item - CUTE_UNROLL - for (size_type s = sizeof_bits_v; s < sizeof_bits_v; s *= 2) { - item |= item << s; - } - CUTE_UNROLL - for (size_type i = 0; i < StorageElements; ++i) { - storage[i] = item; + for (size_type i = 0; i < N; ++i) { + at(i) = value; } } @@ -428,12 +489,12 @@ struct array_subbyte CUTE_HOST_DEVICE constexpr iterator end() { - return iterator(storage + N / ElementsPerStoredItem, N % ElementsPerStoredItem); + return iterator(storage) + N; } CUTE_HOST_DEVICE constexpr const_iterator end() const { - return const_iterator(storage + N / ElementsPerStoredItem, N % ElementsPerStoredItem); + return const_iterator(storage) + N; } CUTE_HOST_DEVICE constexpr @@ -509,6 +570,12 @@ T&& get(array_subbyte&& a) namespace CUTE_STL_NAMESPACE { +template +struct is_reference> + : CUTE_STL_NAMESPACE::true_type +{}; + + template struct tuple_size> : CUTE_STL_NAMESPACE::integral_constant diff --git a/include/cute/container/bit_field.hpp b/include/cute/container/bit_field.hpp index 0cd3e4fe5d..bd85dc65e4 100644 --- a/include/cute/container/bit_field.hpp +++ b/include/cute/container/bit_field.hpp @@ -72,16 +72,10 @@ struct bit_field // Number of bits in data_[idx] used for NumBits if straddling, else 0 static constexpr uint32_t bit_hi = (idx + 1 < N) ? (storage_type_bits - bit_lo) : 0; -private: - // MSVC issues warning C4293 ("shift count negative or too big, undefined behavior") - // if we use NumBits directly in the shift expression, even if the shift occurs - // in the branch of a ternary expression where NumBits is known to be less than - // the number of bits of the value being shifted. - static constexpr uint32_t MollifiedNumBits = NumBits > 63u ? 63u : NumBits; public: // NumBits mask - static constexpr value_type mask = (NumBits < 64u) ? ((uint64_t(1) << MollifiedNumBits) - 1) : uint64_t(-1); + static constexpr value_type mask = value_type(uint64_t(-1) >> (64u - NumBits)); // NumBits mask for BitStart static constexpr storage_type mask_lo = storage_type(mask) << bit_lo; // NumBits mask for leftover bits in data_[idx+1] if straddling, else 0 @@ -93,7 +87,7 @@ struct bit_field CUTE_HOST_DEVICE constexpr value_type get() const { storage_type result = (data_[idx] & mask_lo) >> bit_lo; - if constexpr (bit_hi) { + if constexpr (bit_hi != 0) { result |= (data_[idx+1] & mask_hi) << bit_hi; } return static_cast(result); @@ -104,7 +98,7 @@ struct bit_field void set(value_type x) { storage_type item = static_cast(x & mask); data_[idx] = static_cast((data_[idx] & ~mask_lo) | (item << bit_lo)); - if constexpr (bit_hi) { + if constexpr (bit_hi != 0) { data_[idx+1] = static_cast((data_[idx+1] & ~mask_hi) | (item >> bit_hi)); } } diff --git a/include/cute/int_tuple.hpp b/include/cute/int_tuple.hpp index 4497034fcf..d7a59b884c 100644 --- a/include/cute/int_tuple.hpp +++ b/include/cute/int_tuple.hpp @@ -40,7 +40,7 @@ /** IntTuple is an integer or a tuple of IntTuples. * This file holds utilities for working with IntTuples, * but does not hold a concrete concept or class of IntTuple. - */ + */ namespace cute { @@ -49,7 +49,7 @@ namespace cute // Even though is_tuple is false and tuple_size doesn't compile, // CuTe defines rank(Integral) as 1, so it's useful for get<0>(Integral) to return its input template >::value)> -CUTE_HOST_DEVICE constexpr +CUTE_HOST_DEVICE constexpr decltype(auto) get(T&& t) noexcept { @@ -59,7 +59,7 @@ get(T&& t) noexcept // Custom recursive get for anything that implements get(.) (for a single integer I). template -CUTE_HOST_DEVICE constexpr +CUTE_HOST_DEVICE constexpr decltype(auto) get(T&& t) noexcept { @@ -218,19 +218,29 @@ static constexpr int depth_v = depth_t::value; // product // -template -CUTE_HOST_DEVICE constexpr -auto -product(IntTuple const& a) +// Implementation of product (see below) as a function object +struct Product { - if constexpr (is_tuple::value) { - return cute::apply(a, [](auto const&... v){ return (Int<1>{} * ... * product(v)); }); - } else { - return a; - } + template + CUTE_HOST_DEVICE constexpr + auto + operator()(IntTuple const& a) const + { + if constexpr (is_tuple::value) { + if constexpr (tuple_size::value == 0) { + return Int<1>{}; + } else { + return cute::transform_apply(a, Product{}, multiplies_unary_lfold{}); + } + } else { + return a; + } - CUTE_GCC_UNREACHABLE; -} + CUTE_GCC_UNREACHABLE; + } +}; +// Callable product function object +CUTE_INLINE_CONSTANT Product product; // Return a rank(t) tuple @a result such that get(@a result) = product(get(@a t)) template @@ -259,7 +269,7 @@ size(IntTuple const& a) if constexpr (sizeof...(Is) == 0) { return product(a); } else { - return product(get(a)); + return size(get(a)); } CUTE_GCC_UNREACHABLE; @@ -361,7 +371,7 @@ shape_div(IntTupleA const& a, IntTupleB const& b) if constexpr (is_static::value && is_static::value) { static_assert(IntTupleA::value % IntTupleB::value == 0 || IntTupleB::value % IntTupleA::value == 0, "Static shape_div failure"); return C{}; - } else { // int int + } else { // int int //assert(a % b == 0 || b % a == 0); // Wave dynamic assertion return a / b != 0 ? a / b : signum(a) * signum(b); // Division with rounding away from zero } diff --git a/include/cute/layout.hpp b/include/cute/layout.hpp index 5072f0121e..6925d400eb 100644 --- a/include/cute/layout.hpp +++ b/include/cute/layout.hpp @@ -1034,7 +1034,7 @@ complement(Shape const& shape, Stride const& stride, CoSizeHi const& cosize_hi) // Should just be a sort and a fold... // Then we could even handle dynamic strides (but they would destroy all static strides) - auto [shape_, stride_, result_shape_, result_stride] = + auto [shape_, stride_, result_shape_, result_stride] = fold(make_seq{}, cute::make_tuple(shape, stride, cute::make_tuple(), cute::make_tuple(Int<1>{})), [](auto const& init, auto i) @@ -1094,7 +1094,7 @@ CUTE_HOST_DEVICE constexpr auto inverse_seq(Shape const& shape, Stride const& stride, seq) { - auto next_I = find_if(stride, [](auto a) { return is_constant{}; }); + auto next_I = cute::find_if(stride, [](auto a) { return is_constant{}; }); if constexpr (next_I == decltype(rank(stride))::value) { return seq{}; @@ -1197,22 +1197,16 @@ auto max_common_layout(Layout const& a, Layout const& b) { - if constexpr (is_static::value && is_static::value && - is_static::value && is_static::value) - { - Layout inv_b = right_inverse(b); - Layout common = coalesce(composition(a, inv_b)); + Layout inv_b = right_inverse(b); + Layout common = coalesce(composition(a, inv_b)); - if constexpr (is_constant<1, decltype(stride<0>(common))>::value) { - // Truncate to the size of the contiguous vector (static stride-1 mode) - return composition(inv_b, layout<0>(common)); - } else { - return Layout<_1,_0>{}; - } + // NOTE: If one of the layouts is dynamic, we can't prove alignment+vectorization is valid + // We assume dynamic shapes/strides obey alignment requirements (i.e. are large and multiples of the vector) + if constexpr (is_static(common))>::value && + is_constant<1, decltype(stride<0>(common))>::value) { + // Truncate to the size of the contiguous vector (static stride-1 mode) + return composition(inv_b, layout<0>(common)); } else { - // CASE: One of the layouts is dynamic, can't prove alignment+vectorization is valid - // NOTE: Could weaken if we assume dynamic shapes/strides obey alignment requirements - // (i.e. are large and multiples of the vector) return Layout<_1,_0>{}; } } @@ -1231,21 +1225,15 @@ auto max_common_vector(Layout const& a, Layout const& b) { - if constexpr (is_static::value && is_static::value && - is_static::value && is_static::value) - { - Layout common = coalesce(composition(a, right_inverse(b))); + Layout common = coalesce(composition(a, right_inverse(b))); - if constexpr (is_constant<1, decltype(stride<0>(common))>::value) { - // Truncate to the size of the contiguous vector (static stride-1 mode) - return shape<0>(common); - } else { - return Int<1>{}; - } + // NOTE: If one of the layouts is dynamic, we can't prove alignment+vectorization is valid + // We assume dynamic shapes/strides obey alignment requirements (i.e. are large and multiples of the vector) + if constexpr (is_static(common))>::value && + is_constant<1, decltype(stride<0>(common))>::value) { + // Truncate to the size of the contiguous vector (static stride-1 mode) + return shape<0>(common); } else { - // CASE: One of the layouts is dynamic, can't prove alignment+vectorization is valid - // NOTE: Could weaken if we assume dynamic shapes/strides obey alignment requirements - // (i.e. are large and multiples of the vector) return Int<1>{}; } @@ -1412,6 +1400,21 @@ tiled_divide(Layout const& layout, return div(_, repeat(_)); } +// Same as zipped_divide, but unpacks both modes: (BLK_A,BLK_B,...,a,b,...,x,y) +template +CUTE_HOST_DEVICE constexpr +auto +flat_divide(Layout const& layout, + Tile const& tile) +{ + auto div = zipped_divide(layout, tile); + + auto R0 = rank<0>(div); + auto R1 = rank<1>(div); + return div(repeat(_), repeat(_)); +} + // // Logical product // @@ -1606,7 +1609,7 @@ template CUTE_HOST_DEVICE constexpr auto -recast(Layout const& layout) +recast_layout(Layout const& layout) { if constexpr (sizeof_bits::value == sizeof_bits::value) { return layout; diff --git a/include/cute/layout_composed.hpp b/include/cute/layout_composed.hpp index 7b3b6f4f68..b8765d4a81 100644 --- a/include/cute/layout_composed.hpp +++ b/include/cute/layout_composed.hpp @@ -573,7 +573,7 @@ template CUTE_HOST_DEVICE constexpr auto -recast(ComposedLayout const& layout) +recast_layout(ComposedLayout const& layout) { if constexpr (sizeof(NewType) == sizeof(OldType)) { return layout; diff --git a/include/cute/numeric/arithmetic_tuple.hpp b/include/cute/numeric/arithmetic_tuple.hpp index ead3005cc8..ac6ff53921 100644 --- a/include/cute/numeric/arithmetic_tuple.hpp +++ b/include/cute/numeric/arithmetic_tuple.hpp @@ -126,24 +126,18 @@ operator+(tuple const& t, ArithmeticTuple const& u) { template CUTE_HOST_DEVICE constexpr -auto +ArithmeticTuple const& operator+(C, ArithmeticTuple const& u) { - if constexpr (t == 0) { - return u; - } else { - static_assert(t == 0, "Artihmetic tuple op+ error!"); - } + static_assert(t == 0, "Artihmetic tuple op+ error!"); + return u; } template CUTE_HOST_DEVICE constexpr -auto +ArithmeticTuple const& operator+(ArithmeticTuple const& t, C) { - if constexpr (u == 0) { - return t; - } else { - static_assert(u == 0, "Artihmetic tuple op+ error!"); - } + static_assert(u == 0, "Artihmetic tuple op+ error!"); + return t; } // @@ -153,30 +147,41 @@ operator+(ArithmeticTuple const& t, C) { template struct ArithmeticTupleIterator { + using value_type = ArithTuple; + using element_type = ArithTuple; + using reference = ArithTuple; + ArithTuple coord_; CUTE_HOST_DEVICE constexpr - ArithmeticTupleIterator() : coord_() {} - CUTE_HOST_DEVICE constexpr - ArithmeticTupleIterator(ArithTuple const& coord) : coord_(coord) {} + ArithmeticTupleIterator(ArithTuple const& coord = {}) : coord_(coord) {} CUTE_HOST_DEVICE constexpr ArithTuple const& operator*() const { return coord_; } template CUTE_HOST_DEVICE constexpr - auto operator+(Coord const& c) const { - return ArithmeticTupleIterator(coord_ + c); - } + auto operator[](Coord const& c) const { return *(*this + c); } template CUTE_HOST_DEVICE constexpr - auto operator[](Coord const& c) const { return *(*this + c); } + auto operator+(Coord const& c) const { + return ArithmeticTupleIterator(coord_ + c); + } }; -template -CUTE_HOST_DEVICE void print(ArithmeticTupleIterator const& iter) { - printf("ArithTuple"); print(iter.coord_); +template +CUTE_HOST_DEVICE constexpr +auto +make_inttuple_iter(Tuple const& t) { + return ArithmeticTupleIterator(as_arithmetic_tuple(t)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_inttuple_iter(T0 const& t0, T1 const& t1, Ts const&... ts) { + return make_tuple_iter(cute::make_tuple(t0, t1, ts...)); } // @@ -211,7 +216,7 @@ struct is_integral> : true_type {}; // Get the scalar T out of a ScaledBasis template CUTE_HOST_DEVICE constexpr auto -basis_value(SB const& e) +basis_value(SB const& e) { if constexpr (is_scaled_basis::value) { return basis_value(e.value()); @@ -224,7 +229,7 @@ basis_value(SB const& e) // Apply the N... pack to another Tuple template CUTE_HOST_DEVICE constexpr auto -basis_get(SB const& e, Tuple const& t) +basis_get(SB const& e, Tuple const& t) { if constexpr (is_scaled_basis::value) { return basis_get(e.value(), get(t)); @@ -448,36 +453,44 @@ template CUTE_HOST_DEVICE constexpr auto operator+(C, ScaledBasis const& u) { - if constexpr (t == 0) { - return u; - } else { - static_assert(t == 0, "ScaledBasis op+ error!"); - } + static_assert(t == 0, "ScaledBasis op+ error!"); + return u; } template CUTE_HOST_DEVICE constexpr auto operator+(ScaledBasis const& t, C) { - if constexpr (u == 0) { - return t; - } else { - static_assert(u == 0, "ScaledBasis op+ error!"); - } + static_assert(u == 0, "ScaledBasis op+ error!"); + return t; } // // Display utilities // +template +CUTE_HOST_DEVICE void print(ArithmeticTupleIterator const& iter) +{ + printf("ArithTuple"); print(iter.coord_); +} + template -CUTE_HOST_DEVICE void print(ScaledBasis const& e) { +CUTE_HOST_DEVICE void print(ScaledBasis const& e) +{ print(e.value()); printf("@%d", N); } #if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, ArithmeticTupleIterator const& iter) +{ + return os << "ArithTuple" << iter.coord_; +} + template -CUTE_HOST std::ostream& operator<<(std::ostream& os, ScaledBasis const& e) { +CUTE_HOST std::ostream& operator<<(std::ostream& os, ScaledBasis const& e) +{ return os << e.value() << "@" << N; } #endif diff --git a/include/cute/numeric/int.hpp b/include/cute/numeric/int.hpp index 9be920d1b0..ece592d909 100644 --- a/include/cute/numeric/int.hpp +++ b/include/cute/numeric/int.hpp @@ -140,6 +140,11 @@ struct sizeof_bits> { static constexpr size_t value = Bits; }; +template +struct sizeof_bits> { + static constexpr size_t value = Bits; +}; + template static constexpr int sizeof_bits_v = sizeof_bits::value; diff --git a/include/cute/numeric/integer_subbyte.hpp b/include/cute/numeric/integer_subbyte.hpp index b10b45d870..80b6f0ef0e 100644 --- a/include/cute/numeric/integer_subbyte.hpp +++ b/include/cute/numeric/integer_subbyte.hpp @@ -36,6 +36,8 @@ #include #endif +#include + #include #include diff --git a/include/cute/numeric/integral_constant.hpp b/include/cute/numeric/integral_constant.hpp index a88892251b..e2c1e0aeee 100644 --- a/include/cute/numeric/integral_constant.hpp +++ b/include/cute/numeric/integral_constant.hpp @@ -85,7 +85,7 @@ struct is_integral> : true_type {}; // is_static detects if an (abstract) value is defined completely by it's type (no members) template -struct is_static : bool_constant::value> {}; +struct is_static : bool_constant>::value> {}; template constexpr bool is_static_v = is_static::value; diff --git a/include/cute/numeric/integral_ratio.hpp b/include/cute/numeric/integral_ratio.hpp index 028ffffd66..3d3eb0134c 100644 --- a/include/cute/numeric/integral_ratio.hpp +++ b/include/cute/numeric/integral_ratio.hpp @@ -40,12 +40,16 @@ namespace cute { /** Compile-time rational arithmetic type. - * Like cute::C for std::integral_constant, cute::R for std::ratio has a short name + * Like cute::C for std::integral_constant, cute::R for std::ratio has a short name * for error messages and compile times. * The static data members @a num and @a den represent the reduced numerator and denominator - * of the rational value. Thus, two cute::R types with different @a n or @a d are distinct types - * even if they represent the same rational value. A cute::R exposes the reduced canonical type - * via its type member. That is, cute::R<3,6>::type is cute::R<1,2> and cute::R<6,3>::type is cute::C<2> + * of the rational value. Thus, two cute::R types with different @a n or @a d are distinct types + * even if they represent the same rational value. + * A cute::R exposes the reduced canonical type via its ::type member. + * That is, cute::R<3,6>::type is cute::R<1,2> and cute::R<6,3>::type is cute::C<2>. + * A cute::R::value can be used much like any other trait::value. It can be involved in + * arithmetic expressions (according to the operator-overloads for cute::C and cute::R, + * though these may be incomplete) but with a potential rational value rather than an integral value. */ template class R { @@ -53,7 +57,7 @@ class R { static constexpr auto an = abs(n); static constexpr auto ad = abs(d); static constexpr auto g = gcd(an, ad); - + public: static constexpr auto num = signum(n) * signum(d) * an / g; static constexpr auto den = ad / g; @@ -63,28 +67,28 @@ class R { template CUTE_HOST_DEVICE constexpr -typename R::type +typename R::type ratio(C, C) { return {}; } template CUTE_HOST_DEVICE constexpr -typename R::type +typename R::type operator*(R, R) { return {}; } template CUTE_HOST_DEVICE constexpr -typename R::type +typename R::type operator*(R, C) { return {}; } template CUTE_HOST_DEVICE constexpr -typename R::type +typename R::type operator*(C, R) { return {}; } @@ -109,28 +113,28 @@ operator*(R, C const& c) { template CUTE_HOST_DEVICE constexpr -typename R::type +typename R::type operator+(R, R) { return {}; } template CUTE_HOST_DEVICE constexpr -typename R::type +typename R::type operator+(R, C) { return {}; } template CUTE_HOST_DEVICE constexpr -typename R::type +typename R::type operator+(C, R) { return {}; } template CUTE_HOST_DEVICE constexpr -bool_constant::num == R::num && R::den == R::den> +bool_constant::num == R::num && R::den == R::den> operator==(R, R) { return {}; } @@ -144,14 +148,14 @@ operator==(R, C) { template CUTE_HOST_DEVICE constexpr -bool_constant::num == c && R::den == 1> +bool_constant::num == c && R::den == 1> operator==(C, R) { return {}; } template CUTE_HOST_DEVICE constexpr -typename R::type +typename R::type abs(R) { return {}; } diff --git a/include/cute/numeric/math.hpp b/include/cute/numeric/math.hpp index fc717c9310..82f4b9729f 100644 --- a/include/cute/numeric/math.hpp +++ b/include/cute/numeric/math.hpp @@ -130,6 +130,8 @@ has_single_bit(T x) { } // Smallest number of bits needed to represent the given value +// For x == 0, this is 0 +// For x != 0, this is 1 + floor(log2(x)) // bit_width( 0b0000 ) = 0 // bit_width( 0b0001 ) = 1 // bit_width( 0b0010 ) = 2 @@ -203,7 +205,7 @@ CUTE_HOST_DEVICE constexpr T rotl(T x, int s) { constexpr int N = numeric_limits::digits; - return s == 0 ? x : s > 0 ? (x << s) | (x >> (N - s)) : rotr(x, -s); + return static_cast(s == 0 ? x : s > 0 ? (x << s) | (x >> (N - s)) : rotr(x, -s)); } // Computes the result of circular bitwise right-rotation @@ -212,7 +214,7 @@ CUTE_HOST_DEVICE constexpr T rotr(T x, int s) { constexpr int N = numeric_limits::digits; - return s == 0 ? x : s > 0 ? (x >> s) | (x << (N - s)) : rotl(x, -s); + return static_cast(s == 0 ? x : s > 0 ? (x >> s) | (x << (N - s)) : rotl(x, -s)); } // Counts the number of consecutive 0 bits, starting from the most significant bit diff --git a/include/cute/pointer.hpp b/include/cute/pointer.hpp index 6c6a738f10..20eb79e494 100644 --- a/include/cute/pointer.hpp +++ b/include/cute/pointer.hpp @@ -33,308 +33,232 @@ #include #include -#include +#include // sizeof_bits #include +#include + +#include +#include +#include namespace cute { // -// has_dereference to determine if a type is a pointer concept +// recast_ptr -- Create an iterator over values of type T. +// For most types this will simply be T*, but certain types require more care. +// Subbyte Types: uint2_t, uint4_t, etc +// Requires construction of a subbyte_iterator in order to properly +// resolve each element in byte-addressed memory. // -template -struct has_dereference : false_type { -}; - -template -struct has_dereference())>> : true_type { -}; - -template +template CUTE_HOST_DEVICE constexpr -T* -raw_pointer_cast(T* ptr) { - return ptr; -} - -// -// Extract the physical type from a logical elem type. -// -template -struct get_raw_type -{ - using type = T; -}; - -template -using get_raw_type_t = typename get_raw_type::type; - - -// -// Pointer categories -// - -template -struct is_gmem : false_type {}; - -template -struct is_smem : false_type {}; - -// Anything that is not gmem or smem is rmem -template -struct is_rmem : bool_constant< not (is_gmem::value || is_smem::value)> {}; - -// -// A very simplified wrapper for pointers -- use for constructing tagged pointers -// -template -struct device_ptr +auto +recast_ptr(void* ptr) { - using value_type = T; - - static const uint32_t ElementsPerStoredItem = sizeof(T) * 8 / sizeof_bits_v; - - CUTE_HOST_DEVICE constexpr - device_ptr(T* ptr) : ptr_(ptr) {} - - CUTE_HOST_DEVICE constexpr - T* get() const { return ptr_; } - - CUTE_HOST_DEVICE constexpr - T& operator*() const { return *ptr_; } - - template - CUTE_HOST_DEVICE constexpr - T& operator[](Index const& i) const { - static_assert(sizeof_bits_v >= 8, "Use subbyte_iterator to access the element"); - return ptr_[i]; + if constexpr (is_subbyte::value) { + return subbyte_iterator(ptr); + } else { + return reinterpret_cast(ptr); } + CUTE_GCC_UNREACHABLE; +} - template - CUTE_HOST_DEVICE constexpr - DerivedType operator+(Index const& i) const { return {ptr_ + i / ElementsPerStoredItem}; } - - CUTE_HOST_DEVICE constexpr friend - ptrdiff_t operator-(device_ptr const& a, - device_ptr const& b) { - return a.ptr_ - b.ptr_; +template +CUTE_HOST_DEVICE constexpr +auto +recast_ptr(void const* ptr) +{ + if constexpr (is_subbyte::value) { + return subbyte_iterator(ptr); + } else { + return reinterpret_cast(ptr); } + CUTE_GCC_UNREACHABLE; +} - T* ptr_; -}; - -template +// Disambiguate nullptr +template CUTE_HOST_DEVICE constexpr -T* -raw_pointer_cast(device_ptr ptr) { - return ptr.get(); +auto +recast_ptr(decltype(nullptr)) { // nullptr_t + return recast_ptr(static_cast(nullptr)); } // // gmem_ptr // -template -struct gmem_ptr : device_ptr> { - using device_ptr>::device_ptr; +template +struct gmem_ptr : iter_adaptor> { + using iter_adaptor>::iter_adaptor; }; -template +template +struct is_gmem : false_type {}; +template // Found the gmem +struct is_gmem> : true_type {}; +template // Recurse on ::iterator, if possible +struct is_gmem> : is_gmem {}; + +// Idempotent gmem tag on an iterator +template CUTE_HOST_DEVICE constexpr -gmem_ptr -make_gmem_ptr(T* ptr) { - return {ptr}; +auto +make_gmem_ptr(Iterator iter) { + if constexpr (is_gmem::value) { + return iter; + } else { + return gmem_ptr{iter}; + } + CUTE_GCC_UNREACHABLE; } +// Explicitly typed construction from a raw pointer template CUTE_HOST_DEVICE constexpr -gmem_ptr +auto make_gmem_ptr(void* ptr) { - return {reinterpret_cast(ptr)}; + return make_gmem_ptr(recast_ptr(ptr)); } +// Explicitly typed construction from a raw pointer template CUTE_HOST_DEVICE constexpr -gmem_ptr +auto make_gmem_ptr(void const* ptr) { - return {reinterpret_cast(ptr)}; + return make_gmem_ptr(recast_ptr(ptr)); } -// nullptr_t overloads are needed because otherwise, -// make_gmem_ptr(nullptr) will be ambiguous, -// as std::nullptr_t can be converted to any pointer -// or pointer to member type. +// nullptr_t overload for make_gmem_ptr(nullptr) disambiguation template CUTE_HOST_DEVICE constexpr -gmem_ptr +auto make_gmem_ptr(decltype(nullptr)) { // nullptr_t - return {static_cast(nullptr)}; + return make_gmem_ptr(recast_ptr(nullptr)); } -template -struct is_gmem> : true_type {}; +// The gmem tag is invariant over type-recast +template +CUTE_HOST_DEVICE constexpr +auto +recast_ptr(gmem_ptr

const& ptr) { + return make_gmem_ptr(recast_ptr(ptr.get())); +} // // smem_ptr // -template -struct smem_ptr : device_ptr> { - using device_ptr>::device_ptr; +template +struct smem_ptr : iter_adaptor> { + using iter_adaptor>::iter_adaptor; }; -template -CUTE_HOST_DEVICE constexpr -smem_ptr -make_smem_ptr(T* ptr) { - return {ptr}; -} +template +struct is_smem : false_type {}; +template // Found the smem +struct is_smem> : true_type {}; +template // Recurse on ::iterator, if possible +struct is_smem> : is_smem {}; -template +// Idempotent smem tag on an iterator +template CUTE_HOST_DEVICE constexpr -smem_ptr -make_smem_ptr(void* ptr) { - return {reinterpret_cast(ptr)}; +auto +make_smem_ptr(Iterator iter) { + if constexpr (is_smem::value) { + return iter; + } else { + return smem_ptr{iter}; + } + CUTE_GCC_UNREACHABLE; } -template +// Make a smem swizzle pointer, common operation +template CUTE_HOST_DEVICE constexpr -smem_ptr -make_smem_ptr(void const* ptr) { - return {reinterpret_cast(ptr)}; +auto +make_smem_ptr(Iterator ptr, Swizzle sw) +{ + return make_swizzle_ptr(make_smem_ptr(ptr), sw); } -template -struct is_smem> : true_type {}; - -// -// rmem_ptr -// - -template -struct rmem_ptr : device_ptr> { - using device_ptr>::device_ptr; -}; - +// Explicitly typed construction from a raw pointer template CUTE_HOST_DEVICE constexpr -rmem_ptr -make_rmem_ptr(T* ptr) { - return {ptr}; +auto +make_smem_ptr(void* ptr) { + return make_smem_ptr(recast_ptr(ptr)); } +// Explicitly typed construction from a raw pointer template CUTE_HOST_DEVICE constexpr -rmem_ptr -make_rmem_ptr(void* ptr) { - return {reinterpret_cast(ptr)}; +auto +make_smem_ptr(void const* ptr) { + return make_smem_ptr(recast_ptr(ptr)); } -template +// The smem tag is invariant over type-recast +template CUTE_HOST_DEVICE constexpr -rmem_ptr -make_rmem_ptr(void const* ptr) { - return {reinterpret_cast(ptr)}; +auto +recast_ptr(smem_ptr

const& ptr) { + return make_smem_ptr(recast_ptr(ptr.get())); } -template -struct is_rmem> : true_type {}; - // -// counting iterator -- quick and dirty +// rmem_ptr // -struct counting -{ - using index_type = int; - using value_type = index_type; - - CUTE_HOST_DEVICE constexpr - counting() : n_(0) {} - CUTE_HOST_DEVICE constexpr - counting(index_type const& n) : n_(n) {} - - CUTE_HOST_DEVICE constexpr - index_type operator[](index_type const& i) const { return n_ + i; } - - CUTE_HOST_DEVICE constexpr - index_type const& operator*() const { return n_; } - - CUTE_HOST_DEVICE constexpr - counting operator+(index_type const& i) const { return {n_ + i}; } - CUTE_HOST_DEVICE constexpr - counting& operator++() { ++n_; return *this; } - - CUTE_HOST_DEVICE constexpr - bool operator==(counting const& other) const { return n_ == other.n_; } - CUTE_HOST_DEVICE constexpr - bool operator!=(counting const& other) const { return n_ != other.n_; } - - CUTE_HOST_DEVICE constexpr - bool operator< (counting const& other) const { return n_ < other.n_; } - - index_type n_; +template +struct rmem_ptr : iter_adaptor> { + using iter_adaptor>::iter_adaptor; }; -// -// recast -// - -template -CUTE_HOST_DEVICE constexpr -auto -recast(T* ptr) { - return reinterpret_cast(ptr); -} - -template -CUTE_HOST_DEVICE constexpr -auto -recast(T const* ptr) { - return reinterpret_cast(ptr); -} - -template -CUTE_HOST_DEVICE constexpr -auto -recast(gmem_ptr const& ptr) { - return make_gmem_ptr(recast(ptr.ptr_)); -} - -template -CUTE_HOST_DEVICE constexpr -auto -recast(gmem_ptr const& ptr) { - return make_gmem_ptr(recast(ptr.ptr_)); -} +// Anything that is not gmem or smem is rmem +template +struct is_rmem : bool_constant::value || is_smem::value)> {}; +template +struct is_rmem> : true_type {}; -template +// Idempotent rmem tag on an iterator +template CUTE_HOST_DEVICE constexpr auto -recast(smem_ptr const& ptr) { - return make_smem_ptr(recast(ptr.ptr_)); +make_rmem_ptr(Iterator iter) { + if constexpr (is_rmem::value) { + return iter; + } else { + return rmem_ptr{iter}; + } + CUTE_GCC_UNREACHABLE; } -template +// Explicitly typed construction from a raw pointer +template CUTE_HOST_DEVICE constexpr auto -recast(smem_ptr const& ptr) { - return make_smem_ptr(recast(ptr.ptr_)); +make_rmem_ptr(void* ptr) { + return make_rmem_ptr(recast_ptr(ptr)); } -template +// Explicitly typed construction from a raw pointer +template CUTE_HOST_DEVICE constexpr auto -recast(rmem_ptr const& ptr) { - return make_rmem_ptr(recast(ptr.ptr_)); +make_rmem_ptr(void const* ptr) { + return make_rmem_ptr(recast_ptr(ptr)); } -template +// The rmem tag is invariant over type-recast +template CUTE_HOST_DEVICE constexpr auto -recast(rmem_ptr const& ptr) { - return make_rmem_ptr(recast(ptr.ptr_)); +recast_ptr(rmem_ptr

const& ptr) { + return make_rmem_ptr(recast_ptr(ptr.get())); } // @@ -342,46 +266,40 @@ recast(rmem_ptr const& ptr) { // template -CUTE_HOST_DEVICE void print(T const* const ptr) -{ - printf("raw_ptr_%db(%p)", int(sizeof_bits::value), ptr); -} - -template -CUTE_HOST_DEVICE void print(gmem_ptr const& ptr) +CUTE_HOST_DEVICE void print(gmem_ptr ptr) { - printf("gmem_ptr_%db(%p)", int(sizeof_bits::value), ptr.get()); + printf("gmem_"); print(ptr.get()); } template -CUTE_HOST_DEVICE void print(smem_ptr const& ptr) +CUTE_HOST_DEVICE void print(smem_ptr ptr) { - printf("smem_ptr_%db(%p)", int(sizeof_bits::value), ptr.get()); + printf("smem_"); print(ptr.get()); } template -CUTE_HOST_DEVICE void print(rmem_ptr const& ptr) +CUTE_HOST_DEVICE void print(rmem_ptr ptr) { - printf("rmem_ptr_%db(%p)", int(sizeof_bits::value), ptr.get()); + printf("rmem_"); print(ptr.get()); } #if !defined(__CUDACC_RTC__) template -CUTE_HOST std::ostream& operator<<(std::ostream& os, gmem_ptr const& ptr) +CUTE_HOST std::ostream& operator<<(std::ostream& os, gmem_ptr ptr) { - return os << "gmem_ptr_" << int(sizeof_bits::value) << "b"; + return os << "gmem_[" << int(sizeof_bits>::value) << "b]"; } template -CUTE_HOST std::ostream& operator<<(std::ostream& os, smem_ptr const& ptr) +CUTE_HOST std::ostream& operator<<(std::ostream& os, smem_ptr ptr) { - return os << "smem_ptr_" << int(sizeof_bits::value) << "b"; + return os << "smem_[" << int(sizeof_bits>::value) << "b]"; } template -CUTE_HOST std::ostream& operator<<(std::ostream& os, rmem_ptr const& ptr) +CUTE_HOST std::ostream& operator<<(std::ostream& os, rmem_ptr ptr) { - return os << "rmem_ptr_" << int(sizeof_bits::value) << "b"; + return os << "rmem_[" << int(sizeof_bits>::value) << "b]"; } #endif // !defined(__CUDACC_RTC__) diff --git a/include/cute/pointer_base.hpp b/include/cute/pointer_base.hpp new file mode 100644 index 0000000000..75cdf8c2a9 --- /dev/null +++ b/include/cute/pointer_base.hpp @@ -0,0 +1,247 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include // sizeof_bits + +namespace cute +{ + +// +// C++20 iterator_traits +// + +namespace detail { +// Default reference type of an iterator +template +struct iter_ref { using type = decltype(*declval()); }; +// Prefer to propagate ::reference +template +struct iter_ref> { using type = typename T::reference; }; +} // end namespace detail + +template +using iter_reference = detail::iter_ref; +template +using iter_reference_t = typename iter_reference::type; + +namespace detail { +// Default element_type of an iterator +template +struct iter_e { using type = remove_reference_t::type>; }; +// Prefer to propagate ::element_type +template +struct iter_e> { using type = typename T::element_type; }; +} // end namespace detail + +template +using iter_element = detail::iter_e; +template +using iter_element_t = typename iter_element::type; + +namespace detail { +// Default value_type of an iterator +template +struct iter_v { using type = remove_cv_t::type>; }; +// Prefer to propagate ::value_type +template +struct iter_v> { using type = typename T::value_type; }; +} // end namespace detail + +template +using iter_value = detail::iter_v; +template +using iter_value_t = typename iter_value::type; + +template +struct iterator_traits { + using reference = iter_reference_t; + using element_type = iter_element_t; + using value_type = iter_value_t; +}; + +// +// has_dereference to determine if a type is an iterator concept +// + +namespace detail { +template +struct has_dereference : CUTE_STL_NAMESPACE::false_type {}; +template +struct has_dereference())>> : CUTE_STL_NAMESPACE::true_type {}; +} // end namespace detail + +template +using has_dereference = detail::has_dereference; + +// +// raw_pointer_cast +// + +template +CUTE_HOST_DEVICE constexpr +T* +raw_pointer_cast(T* ptr) { + return ptr; +} + +// +// A very simplified iterator adaptor. +// Derived classed may override methods, but be careful to reproduce interfaces exactly. +// Clients should never have an instance of this class. Do not write methods that take this as a param. +// + +template +struct iter_adaptor +{ + using iterator = Iterator; + using reference = typename iterator_traits::reference; + using element_type = typename iterator_traits::element_type; + using value_type = typename iterator_traits::value_type; + + iterator ptr_; + + CUTE_HOST_DEVICE constexpr + iter_adaptor(iterator ptr = {}) : ptr_(ptr) {} + + CUTE_HOST_DEVICE constexpr + reference operator*() const { return *ptr_; } + + template + CUTE_HOST_DEVICE constexpr + reference operator[](Index const& i) const { return ptr_[i]; } + + template + CUTE_HOST_DEVICE constexpr + DerivedType operator+(Index const& i) const { return {ptr_ + i}; } + + CUTE_HOST_DEVICE constexpr + iterator get() const { return ptr_; } + + CUTE_HOST_DEVICE constexpr + friend bool operator==(DerivedType const& x, DerivedType const& y) { return x.ptr_ == y.ptr_; } + CUTE_HOST_DEVICE constexpr + friend bool operator!=(DerivedType const& x, DerivedType const& y) { return x.ptr_ != y.ptr_; } + CUTE_HOST_DEVICE constexpr + friend bool operator< (DerivedType const& x, DerivedType const& y) { return x.ptr_ < y.ptr_; } + CUTE_HOST_DEVICE constexpr + friend bool operator<=(DerivedType const& x, DerivedType const& y) { return x.ptr_ <= y.ptr_; } + CUTE_HOST_DEVICE constexpr + friend bool operator> (DerivedType const& x, DerivedType const& y) { return x.ptr_ > y.ptr_; } + CUTE_HOST_DEVICE constexpr + friend bool operator>=(DerivedType const& x, DerivedType const& y) { return x.ptr_ >= y.ptr_; } +}; + +template +CUTE_HOST_DEVICE constexpr +auto +raw_pointer_cast(iter_adaptor const& x) { + return raw_pointer_cast(x.ptr_); +} + +// +// counting iterator -- quick and dirty +// + +template +struct counting_iterator +{ + using index_type = T; + using value_type = T; + using reference = T; + + index_type n_; + + CUTE_HOST_DEVICE constexpr + counting_iterator(index_type n = 0) : n_(n) {} + + CUTE_HOST_DEVICE constexpr + index_type operator*() const { return n_; } + + CUTE_HOST_DEVICE constexpr + index_type operator[](index_type i) const { return n_ + i; } + + CUTE_HOST_DEVICE constexpr + counting_iterator operator+(index_type i) const { return {n_ + i}; } + CUTE_HOST_DEVICE constexpr + counting_iterator& operator++() { ++n_; return *this; } + CUTE_HOST_DEVICE constexpr + counting_iterator operator++(int) { counting_iterator ret = *this; ++n_; return ret; } + + CUTE_HOST_DEVICE constexpr + friend bool operator==(counting_iterator const& x, counting_iterator const& y) { return x.n_ == y.n_; } + CUTE_HOST_DEVICE constexpr + friend bool operator!=(counting_iterator const& x, counting_iterator const& y) { return x.n_ != y.n_; } + CUTE_HOST_DEVICE constexpr + friend bool operator< (counting_iterator const& x, counting_iterator const& y) { return x.n_ < y.n_; } + CUTE_HOST_DEVICE constexpr + friend bool operator<=(counting_iterator const& x, counting_iterator const& y) { return x.n_ <= y.n_; } + CUTE_HOST_DEVICE constexpr + friend bool operator> (counting_iterator const& x, counting_iterator const& y) { return x.n_ > y.n_; } + CUTE_HOST_DEVICE constexpr + friend bool operator>=(counting_iterator const& x, counting_iterator const& y) { return x.n_ >= y.n_; } +}; + +template +CUTE_HOST_DEVICE constexpr +T +raw_pointer_cast(counting_iterator const& x) { + return x.n_; +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(T const* const ptr) +{ + printf("ptr[%db](%p)", int(sizeof_bits::value), ptr); +} + +template +CUTE_HOST_DEVICE void print(counting_iterator ptr) +{ + printf("counting_iter_"); print(ptr.n_); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, counting_iterator ptr) +{ + return os << "counting_iter_" << ptr.n_; +} +#endif // !defined(__CUDACC_RTC__) + +} // end namespace cute diff --git a/include/cute/pointer_flagged.hpp b/include/cute/pointer_flagged.hpp new file mode 100644 index 0000000000..bf7981954a --- /dev/null +++ b/include/cute/pointer_flagged.hpp @@ -0,0 +1,146 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include // cast_smem_ptr_to_uint + +#include +#include +#include + +#include + +namespace cute +{ + +// +// Stand-in Swizzle Layout +// A model of a nullptr smem_ptr with B == sizeof_bits::value +// That represents an unset pointer. This is a placeholder type that is waiting for an smem_ptr +// + +template +struct smem_ptr_flag_bits : Int<0> {}; + +using smem_ptr_flag = smem_ptr_flag_bits<1>; + +// A flagged construction method to transform ComposedLayout +// Make a swizzle pointer tensor and check that the intended type size matches +template +CUTE_HOST_DEVICE constexpr +auto +make_tensor(Iterator const& ptr, + ComposedLayout,Layout> const& layout) +{ + static_assert(is_smem::value, "Expected smem."); + static_assert(B == sizeof_bits>::value, "Expected a B-bit pointer type."); + return make_tensor(make_smem_ptr(ptr.get(), layout.layout_a()), + layout.layout_b()); +} + +// NOTE: To preserve smem_ptr_flag_bits under recast ops +template +CUTE_HOST_DEVICE constexpr +auto +upcast(ComposedLayout,Layout> const& layout) +{ + return composition(layout.layout_a(), smem_ptr_flag_bits{}, upcast(layout.layout_b())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +downcast(ComposedLayout,Layout> const& layout) +{ + return composition(layout.layout_a(), smem_ptr_flag_bits{}, downcast(layout.layout_b())); +} + +// +// Conversion with swizzle_layout +// + +template +CUTE_HOST_DEVICE +auto +as_position_independent_swizzle_layout(ComposedLayout,Layout> const& layout) +{ + return composition(recast_layout>(layout.layout_a()), Int<0>{}, layout.layout_b()); +} + +template +CUTE_HOST_DEVICE +auto +as_position_independent_swizzle_tensor(Tensor&& tensor) +{ + static_assert(is_smem>::value, "Expected smem tensor."); + using SwizzleFn = get_swizzle_t>; + if constexpr (SwizzleFn::num_bits == 0) { + return tensor; + } else { +#if !defined(NDEBUG) + { + uint32_t address = cast_smem_ptr_to_uint(raw_pointer_cast(std::forward(tensor).data())); + uint32_t mask = ((uint32_t(1) << SwizzleFn::num_base) - 1) | SwizzleFn::swizzle_code; + assert((address & mask) == 0); // Alignment to the Base, Z, and Y of Swizzle + } +#endif + using T = typename remove_cvref_t::value_type; + // Recast swizzle from acting on byte-addressed pointers to elements of type-T + auto new_swizzle = recast_layout(SwizzleFn{}); + // Strip off everything and create a new smem_ptr for type-T + auto new_ptr = make_smem_ptr(raw_pointer_cast(std::forward(tensor).data())); + return make_tensor(new_ptr, composition(new_swizzle, Int<0>{}, tensor.layout())); + } + CUTE_GCC_UNREACHABLE; +} + +// +// Display utilities +// + +// Capture and cast smem_ptr_flag Layouts to offset-0 layouts +template +CUTE_HOST_DEVICE +void +print_latex(ComposedLayout,Layout> const& layout) +{ + print_latex(as_position_independent_swizzle_layout(layout)); +} + +template +CUTE_HOST_DEVICE void print(smem_ptr_flag_bits ptr) +{ + printf("smem_ptr[%db](unset)", B); +} + +} // end namespace cute diff --git a/include/cute/pointer_swizzle.hpp b/include/cute/pointer_swizzle.hpp new file mode 100644 index 0000000000..58646057e4 --- /dev/null +++ b/include/cute/pointer_swizzle.hpp @@ -0,0 +1,172 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include // iterator_traits +#include + +#include +#include + +/* This implements a swizzle pointer of the form + * InvolutionFn o PtrAdd + * where the InvolutionFn need not be linear. + * + * This differs subtly from swizzle_layout because the smem pointer is used + * as the offset. That means that swizzle_layout will implement position-independent + * swizzle layouts, while swizzle_ptr implements position-dependent swizzle tensors. + * Arch chose to design hardware with position-dependent swizzles. + * + * For clarity: + * NormalLayout : DeRef <- PtrAdd <- [Layout] + * ComposedLayout: DeRef <- PtrAdd <- [Swizzle <- OffsetAdd <- Layout] + * SwizzlePtr : [DeRef <- Swizzle <- PtrAdd] <- Layout + * + * Furthermore, for known swizzles, this pointer attempts to decay itself + * to a normal-pointer with a new layout containing dynamic or static strides. + * This is possible by determining the subdomain of the InvolutionFn + * that is identity and testing if the Layout's codomain is contained + * within it. + */ + +namespace cute +{ + +// concept SwizzleFn { +// CUTE_HOST_DEVICE constexpr static uint apply(uint); +// } +// See Swizzle in swizzle.hpp for common swizzle-functions. + +template +struct swizzle_ptr : iter_adaptor> +{ + using iterator = Iterator; + using reference = typename iterator_traits::reference; + using element_type = typename iterator_traits::element_type; + using value_type = typename iterator_traits::value_type; + + using iter_adaptor>::iter_adaptor; + + template + CUTE_HOST_DEVICE constexpr static + Iter apply_swizzle(Iter ptr) { + return {apply_swizzle(ptr.get())}; + } + + template + CUTE_HOST_DEVICE constexpr static + T* apply_swizzle(T* ptr) { + return reinterpret_cast(SwizzleFn::apply(reinterpret_cast(ptr))); + } + + template + CUTE_HOST_DEVICE constexpr static + subbyte_iterator apply_swizzle(subbyte_iterator ptr) { + return {apply_swizzle(ptr.ptr_), ptr.idx_}; + } + + CUTE_HOST_DEVICE constexpr + reference operator*() const { + return *apply_swizzle(this->get()); + } + + template + CUTE_HOST_DEVICE constexpr + reference operator[](Int const& i) const { + return *apply_swizzle(this->get() + i); + } +}; + +template // Default No-Swizzle +struct get_swizzle { using type = Swizzle<0,4,3>; }; +template // Found the SwizzleFn +struct get_swizzle> { using type = SwizzleFn; }; +template // Recurse into anything with a ::iterator +struct get_swizzle> : get_swizzle {}; + +template +using get_swizzle_t = typename get_swizzle::type; + +template +CUTE_HOST_DEVICE constexpr +swizzle_ptr +make_swizzle_ptr(Iterator ptr, SwizzleFn) { + return {ptr}; +} + +// Swizzle-0 specialization for immediate decay +template +CUTE_HOST_DEVICE constexpr +Iterator +make_swizzle_ptr(Iterator ptr, Swizzle<0,M,S>) { + return ptr; +} + +// +// Recast +// + +template +CUTE_HOST_DEVICE constexpr +auto +raw_pointer_cast(swizzle_ptr const& ptr) { + return raw_pointer_cast(ptr.get()); +} + +// SwizzleFn operates on the pointer address, so it doesn't care about the type +template +CUTE_HOST_DEVICE constexpr +auto +recast_ptr(swizzle_ptr const& ptr) { + return make_swizzle_ptr(recast_ptr(ptr.get()), SwizzleFn{}); +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(swizzle_ptr ptr) +{ + print(SwizzleFn{}); printf("_"); print(ptr.get()); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, swizzle_ptr ptr) +{ + return os << SwizzleFn{} << "_" << ptr.get(); +} +#endif + +} // end namespace cute diff --git a/include/cute/swizzle.hpp b/include/cute/swizzle.hpp index 39ac311de2..54a668bec2 100644 --- a/include/cute/swizzle.hpp +++ b/include/cute/swizzle.hpp @@ -124,92 +124,6 @@ composition(Swizzle, Swizzle) //return ComposedFn, Swizzle>{}; } -// -// Inverse -// - -template -CUTE_HOST_DEVICE constexpr -Swizzle -right_inverse(Swizzle const& sw) -{ - return sw; -} - -template -CUTE_HOST_DEVICE constexpr -Swizzle -left_inverse(Swizzle const& sw) -{ - return sw; -} - -// Kludge -- Probably want an OffsetFn here instead -template ::value)> -CUTE_HOST_DEVICE constexpr -auto -right_inverse(T const& t) -{ - return -t; -} - -// Kludge -- Probably want an OffsetFn here instead -template ::value)> -CUTE_HOST_DEVICE constexpr -auto -left_inverse(T const& t) -{ - return -t; -} - -// -// Upcast and Downcast -// - -template -CUTE_HOST_DEVICE constexpr -auto -upcast(Swizzle const& swizzle) -{ - static_assert(has_single_bit(N), "N must be a power of two"); - constexpr int log2_n = bit_width(uint32_t(N)) - 1; - constexpr int NewM = M - log2_n; - if constexpr (NewM >= 0) { - return Swizzle{}; - } else { - return Swizzle{}; - } - - CUTE_GCC_UNREACHABLE; -} - -template -CUTE_HOST_DEVICE constexpr -auto -downcast(Swizzle const& swizzle) -{ - static_assert(has_single_bit(N), "N must be a power of two"); - constexpr int log2_n = bit_width(uint32_t(N)) - 1; - return Swizzle{}; -} - -template -CUTE_HOST_DEVICE constexpr -auto -recast(Swizzle const& swizzle) -{ - if constexpr (sizeof_bits::value == sizeof_bits::value) { - return swizzle; - } else if constexpr (sizeof_bits::value > sizeof_bits::value) { - static_assert(sizeof_bits::value % sizeof_bits::value == 0, "NewType must be a multiple of OldType"); - return upcast::value/sizeof_bits::value>(swizzle); - } else if constexpr (sizeof_bits::value < sizeof_bits::value) { - static_assert(sizeof_bits::value % sizeof_bits::value == 0, "NewType must be a divisor of OldType"); - return downcast::value/sizeof_bits::value>(swizzle); - } -} - // // Utility for slicing and swizzle "offsets" // @@ -218,8 +132,8 @@ recast(Swizzle const& swizzle) // consumed and which bits are free. Furthermore, it is useful to know whether // each of these bits is known statically or dynamically. -// MixedBits is an 32-bit unsigned integer class where some bits are known statically -// and some bits are known dynamically. These sets of bits are disjoint and it is +// MixedBits is an 32-bit unsigned integer class where some bits are known statically +// and some bits are known dynamically. These sets of bits are disjoint and it is // known statically which bits are known dynamically. // MixedBits can only be manipulated through bitwise operations @@ -524,6 +438,12 @@ to_mixed_bits(Layout const& layout, Coord const& coord) // Display utilities // +template +CUTE_HOST_DEVICE void print(Swizzle const&) +{ + printf("Sw<%d,%d,%d>", B, M, S); +} + template CUTE_HOST_DEVICE void print(MixedBits const& m) { @@ -531,22 +451,16 @@ CUTE_HOST_DEVICE void print(MixedBits const& m) } #if !defined(__CUDACC_RTC__) -template -CUTE_HOST std::ostream& operator<<(std::ostream& os, MixedBits const& m) -{ - return os << "M_" << S << "|(" << m.dynamic_int_ << "&" << F << ")=" << uint32_t(m); -} - template -CUTE_HOST_DEVICE void print(Swizzle const&) +CUTE_HOST std::ostream& operator<<(std::ostream& os, Swizzle const&) { - print("S<%d,%d,%d>", B, M, S); + return os << "Sw<" << B << "," << M << "," << S << ">"; } -template -CUTE_HOST std::ostream& operator<<(std::ostream& os, Swizzle const&) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, MixedBits const& m) { - return os << "S<" << B << "," << M << "," << S << ">"; + return os << "M_" << S << "|(" << m.dynamic_int_ << "&" << F << ")=" << uint32_t(m); } #endif // !defined(__CUDACC_RTC__) diff --git a/include/cute/swizzle_layout.hpp b/include/cute/swizzle_layout.hpp index be966d97e7..1bbccd0c05 100644 --- a/include/cute/swizzle_layout.hpp +++ b/include/cute/swizzle_layout.hpp @@ -147,6 +147,7 @@ get_swizzle_portion(Layout) // Get the "non-swizzle" part of a composed layout, // which is the underlying (non-composed) Layout. template +CUTE_HOST_DEVICE constexpr auto get_nonswizzle_portion(ComposedLayout,Offset,LayoutB> const& slayout) { @@ -155,6 +156,7 @@ get_nonswizzle_portion(ComposedLayout,Offset,LayoutB> const& slay // The non-swizzle part of a non-swizzled layout is just the Layout. template +CUTE_HOST_DEVICE constexpr auto get_nonswizzle_portion(Layout const& slayout) { @@ -361,6 +363,88 @@ left_inverse(ComposedLayout,Offset,Layout> const& layout) } } +template +CUTE_HOST_DEVICE constexpr +Swizzle +right_inverse(Swizzle const& sw) +{ + return sw; +} + +template +CUTE_HOST_DEVICE constexpr +Swizzle +left_inverse(Swizzle const& sw) +{ + return sw; +} + +// Kludge -- Probably want an OffsetFn here instead +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +right_inverse(T const& t) +{ + return -t; +} + +// Kludge -- Probably want an OffsetFn here instead +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +left_inverse(T const& t) +{ + return -t; +} + +// +// Upcast and Downcast +// + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(Swizzle const& swizzle) +{ + static_assert(has_single_bit(N), "N must be a power of two"); + constexpr int log2_n = bit_width(uint32_t(N)) - 1; + constexpr int NewM = M - log2_n; + if constexpr (NewM >= 0) { + return Swizzle{}; + } else { + return Swizzle{}; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +downcast(Swizzle const& swizzle) +{ + static_assert(has_single_bit(N), "N must be a power of two"); + constexpr int log2_n = bit_width(uint32_t(N)) - 1; + return Swizzle{}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +recast_layout(Swizzle const& swizzle) +{ + if constexpr (sizeof_bits::value == sizeof_bits::value) { + return swizzle; + } else if constexpr (sizeof_bits::value > sizeof_bits::value) { + static_assert(sizeof_bits::value % sizeof_bits::value == 0, "NewType must be a multiple of OldType"); + return upcast::value/sizeof_bits::value>(swizzle); + } else if constexpr (sizeof_bits::value < sizeof_bits::value) { + static_assert(sizeof_bits::value % sizeof_bits::value == 0, "NewType must be a divisor of OldType"); + return downcast::value/sizeof_bits::value>(swizzle); + } +} + // // Other operations // diff --git a/include/cute/swizzle_ptr.hpp b/include/cute/swizzle_ptr.hpp deleted file mode 100644 index fde7454f14..0000000000 --- a/include/cute/swizzle_ptr.hpp +++ /dev/null @@ -1,303 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include - -#include - -#include -#include - -#include -#include -#include - -/* This implements a swizzle pointer of the form - * InvolutionFn o PtrAdd - * where the InvolutionFn need not be linear. - * - * This differs subtly from swizzle_layout because the smem pointer is used - * as the offset. That means that swizzle_layout will implement position-independent - * swizzle layouts, while swizzle_ptr implements position-dependent swizzle tensors. - * Arch chose to design hardware with position-dependent swizzles. - * - * For clarity: - * NormalLayout : DeRef <- PtrAdd <- [Layout] - * ComposedLayout: DeRef <- PtrAdd <- [Swizzle <- OffsetAdd <- Layout] - * SwizzlePtr : [DeRef <- Swizzle <- PtrAdd] <- Layout - * - * Furthermore, for known swizzles, this pointer attempts to decay itself - * to a normal-pointer with a new layout containing dynamic or static strides. - * This is possible by determining the subdomain of the InvolutionFn - * that is identity and testing if the Layout's codomain is contained - * within it. - */ - -namespace cute -{ - -template -struct smem_ptr_swizzle -{ - static_assert(is_empty::value, "Swizzle can't have state."); - - static const uint32_t ElementsPerStoredItem = sizeof(T) * 8 / sizeof_bits_v; - - CUTE_HOST_DEVICE constexpr - T* get() const - { - return ptr_; - } - - CUTE_HOST_DEVICE constexpr static - Swizzle get_swizzle() - { - return {}; - } - - CUTE_HOST_DEVICE constexpr static - T* apply_swizzle(T* ptr) - { - return reinterpret_cast(Swizzle::apply(reinterpret_cast(ptr))); - } - - CUTE_HOST_DEVICE constexpr - T& operator*() const - { - return *apply_swizzle(get()); - } - - template - CUTE_HOST_DEVICE constexpr - T& operator[](Int const& i) const - { - static_assert(sizeof_bits_v >= 8, "Use subbyte_iterator to access the element"); - return *apply_swizzle(get() + i); - } - - template - CUTE_HOST_DEVICE constexpr - smem_ptr_swizzle operator+(Int const& i) const - { - return {ptr_ + i / ElementsPerStoredItem}; - } - - T* ptr_; -}; - -template -struct is_smem> : true_type {}; - -// Make a swizzle pointer -template -CUTE_HOST_DEVICE constexpr -auto -make_smem_ptr(T* ptr, Swizzle const&) -{ - return smem_ptr_swizzle{ptr}; -} - -// Specialization for immediate decay -template -CUTE_HOST_DEVICE constexpr -auto -make_smem_ptr(T* ptr, Swizzle<0,M,S> const&) -{ - return make_smem_ptr(ptr); -} - -// A model of a nullptr smem_ptr with B == sizeof_bits::value -// That represents an unset pointer. This is a placeholder type that is waiting for an smem_ptr -template -struct smem_ptr_flag_bits : Int<0> {}; - -using smem_ptr_flag = smem_ptr_flag_bits<1>; - -// A flagged construction method to transform ComposedLayout -// Make a swizzle pointer tensor and check that the intended type size matches -template -CUTE_HOST_DEVICE constexpr -auto -make_tensor(smem_ptr const& ptr, - ComposedLayout,Layout> const& layout) -{ - static_assert(B == sizeof_bits::value, "Expected a B-bit pointer type."); - return make_tensor(make_smem_ptr(ptr.get(), layout.layout_a()), - layout.layout_b()); -} - -// NOTE: To preserve smem_ptr_flag_bits under recast ops -template -CUTE_HOST_DEVICE constexpr -auto -upcast(ComposedLayout,Layout> const& layout) -{ - return composition(layout.layout_a(), smem_ptr_flag_bits{}, upcast(layout.layout_b())); -} - -template -CUTE_HOST_DEVICE constexpr -auto -downcast(ComposedLayout,Layout> const& layout) -{ - return composition(layout.layout_a(), smem_ptr_flag_bits{}, downcast(layout.layout_b())); -} - -// -// Recast -// Swizzle operates on the pointer address, so it doesn't care about the type -// - -template -CUTE_HOST_DEVICE constexpr -auto -recast(smem_ptr_swizzle const& ptr) -{ - return smem_ptr_swizzle{recast(ptr.ptr_)}; -} - -template -CUTE_HOST_DEVICE constexpr -auto -recast(smem_ptr_swizzle const& ptr) -{ - return smem_ptr_swizzle{recast(ptr.ptr_)}; -} - -template -CUTE_HOST_DEVICE constexpr -T* -raw_pointer_cast(smem_ptr_swizzle ptr) { - return ptr.get(); -} - -// -// Conversion with swizzle_layout -// - -template -CUTE_HOST_DEVICE -auto -as_position_independent_swizzle_layout(ComposedLayout,Layout> const& layout) -{ - return composition(recast,uint_bit_t>(layout.layout_a()), Int<0>{}, layout.layout_b()); -} - -template -CUTE_HOST_DEVICE -auto -as_position_independent_swizzle_tensor(Tensor>, Layout> const& tensor) -{ - { - uint32_t address = cast_smem_ptr_to_uint(tensor.data().get()); - uint32_t mask = ((uint32_t(1) << Swizzle::num_base) - 1) & (Swizzle::swizzle_code); - assert((address & mask) == 0); // Alignment to the Base, Z, and Y of Swizzle - } - auto new_swizzle = recast,uint_bit_t>>(tensor.data().get_swizzle()); - return make_tensor(make_smem_ptr(tensor.data().get()), composition(new_swizzle, Int<0>{}, tensor.layout())); -} - -template -CUTE_HOST_DEVICE -auto -as_position_independent_swizzle_tensor(Tensor>, Layout>& tensor) -{ - { - [[maybe_unused]] uint32_t address = cast_smem_ptr_to_uint(tensor.data().get()); - [[maybe_unused]] uint32_t mask = ((uint32_t(1) << Swizzle::num_base) - 1) & (Swizzle::swizzle_code); - assert((address & mask) == 0); // Alignment to the Base, Z, and Y of Swizzle - } - auto new_swizzle = recast,uint_bit_t>>(tensor.data().get_swizzle()); - return make_tensor(make_smem_ptr(tensor.data().get()), composition(new_swizzle, Int<0>{}, tensor.layout())); -} - -template -CUTE_HOST_DEVICE -auto -as_position_independent_swizzle_tensor(Tensor>, Layout>&& tensor) -{ - return as_position_independent_swizzle_tensor(tensor); -} - -// Pass through everything else -// Used if the tensor doesn't have a swizzled layout, e.g. Layout_MN_INTER_Atom, Layout_K_INTER_Atom -template -CUTE_HOST_DEVICE constexpr -auto -as_position_independent_swizzle_tensor(Tensor const& tensor) -{ - return tensor; -} - -template -CUTE_HOST_DEVICE constexpr -auto -as_position_independent_swizzle_tensor(Tensor&& tensor) -{ - return tensor; -} - -// -// Print -// - -// Capture and cast smem_ptr_flag Layouts to offset-0 layouts -template -CUTE_HOST_DEVICE -void -print_latex(ComposedLayout,Layout> const& layout) -{ - auto new_swizzle = recast,uint_bit_t>(layout.layout_a()); - print_latex(composition(new_swizzle, Int<0>{}, layout.layout_b())); -} - -template -CUTE_HOST_DEVICE void print(smem_ptr_flag_bits const& ptr) -{ - printf("smem_ptr_%db(unset)", B); -} - -template -CUTE_HOST_DEVICE void print(smem_ptr_swizzle> const& ptr) -{ - printf("smem_ptr_S<%d,%d,%d>_%db(%p)", B, M, S, int(sizeof_bits::value), ptr.get()); -} - -#if !defined(__CUDACC_RTC__) -template -CUTE_HOST std::ostream& operator<<(std::ostream& os, smem_ptr_swizzle> const&) -{ - return os << "smem_ptr_S<" << B << "," << M << "," << S << ">_" << int(sizeof_bits::value) << "b"; -} -#endif - -} // end namespace cute diff --git a/include/cute/tensor.hpp b/include/cute/tensor.hpp index c4c89de3dd..9b4e744c59 100644 --- a/include/cute/tensor.hpp +++ b/include/cute/tensor.hpp @@ -33,16 +33,16 @@ #include #include +#include +#include + #include #include #include -#include -#include -#include +#include #include #include -#include namespace cute { @@ -52,63 +52,54 @@ namespace cute // // concept Engine { -// using value_type = ; +// using iterator = ; +// using value_type = ; +// using element_type = ; +// using reference = ; // iterator begin(); // }; template -using ArrayEngine = typename conditional<(sizeof_bits::value % 8 == 0), - array_aligned, - array_subbyte>::type; +struct ArrayEngine +{ + using Storage = typename conditional<(sizeof_bits::value % 8 == 0), + array_aligned, + array_subbyte>::type; + using iterator = typename Storage::iterator; + using reference = typename iterator_traits::reference; + using element_type = typename iterator_traits::element_type; + using value_type = typename iterator_traits::value_type; + Storage storage_; + + CUTE_HOST_DEVICE constexpr auto begin() const { return storage_.begin(); } + CUTE_HOST_DEVICE constexpr auto begin() { return storage_.begin(); } +}; template struct ViewEngine { - using value_type = typename cute::remove_cvref())>::type; - - using iterator = Iterator; + using iterator = Iterator; + using reference = typename iterator_traits::reference; + using element_type = typename iterator_traits::element_type; + using value_type = typename iterator_traits::value_type; iterator storage_; - CUTE_HOST_DEVICE constexpr - iterator const& - begin() const { - return storage_; - } - - CUTE_HOST_DEVICE constexpr - iterator& - begin() { - return storage_; - } + CUTE_HOST_DEVICE constexpr iterator const& begin() const { return storage_; } + CUTE_HOST_DEVICE constexpr iterator & begin() { return storage_; } }; -template -struct is_rmem> : is_rmem {}; -template -struct is_smem> : is_smem {}; -template -struct is_gmem> : is_gmem {}; template struct ConstViewEngine { - using value_type = typename cute::remove_cvref())>::type; - - using iterator = Iterator; + using iterator = Iterator; + using reference = typename iterator_traits::reference; + using element_type = typename iterator_traits::element_type; + using value_type = typename iterator_traits::value_type; iterator storage_; - CUTE_HOST_DEVICE constexpr - iterator const& - begin() const { - return storage_; - } + CUTE_HOST_DEVICE constexpr iterator const& begin() const { return storage_; } }; -template -struct is_rmem> : is_rmem {}; -template -struct is_smem> : is_smem {}; -template -struct is_gmem> : is_gmem {}; // // Tensor // @@ -116,14 +107,13 @@ struct is_gmem> : is_gmem {}; template struct Tensor { - using value_type = typename Engine::value_type; - //using pointer = typename engine_traits::pointer; - //using const_pointer = typename engine_traits::const_pointer; - //using reference = typename engine_traits::reference; - //using const_reference = typename engine_traits::const_reference; + using iterator = typename Engine::iterator; + using value_type = typename Engine::value_type; + using element_type = typename Engine::element_type; + using reference = typename Engine::reference; - using engine_type = Engine; - using layout_type = Layout; + using engine_type = Engine; + using layout_type = Layout; CUTE_HOST_DEVICE constexpr Tensor() {} @@ -323,18 +313,11 @@ struct Tensor cute::tuple rep_; }; - -template +template struct is_tensor : false_type {}; template struct is_tensor> : true_type {}; -template -struct is_rmem> : is_rmem {}; -template -struct is_smem> : is_smem {}; -template -struct is_gmem> : is_gmem {}; // Customization point for creation of owning and non-owning Tensors template struct MakeTensor @@ -471,8 +454,7 @@ CUTE_HOST_DEVICE constexpr auto make_counting_tensor(Layout const& layout) { - return make_tensor(ArithmeticTupleIterator(as_arithmetic_tuple(repeat_like(coshape(layout), Int<0>{}))), - layout); + return make_tensor(make_inttuple_iter(repeat_like(coshape(layout), Int<0>{})), layout); } // @@ -665,16 +647,14 @@ group_modes(Tensor&& tensor) // -- doesn't check dynamic integer divisibility // -- doesn't check alignment -// A tagged version for dispatching -template >::value)> +template CUTE_HOST_DEVICE constexpr auto -recast(Tensor&& tensor, type_list) +recast(Tensor&& tensor) { using OldType = typename remove_cvref_t::value_type; auto old_layout = tensor.layout(); - auto new_layout = recast(old_layout); + auto new_layout = recast_layout(old_layout); // If this is an upcast of a normal Layout with static negative strides, then offset as well if constexpr (sizeof(OldType) < sizeof(NewType) && not is_composed_layout::value) { @@ -682,38 +662,14 @@ recast(Tensor&& tensor, type_list) auto extent_diff = transform(shape_diff, flatten(old_layout.stride()), multiplies{}); auto offset = fold(extent_diff, Int<0>{}, [](auto const& i, auto const& a) { return i + cute::min(a,Int<0>{}); }); - return make_tensor(recast(std::forward(tensor).data() + offset), new_layout); + return make_tensor(recast_ptr(std::forward(tensor).data() + offset), new_layout); } else { - return make_tensor(recast(std::forward(tensor).data() ), new_layout); + return make_tensor(recast_ptr(std::forward(tensor).data() ), new_layout); } CUTE_GCC_UNREACHABLE; } -template -CUTE_HOST_DEVICE constexpr -auto -recast(Tensor const& tensor) -{ - return recast(tensor, type_list{}); -} - -template -CUTE_HOST_DEVICE constexpr -auto -recast(Tensor& tensor) -{ - return recast(tensor, type_list{}); -} - -template -CUTE_HOST_DEVICE constexpr -auto -recast(Tensor&& tensor) -{ - return recast(std::forward>(tensor), type_list{}); -} - // // max_common_vector // @@ -736,13 +692,12 @@ max_common_vector(Tensor const& a, { using SrcType = typename Tensor::value_type; using DstType = typename Tensor::value_type; - - using SrcRef = decltype(*(a.data())); - using DstRef = decltype(*(b.data())); + using SrcRef = typename Tensor::reference; + using DstRef = typename Tensor::reference; // Determine if vectorization candidates at all if constexpr (// Should be the same value_types, else the copy is also performing a cast - sizeof(SrcType) == sizeof(DstType) && + sizeof_bits_v == sizeof_bits_v && // The types should be trivially copyable so that vectorization is valid is_trivially_copyable::value && is_trivially_copyable::value && @@ -759,144 +714,222 @@ max_common_vector(Tensor const& a, } // -// Key algebraic operations +// Key algebraic operations -- Divide and Product // -template with shape +// ** Result Tensor shape ((BLK_A,a),(BLK_B,b),x,y). +// ** Each mode of the Tile is applied to the corresponding mode of the Tensor. +// ** See logical_divide(Layout,Tuple) +// +// * A Shape (BLK_A,BLK_B) +// ** Result Tensor shape ((BLK_A,a),(BLK_B,b),x,y). +// ** Equivalent to applying Tile. +// ** See logical_divide(Layout,Tuple) and logical_divide(Layout,Int) +// +// Note that the Tile/Shape Tilers must be weakly_congruent to the Tensor +template >::value)> CUTE_HOST_DEVICE constexpr auto logical_divide(Tensor && tensor, - Tile const& tile) + Tiler const& tiler) // Layout or Tile or Shape { return make_tensor(std::forward(tensor).data(), - logical_divide(tensor.layout(), tile)); + logical_divide(tensor.layout(), tiler)); } -// zipped_divide is logical_divide with modes gathered into standard form ((BLK_A,BLK_B),(a,b)) -template or Shape, this zips modes into standard form ((BLK_A,BLK_B),(a,b,x,y)) +template >::value)> CUTE_HOST_DEVICE constexpr auto -zipped_divide(Tensor && tensor, - Tile const& tile) // Layout or Tile +zipped_divide(Tensor && tensor, + Tiler const& tiler) // Layout or Tile or Shape { return make_tensor(std::forward(tensor).data(), - zipped_divide(tensor.layout(), tile)); + zipped_divide(tensor.layout(), tiler)); } -// tiled_divide is logical_divide with the second output mode flattened ((BLK_A,BLK_B),a,b) -template >::value)> CUTE_HOST_DEVICE constexpr auto -tiled_divide(Tensor && tensor, - Tile const& tile) // Layout or Tile +tiled_divide(Tensor && tensor, + Tiler const& tiler) // Layout or Tile or Shape { return make_tensor(std::forward(tensor).data(), - tiled_divide(tensor.layout(), tile)); + tiled_divide(tensor.layout(), tiler)); +} + +// flat_divide is zipped_divide with the both modes flattened (BLK_A,BLK_B,a,b,x,y) +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +flat_divide(Tensor && tensor, + Tiler const& tiler) // Layout or Tile or Shape +{ + return make_tensor(std::forward(tensor).data(), + flat_divide(tensor.layout(), tiler)); } // logical_product on a Tensor doesn't make sense since it often increases cosize +// though this might make sense for creating Tensors with broadcasted (stride-0) modes // -// Logical Divide utilities: local_partition and local_tile +// Tensor partitioning utilities // -template >::value)> CUTE_HOST_DEVICE constexpr auto -local_partition(Tensor && tensor, - Tile const& tile, - Coord const& coord) +inner_partition(Tensor && tensor, + Tiler const& tiler, + Coord const& coord) { - constexpr int R1 = decltype(rank(tensor))::value; - - // Split the modes of tensor according to the modes of tile - // zipped_divide returns something like ((VEC_A,VEC_B,...),(a,b,...)) - - // The_coord is the coord into the first mode, flatten the rest - return zipped_divide(std::forward(tensor), tile)(coord, repeat(_)); + auto tensor_tiled = zipped_divide(std::forward(tensor), tiler); + constexpr int R0 = decltype(rank<0>(tensor_tiled))::value; + + // The coord slices into the second mode (the "rest" mode), flatten the first + if constexpr (is_tuple::value) { + // Append trailing modes if coord is tuple + constexpr int R1 = decltype(rank<1>(tensor_tiled))::value;; + return tensor_tiled(repeat(_), append(coord,_)); + } else { + // Flat indexing if coord is not tuple + return tensor_tiled(repeat(_), coord); + } } -template >::value)> CUTE_HOST_DEVICE constexpr auto -local_partition(Tensor && tensor, - Tile const& tile, - Coord const& coord, - Projection const& proj) +outer_partition(Tensor && tensor, + Tiler const& tiler, + Coord const& coord) { - return local_partition(std::forward(tensor), - dice(proj, tile), - dice(proj, coord)); + auto tensor_tiled = zipped_divide(std::forward(tensor), tiler); + constexpr int R1 = decltype(rank<1>(tensor_tiled))::value; + + // The coord slices into the first mode (the "tile" mode), flatten the second + if constexpr (is_tuple::value) { + // Append trailing modes if coord is tuple + constexpr int R0 = decltype(rank<0>(tensor_tiled))::value; + return tensor_tiled(append(coord,_), repeat(_)); + } else { + // Flat indexing if coord is not tuple + return tensor_tiled(coord, repeat(_)); + } } -// Special case with Layout and Integral that extracts the coord first -// e.g. local_partition(tensor, ThrLayout, threadIdx.x) -template >::value && - is_integral::value)> -CUTE_HOST_DEVICE +// Tile a tensor according to @a tiler and use @a coord to index into the remainder, keeping the tile. +// This is typical at the CTA level where tiles of data are extracted: +// Tensor data = ... // ( M, N) +// Tensor cta_data = local_tile(data, Shape<_32,_64>{}, make_coord(blockIdx.x,blockIdx.y)); // (_32,_64) +template >::value)> +CUTE_HOST_DEVICE constexpr auto -local_partition(Tensor && tensor, - Layout const& tile, - Index const& index) +local_tile(Tensor && tensor, + Tiler const& tiler, // tiler to apply + Coord const& coord) // coord to slice into "remainder" { - return local_partition(std::forward(tensor), - product_each(shape(tile)), - tile.get_flat_coord(index)); + return inner_partition(std::forward(tensor), + tiler, + coord); } -// Special case with Layout and Integral that extracts the coord first -// e.g. local_partition(tensor, ThrLayout, threadIdx.x, Step<_1,X,_1>{}) -template >::value && - is_integral::value)> +// Same as above, but with a projection parameter to strip out unwanted tiling modes for convenience +// when using projections of the same tiler. +// This is typical at the CTA level where tiles of data are extracted as projections: +// Tensor dataA = ... // (M,K) +// Tensor dataB = ... // (N,K) +// Tensor dataC = ... // (M,N) +// auto cta_tiler = Shape<_32, _64, _4>{}; +// auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); +// Tensor ctaA = local_tile(dataA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (_32,_4,k) +// Tensor ctaB = local_tile(dataA, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (_64,_4,k) +// Tensor ctaC = local_tile(dataA, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (_32,_64) +template >::value)> CUTE_HOST_DEVICE auto -local_partition(Tensor && tensor, - Layout const& tile, - Index const& index, - Projection const& proj) +local_tile(Tensor && tensor, + Tiler const& tiler, // tiler to apply + Coord const& coord, // coord to slice into "remainder" + Proj const& proj) // projection to apply to tiler and coord { - return local_partition(std::forward(tensor), - dice(proj, product_each(shape(tile))), - dice(proj, tile).get_flat_coord(index)); + return local_tile(std::forward(tensor), + dice(proj, tiler), + dice(proj, coord)); } -template >{}, thr_idx); // ( _8, _4) +template >::value)> -CUTE_HOST_DEVICE constexpr +CUTE_HOST_DEVICE auto -local_tile(Tensor && tensor, - Tile const& tile, - Coord const& coord) +local_partition(Tensor && tensor, + Layout const& tile, // coord -> index + Index const& index) // index to slice for { - constexpr int R0 = decltype(rank(tile))::value; - constexpr int R1 = decltype(rank(tensor))::value; - - // Split the modes of tensor according to the modes of tile - // zipped_divide returns something like ((VEC_A,VEC_B,...),(a,b,...)) - - // The padded_coord is the coord into the second mode, flatten the rest - return zipped_divide(std::forward(tensor), tile)(repeat(_), append(coord,_)); + static_assert(is_integral::value); + return outer_partition(std::forward(tensor), + product_each(shape(tile)), + tile.get_flat_coord(index)); } -template , Stride<_16,_1,_0>>{}; +// Tensor thrA = local_partition(dataA, thr_layout, thr_idx, Step<_1, X,_1>{}); // (M/2,K/1) +// Tensor thrB = local_partition(dataB, thr_layout, thr_idx, Step< X,_1,_1>{}); // (N/16,K/1) +// Tensor thrC = local_partition(dataC, thr_layout, thr_idx, Step<_1,_1, X>{}); // (M/2,N/16) +template >::value)> CUTE_HOST_DEVICE auto -local_tile(Tensor && tensor, - Tile const& tile, - Coord const& coord, - Proj const& proj) +local_partition(Tensor && tensor, + Layout const& tile, // coord -> index + Index const& index, // index to slice for + Projection const& proj) { - return local_tile(std::forward(tensor), - dice(proj, tile), - dice(proj, coord)); + return local_partition(std::forward(tensor), + dice(proj, tile), + index); } // @@ -906,7 +939,7 @@ local_tile(Tensor && tensor, template CUTE_HOST_DEVICE void print(Tensor const& tensor) { - print(tensor.data()); print(" o "); print(tensor.layout()); + print(tensor.data()); print(" o "); print(tensor.layout()); } template @@ -951,8 +984,6 @@ CUTE_HOST_DEVICE void print_tensor(Tensor const& tensor) } } - - #if !defined(__CUDACC_RTC__) template CUTE_HOST std::ostream& print_tensor_os(std::ostream& os, Tensor const& tensor) @@ -1008,7 +1039,9 @@ CUTE_HOST std::ostream& operator<<(std::ostream& os, Tensor const // Extended Engines // -#include +#include +#include + // // Tensor Algorithms // diff --git a/include/cute/tensor_predicate.hpp b/include/cute/tensor_predicate.hpp index 730f219462..826f757960 100644 --- a/include/cute/tensor_predicate.hpp +++ b/include/cute/tensor_predicate.hpp @@ -60,4 +60,20 @@ struct TrivialPredTensor } }; +template +struct FunctionPredTensor +{ + CUTE_HOST_DEVICE constexpr + FunctionPredTensor(Fn const& fn) : fn_(fn) {} + + template + CUTE_HOST_DEVICE constexpr + auto + operator()(Coords const&... coords) const { + return fn_(coords...); + } + + Fn const& fn_; +}; + } // end namespace cute diff --git a/include/cute/underscore.hpp b/include/cute/underscore.hpp index 155f5eb1ce..7f8e5c08e9 100644 --- a/include/cute/underscore.hpp +++ b/include/cute/underscore.hpp @@ -95,14 +95,16 @@ using has_int0 = has_elem>; // Slice keeps only the elements of Tuple B that are paired with an Underscore // +namespace detail { + template CUTE_HOST_DEVICE constexpr auto -slice(A const& a, B const& b) +lift_slice(A const& a, B const& b) { if constexpr (is_tuple::value) { static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); - return filter_tuple(a, b, [](auto const& x, auto const& y) { return slice(x,y); }); + return filter_tuple(a, b, [](auto const& x, auto const& y) { return lift_slice(x,y); }); } else if constexpr (is_underscore::value) { return cute::tuple{b}; } else { @@ -112,18 +114,40 @@ slice(A const& a, B const& b) CUTE_GCC_UNREACHABLE; } +} // end namespace detail + +// Entry point overrides the lifting so that slice(_,b) == b +template +CUTE_HOST_DEVICE constexpr +auto +slice(A const& a, B const& b) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + return filter_tuple(a, b, [](auto const& x, auto const& y) { return detail::lift_slice(x,y); }); + } else if constexpr (is_underscore::value) { + return b; + } else { + return cute::tuple<>{}; + } + + CUTE_GCC_UNREACHABLE; +} + // // Dice keeps only the elements of Tuple B that are paired with an Int // +namespace detail { + template CUTE_HOST_DEVICE constexpr auto -dice(A const& a, B const& b) +lift_dice(A const& a, B const& b) { if constexpr (is_tuple::value) { static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); - return filter_tuple(a, b, [](auto const& x, auto const& y) { return dice(x,y); }); + return filter_tuple(a, b, [](auto const& x, auto const& y) { return lift_dice(x,y); }); } else if constexpr (is_underscore::value) { return cute::tuple<>{}; } else { @@ -133,6 +157,26 @@ dice(A const& a, B const& b) CUTE_GCC_UNREACHABLE; } +} // end namespace detail + +// Entry point overrides the lifting so that dice(1,b) == b +template +CUTE_HOST_DEVICE constexpr +auto +dice(A const& a, B const& b) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + return filter_tuple(a, b, [](auto const& x, auto const& y) { return detail::lift_dice(x,y); }); + } else if constexpr (is_underscore::value) { + return cute::tuple<>{}; + } else { + return b; + } + + CUTE_GCC_UNREACHABLE; +} + // // Display utilities // diff --git a/include/cute/util/type_traits.hpp b/include/cute/util/type_traits.hpp index e951e901c4..d3c75eb567 100644 --- a/include/cute/util/type_traits.hpp +++ b/include/cute/util/type_traits.hpp @@ -74,7 +74,10 @@ using CUTE_STL_NAMESPACE::is_void_v; using CUTE_STL_NAMESPACE::is_base_of; using CUTE_STL_NAMESPACE::is_base_of_v; +using CUTE_STL_NAMESPACE::is_const; using CUTE_STL_NAMESPACE::is_const_v; +using CUTE_STL_NAMESPACE::is_volatile; +using CUTE_STL_NAMESPACE::is_volatile_v; // using CUTE_STL_NAMESPACE::true_type; // using CUTE_STL_NAMESPACE::false_type; @@ -115,6 +118,7 @@ template using is_std_integral = CUTE_STL_NAMESPACE::is_integral; using CUTE_STL_NAMESPACE::is_empty; +using CUTE_STL_NAMESPACE::is_empty_v; using CUTE_STL_NAMESPACE::invoke_result_t; diff --git a/include/cutlass/arch/arch.h b/include/cutlass/arch/arch.h index 2ad489518c..ca8d634341 100644 --- a/include/cutlass/arch/arch.h +++ b/include/cutlass/arch/arch.h @@ -41,7 +41,7 @@ namespace cutlass { namespace arch { -#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) +#if defined(__NVCC__) || defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__)) /// Computes laneId within a warp CUTLASS_DEVICE diff --git a/include/cutlass/arch/barrier.h b/include/cutlass/arch/barrier.h index 6b491404e0..6daffceff9 100644 --- a/include/cutlass/arch/barrier.h +++ b/include/cutlass/arch/barrier.h @@ -484,6 +484,22 @@ void fence_view_async_shared() { #endif } +// Arrive on completion of in-flight cp.async operations issued by the calling thread +CUTLASS_DEVICE +void cpasync_barrier_arrive(uint64_t const* smem_ptr) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + "cp.async.mbarrier.arrive.shared.b64 [%0];\n\t" + "}" + : + : "r"(smem_addr)); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif +} + //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/arch/mma.h b/include/cutlass/arch/mma.h index 7a70114d5b..180db93e12 100644 --- a/include/cutlass/arch/mma.h +++ b/include/cutlass/arch/mma.h @@ -92,6 +92,7 @@ struct UseStagedAccumulation { static bool const value = platform::is_same::value || platform::is_same::value; }; + ///////////////////////////////////////////////////////////////////////////////////////////////// /// Tag indicating the complex multiply-add operation @@ -128,7 +129,7 @@ struct OpClassWmmaTensorOp {}; ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Tag classifing operators as Tensor Core with structure sparse operations. +/// Tag classifying operators as Tensor Core with structure sparse operations. struct OpClassSparseTensorOp {}; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/arch/reg_reconfig.h b/include/cutlass/arch/reg_reconfig.h index 2b74a22e6c..22d8d50290 100644 --- a/include/cutlass/arch/reg_reconfig.h +++ b/include/cutlass/arch/reg_reconfig.h @@ -55,7 +55,6 @@ void warpgroup_reg_alloc(){ asm volatile( "setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(RegCount) ); #endif } - template CUTLASS_DEVICE void warpgroup_reg_dealloc(){ diff --git a/include/cutlass/array.h b/include/cutlass/array.h index 63ba80893f..bfdfa072a9 100644 --- a/include/cutlass/array.h +++ b/include/cutlass/array.h @@ -730,6 +730,24 @@ struct multiplies> { } }; +template +struct maximum_absolute_value_reduction, PropogateNaN> { + + CUTLASS_HOST_DEVICE + T operator() (T const& scalar, Array const& rhs) const { + + T result = scalar; + maximum_absolute_value_reduction scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result = scalar_op(result, rhs[i]); + } + + return result; + } +}; + template struct scale> { T const scaling_factor_; @@ -797,6 +815,24 @@ struct divides> { } }; +template +struct reciprocal_approximate> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs) const { + + Array result; + reciprocal_approximate scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i]); + } + + return result; + } +}; + template struct maximum, false> { diff --git a/include/cutlass/barrier.h b/include/cutlass/barrier.h index 63502571fc..04c63af2f9 100644 --- a/include/cutlass/barrier.h +++ b/include/cutlass/barrier.h @@ -219,15 +219,17 @@ using Barrier = GenericBarrier; * * @param ThreadCount_ Number of threads that will wait on a NamedBarrier with a given ID * @param Offset Value added to the ID passed in by the user to determine the NamedBarrier ID to call into + * @param MaxNumNamedBarriers The maximum number of unique barrier IDs that will be requested on this type **/ template < uint32_t ThreadCount_, - uint32_t Offset = 0 + uint32_t Offset = 0, + uint32_t MaxNumNamedBarriers = 16 > struct NamedBarrierManager { - static constexpr uint32_t MaxNumNamedBarriers = 16; - static_assert(Offset < MaxNumNamedBarriers, "Barrier IDs cannot exceed 15"); - static constexpr uint32_t ValidBarrierIds = MaxNumNamedBarriers - Offset; + static constexpr uint32_t HardwareMaxNumNamedBarriers = 16; + static_assert(MaxNumNamedBarriers <= HardwareMaxNumNamedBarriers); + static_assert(MaxNumNamedBarriers + Offset <= HardwareMaxNumNamedBarriers, "Barrier IDs cannot exceed 15"); // Number of threads participating in the barrier static constexpr uint32_t ThreadCount = ThreadCount_; @@ -239,7 +241,7 @@ struct NamedBarrierManager { // template parameter BarrierId, so passing in 0 suffices. using T = typename BarrierSync<0>::T; - using IntegerSequence = cute::make_integer_sequence; + using IntegerSequence = cute::make_integer_sequence; CUTLASS_DEVICE static @@ -275,7 +277,7 @@ struct NamedBarrierManager { CUTLASS_DEVICE static void check_barrier_in_range(uint32_t idx) { - if (idx >= ValidBarrierIds) { + if (idx >= MaxNumNamedBarriers) { CUTE_RUNTIME_ASSERT("Index exceeds barrier count"); } } diff --git a/include/cutlass/complex.h b/include/cutlass/complex.h index 22364d59e1..519d676067 100644 --- a/include/cutlass/complex.h +++ b/include/cutlass/complex.h @@ -42,11 +42,9 @@ #include "cutlass/cutlass.h" #include "cutlass/functional.h" -#include "cutlass/half.h" #include "cutlass/real.h" -#include "cutlass/bfloat16.h" -#include "cutlass/tfloat32.h" +#include "cutlass/numeric_types.h" #include "cutlass/fast_math.h" @@ -490,6 +488,15 @@ CUTLASS_HOST_DEVICE tfloat32_t conj(tfloat32_t const& z) { return z; } +CUTLASS_HOST_DEVICE float_e4m3_t conj(float_e4m3_t const& z) { + return z; +} + +CUTLASS_HOST_DEVICE float_e5m2_t conj(float_e5m2_t const& z) { + return z; +} + + /// Returns the complex conjugate template CUTLASS_HOST_DEVICE complex conj(complex const &z) { diff --git a/include/cutlass/core_io.h b/include/cutlass/core_io.h index 63617afa25..48188a7a21 100644 --- a/include/cutlass/core_io.h +++ b/include/cutlass/core_io.h @@ -57,6 +57,7 @@ #include "cutlass/conv/convolution.h" #include "cutlass/conv/conv2d_problem_size.h" #include "cutlass/conv/conv3d_problem_size.h" + /////////////////////////////////////////////////////////////////////////////////////////////////// /// Output operator for CUDA built-in dim3 type diff --git a/include/cutlass/detail/layout.hpp b/include/cutlass/detail/layout.hpp index 2defe558c3..becc077c5b 100644 --- a/include/cutlass/detail/layout.hpp +++ b/include/cutlass/detail/layout.hpp @@ -182,11 +182,11 @@ constexpr bool is_tma_copy_engine() { return false; } else { - if constexpr ( cute::is_base_of_v - || cute::is_base_of_v - || cute::is_base_of_v - || cute::is_base_of_v - || cute::is_base_of_v + if constexpr ( cute::is_base_of_v + || cute::is_base_of_v + || cute::is_base_of_v + || cute::is_base_of_v + || cute::is_base_of_v ) { return true; } @@ -198,6 +198,7 @@ constexpr bool is_tma_copy_engine() { template constexpr int get_alignment_count_from_gmem_tiled_copy() { + if constexpr (cute::is_void_v) { return 1; } diff --git a/include/cutlass/detail/mma.hpp b/include/cutlass/detail/mma.hpp new file mode 100644 index 0000000000..84d62af053 --- /dev/null +++ b/include/cutlass/detail/mma.hpp @@ -0,0 +1,66 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/arch/mma.h" +#include "cute/layout.hpp" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::detail { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct IsSparseTensorOp : cute::false_type { }; + +// The following metafunction is used to extract the OperatorClass from a cutlass 3.x kernel. +template +struct get_operator_class { + static constexpr bool is_sparse_op = IsSparseTensorOp::value; + static constexpr bool is_tensor_op = cute::size<0>(typename TiledMma::AtomShape_MNK{}) >= 8; + using type = cute::conditional_t< + is_tensor_op, + cute::conditional_t< + is_sparse_op, + cutlass::arch::OpClassSparseTensorOp, + cutlass::arch::OpClassTensorOp + >, + cutlass::arch::OpClassSimt + >; +}; + +template +using get_operator_class_t = typename get_operator_class::type; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::detail diff --git a/include/cutlass/epilogue/collective/builders/sm90_builder.inl b/include/cutlass/epilogue/collective/builders/sm90_builder.inl index dec4b9ff6e..c0401f6d5a 100644 --- a/include/cutlass/epilogue/collective/builders/sm90_builder.inl +++ b/include/cutlass/epilogue/collective/builders/sm90_builder.inl @@ -98,13 +98,18 @@ sm90_get_epilogue_smem_swizzle_layout_atom() { } // Attempts to compute a reasonable epilogue tile based on block tile shape or allows the user to provide one. -template +template constexpr auto sm90_compute_tile_shape_or_override() { if constexpr (cute::is_same_v) { if constexpr (detail::sm90_is_cooperative_v) { - return Shape<_128,_32>{}; + if constexpr (size<0>(TileShape_MNK{}) >= 128) { + return Shape<_128,_32>{}; + } + else { + return Shape<_64,_32>{}; + } } else if constexpr (detail::sm90_is_warp_specialized_v) { if constexpr (sizeof_bits_v == 8) { @@ -191,13 +196,17 @@ struct CallbacksBuilder< TileShape_MNK, EpilogueTile_MN, ElementAccumulator, - enable_if_t + enable_if_t<(FusionOp::IsAuxOutSupported ^ FusionOp::IsAuxInSupported) // only one aux tensor + && not is_subbyte_v> > { using GmemStrideTypeAux = gemm::TagToStrideC_t; using SmemLayoutAtomAux = decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom< GmemStrideTypeAux, typename FusionOp::ElementAux, EpilogueTile_MN>()); - using SmemCopyOpAux = decltype(detail::sm90_get_smem_store_op_for_accumulator< + using CopyOpR2S = decltype(detail::sm90_get_smem_store_op_for_accumulator< + GmemStrideTypeAux, typename FusionOp::ElementAux>()); + using CopyOpS2R = decltype(detail::sm90_get_smem_load_op_for_source< GmemStrideTypeAux, typename FusionOp::ElementAux>()); + using SmemCopyOpAux = conditional_t; using Callbacks = fusion::FusionCallbacks< Sm90TmaWarpSpecialized, @@ -206,6 +215,32 @@ struct CallbacksBuilder< >; }; +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + class FusionOp, + class TileShape_MNK, + class EpilogueTile_MN, + class ElementAccumulator +> +struct CallbacksBuilder< + Sm90TmaWarpSpecialized, + FusionOp, + TileShape_MNK, + EpilogueTile_MN, + ElementAccumulator, + enable_if_t<(FusionOp::IsAuxOutSupported ^ FusionOp::IsAuxInSupported) // only one aux tensor + && sizeof_bits_v == 1> +> { + using Callbacks = fusion::FusionCallbacks< + Sm90TmaWarpSpecialized, + FusionOp, TileShape_MNK, EpilogueTile_MN, + Layout<_1,_0>, DefaultCopy // aux bit tensor doesn't use smem + >; +}; + // Helper for building TMA warp-specialized collective epilogues, specialized by // the fusion operation performed and the dispatch policy to use. template < @@ -279,7 +314,7 @@ struct EpilogueDescriptor { using EpilogueTile = decltype( detail::sm90_compute_tile_shape_or_override< - ElementD, EpilogueTileType, Schedule + ElementD, EpilogueTileType, Schedule, TileShape_MNK >() ); using DispatchPolicy = @@ -443,7 +478,7 @@ struct CollectiveBuilder< cute::is_same_v >> { private: using EpilogueTile_MN = - decltype(detail::sm90_compute_tile_shape_or_override()); + decltype(detail::sm90_compute_tile_shape_or_override()); using DispatchPolicy = decltype(detail::sm90_get_tma_dispatch_policy()); @@ -623,7 +658,7 @@ CollectiveBuilder< cute::is_base_of_v >> { private: using EpilogueTile_MN = decltype(detail::sm90_compute_tile_shape_or_override< - ElementD, EpilogueTileType, Schedule>()); + ElementD, EpilogueTileType, Schedule, TileShape_MNK>()); // MSVC doesn't seem to be able to deduce DispatchPolicy correctly if it's // defined as decltype of a detail::sm90_get_tma_dispatch_policy call. // Instead, we paste in the contents of that function. A natural refactoring diff --git a/include/cutlass/epilogue/collective/default_epilogue.hpp b/include/cutlass/epilogue/collective/default_epilogue.hpp index aea8721d66..fbfde723b9 100644 --- a/include/cutlass/epilogue/collective/default_epilogue.hpp +++ b/include/cutlass/epilogue/collective/default_epilogue.hpp @@ -111,6 +111,18 @@ class DefaultEpilogue { return args; } + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream) { + return cutlass::Status::kSuccess; + } + template CUTLASS_HOST_DEVICE static bool can_implement( diff --git a/include/cutlass/epilogue/collective/detail.hpp b/include/cutlass/epilogue/collective/detail.hpp index 62d2ef755b..d871dc23f5 100644 --- a/include/cutlass/epilogue/collective/detail.hpp +++ b/include/cutlass/epilogue/collective/detail.hpp @@ -87,7 +87,7 @@ CUTLASS_HOST_DEVICE auto get_epilogue_stride(Stride stride){ if constexpr (cute::is_base_of_v) { return cute::make_stride(cute::get<1>(stride), cute::get<0>(stride), cute::get<2>(stride)); - } + } else { return stride; } diff --git a/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp b/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp index 70edf77d5f..b9001b127c 100644 --- a/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp +++ b/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp @@ -130,7 +130,19 @@ class EpilogueTensorBroadcast { return args; } - template + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream) { + return cutlass::Status::kSuccess; + } + + template CUTLASS_HOST_DEVICE static bool can_implement( [[maybe_unused]] ProblemShape const& problem_shape, diff --git a/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp b/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp index 0374a1036b..9e91f8349c 100644 --- a/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp +++ b/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp @@ -52,7 +52,7 @@ namespace collective { /// Ways to generalize this: /// - CTA tile shape /// - vectorization requirements (GMEM) -/// - vectoriz(able) transform() +/// - vectoriz(able) transform() /// template < class StrideC_, @@ -120,7 +120,19 @@ class Epilogue { return args; } - template + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream) { + return cutlass::Status::kSuccess; + } + + template CUTLASS_HOST_DEVICE static bool can_implement( [[maybe_unused]] ProblemShape const& problem_shape, @@ -200,8 +212,8 @@ class Epilogue { // Tile gD and gC by the shape of SmemLayout first auto tile = make_shape(size<0>(sC), size<1>(sC)); - Tensor gCt = local_tile(gC, tile, _); // (SMEM_M,SMEM_N,TILE_M,TILE_N) - Tensor gDt = local_tile(gD, tile, _); // (SMEM_M,SMEM_N,TILE_M,TILE_N) + Tensor gCt = flat_divide(gC, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) + Tensor gDt = flat_divide(gD, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) // Partition sC, gC, and gD for the output auto tiled_s2r = TiledCopyS2R{}; @@ -216,7 +228,7 @@ class Epilogue { // Repeat the D-partitioning for coordinates and predication Tensor cD = make_identity_tensor(make_shape(size<0>(gD),size<1>(gD))); // (BLK_M,BLK_N) -> (blk_m,blk_n) - Tensor cDt = local_tile(cD, tile, _); // (SMEM_M,SMEM_N,TILE_M,TILE_N) + Tensor cDt = flat_divide(cD, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) Tensor tDcD = tD.partition_D(cDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) CUTE_STATIC_ASSERT(size<1>(tCaC) % size<3>(tDgC) == 0); // TILE_M divides MMA_M @@ -258,7 +270,7 @@ class Epilogue { for (int pipe_n = 0; pipe_n < size<2>(tCsC); ++pipe_n) { int mma_m = step_m * size<1>(tCsC) + pipe_m; int mma_n = step_n * size<2>(tCsC) + pipe_n; - + copy(tiled_r2s, tCaC(_,mma_m,mma_n), tCsC(_,pipe_m,pipe_n)); } } @@ -279,14 +291,14 @@ class Epilogue { // source is needed Tensor tDgCmn = tDgC(_,_,_,step_m,step_n); CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < size<1>(tDgDmn); ++m) + for (int m = 0; m < size<1>(tDgDmn); ++m) { CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < size<2>(tDgDmn); ++n) + for (int n = 0; n < size<2>(tDgDmn); ++n) { // Predication if (get<0>(tDcDmn(0,m,n)) < get<0>(residue_mnk) && - get<1>(tDcDmn(0,m,n)) < get<1>(residue_mnk)) + get<1>(tDcDmn(0,m,n)) < get<1>(residue_mnk)) { // Step 5. Elementwise operation with conversion CUTLASS_PRAGMA_UNROLL @@ -309,14 +321,14 @@ class Epilogue { } CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < size<1>(tDgDmn); ++m) + for (int m = 0; m < size<1>(tDgDmn); ++m) { CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < size<2>(tDgDmn); ++n) + for (int n = 0; n < size<2>(tDgDmn); ++n) { // Predication if (get<0>(tDcDmn(0,m,n)) < get<0>(residue_mnk) && - get<1>(tDcDmn(0,m,n)) < get<1>(residue_mnk)) + get<1>(tDcDmn(0,m,n)) < get<1>(residue_mnk)) { // Step 6. Copy to GMEM copy(CopyAtomR2G{}, tDrD(_,m,n), tDgDmn(_,m,n)); diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp index fe146a8546..27b5f37b70 100644 --- a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp @@ -40,6 +40,7 @@ #include "cutlass/epilogue/collective/detail.hpp" #include "cutlass/epilogue/thread/scale_type.h" #include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" #include "cutlass/detail/layout.hpp" #include "cutlass/trace.h" @@ -165,7 +166,7 @@ class CollectiveEpilogue< using LoadPipeline = cutlass::PipelineTransactionAsync; using LoadPipelineState = cutlass::PipelineState; constexpr static uint32_t TmaTransactionBytes = - size(take<0,2>(SmemLayoutC{})) * static_cast(sizeof(SmemElementC)); + (size(take<0,2>(SmemLayoutC{})) * static_cast(sizeof_bits::value)) / 8; // TMA pipeline for storing D using StorePipeline = cute::conditional_t + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return FusionCallbacks::get_workspace_size(problem_shape, args.thread); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream) { + return FusionCallbacks::initialize_workspace(problem_shape, args.thread, workspace, stream); + } + + template CUTLASS_HOST_DEVICE static bool can_implement( ProblemShape const& problem_shape, @@ -252,7 +265,7 @@ class CollectiveEpilogue< constexpr int tma_alignment_bits = 128; auto problem_shape_MNKL = append<4>(problem_shape, 1); auto [M,N,K,L] = problem_shape_MNKL; - + constexpr int min_tma_aligned_elements_D = tma_alignment_bits / cutlass::sizeof_bits::value; bool implementable = cutlass::detail::check_alignment(cute::make_shape(M,N,L), StrideD{}); @@ -275,7 +288,7 @@ class CollectiveEpilogue< // Compute number of epilogue subtiles constexpr int epi_m = size<0>(tile_shape_MNK) / size<0>(EpilogueTile{}); constexpr int epi_n = size<1>(tile_shape_MNK) / size<1>(EpilogueTile{}); - + return epi_m * epi_n; } @@ -326,6 +339,15 @@ class CollectiveEpilogue< auto [M, N, K, L] = problem_shape_mnkl; auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + // Tile residue + auto m_max_coord = unwrap(cute::transform(make_seq(tile_shape_MNK)>{}, [&](auto i) { + return get<0,i>(problem_shape_mnkl) - get<0,i>(tile_shape_MNK) * get<0,i>(tile_coord_mnkl); + })); + auto n_max_coord = unwrap(cute::transform(make_seq(tile_shape_MNK)>{}, [&](auto i) { + return get<1,i>(problem_shape_mnkl) - get<1,i>(tile_shape_MNK) * get<1,i>(tile_coord_mnkl); + })); + auto residue_mn = make_coord(m_max_coord, n_max_coord); + // Represent the full source tensor, slice to get the tile this CTA is currently responsible for Tensor mC = params.tma_load_c.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) Tensor gC = local_tile(mC, take<0,2>(CtaTileMNK{}), make_coord(m_coord,n_coord,l_coord)); // (CTA_M,CTA_N) @@ -335,7 +357,7 @@ class CollectiveEpilogue< if constexpr (not ReuseSmemC and is_source_supported) { ptr_sC = shared_tensors.smem_C.data(); } - Tensor gC_epi = local_tile(gC, EpilogueTile{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) Tensor sC_epi = make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) // Prepare the thread(b)lock's (G)mem to (S)mem TMA tiled copy (bGS_) @@ -344,12 +366,15 @@ class CollectiveEpilogue< Tensor bGS_sC = thrblk_g2s.partition_D(sC_epi); // (G2S,G2S_M,G2S_N,PIPE_C) // Get the fusion callbacks for the producer load warp - auto pld_callbacks = fusion_callbacks.get_producer_load_callbacks( - problem_shape_mnkl, - CtaTileMNK{}, - tile_coord_mnkl, - EpilogueTile{}, - thread_idx); + auto pld_args = cutlass::epilogue::fusion::detail::ProducerLoadArgs{ + problem_shape_mnkl, + CtaTileMNK{}, + tile_coord_mnkl, + residue_mn, + EpilogueTile{}, + thread_idx + }; + auto pld_callbacks = fusion_callbacks.get_producer_load_callbacks(pld_args); bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); // Predication for TMA load (one thread issues TMA load) @@ -445,12 +470,13 @@ class CollectiveEpilogue< auto epi_tile_m = size<0>(EpilogueTile{}); auto epi_tile_n = size<1>(EpilogueTile{}); + // Represent the full output tensor, slice to get the tile this CTA is responsible for Tensor mD = params.tma_store_d.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) Tensor gD = local_tile(mD, take<0,2>(CtaTileMNK{}), make_coord(m_coord,n_coord,l_coord)); // (CTA_M,CTA_N) - + // Apply epilogue subtiling - Tensor gD_epi = local_tile(gD, EpilogueTile{}, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) // Construct the corresponding pipelined smem tensors SmemElementC* ptr_sC = reinterpret_cast(shared_tensors.smem_D.data()); @@ -503,19 +529,37 @@ class CollectiveEpilogue< Tensor bSG_sD = thrblk_s2g.partition_S(sD_epi); // (S2G,S2G_M,S2G_N,PIPE_D) Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) + // Coordinate tensors and residue for tile quantization + auto m_max_coord = unwrap(cute::transform(make_seq(CtaTileMNK{})>{}, [&](auto i) { + auto c_m = get<0,i>(problem_shape_mnkl) - get<0,i>(CtaTileMNK{}) * get<0,i>(tile_coord_mnkl); + return cute::max(0, c_m); + })); + auto n_max_coord = unwrap(cute::transform(make_seq(CtaTileMNK{})>{}, [&](auto i) { + auto c_n = get<1,i>(problem_shape_mnkl) - get<1,i>(CtaTileMNK{}) * get<1,i>(tile_coord_mnkl); + return cute::max(0, c_n); + })); + auto residue_mn = make_coord(m_max_coord, n_max_coord); + Tensor cD = make_identity_tensor(take<0,2>(CtaTileMNK{})); + Tensor tRS_cD = thread_r2s.partition_S(flat_divide(cD, EpilogueTile{})); + CUTE_STATIC_ASSERT(mma_tile_m == epi_tile_m, "EPI_TILE_M must equal MMA_TILE_M"); CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N"); // Get the fusion callbacks for the consumer store warps constexpr bool RefSrc = true; // Register tensors reference R2S copy src layout - auto cst_callbacks = fusion_callbacks.get_consumer_store_callbacks( - problem_shape_mnkl, - CtaTileMNK{}, - tile_coord_mnkl, - EpilogueTile{}, - tiled_copy_C_atom, - thread_idx, - tRS_rC); + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ + problem_shape_mnkl, + CtaTileMNK{}, + tile_coord_mnkl, + residue_mn, + EpilogueTile{}, + tiled_copy_C_atom, + thread_idx, + cD, + tRS_cD, + tRS_rC + }; + auto cst_callbacks = fusion_callbacks.get_consumer_store_callbacks(cst_args); bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); @@ -548,27 +592,13 @@ class CollectiveEpilogue< for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n) { CUTLASS_PRAGMA_UNROLL for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m) { + bool is_last_iteration = epi_m == size<2>(gD_epi)-1 && epi_n == size<3>(gD_epi)-1; + // The current tile in accumulator int mma_m = epi_m; int mma_n = (epi_n * epi_tile_n) / mma_tile_n; Tensor tRS_rAcc_frg_mn = tRS_rAcc_frg(_,mma_m,mma_n); - // Wait for a smem buffer to be available - if (issue_tma_store) { - store_pipeline.producer_acquire(store_pipe_producer_state); - } - synchronize(); - - if constexpr (ReuseSmemC) { - // Let dma warp know smem buffer is consumed and empty after StagesD producer commits - if (issued_stores >= StagesD) { - if (is_producer_load_needed) { - load_pipeline.consumer_release(load_pipe_consumer_state); - } - ++load_pipe_consumer_state; - } - } - if (is_producer_load_needed) { // Wait for the producer load to fill smem load_pipeline.consumer_wait(load_wait_state); @@ -580,7 +610,7 @@ class CollectiveEpilogue< } // First loop fusion callback entry point - cst_callbacks.step_begin(epi_m, epi_n, load_wait_state.count(), is_producer_load_needed); + cst_callbacks.previsit(epi_m, epi_n, load_wait_state.count(), is_producer_load_needed); if (is_producer_load_needed) { if constexpr (not ReuseSmemC) { @@ -602,9 +632,9 @@ class CollectiveEpilogue< // Copy tile from register to smem copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); - // Next loop fusion callback entry point + // Post visit, pre async fence callback entry point constexpr bool issue_smem_store = true; // No smem store predication - cst_callbacks.step_next(epi_m, epi_n, store_pipe_producer_state.count(), issue_smem_store); + cst_callbacks.postvisit(epi_m, epi_n, store_pipe_producer_state.count(), issue_smem_store); // Write the tile from smem to gmem with TMA cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA @@ -613,8 +643,8 @@ class CollectiveEpilogue< copy(params.tma_store_d, bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n)); } - // Last loop fusion callback entry point - cst_callbacks.step_end(epi_m, epi_n, store_pipe_producer_state.count(), issue_tma_store); + // Post async fence, pre TMA commit callback entry point + cst_callbacks.step(epi_m, epi_n, store_pipe_producer_state.count(), issue_tma_store); // Commit the TMA stores for this stage if (issue_tma_store) { @@ -622,6 +652,42 @@ class CollectiveEpilogue< } ++store_pipe_producer_state; ++issued_stores; + + // Wait for the next smem buffer to be available + if (issue_tma_store) { + store_pipeline.producer_acquire(store_pipe_producer_state); + } + synchronize(); + + if constexpr (ReuseSmemC) { + // producer_acquire returns when at most StagesD-1 committed stores are pending + bool store_finished = issued_stores > StorePipeline::UnacquiredStages; + + // Free an smem buffer for reduction if necessary + if (cst_callbacks.is_reduction_buffer_needed(epi_m, epi_n, is_last_iteration) && not store_finished) { + if (issue_tma_store) { + store_pipeline.producer_tail(store_pipe_producer_state); // wait for all TMA stores to finish + } + synchronize(); + } + + // Smem reduction callback entry point using least recently acquired load buffer for workspace + cst_callbacks.reduce(sC_epi(_,_,load_pipe_consumer_state.index()), + synchronize, epi_m, epi_n, is_last_iteration); + + // Let dma warp know earliest smem buffer is consumed and empty after StagesD producer commits + if (store_finished) { + if (is_producer_load_needed) { + load_pipeline.consumer_release(load_pipe_consumer_state); + } + ++load_pipe_consumer_state; + } + } + else { + // Smem reduction callback entry point using most recently acquired store buffer for workspace + cst_callbacks.reduce(sD_epi(_,_,store_pipe_producer_state.index()), + synchronize, epi_m, epi_n, is_last_iteration); + } } // for epi_m } // for epi_n @@ -644,9 +710,8 @@ class CollectiveEpilogue< if constexpr (ReuseSmemC) { if (fusion_callbacks.is_producer_load_needed()) { - // Issue releases on up to StagesD previously issued TMA stores - constexpr int release_stages = - cute::min(StagesD, get_load_pipe_increment(CtaTileMNK{})); + // Issue releases on up to StagesD-1 previously issued TMA stores + constexpr int release_stages = cute::min(StorePipeline::UnacquiredStages, get_load_pipe_increment(CtaTileMNK{})); CUTLASS_PRAGMA_UNROLL for (int stage = 0; stage < release_stages; ++stage) { load_pipeline.consumer_release(load_pipe_consumer_state); diff --git a/include/cutlass/epilogue/fusion/operations.hpp b/include/cutlass/epilogue/fusion/operations.hpp index 848d9a1146..82e1240463 100644 --- a/include/cutlass/epilogue/fusion/operations.hpp +++ b/include/cutlass/epilogue/fusion/operations.hpp @@ -60,15 +60,21 @@ struct FusionOperation { using ElementBias = void; static constexpr int AlignmentBias = 0; static constexpr bool IsPerRowBiasSupported = false; + static constexpr bool IsDePerRowBiasSupported = false; + using ActivationFn = void; static constexpr bool IsEltActSupported = false; + static constexpr bool IsDeEltActSupported = false; using ElementAux = void; using GmemLayoutTagAux = void; static constexpr int AlignmentAux = 0; static constexpr bool IsAuxOutSupported = false; + static constexpr bool IsAuxInSupported = false; + using ElementAmax = void; static constexpr bool IsAbsMaxSupported = false; + }; // D = alpha * acc @@ -242,6 +248,51 @@ struct ScaledLinCombPerRowBiasEltActAmaxAux static constexpr bool IsAuxOutSupported = true; }; +// Z = Aux +// dY = alpha * acc + beta * C +// D = d_activation(dY, Z) +template< + class GmemLayoutTagAux_, + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementAux_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentAux_ = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombDeEltAct + : LinCombEltAct { + using ElementAux = ElementAux_; + using GmemLayoutTagAux = GmemLayoutTagAux_; + static constexpr int AlignmentAux = AlignmentAux_; + static constexpr bool IsAuxInSupported = true; +}; + +// Z = Aux +// dY = alpha * acc + beta * C +// D = d_activation(dY, Z) +// dBias = sum of columns of D +template< + class GmemLayoutTagAux_, + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementAux_ = ElementOutput_, + class ElementBias_ = ElementCompute_, + class ElementScalar_ = ElementCompute_, + int AlignmentAux_ = 128 / sizeof_bits_v, + int AlignmentBias_ = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombDeEltActDePerRowBias + : LinCombDeEltAct { + using ElementBias = ElementBias_; + static constexpr int AlignmentBias = AlignmentBias_; + static constexpr bool IsDePerRowBiasSupported = true; +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::epilogue::fusion diff --git a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp index 84f75f92ac..a3767b7835 100644 --- a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp @@ -117,7 +117,7 @@ template< FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest > using Sm90LinearCombination = - Sm90EVT, // beta * C + (alpha * acc) + Sm90EVT, // beta * C + (alpha * acc) Sm90ScalarBroadcast, // beta Sm90SrcFetch, // C Sm90EVT, // alpha * acc @@ -143,9 +143,9 @@ struct FusionCallbacks< fusion::LinearCombination, CtaTileShapeMNK, EpilogueTile -> : Sm90LinearCombination { +> : Sm90LinearCombination::type, ElementCompute, ElementScalar, RoundStyle> { - using Impl = Sm90LinearCombination; + using Impl = Sm90LinearCombination::type, ElementCompute, ElementScalar, RoundStyle>; using Operation = fusion::LinearCombination; struct Arguments { @@ -208,7 +208,7 @@ struct FusionCallbacks< EpilogueTile > : Sm90LinCombEltAct { - using Impl = Sm90LinCombEltAct; + using Impl = Sm90LinCombEltAct::type, ElementCompute, ElementScalar, RoundStyle>; using Operation = fusion::LinCombEltAct; struct Arguments { @@ -255,10 +255,10 @@ template< FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest > using Sm90LinCombPerRowBias = - Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90EVT, // beta * C + (alpha * acc + bias) Sm90ScalarBroadcast, // beta Sm90SrcFetch, // C - Sm90EVT, // alpha * acc + bias + Sm90EVT, // alpha * acc + bias Sm90ScalarBroadcast, // alpha Sm90AccFetch, // acc Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, Stride<_1,_0,int>, AlignmentBias> // bias @@ -541,10 +541,10 @@ template< FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest > using Sm90PerRowLinCombPerRowBias = - Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90EVT, // beta * C + (alpha * acc + bias) Sm90ColBroadcast<0, CtaTileShapeMNK, ElementScalar, Stride<_1,_0,_0>, AlignmentScalar>, // beta Sm90SrcFetch, // C - Sm90EVT, // alpha * acc + bias + Sm90EVT, // alpha * acc + bias Sm90ColBroadcast<0, CtaTileShapeMNK, ElementScalar, Stride<_1,_0,_0>, AlignmentScalar>, // alpha Sm90AccFetch, // acc Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, Stride<_1,_0,int>, AlignmentBias> // bias @@ -669,10 +669,10 @@ template< FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest > using Sm90ScaledLinCombPerRowBias = - Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90EVT, // beta * C + (alpha * acc + bias) Sm90ScalarBroadcast, 2>, // scale_c * beta Sm90SrcFetch, // C - Sm90EVT, // alpha * acc + bias + Sm90EVT, // alpha * acc + bias Sm90ScalarBroadcast, 3>, // scale_a * scale_b * alpha Sm90AccFetch, // acc Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, Stride<_1,_0,int>, AlignmentBias> // bias @@ -1003,6 +1003,229 @@ struct FusionCallbacks< ///////////////////////////////////////////////////////////////////////////////////////////////// +template< + class CtaTileShapeMNK, + class EpilogueTile, + int Stages, + class StrideAux, + class SmemLayoutAtom, + class CopyOpS2R, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombDeEltAct = + Sm90EVT, // activation(beta * C + (alpha * acc), aux) + Sm90LinearCombination, // beta * C + (alpha * acc) + Sm90AuxLoad, // aux + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + class GmemLayoutTagAux, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux, + class ElementScalar, + int AlignmentAux, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class SmemLayoutAtom, + class CopyOpS2R +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombDeEltAct< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementScalar, AlignmentAux, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile, + SmemLayoutAtom, + CopyOpS2R +> : Sm90LinCombDeEltAct< + CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpS2R, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementScalar, AlignmentAux, RoundStyle + > { + + using Impl = + Sm90LinCombDeEltAct< + CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpS2R, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementScalar, AlignmentAux, RoundStyle + >; + using Operation = + fusion::LinCombDeEltAct< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementScalar, AlignmentAux, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + using StrideAux = cutlass::gemm::TagToStrideC_t; + ElementAux const* aux_ptr = nullptr; + StrideAux dAux = {}; + + operator typename Impl::Arguments() const { + return + { // binary op : activation(beta * C + (alpha * acc), aux) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + {aux_ptr, ElementAux(0), dAux}, // leaf args : aux + activation // binary args : activation + }; // end binary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class CtaTileShapeMNK, + class EpilogueTile, + int Stages, + class StrideAux, + class SmemLayoutAtom, + class CopyOpS2R, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementBias = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombDeEltActDePerRowBias = + Sm90EVT, // Identity for final conversion + Sm90EVT, AlignmentBias>, + Sm90LinCombDeEltAct + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + class GmemLayoutTagAux, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux, + class ElementBias, + class ElementScalar, + int AlignmentAux, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class SmemLayoutAtom, + class CopyOpS2R +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombDeEltActDePerRowBias< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile, + SmemLayoutAtom, + CopyOpS2R +> : Sm90LinCombDeEltActDePerRowBias< + CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpS2R, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90LinCombDeEltActDePerRowBias< + CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpS2R, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + using Operation = + fusion::LinCombDeEltActDePerRowBias< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementBias, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + using StrideAux = cutlass::gemm::TagToStrideC_t; + ElementAux const* aux_ptr = nullptr; + StrideAux dAux = {}; + + using StrideBias = Stride<_1,_0,int>; + ElementBias* dbias_ptr = nullptr; + StrideBias dDbias = {}; + + operator typename Impl::Arguments() const { + return + { // unary op : identity/convert + { // unary op : reduce(activation(beta * C + (alpha * acc), aux)) + { // binary op : activation(beta * C + (alpha * acc), aux) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + {aux_ptr, ElementAux(0), dAux}, // leaf args : aux + activation // binary args : activation + }, // end binary op + {dbias_ptr, ElementCompute(0), dDbias} // unary args : reduce + }, // end unary op + {} // unary args : identity/convert + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace cutlass::epilogue::fusion ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp index 0d62a4bdcb..20e4118f8b 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp @@ -38,6 +38,7 @@ #include "cutlass/cutlass.h" #include "cutlass/array.h" #include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/activation.h" #include "cute/tensor.hpp" @@ -60,6 +61,25 @@ using namespace detail; // ///////////////////////////////////////////////////////////////////////////////////////////////// +// The template argument provided for ComputeFn must be able to accept +// exactly one template parameter. In Standard C++, it's OK for +// ComputeFn to have other template parameters, as long as those have +// defaults. For example, the following struct Foo would work. +// +// template +// struct Foo { +// CUTLASS_HOST_DEVICE auto operator() (A a, B b); +// }; +// +// However, some compilers, such as Clang, require that the argument +// take _exactly_ one template parameter. This is nonstandard C++ +// behavior. One work-around for this case is to create a subclass +// with exactly one template parameter, and then use that subclass as +// the template argument. +// +// template +// struct FooHomogeneous : public Foo {}; +// template< template class ComputeFn, class ElementOutput, @@ -67,77 +87,25 @@ template< FloatRoundStyle RoundStyle, class = void > -struct Sm90Compute : Sm90VisitorImpl<> { - - using Sm90VisitorImpl<>::Sm90VisitorImpl; - - struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { - template - CUTLASS_DEVICE Array - visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, - Array const&... frg_inputs) { - return transform_apply(cute::make_tuple(frg_inputs...), - [&] (auto&& frg_input) { - using ElementInput = typename cute::remove_cvref_t::Element; - using ConvertInput = NumericArrayConverter; - ConvertInput convert_input{}; - - return convert_input(frg_input); - }, - [&] (auto&&... cvt_frg_inputs) { - using ComputeOutput = ComputeFn>; - using ConvertOutput = NumericArrayConverter; - ComputeOutput compute_output{}; - ConvertOutput convert_output{}; - - return convert_output(compute_output(cvt_frg_inputs...)); - } - ); - } +struct Sm90Compute { +private: + using EmptyArguments = typename Sm90VisitorImpl<>::Arguments; + template + struct ComputeArguments { + using type = EmptyArguments; }; - template < - bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class EpilogueTile, - class TiledCopy, - class SrcTensor - > - CUTLASS_DEVICE auto - get_consumer_store_callbacks( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_mnk, - TileCoordMNKL tile_coord_mnkl, - EpilogueTile epi_tile, - TiledCopy tiled_copy, - int thread_idx, - SrcTensor const& tCrC) { - return ConsumerStoreCallbacks(); - } - -}; - -// partial specialization for compute fns that define an Arguments member, e.g. activation hyperparameters -template< - template class ComputeFn, - class ElementOutput, - class ElementCompute, - FloatRoundStyle RoundStyle -> -struct Sm90Compute< - ComputeFn, - ElementOutput, - ElementCompute, - RoundStyle, - cute::void_t::Arguments> -> { + // partial specialization for compute fns that define an Arguments member, e.g. activation hyperparameters + template + struct ComputeArguments> { + using type = typename Fn::Arguments; + }; +public: struct SharedStorage { }; - using Arguments = typename ComputeFn::Arguments; + using Arguments = typename ComputeArguments>::type; using Params = Arguments; @@ -147,6 +115,18 @@ struct Sm90Compute< return args; } + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream) { + return cutlass::Status::kSuccess; + } + CUTLASS_DEVICE bool is_producer_load_needed() const { return false; @@ -166,19 +146,9 @@ struct Sm90Compute< Params const params; - template < - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class EpilogueTile - > + template CUTLASS_DEVICE auto - get_producer_load_callbacks( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_mnk, - TileCoordMNKL tile_coord_mnkl, - EpilogueTile epi_tile, - int thread_idx) { + get_producer_load_callbacks(ProducerLoadArgs const& args) { return EmptyProducerLoadCallbacks{}; } @@ -207,7 +177,12 @@ struct Sm90Compute< ComputeOutput compute_output{}; ConvertOutput convert_output{}; - return convert_output(compute_output(cvt_frg_inputs..., params)); + if constexpr (is_same_v) { + return convert_output(compute_output(cvt_frg_inputs...)); + } + else { + return convert_output(compute_output(cvt_frg_inputs..., params)); + } } ); } @@ -216,22 +191,10 @@ struct Sm90Compute< template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class EpilogueTile, - class TiledCopy, - class SrcTensor + class... Args > CUTLASS_DEVICE auto - get_consumer_store_callbacks( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_mnk, - TileCoordMNKL tile_coord_mnkl, - EpilogueTile epi_tile, - TiledCopy tiled_copy, - int thread_idx, - SrcTensor const& tCrC) { + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { return ConsumerStoreCallbacks(params); } @@ -255,7 +218,7 @@ template < class InputAddOp // Z > struct Sm90TreeVisitor< - Sm90Compute, + Sm90Compute, Sm90ScalarBroadcast, Sm90SrcFetch, InputAddOp @@ -263,7 +226,7 @@ struct Sm90TreeVisitor< Sm90ScalarBroadcast, Sm90SrcFetch, InputAddOp, - Sm90Compute + Sm90Compute > { using Impl = @@ -271,7 +234,7 @@ struct Sm90TreeVisitor< Sm90ScalarBroadcast, Sm90SrcFetch, InputAddOp, - Sm90Compute + Sm90Compute >; CUTLASS_DEVICE bool @@ -334,37 +297,482 @@ struct Sm90TreeVisitor< template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class EpilogueTile, - class TiledCopy, - class SrcTensor + class... Args > CUTLASS_DEVICE auto - get_consumer_store_callbacks( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_mnk, - TileCoordMNKL tile_coord_mnkl, - EpilogueTile epi_tile, - TiledCopy tiled_copy, - int thread_idx, - SrcTensor const& tCrC) { + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { return ConsumerStoreCallbacks( is_C_load_needed(), - Impl::get_consumer_store_callbacks( - problem_shape_mnkl, - tile_shape_mnk, - tile_coord_mnkl, - epi_tile, - tiled_copy, - thread_idx, - tCrC - ) + Impl::get_consumer_store_callbacks(args) ); } }; +// ReLU with aux bit tensor dReLU/dZ +// Aux(i) = Z(i) >= 0 ? 1 : 0 +namespace detail { +template < + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle, + class StrideMNL, + int Alignment, + bool EnableNullptr +> +struct Sm90ReLUAuxStore { + static_assert(Alignment % 128 == 0, "sub-16B alignment not supported yet"); + + struct SharedStorage {}; + + struct Arguments { + cutlass::uint1b_t* ptr_aux = nullptr; + StrideMNL dAux = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90ReLUAuxStore() { } + + CUTLASS_HOST_DEVICE + Sm90ReLUAuxStore(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params const params; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + RTensor&& tC_rAux, + GTensor&& tC_gAux, + CTensor tC_cAux, + ResidueMN residue_mn, + Params const& params) + : tC_rAux(cute::forward(tC_rAux)), + tC_gAux(cute::forward(tC_gAux)), + tC_cAux(tC_cAux), + residue_mn(residue_mn), + params(params) {} + + RTensor tC_rAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + GTensor tC_gAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + CTensor tC_cAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ResidueMN residue_mn; + Params const& params; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input) { + using ConvertInput = NumericArrayConverter; + using ConvertAux = PackPredicates; + using ComputeOutput = cutlass::epilogue::thread::ReLu; + using ConvertOutput = NumericArrayConverter; + ConvertInput convert_input{}; + ComputeOutput relu{}; + ConvertAux convert_aux{}; + ConvertOutput convert_output{}; + + Array frg_compute = convert_input(frg_input); + bool frg_aux[FragmentSize]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + ElementCompute pre_relu = frg_compute[i]; + frg_compute[i] = relu(frg_compute[i]); + frg_aux[i] = frg_compute[i] == pre_relu; + } + + static_assert(FragmentSize % 8 == 0, "Predicate vector must be byte-aligned"); + Tensor tC_rAux_frg = recast(coalesce(tC_rAux(_,_,_,epi_m,epi_n))); // (EPI_V) + tC_rAux_frg(epi_v) = convert_aux(frg_aux); + + return convert_output(frg_compute); + } + + CUTLASS_DEVICE void + end() { + if constexpr (EnableNullptr) { + if (params.ptr_aux == nullptr) { + return; + } + } + + // Copy vectorizes into byte-aligned stores + constexpr int V = cute::min(Alignment, decltype(max_common_vector(tC_rAux, tC_gAux))::value); + if constexpr (V > 0 && V % 8 == 0) { + using VecType = uint_bit_t; + Tensor tC_rAux_vec = recast(tC_rAux); + Tensor tC_gAux_vec = recast(tC_gAux); + Tensor tC_cAux_vec = tC_cAux.compose(make_layout(Int{}, Int{})); + auto predicate_fn = [&] (auto&&... coords) { return elem_less(tC_cAux_vec(coords...), residue_mn); }; + copy_if(FunctionPredTensor(predicate_fn), tC_rAux_vec, tC_gAux_vec); + } + // sub-byte vectorization, must serialize threads + else { + // Assumes no inter-warp sharing of bytes (most copy layouts should satisfy this) + int lane_idx = canonical_lane_idx(); + auto predicate_fn = [&] (auto&&... coords) { return elem_less(tC_cAux(coords...), residue_mn); }; + CUTLASS_PRAGMA_NO_UNROLL + for (int i = 0; i < NumThreadsPerWarp; ++i) { + if (lane_idx == i) { + copy_if(FunctionPredTensor(predicate_fn), tC_rAux, tC_gAux); + } + __syncwarp(); + } + } + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + gmem_ptr ptr_aux = make_gmem_ptr(subbyte_iterator(params.ptr_aux)); + Tensor mAux = make_tensor(ptr_aux, make_layout(make_shape(M,N,L), params.dAux)); // (M,N,L) + Tensor gAux = local_tile(mAux, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) + + Tensor tC_gAux = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + gAux, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tC_rAux = make_tensor(shape(tC_gAux)); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + return ConsumerStoreCallbacks(cute::move(tC_rAux), cute::move(tC_gAux), args.tCcD, args.residue_mn, params); + } +}; +} // namespace detail + +// Specialization on the generic compute+aux EVT +template < + // Compute node + template class Activation, + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle, + // Aux node + int Stages, + class EpilogueTile, + class StrideMNL, + class SmemLayoutAtom, + class CopyOpR2S, + int Alignment, + bool EnableNullptr, + // Input node + class InputOp +> +struct Sm90TreeVisitor< + Sm90Compute, cutlass::epilogue::thread::ReLu>, void>>, + Sm90TreeVisitor< + Sm90AuxStore< + Stages, + EpilogueTile, + cutlass::uint1b_t, + RoundStyle, + StrideMNL, + SmemLayoutAtom, + CopyOpR2S, + Alignment, + EnableNullptr + >, + InputOp + > +> : Sm90VisitorImpl< + Sm90VisitorImpl< + InputOp, + detail::Sm90ReLUAuxStore + >, + Sm90Compute + > +{ + using Impl = + Sm90VisitorImpl< + Sm90VisitorImpl< + InputOp, + detail::Sm90ReLUAuxStore + >, + Sm90Compute + >; + + using Impl::Sm90VisitorImpl; + + template + struct ConsumerStoreCallbacks : CallbacksImpl { + CUTLASS_DEVICE + ConsumerStoreCallbacks(CallbacksImpl&& impl) + : CallbacksImpl(cute::forward(impl)) { } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + auto& [callbacks_input, callbacks_relu_aux] = get<0>(CallbacksImpl::callbacks_tuple).callbacks_tuple; + + Array frg_input = callbacks_input.visit(frg_acc, epi_v, epi_m, epi_n); + return callbacks_relu_aux.visit(frg_acc, epi_v, epi_m, epi_n, frg_input); + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + return ConsumerStoreCallbacks( + Impl::get_consumer_store_callbacks(args) + ); + } + +}; + +// Aux load for uint1b_t +template < + int Stages, + class EpilogueTile, + class StrideMNL, + class SmemLayoutAtom, + class CopyOpS2R, + int Alignment, + bool EnableNullptr +> +struct Sm90AuxLoad< + Stages, + EpilogueTile, + cutlass::uint1b_t, + StrideMNL, + SmemLayoutAtom, + CopyOpS2R, + Alignment, + EnableNullptr +> { + static_assert(Alignment % 128 == 0, "sub-16B alignment not supported yet"); + + struct SharedStorage {}; + + struct Arguments { + cutlass::uint1b_t const* ptr_aux = nullptr; + cutlass::uint1b_t null_default = cutlass::uint1b_t(0); + StrideMNL dAux = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90AuxLoad() { } + + CUTLASS_HOST_DEVICE + Sm90AuxLoad(Params const& params, SharedStorage const&) + : params(params) { } + + Params const params; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(RTensor&& tC_rAux_, GTensor&& tC_gAux_, ResidueMN residue_mn_, Params const& params_) + : tC_rAux(cute::forward(tC_rAux_)), + tC_gAux(cute::forward(tC_gAux_)), + residue_mn(residue_mn_), + params(params_) {} + + RTensor tC_rAux; // (CPY,CPY_M,CPY_N,{EPI_M,EPI_N}) + GTensor tC_gAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ResidueMN residue_mn; + Params const& params; + + CUTLASS_DEVICE void + begin() { + if constexpr (decltype(rank(tC_rAux))::value == 5) { + if constexpr (EnableNullptr) { + if (params.ptr_aux == nullptr) { + return; + } + } + + if (elem_less(repeat_like(residue_mn, _0{}), residue_mn)) { // (partially) in-bounds CTA tile + copy(tC_gAux, tC_rAux); + } + } + } + + CUTLASS_DEVICE void + previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { + if constexpr (decltype(rank(tC_rAux))::value == 3) { + if constexpr (EnableNullptr) { + if (params.ptr_aux == nullptr) { + return; + } + } + + if (elem_less(repeat_like(residue_mn, _0{}), residue_mn)) { + copy(tC_gAux(_,_,_,epi_m,epi_n), tC_rAux); + } + } + } + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + using ElementRegister = typename remove_cvref_t::value_type; + if constexpr (decltype(rank(tC_rAux))::value == 3) { + return recast>(coalesce(tC_rAux))(epi_v); + } + else { + return recast>(coalesce(tC_rAux(_,_,_,epi_m,epi_n)))(epi_v); + } + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + gmem_ptr ptr_aux = make_gmem_ptr(subbyte_iterator(params.ptr_aux)); + Tensor mAux = make_tensor(ptr_aux, make_layout(make_shape(M,N,L), params.dAux)); // (M,N,L) + Tensor gAux = local_tile(mAux, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) + + Tensor tC_gAux = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + gAux, args.epi_tile, args.tiled_copy, args.thread_idx); + + // If byte-unaligned vectorization, store in registers as uint32_t to reduce redundant pack+unpack instruction sequences + constexpr int V = decltype(max_common_vector(tC_gAux.layout(), make_layout(tC_gAux.shape())))::value; + Tensor tC_rAux = [&] () { + if constexpr (V % 8 != 0) { + return make_tensor(take<0,3>(shape(tC_gAux))); // (CPY,CPY_M,CPY_N) + } else { + return make_tensor(shape(tC_gAux)); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + } + }(); + + if constexpr (EnableNullptr) { + if (params.ptr_aux == nullptr) { + fill(tC_rAux, params.null_default); + } + } + + return ConsumerStoreCallbacks(cute::move(tC_rAux), cute::move(tC_gAux), args.residue_mn, params); + } +}; + +// dReLU specialization +template< + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle +> +struct Sm90Compute< + cutlass::epilogue::thread::dReLU, + ElementOutput, + ElementCompute, + RoundStyle +> : Sm90VisitorImpl<> { + + using Sm90VisitorImpl<>::Sm90VisitorImpl; + + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input, + Array const& frg_aux) { + using ConvertInput = NumericArrayConverter; + using ComputeOutput = cutlass::epilogue::thread::dReLU>; + using ConvertOutput = NumericArrayConverter; + ConvertInput convert_input{}; + ComputeOutput compute_output{}; + ConvertOutput convert_output{}; + + return convert_output(compute_output(convert_input(frg_input), frg_aux)); // don't convert frg_aux for dReLU + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + return ConsumerStoreCallbacks(); + } +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass::epilogue::fusion diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp index 348a62befe..e615447dbe 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp @@ -71,23 +71,10 @@ struct Sm90AccFetch : Sm90VisitorImpl<> { template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class EpilogueTile, - class TiledCopy, - class SrcTensor + class... Args > CUTLASS_DEVICE auto - get_consumer_store_callbacks( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_mnk, - TileCoordMNKL tile_coord_mnkl, - EpilogueTile epi_tile, - TiledCopy tiled_copy, - int thread_idx, - SrcTensor const& tCrC) { - + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { return ConsumerStoreCallbacks{}; } }; @@ -131,24 +118,12 @@ struct Sm90SrcFetch : Sm90VisitorImpl<> { template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class EpilogueTile, - class TiledCopy, - class SrcTensor + class... Args > CUTLASS_DEVICE auto - get_consumer_store_callbacks( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_mnk, - TileCoordMNKL tile_coord_mnkl, - EpilogueTile epi_tile, - TiledCopy tiled_copy, - int thread_idx, - SrcTensor const& tCrC) { + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - return ConsumerStoreCallbacks(tCrC); + return ConsumerStoreCallbacks(args.tCrC); } }; @@ -223,6 +198,18 @@ struct Sm90AuxLoad { return Params{tma_load_aux, args.null_default, use_default}; } + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream) { + return cutlass::Status::kSuccess; + } + CUTLASS_HOST_DEVICE Sm90AuxLoad() { } @@ -277,33 +264,23 @@ struct Sm90AuxLoad { } }; - template < - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL - > + template CUTLASS_DEVICE auto - get_producer_load_callbacks( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_mnk, - TileCoordMNKL tile_coord_mnkl, - EpilogueTile epi_tile, - int thread_idx) { + get_producer_load_callbacks(ProducerLoadArgs const& args) { - auto [M, N, K, L] = problem_shape_mnkl; - auto [m, n, k, l] = tile_coord_mnkl; + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; Tensor mAux = params_ptr->tma_load_aux.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) - Tensor gAux = local_tile(mAux, take<0,2>(tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) + Tensor gAux = local_tile(mAux, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) - Tensor gAux_epi = local_tile(gAux, epi_tile, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor gAux_epi = flat_divide(gAux, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) Tensor sAux_epi = make_tensor(make_smem_ptr(smem_aux), SmemLayout{}); // (EPI_TILE_M,EPI_TILE_N,PIPE) ThrCopy thrblk_g2s = params_ptr->tma_load_aux.get_slice(_0{}); Tensor bGS_gAux = thrblk_g2s.partition_S(gAux_epi); // (TMA,TMA_M,TMA_N,EPI_M,EPI_N) Tensor bGS_sAux = thrblk_g2s.partition_D(sAux_epi); // (TMA,TMA_M,TMA_N,PIPE) - return ProducerLoadCallbacks( - cute::move(bGS_gAux), cute::move(bGS_sAux), params_ptr); + return ProducerLoadCallbacks(cute::move(bGS_gAux), cute::move(bGS_sAux), params_ptr); } template @@ -321,7 +298,7 @@ struct Sm90AuxLoad { Params const* params_ptr; CUTLASS_DEVICE void - step_begin(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { + previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { if constexpr (EnableNullptr) { if (params_ptr->use_default) { fill(tC_rAux, params_ptr->null_default); @@ -347,35 +324,24 @@ struct Sm90AuxLoad { template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class TiledCopy, - class SrcTensor + class... Args > CUTLASS_DEVICE auto - get_consumer_store_callbacks( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_mnk, - TileCoordMNKL tile_coord_mnkl, - EpilogueTile epi_tile, - TiledCopy tiled_copy, - int thread_idx, - SrcTensor const& tCrC) { + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - auto [M, N, K, L] = problem_shape_mnkl; + auto [M, N, K, L] = args.problem_shape_mnkl; Tensor mAux = params_ptr->tma_load_aux.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) Tensor tC_gAux = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - mAux, tile_shape_mnk, tile_coord_mnkl, epi_tile, tiled_copy, thread_idx); + mAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); Tensor tC_rAux = make_tensor(take<0,3>(shape(tC_gAux))); // (CPY,CPY_M,CPY_N) auto tiled_s2r = conditional_return( - make_tiled_copy_S(Copy_Atom{}, tiled_copy), - make_tiled_copy_D(Copy_Atom{}, tiled_copy) + make_tiled_copy_S(Copy_Atom{}, args.tiled_copy), + make_tiled_copy_D(Copy_Atom{}, args.tiled_copy) ); Tensor sAux_epi = cute::as_position_independent_swizzle_tensor( make_tensor(make_smem_ptr(smem_aux), SmemLayout{})); // (EPI_TILE_M,EPI_TILE_N,PIPE) - auto tSR_sAux = tiled_s2r.get_slice(thread_idx).partition_S(sAux_epi); // (S2R,S2R_M,S2R_N,PIPE) + auto tSR_sAux = tiled_s2r.get_slice(args.thread_idx).partition_S(sAux_epi); // (S2R,S2R_M,S2R_N,PIPE) return ConsumerStoreCallbacks(cute::move(tC_rAux), tiled_s2r, cute::move(tSR_sAux), params_ptr); @@ -400,7 +366,7 @@ struct Sm90ScalarBroadcast { static_assert( (cute::is_same_v>) || // scalar broadcast, e.g. alpha (cute::is_same_v>) || // batched scalar broadcast, e.g. per-batch alpha - (cute::is_same_v>)); + (cute::is_same_v>)); struct SharedStorage { }; @@ -418,6 +384,18 @@ struct Sm90ScalarBroadcast { return args; } + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream) { + return cutlass::Status::kSuccess; + } + CUTLASS_DEVICE bool is_producer_load_needed() const { return false; @@ -443,24 +421,14 @@ struct Sm90ScalarBroadcast { Element scalar; Params const* params_ptr; - template < - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class EpilogueTile - > + template CUTLASS_DEVICE auto - get_producer_load_callbacks( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_mnk, - TileCoordMNKL tile_coord_mnkl, - EpilogueTile epi_tile, - int thread_idx) { + get_producer_load_callbacks(ProducerLoadArgs const& args) { // Get the scalar for batched broadcast if constexpr ( - cute::is_same_v> || + cute::is_same_v> || cute::is_same_v>) { - auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; update_scalar(l_coord); } @@ -487,28 +455,16 @@ struct Sm90ScalarBroadcast { template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class EpilogueTile, - class TiledCopy, - class SrcTensor + class... Args > CUTLASS_DEVICE auto - get_consumer_store_callbacks( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_mnk, - TileCoordMNKL tile_coord_mnkl, - EpilogueTile epi_tile, - TiledCopy tiled_copy, - int thread_idx, - SrcTensor const& tCrC) { + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { // Get the scalar for batched broadcast if constexpr ( - cute::is_same_v> || + cute::is_same_v> || cute::is_same_v>) { - auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; update_scalar(l_coord); } @@ -579,6 +535,18 @@ struct Sm90RowBroadcast { return args; } + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream) { + return cutlass::Status::kSuccess; + } + CUTLASS_HOST_DEVICE Sm90RowBroadcast() { } @@ -633,29 +601,19 @@ struct Sm90RowBroadcast { } }; - template < - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class EpilogueTile - > + template CUTLASS_DEVICE auto - get_producer_load_callbacks( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_mnk, - TileCoordMNKL tile_coord_mnkl, - EpilogueTile epi_tile, - int thread_idx) { + get_producer_load_callbacks(ProducerLoadArgs const& args) { - auto [M, N, K, L] = problem_shape_mnkl; - auto [m, n, k, l] = tile_coord_mnkl; + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); - Tensor gRow = local_tile(mRow, take<0,2>(tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) + Tensor gRow = local_tile(mRow, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE) make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); - constexpr int EpiTiles = size(shape_div(take<0,2>(tile_shape_mnk), epi_tile)); + constexpr int EpiTiles = decltype(size(shape_div(take<0,2>(args.tile_shape_mnk), args.epi_tile)))::value; return ProducerLoadCallbacks( cute::move(gRow), cute::move(sRow), params); } @@ -673,7 +631,7 @@ struct Sm90RowBroadcast { Params const& params; CUTLASS_DEVICE void - step_begin(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { + previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { if constexpr (EnableNullptr) { if (params.ptr_row == nullptr) { fill(tCrRow, params.null_default); @@ -704,31 +662,19 @@ struct Sm90RowBroadcast { template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class EpilogueTile, - class TiledCopy, - class SrcTensor + class... Args > CUTLASS_DEVICE auto - get_consumer_store_callbacks( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_mnk, - TileCoordMNKL tile_coord_mnkl, - EpilogueTile epi_tile, - TiledCopy tiled_copy, - int thread_idx, - SrcTensor const& tCrC) { + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE) make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages), make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{}))); Tensor tCsRow = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE) - sRow, epi_tile, tiled_copy, thread_idx); + sRow, args.epi_tile, args.tiled_copy, args.thread_idx); Tensor tCrRow = make_tensor_like(take<0,3>(tCsRow)); // (CPY,CPY_M,CPY_N) - constexpr int EpiTiles = size(shape_div(take<0,2>(tile_shape_mnk), epi_tile)); + constexpr int EpiTiles = decltype(size(shape_div(take<0,2>(args.tile_shape_mnk), args.epi_tile)))::value; return ConsumerStoreCallbacks( cute::move(tCrRow), cute::move(tCsRow), params); } @@ -769,6 +715,18 @@ struct Sm90ColBroadcast { return args; } + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream) { + return cutlass::Status::kSuccess; + } + CUTLASS_DEVICE bool is_producer_load_needed() const { return false; @@ -788,19 +746,9 @@ struct Sm90ColBroadcast { Params params; - template < - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class EpilogueTile - > + template CUTLASS_DEVICE auto - get_producer_load_callbacks( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_mnk, - TileCoordMNKL tile_coord_mnkl, - EpilogueTile epi_tile, - int thread_idx) { + get_producer_load_callbacks(ProducerLoadArgs const& args) { return EmptyProducerLoadCallbacks{}; } @@ -847,27 +795,15 @@ struct Sm90ColBroadcast { template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class EpilogueTile, - class TiledCopy, - class SrcTensor + class... Args > CUTLASS_DEVICE auto - get_consumer_store_callbacks( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_mnk, - TileCoordMNKL tile_coord_mnkl, - EpilogueTile epi_tile, - TiledCopy tiled_copy, - int thread_idx, - SrcTensor const& tCrC) { + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - auto [M, N, K, L] = problem_shape_mnkl; + auto [M, N, K, L] = args.problem_shape_mnkl; Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - mCol, tile_shape_mnk, tile_coord_mnkl, epi_tile, tiled_copy, thread_idx); + mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) return ConsumerStoreCallbacks( diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp index 8e1ffb08a2..374309ee29 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp @@ -36,6 +36,7 @@ #pragma once #include "cutlass/cutlass.h" +#include "cutlass/workspace.h" #include "cute/tensor.hpp" #include "sm90_visitor_tma_warpspecialized.hpp" @@ -122,6 +123,18 @@ struct Sm90AuxStore { return {tma_store_aux, is_nullptr}; } + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream) { + return cutlass::Status::kSuccess; + } + CUTLASS_HOST_DEVICE Sm90AuxStore() { } @@ -143,18 +156,9 @@ struct Sm90AuxStore { return false; } - template < - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL - > + template CUTLASS_DEVICE auto - get_producer_load_callbacks( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_mnk, - TileCoordMNKL tile_coord_mnkl, - EpilogueTile epi_tile, - int thread_idx) { + get_producer_load_callbacks(ProducerLoadArgs const& args) { return EmptyProducerLoadCallbacks{}; } @@ -202,7 +206,7 @@ struct Sm90AuxStore { } CUTLASS_DEVICE void - step_next(int epi_m, int epi_n, int store_iteration, bool issue_smem_store) { + postvisit(int epi_m, int epi_n, int store_iteration, bool issue_smem_store) { if constexpr (EnableNullptr) { if (params_ptr->is_nullptr) { return; @@ -219,7 +223,7 @@ struct Sm90AuxStore { } CUTLASS_DEVICE void - step_end(int epi_m, int epi_n, int store_iteration, bool issue_tma_store) { + step(int epi_m, int epi_n, int store_iteration, bool issue_tma_store) { if constexpr (EnableNullptr) { if (params_ptr->is_nullptr) { return; @@ -236,40 +240,29 @@ struct Sm90AuxStore { template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class TiledCopy, - class SrcTensor + class... Args > CUTLASS_DEVICE auto - get_consumer_store_callbacks( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_mnk, - TileCoordMNKL tile_coord_mnkl, - EpilogueTile epi_tile, - TiledCopy tiled_copy, - int thread_idx, - SrcTensor const& tCrC) { + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - auto [M, N, K, L] = problem_shape_mnkl; - auto [m, n, k, l] = tile_coord_mnkl; + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; Tensor mAux = params_ptr->tma_store_aux.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) - Tensor gAux = local_tile(mAux, take<0,2>(tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) + Tensor gAux = local_tile(mAux, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) Tensor tC_gAux = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - gAux, epi_tile, tiled_copy, thread_idx); + gAux, args.epi_tile, args.tiled_copy, args.thread_idx); Tensor tC_rAux = make_tensor(take<0,3>(shape(tC_gAux))); // (CPY,CPY_M,CPY_N) Tensor sAux_epi = cute::as_position_independent_swizzle_tensor( make_tensor(make_smem_ptr(smem_aux), SmemLayout{})); // (EPI_TILE_M,EPI_TILE_N,PIPE) - Tensor gAux_epi = local_tile(gAux, epi_tile, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor gAux_epi = flat_divide(gAux, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) auto tiled_r2s = conditional_return( - make_tiled_copy_S(Copy_Atom{}, tiled_copy), - make_tiled_copy_D(Copy_Atom{}, tiled_copy) + make_tiled_copy_S(Copy_Atom{}, args.tiled_copy), + make_tiled_copy_D(Copy_Atom{}, args.tiled_copy) ); - auto tRS_sAux = tiled_r2s.get_slice(thread_idx).partition_D(sAux_epi); // (R2S,R2S_M,R2S_N,PIPE) + auto tRS_sAux = tiled_r2s.get_slice(args.thread_idx).partition_D(sAux_epi); // (R2S,R2S_M,R2S_N,PIPE) ThrCopy thrblk_s2g = params_ptr->tma_store_aux.get_slice(_0{}); Tensor bSG_sAux = thrblk_s2g.partition_S(sAux_epi); // (TMA,TMA_M,TMA_N,PIPE) @@ -294,7 +287,7 @@ struct Sm90AuxStore { // Scalar reduction template < template class RegReduceFn, - template class AtomicReduceFn, + template class GmemReduceFn, class ElementOutput, class ElementCompute, FloatRoundStyle RoundStyle, @@ -302,10 +295,15 @@ template < bool EnableNullptr = true // Noop on nullptr params > struct Sm90ScalarReduction { +private: static_assert( (cute::is_same_v>) || // scalar reduction, e.g. tensor max element (cute::is_same_v>) || // batched scalar reduction, e.g. per-batch max element - (cute::is_same_v>)); + (cute::is_same_v>)); + static constexpr bool IsAtomic = is_atomic>::value; + static_assert(IsAtomic, "non-atomic scalar reduction not supported yet"); + +public: struct SharedStorage { }; struct Arguments { @@ -322,6 +320,26 @@ struct Sm90ScalarReduction { return args; } + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream) { + if constexpr (IsAtomic) { + auto [M, N, K, L] = problem_shape; + Layout mScalar_layout = make_layout(make_shape(M,N,L), args.dScalar); + if (args.ptr_scalar != nullptr) { + return fill_workspace(args.ptr_scalar, ElementOutput(args.reduction_identity), cosize(mScalar_layout), stream); + } + } + + return cutlass::Status::kSuccess; + } + CUTLASS_DEVICE bool is_producer_load_needed() const { return false; @@ -341,19 +359,9 @@ struct Sm90ScalarReduction { Params const params; - template < - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class EpilogueTile - > + template CUTLASS_DEVICE auto - get_producer_load_callbacks( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_mnk, - TileCoordMNKL tile_coord_mnkl, - EpilogueTile epi_tile, - int thread_idx) { + get_producer_load_callbacks(ProducerLoadArgs const& args) { return EmptyProducerLoadCallbacks{}; } @@ -362,18 +370,18 @@ struct Sm90ScalarReduction { CUTLASS_DEVICE ConsumerStoreCallbacks( int l_coord, - CTensor&& tCcScalar, + CTensor tCcScalar, ResidueMN residue_mn, Params const& params) : scalar(params.reduction_identity), l_coord(l_coord), - tCcScalar(cute::forward(tCcScalar)), + tCcScalar(tCcScalar), residue_mn(residue_mn), params(params) {} ElementCompute scalar; int l_coord; - CTensor tCcScalar; + CTensor tCcScalar; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) ResidueMN residue_mn; Params params; @@ -393,10 +401,11 @@ struct Sm90ScalarReduction { ReduceInput reduce_input{}; Array frg_I = convert_input(frg_input); + Tensor tCcScalar_mn = tCcScalar(_,_,_,epi_m,epi_n); CUTLASS_PRAGMA_UNROLL for (int i = 0; i < FragmentSize; ++i) { - if (elem_less(tCcScalar(epi_v * FragmentSize + i), residue_mn)) { + if (elem_less(tCcScalar_mn(epi_v * FragmentSize + i), residue_mn)) { scalar = reduce_input(scalar, frg_I[i]); } } @@ -413,7 +422,7 @@ struct Sm90ScalarReduction { } using ConvertI = NumericConverter; - using ReduceInput = AtomicReduceFn; + using ReduceInput = GmemReduceFn; ConvertI convert_I{}; ReduceInput reduce_input{}; @@ -426,36 +435,12 @@ struct Sm90ScalarReduction { template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class EpilogueTile, - class TiledCopy, - class SrcTensor + class... Args > CUTLASS_DEVICE auto - get_consumer_store_callbacks( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_mnk, - TileCoordMNKL tile_coord_mnkl, - EpilogueTile epi_tile, - TiledCopy tiled_copy, - int thread_idx, - SrcTensor const& tCrC) { - - int l_coord = static_cast(get<3>(tile_coord_mnkl)); - - // Compute tile residues and coordinate tensors for predication - auto [M, N, K, L] = problem_shape_mnkl; - auto [m, n, k, l] = tile_coord_mnkl; - auto residue_mn = make_coord( - M - static_cast(m) * size<0>(tile_shape_mnk), - N - static_cast(n) * size<1>(tile_shape_mnk) - ); - Tensor cScalar = make_identity_tensor(take<0,2>(tile_shape_mnk)); - Tensor tCcScalar = sm90_partition_for_epilogue(cScalar, epi_tile, tiled_copy, thread_idx); - - return ConsumerStoreCallbacks(l_coord, cute::move(tCcScalar), residue_mn, params); + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + return ConsumerStoreCallbacks( + get<3>(args.tile_coord_mnkl), args.tCcD, args.residue_mn, params); } }; @@ -466,7 +451,7 @@ struct Sm90ScalarReduction { // Row vector reduction template < template class RegReduceFn, - template class AtomicReduceFn, + template class GmemReduceFn, int Stages, class CtaTileShapeMNK, class ElementOutput, @@ -477,12 +462,16 @@ template < bool EnableNullptr = true // Noop on nullptr params > struct Sm90RowReduction { +private: static_assert(Stages == 0, "Smem usage not supported yet"); static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); static_assert( (cute::is_same_v>) || // row vector reduction, e.g. per-col sum over all batches (cute::is_same_v>)); // batched row vector reduction, e.g. per-col sum per batch + static constexpr bool IsAtomic = is_atomic>::value; + static_assert(IsAtomic, "non-atomic row reduction not supported yet"); +public: struct SharedStorage { }; struct Arguments { @@ -499,6 +488,26 @@ struct Sm90RowReduction { return args; } + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream) { + if constexpr (IsAtomic) { + auto [M, N, K, L] = problem_shape; + Layout mRow_layout = make_layout(make_shape(M,N,L), args.dRow); + if (args.ptr_row != nullptr) { + return fill_workspace(args.ptr_row, ElementOutput(args.reduction_identity), cosize(mRow_layout), stream); + } + } + + return cutlass::Status::kSuccess; + } + CUTLASS_DEVICE bool is_producer_load_needed() const { return false; @@ -518,19 +527,9 @@ struct Sm90RowReduction { Params params; - template < - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class EpilogueTile - > + template CUTLASS_DEVICE auto - get_producer_load_callbacks( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_mnk, - TileCoordMNKL tile_coord_mnkl, - EpilogueTile epi_tile, - int thread_idx) { + get_producer_load_callbacks(ProducerLoadArgs const& args) { return EmptyProducerLoadCallbacks{}; } @@ -540,12 +539,12 @@ struct Sm90RowReduction { ConsumerStoreCallbacks( RTensor&& tCrRow, GTensor&& tCgRow, - CTensor&& tCcRow, + CTensor tCcRow, ResidueMN residue_mn, Params const& params) : tCrRow(cute::forward(tCrRow)), tCgRow(cute::forward(tCgRow)), - tCcRow(cute::forward(tCcRow)), + tCcRow(tCcRow), residue_mn(residue_mn), params(params) {} @@ -580,7 +579,7 @@ struct Sm90RowReduction { CUTLASS_PRAGMA_UNROLL for (int i = 0; i < FragmentSize; ++i) { - if (elem_less(tCcRow_mn(i), residue_mn)) { + if (elem_less(tCcRow_mn(epi_v * FragmentSize + i), residue_mn)) { ElementCompute& tCrRow_vmn = tCrRow_mn(epi_v * FragmentSize + i); tCrRow_vmn = reduce_input(tCrRow_vmn, frg_I[i]); } @@ -590,7 +589,7 @@ struct Sm90RowReduction { } CUTLASS_DEVICE void - step_end(int epi_m, int epi_n, int store_iteration, bool issue_tma_store) { + step(int epi_m, int epi_n, int store_iteration, bool issue_tma_store) { if constexpr (EnableNullptr) { if (params.ptr_row == nullptr) { return; @@ -599,7 +598,7 @@ struct Sm90RowReduction { if (epi_m == size<3>(tCrRow)-1) { // assumes M-major subtile loop using ConvertI = NumericConverter; - using ReduceInput = AtomicReduceFn; + using ReduceInput = GmemReduceFn; ConvertI convert_I{}; ReduceInput reduce_input{}; @@ -616,7 +615,7 @@ struct Sm90RowReduction { for (int i = 0; i < size(tCrRow_flt); ++i) { // partially OOB in M must still issue gmem reduction, so only consider residue_n // in case last epi tile in column is fully OOB in M and CTA tile is partially OOB in M - if (residue_n > get<1>(tCcRow_flt(i)) && + if (residue_n > get<1>(tCcRow_flt(i)) && // fully OOB in M does not need to issue gmem reduction, skip residue_m > 0) { reduce_input(&tCgRow_flt(i), convert_I(tCrRow_flt(i))); @@ -632,41 +631,20 @@ struct Sm90RowReduction { template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class EpilogueTile, - class TiledCopy, - class SrcTensor + class... Args > CUTLASS_DEVICE auto - get_consumer_store_callbacks( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_mnk, - TileCoordMNKL tile_coord_mnkl, - EpilogueTile epi_tile, - TiledCopy tiled_copy, - int thread_idx, - SrcTensor const& tCrC) { + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { - auto [M, N, K, L] = problem_shape_mnkl; + auto [M, N, K, L] = args.problem_shape_mnkl; Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); Tensor tCgRow = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - mRow, tile_shape_mnk, tile_coord_mnkl, epi_tile, tiled_copy, thread_idx); + mRow, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); Tensor tCrRow = make_tensor_like(tCgRow(_,_,_,_,_0{})); // (CPY,CPY_M,CPY_N,EPI_M) fill(tCrRow, params.reduction_identity); - // Compute tile residues and coordinate tensors for predication - auto [m, n, k, l] = tile_coord_mnkl; - auto residue_mn = make_coord( - M - static_cast(m) * size<0>(tile_shape_mnk), - N - static_cast(n) * size<1>(tile_shape_mnk) - ); - Tensor cRow = make_identity_tensor(take<0,2>(tile_shape_mnk)); - Tensor tCcRow = sm90_partition_for_epilogue(cRow, epi_tile, tiled_copy, thread_idx); - - return ConsumerStoreCallbacks( - cute::move(tCrRow), cute::move(tCgRow), cute::move(tCcRow), residue_mn, params); + return ConsumerStoreCallbacks( + cute::move(tCrRow), cute::move(tCgRow), args.tCcD, args.residue_mn, params); } }; @@ -675,7 +653,7 @@ struct Sm90RowReduction { // Col vector reduction template < template class RegReduceFn, - template class AtomicReduceFn, + template class GmemReduceFn, int Stages, class CtaTileShapeMNK, class ElementOutput, @@ -683,29 +661,110 @@ template < FloatRoundStyle RoundStyle, class StrideMNL = Stride<_1,_0,_0>, int Alignment = 128 / sizeof_bits_v, - bool EnableNullptr = true // Noop on nullptr params + bool EnableNullptr = true, // Noop on nullptr params + // If this is false, ptr_col is assumed to point to a compact m-major (round_nearest(M,CTA_M), ceil_div(N,CTA_N), L) + // tensor of ElementCompute. It is the user's responsibility to reduce this to a (M, L) tensor of ElementOutput + bool FinalReduction = true > struct Sm90ColReduction { +private: static_assert(Stages == 0, "Smem usage not supported yet"); static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); static_assert( (cute::is_same_v>) || // col vector reduction, e.g. per-row sum over all batches (cute::is_same_v>)); // batched col vector reduction, e.g. per-row sum per batch + static constexpr bool IsAtomic = is_atomic>::value; + static_assert(not (IsAtomic && not FinalReduction), "atomic reduction must be final"); +public: struct SharedStorage { }; struct Arguments { - ElementOutput* ptr_col = nullptr; + void* ptr_col = nullptr; // ElementOutput* if FinalReduction, else ElementCompute* ElementCompute reduction_identity = 0; StrideMNL dCol = {}; }; - using Params = Arguments; + struct Params { + void* ptr_col = nullptr; + ElementCompute reduction_identity = 0; + StrideMNL dCol = {}; + ElementCompute* reduction_buffer = nullptr; + int* tile_counters = nullptr; + }; template static constexpr Params to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { - return args; + ElementCompute* reduction_buffer; + int* tile_counters; + if constexpr (IsAtomic) { + reduction_buffer = nullptr; + tile_counters = nullptr; + } + else if constexpr (not FinalReduction) { + reduction_buffer = reinterpret_cast(args.ptr_col); + tile_counters = nullptr; + } + else { + auto [M, N, K, L] = problem_shape; + auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; + size_t tile_counters_offset = product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute); + tile_counters_offset = round_nearest(tile_counters_offset, sizeof(int)); + + reduction_buffer = reinterpret_cast(workspace); + tile_counters = reinterpret_cast(reinterpret_cast(workspace) + tile_counters_offset); + } + + return { + args.ptr_col, + args.reduction_identity, + args.dCol, + reduction_buffer, + tile_counters + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + if constexpr (IsAtomic || not FinalReduction) { + return 0; + } + + size_t workspace_size = 0; + auto [M, N, K, L] = problem_shape; + auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; + + // Increment by size of reduction buffer + workspace_size += product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute); + // Align and increment by size of tile counters + workspace_size = round_nearest(workspace_size, sizeof(int)); + workspace_size += cute::ceil_div(M, tile_M) * sizeof(int); + + return workspace_size; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream) { + if constexpr (IsAtomic) { + auto [M, N, K, L] = problem_shape; + Layout mCol_layout = make_layout(make_shape(M,N,L), args.dCol); + if (args.ptr_col != nullptr) { + return fill_workspace(args.ptr_col, ElementOutput(args.reduction_identity), cosize(mCol_layout), stream); + } + return Status::kSuccess; + } + + auto [M, N, K, L] = problem_shape; + auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; + size_t tile_counters_offset = product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute); + tile_counters_offset = round_nearest(tile_counters_offset, sizeof(int)); + + int* tile_counters = reinterpret_cast(reinterpret_cast(workspace) + tile_counters_offset); + size_t tile_counters_size = cute::ceil_div(M, tile_M) * sizeof(int); + return zero_workspace(tile_counters, tile_counters_size, stream); } CUTLASS_DEVICE bool @@ -727,66 +786,48 @@ struct Sm90ColReduction { Params params; - template < - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class EpilogueTile - > + template CUTLASS_DEVICE auto - get_producer_load_callbacks( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_mnk, - TileCoordMNKL tile_coord_mnkl, - EpilogueTile epi_tile, - int thread_idx) { + get_producer_load_callbacks(ProducerLoadArgs const& args) { return EmptyProducerLoadCallbacks{}; } - template + template struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { CUTLASS_DEVICE - ConsumerStoreCallbacks( - RTensor&& tCrCol, - GTensor&& tCgCol, - CTensor&& tCcCol, - ResidueMN residue_mn, - Params const& params) - : tCrCol(cute::forward(tCrCol)), - tCgCol(cute::forward(tCgCol)), - tCcCol(cute::forward(tCcCol)), - residue_mn(residue_mn), + ConsumerStoreCallbacks(ArgsTuple&& args_tuple, Params const& params) + : args_tuple(cute::forward(args_tuple)), params(params) {} - RTensor tCrCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - ResidueMN residue_mn; + ArgsTuple args_tuple; Params const& params; + bool do_final_reduction = false; template CUTLASS_DEVICE auto visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, Array const& frg_input) { - if constexpr (EnableNullptr) { if (params.ptr_col == nullptr) { return frg_input; } } + auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + tile_coord_mnkl, residue_mn, epi_tile, tiled_copy, thread_idx] = args_tuple; + Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); + Tensor tCcCol_mn = tCcCol(_,_,_,epi_m,epi_n); + using ConvertInput = NumericArrayConverter; using ReduceInput = RegReduceFn; ConvertInput convert_input{}; ReduceInput reduce_input{}; Array frg_I = convert_input(frg_input); - Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); - Tensor tCcCol_mn = tCcCol(_,_,_,epi_m,epi_n); - CUTLASS_PRAGMA_UNROLL for (int i = 0; i < FragmentSize; ++i) { - if (elem_less(tCcCol_mn(i), residue_mn)) { + if (elem_less(tCcCol_mn(epi_v * FragmentSize + i), residue_mn)) { ElementCompute& tCrCol_vmn = tCrCol_mn(epi_v * FragmentSize + i); tCrCol_vmn = reduce_input(tCrCol_vmn, frg_I[i]); } @@ -795,71 +836,283 @@ struct Sm90ColReduction { return frg_input; } + template CUTLASS_DEVICE void - end() { + reduce(STensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration) { + if (not is_last_iteration) { + return; + } + + auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + tile_coord_mnkl, residue_mn, epi_tile, tiled_copy, thread_idx] = args_tuple; + auto [m, n, k, l] = tile_coord_mnkl; + constexpr bool ReferenceSrc = decltype(ref_src)::value; + + // Runtime nullptr is noop if constexpr (EnableNullptr) { if (params.ptr_col == nullptr) { return; } } - using ConvertI = NumericConverter; - using ReduceInput = AtomicReduceFn; + // fully OOB CTA in partially OOB cluster + if (not elem_less(cCol(_0{},_0{}), residue_mn)) { + return; + } - ConvertI convert_I{}; - ReduceInput reduce_input{}; + // + // 1. Warp shuffle reduction + // + using FragmentShuffle = Array; + using ReduceShuffle = RegReduceFn; + ReduceShuffle reduce_shuffle{}; + Tensor tCrCol_frg = recast(filter(tCrCol)); + CUTLASS_PRAGMA_UNROLL + for (int reduction_cols = size<1>(lane_layout_MN) / 2; reduction_cols > 0; reduction_cols /= 2) { + CUTLASS_PRAGMA_UNROLL + for (int frg_idx = 0; frg_idx < size(tCrCol_frg); ++frg_idx) { + uint64_t frg_shfl = reinterpret_cast(tCrCol_frg(frg_idx)); + frg_shfl = __shfl_down_sync(0xFFFFFFFF, frg_shfl, lane_layout_MN(_0{},reduction_cols)); + tCrCol_frg(frg_idx) = reduce_shuffle(tCrCol_frg(frg_idx), reinterpret_cast(frg_shfl)); + } + } + bool is_reduced_lane = get<1>(lane_mn) == 0; - // Filter so we don't issue redunant copies over stride-0 modes - Tensor tCrCol_flt = filter_zeros(tCrCol); - Tensor tCgCol_flt = filter_zeros(tCgCol); - Tensor tCcCol_flt = make_tensor(tCcCol.data(), make_layout(tCgCol_flt.shape(), tCcCol.stride())); + // + // 2. Atomic reduction + // + if constexpr (IsAtomic) { + // Filter so we don't issue redunant copies over stride-0 modes + Tensor tCrCol_flt = filter_zeros(tCrCol); + Tensor tCcCol_flt = make_tensor(tCcCol.data(), make_layout(tCrCol_flt.shape(), tCcCol.stride())); + + Tensor tCgCol = sm90_partition_for_epilogue(gCol_l(_,_,l), epi_tile, tiled_copy, thread_idx); + Tensor tCgCol_flt = filter_zeros(tCgCol); + + // NOTE: atomic reduction is performed in the output type + using ConvertOutput = NumericConverter; + using ReduceOutput = GmemReduceFn; + ConvertOutput convert_output{}; + ReduceOutput reduce_output{}; + + if (is_reduced_lane) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrCol_flt); ++i) { + if (elem_less(tCcCol_flt(i), residue_mn)) { + reduce_output(&tCgCol_flt(i), convert_output(tCrCol_flt(i))); + } + } + } + sync_fn(); + } - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < size(tCrCol_flt); ++i) { - if (elem_less(tCcCol_flt(i), residue_mn)) { - reduce_input(&tCgCol_flt(i), convert_I(tCrCol_flt(i))); + // + // 2. One warp in N, skip threadblock smem reduction + // + else if constexpr (decltype(size<1>(warp_layout_MN))::value <= 1) { + // Dump warp reduction to gmem workspace + using ElementGmem = conditional_t; + Tensor tCgBuf = sm90_partition_for_epilogue(gBuf_nl(_,_,n,l), epi_tile, tiled_copy, thread_idx); + if (is_reduced_lane) { + // Filter so we don't issue redunant copies over stride-0 modes + copy(filter(tCrCol), recast(filter(tCgBuf))); } + sync_fn(); + } + + // + // 2. Multiple warps in N, do threadblock smem reduction + // + else { + Tensor sBuf = make_tensor(make_smem_ptr(raw_pointer_cast(smem_buffer.data())), sBuf_layout); + static_assert(decltype(cosize(sBuf.layout()))::value * sizeof(ElementCompute) <= + decltype(cosize(smem_buffer.layout()))::value * sizeof(typename remove_cvref_t::value_type), + "smem reduction buffer not large enough, use a larger epilogue tile"); + + // Dump warp reduction to smem workspace + Tensor tCsBuf = sm90_partition_for_epilogue(sBuf(_,_,get<1>(warp_mn)), epi_tile, tiled_copy, thread_idx); + if (is_reduced_lane) { + // Filter so we don't issue redunant copies over stride-0 modes + copy(filter(tCrCol), filter(tCsBuf)); + } + sync_fn(); + + constexpr int SmemFragSize = cute::max(1, sizeof(uint32_t) / sizeof(ElementCompute)); + using FragmentSmem = Array; + using VectorSmem = uint_bit_t>; + using ReduceSmem = GmemReduceFn; + ReduceSmem reduce_smem{}; + + Tensor sBuf_frg = recast(filter_zeros(sBuf)); + Tensor sBuf_vec = recast(filter_zeros(sBuf)); + constexpr int FragsPerCol = decltype(size<0>(sBuf_frg))::value; + + // Do the threadblock smem reduction + CUTLASS_PRAGMA_UNROLL + for (int reduction_cols = size<1>(warp_layout_MN) / 2; reduction_cols > 1; reduction_cols /= 2) { + int FragsPerReduction = reduction_cols * FragsPerCol; + CUTLASS_PRAGMA_NO_UNROLL + for (int frg_idx = thread_idx; frg_idx < FragsPerReduction; frg_idx += size(tiled_copy)) { + FragmentSmem frg_smem = reduce_smem(sBuf_frg(frg_idx), sBuf_frg(frg_idx + FragsPerReduction)); + sBuf_vec(frg_idx) = reinterpret_cast(frg_smem); + } + sync_fn(); + } + + // Do final smem reduction and dump to gmem workspace + using VectorGmem = conditional_t; + Tensor gBuf_vec = recast(filter(gBuf_nl(_,_,n,l))); + CUTLASS_PRAGMA_NO_UNROLL + for (int frg_idx = thread_idx; frg_idx < FragsPerCol; frg_idx += size(tiled_copy)) { + FragmentSmem frg_smem = reduce_smem(sBuf_frg(frg_idx), sBuf_frg(frg_idx + FragsPerCol)); + gBuf_vec(frg_idx) = reinterpret_cast(frg_smem); + } + sync_fn(); + } + + // + // 3. Increment atomic counters to signal final gmem reduction + // + if constexpr (not IsAtomic && FinalReduction) { + // Ensure gmem writes are visible to other threads before incrementing counter + __threadfence(); + sync_fn(); + // Collective thread 0 increments atomic tile counter and copies value to smem + int* prev_tile_count = reinterpret_cast(raw_pointer_cast(smem_buffer.data())); + if (thread_idx == 0) { + *prev_tile_count = atomicAdd(¶ms.tile_counters[m], 1); + } + sync_fn(); + // Broadcast tile count to other threads in CTA and determine final reduction status + do_final_reduction = *prev_tile_count == size<2>(gBuf_nl) * size<3>(gBuf_nl) - 1; + sync_fn(); + } + } + + CUTLASS_DEVICE void + end() { + // + // 4. Do final gmem reduction if necessary + // + if constexpr (not IsAtomic && FinalReduction) { + if (not do_final_reduction) { + return; + } + + auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + tile_coord_mnkl, residue_mn, epi_tile, tiled_copy, thread_idx] = args_tuple; + + using ReduceOutput = GmemReduceFn; + using ConvertOutput = NumericConverter; + ReduceOutput reduce_output{}; + ConvertOutput convert_output{}; + + // Reduction over batches + if (size<2>(stride(gCol_l)) == 0) { + CUTLASS_PRAGMA_NO_UNROLL + for (int m = thread_idx; m < size<0>(gBuf_nl); m += size(tiled_copy)) { + Tensor tRgBuf_nl = gBuf_nl(m,_0{},_,_); + ElementCompute output = tRgBuf_nl(_0{}); + CUTLASS_PRAGMA_NO_UNROLL + for (int nl = 1; nl < size(tRgBuf_nl); ++nl) { + output = reduce_output(output, tRgBuf_nl(nl)); + } + if (elem_less(cCol(m,_0{}), residue_mn)) { + gCol_l(m,_0{},_0{}) = convert_output(output); + } + } + } + // No reduction over batches + else { + CUTLASS_PRAGMA_NO_UNROLL + for (int m = thread_idx; m < size<0>(gBuf_nl); m += size(tiled_copy)) { + bool do_store = elem_less(cCol(m,_0{}), residue_mn); + CUTLASS_PRAGMA_NO_UNROLL + for (int l = 0; l < size<3>(gBuf_nl); ++l) { + Tensor tRgBuf_n = gBuf_nl(m,_0{},_,l); + ElementCompute output = tRgBuf_n(_0{}); + CUTLASS_PRAGMA_NO_UNROLL + for (int n = 1; n < size(tRgBuf_n); ++n) { + output = reduce_output(output, tRgBuf_n(n)); + } + if (do_store) { + gCol_l(m,_0{},l) = convert_output(output); + } + } + } + } + } } + CUTLASS_DEVICE bool + is_reduction_buffer_needed(int epi_m, int epi_n, bool is_last_iteration) const { + auto const& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + tile_coord_mnkl, residue_mn, epi_tile, tiled_copy, thread_idx] = args_tuple; + + return (not IsAtomic && // atomic reduction doesn't use smem + is_last_iteration && // smem reduction happens after epilogue loop + (decltype(size<1>(warp_layout_MN))::value > 1 || // smem reduction happens when multiple warps are in N + FinalReduction)); // smem is used to broadcast tile counters for final reduction + } + }; template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class EpilogueTile, - class TiledCopy, - class SrcTensor + class... Args > CUTLASS_DEVICE auto - get_consumer_store_callbacks( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_mnk, - TileCoordMNKL tile_coord_mnkl, - EpilogueTile epi_tile, - TiledCopy tiled_copy, - int thread_idx, - SrcTensor const& tCrC) { - - auto [M, N, K, L] = problem_shape_mnkl; - Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + Layout ref_layout_MN = [&] () { + if constexpr (ReferenceSrc) { return get<0>(args.tiled_copy.get_layoutS_MN()); } + else { return get<0>(args.tiled_copy.get_layoutD_MN()); } + }(); // tile_mn -> tv_idx + + // Get the MN layout + coord of lanes to determine shuffle reduction iterations + using _W = Int; + Layout tv2lane = Layout,_W,_1>,Stride<_1,_0,_0>>{}; // tv_idx -> lane_idx + Layout ref2lane = composition(tv2lane, ref_layout_MN); // tile_mn -> lane_idx + Layout lane_layout_MN = make_layout(filter(get<0>(ref2lane)), filter(get<1>(ref2lane))); // lane_mn -> lane_idx + Layout inv_lane_layout_MN = right_inverse(lane_layout_MN); // lane_idx -> lane_mn + int lane_idx = canonical_lane_idx(); + auto lane_mn = idx2crd(inv_lane_layout_MN(lane_idx), shape(lane_layout_MN)); + + // Get the MN layout + coord of warps to determine smem reduction iterations + Layout tv2warp = Layout,_W,_1>,Stride<_0,_1,_0>>{}; // tv_idx -> warp_idx + Layout ref2warp = composition(tv2warp, ref_layout_MN); // tile_mn -> warp_idx + Layout warp_layout_MN = make_layout(filter(get<0>(ref2warp)), filter(get<1>(ref2warp))); // warp_mn -> warp_idx + Layout inv_warp_layout_MN = right_inverse(warp_layout_MN); // warp_idx -> warp_mn + int warp_idx = args.thread_idx / NumThreadsPerWarp; + auto warp_mn = idx2crd(inv_warp_layout_MN(warp_idx), shape(warp_layout_MN)); + + // Partition output gmem and register tensors + auto [tile_M, tile_N, tile_K] = args.tile_shape_mnk; + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); // (M,N,L) + Tensor gCol_l = local_tile(mCol, take<0,2>(args.tile_shape_mnk), make_coord(m,n,_)); // (CTA_M,CTA_N,L) Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) - mCol, tile_shape_mnk, tile_coord_mnkl, epi_tile, tiled_copy, thread_idx); + gCol_l(_,_,l), args.epi_tile, args.tiled_copy, args.thread_idx); Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) fill(tCrCol, params.reduction_identity); - // Compute tile residues and coordinate tensors for predication - auto [m, n, k, l] = tile_coord_mnkl; - auto residue_mn = make_coord( - M - static_cast(m) * size<0>(tile_shape_mnk), - N - static_cast(n) * size<1>(tile_shape_mnk) - ); - Tensor cCol = make_identity_tensor(take<0,2>(tile_shape_mnk)); - Tensor tCcCol = sm90_partition_for_epilogue(cCol, epi_tile, tiled_copy, thread_idx); + // Partition gmem+smem reduction buffer tensors + Layout gBuf_layout = make_layout(take<0,2>(args.tile_shape_mnk), make_stride(_1{}, _0{})); + Layout mBuf_layout = blocked_product(gBuf_layout, make_layout(ceil_div(make_shape(M,N,L), shape(gBuf_layout)))); + Tensor mBuf = make_tensor(make_gmem_ptr(params.reduction_buffer), mBuf_layout); // (ceil_M,ceil_N,L) + Tensor gBuf_nl = local_tile(mBuf, take<0,2>(args.tile_shape_mnk), make_coord(m,_,_)); // (CTA_M,CTA_N,REST_N,L) + Layout sBuf_layout = blocked_product(gBuf_layout,make_layout(make_shape(_1{},_1{},size<1>(warp_layout_MN)))); // (CTA_M,CTA_N,WARPS_N) - return ConsumerStoreCallbacks(cute::move(tCrCol), cute::move(tCgCol), cute::move(tCcCol), residue_mn, params); + return ConsumerStoreCallbacks( + make_tuple(bool_constant{}, cute::move(tCrCol), args.tCcD, gCol_l, args.cD, gBuf_nl, sBuf_layout, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + args.tile_coord_mnkl, args.residue_mn, args.epi_tile, args.tiled_copy, args.thread_idx), + params + ); } }; diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp index 85b69333d6..bca7d1cedd 100644 --- a/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp +++ b/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp @@ -37,6 +37,7 @@ #pragma once #include "cutlass/cutlass.h" +#include "cutlass/workspace.h" #include "cute/tensor.hpp" @@ -71,7 +72,7 @@ sm90_partition_for_epilogue( TiledCopy tiled_copy, int thread_idx) { ThrCopy thread_copy = tiled_copy.get_thread_slice(thread_idx); - Tensor cT_epi = local_tile(cT, epi_tile, _); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N,...) + Tensor cT_epi = flat_divide(cT, epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N,...) if constexpr (ReferenceSrc) { return thread_copy.partition_S(cT_epi); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,...) } @@ -111,6 +112,84 @@ sm90_partition_for_epilogue( // ///////////////////////////////////////////////////////////////////////////////////////////////// +template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class ResidueMN, + class EpilogueTile +> +struct ProducerLoadArgs { + ProblemShapeMNKL problem_shape_mnkl; + TileShapeMNK tile_shape_mnk; + TileCoordMNKL tile_coord_mnkl; + ResidueMN residue_mn; + EpilogueTile epi_tile; + int thread_idx; + + CUTLASS_DEVICE + ProducerLoadArgs( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + ResidueMN residue_mn, + EpilogueTile epi_tile, + int thread_idx) + : problem_shape_mnkl(problem_shape_mnkl), + tile_shape_mnk(tile_shape_mnk), + tile_coord_mnkl(tile_coord_mnkl), + residue_mn(residue_mn), + epi_tile(epi_tile), + thread_idx(thread_idx) {} +}; + +template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class ResidueMN, + class EpilogueTile, + class TiledCopy, + class CoordTensor, + class ThrCoordTensor, + class ThrSrcTensor +> +struct ConsumerStoreArgs { + ProblemShapeMNKL problem_shape_mnkl; + TileShapeMNK tile_shape_mnk; + TileCoordMNKL tile_coord_mnkl; + ResidueMN residue_mn; + EpilogueTile epi_tile; + TiledCopy tiled_copy; + int thread_idx; + CoordTensor cD; + ThrCoordTensor tCcD; + ThrSrcTensor const& tCrC; + + CUTLASS_DEVICE + ConsumerStoreArgs( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + ResidueMN residue_mn, + EpilogueTile epi_tile, + TiledCopy tiled_copy, + int thread_idx, + CoordTensor cD, + ThrCoordTensor tCcD, + ThrSrcTensor const& tCrC) + : problem_shape_mnkl(problem_shape_mnkl), + tile_shape_mnk(tile_shape_mnk), + tile_coord_mnkl(tile_coord_mnkl), + residue_mn(residue_mn), + epi_tile(epi_tile), + tiled_copy(tiled_copy), + thread_idx(thread_idx), + cD(cD), + tCcD(tCcD), + tCrC(tCrC) {} +}; + template struct Sm90VisitorImplBase { // Shared memory allocation @@ -132,6 +211,46 @@ struct Sm90VisitorImplBase { ); } + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return transform_apply(tuple{}, args, + [&] (auto&& op, auto const& op_args) { + using Op = cute::remove_cvref_t; + size_t op_workspace_size = Op::get_workspace_size(problem_shape, op_args); + return round_nearest(op_workspace_size, MinWorkspaceAlignment); + }, + [&] (auto&&... op_workspace_size) { + return (0 + ... + op_workspace_size); + } + ); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream) { + Status status = Status::kSuccess; + uint8_t* op_workspace = reinterpret_cast(workspace); + return transform_apply(tuple{}, args, + // Initialize each operation's workspace, stopping at the first error + [&] (auto&& op, auto const& op_args) { + if (status != Status::kSuccess) { + return status; + } + + using Op = cute::remove_cvref_t; + status = Op::initialize_workspace(problem_shape, op_args, op_workspace, stream); + if (op_workspace != nullptr) { + size_t op_workspace_size = Op::get_workspace_size(problem_shape, op_args); + op_workspace += round_nearest(op_workspace_size, MinWorkspaceAlignment); + } + return status; + }, + // Return the final status + [&] (auto const&...) { return status; } + ); + } + CUTLASS_HOST_DEVICE Sm90VisitorImplBase() {} @@ -167,13 +286,11 @@ struct Sm90VisitorImpl : Sm90VisitorImplBase { // e.g. for batched beta this must always be true regardless of current batch idx CUTLASS_DEVICE bool is_producer_load_needed() const { - bool needed = false; - for_each(ops, - [&] (auto const& op) { - needed |= op.is_producer_load_needed(); + return apply(ops, + [] (auto const&... op) { + return (false || ... || op.is_producer_load_needed()); } ); - return needed; } // Is a producer TMA load specifically for C needed @@ -183,13 +300,11 @@ struct Sm90VisitorImpl : Sm90VisitorImplBase { // e.g. for batched beta this can be false depending on current batch idx CUTLASS_DEVICE bool is_C_load_needed() const { - bool needed = false; - for_each(ops, - [&] (auto const& op) { - needed |= op.is_C_load_needed(); + return apply(ops, + [] (auto const&... op) { + return (false || ... || op.is_C_load_needed()); } ); - return needed; } // @@ -241,28 +356,12 @@ struct Sm90VisitorImpl : Sm90VisitorImplBase { // Producer load callbacks factory // All operations must redefine this, but most can just dispatch to the base impl - template < - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class EpilogueTile - > + template CUTLASS_DEVICE auto - get_producer_load_callbacks( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_mnk, - TileCoordMNKL tile_coord_mnkl, - EpilogueTile epi_tile, - int thread_idx) { + get_producer_load_callbacks(ProducerLoadArgs const& args) { return transform_apply(ops, [&] (auto& op) { - return op.get_producer_load_callbacks( - problem_shape_mnkl, - tile_shape_mnk, - tile_coord_mnkl, - epi_tile, - thread_idx - ); + return op.get_producer_load_callbacks(args); }, [] (auto&&... callbacks) { auto callbacks_tuple = cute::make_tuple(callbacks...); @@ -293,10 +392,10 @@ struct Sm90VisitorImpl : Sm90VisitorImplBase { // Start of subtile store iteration. Smem broadcasts usually performed here. // Upon entry, all producer loads for this subtile are completed and visible. CUTLASS_DEVICE void - step_begin(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { + previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { for_each(callbacks_tuple, [&] (auto& callbacks) { - callbacks.step_begin(epi_m, epi_n, load_iteration, is_producer_load_needed); + callbacks.previsit(epi_m, epi_n, load_iteration, is_producer_load_needed); } ); } @@ -308,24 +407,49 @@ struct Sm90VisitorImpl : Sm90VisitorImplBase { Array const&... frg_inputs) // depends on the N-naryness of the op = delete; // Must be implemented for each operation - // After D smem store, before smem async fence. Smem reductions usually performed here. + // After visit call, before smem async fence. Smem stores usually performed here. // Upon exit, all smem stores for TMA must have been issued CUTLASS_DEVICE void - step_next(int epi_m, int epi_n, int store_iteration, bool issue_smem_store) { + postvisit(int epi_m, int epi_n, int store_iteration, bool issue_smem_store) { for_each(callbacks_tuple, [&] (auto& callbacks) { - callbacks.step_next(epi_m, epi_n, store_iteration, issue_smem_store); + callbacks.postvisit(epi_m, epi_n, store_iteration, issue_smem_store); } ); } - // End of subtile iteration, before TMA store commit. Aux stores usually performed here + // After async fence, before TMA store commit. Aux stores usually performed here // Upon exit, all TMA stores for this subtile must have been issued CUTLASS_DEVICE void - step_end(int epi_m, int epi_n, int store_iteration, bool issue_tma_store) { + step(int epi_m, int epi_n, int store_iteration, bool issue_tma_store) { + for_each(callbacks_tuple, + [&] (auto& callbacks) { + callbacks.step(epi_m, epi_n, store_iteration, issue_tma_store); + } + ); + } + + // After TMA store commit. Smem reductions usually performed here + // reduction_buffer is an arbitrary smem tensor that can be used for workspace + // It is each nodes reponsibility to assert that this buffer is sufficiently sized + // and to ensure that this buffer is no longer needed upon callback exit + // i.e. results are synchronized and no longer in the reduction buffer + template + CUTLASS_DEVICE void + reduce(STensor&& reduction_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration) { for_each(callbacks_tuple, [&] (auto& callbacks) { - callbacks.step_end(epi_m, epi_n, store_iteration, issue_tma_store); + callbacks.reduce(reduction_buffer, sync_fn, epi_m, epi_n, is_last_iteration); + } + ); + } + + // Collective can query this to determine whether a buffer needs to be freed for reduction + CUTLASS_DEVICE bool + is_reduction_buffer_needed(int epi_m, int epi_n, bool is_last_iteration) const { + return apply(callbacks_tuple, + [&] (auto const&... callbacks) { + return (false || ... || callbacks.is_reduction_buffer_needed(epi_m, epi_n, is_last_iteration)); } ); } @@ -345,33 +469,13 @@ struct Sm90VisitorImpl : Sm90VisitorImplBase { // All operations must redefine this template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class EpilogueTile, - class TiledCopy, - class SrcTensor + class... Args > CUTLASS_DEVICE auto - get_consumer_store_callbacks( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_mnk, - TileCoordMNKL tile_coord_mnkl, - EpilogueTile epi_tile, - TiledCopy tiled_copy, - int thread_idx, - SrcTensor const& tCrC) { + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { return transform_apply(ops, [&] (auto& op) { - return op.template get_consumer_store_callbacks( - problem_shape_mnkl, - tile_shape_mnk, - tile_coord_mnkl, - epi_tile, - tiled_copy, - thread_idx, - tCrC - ); + return op.template get_consumer_store_callbacks(args); }, [] (auto&&... callbacks) { auto callbacks_tuple = cute::make_tuple(callbacks...); @@ -430,33 +534,13 @@ struct Sm90TreeVisitor : Sm90VisitorImpl { template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class EpilogueTile, - class TiledCopy, - class SrcTensor + class... Args > CUTLASS_DEVICE auto - get_consumer_store_callbacks( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_mnk, - TileCoordMNKL tile_coord_mnkl, - EpilogueTile epi_tile, - TiledCopy tiled_copy, - int thread_idx, - SrcTensor const& tCrC) { + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { return ConsumerStoreCallbacks( Sm90VisitorImpl:: - get_consumer_store_callbacks( - problem_shape_mnkl, - tile_shape_mnk, - tile_coord_mnkl, - epi_tile, - tiled_copy, - thread_idx, - tCrC - ) + get_consumer_store_callbacks(args) ); } @@ -502,33 +586,13 @@ struct Sm90SplitTreeVisitor : Sm90VisitorImpl CUTLASS_DEVICE auto - get_consumer_store_callbacks( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_mnk, - TileCoordMNKL tile_coord_mnkl, - EpilogueTile epi_tile, - TiledCopy tiled_copy, - int thread_idx, - SrcTensor const& tCrC) { + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { return ConsumerStoreCallbacks( Sm90VisitorImpl:: - get_consumer_store_callbacks( - problem_shape_mnkl, - tile_shape_mnk, - tile_coord_mnkl, - epi_tile, - tiled_copy, - thread_idx, - tCrC - ) + get_consumer_store_callbacks(args) ); } @@ -601,33 +665,13 @@ struct Sm90TopologicalVisitor : Sm90VisitorImpl { template < bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy - class ProblemShapeMNKL, - class TileShapeMNK, - class TileCoordMNKL, - class EpilogueTile, - class TiledCopy, - class SrcTensor + class... Args > CUTLASS_DEVICE auto - get_consumer_store_callbacks( - ProblemShapeMNKL problem_shape_mnkl, - TileShapeMNK tile_shape_mnk, - TileCoordMNKL tile_coord_mnkl, - EpilogueTile epi_tile, - TiledCopy tiled_copy, - int thread_idx, - SrcTensor const& tCrC) { + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { return ConsumerStoreCallbacks( Sm90VisitorImpl:: - get_consumer_store_callbacks( - problem_shape_mnkl, - tile_shape_mnk, - tile_coord_mnkl, - epi_tile, - tiled_copy, - thread_idx, - tCrC - ) + get_consumer_store_callbacks(args) ); } @@ -663,6 +707,33 @@ struct Sm90VisitorImplBase { }; } + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + size_t workspace_size = 0; + workspace_size += Op0::get_workspace_size(problem_shape, args.op_0); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + status = Op0::initialize_workspace(problem_shape, args.op_0, workspace_ptr + workspace_offset, stream); + workspace_offset += Op0::get_workspace_size(problem_shape, args.op_0); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + CUTLASS_HOST_DEVICE Sm90VisitorImplBase() {} @@ -702,6 +773,43 @@ struct Sm90VisitorImplBase { }; } + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + size_t workspace_size = 0; + workspace_size += Op0::get_workspace_size(problem_shape, args.op_0); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += Op1::get_workspace_size(problem_shape, args.op_1); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + status = Op0::initialize_workspace(problem_shape, args.op_0, workspace_ptr + workspace_offset, stream); + workspace_offset += Op0::get_workspace_size(problem_shape, args.op_0); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = Op1::initialize_workspace(problem_shape, args.op_1, workspace_ptr + workspace_offset, stream); + workspace_offset += Op1::get_workspace_size(problem_shape, args.op_1); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + CUTLASS_HOST_DEVICE Sm90VisitorImplBase() {} @@ -746,6 +854,53 @@ struct Sm90VisitorImplBase { }; } + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + size_t workspace_size = 0; + workspace_size += Op0::get_workspace_size(problem_shape, args.op_0); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += Op1::get_workspace_size(problem_shape, args.op_1); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += Op2::get_workspace_size(problem_shape, args.op_2); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + status = Op0::initialize_workspace(problem_shape, args.op_0, workspace_ptr + workspace_offset, stream); + workspace_offset += Op0::get_workspace_size(problem_shape, args.op_0); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = Op1::initialize_workspace(problem_shape, args.op_1, workspace_ptr + workspace_offset, stream); + workspace_offset += Op1::get_workspace_size(problem_shape, args.op_1); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = Op2::initialize_workspace(problem_shape, args.op_2, workspace_ptr + workspace_offset, stream); + workspace_offset += Op2::get_workspace_size(problem_shape, args.op_2); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + CUTLASS_HOST_DEVICE Sm90VisitorImplBase() {} @@ -795,6 +950,63 @@ struct Sm90VisitorImplBase { }; } + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + size_t workspace_size = 0; + workspace_size += Op0::get_workspace_size(problem_shape, args.op_0); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += Op1::get_workspace_size(problem_shape, args.op_1); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += Op2::get_workspace_size(problem_shape, args.op_2); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += Op3::get_workspace_size(problem_shape, args.op_3); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + status = Op0::initialize_workspace(problem_shape, args.op_0, workspace_ptr + workspace_offset, stream); + workspace_offset += Op0::get_workspace_size(problem_shape, args.op_0); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = Op1::initialize_workspace(problem_shape, args.op_1, workspace_ptr + workspace_offset, stream); + workspace_offset += Op1::get_workspace_size(problem_shape, args.op_1); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = Op2::initialize_workspace(problem_shape, args.op_2, workspace_ptr + workspace_offset, stream); + workspace_offset += Op2::get_workspace_size(problem_shape, args.op_2); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = Op3::initialize_workspace(problem_shape, args.op_3, workspace_ptr + workspace_offset, stream); + workspace_offset += Op3::get_workspace_size(problem_shape, args.op_3); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + CUTLASS_HOST_DEVICE Sm90VisitorImplBase() {} diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index 1236e52e11..33f9aad849 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -171,8 +171,8 @@ struct ReLu> { template struct Clamp { struct Arguments { - T lower_bound = cutlass::platform::numeric_limits::min(); - T upper_bound = cutlass::platform::numeric_limits::max(); + T lower_bound = CUTLASS_STL_NAMESPACE::numeric_limits::min(); + T upper_bound = CUTLASS_STL_NAMESPACE::numeric_limits::max(); }; CUTLASS_HOST_DEVICE @@ -615,12 +615,13 @@ struct dGELU > { template struct dReLU { CUTLASS_HOST_DEVICE - T operator()(T const& d_t, bool d_relu) const { + T operator()(T d_t, bool d_relu) const { return d_relu ? d_t : T(0); } + template CUTLASS_HOST_DEVICE - T operator()(T const& d_t, uint1b_t d_relu) const { + T operator()(T d_t, U d_relu) const { return operator()(d_t, static_cast(d_relu)); } }; @@ -649,6 +650,20 @@ struct dReLU> { return operator()(d_t, preds); } + + template + CUTLASS_HOST_DEVICE + Array operator()(Array const& d_t, Array const& d_relu) const { + Array y; + dReLU relu_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + y[i] = relu_op(d_t[i], d_relu[i]); + } + + return y; + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h b/include/cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h index a33b6ddf8c..754f7035a7 100644 --- a/include/cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h +++ b/include/cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h @@ -44,6 +44,14 @@ namespace cutlass { namespace epilogue { namespace threadblock { +namespace detail { + +struct EVT2xBase { }; + +template +static constexpr bool is_2x_evt_v = platform::is_base_of::value; + +} // namespace detail //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -67,7 +75,8 @@ class EpilogueWithVisitorCallbacks : typename DefaultEpilogue::Shape, DefaultEpilogue::kPartitionsK, typename DefaultEpilogue::WarpMmaOperator, - typename DefaultEpilogue::AccumulatorFragmentIterator> + typename DefaultEpilogue::AccumulatorFragmentIterator>, + public detail::EVT2xBase { public: diff --git a/include/cutlass/epilogue/threadblock/fusion/visitor_load.hpp b/include/cutlass/epilogue/threadblock/fusion/visitor_load.hpp index 54845b2646..71398200ef 100644 --- a/include/cutlass/epilogue/threadblock/fusion/visitor_load.hpp +++ b/include/cutlass/epilogue/threadblock/fusion/visitor_load.hpp @@ -96,7 +96,7 @@ struct VisitorScalarBroadcast { (cute::is_same_v>) || // scalar broadcast, e.g. alpha (cute::is_same_v>) || (cute::is_same_v>)); // batched scalar broadcast, e.g. per-batch alpha - + struct SharedStorage { }; struct Arguments { @@ -132,12 +132,12 @@ struct VisitorScalarBroadcast { CUTLASS_DEVICE Callbacks(Element scalar) : scalar(scalar) {} - + Element scalar; template CUTLASS_DEVICE auto // returns an Array - visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, Array const& frg_acc) { Array frg_scalar; frg_scalar.fill(scalar); @@ -224,7 +224,7 @@ struct VisitorAuxLoad{ // Global load type static int constexpr vec_bits = ThreadMap::kElementsPerAccess * sizeof_bits::value; using VecType = uint_bit_t; - static int constexpr VecLength = sizeof(VecType) / sizeof(Element); + static int constexpr VecLength = sizeof(VecType) / sizeof(Element); CUTLASS_HOST_DEVICE VisitorAuxLoad() { } @@ -272,7 +272,7 @@ struct VisitorAuxLoad{ template CUTLASS_DEVICE auto // returns an Array - visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, Array const& frg_acc) { Tensor tC_rAux_frg = recast>(coalesce(tC_rAux(_,_,_,iter_idx%Stages))); return tC_rAux_frg(frg_idx); @@ -285,9 +285,9 @@ struct VisitorAuxLoad{ gemm::GemmCoord threadblock_tile_offset, int thread_idx, ProblemShape problem_shape - ) { + ) { Tensor mAux = make_tensor( - make_gmem_ptr(params_ptr->ptr_aux), + make_gmem_ptr(params_ptr->ptr_aux), problem_shape, params_ptr->dAux); // (M,N,L) // VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER @@ -299,14 +299,14 @@ struct VisitorAuxLoad{ // Generate the pred tensor Tensor cAux = make_identity_tensor(mAux.shape()); - Tensor tC_cAux = local_partition( + Tensor tC_cAux = outer_partition( group_modes<3,6>(ThreadMap::partition(cAux, thread_idx, threadblock_tile_offset)), Shape>{}, (_0{}) ); return Callbacks< - decltype(tC_gAux), decltype(tC_rAux), + decltype(tC_gAux), decltype(tC_rAux), decltype(tC_cAux), ProblemShape>( cute::move(tC_gAux), cute::move(tC_rAux), @@ -354,7 +354,7 @@ struct VisitorRowBroadcast { CUTLASS_HOST_DEVICE VisitorRowBroadcast(Params const& params, SharedStorage const& shared_storage) : params_ptr(¶ms) { } - + Params const* params_ptr; template @@ -372,7 +372,7 @@ struct VisitorRowBroadcast { tC_cRow(cute::forward(tC_cRow)), n(get<1>(problem_shape)), params_ptr(params_ptr) { } - + GTensor tC_gRow; RTensor tC_rRow; CTensor tC_cRow; @@ -394,7 +394,7 @@ struct VisitorRowBroadcast { template CUTLASS_DEVICE auto // returns an Array - visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, Array const& frg_acc) { Tensor rRow_frg = recast>(coalesce(tC_rRow)); return rRow_frg(column_idx); @@ -409,10 +409,10 @@ struct VisitorRowBroadcast { ProblemShape problem_shape ) { Tensor mRow = make_tensor( - make_gmem_ptr(params_ptr->ptr_row), + make_gmem_ptr(params_ptr->ptr_row), problem_shape, params_ptr->dRow); - + // VECTOR, FRAGMENT_COLUMN Tensor tC_gRow = recast( ThreadMap::partition(mRow, thread_idx, threadblock_tile_offset) @@ -421,14 +421,14 @@ struct VisitorRowBroadcast { // Generate the pred tensor Tensor cRow = make_identity_tensor(mRow.shape()); - Tensor tC_cRow = local_partition( + Tensor tC_cRow = outer_partition( ThreadMap::partition(cRow, thread_idx, threadblock_tile_offset)(_,_,_0{},_0{},_0{},_0{}), Shape>{}, (_0{}) ); - + return Callbacks< - decltype(tC_gRow), decltype(tC_rRow), + decltype(tC_gRow), decltype(tC_rRow), decltype(tC_cRow), ProblemShape>( cute::move(tC_gRow), cute::move(tC_rRow), @@ -472,7 +472,7 @@ struct VisitorColBroadcast { CUTLASS_HOST_DEVICE VisitorColBroadcast(Params const& params, SharedStorage const& shared_storage) : params_ptr(¶ms) { } - + Params const* params_ptr; template @@ -490,7 +490,7 @@ struct VisitorColBroadcast { tC_cCol(cute::forward(tC_cCol)), m(get<0>(problem_shape)), params_ptr(params_ptr) { } - + GTensor tC_gCol; RTensor tC_rCol; CTensor tC_cCol; @@ -510,7 +510,7 @@ struct VisitorColBroadcast { template CUTLASS_DEVICE auto // returns an Array - visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, Array const& frg_acc) { Array frg_col; frg_col.fill(tC_rCol(row_idx,iter_idx)); @@ -529,7 +529,7 @@ struct VisitorColBroadcast { make_gmem_ptr(params_ptr->ptr_col), problem_shape, params_ptr->dCol); - + // VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER Tensor tC_gCol = group_modes<1,4>( ThreadMap::partition(mCol, thread_idx, threadblock_tile_offset)(_0{},_0{},_,_,_,_)); diff --git a/include/cutlass/epilogue/threadblock/fusion/visitor_store.hpp b/include/cutlass/epilogue/threadblock/fusion/visitor_store.hpp index 5edd5cf091..328a49ecc3 100644 --- a/include/cutlass/epilogue/threadblock/fusion/visitor_store.hpp +++ b/include/cutlass/epilogue/threadblock/fusion/visitor_store.hpp @@ -118,7 +118,7 @@ struct VisitorAuxStore{ template CUTLASS_DEVICE auto // returns an Array - visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, Array const& frg_acc, Array const& frg_input) { using ConvertInput = NumericArrayConverter; @@ -152,8 +152,8 @@ struct VisitorAuxStore{ ProblemShape problem_shape ) { Tensor mAux = make_tensor( - make_gmem_ptr(params_ptr->ptr_aux), - problem_shape, + make_gmem_ptr(params_ptr->ptr_aux), + problem_shape, params_ptr->dAux); // (M,N,L) // VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER Tensor tC_gAux = recast(group_modes<3,6>(ThreadMap::partition(mAux, thread_idx, threadblock_tile_offset))); @@ -161,14 +161,14 @@ struct VisitorAuxStore{ // Generate the pred tensor Tensor cAux = make_identity_tensor(mAux.shape()); - Tensor tC_cAux = local_partition( + Tensor tC_cAux = outer_partition( group_modes<3,6>(ThreadMap::partition(cAux, thread_idx, threadblock_tile_offset)), Shape>{}, (_0{}) ); return Callbacks< - decltype(tC_gAux), decltype(tC_rAux), + decltype(tC_gAux), decltype(tC_rAux), decltype(tC_cAux), ProblemShape>( cute::move(tC_gAux), cute::move(tC_rAux), @@ -186,7 +186,7 @@ struct VisitorAuxStore{ ///////////////////////////////////////////////////////////////////////////////////////////////// // Helper functions template < - template class ReduceFn, + template class ReduceFn, int kThreads, class T> CUTLASS_DEVICE void intra_warp_row_reduce(T& value) { @@ -266,7 +266,7 @@ struct VisitorColReduction { CUTLASS_HOST_DEVICE VisitorColReduction(Params const& params, SharedStorage const& shared_storage) : params_ptr(¶ms) { } - + Params const* params_ptr; template @@ -307,7 +307,7 @@ struct VisitorColReduction { template CUTLASS_DEVICE auto // returns an Array - visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, Array const& frg_acc, Array const& frg_input) { @@ -414,13 +414,13 @@ struct VisitorRowReduction { VisitorRowReduction(Params const& params, SharedStorage const& shared_storage) : params_ptr(¶ms), smem_reduce(const_cast(shared_storage.reduction.data())) { } - + Params const* params_ptr; ElementCompute* smem_reduce; template < class RTensorR2S, class STensorR2S, class CTensorR2S, - class STensorS2R, class RTensorS2R, class CTensorS2R, + class STensorS2R, class RTensorS2R, class CTensorS2R, class GTensor, class CTensor, class ProblemShape> struct Callbacks : EmptyCallbacks { CUTLASS_DEVICE @@ -465,7 +465,7 @@ struct VisitorRowReduction { // R->G GTensor tC_gRow; CTensor tC_cRow; - + Params const* params_ptr; int n; int m; @@ -477,10 +477,10 @@ struct VisitorRowReduction { template CUTLASS_DEVICE auto // returns an Array - visit(int iter_idx, int row_idx, int column_idx, int frg_idx, + visit(int iter_idx, int row_idx, int column_idx, int frg_idx, Array const& frg_acc, Array const& frg_input) { - + using ConvertInput = NumericArrayConverter; ConvertInput convert_input{}; Tensor tRS_rRow_frg = recast>(coalesce(tRS_rSrc)); @@ -528,13 +528,13 @@ struct VisitorRowReduction { atomic_reduce(&tC_gRow(j), tSR_rRows(j)); } - } + } } private: template - CUTLASS_DEVICE ElementCompute + CUTLASS_DEVICE ElementCompute reduction(Array& reduce_buffer, Array const& result) { using ReduceInput = RegReduceFn; ReduceInput reduce_input{}; @@ -556,7 +556,7 @@ struct VisitorRowReduction { make_gmem_ptr(params_ptr->ptr_row), problem_shape, params_ptr->dRow); - + // // Step 1: reduce fragment input (Src) into tRS_rSrc // @@ -567,7 +567,7 @@ struct VisitorRowReduction { Tensor cSrc = make_identity_tensor(mRow.shape()); // FRAGMENT_COLUMN, FRAGMENT_ROW, (ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER) Tensor tRS_cSrc = group_modes<2,5>(ThreadMap::partition(cSrc, thread_idx, threadblock_tile_offset)(_0{},_,_,_,_,_)); - + // // Step 2: copy the partial results in tRS_rSrc to sRows in shared memory // @@ -587,18 +587,18 @@ struct VisitorRowReduction { // VECTOR*ACCESS_WIDTH*FRAGMENT_COL,ACCESS_ROWS*WARPS_PER_ROW*GROUPS*CLUSTERS Tensor sRows_nm = coalesce(group_modes<1,5>(group_modes<0,3>(sRows)), Shape<_1,_1>{}); // SMEM_ROW/THREADS,ACCESS_ROWS*WARPS_PER_ROW*GROUPS*CLUSTERS - Tensor tSR_sRows = local_partition(sRows_nm, Shape,_1>{}, thread_idx); + Tensor tSR_sRows = outer_partition(sRows_nm, Shape,_1>{}, thread_idx); // SMEM_ROW/THREADS Tensor tSR_rRows = make_tensor_like(tSR_sRows(_,_0{})); // Coord Tensor cRows_nm = make_identity_tensor(sRows_nm.shape()); - Tensor tSR_cRows = local_partition(cRows_nm, Shape,_1>{}, thread_idx)(_,_0{}); - + Tensor tSR_cRows = outer_partition(cRows_nm, Shape,_1>{}, thread_idx)(_,_0{}); + // // Step 4: atomically reduce the results to global memory // - - Tensor tC_gRow = local_partition( + + Tensor tC_gRow = outer_partition( // Cta tile local_tile( mRow, typename ThreadMap::CtaShapeMNL{}, make_coord(_,_,_),Step<_1,_1, X>{} @@ -608,7 +608,7 @@ struct VisitorRowReduction { )(_0{},_); Tensor cRow = make_identity_tensor(mRow.shape()); - Tensor tC_cRow = local_partition( + Tensor tC_cRow = outer_partition( // Cta tile local_tile( cRow, typename ThreadMap::CtaShapeMNL{}, make_coord(_,_,_), Step<_1,_1, X>{} @@ -680,7 +680,7 @@ struct VisitorScalarReduction { CUTLASS_HOST_DEVICE VisitorScalarReduction(Params const& params, SharedStorage const& shared_storage) : params_ptr(¶ms) { } - + Params const* params_ptr; template @@ -760,7 +760,7 @@ struct VisitorScalarReduction { ); Tensor tC_gScalar = mScalar(_,_,threadblock_tile_offset.k()); - + return Callbacks< decltype(tC_cSrc), decltype(tC_gScalar), ProblemShape>( diff --git a/include/cutlass/fast_math.h b/include/cutlass/fast_math.h index 5bcd0c9784..31f2cb5fda 100644 --- a/include/cutlass/fast_math.h +++ b/include/cutlass/fast_math.h @@ -388,6 +388,12 @@ struct FastDivmod { return quotient; } + /// Alias for `div` to match the interface of FastDivmodU64 + CUTLASS_HOST_DEVICE + int divide(int dividend) const { + return div(dividend); + } + /// Computes integer division and modulus using precomputed values. This is computationally /// inexpensive. /// @@ -529,6 +535,54 @@ struct FastDivmodU64 { ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Object to encapsulate the fast division+modulus operation for 64b integer division +/// in which the divisor is a power of two. +struct FastDivmodU64Pow2 { + + uint64_t divisor; + unsigned int shift_right; + + /// Default ctor + CUTLASS_HOST_DEVICE + FastDivmodU64Pow2(): divisor(0), shift_right(0) { } + + /// Construct the FastDivmod object, in host code ideally. + /// + /// This precomputes some values based on the divisor and is computationally expensive. + CUTLASS_HOST_DEVICE + FastDivmodU64Pow2(uint64_t divisor_): divisor(divisor_), shift_right(FastDivmodU64::integer_log2(divisor_)) { } + + /// Returns the quotient of floor(dividend / divisor) + CUTLASS_HOST_DEVICE + uint64_t divide(uint64_t dividend) const { + return dividend >> shift_right; + } + + /// Computes the remainder given a computed quotient and dividend + CUTLASS_HOST_DEVICE + uint64_t modulus(uint64_t dividend) const { + // See https://docs.nvidia.com/cuda/cuda-c-best-practices-guide/index.html#division-modulo-operations + return dividend & (divisor - 1); + } + + /// Returns the quotient of floor(dividend / divisor) and computes the remainder + CUTLASS_HOST_DEVICE + uint64_t divmod(uint64_t &remainder, uint64_t dividend) const { + uint64_t quotient = divide(dividend); + remainder = modulus(dividend); + return quotient; + } + + /// Computes integer division and modulus using precomputed values. This is computationally + /// inexpensive. + CUTLASS_HOST_DEVICE + void operator()(uint64_t "ient, uint64_t &remainder, uint64_t dividend) const { + quotient = divmod(remainder, dividend); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + /// Computes the coordinate decomposition from a linear index (64-bit linear index => coord) /// /// This decomposition is accelerated by the FastDivmodU64 object. It is assumed that diff --git a/include/cutlass/float8.h b/include/cutlass/float8.h index e2d2245c43..38af9c1824 100644 --- a/include/cutlass/float8.h +++ b/include/cutlass/float8.h @@ -489,6 +489,10 @@ struct alignas(1) float_e4m3_t : float8_base { explicit float_e4m3_t(int x): float_e4m3_t(float(x)) { } + CUTLASS_HOST_DEVICE + explicit float_e4m3_t(unsigned x): float_e4m3_t(float(x)) { + } + /// E5M2 conversion. Defined after float_e5m2_t is defined. CUTLASS_HOST_DEVICE explicit float_e4m3_t(float_e5m2_t x); @@ -694,6 +698,10 @@ struct alignas(1) float_e5m2_t : float8_base { explicit float_e5m2_t(int x): float_e5m2_t(float(x)) { } + CUTLASS_HOST_DEVICE + explicit float_e5m2_t(unsigned x): float_e5m2_t(float(x)) { + } + /// E4M3 conversion CUTLASS_HOST_DEVICE explicit float_e5m2_t(float_e4m3_t x); diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h index 0554dd7adb..40cfef28f4 100644 --- a/include/cutlass/functional.h +++ b/include/cutlass/functional.h @@ -228,6 +228,25 @@ struct divides { } }; +/// reciprocal_approximate +template +struct reciprocal_approximate { + CUTLASS_HOST_DEVICE + T operator()(T lhs) const { + return divide(T(1), lhs); + } +}; + +template <> +struct reciprocal_approximate { + CUTLASS_HOST_DEVICE + float operator()(float lhs) const { + float ret; + ret = 1.0f / lhs; + return ret; + } +}; + /// Negate template struct negate { @@ -273,7 +292,7 @@ struct less { } }; -template +template struct maximum { CUTLASS_HOST_DEVICE T operator()(T const &lhs, T const &rhs) const { @@ -281,8 +300,17 @@ struct maximum { } }; -// Maximum with nan propogation -// To propgate the NANs, the "max" of a two element that contains NaNs should also return a NaN +// This is a subclass and not an alias +// in order to work around a known Clang issue, +// where a template template parameter with one template parameter +// does not match classes that take multiple template parameters +// but have defaults for all but the first. +template +struct maximum_with_default_nan_propagation : public maximum +{}; + +// Maximum with nan propagation +// To propagate NANs, the "max" of a two element that contains NaNs should also return a NaN template struct maximum { CUTLASS_HOST_DEVICE @@ -319,10 +347,21 @@ struct maximum { } }; +// This is a subclass and not an alias +// in order to work around a known Clang issue, +// where a template template parameter with one template parameter +// does not match classes that take multiple template parameters +// but have defaults for all but the first. template -using maximum_with_nan_propogation = maximum; +struct maximum_with_nan_propagation : maximum +{}; -template +// This alias exists for backwards compatibility only. +// Please use the correctly spelled class template above. +template +using maximum_with_nan_propogation = maximum_with_nan_propagation; + +template struct minimum{ CUTLASS_HOST_DEVICE T operator()(T const &lhs, T const &rhs) const { @@ -350,24 +389,24 @@ struct minimum { } }; -template +template struct maximum_absolute_value { CUTLASS_HOST_DEVICE float operator()(T const &lhs, T const &rhs) const { absolute_value_op abs_op; - maximum max_op; + maximum max_op; return max_op(abs_op(lhs), abs_op(rhs)); } }; // assumes the left operand is already an absolute value -template +template struct maximum_absolute_value_reduction { CUTLASS_HOST_DEVICE float operator()(T const &lhs, T const &rhs) const { absolute_value_op abs_op; - maximum max_op; + maximum max_op; return max_op(lhs, abs_op(rhs)); } @@ -382,6 +421,15 @@ struct multiply_add { } }; +// Fused multiply-add that takes exactly one template parameter. +// This is useful for working around a known Clang issue, +// where a template template parameter with one template parameter +// does not match classes that take multiple template parameters +// but have defaults for all but the first. +template +struct homogeneous_multiply_add : public multiply_add +{}; + /// Fused multiply-add template struct multiply_add_relu0 { @@ -582,6 +630,14 @@ struct atomic_maximum { } }; +// is_atomic +template +struct is_atomic : platform::false_type {}; +template +struct is_atomic> : platform::true_type {}; +template +struct is_atomic> : platform::true_type {}; + ///////////////////////////////////////////////////////////////////////////////////////////////// // diff --git a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl index f0df56d408..8fa85d8d56 100644 --- a/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl +++ b/include/cutlass/gemm/collective/builders/sm90_gmma_builder.inl @@ -69,11 +69,11 @@ constexpr int compute_stage_count_or_override(StageCountAutoCarveout stage_count) { // 32 bytes to account for barriers etc. constexpr int stage_barrier_bytes = 32; - constexpr int a_bytes = static_cast(sizeof(ElementA)); - constexpr int b_bytes = static_cast(sizeof(ElementB)); + constexpr int a_bits = static_cast(sizeof_bits::value); + constexpr int b_bits = static_cast(sizeof_bits::value); constexpr int stage_bytes = - (a_bytes * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + - (b_bytes * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 + + (b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 + stage_barrier_bytes; return (CapacityBytes - carveout_bytes) / stage_bytes; @@ -95,13 +95,39 @@ is_warpspecialized_transpose_B(){ constexpr bool IsInputSizeTwoBytes = is_input_size_two_bytes(); constexpr bool IsLayoutAmnBmn = cutlass::gemm::detail::is_mn_major_A() && cutlass::gemm::detail::is_mn_major_B(); - constexpr bool IsWarpSpecialized = cute::is_base_of_v || - cute::is_base_of_v || - cute::is_base_of_v; + constexpr bool IsWarpSpecialized = cute::is_base_of_v || + cute::is_base_of_v || + cute::is_base_of_v || + cute::is_base_of_v || + cute::is_base_of_v || + cute::is_base_of_v; constexpr bool IsWarpSpecializedTransposeB = !IsInputSizeTwoBytes && IsLayoutAmnBmn && IsWarpSpecialized; return IsWarpSpecializedTransposeB; } +template +struct Sm90TypeWidths { + static constexpr bool IsElementALarger = (cute::sizeof_bits_v) > cute::sizeof_bits_v; + using WideType = cute::conditional_t; + using NarrowType = cute::conditional_t; +}; + + +template +constexpr bool +sm90_is_narrow_type_k_major() { + using Widths = Sm90TypeWidths; + using NarrowType = typename Widths::NarrowType; + using WideType = typename Widths::WideType; + + constexpr bool IsANarrow = cute::is_same_v; + constexpr cute::GMMA::Major NarrowGmmaMajor = IsANarrow ? detail::gmma_rs_tag_to_major_A() : + detail::gmma_rs_tag_to_major_B(); + + constexpr bool IsNarrowLayoutKMajor = NarrowGmmaMajor == cute::GMMA::Major::K; + return IsNarrowLayoutKMajor; +} + } // namespace detail ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -297,6 +323,135 @@ struct CollectiveBuilder< ///////////////////////////////////////////////////////////////////////////////////////////////// +// GMMA_TMA_WS_RS Mixed GEMM +template < + class ElementPairA_, + class GmemLayoutPairA_, + int AlignmentA, + class ElementPairB_, + class GmemLayoutPairB_, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementPairA_, + GmemLayoutPairA_, + AlignmentA, + ElementPairB_, + GmemLayoutPairB_, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + (cute::is_same_v || + cute::is_same_v || + cute::is_same_v)> +> { + +public: + static constexpr bool IsATransformed = cute::sizeof_bits_v < cute::sizeof_bits_v; + + // Split out items for processessing, no splitting for now since scales aren't supported. + using ElementA = ElementPairA_; + using ElementB = ElementPairB_; + + using GmemLayoutA = GmemLayoutPairA_; + using GmemLayoutB = GmemLayoutPairB_; + + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(detail::is_aligned(), + "Should meet TMA alignment requirement\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false == 0, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_rs_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_rs_tag_to_major_B(); + static constexpr bool IsWarpSpecializedTransposeB = detail::is_warpspecialized_transpose_B< + ElementA, GmemLayoutA, ElementB, GmemLayoutB, KernelScheduleType>(); + static_assert(!IsWarpSpecializedTransposeB, "Mixed input GEMM does not support WS transpose B."); + + // If A is scaled, then we don't need to swap. Otherwise, we must ensure B goes to RF and we must swap the operands. + static constexpr bool SwapAB = !IsATransformed; + static_assert(detail::sm90_is_narrow_type_k_major(), "The narrow type must be K-major."); + + static_assert((IsATransformed && (cute::sizeof_bits_v <= 8) && (sizeof(ElementB) == 2)) || + (!IsATransformed && (cute::sizeof_bits_v <= 8) && (sizeof(ElementA) == 2)) || + (GmmaMajorA == cute::GMMA::Major::K && GmmaMajorB == cute::GMMA::Major::K), + "The unscaled element must be 2 bytes OR both inputs must be K-major"); + // When we relax the above assertion, we must handle setting the tile mma GmmaMajorB correctly. + static constexpr cute::GMMA::Major TiledMmaGmmaMajorB = SwapAB ? GmmaMajorA : GmmaMajorB; + + using ElementMma = cute::conditional_t; + using AtomLayoutMNK = cute::conditional_t, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::rs_op_selector< + ElementMma, ElementMma, ElementAccumulator, TileShape_MNK, GMMA::Major::K, TiledMmaGmmaMajorB>(), AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + using SmemLayoutAtomA = decltype(detail::rs_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), IsWarpSpecializedTransposeB>()); + using SmemLayoutAtomB = decltype(detail::rs_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), IsWarpSpecializedTransposeB>()); + + using RealElementA = cute::conditional_t; + using RealElementB = cute::conditional_t; + static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); + + using SmemCopyAtomA = cute::conditional_t>; + using SmemCopyAtomB = cute::conditional_t, void>; + + using DispatchPolicy = MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput; + + // We pack the scale data with the operand that will be optionally scaled and converted before MMA. + using StrideAPair = TagToStrideA_t; + using StrideBPair = TagToStrideB_t; + + using GmemTiledCopyAPair = GmemTiledCopyA; + using SmemLayoutAtomAPair = SmemLayoutAtomA; + using SmemCopyAtomAPair = SmemCopyAtomA; + + using GmemTiledCopyBPair = GmemTiledCopyB; + using SmemLayoutAtomBPair = SmemLayoutAtomB; + using SmemCopyAtomBPair = SmemCopyAtomB; + + + // If the src type of the converter is the same as ElementA, + // interpret this as if the user wanted to apply the scale to the A matrix. + using CollectiveOp = CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementPairA_, + StrideAPair, + ElementPairB_, + StrideBPair, + TiledMma, + GmemTiledCopyAPair, + SmemLayoutAtomAPair, + SmemCopyAtomAPair, + cute::identity, + GmemTiledCopyBPair, + SmemLayoutAtomBPair, + SmemCopyAtomBPair, + cute::identity + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + // GMMA_TMA_WS_FP8_FAST_ACCUM_SS template < class ElementA, @@ -490,7 +645,8 @@ template < class StageCountType, class KernelScheduleType > -struct CollectiveBuilder< +struct [[deprecated("Use one of KernelCpAsyncWarpSpecialized schedules instead")]] +CollectiveBuilder< arch::Sm90, arch::OpClassTensorOp, ElementA, @@ -506,6 +662,61 @@ struct CollectiveBuilder< KernelScheduleType, cute::enable_if_t< cute::is_same_v> +> { + // Map to warp-specialized kernels for better performance + using CollectiveOp = typename CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutA, + AlignmentA, + ElementB, + GmemLayoutB, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelCpAsyncWarpSpecialized + >::CollectiveOp; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_CpAsync_WS_SS +template < + class ElementA, + class GmemLayoutA, + int AlignmentA, + class ElementB, + class GmemLayoutB, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutA, + AlignmentA, + ElementB, + GmemLayoutB, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + (cute::is_same_v || + cute::is_same_v || + cute::is_same_v) && + not detail::is_use_rmem_A() + > > { static_assert(is_static::value); static_assert(is_static::value); @@ -523,14 +734,19 @@ struct CollectiveBuilder< static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); + using AtomLayoutMNK = cute::conditional_t, + Layout(TileShape_MNK{}) < 128) ? 1 : 2>,_1,_1>>, Layout>>; + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< - MmaElementA, MmaElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>())); + MmaElementA, MmaElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); + + static constexpr int NumLoadWarpGroups = cute::is_same_v ? 2 : 1; using GmemTiledCopyA = decltype(detail::make_cp_async_gmem_tiled_copy< - 128, ElementA, AlignmentA, TagToStrideA_t, + NumThreadsPerWarpGroup * NumLoadWarpGroups, ElementA, AlignmentA, TagToStrideA_t, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using GmemTiledCopyB = decltype(detail::make_cp_async_gmem_tiled_copy< - 128, ElementB, AlignmentB, TagToStrideB_t, + NumThreadsPerWarpGroup * NumLoadWarpGroups, ElementB, AlignmentB, TagToStrideB_t, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); using SmemLayoutAtomA = decltype(detail::ss_smem_selector< @@ -541,8 +757,11 @@ struct CollectiveBuilder< static constexpr int PipelineStages = detail::compute_stage_count_or_override< detail::sm90_smem_capacity_bytes, MmaElementA, MmaElementB, TileShape_MNK>(StageCountType{}); + using DispatchPolicy = MainloopSm90CpAsyncGmmaWarpSpecialized< + PipelineStages, ClusterShape_MNK, KernelScheduleType>; + using CollectiveOp = CollectiveMma< - MainloopSm90CpAsyncGmma, + DispatchPolicy, TileShape_MNK, ElementA, TagToStrideA_t, @@ -562,6 +781,110 @@ struct CollectiveBuilder< ///////////////////////////////////////////////////////////////////////////////////////////////// +// GMMA_CpAsync_WS_RS +template < + class ElementA, + class GmemLayoutA, + int AlignmentA, + class ElementB, + class GmemLayoutB, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutA, + AlignmentA, + ElementB, + GmemLayoutB, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + (cute::is_same_v || + cute::is_same_v || + cute::is_same_v) && + detail::is_use_rmem_A() + > +> { + static_assert(is_static::value); + static_assert(is_static::value); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false == 0, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + + // For fp32 types, map to tf32 MMA value type + using MmaElementA = cute::conditional_t, tfloat32_t, ElementA>; + using MmaElementB = cute::conditional_t, tfloat32_t, ElementB>; + + static_assert(detail::is_aligned(), + "Minimum alignment required for cp.async is 4B."); + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_rs_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_rs_tag_to_major_B(); + static constexpr bool SwapAB = detail::is_swapAB(); + static constexpr bool IsWarpSpecializedTransposeB = detail::is_warpspecialized_transpose_B< + ElementA, GmemLayoutA, ElementB, GmemLayoutB, KernelScheduleType>(); + + using AtomLayoutMNK = cute::conditional_t, + Layout(TileShape_MNK{}) < 128) ? 1 : 2>,_1,_1>>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::rs_op_selector< + MmaElementA, MmaElementB, ElementAccumulator, TileShape_MNK, GMMA::Major::K, GMMA::Major::K>(), AtomLayoutMNK{})); + + static constexpr int NumLoadWarpGroups = 1; + + using GmemTiledCopyA = decltype(detail::make_cp_async_gmem_tiled_copy< + NumThreadsPerWarpGroup * NumLoadWarpGroups, ElementA, AlignmentA, TagToStrideA_t, + decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using GmemTiledCopyB = decltype(detail::make_cp_async_gmem_tiled_copy< + NumThreadsPerWarpGroup * NumLoadWarpGroups, ElementB, AlignmentB, TagToStrideB_t, + decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + using SmemLayoutAtomA = decltype(detail::rs_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), IsWarpSpecializedTransposeB>()); + using SmemLayoutAtomB = decltype(detail::rs_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), IsWarpSpecializedTransposeB>()); + + static constexpr int PipelineStages = detail::compute_stage_count_or_override< + detail::sm90_smem_capacity_bytes, MmaElementA, MmaElementB, TileShape_MNK>(StageCountType{}); + + using DispatchPolicy = MainloopSm90CpAsyncGmmaRmemAWarpSpecialized< + PipelineStages, ClusterShape_MNK, KernelScheduleType>; + + using SmemCopyAtomA = cute::conditional_t>; + using SmemCopyAtomB = cute::conditional_t, void>; + + using CollectiveOp = CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + TagToStrideA_t, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + // GMMA auto kernel schedule template < class ElementA, @@ -601,15 +924,26 @@ struct CollectiveBuilder< static constexpr bool IsTmaCompatible = detail::is_aligned< ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(); +static constexpr bool IsMixedWidthInput = cute::sizeof_bits_v != cute::sizeof_bits_v; + #if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 1))) // Persistent schedules perform best for CUDA Toolkits with version >= 12.1 // KernelTmaWarpSpecializedCooperative requires TileShape_M to be at least 128 - using KernelWarpSpecializedSchedule = cute::conditional_t(TileShape_MNK{}) == Int<64>{}, + using KernelTmaWarpSpecializedScheduleSameInput = cute::conditional_t(TileShape_MNK{}) == Int<64>{}, KernelTmaWarpSpecializedPingpong, KernelTmaWarpSpecializedCooperative>; + + using KernelTmaWarpSpecializedScheduleMixedInput = cute::conditional_t(TileShape_MNK{}) == Int<64>{}, + KernelTmaWarpSpecializedPingpongMixedInput, KernelTmaWarpSpecializedCooperativeMixedInput>; + + using KernelTmaWarpSpecializedSchedule = cute::conditional_t; #else - using KernelWarpSpecializedSchedule = KernelTmaWarpSpecialized; + using KernelTmaWarpSpecializedSchedule = cute::conditional_t; #endif + // Non-persistent schedule is a safer choice for CpAsync kernels due to register pressure + using KernelCpAsyncWarpSpecializedSchedule = KernelCpAsyncWarpSpecialized; + using KernelSchedule = cute::conditional_t; + static_assert((cute::is_same_v && IsMixedWidthInput) || !IsMixedWidthInput, "Only TMA warp specialized kernels are supported for mixed width input."); using CollectiveOp = typename CollectiveBuilder< arch::Sm90, arch::OpClassTensorOp, @@ -623,7 +957,7 @@ static constexpr bool IsTmaCompatible = detail::is_aligned< TileShape_MNK, ClusterShape_MNK, StageCountType, - cute::conditional_t + KernelSchedule >::CollectiveOp; }; diff --git a/include/cutlass/gemm/collective/collective_mma.hpp b/include/cutlass/gemm/collective/collective_mma.hpp index de6c77b4c5..985e0ecc48 100644 --- a/include/cutlass/gemm/collective/collective_mma.hpp +++ b/include/cutlass/gemm/collective/collective_mma.hpp @@ -67,9 +67,11 @@ struct CollectiveMma { #include "cutlass/gemm/collective/sm70_mma_twostage.hpp" #include "cutlass/gemm/collective/sm80_mma_multistage.hpp" -#include "cutlass/gemm/collective/sm90_mma_multistage_gmma_ss.hpp" +#include "cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp" #include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp" #include "cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp" +#include "cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp" #include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp" #include "cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp new file mode 100644 index 0000000000..1e1c5e6dff --- /dev/null +++ b/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_rs_warpspecialized.hpp @@ -0,0 +1,662 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/transform/collective/sm90_wgmma_transpose.hpp" +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template < + int Stages, + class ClusterShape_, + class TileShape_, + class KernelSchedule, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90CpAsyncGmmaRmemAWarpSpecialized, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90CpAsyncGmmaRmemAWarpSpecialized; + using TileShape = TileShape_; + using ClusterShape = ClusterShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + + // Swap and transpose A/B for A k-major layout and B mn-major layout since WGMMA is k-major only (e.g. tf32, Fp32, Int8, Fp8 WGMMA) + static constexpr bool IsLayoutAkBmn = + cute::is_same_v, layout::RowMajor> && + cute::is_same_v, layout::RowMajor>; + + static constexpr bool IsInputSizeTwoBytes = sizeof(ElementA) == 2 && sizeof(ElementB) == 2; + static constexpr bool SwapAB = !IsInputSizeTwoBytes && IsLayoutAkBmn; + using InternalGmemTiledCopyA = cute::conditional_t; + using InternalGmemTiledCopyB = cute::conditional_t; + using InternalSmemLayoutAtomA = cute::conditional_t; + using InternalSmemLayoutAtomB = cute::conditional_t; + using InternalSmemCopyAtomA = cute::conditional_t; + using InternalSmemCopyAtomB = cute::conditional_t; + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using ConvertedElementA = cute::conditional_t>>; + using ConvertedElementB = cute::conditional_t>>; + using InternalElementA = cute::conditional_t; + using InternalElementB = cute::conditional_t; + using InternalStrideA = cute::conditional_t; + using InternalStrideB = cute::conditional_t; + + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineAsync; + using PipelineState = typename MainloopPipeline::PipelineState; + using PipelineParams = typename MainloopPipeline::Params; + + static_assert(rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + using SmemLayoutA = decltype(tile_to_shape( + InternalSmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}))); + using SmemLayoutB = decltype(tile_to_shape( + InternalSmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}))); + + // If A mn-layout and B mn-layout, transposing B matrix since WGMMA is k-major only (e.g. tf32, fp32, fp8, int8). + static constexpr bool IsLayoutAmnBmn = + cute::is_same_v, layout::ColumnMajor> && + cute::is_same_v, layout::RowMajor>; + static constexpr bool TransposeB = !IsInputSizeTwoBytes && IsLayoutAmnBmn; + using TransposeOperandB = decltype(cutlass::transform::collective::detail::make_transpose_operand_b( + 0, 0, TiledMma{}, SmemLayoutB{}, InternalSmemLayoutAtomB{}, + InternalElementB{}, cute::bool_constant{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(not cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source A from rmem and B operand from smem_desc for this mainloop."); + + using GmmaSmemLayoutAtomB = decltype(transform::collective::detail::gmma_smem_transpose_or_passthrough< + TransposeB, InternalSmemLayoutAtomB, InternalElementB>()); + + // SmemLayoutB for GMMA is different from SmemLayoutB for TMA if TransposeB + using GmmaSmemLayoutB = decltype(tile_to_shape( + GmmaSmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}))); + + static_assert(!SwapAB || !TransposeB, "Cannot SwapAB and TransposeB at the same time."); + static_assert(TransposeB xor (cute::is_same_v), + "Should be same layout if not TransposeB."); + static_assert(!TransposeB || ((size<1>(SmemLayoutB{}) * sizeof_bits::value) / 8) == 128, + "SmemLayoutB K must be 128bytes to be transposed."); + static_assert(!transform::collective::detail::use_universal_transposition(), + "Warp specialized ARF kernels have not supported universal B transposition yet."); + + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct<256> { + cute::array_aligned, 256> smem_A; + cute::array_aligned, 256> smem_B; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A = nullptr; + StrideA dA{}; + ElementB const* ptr_B = nullptr; + StrideB dB{}; + uint32_t mma_promotion_interval = 4; + }; + + // Device side kernel params + struct Params { + InternalElementA const* ptr_A = nullptr; + InternalStrideA dA{}; + InternalElementB const* ptr_B = nullptr; + InternalStrideB dB{}; + uint32_t mma_promotion_interval = 4; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + [[maybe_unused]] ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + if constexpr (not SwapAB) { + return { + reinterpret_cast(args.ptr_A), + args.dA, + reinterpret_cast(args.ptr_B), + args.dB + }; + } + else { + return { + reinterpret_cast(args.ptr_B), + args.dB, + reinterpret_cast(args.ptr_A), + args.dA + }; + } + } + + template + CUTLASS_HOST_DEVICE static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class TensorA, + class TensorB, + class KTileIterator, + class ResidueMNK + > + CUTLASS_DEVICE void + load( + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + TensorA const& gA_in, + TensorB const& gB_in, + KTileIterator k_tile_iter, int k_tile_count, + ResidueMNK residue_mnk, + int thread_idx, + TensorStorage& shared_tensors) + { + using namespace cute; + + static_assert(is_gmem::value, "A tensor must be gmem resident."); + static_assert(is_gmem::value, "B tensor must be gmem resident."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Shift tensor so residue_k is at origin (Can't read any k_coord < residue_k) + // This aligns the tensor with BLK_K for all but the 0th k_tile + Tensor gA = domain_offset(make_coord(0, get<2>(residue_mnk), 0), gA_in); + Tensor gB = domain_offset(make_coord(0, get<2>(residue_mnk), 0), gB_in); + + // Partition the copying of A and B tiles across the threads + InternalGmemTiledCopyA gmem_tiled_copy_a; + InternalGmemTiledCopyB gmem_tiled_copy_b; + auto gmem_thr_copy_a = gmem_tiled_copy_a.get_slice(thread_idx); + auto gmem_thr_copy_b = gmem_tiled_copy_b.get_slice(thread_idx); + + Tensor tAgA = gmem_thr_copy_a.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) + Tensor tAsA = gmem_thr_copy_a.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,PIPE) + Tensor tBgB = gmem_thr_copy_b.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) + Tensor tBsB = gmem_thr_copy_b.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,PIPE) + + // Allocate predicate tensors for m and n + Tensor tApA = make_tensor(make_shape(size<1>(tAsA), size<2>(tAsA)), Stride<_1,_0>{}); + Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); + + // Construct identity layout for sA and sB + Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cB = make_identity_tensor(make_shape(size<0>(sB), size<1>(sB))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + Tensor tAcA = gmem_thr_copy_a.partition_S(cA); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tBcB = gmem_thr_copy_b.partition_S(cB); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Set predicates for m bounds + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<0>(tApA); ++m) { + tApA(m,0) = get<0>(tAcA(0,m,0)) < get<0>(residue_mnk); // blk_m coord < residue_m + } + // Set predicates for n bounds + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<0>(tBpB); ++n) { + tBpB(n,0) = get<0>(tBcB(0,n,0)) < get<1>(residue_mnk); // blk_n coord < residue_n + } + + // 0-th stage with predication on k to account for residue + { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + int write_stage = smem_pipe_write.index(); + + // Copy gmem to smem for *k_tile_iter, predicating for k residue + Tensor tAgAk = tAgA(_,_,_,*k_tile_iter); + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tAsA); ++k) { + if (get<1>(tAcA(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gA shifted) + copy_if(gmem_tiled_copy_a, tApA(_,k), tAgAk(_,_,k), tAsA(_,_,k,write_stage)); + } + else { + clear(tAsA(_,_,k,write_stage)); + } + } + Tensor tBgBk = tBgB(_,_,_,*k_tile_iter); + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tBsB); ++k) { + if (get<1>(tBcB(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gB shifted) + copy_if(gmem_tiled_copy_b, tBpB(_,k), tBgBk(_,_,k), tBsB(_,_,k,write_stage)); + } + else { + clear(tBsB(_,_,k,write_stage)); + } + } + + ++k_tile_iter; + --k_tile_count; + + // UNLOCK smem_pipe_write + pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive); + + // Advance smem_pipe_write + ++smem_pipe_write; + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + int write_stage = smem_pipe_write.index(); + + // Copy gmem to smem for *k_tile_iter + copy_if(gmem_tiled_copy_a, tApA, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy_if(gmem_tiled_copy_b, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + + // UNLOCK smem_pipe_write + pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive); + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail( + MainloopPipeline pipeline, + PipelineState smem_pipe_write) { + // Issue the epilogue waits + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) + { + using namespace cute; + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(rank(InternalSmemLayoutAtomA{}) == 2, "InternalSmemLayoutAtomA must be rank 2."); + static_assert(rank(InternalSmemLayoutAtomB{}) == 2, "InternalSmemLayoutAtomB must be rank 2."); + static_assert(!cute::is_void_v, + "SM90 GMMA mainloops must specify a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + // Obtain warp index + int warp_idx = canonical_warp_idx_sync(); + [[maybe_unused]] int warp_group_thread_idx = thread_idx % 128; + + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) + Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_M,BLK_K,PIPE) + + // If TransposeB, GMMA will read from transposed B layout SMEM + Tensor gmma_sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), GmmaSmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + // Allocate fragments and descriptors + Tensor tCsA = thread_mma.partition_A(sA); + Tensor tCrA = thread_mma.partition_fragment_A(sA(_,_,Int<0>{})); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(gmma_sB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + // + // Copy Atom A retiling + // + + + auto smem_tiled_copy_A = make_tiled_copy_A(InternalSmemCopyAtomA{}, tiled_mma); + + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K + CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + TransposeOperandB transpose = cutlass::transform::collective::detail::make_transpose_operand_b( + warp_idx, warp_group_thread_idx, tiled_mma, SmemLayoutB{}, + InternalSmemLayoutAtomB{}, InternalElementB{}, + cute::bool_constant{}); + + warpgroup_fence_operand(accum); + // first k tile + { + pipeline.consumer_wait(smem_pipe_read); + + int read_stage = smem_pipe_read.index(); + + ++smem_pipe_read; + + bool skip_wait = (pipeline.consumer_try_wait(smem_pipe_read) == BarrierStatus::WaitDone); + + // copy smem->rmem for A operand + copy(smem_tiled_copy_A, tCsA(_,_,0,read_stage), tCrA_copy_view(_,_,0)); + // transpose B operand in SMEM + transpose(sB, gmma_sB, read_stage, 0); + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA) - 1; ++k_block) { + copy(smem_tiled_copy_A, tCsA(_,_,k_block + 1,read_stage), tCrA_copy_view(_,_,k_block + 1)); + if (k_block == 0) { + transpose(sB, gmma_sB, read_stage, 1); + transpose.synchronize(); + } + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + } + + warpgroup_wait<2>(); + + + if (k_tile_count - 1 > 0) { + if (!skip_wait) { + pipeline.consumer_wait(smem_pipe_read); + } + copy(smem_tiled_copy_A, tCsA(_,_,0,smem_pipe_read.index()), tCrA_copy_view(_,_,0)); + transpose(sB, gmma_sB, smem_pipe_read.index(), 0); + } + + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,size<2>(tCrA) - 1), tCrB(_,_,size<2>(tCrA) - 1,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + warpgroup_wait<2>(); + } + + warpgroup_fence_operand(accum); + // Mainloop GMMAs + --k_tile_count; + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 1; --k_tile_count) { + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + ++smem_pipe_read; + bool skip_wait = (pipeline.consumer_try_wait(smem_pipe_read) == BarrierStatus::WaitDone); + + warpgroup_fence_operand(accum); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + if (k_block == size<2>(tCrA) - 1) { + if (!skip_wait) { + pipeline.consumer_wait(smem_pipe_read); + } + copy(smem_tiled_copy_A, tCsA(_,_,0,smem_pipe_read.index()), tCrA_copy_view(_,_,0)); + // transpose B operand in SMEM + transpose(sB, gmma_sB, smem_pipe_read.index(), 0); + } else { + copy(smem_tiled_copy_A, tCsA(_,_,k_block + 1,read_stage), tCrA_copy_view(_,_,k_block + 1)); + // transpose B operand in SMEM + if (k_block < 2) { + transpose.synchronize(k_block); // make transpose of k_block available + } + if (k_block == 0) { + transpose(sB, gmma_sB, read_stage, 1); + } + } + + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + warpgroup_wait<2>(); + if (k_block == 1) { + // release prior barrier + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } + warpgroup_fence_operand(accum); + + } + + warpgroup_fence_operand(accum); + + if (k_tile_count > 0) { + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + warpgroup_fence_operand(accum); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA) - 1; ++k_block) { + copy(smem_tiled_copy_A, tCsA(_,_,k_block + 1,read_stage), tCrA_copy_view(_,_,k_block + 1)); + if (k_block < 2) { + transpose.synchronize(k_block); // make k_block transpose available + } + if (k_block == 0) { + transpose(sB, gmma_sB, read_stage, 1); + } + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + warpgroup_wait<2>(); + if (k_block == 1) { + // release prior barrier + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } + + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,size<2>(tCrA) - 1), tCrB(_,_,size<2>(tCrA) - 1,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + warpgroup_wait<2>(); + warpgroup_fence_operand(accum); + } + + warpgroup_fence_operand(accum); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss.hpp b/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss.hpp deleted file mode 100644 index b842eace70..0000000000 --- a/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss.hpp +++ /dev/null @@ -1,609 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -#pragma once - -#include "cutlass/cutlass.h" -#include "cutlass/gemm/dispatch_policy.hpp" -#include "cutlass/pipeline/pipeline.hpp" -#include "cute/arch/cluster_sm90.hpp" -#include "cutlass/arch/reg_reconfig.h" - -#include "cute/arch/copy_sm90.hpp" -#include "cute/atom/mma_atom.hpp" -#include "cute/algorithm/gemm.hpp" - -#include - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass::gemm::collective { -using namespace cute; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - int Stages, - class ClusterShape, - class TileShape_, - class ElementA_, - class StrideA_, - class ElementB_, - class StrideB_, - class TiledMma_, - class GmemTiledCopyA_, - class SmemLayoutAtomA_, - class SmemCopyAtomA_, - class TransformA_, - class GmemTiledCopyB_, - class SmemLayoutAtomB_, - class SmemCopyAtomB_, - class TransformB_> -struct CollectiveMma< - MainloopSm90CpAsyncGmmaUnpredicated, - TileShape_, - ElementA_, - StrideA_, - ElementB_, - StrideB_, - TiledMma_, - GmemTiledCopyA_, - SmemLayoutAtomA_, - SmemCopyAtomA_, - TransformA_, - GmemTiledCopyB_, - SmemLayoutAtomB_, - SmemCopyAtomB_, - TransformB_> -{ - // - // Type Aliases - // - using DispatchPolicy = MainloopSm90CpAsyncGmmaUnpredicated; - using TileShape = TileShape_; - using ElementA = ElementA_; - using StrideA = StrideA_; - using ElementB = ElementB_; - using StrideB = StrideB_; - using TiledMma = TiledMma_; - using ElementAccumulator = typename TiledMma::ValTypeC; - using GmemTiledCopyA = GmemTiledCopyA_; - using GmemTiledCopyB = GmemTiledCopyB_; - using SmemLayoutAtomA = SmemLayoutAtomA_; - using SmemLayoutAtomB = SmemLayoutAtomB_; - using SmemCopyAtomA = SmemCopyAtomA_; - using SmemCopyAtomB = SmemCopyAtomB_; - using TransformA = TransformA_; - using TransformB = TransformB_; - using ArchTag = typename DispatchPolicy::ArchTag; - - static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); - static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - - static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); - static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - - using SmemLayoutA = decltype(tile_to_shape( - SmemLayoutAtomA{}, - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}))); - using SmemLayoutB = decltype(tile_to_shape( - SmemLayoutAtomB{}, - make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}))); - - static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); - static_assert(cute::is_base_of::value && - cute::is_base_of::value, - "MMA atom must source both A and B operand from smem_desc for this mainloop."); - - struct SharedStorage - { - cute::array_aligned> smem_a; - cute::array_aligned> smem_b; - }; - - struct Arguments { - ElementA const* ptr_A; - StrideA dA; - ElementB const* ptr_B; - StrideB dB; - }; - - using Params = Arguments; - - // - // Methods - // - - CollectiveMma() = default; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& _, Arguments const& args, void* workspace) { - (void) workspace; - return args; - } - - /// Perform a collective-scoped matrix multiply-accumulate - template < - class TensorA, - class TensorB, - class FrgTensorC, - class KTileIterator, - class ResidueMNK - > - CUTLASS_DEVICE void - operator() ( - TensorA gA, - TensorB gB, - FrgTensorC& accum, - KTileIterator k_tile_iter, int k_tile_count, - ResidueMNK residue_mnk, - int thread_idx, - char *smem_buf, - Params const& mainloop_params) - { - using namespace cute; - - (void) residue_mnk; - - static_assert(is_gmem::value, "A tensor must be gmem resident."); - static_assert(is_gmem::value, "B tensor must be gmem resident."); - static_assert(is_rmem::value, "C tensor must be rmem resident."); - static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2."); - static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2."); - static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); - static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); - static_assert(cute::is_same::value, - "SM90 warpgroup MMA must specify transforms through MMA_Atom."); - static_assert(cute::is_same::value, - "SM90 warpgroup MMA must specify transforms through MMA_Atom."); - static_assert(cute::is_same::value, - "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); - static_assert(cute::is_same::value, - "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); - - SharedStorage& storage = *reinterpret_cast(smem_buf); - Tensor sA = make_tensor(make_smem_ptr(storage.smem_a.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB = make_tensor(make_smem_ptr(storage.smem_b.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) - - // Partition the copying of A and B tiles across the threads - GmemTiledCopyA gmem_tiled_copy_a; - GmemTiledCopyB gmem_tiled_copy_b; - auto gmem_thr_copy_a = gmem_tiled_copy_a.get_slice(thread_idx); - auto gmem_thr_copy_b = gmem_tiled_copy_b.get_slice(thread_idx); - - Tensor tAgA = gmem_thr_copy_a.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) - Tensor tAsA = gmem_thr_copy_a.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,PIPE) - Tensor tBgB = gmem_thr_copy_b.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) - Tensor tBsB = gmem_thr_copy_b.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,PIPE) - - // Tile MMA atom and compute thread partitions across A, B and C - TiledMma tiled_mma; - auto thr_mma = tiled_mma.get_thread_slice(thread_idx); - - // Allocate registers for pipelining - Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) - - Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_N,MMA_K,PIPE) - Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE) - - CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N - CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K - CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE - CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tAsA)); // PIPE - CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tBsB)); // PIPE - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE - - // - // Prologue - // - - CUTLASS_PRAGMA_UNROLL - for (int k_pipe = 0; k_pipe < DispatchPolicy::Stages-1; ++k_pipe) { - copy(gmem_tiled_copy_a, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,k_pipe)); - copy(gmem_tiled_copy_b, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,k_pipe)); - cp_async_fence(); - ++k_tile_iter; - --k_tile_count; - } - - // Current pipe index in smem to read from - int smem_pipe_read = 0; - // Current pipe index in smem to write to - int smem_pipe_write = DispatchPolicy::Stages-1; - - // - // Pipelined Main Loop - // - CUTLASS_PRAGMA_NO_UNROLL - for ( ; k_tile_count > -(DispatchPolicy::Stages-1); --k_tile_count) - { - // Copy gmem to smem before computing gemm on each k-pipe - // pipe index in smem where the next gmem tile will be read into - copy(gmem_tiled_copy_a, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,smem_pipe_write)); - copy(gmem_tiled_copy_b, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,smem_pipe_write)); - cp_async_fence(); - if (k_tile_count > 0) { ++k_tile_iter; } - - // - // Compute on k_tile - // - warpgroup_fence_operand(accum); - warpgroup_arrive(); - - cp_async_wait(); - cute::gemm(tiled_mma, tCrA(_,_,_,smem_pipe_read), tCrB(_,_,_,smem_pipe_read), accum); - warpgroup_commit_batch(); - - // - // Advance the pipe - // - ++smem_pipe_read; - smem_pipe_read = (smem_pipe_read == DispatchPolicy::Stages) ? smem_pipe_read = 0 : smem_pipe_read; - - ++smem_pipe_write; - smem_pipe_write = (smem_pipe_write == DispatchPolicy::Stages) ? smem_pipe_write = 0 : smem_pipe_write; - - // Wait for the pipeline MMAs to drain - warpgroup_wait<0>(); - warpgroup_fence_operand(accum); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - int Stages, - class ClusterShape, - class TileShape_, - class ElementA_, - class StrideA_, - class ElementB_, - class StrideB_, - class TiledMma_, - class GmemTiledCopyA_, - class SmemLayoutAtomA_, - class SmemCopyAtomA_, - class TransformA_, - class GmemTiledCopyB_, - class SmemLayoutAtomB_, - class SmemCopyAtomB_, - class TransformB_> -struct CollectiveMma< - MainloopSm90CpAsyncGmma, - TileShape_, - ElementA_, - StrideA_, - ElementB_, - StrideB_, - TiledMma_, - GmemTiledCopyA_, - SmemLayoutAtomA_, - SmemCopyAtomA_, - TransformA_, - GmemTiledCopyB_, - SmemLayoutAtomB_, - SmemCopyAtomB_, - TransformB_> -{ - // - // Type Aliases - // - using DispatchPolicy = MainloopSm90CpAsyncGmma; - using TileShape = TileShape_; - using ElementA = ElementA_; - using StrideA = StrideA_; - using ElementB = ElementB_; - using StrideB = StrideB_; - using TiledMma = TiledMma_; - using ElementAccumulator = typename TiledMma::ValTypeC; - using GmemTiledCopyA = GmemTiledCopyA_; - using GmemTiledCopyB = GmemTiledCopyB_; - using SmemLayoutAtomA = SmemLayoutAtomA_; - using SmemLayoutAtomB = SmemLayoutAtomB_; - using SmemCopyAtomA = SmemCopyAtomA_; - using SmemCopyAtomB = SmemCopyAtomB_; - using TransformA = TransformA_; - using TransformB = TransformB_; - using ArchTag = typename DispatchPolicy::ArchTag; - - static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); - static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - - static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); - static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); - - using SmemLayoutA = decltype(tile_to_shape( - SmemLayoutAtomA{}, - make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}))); - using SmemLayoutB = decltype(tile_to_shape( - SmemLayoutAtomB{}, - make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}))); - - static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); - static_assert(cute::is_base_of::value && - cute::is_base_of::value, - "MMA atom must source both A and B operand from smem_desc for this mainloop."); - - struct SharedStorage - { - cute::array_aligned> smem_a; - cute::array_aligned> smem_b; - }; - - struct Arguments { - ElementA const* ptr_A; - StrideA dA; - ElementB const* ptr_B; - StrideB dB; - }; - - using Params = Arguments; - - // - // Methods - // - - CollectiveMma() = default; - - template - static constexpr Params - to_underlying_arguments(ProblemShape const& _, Arguments const& args, void* workspace) { - (void) workspace; - return args; - } - - template - CUTLASS_HOST_DEVICE static bool - can_implement( - [[maybe_unused]] ProblemShape const& problem_shape, - [[maybe_unused]] Arguments const& args) { - return true; - } - - /// Perform a collective-scoped matrix multiply-accumulate - template < - class FrgTensorD, - class TensorA, - class TensorB, - class FrgTensorC, - class KTileIterator, - class ResidueMNK - > - CUTLASS_DEVICE void - operator() ( - FrgTensorD &accum, - TensorA gA_in, - TensorB gB_in, - FrgTensorC const &src_accum, - KTileIterator k_tile_iter, int k_tile_count, - ResidueMNK residue_mnk, - int thread_idx, - char *smem_buf) - { - using namespace cute; - - static_assert(is_rmem::value, "D tensor must be rmem resident."); - static_assert(is_gmem::value, "A tensor must be gmem resident."); - static_assert(is_gmem::value, "B tensor must be gmem resident."); - static_assert(is_rmem::value, "C tensor must be rmem resident."); - static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2."); - static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2."); - static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); - static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); - static_assert(cute::is_same::value, - "SM90 warpgroup MMA must specify transforms through MMA_Atom."); - static_assert(cute::is_same::value, - "SM90 warpgroup MMA must specify transforms through MMA_Atom."); - static_assert(cute::is_same::value, - "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); - static_assert(cute::is_same::value, - "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); - - SharedStorage& storage = *reinterpret_cast(smem_buf); - Tensor sA = make_tensor(make_smem_ptr(storage.smem_a.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) - Tensor sB = make_tensor(make_smem_ptr(storage.smem_b.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) - - // Shift tensor so residue_k is at origin (Can't read any k_coord < residue_k) - // This aligns the tensor with BLK_K for all but the 0th k_tile - Tensor gA = domain_offset(make_coord(0, get<2>(residue_mnk), 0), gA_in); - Tensor gB = domain_offset(make_coord(0, get<2>(residue_mnk), 0), gB_in); - - // Partition the copying of A and B tiles across the threads - GmemTiledCopyA gmem_tiled_copy_a; - GmemTiledCopyB gmem_tiled_copy_b; - auto gmem_thr_copy_a = gmem_tiled_copy_a.get_slice(thread_idx); - auto gmem_thr_copy_b = gmem_tiled_copy_b.get_slice(thread_idx); - - Tensor tAgA = gmem_thr_copy_a.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) - Tensor tAsA = gmem_thr_copy_a.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,PIPE) - Tensor tBgB = gmem_thr_copy_b.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) - Tensor tBsB = gmem_thr_copy_b.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,PIPE) - - // - // PREDICATES - // - - // Allocate predicate tensors for m and n - Tensor tApA = make_tensor(make_shape(size<1>(tAsA), size<2>(tAsA)), Stride<_1,_0>{}); - Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); - - // Construct identity layout for sA and sB - Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) -> (blk_m,blk_k) - Tensor cB = make_identity_tensor(make_shape(size<0>(sB), size<1>(sB))); // (BLK_N,BLK_K) -> (blk_n,blk_k) - - // Repeat the partitioning with identity layouts - Tensor tAcA = gmem_thr_copy_a.partition_S(cA); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) - Tensor tBcB = gmem_thr_copy_b.partition_S(cB); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) - - // Set predicates for m bounds - CUTLASS_PRAGMA_UNROLL - for (int m = 0; m < size<0>(tApA); ++m) { - tApA(m,0) = get<0>(tAcA(0,m,0)) < get<0>(residue_mnk); // blk_m coord < residue_m - } - // Set predicates for n bounds - CUTLASS_PRAGMA_UNROLL - for (int n = 0; n < size<0>(tBpB); ++n) { - tBpB(n,0) = get<0>(tBcB(0,n,0)) < get<1>(residue_mnk); // blk_n coord < residue_n - } - - // - // Prologue/PREFETCH - // - - // Clear the smem tiles to account for predicated off loads - clear(tAsA); - clear(tBsB); - - // Start async loads for 0th k-tile, where we take care of the k residue - { - constexpr int k_pipe = 0; - - Tensor tAgAk = tAgA(_,_,_,*k_tile_iter); - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < size<2>(tAsA); ++k) { - if (get<1>(tAcA(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gA shifted) - copy_if(gmem_tiled_copy_a, tApA(_,k), tAgAk(_,_,k), tAsA(_,_,k,k_pipe)); - } - } - Tensor tBgBk = tBgB(_,_,_,*k_tile_iter); - CUTLASS_PRAGMA_UNROLL - for (int k = 0; k < size<2>(tBsB); ++k) { - if (get<1>(tBcB(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gB shifted) - copy_if(gmem_tiled_copy_b, tBpB(_,k), tBgBk(_,_,k), tBsB(_,_,k,k_pipe)); - } - } - cp_async_fence(); - ++k_tile_iter; - --k_tile_count; - } - - // Start async loads for 1st k-tile onwards, no k-residue handling needed - CUTLASS_PRAGMA_UNROLL - for (int k_pipe = 1; k_pipe < DispatchPolicy::Stages-1; ++k_pipe) { - if (k_tile_count <= 0) { - clear(tApA); - clear(tBpB); - } - copy_if(gmem_tiled_copy_a, tApA, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,k_pipe)); // CpAsync - copy_if(gmem_tiled_copy_b, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,k_pipe)); // CpAsync - cp_async_fence(); - ++k_tile_iter; - --k_tile_count; - } - - // - // MMA Atom partitioning - // - - // Tile MMA atom and compute thread partitions across A, B and C - TiledMma tiled_mma; - auto thr_mma = tiled_mma.get_thread_slice(thread_idx); - - // Allocate registers for pipelining - Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) - Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) - - Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_N,MMA_K,PIPE) - Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_N,PIPE) - - CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M - CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(src_accum)); // M - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N - CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(src_accum)); // N - CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K - CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE - CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tAsA)); // PIPE - CUTE_STATIC_ASSERT_V(size<3>(tCsB) == size<3>(tBsB)); // PIPE - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE - CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE - - // Current pipe index in smem to read from - int smem_pipe_read = 0; - // Current pipe index in smem to write to - int smem_pipe_write = DispatchPolicy::Stages-1; - - // - // Pipelined Main Loop - // - CUTLASS_PRAGMA_NO_UNROLL - for ( ; k_tile_count > -(DispatchPolicy::Stages-1); --k_tile_count) - { - // - // Copy gmem to smem for *k_tile_iter - // - if (k_tile_count <= 0) { - clear(tApA); - clear(tBpB); - } - copy_if(gmem_tiled_copy_a, tApA, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,smem_pipe_write)); // CpAsync - copy_if(gmem_tiled_copy_b, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,smem_pipe_write)); // CpAsync - cp_async_fence(); - ++k_tile_iter; - - // - // Compute on k_tile - // - warpgroup_fence_operand(accum); - warpgroup_arrive(); - - cp_async_wait(); - cute::gemm(tiled_mma, accum, tCrA(_,_,_,smem_pipe_read), tCrB(_,_,_,smem_pipe_read), src_accum); - warpgroup_commit_batch(); - - // - // Advance the pipe - // - ++smem_pipe_read; - smem_pipe_read = (smem_pipe_read == DispatchPolicy::Stages) ? smem_pipe_read = 0 : smem_pipe_read; - - ++smem_pipe_write; - smem_pipe_write = (smem_pipe_write == DispatchPolicy::Stages) ? smem_pipe_write = 0 : smem_pipe_write; - - // Wait for the pipeline MMAs to drain - warpgroup_wait<0>(); - warpgroup_fence_operand(accum); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace cutlass::gemm::collective - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp new file mode 100644 index 0000000000..1b74153f46 --- /dev/null +++ b/include/cutlass/gemm/collective/sm90_mma_multistage_gmma_ss_warpspecialized.hpp @@ -0,0 +1,483 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop +template < + int Stages, + class ClusterShape_, + class TileShape_, + class KernelSchedule, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90CpAsyncGmmaWarpSpecialized, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90CpAsyncGmmaWarpSpecialized; + using TileShape = TileShape_; + using ClusterShape = ClusterShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineAsync; + using PipelineState = typename MainloopPipeline::PipelineState; + using PipelineParams = typename MainloopPipeline::Params; + + static_assert(rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}))); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}))); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct<128> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A = nullptr; + StrideA dA{}; + ElementB const* ptr_B = nullptr; + StrideB dB{}; + uint32_t mma_promotion_interval = 4; + }; + + // Device side kernel params + using Params = Arguments; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + [[maybe_unused]] ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return args; + } + + template + CUTLASS_HOST_DEVICE static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class TensorA, + class TensorB, + class KTileIterator, + class ResidueMNK + > + CUTLASS_DEVICE void + load( + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + TensorA const& gA_in, + TensorB const& gB_in, + KTileIterator k_tile_iter, int k_tile_count, + ResidueMNK residue_mnk, + int thread_idx, + TensorStorage& shared_tensors) + { + using namespace cute; + + static_assert(is_gmem::value, "A tensor must be gmem resident."); + static_assert(is_gmem::value, "B tensor must be gmem resident."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // Shift tensor so residue_k is at origin (Can't read any k_coord < residue_k) + // This aligns the tensor with BLK_K for all but the 0th k_tile + Tensor gA = domain_offset(make_coord(0, get<2>(residue_mnk), 0), gA_in); + Tensor gB = domain_offset(make_coord(0, get<2>(residue_mnk), 0), gB_in); + + // Partition the copying of A and B tiles across the threads + GmemTiledCopyA gmem_tiled_copy_a; + GmemTiledCopyB gmem_tiled_copy_b; + auto gmem_thr_copy_a = gmem_tiled_copy_a.get_slice(thread_idx); + auto gmem_thr_copy_b = gmem_tiled_copy_b.get_slice(thread_idx); + + Tensor tAgA = gmem_thr_copy_a.partition_S(gA); // (ACPY,ACPY_M,ACPY_K,k) + Tensor tAsA = gmem_thr_copy_a.partition_D(sA); // (ACPY,ACPY_M,ACPY_K,PIPE) + Tensor tBgB = gmem_thr_copy_b.partition_S(gB); // (BCPY,BCPY_N,BCPY_K,k) + Tensor tBsB = gmem_thr_copy_b.partition_D(sB); // (BCPY,BCPY_N,BCPY_K,PIPE) + + // Allocate predicate tensors for m and n + Tensor tApA = make_tensor(make_shape(size<1>(tAsA), size<2>(tAsA)), Stride<_1,_0>{}); + Tensor tBpB = make_tensor(make_shape(size<1>(tBsB), size<2>(tBsB)), Stride<_1,_0>{}); + + // Construct identity layout for sA and sB + Tensor cA = make_identity_tensor(make_shape(size<0>(sA), size<1>(sA))); // (BLK_M,BLK_K) -> (blk_m,blk_k) + Tensor cB = make_identity_tensor(make_shape(size<0>(sB), size<1>(sB))); // (BLK_N,BLK_K) -> (blk_n,blk_k) + + // Repeat the partitioning with identity layouts + Tensor tAcA = gmem_thr_copy_a.partition_S(cA); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) + Tensor tBcB = gmem_thr_copy_b.partition_S(cB); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) + + // Set predicates for m bounds + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<0>(tApA); ++m) { + tApA(m,0) = get<0>(tAcA(0,m,0)) < get<0>(residue_mnk); // blk_m coord < residue_m + } + // Set predicates for n bounds + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<0>(tBpB); ++n) { + tBpB(n,0) = get<0>(tBcB(0,n,0)) < get<1>(residue_mnk); // blk_n coord < residue_n + } + + // 0-th stage with predication on k to account for residue + { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + int write_stage = smem_pipe_write.index(); + + // Copy gmem to smem for *k_tile_iter, predicating for k residue + Tensor tAgAk = tAgA(_,_,_,*k_tile_iter); + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tAsA); ++k) { + if (get<1>(tAcA(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gA shifted) + copy_if(gmem_tiled_copy_a, tApA(_,k), tAgAk(_,_,k), tAsA(_,_,k,write_stage)); + } + else { + clear(tAsA(_,_,k,write_stage)); + } + } + Tensor tBgBk = tBgB(_,_,_,*k_tile_iter); + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < size<2>(tBsB); ++k) { + if (get<1>(tBcB(0,0,k)) >= -get<2>(residue_mnk)) { // blk_k coord < residue_k (gB shifted) + copy_if(gmem_tiled_copy_b, tBpB(_,k), tBgBk(_,_,k), tBsB(_,_,k,write_stage)); + } + else { + clear(tBsB(_,_,k,write_stage)); + } + } + ++k_tile_iter; + --k_tile_count; + + // UNLOCK smem_pipe_write + pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive); + + // Advance smem_pipe_write + ++smem_pipe_write; + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + int write_stage = smem_pipe_write.index(); + + // Copy gmem to smem for *k_tile_iter + copy_if(gmem_tiled_copy_a, tApA, tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy_if(gmem_tiled_copy_b, tBpB, tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + + // UNLOCK smem_pipe_write + pipeline.producer_commit(smem_pipe_write, cutlass::arch::cpasync_barrier_arrive); + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail( + MainloopPipeline pipeline, + PipelineState smem_pipe_write) { + // Issue the epilogue waits + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) + { + using namespace cute; + + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + warpgroup_fence_operand(accum); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) { + + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + + warpgroup_arrive(); + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + warpgroup_commit_batch(); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accum); + + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + + warpgroup_fence_operand(accum); + warpgroup_arrive(); + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accum); + + // UNLOCK smem_pipe_release, done _computing_ on it + pipeline.consumer_release(smem_pipe_release); + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + warpgroup_fence_operand(accum); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp index 6fe3f4565c..2928192b42 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized.hpp @@ -59,7 +59,6 @@ template < int Stages, class ClusterShape, class KernelSchedule, - int PipelineAsyncMmaStages, class TileShape_, class ElementA_, class StrideA_, @@ -75,7 +74,7 @@ template < class SmemCopyAtomB_, class TransformB_> struct CollectiveMma< - MainloopSm90TmaGmmaRmemAWarpSpecialized, + MainloopSm90TmaGmmaRmemAWarpSpecialized, TileShape_, ElementA_, StrideA_, @@ -94,7 +93,7 @@ struct CollectiveMma< // // Type Aliases // - using DispatchPolicy = MainloopSm90TmaGmmaRmemAWarpSpecialized; + using DispatchPolicy = MainloopSm90TmaGmmaRmemAWarpSpecialized; using TileShape = TileShape_; using ElementA = ElementA_; using StrideA = StrideA_; @@ -190,9 +189,19 @@ struct CollectiveMma< static_assert(!SwapAB || !TransposeB, "Cannot SwapAB and TransposeB at the same time."); static_assert(TransposeB xor (cute::is_same_v), "Should be same layout if not TransposeB."); - static_assert(!TransposeB || size<1>(SmemLayoutB{}) * sizeof(InternalElementB) == 128, + static_assert(!TransposeB || (((size<1>(SmemLayoutB{}) * sizeof_bits::value)) / 8) == 128, "SmemLayoutB K must be 128bytes to be transposed."); - static_assert(!transform::collective::detail::use_universal_transposition(), + + static constexpr bool uses_universal_transposition() { + if constexpr (TransposeB) { + return transform::collective::detail::use_universal_transposition(); + } + else { + return false; + } + } + + static_assert(!uses_universal_transposition(), "Warp specialized ARF kernels have not supported universal B transposition yet."); static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); @@ -216,10 +225,10 @@ struct CollectiveMma< // Host side kernel arguments struct Arguments { - ElementA const* ptr_A; - StrideA dA; - ElementB const* ptr_B; - StrideB dB; + ElementA const* ptr_A = nullptr; + StrideA dA{}; + ElementB const* ptr_B = nullptr; + StrideB dB{}; uint32_t mma_promotion_interval = 4; }; @@ -321,11 +330,9 @@ struct CollectiveMma< } static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; - static constexpr int K_PIPE_MMAS = DispatchPolicy::PipelineAsyncMmaStages; - static_assert(K_PIPE_MMAS == 0, "no MMA stage should be asynchronous for this mainloop for now."); static constexpr uint32_t TmaTransactionBytes = - (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof(InternalElementA)))+ - (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof(InternalElementB))); + (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)) / 8 + + (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)) / 8 ; /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance CUTLASS_DEVICE @@ -335,19 +342,45 @@ struct CollectiveMma< cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); } + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// that the tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + tile_input_tensors(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params, TileShapeMNK const& tileshape_mnk) { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, tileshape_mnk, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, tileshape_mnk, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + return cute::make_tuple(gA_mkl, gB_nkl); + } + /// Perform a collective-scoped matrix multiply-accumulate /// Producer Perspective template < - class TensorA, class TMA_LOAD_A, - class TensorB, class TMA_LOAD_B, - class KTileIterator + class TensorA, class TensorB, + class KTileIterator, class BlockCoord > CUTLASS_DEVICE void load( + Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write, - TensorA const& gA, TMA_LOAD_A& tma_load_a, - TensorB const& gB, TMA_LOAD_B& tma_load_b, + cute::tuple const& tiled_tensors, + BlockCoord const& blk_coord, KTileIterator k_tile_iter, int k_tile_count, int thread_idx, uint32_t block_rank_in_cluster, @@ -372,8 +405,16 @@ struct CollectiveMma< constexpr uint32_t cluster_shape_x = get<0>(DispatchPolicy::ClusterShape()); uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - auto block_tma_a = tma_load_a.get_slice(cluster_local_block_id.y); - auto block_tma_b = tma_load_b.get_slice(cluster_local_block_id.x); + Tensor gA_mkl = get<0>(tiled_tensors); + Tensor gB_nkl = get<1>(tiled_tensors); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) // Applies the mapping from block_tma_a Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) @@ -415,8 +456,8 @@ struct CollectiveMma< BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); int write_stage = smem_pipe_write.index(); - copy(tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); - copy(tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); ++k_tile_iter; // Advance smem_pipe_write @@ -508,9 +549,12 @@ struct CollectiveMma< auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA); // (CPY,CPY_M,CPY_K) + Tensor tCsA_copy_view = smem_thr_copy_A.partition_S(sA); // (CPY,CPY_M,CPY_K) CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K + CUTE_STATIC_ASSERT_V(size<1>(tCsA_copy_view) == size<1>(tCrA_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsA_copy_view) == size<2>(tCrA_copy_view)); // CPY_K CUTE_STATIC_ASSERT_V(size<1>(tCrA) == size<1>(accum)); // MMA_M CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K @@ -521,8 +565,6 @@ struct CollectiveMma< // // PIPELINED MAIN LOOP // - static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), - "ERROR : Incorrect number of MMAs in flight"); // We release buffers to producer warps(dma load) with some mmas in flight PipelineState smem_pipe_release = smem_pipe_read; @@ -548,43 +590,41 @@ struct CollectiveMma< barrier_token = pipeline.consumer_try_wait(smem_pipe_read); // copy smem->rmem for A operand - copy(smem_tiled_copy_A, tCsA(_,_,0,read_stage), tCrA_copy_view(_,_,0)); + copy(smem_tiled_copy_A, tCsA_copy_view(_,_,0,read_stage), tCrA_copy_view(_,_,0)); // transpose B operand in SMEM transpose(sB, gmma_sB, read_stage, 0); // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA) - 1; ++k_block) { - copy(smem_tiled_copy_A, tCsA(_,_,k_block + 1,read_stage), tCrA_copy_view(_,_,k_block + 1)); + copy(smem_tiled_copy_A, tCsA_copy_view(_,_,k_block + 1,read_stage), tCrA_copy_view(_,_,k_block + 1)); transpose.synchronize(k_block); transpose(sB, gmma_sB, read_stage, k_block + 1); warpgroup_arrive(); // (V,M) x (V,N) => (V,M,N) cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); - tiled_mma.accumulate_ = GMMA::ScaleOut::One; + if(k_block == 0) { + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } warpgroup_commit_batch(); } warpgroup_wait<2>(); - --k_tile_count; - if (k_tile_count > 0) { - pipeline.consumer_wait(smem_pipe_read, barrier_token); - copy(smem_tiled_copy_A, tCsA(_,_,0,smem_pipe_read.index()), tCrA_copy_view(_,_,0)); - transpose(sB, gmma_sB, smem_pipe_read.index(), 0); - } warpgroup_arrive(); // (V,M) x (V,N) => (V,M,N) cute::gemm(tiled_mma, tCrA(_,_,size<2>(tCrA) - 1), tCrB(_,_,size<2>(tCrA) - 1,read_stage), accum); - tiled_mma.accumulate_ = GMMA::ScaleOut::One; warpgroup_commit_batch(); + --k_tile_count; + if(k_tile_count == 0) { + return; + } + pipeline.consumer_wait(smem_pipe_read, barrier_token); + copy(smem_tiled_copy_A, tCsA_copy_view(_,_,0,smem_pipe_read.index()), tCrA_copy_view(_,_,0)); + transpose(sB, gmma_sB, smem_pipe_read.index(), 0); warpgroup_wait<2>(); } - if (k_tile_count == 0) { - return; - } - warpgroup_fence_operand(accum); // Mainloop GMMAs CUTLASS_PRAGMA_NO_UNROLL @@ -606,12 +646,12 @@ struct CollectiveMma< } if (k_block == size<2>(tCrA) - 1) { pipeline.consumer_wait(smem_pipe_read, barrier_token); - copy(smem_tiled_copy_A, tCsA(_,_,0,smem_pipe_read.index()), tCrA_copy_view(_,_,0)); + copy(smem_tiled_copy_A, tCsA_copy_view(_,_,0,smem_pipe_read.index()), tCrA_copy_view(_,_,0)); // transpose B operand in SMEM transpose(sB, gmma_sB, smem_pipe_read.index(), 0); } else { - copy(smem_tiled_copy_A, tCsA(_,_,k_block + 1,read_stage), tCrA_copy_view(_,_,k_block + 1)); + copy(smem_tiled_copy_A, tCsA_copy_view(_,_,k_block + 1,read_stage), tCrA_copy_view(_,_,k_block + 1)); // transpose B operand in SMEM transpose.synchronize(k_block); // make transpose of k_block available transpose(sB, gmma_sB, read_stage, k_block + 1); @@ -620,7 +660,6 @@ struct CollectiveMma< warpgroup_arrive(); // (V,M) x (V,N) => (V,M,N) cute::gemm(tiled_mma, tCrA(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); - tiled_mma.accumulate_ = GMMA::ScaleOut::One; warpgroup_commit_batch(); warpgroup_wait<2>(); if (k_block == 1) { @@ -647,8 +686,7 @@ struct CollectiveMma< // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int k_block = 0; k_block < size<2>(tCrA) - 1; ++k_block) { - - copy(smem_tiled_copy_A, tCsA(_,_,k_block + 1,read_stage), tCrA_copy_view(_,_,k_block + 1)); + copy(smem_tiled_copy_A, tCsA_copy_view(_,_,k_block + 1,read_stage), tCrA_copy_view(_,_,k_block + 1)); transpose.synchronize(k_block); // make k_block transpose available transpose(sB, gmma_sB, read_stage, k_block + 1); warpgroup_arrive(); @@ -667,7 +705,6 @@ struct CollectiveMma< warpgroup_arrive(); // (V,M) x (V,N) => (V,M,N) cute::gemm(tiled_mma, tCrA(_,_,size<2>(tCrA) - 1), tCrB(_,_,size<2>(tCrA) - 1,read_stage), accum); - tiled_mma.accumulate_ = GMMA::ScaleOut::One; warpgroup_commit_batch(); } diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp new file mode 100644 index 0000000000..f429507002 --- /dev/null +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp @@ -0,0 +1,830 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/detail/layout.hpp" + +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_traits_sm90_tma.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/transform/collective/sm90_wgmma_transpose.hpp" +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// WarpSpecialized Mainloop that source A operand from registers +template < + int Stages, + class ClusterShape, + class KernelSchedule, + class TileShape_, + class ElementAOptionalTuple, + class StrideAOptionalTuple, + class ElementBOptionalTuple, + class StrideBOptionalTuple, + class TiledMma_, + class GmemTiledCopyAOptionalTuple, + class SmemLayoutAtomAOptionalTuple, + class SmemCopyAtomAOptionalTuple, + class TransformA_, + class GmemTiledCopyBOptionalTuple, + class SmemLayoutAtomBOptionalTuple, + class SmemCopyAtomBOptionalTuple, + class TransformB_> +struct CollectiveMma< + MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput, + TileShape_, + ElementAOptionalTuple, + StrideAOptionalTuple, + ElementBOptionalTuple, + StrideBOptionalTuple, + TiledMma_, + GmemTiledCopyAOptionalTuple, + SmemLayoutAtomAOptionalTuple, + SmemCopyAtomAOptionalTuple, + TransformA_, + GmemTiledCopyBOptionalTuple, + SmemLayoutAtomBOptionalTuple, + SmemCopyAtomBOptionalTuple, + TransformB_> +{ +private: + template + static constexpr auto + get_logical_ptr(PointerType const* ptr) { + if constexpr (cute::sizeof_bits_v < 8) { + return subbyte_iterator(ptr); + } + else { + return ptr; + } + } + +public: + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput; + using TileShape = TileShape_; + + using ElementA = ElementAOptionalTuple; + using ElementB = ElementBOptionalTuple; + static constexpr bool IsATransformed = cute::sizeof_bits_v < cute::sizeof_bits_v; + using ElementScale = void; + + using StrideA = StrideAOptionalTuple; + using StrideB = StrideBOptionalTuple; + using StrideScale = void; + static constexpr int AlignmentScale = cute::Int<0>{}; + + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + + using GmemTiledCopyA = GmemTiledCopyAOptionalTuple; + using GmemTiledCopyB = GmemTiledCopyBOptionalTuple; + using GmemTiledCopyScale = void; + + using SmemLayoutAtomA = SmemLayoutAtomAOptionalTuple; + using SmemLayoutAtomB = SmemLayoutAtomBOptionalTuple; + using SmemLayoutAtomScale = void; + + using SmemCopyAtomA = SmemCopyAtomAOptionalTuple; + using SmemCopyAtomB = SmemCopyAtomBOptionalTuple; + using SmemCopyAtomScale = void; + + // Swap and transpose A/B for A k-major layout and B mn-major layout since WGMMA is k-major only (e.g. tf32, Fp32, Int8, Fp8 WGMMA) + static constexpr bool IsLayoutAkBmn = + cute::is_same_v, layout::RowMajor> && + cute::is_same_v, layout::RowMajor>; + + static constexpr bool IsInputSizeTwoBytes = sizeof(ElementA) == 2 && sizeof(ElementB) == 2; + + // We must ensure the type to be scaled goes to RF + static constexpr bool SwapAB = !IsATransformed; + using InternalSmemLayoutAtomA = cute::conditional_t; + using InternalSmemLayoutAtomB = cute::conditional_t; + using InternalSmemCopyAtomA = cute::conditional_t; + using InternalSmemCopyAtomB = cute::conditional_t; + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using ConvertedElementA = cute::conditional_t>>; + using ConvertedElementB = cute::conditional_t>>; + using RealInternalElementA = cute::conditional_t; + using RealInternalElementB = cute::conditional_t; + using InternalElementA = cute::conditional_t; + using InternalElementB = cute::conditional_t; + using InternalStrideA = cute::conditional_t; + using InternalStrideB = cute::conditional_t; + + using TransformA = TransformA_; + using TransformB = TransformB_; + using InternalTransformA = cute::conditional_t; + using InternalTransformB = cute::conditional_t; + + static_assert(sizeof(InternalElementB) == 2 || + (cute::is_same_v, layout::RowMajor> && + cute::is_same_v, layout::ColumnMajor>), + "B operand after swap must be 2 bytes OR K-major."); + static constexpr int IsSubbyteA = cute::sizeof_bits_v < 8; + using TmaElementA = cute::conditional_t; + + using ArchTag = typename DispatchPolicy::ArchTag; + + using MainloopPipeline = cutlass::PipelineTmaAsync< + DispatchPolicy::Stages, + typename DispatchPolicy::ClusterShape>; + using PipelineState = cutlass::PipelineState; + + using PipelineParams = typename MainloopPipeline::Params; + + static_assert(rank(InternalSmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(rank(InternalSmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(InternalSmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::is_same_v || cute::is_same_v, + "The TMA mcast for A must match the mcast for scales or the scale tiled copy must be void."); + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + InternalSmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,InternalStrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB = decltype(tile_to_shape( + InternalSmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,InternalStrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + // If A mn-layout and B mn-layout, transposing B matrix since WGMMA is k-major only (e.g. tf32, fp32, fp8, int8). + static constexpr bool IsLayoutAmnBmn = + cute::is_same_v, layout::ColumnMajor> && + cute::is_same_v, layout::RowMajor>; + static constexpr bool TransposeB = !IsInputSizeTwoBytes && IsLayoutAmnBmn; + using TransposeOperandB = decltype(cutlass::transform::collective::detail::make_transpose_operand_b( + 0, 0, TiledMma{}, SmemLayoutB{}, InternalSmemLayoutAtomB{}, + InternalElementB{}, cute::bool_constant{})); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(not cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source A from rmem and B operand from smem_desc for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + using GmmaSmemLayoutAtomB = decltype(transform::collective::detail::gmma_smem_transpose_or_passthrough< + TransposeB, InternalSmemLayoutAtomB, InternalElementB>()); + + // SmemLayoutB for GMMA is different from SmemLayoutB for TMA if TransposeB + using GmmaSmemLayoutB = decltype(tile_to_shape( + GmmaSmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + conditional_t< ::cutlass::gemm::detail::is_major<0,InternalStrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(!SwapAB || !TransposeB, "Cannot SwapAB and TransposeB at the same time."); + static_assert(TransposeB xor (cute::is_same_v), + "Should be same layout if not TransposeB."); + static_assert(!TransposeB || size<1>(SmemLayoutB{}) * cute::sizeof_bits_v / 8 == 128, + "SmemLayoutB K must be 128bytes to be transposed."); + + static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); + + static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{}); + + static_assert(SmemAlignmentA >= 128 and SmemAlignmentB >= 128, "Require at least 128B alignment"); + + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct { + cute::ArrayEngine> smem_A; + cute::ArrayEngine> smem_B; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A = nullptr; + StrideA dA{}; + ElementB const* ptr_B = nullptr; + StrideB dB{}; + uint32_t mma_promotion_interval = 4; + }; + + // Device side kernel params + struct Params { + private: + using Outer = CollectiveMma; + + public: + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy( + GmemTiledCopyA{}, + make_tensor(Outer::get_logical_ptr(static_cast(nullptr)), repeat_like(InternalStrideA{}, int32_t(0)), InternalStrideA{}), + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{}))); // mcast along N mode for this M load, if any + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy( + GmemTiledCopyB{}, + make_tensor(Outer::get_logical_ptr(static_cast(nullptr)), repeat_like(InternalStrideB{}, int32_t(0)), InternalStrideB{}), + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{}))); // mcast along M mode for this N load, if any + TMA_A tma_load_a; + TMA_B tma_load_b; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + if constexpr (SwapAB) { + M = get<1>(problem_shape_MNKL); + N = get<0>(problem_shape_MNKL); + } + + InternalElementA const* ptr_A; + InternalStrideA dA; + InternalElementB const* ptr_B; + InternalStrideB dB; + + if constexpr (not SwapAB) { + ptr_A = reinterpret_cast(args.ptr_A); + ptr_B = reinterpret_cast(args.ptr_B); + dA = args.dA; + dB = args.dB; + } + else { + ptr_A = reinterpret_cast(args.ptr_B); + ptr_B = reinterpret_cast(args.ptr_A); + dA = args.dB; + dB = args.dA; + } + + Tensor tensor_a = make_tensor(get_logical_ptr(ptr_A), make_layout(make_shape(M,K,L), dA)); + Tensor tensor_b = make_tensor(get_logical_ptr(ptr_B), make_layout(make_shape(N,K,L), dB)); + typename Params::TMA_A tma_load_a = make_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); // mcast along N mode for this M load, if any + typename Params::TMA_B tma_load_b = make_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); // mcast along M mode for this N load, if any + return { + tma_load_a, + tma_load_b + }; + } + + template + CUTLASS_HOST_DEVICE static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + bool implementable = true; + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + return implementable; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr uint32_t TmaTransactionBytes = + (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(cute::sizeof_bits_v) / 8)+ + (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(cute::sizeof_bits_v) / 8); + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// that the tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + tile_input_tensors(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params, TileShapeMNK const& tileshape_mnk) { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, tileshape_mnk, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, tileshape_mnk, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + return cute::make_tuple(gA_mkl, gB_nkl); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class TensorA, class TensorB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_write, + cute::tuple const& tiled_tensors, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + + using namespace cute; + int warp_idx = canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % 4; + int lane_predicate = cute::elect_one_sync(); + + if (warp_idx_in_warp_group == 0 and lane_predicate) { + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) + Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A and B + // + + constexpr uint32_t cluster_shape_x = get<0>(DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + Tensor gA_mkl = get<0>(tiled_tensors); + Tensor gB_nkl = get<1>(tiled_tensors); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int warp_idx = canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % 4; + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (warp_idx_in_warp_group == 0 and lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + using namespace cute; + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(rank(InternalSmemLayoutAtomA{}) == 2, "InternalSmemLayoutAtomA must be rank 2."); + static_assert(rank(InternalSmemLayoutAtomB{}) == 2, "InternalSmemLayoutAtomB must be rank 2."); + static_assert(!cute::is_void_v, + "SM90 GMMA mainloops must specify a non-void copy atom for RF sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + // Obtain warp index + int warp_idx = canonical_warp_idx_sync(); + [[maybe_unused]] int warp_group_thread_idx = thread_idx % 128; + + Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE) + + Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_M,BLK_K,PIPE) + + // If TransposeB, GMMA will read from transposed B layout SMEM + Tensor gmma_sB_position_dependent = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), + GmmaSmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + Tensor gmma_sB = as_position_independent_swizzle_tensor(gmma_sB_position_dependent); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCsA = thread_mma.partition_A(sA); + + // Allocate fragments and descriptors + Tensor tCrA_mma = thread_mma.partition_fragment_A(sA(_,_,Int<0>{})); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrA_load = make_fragment_like(tCrA_mma); + + Tensor tCsB = thread_mma.partition_B(gmma_sB_position_dependent); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + // + // Copy Atom A retiling + // + + auto smem_tiled_copy_A = make_tiled_copy_A(InternalSmemCopyAtomA{}, tiled_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(warp_group_thread_idx); + + Tensor tCrA_copy_view = smem_thr_copy_A.retile_D(tCrA_load); // (CPY,CPY_M,CPY_K) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrA_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrA_copy_view)); // CPY_K + CUTE_STATIC_ASSERT_V(size<1>(tCrA_mma) == size<1>(accum)); // MMA_M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + TransposeOperandB transpose = cutlass::transform::collective::detail::make_transpose_operand_b( + warp_idx, warp_group_thread_idx, tiled_mma, SmemLayoutB{}, + InternalSmemLayoutAtomB{}, InternalElementB{}, + cute::bool_constant{}); + + warpgroup_fence_operand(accum); + + ConsumerToken barrier_token = {BarrierStatus::WaitAgain}; + // first k tile + { + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + + ++smem_pipe_read; + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + + // copy smem->rmem for A operand + copy(smem_tiled_copy_A, tCsA(_,_,0,read_stage), tCrA_copy_view(_,_,0)); + transform_internal_A(tCrA_load(_, _, 0), tCrA_mma(_, _, 0)); + // transpose B operand in SMEM + transpose(sB, gmma_sB, read_stage, 0); + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA_load) - 1; ++k_block) { + copy(smem_tiled_copy_A, tCsA(_,_,k_block + 1,read_stage), tCrA_copy_view(_,_,k_block + 1)); + transform_internal_A(tCrA_load(_, _, k_block + 1), tCrA_mma(_, _, k_block + 1)); + transpose.synchronize(k_block); + transpose(sB, gmma_sB, read_stage, k_block + 1); + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + } + + warpgroup_wait<2>(); + + --k_tile_count; + if (k_tile_count > 0) { + pipeline.consumer_wait(smem_pipe_read, barrier_token); + copy(smem_tiled_copy_A, tCsA(_,_,0,smem_pipe_read.index()), tCrA_copy_view(_,_,0)); + transform_internal_A(tCrA_load(_, _, 0), tCrA_mma(_, _, 0)); + transpose(sB, gmma_sB, smem_pipe_read.index(), 0); + } + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + const int final_k = size<2>(tCrA_load) - 1; + cute::gemm(tiled_mma, tCrA_mma(_,_, final_k), tCrB(_,_,final_k,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + warpgroup_wait<2>(); + } + + if (k_tile_count == 0) { + return; + } + + warpgroup_fence_operand(accum); + // Mainloop GMMAs + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 1; --k_tile_count) { + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + ++smem_pipe_read; + + warpgroup_fence_operand(accum); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA_load); ++k_block) { + if (k_block == 0) { + barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + } + if (k_block == size<2>(tCrA_load) - 1) { + pipeline.consumer_wait(smem_pipe_read, barrier_token); + copy(smem_tiled_copy_A, tCsA(_,_,0,smem_pipe_read.index()), tCrA_copy_view(_,_,0)); + transform_internal_A(tCrA_load(_, _, 0), tCrA_mma(_, _, 0)); + // transpose B operand in SMEM + transpose(sB, gmma_sB, smem_pipe_read.index(), 0); + } + else { + copy(smem_tiled_copy_A, tCsA(_,_,k_block + 1,read_stage), tCrA_copy_view(_,_,k_block + 1)); + transform_internal_A(tCrA_load(_, _, k_block + 1), tCrA_mma(_, _, k_block + 1)); + // transpose B operand in SMEM + transpose.synchronize(k_block); // make transpose of k_block available + transpose(sB, gmma_sB, read_stage, k_block + 1); + } + + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + warpgroup_wait<2>(); + if (k_block == 1) { + // release prior barrier + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } + warpgroup_fence_operand(accum); + + } + + warpgroup_fence_operand(accum); + + { + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + + warpgroup_fence_operand(accum); + + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA_load) - 1; ++k_block) { + + copy(smem_tiled_copy_A, tCsA(_,_,k_block + 1,read_stage), tCrA_copy_view(_,_,k_block + 1)); + transform_internal_A(tCrA_load(_, _, k_block + 1), tCrA_mma(_, _, k_block + 1)); + transpose.synchronize(k_block); // make k_block transpose available + transpose(sB, gmma_sB, read_stage, k_block + 1); + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA_mma(_,_,k_block), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + warpgroup_wait<2>(); + if (k_block == 1) { + // release prior barrier + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } + + warpgroup_arrive(); + // (V,M) x (V,N) => (V,M,N) + const int final_k = size<2>(tCrA_load) - 1; + cute::gemm(tiled_mma, tCrA_mma(_,_,final_k), tCrB(_,_,final_k,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + warpgroup_commit_batch(); + } + + warpgroup_fence_operand(accum); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = 1; + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } + +private: + template + static constexpr bool + is_fast_converter_exact() { + using DstType = typename Converter::result_type; + using SrcType = typename Converter::source_type; + + constexpr bool IsIntToFP32Exact = cute::is_same_v && + (cute::numeric_limits::is_integer && cute::sizeof_bits_v <= 16); + + constexpr bool IsIntToFP16orBF16Exact = (cute::is_same_v || cute::is_same_v) && + (cute::numeric_limits::is_integer && cute::sizeof_bits_v <= 8); + + return IsIntToFP32Exact || IsIntToFP16orBF16Exact; + } + + template > + CUTLASS_DEVICE void + transform_internal_A(Tensor&& in, Tensor&& out) { + /// This is an element-wise conversion where we expect both tensors to have the same layout. + /// As a result, we can cast as a cutlass array to use the fast numeric converters without + /// worrying about indexing into the layout. + + /// The inputs must be backed by registers & be statically sized so we can unroll the conversion loops. + static_assert(is_rmem::value, "Input tensor for A conversion must come from registers"); + static_assert(is_rmem::value, "Output tensor for A conversion must come from registers"); + static_assert(cute::is_same_v, "Input engine must be same type as the A operand"); + static_assert(cute::is_same_v, "Output engine must be same type as the Mma input"); + static_assert(is_static_v, "Tensor layout for the conversion must be static"); + + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + + constexpr cutlass::FloatRoundStyle RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + using DefaultConverterAToB = cutlass::NumericArrayConverter; + using FastConverterAToB = cutlass::FastNumericArrayConverter; + + using ConverterAToB = cute::conditional_t(), FastConverterAToB, DefaultConverterAToB>; + + SrcArray* src_array_ptr = reinterpret_cast(raw_pointer_cast(in.data())); + DstArray* dst_array_ptr = reinterpret_cast(raw_pointer_cast(out.data())); + *dst_array_ptr = std::move(ConverterAToB::convert(*src_array_ptr)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp index 932765ea7a..af14f1378c 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss.hpp @@ -321,8 +321,8 @@ struct CollectiveMma< // Set the bytes transferred in this TMA transaction (may involve multiple issues) constexpr uint32_t TmaTransactionBytes = static_cast( - (size<0>(sA) * size<1>(sA) * sizeof(InternalElementA)) + - (size<0>(sB) * size<1>(sB) * sizeof(InternalElementB))); + (size<0>(sA) * size<1>(sA) * sizeof_bits::value) / 8 + + (size<0>(sB) * size<1>(sB) * sizeof_bits::value) / 8); // Obtain warp index diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp index c7dee7b1d4..b0656ca407 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp @@ -250,8 +250,8 @@ struct CollectiveMma< static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; static constexpr int K_PIPE_MMAS = 1; static constexpr uint32_t TmaTransactionBytes = - (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof(ElementA)))+ - (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof(ElementB))); + (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)) / 8+ + (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)) / 8; /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance CUTLASS_DEVICE @@ -261,19 +261,45 @@ struct CollectiveMma< cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); } + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// that the tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + tile_input_tensors(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params, TileShapeMNK const& tileshape_mnk) { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, tileshape_mnk, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, tileshape_mnk, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + return cute::make_tuple(gA_mkl, gB_nkl); + } + /// Perform a collective-scoped matrix multiply-accumulate /// Producer Perspective template < - class TensorA, class TMA_LOAD_A, - class TensorB, class TMA_LOAD_B, - class KTileIterator + class TensorA, class TensorB, + class KTileIterator, class BlockCoord > CUTLASS_DEVICE void load( + Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write, - TensorA const& gA, TMA_LOAD_A& tma_load_a, - TensorB const& gB, TMA_LOAD_B& tma_load_b, + cute::tuple const& tiled_tensors, + BlockCoord const& blk_coord, KTileIterator k_tile_iter, int k_tile_count, int thread_idx, uint32_t block_rank_in_cluster, @@ -296,8 +322,16 @@ struct CollectiveMma< constexpr uint32_t cluster_shape_x = get<0>(DispatchPolicy::ClusterShape()); uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - auto block_tma_a = tma_load_a.get_slice(cluster_local_block_id.y); - auto block_tma_b = tma_load_b.get_slice(cluster_local_block_id.x); + Tensor gA_mkl = get<0>(tiled_tensors); + Tensor gB_nkl = get<1>(tiled_tensors); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) // Applies the mapping from block_tma_a Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) @@ -340,8 +374,8 @@ struct CollectiveMma< BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); int write_stage = smem_pipe_write.index(); - copy(tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); - copy(tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); ++k_tile_iter; // Advance smem_pipe_write diff --git a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp index 0e16027139..7a67b24561 100644 --- a/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp +++ b/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp @@ -248,8 +248,8 @@ struct CollectiveMma< static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; static constexpr int K_PIPE_MMAS = 1; static constexpr uint32_t TmaTransactionBytes = - (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof(ElementA)))+ - (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof(ElementB))); + (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)) / 8+ + (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)) / 8; /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance CUTLASS_DEVICE @@ -259,19 +259,45 @@ struct CollectiveMma< cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); } + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// that the tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + tile_input_tensors(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params, TileShapeMNK const& tileshape_mnk) { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, tileshape_mnk, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, tileshape_mnk, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + return cute::make_tuple(gA_mkl, gB_nkl); + } + /// Perform a collective-scoped matrix multiply-accumulate /// Producer Perspective template < - class TensorA, class TMA_LOAD_A, - class TensorB, class TMA_LOAD_B, - class KTileIterator + class TensorA, class TensorB, + class KTileIterator, class BlockCoord > CUTLASS_DEVICE void load( + Params const& mainloop_params, MainloopPipeline pipeline, PipelineState smem_pipe_write, - TensorA const& gA, TMA_LOAD_A& tma_load_a, - TensorB const& gB, TMA_LOAD_B& tma_load_b, + cute::tuple const& tiled_tensors, + BlockCoord const& blk_coord, KTileIterator k_tile_iter, int k_tile_count, int thread_idx, uint32_t block_rank_in_cluster, @@ -294,8 +320,16 @@ struct CollectiveMma< constexpr uint32_t cluster_shape_x = get<0>(DispatchPolicy::ClusterShape()); uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; - auto block_tma_a = tma_load_a.get_slice(cluster_local_block_id.y); - auto block_tma_b = tma_load_b.get_slice(cluster_local_block_id.x); + Tensor gA_mkl = get<0>(tiled_tensors); + Tensor gB_nkl = get<1>(tiled_tensors); + + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) // Applies the mapping from block_tma_a Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) @@ -338,8 +372,8 @@ struct CollectiveMma< BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); int write_stage = smem_pipe_write.index(); - copy(tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); - copy(tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); ++k_tile_iter; // Advance smem_pipe_write diff --git a/include/cutlass/gemm/device/gemm_universal_adapter.h b/include/cutlass/gemm/device/gemm_universal_adapter.h index 896bff187d..33a2958c19 100644 --- a/include/cutlass/gemm/device/gemm_universal_adapter.h +++ b/include/cutlass/gemm/device/gemm_universal_adapter.h @@ -41,6 +41,7 @@ #include "cutlass/device_kernel.h" #include "cutlass/gemm/gemm.h" #include "cutlass/detail/layout.hpp" +#include "cutlass/detail/mma.hpp" #if !defined(__CUDACC_RTC__) #include "cutlass/cluster_launch.hpp" @@ -51,6 +52,7 @@ #include "cutlass/gemm/device/gemm_universal_base.h" #include "cutlass/gemm/kernel/gemm_transpose_operands.h" #include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/epilogue/threadblock/epilogue_with_visitor_callbacks.h" // 3.x #include "cutlass/gemm/kernel/gemm_universal.hpp" @@ -111,10 +113,7 @@ class GemmUniversalAdapter< // Legacy: Assume MultiplyAdd only since we do not use this tag type in 3.0 using MathOperator = cutlass::arch::OpMultiplyAdd; - // All tensorop operations have atom shape's M >= 8 - using OperatorClass = cute::conditional_t< - cute::size<0>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}) >= 8, - cutlass::arch::OpClassTensorOp, cutlass::arch::OpClassSimt>; + using OperatorClass = cutlass::detail::get_operator_class_t; using ArchTag = typename GemmKernel::ArchTag; @@ -398,6 +397,7 @@ class GemmUniversalAdapter< using GemmKernel = GemmKernel_; static bool const kInternalTranspose = + !cutlass::epilogue::threadblock::detail::is_2x_evt_v && // 2.x EVT does not require internal transpose cute::is_same::value; using ThreadblockShape = typename GemmKernel::Mma::Shape; diff --git a/include/cutlass/gemm/dispatch_policy.hpp b/include/cutlass/gemm/dispatch_policy.hpp index f122fe0384..ea5261155e 100644 --- a/include/cutlass/gemm/dispatch_policy.hpp +++ b/include/cutlass/gemm/dispatch_policy.hpp @@ -43,23 +43,42 @@ using namespace cute; ////////////////////////////////////////////////////////////////////////////// // -// Policies for categorical dispatch of mainloop against kernel grid schedules +// Kernel schedule policies (the base class tags, one for each kernel layer file) // struct KernelMultistage { }; +struct KernelCpAsyncWarpSpecialized { }; +struct KernelCpAsyncWarpSpecializedPingpong { }; +struct KernelCpAsyncWarpSpecializedCooperative { }; struct KernelTma { }; struct KernelTmaWarpSpecialized { }; struct KernelTmaWarpSpecializedPingpong { }; struct KernelTmaWarpSpecializedCooperative { }; +////////////////////////////////////////////////////////////////////////////// + +// +// Builder dispatch policies (not a part of the main CUTLASS layers, simply used to opt into +// specific collective builder dispatches) +// + // FP8 related policies (including Fast Accumulation) struct KernelTmaWarpSpecializedFP8FastAccum : KernelTmaWarpSpecialized { }; struct KernelTmaWarpSpecializedPingpongFP8FastAccum : KernelTmaWarpSpecializedPingpong { }; struct KernelTmaWarpSpecializedCooperativeFP8FastAccum: KernelTmaWarpSpecializedCooperative { }; +// Policies to opt into mixed type GEMMs +struct KernelTmaWarpSpecializedMixedInput : KernelTmaWarpSpecialized { }; +struct KernelTmaWarpSpecializedPingpongMixedInput : KernelTmaWarpSpecializedPingpong { }; +struct KernelTmaWarpSpecializedCooperativeMixedInput: KernelTmaWarpSpecializedCooperative { }; + +////////////////////////////////////////////////////////////////////////////// + // Policies for dispatch of epilogue struct EpilogueDefault { }; struct EpilogueTransposed { }; +////////////////////////////////////////////////////////////////////////////// + // // Collective Mainloop Policies // @@ -98,28 +117,30 @@ struct MainloopSm80CpAsync { using ClusterShape = Shape<_1,_1,_1>; }; -// n-buffer in smem (cp.async), pipelined with Hopper GMMA, WITHOUT predicated gmem loads +// n-buffer in smem (cp.async), pipelined with Hopper GMMA, with predicated gmem loads, warp specialized dynamic schedule template< int Stages_, - class ClusterShape_ = Shape<_1,_1,_1> + class ClusterShape_ = Shape<_1,_1,_1>, + class KernelSchedule = KernelCpAsyncWarpSpecialized > -struct MainloopSm90CpAsyncGmmaUnpredicated { +struct MainloopSm90CpAsyncGmmaWarpSpecialized { constexpr static int Stages = Stages_; using ClusterShape = ClusterShape_; using ArchTag = arch::Sm90; - using Schedule = KernelMultistage; + using Schedule = KernelSchedule; }; -// n-buffer in smem (cp.async), pipelined with Hopper GMMA, with predicated gmem loads +// n-buffer in smem (cp.async), pipelined with Hopper GMMA, with predicated gmem loads, warp specialized dynamic schedule template< int Stages_, - class ClusterShape_ = Shape<_1,_1,_1> + class ClusterShape_ = Shape<_1,_1,_1>, + class KernelSchedule = KernelCpAsyncWarpSpecialized > -struct MainloopSm90CpAsyncGmma { +struct MainloopSm90CpAsyncGmmaRmemAWarpSpecialized { constexpr static int Stages = Stages_; using ClusterShape = ClusterShape_; using ArchTag = arch::Sm90; - using Schedule = KernelMultistage; + using Schedule = KernelSchedule; }; // n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, static schedule between TMA and GMMA @@ -154,13 +175,11 @@ struct MainloopSm90TmaGmmaWarpSpecialized { template< int Stages_, class ClusterShape_ = Shape<_1,_1,_1>, - class KernelSchedule = KernelTmaWarpSpecialized, - int PipelineAsyncMmaStages_ = 0 + class KernelSchedule = KernelTmaWarpSpecialized > struct MainloopSm90TmaGmmaRmemAWarpSpecialized { constexpr static int Stages = Stages_; using ClusterShape = ClusterShape_; - constexpr static int PipelineAsyncMmaStages = PipelineAsyncMmaStages_; using ArchTag = arch::Sm90; using Schedule = KernelSchedule; static_assert( @@ -170,6 +189,26 @@ struct MainloopSm90TmaGmmaRmemAWarpSpecialized { "KernelSchedule must be one of the warp specialized policies"); }; +template< + int Stages_, + class ClusterShape_ = Shape<_1,_1,_1>, + class KernelSchedule = KernelTmaWarpSpecialized +> +struct MainloopSm90TmaGmmaRmemAWarpSpecializedMixedInput { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm90; + using Schedule = KernelSchedule; + static_assert( + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v, + "KernelSchedule must be one of the warp specialized policies"); +}; + // n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, Warp specialized dynamic schedule // For FP8 kernels template< diff --git a/include/cutlass/gemm/kernel/gemm_streamk_with_fused_epilogue.h b/include/cutlass/gemm/kernel/gemm_streamk_with_fused_epilogue.h index 6d6714d805..aba6e7fd85 100644 --- a/include/cutlass/gemm/kernel/gemm_streamk_with_fused_epilogue.h +++ b/include/cutlass/gemm/kernel/gemm_streamk_with_fused_epilogue.h @@ -375,14 +375,17 @@ struct GemmStreamkWithFusedEpilogue // Initialize the block mapping structure block_mapping = ThreadblockSwizzle( - typename ThreadblockSwizzle::template KernelTraits(), args.mode, args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count, sm_occupancy, device_sms, - avail_sms); + avail_sms, + sizeof(ElementA), + sizeof(ElementB), + sizeof(ElementC), + Epilogue::kAccumulatorFragments); } /// Returns the workspace size (in bytes) needed for these parameters diff --git a/include/cutlass/gemm/kernel/gemm_universal.hpp b/include/cutlass/gemm/kernel/gemm_universal.hpp index 4e046ddd3e..7bd98ce7d4 100644 --- a/include/cutlass/gemm/kernel/gemm_universal.hpp +++ b/include/cutlass/gemm/kernel/gemm_universal.hpp @@ -69,6 +69,9 @@ class GemmUniversal; #include "cutlass/gemm/kernel/sm70_gemm.hpp" #include "cutlass/gemm/kernel/sm90_gemm_tma.hpp" +#include "cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp" +#include "cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp" +#include "cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp" #include "cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp" #include "cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp" #include "cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp" diff --git a/include/cutlass/gemm/kernel/gemm_universal_streamk.h b/include/cutlass/gemm/kernel/gemm_universal_streamk.h index c52a15c3c2..eaa3cfc1ba 100644 --- a/include/cutlass/gemm/kernel/gemm_universal_streamk.h +++ b/include/cutlass/gemm/kernel/gemm_universal_streamk.h @@ -357,14 +357,17 @@ struct GemmUniversalStreamk { // Initialize the block mapping structure block_mapping = ThreadblockSwizzle( - typename ThreadblockSwizzle::template KernelTraits(), args.mode, args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count, sm_occupancy, device_sms, - avail_sms); + avail_sms, + sizeof(ElementA), + sizeof(ElementB), + sizeof(ElementC), + Epilogue::kAccumulatorFragments); } diff --git a/include/cutlass/gemm/kernel/gemm_universal_with_visitor_streamk.h b/include/cutlass/gemm/kernel/gemm_universal_with_visitor_streamk.h index dd2c52f46b..50ecfbee3c 100644 --- a/include/cutlass/gemm/kernel/gemm_universal_with_visitor_streamk.h +++ b/include/cutlass/gemm/kernel/gemm_universal_with_visitor_streamk.h @@ -233,14 +233,17 @@ class GemmWithEpilogueVisitorStreamk { // Initialize the block mapping structure block_mapping = ThreadblockSwizzle( - typename ThreadblockSwizzle::template KernelTraits(), args.mode, args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.batch_count, sm_occupancy, device_sms, - avail_sms); + avail_sms, + sizeof(ElementA), + sizeof(ElementB), + sizeof(ElementC), + Epilogue::kAccumulatorFragments); } diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp index e5ae25a70c..cb4baf4d99 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp @@ -323,26 +323,26 @@ class GemmUniversal< static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); - // Separate out problem shape for convenience - // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); - auto M = get<0>(problem_shape_MNKL); - auto N = get<1>(problem_shape_MNKL); - auto K = get<2>(problem_shape_MNKL); - auto L = get<3>(problem_shape_MNKL); - - // TMA requires special handling of strides to deal with coord codomain mapping - // Represent the full tensors -- get these from TMA - Tensor mA_mkl = params.mainloop.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) - Tensor mB_nkl = params.mainloop.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) // Get the appropriate blocks for this thread block -- potential for thread block locality auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) TiledMma tiled_mma; - // Make tiled views, defer the slice - Tensor gA_mkl = local_tile(mA_mkl, blk_shape, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) - Tensor gB_nkl = local_tile(mB_nkl, blk_shape, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Prepare and partition the input tensors. Expects a tuple of tensors where: + // get<0>(tiled_tensors) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) + // get<1>(tiled_tensors) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) + auto tiled_tensors = collective_mainloop.tile_input_tensors(problem_shape_MNKL, params.mainloop, blk_shape); + static_assert(tuple_size_v >= 2, "Output of tile_input_tensors must have at least two elements (A, B)"); + + // Extract out partitioned A and B. + Tensor gA_mkl = get<0>(tiled_tensors); + Tensor gB_nkl = get<1>(tiled_tensors); // Compute m_coord, n_coord, and l_coord with their post-tiled shapes auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mkl)); @@ -350,28 +350,21 @@ class GemmUniversal< auto l_coord = idx2crd(int(blockIdx.z), shape<4>(gB_nkl)); auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - // Slice with m_coord and n_coord - Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) - // Get pipeline iterators and increments from tensor shapes - auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); - auto k_tile_count = size<2>(gA); + auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl)); + auto k_tile_count = size<3>(gA_mkl); // Wait for all thread blocks in the Cluster cluster_wait_fn(); - // In a warp specialized kernel, collectives expose data movement and compute operations separately - CollectiveMainloop collective_mainloop; - CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); - if (warp_group_role == WarpGroupRole::Producer) { if (producer_warp_role == ProducerWarpRole::MainloopEpilogue) { collective_mainloop.load( + params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state, - gA, params.mainloop.tma_load_a, - gB, params.mainloop.tma_load_b, + tiled_tensors, + blk_coord, k_tile_iter, k_tile_count, lane_idx, block_rank_in_cluster, @@ -408,7 +401,7 @@ class GemmUniversal< mainloop_pipe_consumer_state, accumulators, k_tile_count, - thread_idx, + warp_group_thread_idx, shared_storage.tensors.mainloop, params.mainloop ); diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp index 7ad54f4afc..ca1ce1a424 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp @@ -31,6 +31,7 @@ #pragma once #include "cutlass/cutlass.h" +#include "cutlass/workspace.h" #include "cutlass/fast_math.h" #include "cutlass/kernel_hardware_info.hpp" #include "cute/arch/cluster_sm90.hpp" @@ -187,16 +188,29 @@ class GemmUniversal< CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; - TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments( - problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, workspace); + + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + void* scheduler_workspace = workspace_ptr; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* mainloop_workspace = nullptr; return { args.mode, problem_shape, - CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), - CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace), + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), hw_info, - scheduler, + TileScheduler::to_underlying_arguments(problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace), workspace }; } @@ -215,19 +229,42 @@ class GemmUniversal< return implementable; } - static int + static size_t get_workspace_size(Arguments const& args) { - TileScheduler t; - return t.template get_workspace_size( + size_t workspace_size = 0; + workspace_size += TileScheduler::template get_workspace_size( args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; } - static - cutlass::Status + static cutlass::Status initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { - TileScheduler t; - return t.template initialize_workspace( - args.scheduler, workspace, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups); + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; } // Computes the kernel launch grid shape based on runtime parameters @@ -368,30 +405,13 @@ class GemmUniversal< } } (); - // Separate out problem shape for convenience - // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); - auto M = get<0>(problem_shape_MNKL); - auto N = get<1>(problem_shape_MNKL); - auto K = get<2>(problem_shape_MNKL); - auto L = get<3>(problem_shape_MNKL); - - // TMA requires special handling of strides to deal with coord codomain mapping - // Represent the full tensors -- get these from TMA - Tensor mA_mkl = params.mainloop.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) - Tensor mB_nkl = params.mainloop.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) // Get the appropriate blocks for this thread block -- potential for thread block locality TiledMma tiled_mma; auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) - // Make tiled views, defer the slice - Tensor gA_mkl = local_tile(mA_mkl, blk_shape, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) - Tensor gB_nkl = local_tile(mB_nkl, blk_shape, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) - - // Get pipeline stage increments from tensor shapes - auto k_tile_count = size<3>(gA_mkl); - TileScheduler scheduler{params.scheduler}; auto work_tile_info = scheduler.get_current_work(); @@ -399,6 +419,19 @@ class GemmUniversal< CollectiveMainloop collective_mainloop; CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + // Prepare and partition the input tensors. Expects a tuple of tensors where: + // get<0>(tiled_tensors) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) + // get<1>(tiled_tensors) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) + auto tiled_tensors = collective_mainloop.tile_input_tensors(problem_shape_MNKL, params.mainloop, blk_shape); + static_assert(tuple_size_v >= 2, "Output of tile_input_tensors must have at least two elements (A, B)"); + + // Extract out partitioned A and B. + Tensor gA_mkl = get<0>(tiled_tensors); + Tensor gB_nkl = get<1>(tiled_tensors); + + // Get pipeline stage increments from tensor shapes + auto k_tile_count = size<3>(gA_mkl); + // Wait for all thread blocks in the Cluster cluster_wait_fn(); @@ -408,27 +441,24 @@ class GemmUniversal< // Mainloop Producer Warp if (producer_warp_role == ProducerWarpRole::Mainloop) { bool do_load_order_arrive = true; - while (work_tile_info.is_valid_tile) { + while (work_tile_info.is_valid()) { // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - // Slice with our work tile coordinates to construct mainloop tensor views - Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) - // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); - auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<2>(gA)), shape<2>(gA)); + auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<3>(gA_mkl)), shape<3>(gA_mkl)); collective_mainloop.load( + params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state, - gA, params.mainloop.tma_load_a, - gB, params.mainloop.tma_load_b, + tiled_tensors, + blk_coord, k_tile_iter, work_k_tile_count, lane_idx, block_rank_in_cluster, @@ -454,8 +484,8 @@ class GemmUniversal< // Epilogue Producer Warp else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) { load_order_barrier.wait(); - while (work_tile_info.is_valid_tile) { - if (TileScheduler::compute_epilogue(work_tile_info)) { + while (work_tile_info.is_valid()) { + if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); @@ -489,7 +519,7 @@ class GemmUniversal< // Do we potentially issue tail arrives for TMA stores, if epilogue load is waiting for it bool do_store_tail = false; - while (work_tile_info.is_valid_tile) { + while (work_tile_info.is_valid()) { // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); @@ -529,7 +559,7 @@ class GemmUniversal< TileScheduler::fixup( params.scheduler, work_tile_info, accumulators, NumMmaWarpGroups, consumer_warp_group_idx); - if (TileScheduler::compute_epilogue(work_tile_info)) { + if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { // Epilogue and write to gD auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = collective_epilogue.store( diff --git a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp index dd1f5a6b0c..845ed861e3 100644 --- a/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp +++ b/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp @@ -31,6 +31,7 @@ #pragma once #include "cutlass/cutlass.h" +#include "cutlass/workspace.h" #include "cutlass/kernel_hardware_info.hpp" #include "cutlass/fast_math.h" #include "cute/arch/cluster_sm90.hpp" @@ -196,13 +197,28 @@ class GemmUniversal< CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + // Calculate workspace pointers + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + void* scheduler_workspace = workspace_ptr; + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* epilogue_workspace = workspace_ptr + workspace_offset; + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + + void* mainloop_workspace = nullptr; + return { args.mode, problem_shape, - CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), - CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace), + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, mainloop_workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, epilogue_workspace), hw_info, - TileScheduler::to_underlying_arguments(problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler) + TileScheduler::to_underlying_arguments(problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, scheduler_workspace) }; } @@ -220,16 +236,42 @@ class GemmUniversal< return implementable; } - static - int + static size_t get_workspace_size(Arguments const& args) { - return 0; + size_t workspace_size = 0; + workspace_size += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; } - static - cutlass::Status + static cutlass::Status initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { - return Status::kSuccess; + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + status = TileScheduler::template initialize_workspace( + args.scheduler, workspace_ptr + workspace_offset, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_offset += TileScheduler::template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = CollectiveEpilogue::initialize_workspace(args.problem_shape, args.epilogue, workspace_ptr + workspace_offset, stream); + workspace_offset += CollectiveEpilogue::get_workspace_size(args.problem_shape, args.epilogue); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; } // Computes the kernel launch grid shape based on runtime parameters @@ -371,25 +413,26 @@ class GemmUniversal< } (); // Separate out problem shape for convenience - // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); - auto M = get<0>(problem_shape_MNKL); - auto N = get<1>(problem_shape_MNKL); - auto K = get<2>(problem_shape_MNKL); - auto L = get<3>(problem_shape_MNKL); - - // TMA requires special handling of strides to deal with coord codomain mapping - // Represent the full tensors -- get these from TMA - Tensor mA_mkl = params.mainloop.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) - Tensor mB_nkl = params.mainloop.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) // Get the appropriate blocks for this thread block -- potential for thread block locality TiledMma tiled_mma; auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) - // Make tiled views, defer the slice - Tensor gA_mkl = local_tile(mA_mkl, blk_shape, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) - Tensor gB_nkl = local_tile(mB_nkl, blk_shape, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Prepare and partition the input tensors. Expects a tuple of tensors where: + // get<0>(tiled_tensors) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) + // get<1>(tiled_tensors) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) + auto tiled_tensors = collective_mainloop.tile_input_tensors(problem_shape_MNKL, params.mainloop, blk_shape); + static_assert(tuple_size_v >= 2, "Output of tile_input_tensors must have at least two elements (A, B)"); + + // Extract out partitioned A and B. + Tensor gA_mkl = get<0>(tiled_tensors); + Tensor gB_nkl = get<1>(tiled_tensors); // Get pipeline stage increments from tensor shapes auto k_tile_count = size<3>(gA_mkl); @@ -408,10 +451,6 @@ class GemmUniversal< } auto work_tile_info = scheduler.get_current_work(); - // In a warp specialized kernel, collectives expose data movement and compute operations separately - CollectiveMainloop collective_mainloop; - CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); - // Wait for all thread blocks in the Cluster cluster_wait_fn(); @@ -421,24 +460,21 @@ class GemmUniversal< // Mainloop Producer Warp if (producer_warp_role == ProducerWarpRole::Mainloop) { bool do_load_order_arrive = true; - while (work_tile_info.is_valid_tile) { + while (work_tile_info.is_valid()) { // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); - // Slice with our work tile coordinates to construct mainloop tensor views - Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) - Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) - - auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); + auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl)); collective_mainloop.load( + params.mainloop, mainloop_pipeline, mainloop_pipe_producer_state, - gA, params.mainloop.tma_load_a, - gB, params.mainloop.tma_load_b, + tiled_tensors, + blk_coord, k_tile_iter, k_tile_count, lane_idx, block_rank_in_cluster, @@ -465,7 +501,7 @@ class GemmUniversal< // Epilogue Producer Warp else if (producer_warp_role == ProducerWarpRole::Epilogue && collective_epilogue.is_producer_load_needed()) { load_order_barrier.wait(); - while (work_tile_info.is_valid_tile) { + while (work_tile_info.is_valid()) { // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); @@ -497,7 +533,7 @@ class GemmUniversal< else if (warp_group_role == WarpGroupRole::Consumer0 || warp_group_role == WarpGroupRole::Consumer1) { cutlass::arch::warpgroup_reg_alloc(); - while (work_tile_info.is_valid_tile) { + while (work_tile_info.is_valid()) { // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); @@ -515,7 +551,7 @@ class GemmUniversal< mainloop_pipe_consumer_state, accumulators, k_tile_count, - thread_idx, + warp_group_thread_idx, shared_storage.tensors.mainloop, params.mainloop ); diff --git a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp new file mode 100644 index 0000000000..bdcfa4ef01 --- /dev/null +++ b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized.hpp @@ -0,0 +1,417 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/arch/mma_sm90.h" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cute/tensor.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileScheduler_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileScheduler_, + cute::enable_if_t>> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + static_assert(ArchTag::kMinComputeCapability >= 90); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(cute::is_void_v or cute::is_same_v, + "Non-persistent warp-specialized kernel does not support specializing the tile scheduler."); + using TileSchedulerTag = TileScheduler_; + using TileScheduler = typename detail::TileSchedulerSelector< + TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + + // Kernel level shared memory storage + struct SharedStorage { + union TensorStorage { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + struct PipelineStorage : cute::aligned_struct<16> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + } pipelines; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + using GmemTiledCopyA = typename CollectiveMainloop::GmemTiledCopyA; + using GmemTiledCopyB = typename CollectiveMainloop::GmemTiledCopyB; + static_assert(cute::size(GmemTiledCopyA{}) == cute::size(GmemTiledCopyB{}), "Number of threads in A/B tiled copies must be the same."); + + static constexpr uint32_t NumLoadWarpGroups = cute::size(GmemTiledCopyA{}) / NumThreadsPerWarpGroup; + static constexpr uint32_t NumMmaWarpGroups = cute::size(TiledMma{}) / NumThreadsPerWarpGroup; + static constexpr uint32_t NumWarpGroups = NumLoadWarpGroups + NumMmaWarpGroups; + static_assert(NumWarpGroups == 2 || NumWarpGroups == 3, "Number of warp groups must be 2 or 3 for good performance."); + + static constexpr uint32_t MaxThreadsPerBlock = NumWarpGroups * NumThreadsPerWarpGroup; + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode; + ProblemShape problem_shape; + MainloopParams mainloop; + EpilogueParams epilogue; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + (void) workspace; + auto problem_shape = args.problem_shape; + if constexpr (detail::IF_SWAP_AB::value) { + // swap M/N + get<0>(problem_shape) = get<1>(args.problem_shape); + get<1>(problem_shape) = get<0>(args.problem_shape); + } + return { + args.mode, + problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace) + }; + } + + CUTLASS_HOST_DEVICE static + bool + can_implement(Arguments const& args) { + bool implementable = (args.mode == GemmUniversalMode::kGemm) or + (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + return implementable; + } + + static + int + get_workspace_size(Arguments const& args) { + return 0; + } + + static + cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + return Status::kSuccess; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + auto cluster_shape = Shape<_1,_1,_1>{}; + auto tile_shape = TileShape{}; + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + return TileScheduler::get_tiled_cta_shape_mnl( + problem_shape_MNKL, tile_shape, cluster_shape); + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + using namespace cute; + using X = Underscore; + + // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. + #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) + if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { + printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); + return; + } + #endif + + enum class WarpGroupRole { + Producer = 0, + Consumer = 1, + }; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + int thread_idx = int(threadIdx.x); + int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; + int warp_group_idx = canonical_warp_group_idx(); + CUTLASS_ASSERT(warp_group_idx < NumWarpGroups); + WarpGroupRole warp_group_role = warp_group_idx < NumLoadWarpGroups ? WarpGroupRole::Producer : WarpGroupRole::Consumer; + + // Mainloop Load pipeline + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + typename MainloopPipeline::Params mainloop_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.producer_arv_count = NumLoadWarpGroups * NumThreadsPerWarpGroup; + mainloop_pipeline_params.consumer_arv_count = NumMmaWarpGroups * NumThreadsPerWarpGroup; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params); + + // Epilogue Load pipeline + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.producer_arv_count = NumLoadWarpGroups * NumThreadsPerWarpGroup; + epi_load_pipeline_params.consumer_arv_count = NumMmaWarpGroups * NumThreadsPerWarpGroup; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + // Initialize starting pipeline states for the collectives + // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; + + // For the DMA Load (producer) we start with an opposite phase + // i.e., we skip all waits since we know that the buffer is indeed empty + PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + // Preconditions + static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + // Separate out problem shape for convenience + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + auto M = get<0>(problem_shape_MNKL); + auto N = get<1>(problem_shape_MNKL); + auto K = get<2>(problem_shape_MNKL); + auto L = get<3>(problem_shape_MNKL); + + // Represent the full tensors + Tensor mA_mkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_A), make_shape(M,K,L), params.mainloop.dA); //(m,k,l) + Tensor mB_nkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_B), make_shape(N,K,L), params.mainloop.dB); //(n,k,l) + + // Get the appropriate blocks for this thread block -- potential for thread block locality + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + TiledMma tiled_mma; + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, blk_shape, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, blk_shape, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + // Compute m_coord, n_coord, and l_coord with their post-tiled shapes + auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mkl)); + auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nkl)); + auto l_coord = idx2crd(int(blockIdx.z), shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Slice with m_coord and n_coord + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Get pipeline iterators and increments from tensor shapes + auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); + auto k_tile_count = size<2>(gA); + auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape); + auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape); + + // Wait for all threads in the thread block + __syncthreads(); + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue{params.epilogue, shared_storage.tensors.epilogue}; + + if (warp_group_role == WarpGroupRole::Producer) { + // Compute tile residues for predication + auto m_max_coord = M - size<0>(gA) * get<0>(blk_coord); // M - BLK_M * m_coord + auto n_max_coord = N - size<0>(gB) * get<1>(blk_coord); // N - BLK_N * n_coord + auto k_residue = K - size<1>(gA) * size<2>(gA); // K - BLK_K * k_coord_max + auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue); + + collective_mainloop.load( + mainloop_pipeline, + mainloop_pipe_producer_state, + gA, + gB, + k_tile_iter, k_tile_count, + residue_mnk, + thread_idx, + shared_storage.tensors.mainloop + ); + // Update starting mainloop pipeline state for the pipeline drain + mainloop_pipe_producer_state.advance(k_tile_count); + // Make sure mainloop consumer has been waited upon before issuing epilogue load + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + + if (collective_epilogue.is_producer_load_needed()) { + epi_load_pipe_producer_state = + collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + tiled_mma, + thread_idx, + shared_storage.tensors.epilogue + ); + collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + } + } + else if (warp_group_role == WarpGroupRole::Consumer) { + Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) + + collective_mainloop.mma( + mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators, + k_tile_count, + warp_group_thread_idx, + shared_storage.tensors.mainloop, + params.mainloop + ); + + // Make sure the math instructions are done and free buffers before entering the epilogue + collective_mainloop.mma_tail( + mainloop_pipeline, + mainloop_pipe_consumer_state, + k_tile_count + ); + + // Epilogue and write to gD + collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + accumulators, + tiled_mma, + warp_group_thread_idx, + shared_storage.tensors.epilogue + ); + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp new file mode 100644 index 0000000000..7a8c5fa74a --- /dev/null +++ b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_cooperative.hpp @@ -0,0 +1,518 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/arch/mma_sm90.h" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cute/tensor.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileScheduler_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileScheduler_, + cute::enable_if_t>> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + static_assert(ArchTag::kMinComputeCapability >= 90); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + using TileSchedulerTag = TileScheduler_; + using TileScheduler = typename detail::TileSchedulerSelector< + TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + using GmemTiledCopyA = typename CollectiveMainloop::GmemTiledCopyA; + using GmemTiledCopyB = typename CollectiveMainloop::GmemTiledCopyB; + static_assert(cute::size(GmemTiledCopyA{}) == cute::size(GmemTiledCopyB{}), "Number of threads in A/B tiled copies must be the same"); + + static constexpr uint32_t NumLoadWarpGroups = cute::size(GmemTiledCopyA{}) / NumThreadsPerWarpGroup; + static constexpr uint32_t NumMmaWarpGroups = cute::size(TiledMma{}) / NumThreadsPerWarpGroup; + static constexpr uint32_t NumWarpGroups = NumLoadWarpGroups + NumMmaWarpGroups; + static_assert(NumWarpGroups == 2 || NumWarpGroups == 3, "Number of warp groups must be 2 or 3 for good performance."); + + static constexpr uint32_t MaxThreadsPerBlock = NumWarpGroups * NumThreadsPerWarpGroup; + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + // Kernel level shared memory storage + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128> { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + struct PipelineStorage : cute::aligned_struct<16> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + } pipelines; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode; + ProblemShape problem_shape; + MainloopParams mainloop; + EpilogueParams epilogue; + KernelHardwareInfo hw_info; + TileSchedulerParams scheduler; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + + auto problem_shape = args.problem_shape; + if constexpr (detail::IF_SWAP_AB::value) { + // swap M/N + get<0>(problem_shape) = get<1>(args.problem_shape); + get<1>(problem_shape) = get<0>(args.problem_shape); + } + auto problem_shape_MNKL = append<4>(problem_shape, 1); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments( + problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, workspace); + + return { + args.mode, + problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace), + hw_info, + scheduler + }; + } + + CUTLASS_HOST_DEVICE static + bool + can_implement(Arguments const& args) { + bool implementable = (args.mode == GemmUniversalMode::kGemm) or + (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + return implementable; + } + + static + int + get_workspace_size(Arguments const& args) { + TileScheduler t; + return t.template get_workspace_size( + args.scheduler, args.problem_shape, args.hw_info, NumMmaWarpGroups); + } + + static + cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + TileScheduler t; + return t.template initialize_workspace( + args.scheduler, workspace, stream, args.problem_shape, args.hw_info, NumMmaWarpGroups); + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently + TileSchedulerArguments args{}; + if constexpr (!std::is_const_v) { + args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_; + } + return TileScheduler::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args); + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + using namespace cute; + using X = Underscore; + + // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. + #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) + if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { + printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); + return; + } + #endif + + static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + /* In the Cooperative kernel, one or multiple Consumers collaborate on the same tile */ + enum class WarpGroupRole { + Producer = 0, + Consumer = 1, + }; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + int thread_idx = int(threadIdx.x); + int mma_thread_idx = thread_idx % size(TiledMma{}); + int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; + int warp_group_idx = canonical_warp_group_idx(); + CUTLASS_ASSERT(warp_group_idx < NumWarpGroups); + WarpGroupRole warp_group_role = warp_group_idx < NumLoadWarpGroups ? WarpGroupRole::Producer : WarpGroupRole::Consumer; + + // Mainloop Load pipeline + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + typename MainloopPipeline::Params mainloop_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.producer_arv_count = NumLoadWarpGroups * NumThreadsPerWarpGroup; + mainloop_pipeline_params.consumer_arv_count = NumMmaWarpGroups * NumThreadsPerWarpGroup; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params); + + // Epilogue Load pipeline + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.producer_arv_count = NumLoadWarpGroups * NumThreadsPerWarpGroup; + epi_load_pipeline_params.consumer_arv_count = NumMmaWarpGroups * NumThreadsPerWarpGroup; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + // Initialize starting pipeline states for the collectives + // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; + + // For the DMA Load (producer) we start with an opposite phase + // i.e., we skip all waits since we know that the buffer is indeed empty + PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + // Separate out problem shape for convenience + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + auto M = get<0>(problem_shape_MNKL); + auto N = get<1>(problem_shape_MNKL); + auto K = get<2>(problem_shape_MNKL); + auto L = get<3>(problem_shape_MNKL); + + // Represent the full tensors + Tensor mA_mkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_A), make_shape(M,K,L), params.mainloop.dA); //(m,k,l) + Tensor mB_nkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_B), make_shape(N,K,L), params.mainloop.dB); //(n,k,l) + + // Get the appropriate blocks for this thread block -- potential for thread block locality + TiledMma tiled_mma; + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, blk_shape, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, blk_shape, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + TileScheduler scheduler{params.scheduler}; + auto work_tile_info = scheduler.get_current_work(); + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue{params.epilogue, shared_storage.tensors.epilogue}; + + // Wait for all threads in the thread block + __syncthreads(); + + if (warp_group_role == WarpGroupRole::Producer) { + + while (work_tile_info.is_valid()) { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Slice with our work tile coordinates to construct mainloop tensor views + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Get the number of K tiles to compute for this work as well as the starting K tile offset of the work. + auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); + auto work_k_tile_start = TileScheduler::get_work_k_tile_start(work_tile_info); + auto k_tile_iter = cute::make_coord_iterator(idx2crd(work_k_tile_start, shape<2>(gA)), shape<2>(gA)); + + // Compute tile residues for predication + auto m_max_coord = M - size<0>(gA) * get<0>(blk_coord); // M - BLK_M * m_coord + auto n_max_coord = N - size<0>(gB) * get<1>(blk_coord); // N - BLK_N * n_coord + auto k_residue = K - size<1>(gA) * size<2>(gA); // K - BLK_K * k_coord_max + auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue); + + collective_mainloop.load( + mainloop_pipeline, + mainloop_pipe_producer_state, + gA, + gB, + k_tile_iter, work_k_tile_count, + residue_mnk, + thread_idx, + shared_storage.tensors.mainloop + ); + // Update starting pipeline state for the next tile + mainloop_pipe_producer_state.advance(work_k_tile_count); + + if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler) && + collective_epilogue.is_producer_load_needed()) { + epi_load_pipe_producer_state = + collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + tiled_mma, + warp_group_thread_idx, + shared_storage.tensors.epilogue + ); + } + + // Get next work tile + work_tile_info = fetch_next_work(work_tile_info, scheduler); + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + + if (collective_epilogue.is_producer_load_needed()) { + collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + } + } // Producer Warp Group End + + else if (warp_group_role == WarpGroupRole::Consumer) { + + bool do_store_tail = false; + while (work_tile_info.is_valid()) { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + auto work_k_tile_count = TileScheduler::get_work_k_tile_count(work_tile_info, problem_shape_MNKL, blk_shape); + + // Allocate the the accumulators for the (M,N) blk_shape + // + // MSVC CTAD breaks if we say "Tensor" here, so we use "auto" instead. + auto accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) + + collective_mainloop.mma( + mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators, + work_k_tile_count, + mma_thread_idx, + shared_storage.tensors.mainloop, + params.mainloop + ); + + // Make sure the math instructions are done and free buffers before entering the epilogue + collective_mainloop.mma_tail( + mainloop_pipeline, + mainloop_pipe_consumer_state, + work_k_tile_count + ); + + // Update starting mainloop pipeline state for the next tile + mainloop_pipe_consumer_state.advance(work_k_tile_count); + + // Index of warp group within consumer warp groups + int consumer_warp_group_idx = canonical_warp_group_idx() - NumLoadWarpGroups; + + // Perform reduction across splits, if needed + TileScheduler::fixup( + params.scheduler, work_tile_info, accumulators, NumMmaWarpGroups, consumer_warp_group_idx); + + if (TileScheduler::compute_epilogue(work_tile_info, params.scheduler)) { + // Epilogue and write to gD + auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = + collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + accumulators, + tiled_mma, + mma_thread_idx, + shared_storage.tensors.epilogue + ); + epi_load_pipe_consumer_state = epi_load_pipe_consumer_state_next; + epi_store_pipe_producer_state = epi_store_pipe_producer_state_next; + do_store_tail = true; + } + + // Get next work tile + work_tile_info = fetch_next_work(work_tile_info, scheduler); + } // Scheduler work fetch loop + + if (do_store_tail) { + collective_epilogue.store_tail( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state + ); + } + } // Consumer Warp Groups End + } + +private: + // Kernel helper function to get next work unit + CUTLASS_DEVICE + typename TileScheduler::WorkTileInfo + fetch_next_work( + typename TileScheduler::WorkTileInfo& work_tile_info, + TileScheduler& scheduler) const { + // Check whether we should continue on with the current work unit. If this is the case, + // the work unit will have been updated in continue_current_work to reflect the new + // tile to be computed. + if (scheduler.continue_current_work(work_tile_info)) { + return work_tile_info; + } + + // Get next work tile + scheduler.advance_to_next_work(); + return scheduler.get_current_work(); + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp new file mode 100644 index 0000000000..f43ff562d0 --- /dev/null +++ b/include/cutlass/gemm/kernel/sm90_gemm_warpspecialized_pingpong.hpp @@ -0,0 +1,516 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/fast_math.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/arch/mma_sm90.h" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" + +#include "cute/tensor.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileScheduler_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileScheduler_, + cute::enable_if_t>> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(rank(ProblemShape{}) == 3 or rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + static_assert(ArchTag::kMinComputeCapability >= 90); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(!cute::is_same_v, "Ping-pong kernel does not currently support stream-K scheduler."); + using TileSchedulerTag = TileScheduler_; + using TileScheduler = typename detail::TileSchedulerSelector< + TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + using TileSchedulerParams = typename TileScheduler::Params; + + using GmemTiledCopyA = typename CollectiveMainloop::GmemTiledCopyA; + using GmemTiledCopyB = typename CollectiveMainloop::GmemTiledCopyB; + static_assert(cute::size(GmemTiledCopyA{}) == cute::size(GmemTiledCopyB{}), "Number of threads in A/B tiled copies must be the same"); + + static constexpr uint32_t NumLoadWarpGroups = cute::size(GmemTiledCopyA{}) / NumThreadsPerWarpGroup; + static constexpr uint32_t NumMmaWarpGroups = 2 * cute::size(TiledMma{}) / NumThreadsPerWarpGroup; + static constexpr uint32_t NumWarpGroups = NumLoadWarpGroups + NumMmaWarpGroups; + static_assert(NumWarpGroups == 2 || NumWarpGroups == 3, "Number of warp groups must be 2 or 3 for good performance."); + static_assert(NumMmaWarpGroups == 2, "Pingpong kernel requires 2 MMA warp groups."); + + static constexpr uint32_t MaxThreadsPerBlock = NumWarpGroups * NumThreadsPerWarpGroup; + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + // Order Sequence barrier with two stages: one for Mainloop and one for Epilogue + static constexpr uint32_t StagesPerMathWarpGroup = 2; + using MathWarpGroupOrderBarrier = cutlass::OrderedSequenceBarrier< + StagesPerMathWarpGroup, NumMmaWarpGroups>; + + // Kernel level shared memory storage + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128> { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + struct PipelineStorage : cute::aligned_struct<16> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + using MathWarpGroupOrderBarrierStorage = typename MathWarpGroupOrderBarrier::SharedStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) MathWarpGroupOrderBarrierStorage math_wg_order; + } pipelines; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode; + ProblemShape problem_shape; + MainloopParams mainloop; + EpilogueParams epilogue; + KernelHardwareInfo hw_info; + TileSchedulerParams scheduler; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + CUTLASS_TRACE_HOST("to_underlying_arguments():"); + + (void) workspace; + auto problem_shape = args.problem_shape; + if constexpr (detail::IF_SWAP_AB::value) { + // swap M/N + get<0>(problem_shape) = get<1>(args.problem_shape); + get<1>(problem_shape) = get<0>(args.problem_shape); + } + auto problem_shape_MNKL = append<4>(problem_shape, 1); + + // Get SM count if needed, otherwise use user supplied SM count + int sm_count = args.hw_info.sm_count; + if (sm_count <= 0) { + CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n" + " For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count."); + sm_count = KernelHardwareInfo::query_device_multiprocessor_count(args.hw_info.device_id); + } + + CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count); + + KernelHardwareInfo hw_info{args.hw_info.device_id, sm_count}; + TileSchedulerParams scheduler = TileScheduler::to_underlying_arguments( + problem_shape_MNKL, TileShape{}, ClusterShape{}, hw_info, args.scheduler, workspace); + + return { + args.mode, + problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace), + hw_info, + scheduler + }; + } + + CUTLASS_HOST_DEVICE static + bool + can_implement(Arguments const& args) { + bool implementable = (args.mode == GemmUniversalMode::kGemm) or + (args.mode == GemmUniversalMode::kBatched && rank(ProblemShape{}) == 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + return implementable; + } + + static + int + get_workspace_size(Arguments const& args) { + return 0; + } + + static + cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) { + return Status::kSuccess; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + // Given device SM count, set grid size s.t. we do not launch more thread blocks than we can run concurrently + TileSchedulerArguments args{}; + if constexpr (!std::is_const_v) { + args.max_swizzle_size = 1 << params.scheduler.log_swizzle_size_; + } + return TileScheduler::get_grid_shape(params.problem_shape, TileShape{}, ClusterShape{}, params.hw_info, args); + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + using namespace cute; + using X = Underscore; + + // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. + #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) + if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { + printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); + return; + } + #endif + + // Preconditions + static_assert(rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + enum class WarpGroupRole { + Producer = 0, + Consumer = 1, + }; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + int thread_idx = int(threadIdx.x); + int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; + int warp_group_idx = canonical_warp_group_idx(); + CUTLASS_ASSERT(warp_group_idx < NumWarpGroups); + WarpGroupRole warp_group_role = warp_group_idx < NumLoadWarpGroups ? WarpGroupRole::Producer : WarpGroupRole::Consumer; + int warp_group_consumer_idx = warp_group_idx - NumLoadWarpGroups; + + // Mainloop Load pipeline + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + typename MainloopPipeline::Params mainloop_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.producer_arv_count = NumLoadWarpGroups * NumThreadsPerWarpGroup; + mainloop_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; // only 1 WG consumes at a time + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params); + + // Epilogue Load pipeline + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.producer_arv_count = NumLoadWarpGroups * NumThreadsPerWarpGroup; + epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; // only 1 WG consumes at a time + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + typename MathWarpGroupOrderBarrier::Params params_math_wg_order_barrier; + // DMA Load WG will not participate in these Ordered Barrier syncs + params_math_wg_order_barrier.group_id = warp_group_consumer_idx; + params_math_wg_order_barrier.group_size = NumThreadsPerWarpGroup; // Number of threads / participants in a group + MathWarpGroupOrderBarrier math_wg_order_barrier(shared_storage.pipelines.math_wg_order, params_math_wg_order_barrier); + + // Initialize starting pipeline states for the collectives + // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; + + // For the DMA Load (producer) we start with an opposite phase + // i.e., we skip all waits since we know that the buffer is indeed empty + PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + // Separate out problem shape for convenience + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + auto M = get<0>(problem_shape_MNKL); + auto N = get<1>(problem_shape_MNKL); + auto K = get<2>(problem_shape_MNKL); + auto L = get<3>(problem_shape_MNKL); + + // Represent the full tensors + Tensor mA_mkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_A), make_shape(M,K,L), params.mainloop.dA); //(m,k,l) + Tensor mB_nkl = make_tensor(make_gmem_ptr(params.mainloop.ptr_B), make_shape(N,K,L), params.mainloop.dB); //(n,k,l) + + // Get the appropriate blocks for this thread block -- potential for thread block locality + TiledMma tiled_mma; + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, blk_shape, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, blk_shape, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + // Get pipeline stage increments from tensor shapes + auto k_tile_count = size<3>(gA_mkl); + auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape); + auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape); + + TileScheduler scheduler{params.scheduler}; + + if (warp_group_consumer_idx == 1) { + // Advance 2nd Math WG to the next work tile for the startup + scheduler.advance_to_next_work(); + // Advance 2nd Math WG pipeline states to the end of 1st Math WG + mainloop_pipe_consumer_state.advance(k_tile_count); + epi_load_pipe_consumer_state.advance(c_tile_count); + epi_store_pipe_producer_state.advance(d_tile_count); + } + auto work_tile_info = scheduler.get_current_work(); + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue{params.epilogue, shared_storage.tensors.epilogue}; + + // Wait for all threads in the thread block + __syncthreads(); + + if (warp_group_role == WarpGroupRole::Producer) { + + while (work_tile_info.is_valid()) { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Slice with our work tile coordinates to construct mainloop tensor views + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); + + // Compute tile residues for predication + auto m_max_coord = M - size<0>(gA) * get<0>(blk_coord); // M - BLK_M * m_coord + auto n_max_coord = N - size<0>(gB) * get<1>(blk_coord); // N - BLK_N * n_coord + auto k_residue = K - size<1>(gA) * size<2>(gA); // K - BLK_K * k_coord_max + auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue); + + collective_mainloop.load( + mainloop_pipeline, + mainloop_pipe_producer_state, + gA, + gB, + k_tile_iter, k_tile_count, + residue_mnk, + thread_idx, + shared_storage.tensors.mainloop + ); + // Update starting pipeline state for the next tile + mainloop_pipe_producer_state.advance(k_tile_count); + + if (collective_epilogue.is_producer_load_needed()) { + collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + tiled_mma, + warp_group_thread_idx, + shared_storage.tensors.epilogue + ); + // Update starting pipeline state for the next tile + epi_load_pipe_producer_state.advance(c_tile_count); + } + + // Get next work tile + scheduler.advance_to_next_work(); + work_tile_info = scheduler.get_current_work(); + } // Scheduler work fetch loop + + // Make sure all Consumer Warp Groups have been waited upon + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + if (collective_epilogue.is_producer_load_needed()) { + collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + } + } // Producer Warp Group End + + else if (warp_group_role == WarpGroupRole::Consumer) { + + while (work_tile_info.is_valid()) { + // Compute m_coord, n_coord, l_coord with the post-tiled m-shape and n-shape + auto m_coord = idx2crd(work_tile_info.M_idx, shape<2>(gA_mkl)); + auto n_coord = idx2crd(work_tile_info.N_idx, shape<2>(gB_nkl)); + auto l_coord = idx2crd(work_tile_info.L_idx, shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Allocate the the accumulators for the (M,N) blk_shape + Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) + + // Order two Math WG's MMA one after the other, helps hide Epilogue + math_wg_order_barrier.wait(); + + collective_mainloop.mma( + mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators, + k_tile_count, + thread_idx, + shared_storage.tensors.mainloop, + params.mainloop + ); + + // Cue for next Math WG's MMA to start + math_wg_order_barrier.arrive(); + + // Make sure the math instructions are done and free buffers before entering the epilogue + collective_mainloop.mma_tail( + mainloop_pipeline, + mainloop_pipe_consumer_state, + k_tile_count + ); + // Update starting mainloop pipeline state for the next tile + mainloop_pipe_consumer_state.advance(k_tile_count * NumMmaWarpGroups); + + // Order two Math WG's Epilogue one after the other + math_wg_order_barrier.wait(); + + // Epilogue and write to gD + collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + accumulators, + tiled_mma, + warp_group_thread_idx, + shared_storage.tensors.epilogue + ); + // Update starting load/store pipeline states for the next tile + epi_load_pipe_consumer_state.advance(c_tile_count * NumMmaWarpGroups); + epi_store_pipe_producer_state.advance(d_tile_count * NumMmaWarpGroups); + + // Wait for all TMA stores to complete + epi_store_pipeline.producer_tail(epi_store_pipe_producer_state); + + // Cue for next Math WG's Epilogue to start + math_wg_order_barrier.arrive(); + + // Get next work tile + scheduler.advance_to_next_work(NumMmaWarpGroups); + work_tile_info = scheduler.get_current_work(); + } // Scheduler work fetch loop + } // Consumer Warp Groups End + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp b/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp index 8fb60d9004..ff64c14a10 100644 --- a/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp +++ b/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp @@ -50,6 +50,7 @@ class PersistentTileSchedulerSm90 { private: uint64_t current_work_linear_idx_; + uint64_t total_grid_size_; public: struct WorkTileInfo { @@ -57,12 +58,29 @@ class PersistentTileSchedulerSm90 { int32_t N_idx = 0; int32_t L_idx = 0; bool is_valid_tile = false; + + CUTLASS_HOST_DEVICE + bool + is_valid() const { + return is_valid_tile; + } + + CUTLASS_HOST_DEVICE + static WorkTileInfo + invalid_work_tile() { + return {-1, -1, -1, false}; + } + + CUTLASS_HOST_DEVICE + bool + is_final_split(uint32_t k_tiles_per_output_tile) const { + return true; + } }; using Params = PersistentTileSchedulerSm90Params; using RasterOrder = typename Params::RasterOrder; using RasterOrderOptions = typename Params::RasterOrderOptions; - struct Arguments { int max_swizzle_size = 1; RasterOrderOptions raster_order = RasterOrderOptions::Heuristic; @@ -116,6 +134,8 @@ class PersistentTileSchedulerSm90 { else { current_work_linear_idx_ = uint64_t(blockIdx.x) * uint64_t(gridDim.y) + uint64_t(blockIdx.y); } + + total_grid_size_ = uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z); #else CUTLASS_ASSERT(false && "This line should never be reached"); #endif @@ -130,6 +150,10 @@ class PersistentTileSchedulerSm90 { CUTLASS_DEVICE WorkTileInfo get_current_work_for_linear_idx(uint64_t linear_idx) const { + if (linear_idx >= scheduler_params.blocks_per_problem_) { + return WorkTileInfo::invalid_work_tile(); + } + // Map worker's linear index into the CTA tiled problem shape to the corresponding MNL indices uint64_t work_idx_l, remainder; scheduler_params.divmod_batch_(work_idx_l, remainder, linear_idx); @@ -143,19 +167,13 @@ class PersistentTileSchedulerSm90 { scheduler_params.log_swizzle_size_, scheduler_params.raster_order_); - return {work_idx_m, work_idx_n, static_cast(work_idx_l), linear_idx < scheduler_params.blocks_per_problem_}; + return {work_idx_m, work_idx_n, static_cast(work_idx_l), true}; } CUTLASS_DEVICE void advance_to_next_work(uint32_t advance_count = 1) { - // MSVC requires protecting use of CUDA-specific nonstandard syntax, - // like blockIdx and gridDim, with __CUDA_ARCH__. -#if defined(__CUDA_ARCH__) - current_work_linear_idx_ += uint64_t(gridDim.x) * uint64_t(gridDim.y) * uint64_t(gridDim.z) * uint64_t(advance_count); -#else - CUTLASS_ASSERT(false && "This line should never be reached"); -#endif + current_work_linear_idx_ += total_grid_size_ * uint64_t(advance_count); } // get work_idx_m, work_idx_n from blk_per_grid_dim while applying swizzle @@ -163,8 +181,8 @@ class PersistentTileSchedulerSm90 { cute::tuple get_work_idx_m_and_n( uint64_t blk_per_grid_dim, - FastDivmodU64 const& divmod_cluster_shape_major, - FastDivmodU64 const& divmod_cluster_shape_minor, + FastDivmodU64Pow2 const& divmod_cluster_shape_major, + FastDivmodU64Pow2 const& divmod_cluster_shape_minor, FastDivmodU64 const& divmod_cluster_blk_major, int32_t log_swizzle_size, RasterOrder raster_order) { @@ -205,6 +223,46 @@ class PersistentTileSchedulerSm90 { } + // Computes the linear index within a batch given M and N tile offsets within the batch. + // This essentially inverts the mapping performed in get_work_idx_m_and_n + static CUTLASS_DEVICE + uint64_t + get_linear_idx_from_m_and_n( + int32_t tile_m, + int32_t tile_n, + FastDivmodU64Pow2 const& divmod_cluster_shape_major, + FastDivmodU64Pow2 const& divmod_cluster_shape_minor, + FastDivmodU64 const& divmod_cluster_blk_major, + int32_t log_swizzle_size, + RasterOrder raster_order) { + + auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster(); + + uint64_t minor_work_idx, major_work_idx, cluster_minor_offset; + if (raster_order == RasterOrder::AlongN) { + minor_work_idx = static_cast(tile_m); + major_work_idx = static_cast(tile_n); + cluster_minor_offset = cta_m_in_cluster; + } + else { + major_work_idx = static_cast(tile_m); + minor_work_idx = static_cast(tile_n); + cluster_minor_offset = cta_n_in_cluster; + } + + uint64_t cluster_idx_minor, cluster_idx_major, cluster_major_offset; + cluster_idx_minor = divmod_cluster_shape_minor.divide(minor_work_idx - cluster_minor_offset); + divmod_cluster_shape_major(cluster_idx_major, cluster_major_offset, major_work_idx); + + uint64_t cluster_idx_minor_div_swizzle = cluster_idx_minor >> log_swizzle_size; + uint64_t offset = cluster_idx_minor & ((1 << log_swizzle_size) - 1); + + uint64_t extra = cluster_idx_minor_div_swizzle * divmod_cluster_blk_major.divisor + cluster_idx_major; + + uint64_t cluster_id = (extra << log_swizzle_size) | offset; + return (cluster_id * divmod_cluster_shape_major.divisor + cluster_major_offset) * divmod_cluster_shape_minor.divisor + cluster_minor_offset; + } + // Given the inputs, computes the total number of output blocks this problem will compute over // Note that this is only the logical size of our grid, not the physical grid we will actually launch. template @@ -250,7 +308,7 @@ class PersistentTileSchedulerSm90 { // output tile. For the basic tile scheduler, this is always true. CUTLASS_HOST_DEVICE static bool - compute_epilogue(WorkTileInfo const&) { + compute_epilogue(WorkTileInfo const&, Params const&) { return true; } diff --git a/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp b/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp index ff9cb20972..ad333e1d09 100644 --- a/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp +++ b/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp @@ -64,33 +64,44 @@ class PersistentTileSchedulerSm90StreamK { using RasterOrder = UnderlyingScheduler::RasterOrder; using RasterOrderOptions = UnderlyingScheduler::RasterOrderOptions; - // Use a dummy barrier manager to simply get the type used to store the barrier using BarrierType = typename NamedBarrierManager<1>::T; + using Params = PersistentTileSchedulerSm90StreamKParams; + using ReductionMode = Params::ReductionMode; + struct WorkTileInfo { int32_t M_idx = 0; int32_t N_idx = 0; int32_t K_idx = 0; int32_t L_idx = 0; - bool is_valid_tile = false; - - // Number of splits to be used in computing the {L_idx, M_idx, N_idx} output tile. - // Splits = 1 indicates that this is a data-parallel block. - uint32_t splits = 1; - // Number of k iterations to compute for the current tile + // Number of k tiles to compute for this unit of work. For stream-K, this + // can indicate the number of K tiles across multiple output tiles. uint32_t k_tile_count = 0; - // Number of k iterations remaining for the work unit as a whole + // Number of k tiles remaining for the work unit as a whole uint32_t k_tile_remaining = 0; - // Whether this unit of work is the final split for the given tile - bool is_final_split = true; - }; + CUTLASS_HOST_DEVICE + bool + is_valid() const { + // Use negative indices to denote invalid work + return M_idx >= 0; + } - using Params = PersistentTileSchedulerSm90StreamKParams; - using ReductionMode = Params::ReductionMode; + CUTLASS_HOST_DEVICE + static WorkTileInfo + invalid_work_tile() { + return {-1, -1, -1, -1, 0}; + } + + CUTLASS_HOST_DEVICE + bool + is_final_split(uint32_t k_tiles_per_output_tile) const { + return (K_idx + k_tile_count) == k_tiles_per_output_tile; + } + }; struct Arguments { @@ -117,6 +128,12 @@ class PersistentTileSchedulerSm90StreamK { CUTLASS_HOST_DEVICE Arguments(int splits_) : splits(splits_) {} + CUTLASS_HOST_DEVICE + Arguments(int splits_, int max_swizzle_size_, RasterOrderOptions raster_order_) : + splits(splits_), + max_swizzle_size(max_swizzle_size_), + raster_order(raster_order_) {} + // The splitting factor to be used in a split-K decomposition of the problem. // If this is set to a value greater than 1, stream-K decomposition logic // is bypassed in favor of a split-K decomposition. @@ -187,26 +204,20 @@ class PersistentTileSchedulerSm90StreamK { CUTLASS_DEVICE static WorkTileInfo get_current_work_for_linear_idx(uint64_t linear_idx, Params const& params) { - if (linear_idx >= params.units_per_problem_) { + // The maximum number of work units is units_per_problem_ * splits_. + // The multiplication by splits_ is used for handling split-K, in which + // units_per_problem_ is equal to the total number of output tiles. To account + // for the fact that we have splits_ peers per output tile, we multiply this + // value by splits_. For stream-K, this multiplication ends up being a no-op + // because splits_ is set to 1 for stream-K. + if (linear_idx >= params.units_per_problem_ * params.splits_) { // Invalid work. Return an empty result. - return {0, 0, 0, 0, false, 0}; + return WorkTileInfo::invalid_work_tile(); } - // Determine whether this work unit is a data-parallel or stream-K work unit - bool is_stream_k_unit = linear_idx < params.sk_units_; - - bool is_split_k = params.splits_ > 1; - - if (is_split_k || !is_stream_k_unit) { - // Bypass the stream-K scheduling logic for basic data-parallel or split-K work - return set_non_stream_k_work(linear_idx, params, is_split_k); - } - else { - // This is a stream-K work unit - WorkTileInfo work_tile_info; - set_stream_k_work(params, linear_idx, work_tile_info, /*new_unit = */ true); - return work_tile_info; - } + WorkTileInfo work_tile_info; + assign_work(params, linear_idx, work_tile_info); + return work_tile_info; } // Returns whether the current work_tile_info passed in should continue to be used. This @@ -233,7 +244,7 @@ class PersistentTileSchedulerSm90StreamK { return false; } - set_stream_k_work(params, linear_idx, work_tile_info, /* new_unit = */ false); + assign_work(params, linear_idx, work_tile_info); return true; } @@ -280,7 +291,7 @@ class PersistentTileSchedulerSm90StreamK { static bool requires_fixup(Params const& params, WorkTileInfo const& work_tile_info) { // Fixup is not needed for data-parallel tiles - return work_tile_info.k_tile_count != params.k_tiles_per_output_tile_; + return work_tile_info.k_tile_count != params.divmod_tiles_per_output_tile_.divisor; } // Performs the reduction across splits for a given output tile. @@ -293,7 +304,9 @@ class PersistentTileSchedulerSm90StreamK { FrgTensorC& accumulators, uint32_t num_barriers, uint32_t barrier_idx) { - using BarrierManager = NamedBarrierManager; + static constexpr uint32_t Offset = 2; + static constexpr uint32_t MaxNumNamedBarriers = 2; + using BarrierManager = NamedBarrierManager; return fixup_helper( params, work_tile_info, accumulators, num_barriers, barrier_idx); } @@ -331,24 +344,23 @@ class PersistentTileSchedulerSm90StreamK { using AccumulatorArrayT = Array; using BlockStripedReduceT = BlockStripedReduce; - AccumulatorArrayT* reduction_workspace_array = reinterpret_cast(group_reduction_workspace); - AccumulatorArrayT* accumulator_array = reinterpret_cast(&accumulators); - - int barrier_group_thread_idx = threadIdx.x % BarrierManager::ThreadCount; - // The number of tiles for which reduction is required is either: // (a) the total number of output tiles (in the case of split-K) // (b) the number of stream-K tiles - // To calcualte the the total number of output tiles in the split-K case, we + // To calculate the the total number of output tiles in the split-K case, we // note that, in the split-K case, the units_per_problem_ member of Params will be - // the total number of output tiles multiplied by the number of splits. - auto reduction_tiles = params.splits_ > 1 ? (params.units_per_problem_ / params.splits_) : params.sk_tiles_; + // the total number of output tiles. + auto reduction_tiles = params.splits_ > 1 ? params.units_per_problem_ : params.sk_tiles_; auto reduction_workspace_size = Params::get_reduction_workspace_size( reduction_tiles, to_gemm_coord(TileShape{}), sizeof_bits::value); BarrierType* lock_workspace = reinterpret_cast( reinterpret_cast(params.reduction_workspace_) + reduction_workspace_size); - if (!work_tile_info.is_final_split) { + AccumulatorArrayT* reduction_workspace_array = reinterpret_cast(group_reduction_workspace); + AccumulatorArrayT* accumulator_array = reinterpret_cast(&accumulators); + int barrier_group_thread_idx = threadIdx.x % BarrierManager::ThreadCount; + + if (!work_tile_info.is_final_split(params.divmod_tiles_per_output_tile_.divisor)) { if (work_tile_info.K_idx == 0) { // First peer initializes the workspace partials BlockStripedReduceT::store(reduction_workspace_array, *accumulator_array, barrier_group_thread_idx); @@ -359,7 +371,12 @@ class PersistentTileSchedulerSm90StreamK { BarrierManager::wait_eq(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, work_tile_info.K_idx); } else { - // Wait unitl the first split has stored its accumulators + // Wait until the first split has stored its accumulators. Note that the first split will have + // accumulated a value into the lock potentially greater than one (since the locked value is + // incremented by work_tile_info.k_tile_count below for both the deterministic and non-deterministic) + // cases. For non-deterministic reductions, all that non-first or last splits care about is whether + // the first split has been written, so we only wait while the locked value is less than 1. This + // avoids having to add logic to determine the work_tile_info.k_tile_count for the first split. BarrierManager::wait_lt(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, 1); } @@ -371,7 +388,11 @@ class PersistentTileSchedulerSm90StreamK { BarrierManager::arrive_inc(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, work_tile_info.k_tile_count); } else { - // Wait until the preceding split added its accumulators + // Wait until the preceding split added its accumulators. + // For both the deterministic and non-deterministic case, each preceding split will have incremented + // the locked value by work_tile_info.k_tile_count. Thus, the final split konws that it can begin + // loading the partially-reduced value when the locked value reaches its starting K tile index (i.e., + // work_tile_info.K_idx). BarrierManager::wait_eq(barrier_idx, lock_workspace, barrier_group_thread_idx, lock_idx, work_tile_info.K_idx); // The block computing the final split for the tile adds previously-reduced partials @@ -384,54 +405,25 @@ class PersistentTileSchedulerSm90StreamK { // output tile. For the case of stream-K, this should only occur if the work is marked as the final split. CUTLASS_HOST_DEVICE static bool - compute_epilogue(WorkTileInfo const& work_tile_info) { - return work_tile_info.is_final_split; + compute_epilogue(WorkTileInfo const& work_tile_info, Params const& params) { + return work_tile_info.is_final_split(params.divmod_tiles_per_output_tile_.divisor); } // Returns the linearized index of the output tile corresponding to the tile with offset [L, M, K] CUTLASS_DEVICE static int output_tile_index(Params const& params, WorkTileInfo const& work_tile_info) { - if (params.splits_ > 1) { - auto tiles_mn = params.divmod_batch_.divisor / params.splits_; - if (params.raster_order_ == RasterOrder::AlongN) { - return - (tiles_mn * work_tile_info.L_idx) + - (params.divmod_cluster_shape_major_.divisor * - params.divmod_cluster_blk_major_.divisor * work_tile_info.M_idx) + - work_tile_info.N_idx; - } - else { - return - (tiles_mn * work_tile_info.L_idx) + - (params.divmod_cluster_shape_major_.divisor * - params.divmod_cluster_blk_major_.divisor * work_tile_info.N_idx) + - work_tile_info.M_idx; - } - } - else { - auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster(); - - uint64_t cta_per_grid_dim; - uint64_t cluster_dim_idx; - if (params.raster_order_ == RasterOrder::AlongN) { - uint64_t block_idx_m = (work_tile_info.M_idx - cta_m_in_cluster) / params.divmod_cluster_shape_minor_.divisor; - uint64_t block_idx_n = work_tile_info.N_idx; - cta_per_grid_dim = (params.divmod_cluster_shape_major_.divisor * - params.divmod_cluster_blk_major_.divisor * block_idx_m) + block_idx_n; - cluster_dim_idx = cta_m_in_cluster; - } - else { - uint64_t block_idx_m = work_tile_info.M_idx; - uint64_t block_idx_n = (work_tile_info.N_idx - cta_n_in_cluster) / params.divmod_cluster_shape_minor_.divisor; - cta_per_grid_dim = (params.divmod_cluster_shape_major_.divisor * - params.divmod_cluster_blk_major_.divisor * block_idx_n) + block_idx_m; - cluster_dim_idx = cta_n_in_cluster; - } + uint64_t linear_idx_in_batch = UnderlyingScheduler::get_linear_idx_from_m_and_n( + work_tile_info.M_idx, work_tile_info.N_idx, + params.divmod_cluster_shape_major_, + params.divmod_cluster_shape_minor_, + params.divmod_cluster_blk_major_, + params.log_swizzle_size_, + params.raster_order_ + ); - uint64_t tile_in_batch = params.divmod_cluster_shape_minor_.divisor * cta_per_grid_dim; - return params.divmod_batch_.divisor * work_tile_info.L_idx + tile_in_batch + cluster_dim_idx; - } + uint64_t tiles_mn = params.divmod_batch_.divisor; + return tiles_mn * work_tile_info.L_idx + linear_idx_in_batch; } template @@ -518,106 +510,125 @@ class PersistentTileSchedulerSm90StreamK { // iterations) is used to find the next tile in the current work unit. CUTLASS_DEVICE static void - set_stream_k_work( + assign_work( Params const& params, uint64_t linear_idx, - WorkTileInfo& work_tile_info, - bool new_unit) { - // In the CUTLASS 2.x implementation of stream K, stream-K work is assigned to each stream-K - // threadblock individually. For the most part, the set of K iterations corresponding to stream-K - // work was divided amongst stream-K threadblocks, and a threadblock determined which tile - // it would compute a (potentially-partial) output tile for based on the space of k iterations - // assigned to it. This often results in stream-K threadblocks processing tiles with different - // offsets in the K dimension from one another. This can reduce locality, but is lmitied to the - // (generally few) waves of threadblocks assigned to compute stream-K work. - // - // With the introduction of threadblock clusters, there is additional benefit to maintaining - // locality in the K dimension: shared portions of operands can be multicasted to threadblocks - // within a cluster. Thus, we would like to ensure that the assignment of stream-K work to - // threadblocks respects the ability to perform multicasting. - // - // To do so, we divide up the linearized stream-K units into clusters and share the same K - // offsets for work within clusters. - auto cluster_size = params.divmod_cluster_shape_major_.divisor * params.divmod_cluster_shape_minor_.divisor; - auto cluster_linear_work_idx = linear_idx / cluster_size; - - // Determine the starting k iteration computed by this stream-K work unit - uint32_t unit_iter_start = params.k_tiles_per_sk_unit_ * cluster_linear_work_idx; - - // Adjust the starting position and number of k iterations for "big units," which - // compute one extra iteration. These are the first big_units_ units in the - // linearized ID space. - bool is_big_unit = cluster_linear_work_idx < params.big_units_; - if (is_big_unit) { - // Since the "big units" are the first units in the linearized ID space, each - // of the units preceding this big unit computed one extra iteration. Thus, - // we must offset our start iteration by the number of units that precede - // the current unit in the linearized ID space. - unit_iter_start += cluster_linear_work_idx; - } else { - // Increment by one for each of the big clusters (since all big units precede this unit) - unit_iter_start += params.big_units_; + WorkTileInfo& work_tile_info) { + + uint64_t true_tile_id = linear_idx; + if (linear_idx >= params.sk_units_ && params.splits_ == 1) { + // Data-parallel work + true_tile_id = linear_idx - params.sk_units_ + params.sk_tiles_; + work_tile_info.K_idx = 0; + work_tile_info.k_tile_count = params.divmod_tiles_per_output_tile_.divisor; + work_tile_info.k_tile_remaining = params.divmod_tiles_per_output_tile_.divisor; } + else { + // In the CUTLASS 2.x implementation of stream K, stream-K work is assigned to each stream-K + // threadblock individually. For the most part, the set of K iterations corresponding to stream-K + // work was divided amongst stream-K threadblocks, and a threadblock determined which tile + // it would compute a (potentially-partial) output tile for based on the space of k iterations + // assigned to it. This often results in stream-K threadblocks processing tiles with different + // offsets in the K dimension from one another. This can reduce locality, but is lmitied to the + // (generally few) waves of threadblocks assigned to compute stream-K work. + // + // With the introduction of threadblock clusters, there is additional benefit to maintaining + // locality in the K dimension: shared portions of operands can be multicasted to threadblocks + // within a cluster. Thus, we would like to ensure that the assignment of stream-K work to + // threadblocks respects the ability to perform multicasting. + // + // To do so, we divide up the linearized stream-K units into clusters and share the same K + // offsets for work within clusters. + + // Equivalent to linear_idx / cluster_size + auto cluster_linear_work_idx = params.divmod_cluster_shape_minor_.divide( + params.divmod_cluster_shape_major_.divide(linear_idx) + ); + + uint64_t split; + params.divmod_clusters_mnl_(split, cluster_linear_work_idx, cluster_linear_work_idx); + auto big_unit_cmp = params.splits_ > 1 ? split : cluster_linear_work_idx; + auto linear_idx_mult = params.splits_ > 1 ? params.divmod_tiles_per_output_tile_.divisor : params.k_tiles_per_sk_unit_; + + // Determine the starting k iteration computed by this stream-K work unit + uint32_t unit_iter_start = (linear_idx_mult * cluster_linear_work_idx) + (params.k_tiles_per_sk_unit_ * split); + + // Adjust the starting position and number of k iterations for "big units," which + // compute one extra iteration. These are the first big_units_ units in the + // linearized ID space. + bool is_big_unit = big_unit_cmp < params.big_units_; + if (is_big_unit) { + // Since the "big units" are the first units in the linearized ID space, each + // of the units preceding this big unit computed one extra iteration. Thus, + // we must offset our start iteration by the number of units that precede + // the current unit in the linearized ID space. + unit_iter_start += big_unit_cmp; + } + else { + // Increment by one for each of the big clusters (since all big units precede this unit) + unit_iter_start += params.big_units_; + } - uint32_t unit_iters; - if (new_unit) { - unit_iters = params.k_tiles_per_sk_unit_; + if (work_tile_info.k_tile_count == 0) { + // This is a new unit + work_tile_info.k_tile_remaining = params.k_tiles_per_sk_unit_; - // Only adjust iteration count for big unit if we are initializing this - // work unit. For existing work units, the extra iteration for big units - // has already been accounted for in k_tiles_reamaining - if (is_big_unit) { - ++unit_iters; + // Only adjust iteration count for big unit if we are initializing this + // work unit. For existing work units, the extra iteration for big units + // has already been accounted for in k_tiles_reamaining + if (is_big_unit) { + ++work_tile_info.k_tile_remaining; + } } - } - else { - unit_iters = work_tile_info.k_tile_remaining; - } - // Find the output tile corresponding to the final k iteration covered by this - // work unit. Stream-K work units will work backwards in terms of the tiles they - // are responsible computing. This is beneficial because the final (partial) - // tile computed by a stream-K block is typically the beginning of the output - // tile, while the beginning (partial) tile is typically the ending of another - // output tile. Since ending portions of an output tile must reduce across - // other work units computing portions of that output tile, it is preferable - // for them to be computed later, so as to reduce the likelihood of blocking - // on other work. - uint32_t unit_iter_end = unit_iter_start + unit_iters - 1; - uint32_t true_tile_id = unit_iter_end / params.k_tiles_per_output_tile_; - uint32_t true_tile_iter_start = true_tile_id * params.k_tiles_per_output_tile_; - uint32_t true_tile_iter_end = true_tile_iter_start + params.k_tiles_per_output_tile_; - - // Bring the linearized tile ID back into the space of tiles, rather than clusters - true_tile_id *= cluster_size; - - auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster(); - - // The final linearized tile ID is in units of the cluster dimension over which we rasterize. - if (params.raster_order_ == RasterOrder::AlongN) { - true_tile_id += cta_n_in_cluster * params.divmod_cluster_shape_minor_.divisor; - } - else { - true_tile_id += cta_m_in_cluster * params.divmod_cluster_shape_minor_.divisor; - } + // Find the output tile corresponding to the final k iteration covered by this + // work unit. Stream-K work units will work backwards in terms of the tiles they + // are responsible computing. This is beneficial because the final (partial) + // tile computed by a stream-K block is typically the beginning of the output + // tile, while the beginning (partial) tile is typically the ending of another + // output tile. Since ending portions of an output tile must reduce across + // other work units computing portions of that output tile, it is preferable + // for them to be computed later, so as to reduce the likelihood of blocking + // on other work. + uint32_t unit_iter_end = unit_iter_start + work_tile_info.k_tile_remaining - 1; + + true_tile_id = params.divmod_tiles_per_output_tile_.divide(unit_iter_end); + uint32_t true_tile_iter_start = true_tile_id * params.divmod_tiles_per_output_tile_.divisor; + uint32_t true_tile_iter_end = true_tile_iter_start + params.divmod_tiles_per_output_tile_.divisor; + + // Bring the linearized tile ID back into the space of tiles, rather than clusters + true_tile_id *= params.divmod_cluster_shape_major_.divisor * params.divmod_cluster_shape_minor_.divisor; + + auto [cta_m_in_cluster, cta_n_in_cluster, _] = cute::block_id_in_cluster(); + + // The final linearized tile ID is in units of the cluster dimension over which we rasterize. + if (params.raster_order_ == RasterOrder::AlongN) { + true_tile_id += cta_n_in_cluster * params.divmod_cluster_shape_minor_.divisor; + } + else { + true_tile_id += cta_m_in_cluster * params.divmod_cluster_shape_minor_.divisor; + } - // The unit's starting k iteration in the current tile is either the starting - // iteration for the tile as a whole, or the starting k iteration for the unit - // as a whole (if the latter is greater than the former). - uint32_t tile_iter_start = max(true_tile_iter_start, unit_iter_start); + // The unit's starting k iteration in the current tile is either the starting + // iteration for the tile as a whole, or the starting k iteration for the unit + // as a whole (if the latter is greater than the former). + uint32_t tile_iter_start = max(true_tile_iter_start, unit_iter_start); - // Similarly, the unit's ending k iteration (exclusive) is either the end of - // the current tile it is assigned, or the ending iteration of the unit as a whole - // (if the latter is less than the former). - uint32_t tile_iter_end = min(true_tile_iter_end, unit_iter_end + 1); + // Similarly, the unit's ending k iteration (exclusive) is either the end of + // the current tile it is assigned, or the ending iteration of the unit as a whole + // (if the latter is less than the former). + uint32_t tile_iter_end = min(true_tile_iter_end, unit_iter_end + 1); - uint32_t tile_iters = tile_iter_end - tile_iter_start; + // Set the k offset to be the starting k tile for this output tile + work_tile_info.K_idx = static_cast(tile_iter_start - true_tile_iter_start); + + work_tile_info.k_tile_count = tile_iter_end - tile_iter_start; + } uint64_t work_idx_l, remainder; params.divmod_batch_(work_idx_l, remainder, true_tile_id); - uint64_t cta_per_grid_dim, dontcare; - params.divmod_cluster_shape_minor_(cta_per_grid_dim, dontcare, remainder); + uint64_t cta_per_grid_dim = params.divmod_cluster_shape_minor_.divide(remainder); auto [work_idx_m, work_idx_n] = UnderlyingScheduler::get_work_idx_m_and_n( cta_per_grid_dim, @@ -627,113 +638,11 @@ class PersistentTileSchedulerSm90StreamK { params.log_swizzle_size_, params.raster_order_); - // - // Update the work_tile_info - // - // Set the M, N, and L block offsets work_tile_info.M_idx = work_idx_m; work_tile_info.N_idx = work_idx_n; work_tile_info.L_idx = static_cast(work_idx_l); - // Set the k offset to be the starting k tile for this output tile - work_tile_info.K_idx = static_cast(tile_iter_start - true_tile_iter_start); - - // Set the split count to be the number of k tiles in the output tile - work_tile_info.splits = params.k_tiles_per_output_tile_; - - // Any checks for invalid work units should be done prior to this call - work_tile_info.is_valid_tile = true; - - work_tile_info.k_tile_count = tile_iters; - work_tile_info.k_tile_remaining = unit_iters; - - // Compute the epilogue if this unit of work contains the ending k iteration for - // the output tile in question - work_tile_info.is_final_split = (tile_iter_end == true_tile_iter_end); - } - - // Returns a WorkTileInfo to be computed for either the data-parallel or split-K - // work unit identified by the provided linear ID. - CUTLASS_DEVICE - static WorkTileInfo - set_non_stream_k_work(uint64_t linear_idx, Params const& params, bool is_split_k) { - - // The linearized ID space is in terms of work units, rather than tiles. However, - // to compute the correct block offset for a data-parallel tile, we must convert - // the current ID to the data-parallel tile it corresponds to. Each data-parallel - // unit maps to a single data-parallel tile, but each stream-K unit can map to more - // than one tile. Thus, we must offset the work-unit ID among the data-parallel units - // by the total number of output tiles that will be computed by stream-K units. - // - // The logic below also works for the split-K case, in which sk_units_ and sk_tiles_ - // are each 0. - uint64_t linear_work_idx = linear_idx - params.sk_units_ + params.sk_tiles_; - - // Map worker's linear index into the CTA-tiled problem shape to the corresponding MNL indices - uint64_t work_idx_l, remainder; - params.divmod_batch_(work_idx_l, remainder, linear_work_idx); - - uint64_t work_idx_k = 0; - if (is_split_k) { - params.divmod_k_(work_idx_k, remainder, remainder); - } - - uint64_t cta_per_grid_dim, dontcare; - params.divmod_cluster_shape_minor_(cta_per_grid_dim, dontcare, remainder); - - auto [work_idx_m, work_idx_n] = UnderlyingScheduler::get_work_idx_m_and_n( - cta_per_grid_dim, - params.divmod_cluster_shape_major_, - params.divmod_cluster_shape_minor_, - params.divmod_cluster_blk_major_, - params.log_swizzle_size_, - params.raster_order_); - - bool is_final_split = (work_idx_k == params.splits_ - 1); - - uint32_t k_tiles = params.k_tiles_per_output_tile_; - if (is_split_k) { - // Determine the number of iterations and starting iteration of this split. - // Doing so requires accounting for residual iterations, which are handled - // by the first big_units_ splits (with big_units_ = tiles % sm_count). - - // Offsets for "normal" units. No additional k iterations are performed, - // and big_units_ "big" units preceded us, each of which performed one - // additional iteration. Thus, we must increase our split starting offset - // by big_units_. - int additional_k_tiles = 0; - int split_start_offset = params.big_units_; - - if (work_idx_k < params.big_units_) { - // Offsets for "big" units. One additional k iteration is performed, - // and each split preceding us was a big unit, so we must increase - // our split starting offset by our split ID (work_idx_k). - additional_k_tiles = 1; - split_start_offset = work_idx_k; - } - - // Set up k iteration count and split starting iteration assuming the - // iteration space is evenly split. - k_tiles /= params.splits_; - work_idx_k *= k_tiles; - - // Apply any fixup needed to handle residuals - work_idx_k += split_start_offset; - k_tiles += additional_k_tiles; - } - - return { - work_idx_m, - work_idx_n, - static_cast(work_idx_k), - static_cast(work_idx_l), - true, - params.k_tiles_per_output_tile_, - k_tiles, - k_tiles, // remaining iterations - is_final_split - }; } }; diff --git a/include/cutlass/gemm/kernel/tile_scheduler_params.h b/include/cutlass/gemm/kernel/tile_scheduler_params.h index eb98fd2f42..8cfb484545 100644 --- a/include/cutlass/gemm/kernel/tile_scheduler_params.h +++ b/include/cutlass/gemm/kernel/tile_scheduler_params.h @@ -78,8 +78,8 @@ struct PersistentTileSchedulerSm90Params { AlongN }; - FastDivmodU64 divmod_cluster_shape_major_{}; - FastDivmodU64 divmod_cluster_shape_minor_{}; + FastDivmodU64Pow2 divmod_cluster_shape_major_{}; + FastDivmodU64Pow2 divmod_cluster_shape_minor_{}; FastDivmodU64 divmod_batch_{}; FastDivmodU64 divmod_cluster_blk_major_{}; @@ -143,13 +143,13 @@ struct PersistentTileSchedulerSm90Params { divmod_batch_ = FastDivmodU64(problem_blocks_m * problem_blocks_n); if (raster_order == RasterOrder::AlongN) { - divmod_cluster_shape_major_ = FastDivmodU64(cluster_shape.n()); - divmod_cluster_shape_minor_ = FastDivmodU64(cluster_shape.m()); + divmod_cluster_shape_major_ = FastDivmodU64Pow2(cluster_shape.n()); + divmod_cluster_shape_minor_ = FastDivmodU64Pow2(cluster_shape.m()); divmod_cluster_blk_major_ = FastDivmodU64(problem_blocks_n / cluster_shape.n()); } else { - divmod_cluster_shape_major_ = FastDivmodU64(cluster_shape.m()); - divmod_cluster_shape_minor_ = FastDivmodU64(cluster_shape.n()); + divmod_cluster_shape_major_ = FastDivmodU64Pow2(cluster_shape.m()); + divmod_cluster_shape_minor_ = FastDivmodU64Pow2(cluster_shape.n()); divmod_cluster_blk_major_ = FastDivmodU64(problem_blocks_m / cluster_shape.m()); } } @@ -374,15 +374,22 @@ struct PersistentTileSchedulerSm90StreamKParams { using RasterOrder = UnderlyingParams::RasterOrder; using RasterOrderOptions = UnderlyingParams::RasterOrderOptions; - FastDivmodU64 divmod_cluster_shape_major_{}; - FastDivmodU64 divmod_cluster_shape_minor_{}; + // Cluster dimensions are typically always a power of 2, so use + // the power-of-two variants of FastDivmod for these. + FastDivmodU64Pow2 divmod_cluster_shape_major_{}; + FastDivmodU64Pow2 divmod_cluster_shape_minor_{}; + FastDivmodU64 divmod_batch_{}; - FastDivmodU64 divmod_k_{}; FastDivmodU64 divmod_cluster_blk_major_{}; - int32_t log_swizzle_size_ = 0; + // Total number of cluster-sized output tiles (i.e., not including any + // splitting factors). This is primarily used for split-K decompositions, + // and may be overridden in other decompositions. + FastDivmodU64 divmod_clusters_mnl_{}; uint64_t units_per_problem_ = 0; + FastDivmod divmod_tiles_per_output_tile_{}; + int32_t log_swizzle_size_ = 0; RasterOrder raster_order_ = RasterOrder::AlongN; // The splitting factor to be used in a split-K decomposition of the problem. @@ -390,9 +397,6 @@ struct PersistentTileSchedulerSm90StreamKParams { // is bypassed in favor of a split-K decomposition. uint32_t splits_ = 1; - // Number of tiled k iterations required to compute a single output tile. - uint32_t k_tiles_per_output_tile_ = 0; - // Number of stream-K or split-K work units that compute an extra k iteration. // This is done to handle residuals in dividing up the k iteration space. // For stream-K, since the actual assignment of work to stream-K units will be done @@ -475,10 +479,10 @@ struct PersistentTileSchedulerSm90StreamKParams { raster_order_option ); - auto problem_blocks_m = problem_blocks.x; - auto problem_blocks_n = problem_blocks.y; auto problem_blocks_l = problem_blocks.z; + auto problem_blocks_m = round_up(problem_blocks.x, (1 << underlying_params.log_swizzle_size_) * cluster_shape.m()); + auto problem_blocks_n = round_up(problem_blocks.y, (1 << underlying_params.log_swizzle_size_) * cluster_shape.n()); uint64_t output_tiles = problem_blocks_m * problem_blocks_n * problem_blocks_l; // Reduction workspace is at the beginning of the workspace. Lock workspace follows. @@ -620,13 +624,17 @@ struct PersistentTileSchedulerSm90StreamKParams { divmod_cluster_shape_major_ = underlying_params.divmod_cluster_shape_major_; divmod_cluster_shape_minor_ = underlying_params.divmod_cluster_shape_minor_; divmod_batch_ = underlying_params.divmod_batch_; - divmod_k_ = FastDivmodU64(problem_blocks_m * problem_blocks_n); // Static k-splitting divmod. Unused for stream-K. + divmod_tiles_per_output_tile_ = FastDivmod(k_tiles_per_output_tile); divmod_cluster_blk_major_ = underlying_params.divmod_cluster_blk_major_; + + // Override divmod_clusters_mnl_ to be the number of cluster-sized stream-K units. + // This setting ensures that the use of this divmod for stream-K decompositions + // is essentially a no-op. + divmod_clusters_mnl_ = FastDivmodU64(sk_units / cluster_size); + splits_ = 1; log_swizzle_size_ = underlying_params.log_swizzle_size_; units_per_problem_ = static_cast(dp_units + sk_units); raster_order_ = underlying_params.raster_order_; - splits_ = 1; // Static k-splitting factor. Unused for stream-K. - k_tiles_per_output_tile_ = k_tiles_per_output_tile; big_units_ = static_cast(sk_big_units_per_cluster); reduction_workspace_ = reduction_workspace; sk_tiles_ = sk_tiles; @@ -755,6 +763,10 @@ struct PersistentTileSchedulerSm90StreamKParams { uint32_t barrier_bits, uint32_t accumulator_bits) { + auto log_swizzle_size = UnderlyingParams::get_log_swizzle_size(problem_blocks.x, problem_blocks.y, max_swizzle); + problem_blocks.x = round_up(problem_blocks.x, (1 << log_swizzle_size) * cluster_shape.m()); + problem_blocks.y = round_up(problem_blocks.y, (1 << log_swizzle_size) * cluster_shape.n()); + // Workspace is needed only for output tiles that will be split. Thus, we first determine the number // of output tiles that will be split, and then calculate the workspace needed to cover these. uint64_t output_tiles = problem_blocks.x * problem_blocks.y * problem_blocks.z; @@ -966,24 +978,25 @@ struct PersistentTileSchedulerSm90StreamKParams { void* reduction_workspace, ReductionMode reduction_mode) { - divmod_cluster_shape_major_ = underlying_params.divmod_cluster_shape_major_, - divmod_cluster_shape_minor_ = underlying_params.divmod_cluster_shape_minor_, - divmod_batch_ = FastDivmodU64(blocks_m * blocks_n * splits), - divmod_k_ = FastDivmodU64(blocks_m * blocks_n), - divmod_cluster_blk_major_ = underlying_params.divmod_cluster_blk_major_, - log_swizzle_size_ = underlying_params.log_swizzle_size_, - units_per_problem_ = blocks_m * blocks_n * blocks_l * splits, - raster_order_ = underlying_params.raster_order_, - splits_ = splits, - k_tiles_per_output_tile_ = k_tiles_per_output_tile, - big_units_ = k_tiles_per_output_tile % splits, + divmod_cluster_shape_major_ = underlying_params.divmod_cluster_shape_major_; + divmod_cluster_shape_minor_ = underlying_params.divmod_cluster_shape_minor_; + divmod_batch_ = FastDivmodU64(blocks_m * blocks_n); + divmod_tiles_per_output_tile_ = FastDivmod(k_tiles_per_output_tile); + auto cluster_size = underlying_params.divmod_cluster_shape_major_.divisor * underlying_params.divmod_cluster_shape_minor_.divisor; + divmod_clusters_mnl_ = FastDivmodU64((blocks_m * blocks_n * blocks_l) / cluster_size); + splits_ = splits; + divmod_cluster_blk_major_ = underlying_params.divmod_cluster_blk_major_; + log_swizzle_size_ = underlying_params.log_swizzle_size_; + units_per_problem_ = blocks_m * blocks_n * blocks_l; + raster_order_ = underlying_params.raster_order_; + big_units_ = k_tiles_per_output_tile % splits; reduction_workspace_ = reduction_workspace; reduction_mode_ = reduction_mode; + k_tiles_per_sk_unit_ = k_tiles_per_output_tile / splits; // No stream-K work is performed for "basic" data-parallel and split-K decompositions sk_tiles_ = 0; sk_units_ = 0; - k_tiles_per_sk_unit_ = 0; } private: diff --git a/include/cutlass/gemm/threadblock/mma_sparse_multistage.h b/include/cutlass/gemm/threadblock/mma_sparse_multistage.h index 4b35e69603..6ed9692d90 100644 --- a/include/cutlass/gemm/threadblock/mma_sparse_multistage.h +++ b/include/cutlass/gemm/threadblock/mma_sparse_multistage.h @@ -643,9 +643,9 @@ class SparseMmaMultistage : // we can start right away on mma instructions if (warp_mma_k + 1 == Base::kWarpGemmIterations) warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], - warp_transformed_frag_B[(warp_mma_k + 1) % 2], + warp_transformed_frag_B[(warp_mma_k + 1) % Detail::kBBufferSize], warp_loaded_frag_A[(warp_mma_k + 1) % 2], - warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + warp_loaded_frag_B[(warp_mma_k + 1) % Detail::kBBufferSize]); } } diff --git a/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h b/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h index 196fe1a37a..8589aabbf5 100644 --- a/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h +++ b/include/cutlass/gemm/threadblock/threadblock_swizzle_streamk.h @@ -408,19 +408,21 @@ struct ThreadblockSwizzleStreamK { } /// Constructor: *Gemm* problem size (m, n, k) - template ThreadblockSwizzleStreamK( - KernelTraits const kernel_traits_, GemmUniversalMode const mode_, GemmCoord const problem_size_, GemmCoord const tile_size_, int const batch_split_, /// Either (mode == GemmUniversalMode::kBatched) the batch count, or (mode == GemmUniversalMode::kGemm) the tile-splitting factor (1 defaults to StreamK, >1 emulates Split-K) int const sm_occupancy_, int const device_sms_, - int const avail_sms_) /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling) + int const avail_sms_, /// The number of SMs that StreamK dispatch heuristics will attempt to load-balance across (-1 defaults to device width, 1 implies classic data-parallel scheduling) + size_t const element_A_bytes_, + size_t const element_B_bytes_, + size_t const element_C_bytes_, + int const epilogue_acc_fragments_) : problem_size(problem_size_), - batch_count((mode_ == GemmUniversalMode::kBatched) ? batch_split_ : 1), + batch_count((mode_ == GemmUniversalMode::kBatched || mode_ == GemmUniversalMode::kArray) ? batch_split_ : 1), reduction_blocks(0), dp_blocks(0), dp_first_wave_tiles(1), // Default: one tile per DP-block in the first wave of DP blocks @@ -446,17 +448,17 @@ struct ThreadblockSwizzleStreamK { batch_count); size_t problem_bytes = - (sizeof(typename GemmKernel::ElementC) * problem_size.m() * problem_size.n()) + - (sizeof(typename GemmKernel::ElementA) * problem_size.m() * problem_size.k()) + - (sizeof(typename GemmKernel::ElementB) * problem_size.k() * problem_size.n()); + (element_C_bytes_ * problem_size.m() * problem_size.n()) + + (element_A_bytes_ * problem_size.m() * problem_size.k()) + + (element_B_bytes_ * problem_size.k() * problem_size.n()); size_t problem_flops = size_t(problem_size.m()) * size_t(problem_size.n()) * size_t(problem_size.k()) * 2; - float flops_per_byte = float(problem_flops) / float(problem_bytes); + [[maybe_unused]] float flops_per_byte = float(problem_flops) / float(problem_bytes); int output_tiles = tiled_shape.m() * tiled_shape.n(); int waves = (output_tiles + avail_sms - 1) / avail_sms; - float dp_efficiency = float(output_tiles) / float(waves * avail_sms); + [[maybe_unused]] float dp_efficiency = float(output_tiles) / float(waves * avail_sms); // // Determine dispatch composition of DP-tiles and SK-blocks @@ -528,8 +530,7 @@ struct ThreadblockSwizzleStreamK { (sk_blocks > 2 * sk_tiles)) { // Launch a reduction block for every accumulator fragment in each SK-tile - static const int kAccumulatorFragments = GemmKernel::Epilogue::kAccumulatorFragments; - reduction_blocks = sk_tiles * kAccumulatorFragments; + reduction_blocks = sk_tiles * epilogue_acc_fragments_; } diff --git a/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h b/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h index 9572f2e32e..6238097ca9 100644 --- a/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h +++ b/include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h @@ -294,6 +294,8 @@ struct DefaultMmaTensorOp< Policy, PartitionsK, AccumulatorsInRowMajor>; }; +///////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace warp } // namespace gemm } // namespace cutlass diff --git a/include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h b/include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h index ee58e39dc5..22598a217d 100644 --- a/include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h +++ b/include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h @@ -132,10 +132,10 @@ struct FragmentShuffler ; using MmaFragment = Array; - + static uint32_t const kSelectBytesEvenThread = 0x5410; static uint32_t const kSelectBytesOddThread = 0x7632; @@ -168,7 +168,7 @@ struct FragmentShuffler (&mma_frag_src_ptr[n]); uint32_t *dst_ptr = reinterpret_cast(&mma_frag_dst_ptr[n]); - + // Shuffle data within the warp, pull from other threads within the warp uint32_t tmp0 = __shfl_up_sync(0xFFFFFFFF, src_ptr[0], delta_up_); uint32_t tmp1 = __shfl_down_sync(0xFFFFFFFF, src_ptr[0], delta_down_); @@ -218,7 +218,7 @@ struct FragmentShuffler ; using MmaFragment = Array; - + static uint32_t const kSelectBytesEvenThread = 0x5410; static uint32_t const kSelectBytesOddThread = 0x7632; @@ -260,7 +260,7 @@ struct FragmentShuffler struct FragmentConverter { - + using ElementDst = ElementDst_; using ElementSrc = ElementSrc_; @@ -522,17 +522,6 @@ class MmaMixedInputTensorOp { void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, FragmentA const &A, FragmentB const &B) const { - // Shuffle data within warp to obtain the mma.sync operand layout - detail::FragmentShuffler shuffler_A; - FragmentA tmp_A; - tmp_A = shuffler_A(A); - - // Convert the A operand to the Mma Instruction operand type - detail::FragmentConverter convert_A; - dst_A = convert_A(tmp_A); - - // Shuffle data within warp to obtain the mma.sync operand layout detail::FragmentShuffler shuffler_B; @@ -542,6 +531,27 @@ class MmaMixedInputTensorOp { // Convert the B operand to the Mma Instruction operand type detail::FragmentConverter convert_B; dst_B = convert_B(tmp_B); + + FragmentA tmp_A; + + Array * + ptr_tmp_A = reinterpret_cast *>(&tmp_A); + Array * + ptr_dst_A = reinterpret_cast *>(&dst_A); + + // Shuffle data within warp to obtain the mma.sync operand layout + detail::FragmentShuffler shuffler_A; + + // Convert the A operand to the Mma Instruction operand type + detail::FragmentConverter convert_A; + + tmp_A = shuffler_A(A); + ptr_dst_A[0] = convert_A(ptr_tmp_A[0]); + + ptr_dst_A[1] = convert_A(ptr_tmp_A[1]); } }; @@ -551,4 +561,4 @@ class MmaMixedInputTensorOp { } // namespace gemm } // namespace cutlass -///////////////////////////////////////////////////////////////////////////////////////////////// \ No newline at end of file +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/gemm/warp/mma_sparse_tensor_op.h b/include/cutlass/gemm/warp/mma_sparse_tensor_op.h index e049f4f0fa..5b0fc40f77 100644 --- a/include/cutlass/gemm/warp/mma_sparse_tensor_op.h +++ b/include/cutlass/gemm/warp/mma_sparse_tensor_op.h @@ -158,6 +158,7 @@ class SparseMmaTensorOp { /// Max ID2 static int const kMaxID2 = Policy::Operator::kMaxID2; + static int const kVerticalVisit = false; /// Data type of meta E that is moved at the same time using ElementE = typename cutlass::platform::conditional= 800) - D = C; MmaOperandA const *ptr_A = reinterpret_cast(&A); @@ -260,6 +259,36 @@ class SparseMmaTensorOp { MmaOperandC *ptr_D = reinterpret_cast(&D); MmaOperandE const *ptr_E = reinterpret_cast(&E); + if (kVerticalVisit) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; ++n) { + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; ++m) { + + int m_serpentine = ((n % 2) ? (MmaIterations::kRow - 1 - m) : m); + int id2 = m_serpentine % kMaxID2; + + if (AccumulatorsInRowMajor) { // matrix B is reordered + mma( + ptr_D[n + m_serpentine * MmaIterations::kColumn], + ptr_A[m_serpentine], + ptr_B[n], + ptr_D[n + m_serpentine * MmaIterations::kColumn], + ptr_E[(m_serpentine / kMaxID2)], + id2); + } else { + mma( + ptr_D[m_serpentine + n * MmaIterations::kRow], + ptr_A[m_serpentine], + ptr_B[n], + ptr_D[m_serpentine + n * MmaIterations::kRow], + ptr_E[(m_serpentine / kMaxID2)], + id2); + } + } + } + } else { CUTLASS_PRAGMA_UNROLL for (int m = 0; m < MmaIterations::kRow; ++m) { @@ -288,9 +317,7 @@ class SparseMmaTensorOp { } } } - #else - assert(0); - #endif + } } /// Transform the mma operands to the required types @@ -298,7 +325,6 @@ class SparseMmaTensorOp { void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, FragmentA const &A, FragmentB const &B) const { - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) // // Define conversions from source type to instruction type // @@ -308,25 +334,42 @@ class SparseMmaTensorOp { FloatRoundStyle const kRoundB = PreferredRoundingMode::kRound; - detail::ConvertAndPack - convert_A; - NumericArrayConverter - convert_B; - Array const *ptr_A = - reinterpret_cast const *>(&A); - Array * - ptr_dst_A = reinterpret_cast *>(&dst_A); - - dst_B = convert_B(B); - - ptr_dst_A[0] = convert_A(ptr_A[0]); - ptr_dst_A[1] = convert_A(ptr_A[1]); - #else - assert(0); - #endif + + if (kVerticalVisit) { + detail::ConvertAndPack + convert_A; + NumericArrayConverter + convert_B; + Array const *ptr_B = + reinterpret_cast const *>(&B); + Array * + ptr_dst_B = reinterpret_cast *>(&dst_B); + + dst_A = convert_A(A); + + ptr_dst_B[0] = convert_B(ptr_B[0]); + ptr_dst_B[1] = convert_B(ptr_B[1]); + } else { + detail::ConvertAndPack + convert_A; + NumericArrayConverter + convert_B; + Array const *ptr_A = + reinterpret_cast const *>(&A); + Array * + ptr_dst_A = reinterpret_cast *>(&dst_A); + + dst_B = convert_B(B); + + ptr_dst_A[0] = convert_A(ptr_A[0]); + ptr_dst_A[1] = convert_A(ptr_A[1]); + } } }; diff --git a/include/cutlass/gemm/warp/mma_tensor_op.h b/include/cutlass/gemm/warp/mma_tensor_op.h index 3124618c28..ba1b3ca71f 100644 --- a/include/cutlass/gemm/warp/mma_tensor_op.h +++ b/include/cutlass/gemm/warp/mma_tensor_op.h @@ -217,6 +217,12 @@ class MmaTensorOp { /// Number of partitions along K dimension static int const kPartitionsK = PartitionsK_; + #if defined(__CUDA_ARCH__) && ((__CUDA_ARCH__ < 800) || (__CUDA_ARCH__ == 890)) + static int const kVerticalVisit = true; + #else + static int const kVerticalVisit = false; + #endif + public: /// Iterates over the A operand in memory @@ -293,16 +299,8 @@ class MmaTensorOp { MmaOperandB const *ptr_B = reinterpret_cast(&B); MmaOperandC *ptr_D = reinterpret_cast(&D); - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) - // Serpentine visitation order maximizing reuse of Rb - // The visitation order is like - // _ - // | | | | - // | | | | - // |_| |_| - // - // Down Up Down Up - + + if (kVerticalVisit) { CUTLASS_PRAGMA_UNROLL for (int n = 0; n < MmaIterations::kColumn; ++n) { @@ -326,16 +324,7 @@ class MmaTensorOp { } } } - #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) - // Serpentine visitation order maximizing reuse of Ra - // The visitation order is like - // _________ - // _________| - // |_________ - // __________| - // - // Right Left Right Left - + } else { CUTLASS_PRAGMA_UNROLL for (int m = 0; m < MmaIterations::kRow; ++m) { @@ -358,9 +347,7 @@ class MmaTensorOp { } } } - #else - assert(0); - #endif + } } /// Transform the mma operands to the required types @@ -377,7 +364,7 @@ class MmaTensorOp { FloatRoundStyle const kRoundB = PreferredRoundingMode::kRound; - #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) + if (kVerticalVisit) { detail::ConvertAndPack convert_A; @@ -394,8 +381,7 @@ class MmaTensorOp { ptr_dst_B[0] = convert_B(ptr_B[0]); ptr_dst_B[1] = convert_B(ptr_B[1]); - - #elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + } else { detail::ConvertAndPack convert_A; @@ -412,9 +398,7 @@ class MmaTensorOp { ptr_dst_A[0] = convert_A(ptr_A[0]); ptr_dst_A[1] = convert_A(ptr_A[1]); - #else - assert(0); - #endif + } } }; diff --git a/include/cutlass/kernel_hardware_info.h b/include/cutlass/kernel_hardware_info.h index 46e5a89ca0..d0924296e8 100644 --- a/include/cutlass/kernel_hardware_info.h +++ b/include/cutlass/kernel_hardware_info.h @@ -53,7 +53,7 @@ struct KernelHardwareInfo { // Data members // int device_id = 0; - int sm_count = 0; + int sm_count = 0; // // Methods diff --git a/include/cutlass/numeric_conversion.h b/include/cutlass/numeric_conversion.h index a3ad138b0a..27fc0e6f95 100644 --- a/include/cutlass/numeric_conversion.h +++ b/include/cutlass/numeric_conversion.h @@ -1303,11 +1303,195 @@ struct NumericArrayConverter { // ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + using result_element = float; + using source_element = float_e4m3_t; + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_DEVICE + static result_type convert(source_type const & source) { + + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint32_t out_fp16; + uint16_t const& src_packed = reinterpret_cast(source); + + asm volatile( \ + "{\n" \ + "cvt.rn.f16x2.e4m3x2 %0, %1;\n" \ + "}\n" : "=r"(out_fp16): "h"(src_packed)); + + float2 res0 = __half22float2(reinterpret_cast<__half2 &>(out_fp16)); + + result_type out; + out[0] = res0.x; + out[1] = res0.y; + return out; + #else + result_type result; + NumericConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + result[i] = converter(source[i]); + } + + return result; + #endif + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + using result_element = float_e4m3_t; + using source_element = float; + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_DEVICE + static result_type convert(source_type const & source) { + + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint16_t out; + + asm volatile( \ + "{\n" \ + "cvt.rn.satfinite.e4m3x2.f32 %0, %2, %1;\n" \ + "}" \ + : "=h"(out) : "f"(source[0]), "f"(source[1])); + + return reinterpret_cast(out); + #else + result_type result; + NumericConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + result[i] = converter(source[i]); + } + + return result; + #endif + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + +/// Partial specialization for Array <= Array +template < + FloatRoundStyle Round +> +struct NumericArrayConverter { + using result_element = float; + using source_element = float_e5m2_t; + + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + CUTLASS_DEVICE + static result_type convert(source_type const & source) { + + #if defined(CUDA_PTX_FP8_CVT_ENABLED) + uint32_t out_fp16; + uint16_t const& src_packed = reinterpret_cast(source); + + asm volatile( \ + "{\n" \ + "cvt.rn.f16x2.e5m2x2 %0, %1;\n" \ + "}\n" : "=r"(out_fp16): "h"(src_packed)); + + float2 res0 = __half22float2(reinterpret_cast<__half2 &>(out_fp16)); + + result_type out; + out[0] = res0.x; + out[1] = res0.y; + return out; + #else + result_type result; + NumericConverter converter; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + result[i] = converter(source[i]); + } + + return result; + #endif + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; +namespace detail { + +/// Special converters that can be used with 4 8-bit elements packed in a register. +/// Common use is for fast FP8 converters. +template < + typename T, + typename S, + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest, + typename Transform = cutlass::transform::thread::UnaryTransform::Identity +> +struct NumericArrayConverterPacked4Element { + using result_type = Array; + using source_type = Array; + static FloatRoundStyle const round_style = Round; + + static_assert(platform::is_same::value || + platform::is_same::value, + "Unary Operator not supported."); + + CUTLASS_HOST_DEVICE + static result_type convert(source_type const & s) { + + result_type result; + NumericConverter convert_; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 4; ++i) { + if (platform::is_same::value) { + result[i] = convert_(s[i]); + } + else { // conjugate + result[i] = conj(convert_(s[i])); + } + } + + return result; + } + + CUTLASS_HOST_DEVICE + result_type operator()(source_type const &s) const { + return convert(s); + } +}; + /// Partial specialization for Array <= Array template < FloatRoundStyle Round > -struct NumericArrayConverter { +struct NumericArrayConverterPacked4Element { using result_element = float; using source_element = float_e4m3_t; @@ -1362,7 +1546,7 @@ struct NumericArrayConverter { template < FloatRoundStyle Round > -struct NumericArrayConverter { +struct NumericArrayConverterPacked4Element { using result_element = float_e4m3_t; using source_element = float; @@ -1406,11 +1590,17 @@ struct NumericArrayConverter { } }; +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Partial specializations for Array <=> Array +// +///////////////////////////////////////////////////////////////////////////////////////////////// + /// Partial specialization for Array <= Array template < FloatRoundStyle Round > -struct NumericArrayConverter { +struct NumericArrayConverterPacked4Element { using result_element = float; using source_element = float_e5m2_t; @@ -1465,7 +1655,7 @@ struct NumericArrayConverter { template < FloatRoundStyle Round > -struct NumericArrayConverter { +struct NumericArrayConverterPacked4Element { using result_element = float_e5m2_t; using source_element = float; @@ -1519,7 +1709,7 @@ struct NumericArrayConverter { template < FloatRoundStyle Round > -struct NumericArrayConverter { +struct NumericArrayConverterPacked4Element { using result_element = half_t; using source_element = float_e4m3_t; @@ -1564,7 +1754,7 @@ struct NumericArrayConverter { template < FloatRoundStyle Round > -struct NumericArrayConverter { +struct NumericArrayConverterPacked4Element { using result_element = float_e4m3_t; using source_element = half_t; @@ -1609,11 +1799,17 @@ struct NumericArrayConverter { } }; +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Partial specializations for Array <=> Array +// +///////////////////////////////////////////////////////////////////////////////////////////////// + /// Partial specialization for Array <= Array template < FloatRoundStyle Round > -struct NumericArrayConverter { +struct NumericArrayConverterPacked4Element { using result_element = half_t; using source_element = float_e5m2_t; @@ -1658,7 +1854,7 @@ struct NumericArrayConverter { template < FloatRoundStyle Round > -struct NumericArrayConverter { +struct NumericArrayConverterPacked4Element { using result_element = float_e5m2_t; using source_element = half_t; @@ -1713,7 +1909,7 @@ struct NumericArrayConverter { template < FloatRoundStyle Round > -struct NumericArrayConverter { +struct NumericArrayConverterPacked4Element { using result_element = bfloat16_t; using source_element = float_e4m3_t; @@ -1726,7 +1922,7 @@ struct NumericArrayConverter { #if defined(CUDA_PTX_FP8_CVT_ENABLED) // Convert f8 to float - NumericArrayConverter src2float; + NumericArrayConverterPacked4Element src2float; Array tmp_floats = src2float(source); // Convert float to bf16 @@ -1761,7 +1957,7 @@ struct NumericArrayConverter { template < FloatRoundStyle Round > -struct NumericArrayConverter { +struct NumericArrayConverterPacked4Element { using result_element = float_e4m3_t; using source_element = bfloat16_t; @@ -1782,7 +1978,7 @@ struct NumericArrayConverter { packed_tmp[1] = src2float(packed_source[1]); // Convert float to f8 - NumericArrayConverter float2result; + NumericArrayConverterPacked4Element float2result; return float2result(tmp); #else result_type result; @@ -1803,11 +1999,17 @@ struct NumericArrayConverter { } }; +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Partial specializations for Array <=> Array +// +///////////////////////////////////////////////////////////////////////////////////////////////// + /// Partial specialization for Array <= Array template < FloatRoundStyle Round > -struct NumericArrayConverter { +struct NumericArrayConverterPacked4Element { using result_element = bfloat16_t; using source_element = float_e5m2_t; @@ -1820,7 +2022,7 @@ struct NumericArrayConverter { #if defined(CUDA_PTX_FP8_CVT_ENABLED) // Convert f8 to float - NumericArrayConverter src2float; + NumericArrayConverterPacked4Element src2float; Array tmp_floats = src2float(source); // Convert float to bf16 @@ -1855,7 +2057,7 @@ struct NumericArrayConverter { template < FloatRoundStyle Round > -struct NumericArrayConverter { +struct NumericArrayConverterPacked4Element { using result_element = float_e5m2_t; using source_element = bfloat16_t; @@ -1876,7 +2078,7 @@ struct NumericArrayConverter { packed_tmp[1] = src2float(packed_source[1]); // Convert float to f8 - NumericArrayConverter float2result; + NumericArrayConverterPacked4Element float2result; return float2result(tmp); #else result_type result; @@ -1907,7 +2109,7 @@ struct NumericArrayConverter { template < FloatRoundStyle Round > -struct NumericArrayConverter { +struct NumericArrayConverterPacked4Element { using result_element = float_e4m3_t; using source_element = float_e5m2_t; @@ -1938,7 +2140,7 @@ struct NumericArrayConverter { template < FloatRoundStyle Round > -struct NumericArrayConverter { +struct NumericArrayConverterPacked4Element { using result_element = float_e5m2_t; using source_element = float_e4m3_t; @@ -1965,63 +2167,7 @@ struct NumericArrayConverter { } }; -///////////////////////////////////////////////////////////////////////////////////////////////// -// -// Partial specializations for: -// Array <=> Array -// Array <=> Array -// -// These are needed to avoid multiple-matching-template compilation errors (e.g., when -// compiling float_e4m3_t <=> float_e4m3_t, which among T <= float_e4m3_t and float_e4m3_t <= T -// should be used?) -// -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - using result_element = float_e4m3_t; - using source_element = float_e4m3_t; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const &source) { - return source; - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; - -/// Partial specialization for Array <= Array -template < - FloatRoundStyle Round -> -struct NumericArrayConverter { - using result_element = float_e5m2_t; - using source_element = float_e5m2_t; - - using result_type = Array; - using source_type = Array; - static FloatRoundStyle const round_style = Round; - - CUTLASS_HOST_DEVICE - static result_type convert(source_type const &source) { - return source; - } - - CUTLASS_HOST_DEVICE - result_type operator()(source_type const &s) const { - return convert(s); - } -}; +} ///////////////////////////////////////////////////////////////////////////////////////////////// // @@ -2058,7 +2204,7 @@ struct PackedNumericArrayConverter { packed_result_type* packed_result = reinterpret_cast(&result); const packed_source_type* packed_source = reinterpret_cast(&source); - NumericArrayConverter packed_converter; + detail::NumericArrayConverterPacked4Element packed_converter; CUTLASS_PRAGMA_UNROLL for (int i = 0; i < N / 4; ++i) { @@ -2150,8 +2296,11 @@ template < struct NumericArrayConverter : public PackedNumericArrayConverter {}; + + ///////////////////////////////////////////////////////////////////////////////////////////////// + /// Partial specialization for Array <= Array /// Conversion is performed with saturation regardless of setting of /// the `Round` template parameter. @@ -2360,7 +2509,9 @@ struct FastNumericArrayConverter { /// Partial specialization for Array <= Array template -struct FastNumericArrayConverter { +struct FastNumericArrayConverter::is_integer> +> { using result_type = Array; using source_type = Array; static FloatRoundStyle const round_style = Round; @@ -2442,7 +2593,6 @@ struct FastNumericArrayConverter { result_type operator()(source_type const &s) const { return convert(s); } }; - /// Partial specialization for Array <= Array template struct FastNumericArrayConverter { @@ -2454,7 +2604,7 @@ struct FastNumericArrayConverter { CUTLASS_DEVICE static result_type convert(source_type const &source) { result_type result; - + #if 0 // Scalar conversion (Please keep this code for reference for vectorized version below) CUTLASS_PRAGMA_UNROLL for (int i = 0; i < 4; ++i) { @@ -2471,8 +2621,8 @@ struct FastNumericArrayConverter { // (See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt) // The inline ptx below uses `msb=0` and `msb=1` from the above link to sign extend the sign-bit in 0, 1, 2, 3 bytes of s8x4 // into result_ptr[0] and result_ptr[1]'s 08-15 and 24-31 bits, respectively. - // Note that `__byte_perm(source_ptr[0], source_ptr[0], 0x9180);` won't achieve the same and doesn't sign extend the sign-bit. - // Thus, we use inline ptx `prmt.b32` instruction for the desired sign extend from `s8x2` to `s16x2`. + // Note that `__byte_perm(source_ptr[0], source_ptr[0], 0x9180);` won't acheive the same and doesn't sign extend the sign-bit. + // Thus, we use inline ptx `prmt.b32` instruction for the desired sign extend from s8x2 to s16x2. asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(result_ptr[0]) : "r"(source_ptr[0]), "n"(0x9180)); asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(result_ptr[1]) : "r"(source_ptr[0]), "n"(0xB3A2)); @@ -2508,7 +2658,6 @@ struct FastNumericArrayConverter { } }; - /// Partial specialization for Array <= Array template struct FastNumericArrayConverter { @@ -2519,7 +2668,7 @@ struct FastNumericArrayConverter { CUTLASS_DEVICE static result_type convert(source_type const &source) { result_type result; - + uint32_t const* source_ptr = reinterpret_cast(&source); uint32_t* result_ptr = reinterpret_cast(&result); @@ -2632,7 +2781,7 @@ struct FastNumericArrayConverter { template struct FastNumericArrayConverter::value || platform::is_same::value) && - (platform::is_same::value || platform::is_same::value)>::type> { + (platform::is_same::value || platform::is_same::value)>::type> { static_assert(!(N % 4), "N must be multiple of 4."); using result_type = Array; @@ -2658,7 +2807,7 @@ struct FastNumericArrayConverter { template using make_index_sequence = typename index_sequence_helper::type; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Get the register type used in kernel +// ///////////////////////////////////////////////////////////////////////////////////////////////// +namespace detail { + +template +struct get_unpacked_element_type { + using type = T; +}; + +} // namespace detail + } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/pipeline/sm90_pipeline.hpp b/include/cutlass/pipeline/sm90_pipeline.hpp index f0632830bc..4a50328db5 100644 --- a/include/cutlass/pipeline/sm90_pipeline.hpp +++ b/include/cutlass/pipeline/sm90_pipeline.hpp @@ -231,6 +231,7 @@ public : int warp_idx = canonical_warp_idx(); int lane_predicate = cute::elect_one_sync(); auto cluster_shape = ClusterShape{}; + if (warp_idx == 0 && lane_predicate == 1) { // Barrier FULL init for (int i = 0; i < Stages; ++i) { @@ -244,6 +245,8 @@ public : empty_barrier_ptr_[i].init(multicast_consumer_arrival_count); } } + cutlass::arch::fence_barrier_init(); + // Logic to optimally schedule Empty Arrives // Goal : To divide SYNCS Empty Arrival duty equally amongst the Warp-Group (128 threads) dim3 block_id = cute::block_id_in_cluster(); @@ -279,8 +282,6 @@ public : // STEP 2: Find if this dst block-id needs an arrival for this problem is_signalling_thread_ &= dst_blockid_ < cluster_size; is_signalling_thread_ &= is_same_row_or_col(dst_blockid_, block_id, cluster_shape); - - cutlass::arch::fence_barrier_init(); } CUTLASS_DEVICE @@ -899,6 +900,13 @@ public : producer_commit(state.index()); } + template + CUTLASS_DEVICE + void producer_commit(PipelineState state, UserDefinedArriveOp&& user_defined_arrive_op) { + cute::forward(user_defined_arrive_op)(producer_get_barrier(state.index())); + producer_commit(state); + } + // Prevents early exit of producer blocks in Cluster. // This should be called once before kernel exits. CUTLASS_DEVICE diff --git a/include/cutlass/platform/platform.h b/include/cutlass/platform/platform.h index f782c9ed9d..64dadb4b3e 100644 --- a/include/cutlass/platform/platform.h +++ b/include/cutlass/platform/platform.h @@ -127,13 +127,14 @@ #include // Minimum/maximum operations #include // nullptr_t #include // Arithmetic operations -#include // float_round_style, float_denorm_style #include // For methods on std::pair +#include // float_round_style, float_denorm_style #if (!defined(_MSC_VER) && (__cplusplus >= 201103L)) || (defined(_MSC_VER) && (_MS_VER >= 1500)) #include // For integral constants, conditional metaprogramming, and type traits #endif -#include "cutlass/cutlass.h" +#include +#include #endif @@ -389,10 +390,14 @@ struct conditional { typedef F type; }; +template +using void_t = void; + #else using std::enable_if; using std::conditional; +using std::void_t; #endif diff --git a/include/cutlass/predicate_vector.h b/include/cutlass/predicate_vector.h index b5ffe74eaf..3e41193ec0 100644 --- a/include/cutlass/predicate_vector.h +++ b/include/cutlass/predicate_vector.h @@ -169,7 +169,7 @@ struct PredicateVector { CUTLASS_HOST_DEVICE static constexpr bool computeWordMask() { Storage mask(0); CUTLASS_PRAGMA_UNROLL - for (int byte = 0; byte < sizeof(Storage); ++byte) { + for (size_t byte = 0; byte < sizeof(Storage); ++byte) { mask |= (kByteMask << (byte * 8)); } return mask; @@ -178,9 +178,8 @@ struct PredicateVector { /// Returns mask of last word. CUTLASS_HOST_DEVICE static constexpr bool computeLastWordMask() { Storage mask(0); - constexpr int count = (kBytes % sizeof(Storage) == 0) ? sizeof(Storage) : (kBytes % sizeof(Storage)); CUTLASS_PRAGMA_UNROLL - for (int byte = 0; byte < count; ++byte) { + for (int byte = 0; byte < kBytes % sizeof(Storage); ++byte) { mask |= (kByteMask << (byte * 8)); } return mask; @@ -514,7 +513,7 @@ struct PredicateVector { /// Returns true if entire predicate array is zero. CUTLASS_HOST_DEVICE bool is_zero() const { - constexpr Storage mask = computeWordMask(); + constexpr Storage mask = computeWordMask(); Storage result = 0; CUTLASS_PRAGMA_UNROLL for (int word = 0; word < kWordCount - 1; ++word) { @@ -522,6 +521,7 @@ struct PredicateVector { } constexpr Storage last_word_mask = computeLastWordMask(); result |= (storage(kWordCount - 1) & last_word_mask); + return result == 0; } diff --git a/include/cutlass/relatively_equal.h b/include/cutlass/relatively_equal.h index 00e737923c..e37e0bbc4a 100644 --- a/include/cutlass/relatively_equal.h +++ b/include/cutlass/relatively_equal.h @@ -81,6 +81,12 @@ bool relatively_equal_float(T a, T b, T epsilon, T nonzero_floor) { ///////////////////////////////////////////////////////////////////////////////////////////////// +template <> +CUTLASS_HOST_DEVICE +bool relatively_equal(bool a, bool b, bool, bool) { + return (a == b); +} + template <> CUTLASS_HOST_DEVICE bool relatively_equal(uint1b_t a, uint1b_t b, uint1b_t, uint1b_t) { diff --git a/include/cutlass/subbyte_reference.h b/include/cutlass/subbyte_reference.h index d06580522a..8ef6d18782 100644 --- a/include/cutlass/subbyte_reference.h +++ b/include/cutlass/subbyte_reference.h @@ -763,22 +763,18 @@ class SubbyteReference::value) | original_low_bits; - - - StorageUnit update_low_bits = (original_low_bits & kLowUpdateMask) | low_new_bits; - StorageUnit update_high_bits = (original_high_bits & kHighUpdateMask) | high_new_bits; - - updated = (CudaAtomicType(update_high_bits) << sizeof_bits::value) | update_low_bits; - - original = atomicCAS(reinterpret_cast(ptr_), original, updated); - - } while (updated != original); + original_low_bits = ((*ptr_)[low_storage_unit_idx_]); + update_low_bits = (original_low_bits & kLowUpdateMask) | low_new_bits; + original_low_bits = atomicCAS(&((*ptr_)[low_storage_unit_idx_]), original_low_bits, update_low_bits); + } while (update_low_bits != original_low_bits); + do { + original_high_bits = ((*ptr_)[high_storage_unit_idx_]); + update_high_bits = (original_high_bits & kHighUpdateMask) | high_new_bits; + original_high_bits = atomicCAS(&((*ptr_)[high_storage_unit_idx_]), original_high_bits, update_high_bits); + } while (update_high_bits != original_high_bits); } else { /// Only need update 1 storage unit. @@ -788,12 +784,13 @@ class SubbyteReference(ptr_), original, updated); + original = atomicCAS(&((*ptr_)[low_storage_unit_idx_]), original, updated); } while (updated != original); } #else + StorageUnit update_low_bits = ((*ptr_)[low_storage_unit_idx_] & kLowUpdateMask) | low_new_bits; StorageUnit update_high_bits = ((*ptr_)[high_storage_unit_idx_] & kHighUpdateMask) | high_new_bits; diff --git a/include/cutlass/workspace.h b/include/cutlass/workspace.h index 3c71f87826..537cceef62 100644 --- a/include/cutlass/workspace.h +++ b/include/cutlass/workspace.h @@ -45,6 +45,7 @@ #pragma once #if !defined(__CUDACC_RTC__) +#include "cuda.h" #include "cuda_runtime.h" #include "cutlass/trace.h" @@ -55,16 +56,19 @@ namespace cutlass { ///////////////////////////////////////////////////////////////////////////////////////////////// + +static constexpr int MinWorkspaceAlignment = 16; + #if !defined(__CUDACC_RTC__) static Status -zero_workspace(void* workspace, int workspace_size, cudaStream_t stream = nullptr) { +zero_workspace(void* workspace, size_t workspace_size, cudaStream_t stream = nullptr) { if (workspace_size > 0) { if (workspace == nullptr) { CUTLASS_TRACE_HOST(" error: device workspace must not be null"); return Status::kErrorWorkspaceNull; } - CUTLASS_TRACE_HOST(" clearing barrier workspace"); + CUTLASS_TRACE_HOST(" clearing workspace"); cudaError_t result = cudaMemsetAsync(workspace, 0, workspace_size, stream); if (cudaSuccess != result) { result = cudaGetLastError(); // to clear the error bit @@ -77,6 +81,47 @@ zero_workspace(void* workspace, int workspace_size, cudaStream_t stream = nullpt } #endif +#if !defined(__CUDACC_RTC__) +template +Status +fill_workspace(void* workspace, T fill_value, size_t fill_count, cudaStream_t stream = nullptr) { + static_assert(sizeof(T) == 4 || sizeof(T) == 2 || sizeof(T) == 1, "Unsupported fill type"); + if (fill_count > 0) { + if (workspace == nullptr) { + CUTLASS_TRACE_HOST(" error: device workspace must not be null"); + return Status::kErrorWorkspaceNull; + } + + CUTLASS_TRACE_HOST(" filling workspace"); + CUdeviceptr d_workspace = reinterpret_cast(workspace); + CUresult result = CUDA_SUCCESS; + if (sizeof(T) == 4) { + result = cuMemsetD32Async(d_workspace, reinterpret_cast(fill_value), fill_count, stream); + } + else if (sizeof(T) == 2) { + result = cuMemsetD16Async(d_workspace, reinterpret_cast(fill_value), fill_count, stream); + } + else if (sizeof(T) == 1) { + result = cuMemsetD8Async(d_workspace, reinterpret_cast(fill_value), fill_count, stream); + } + + if (CUDA_SUCCESS != result) { + const char** error_string_ptr = nullptr; + (void) cuGetErrorString(result, error_string_ptr); + if (error_string_ptr != nullptr) { + CUTLASS_TRACE_HOST(" cuMemsetD" << sizeof(T) * 8 << "Async() returned error " << *error_string_ptr); + } + else { + CUTLASS_TRACE_HOST(" cuMemsetD" << sizeof(T) * 8 << "Async() returned unrecognized error"); + } + return Status::kErrorInternal; + } + } + + return Status::kSuccess; +} +#endif + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass diff --git a/media/docs/gemm_api_3x.md b/media/docs/gemm_api_3x.md index 8197d2e721..8dded99d2d 100644 --- a/media/docs/gemm_api_3x.md +++ b/media/docs/gemm_api_3x.md @@ -296,7 +296,9 @@ freely with any mainloop. Each mainloop policy either prescribes a `Schedule` wi it needs to be run, or exposes a template API that lets the user pick a subset of the following schedules: ```c++ -struct KernelMultistage { }; +struct KernelCpAsyncWarpSpecialized { }; +struct KernelCpAsyncWarpSpecializedPingpong { }; +struct KernelCpAsyncWarpSpecializedCooperative { }; struct KernelTma { }; struct KernelTmaWarpSpecialized { }; struct KernelTmaWarpSpecializedPingpong { }; @@ -305,7 +307,7 @@ struct KernelTmaWarpSpecializedCooperative { }; - A single kernel schedule can support multiple mainloop implementations. For example, `KernelMultistage` can be composed with many different mainloop implementations across GPU -architectures such as `MainloopSm70TwoStage`, `MainloopSm80CpAsyncUnpredicated`, `MainloopSm90CpAsyncGmma`, and many more. +architectures such as `MainloopSm70TwoStage`, `MainloopSm80CpAsyncUnpredicated`, and many more. - A single mainloop can be composed with multiple possible kernel schedules. For example, the `MainloopSm90TmaGmmaWarpSpecialized` can be diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000..44723087ec --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,28 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[project] +name = "cutlass" +version = "3.3.0.0" +description = "CUTLASS" +readme = "README.md" +requires-python = ">=3.8" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: BSD License", + "Operating System :: OS Independent", +] +dependencies = [ + "cuda-python>=11.8.0", + "networkx", + "numpy", + "pydot", + "rmm-cu12 ; python_version>='3.9'", + "scipy", + "treelib" +] + +[project.urls] +"Homepage" = "https://github.com/nvidia/cutlass" +"Bug Tracker" = "https://github.com/nvidia/cutlass/issues" diff --git a/python/LICENSE.txt b/python/LICENSE.txt new file mode 100644 index 0000000000..2913ab80f9 --- /dev/null +++ b/python/LICENSE.txt @@ -0,0 +1,27 @@ +Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/python/README.md b/python/README.md index 63388b077d..d86c6b2d74 100644 --- a/python/README.md +++ b/python/README.md @@ -67,14 +67,13 @@ The CUTLASS Python interface currently supports the following operations: * Grouped GEMM (for pre-SM90 kernels) ### Getting started -We recommend using the CUTLASS Python interface via one of the Docker images located in the [docker](/python/docker) directory. +We recommend using the CUTLASS Python interface via an [NGC PyTorch Docker container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch): ```bash -docker build -t cutlass-cuda12.1:latest -f docker/Dockerfile-cuda12.1-pytorch . -docker run --gpus all -it --rm cutlass-cuda12.1:latest +docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.08-py3 ``` -The CUTLASS Python interface has been tested with CUDA 11.8, 12.0, and 12.1 on Python 3.8.10 and 3.9.7. +The CUTLASS Python interface has been tested with CUDA 11.8, 12.0, and 12.1 on Python 3.8 and 3.9. #### Optional environment variables Prior to installing the CUTLASS Python interface, one may optionally set the following environment variables: @@ -82,19 +81,21 @@ Prior to installing the CUTLASS Python interface, one may optionally set the fol * `CUDA_INSTALL_PATH`: the path to the installation of CUDA If these environment variables are not set, the installation process will infer them to be the following: -* `CUTLASS_PATH`: one directory level above the current directory (i.e., `$(pwd)/..`) +* `CUTLASS_PATH`: either one directory level above the current directory (i.e., `$(pwd)/..`) if installed locally or in the `source` directory of the location in which `cutlass_library` was installed * `CUDA_INSTALL_PATH`: the directory holding `/bin/nvcc` for the first version of `nvcc` on `$PATH` (i.e., `which nvcc | awk -F'/bin/nvcc' '{print $1}'`) **NOTE:** The version of `cuda-python` installed must match the CUDA version in `CUDA_INSTALL_PATH`. #### Installation -The CUTLASS Python interface can currently be installed via: +The CUTLASS Python interface can currently be installed by navigating to the root of the CUTLASS directory and performing ```bash -python setup.py develop --user +pip install . ``` -This will allow changes to the Python interface source to be reflected when using the Python interface. -We plan to add support for installing via `python setup.py install` in a future release. +If you would like to be able to make changes to CULASS Python interface and have them reflected when using the interface, perform: +```bash +pip install -e . +``` ### Examples Jupyter notebook examples of using the CUTLASS Python interface are located in [examples/python](/examples/python). @@ -135,10 +136,7 @@ python setup_library.py develop --user Alternatively, `cutlass_library` will automatically be installed if you install the CUTLASS Python interface package. -You can also use the [generator.py](/python/cutlass_library/generator.py) script directly without installing the module via: -```bash -python -m cutlass_library.generator -``` +You can also use the [generator.py](/python/cutlass_library/generator.py) script directly without installing the module. # Copyright diff --git a/python/cutlass/__init__.py b/python/cutlass/__init__.py index 39e9b4076f..0af9335715 100644 --- a/python/cutlass/__init__.py +++ b/python/cutlass/__init__.py @@ -37,14 +37,6 @@ import cutlass_library -def _cutlass_path_from_dir() -> str: - cutlass_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '../../') - if not os.path.isdir(cutlass_path): - raise Exception(f'Environment variable "CUTLASS_PATH" is not defined, ' - f'and default path of {cutlass_path} does not exist.') - return cutlass_path - - def _cuda_install_path_from_nvcc() -> str: import subprocess # Attempt to detect CUDA_INSTALL_PATH based on location of NVCC @@ -60,66 +52,41 @@ def _cuda_install_path_from_nvcc() -> str: return cuda_install_path -CUTLASS_PATH = os.getenv("CUTLASS_PATH", _cutlass_path_from_dir()) -CUDA_INSTALL_PATH = os.getenv("CUDA_INSTALL_PATH", _cuda_install_path_from_nvcc()) +CUTLASS_PATH = os.getenv("CUTLASS_PATH", cutlass_library.source_path) + +# Alias CUTLASS_PATH as source_path +source_path = CUTLASS_PATH + +_CUDA_INSTALL_PATH = None +def cuda_install_path(): + """ + Helper method for on-demand fetching of the CUDA installation path. This allows + the import of CUTLASS to proceed even if NVCC is not available, preferring to + raise this error only when an operation that needs NVCC is being performed. + """ + global _CUDA_INSTALL_PATH + if _CUDA_INSTALL_PATH is None: + _CUDA_INSTALL_PATH = os.getenv("CUDA_INSTALL_PATH", _cuda_install_path_from_nvcc()) + return _CUDA_INSTALL_PATH + CACHE_FILE = "compiled_cache.db" -# Import types/methods from the CUTLASS utility libraries for profiler generation/emission under -from cutlass_library.library import ( - ArchitectureNames, - ComplexTransform, - ComplexTransformTag, - ConvKind, - ConvKindNames, - ConvKindTag, - ConvMode, +from cutlass_library import ( DataType, - DataTypeNames, - DataTypeSize, - DataTypeTag, - EpilogueFunctor, - EpilogueScheduleSuffixes, - EpilogueScheduleTag, EpilogueScheduleType, - GemmKind, - GemmKindNames, - GemmUniversalMode, - IteratorAlgorithm, - IteratorAlgorithmNames, - IteratorAlgorithmTag, - LayoutTag, - LayoutType, - KernelScheduleSuffixes, - KernelScheduleTag, KernelScheduleType, - MathInstruction, - MathOperation, - MathOperationTag, + LayoutType, OpcodeClass, - OpcodeClassNames, - OpcodeClassTag, - OperationKind, - SharedMemPerCC, - ShortComplexLayoutNames, - ShortDataTypeNames, - ShortLayoutTypeNames, - SplitKMode, - StrideSupport, - StrideSupportNames, - StrideSupportTag, - SwizzlingFunctor, - SwizzlingFunctorTag, - TensorDescription, TileDescription, - TileSchedulerSuffixes, - TileSchedulerTag, TileSchedulerType, - get_complex_from_real, ) this = sys.modules[__name__] this.logger = logging.getLogger(__name__) +# RMM is only supported for Python 3.9+ +this.use_rmm = (sys.version_info.major == 3 and sys.version_info.major > 8) or sys.version_info.major > 3 + def set_log_level(level: int): """ Sets the log level @@ -134,11 +101,20 @@ def set_log_level(level: int): from cutlass.library_defaults import OptionRegistry from cutlass.backend.utils.device import device_cc -this.option_registry = OptionRegistry(device_cc()) +this._option_registry = None +def get_option_registry(): + """ + Helper method for on-demand initialization of the options registry. This avoids building + the registry when CUTLASS is imported. + """ + if this._option_registry is None: + this.logger.info("Initializing option registry") + this._option_registry = OptionRegistry(device_cc()) + return this._option_registry -this.__version__ = '3.2.1' +this.__version__ = '3.3.0' -from cutlass.backend import get_memory_pool +from cutlass.backend import create_memory_pool from cutlass.emit.pytorch import pytorch from cutlass.op.gemm import Gemm from cutlass.op.conv import Conv2d, Conv2dFprop, Conv2dDgrad, Conv2dWgrad @@ -146,4 +122,58 @@ def set_log_level(level: int): from cutlass.op.op import OperationBase from cutlass.backend.evt.ir.tensor import Tensor -get_memory_pool(init_pool_size=2 ** 30, max_pool_size=2 ** 32) + +this.memory_pool = None +def get_memory_pool(): + """" + Helper method for on-demand memory pool. This avoids allocating the memory pool unnecessarily + whe CUTLASS is imported. + """ + if this.use_rmm and this.memory_pool is None: + this.memory_pool = create_memory_pool(init_pool_size=2 ** 30, max_pool_size=2 ** 32) + return this.memory_pool + + +from cuda import cuda + +this._context = None +this._device_id = None +def initialize_cuda_context(): + if this._device_id is not None: + return + + if this.use_rmm: + # This also covers initializing the CUDA context + get_memory_pool() + + device_id = os.getenv("CUTLASS_CUDA_DEVICE_ID") + if device_id is None: + if not this.use_rmm: + # We must manually call cuInit in the absence of RMM + err, = cuda.cuInit(0) + if err != cuda.CUresult.CUDA_SUCCESS: + raise Exception(f"cuInit failed with error {err}") + + err, device_count = cuda.cuDeviceGetCount() + if err != cuda.CUresult.CUDA_SUCCESS: + raise Exception(f"cuDeviceGetCount failed with error {err}") + if device_count <= 0: + raise Exception("No CUDA devices found") + device_id = 0 + + this._device_id = device_id + + if not this.use_rmm and this._context is None: + # We must manually initialize the context in the absence of RMM + err, device = cuda.cuDeviceGet(this._device_id) + if err != cuda.CUresult.CUDA_SUCCESS: + raise Exception(f"cuDeviceGet failed with error {err}") + + err, this._context = cuda.cuCtxCreate(0, device) + if err != cuda.CUresult.CUDA_SUCCESS: + raise Exception(f"cuCtxCreate failed with error {err}") + + +def device_id() -> int: + initialize_cuda_context() + return this._device_id diff --git a/python/cutlass/backend/__init__.py b/python/cutlass/backend/__init__.py index 9b94c78d50..f1dce8d73b 100644 --- a/python/cutlass/backend/__init__.py +++ b/python/cutlass/backend/__init__.py @@ -6,17 +6,11 @@ from cutlass.backend.frontend import * from cutlass.backend.gemm_operation import * from cutlass.backend.library import * -from cutlass.backend.memory_manager import PoolMemoryManager +from cutlass.backend.memory_manager import PoolMemoryManager, create_memory_pool from cutlass.backend.operation import * from cutlass.backend.reduction_operation import * from cutlass.backend.type_hint import * from cutlass.backend.utils import * from cutlass.backend.utils.device import device_cc -from cutlass.backend.utils.software import ( - CheckPackages, - SubstituteTemplate, - device_sm_count, - get_memory_pool, -) compiler = ArtifactManager() diff --git a/python/cutlass/backend/arguments.py b/python/cutlass/backend/arguments.py index 20a01e6267..2c188334b0 100644 --- a/python/cutlass/backend/arguments.py +++ b/python/cutlass/backend/arguments.py @@ -36,16 +36,10 @@ from cuda import cuda, cudart import numpy as np +import cutlass from cutlass.backend.frontend import CupyFrontend, NumpyFrontend, TorchFrontend -from cutlass.backend.utils.software import CheckPackages - -torch_available = CheckPackages().check_torch() -if torch_available: - import torch - -cupy_available = CheckPackages().check_cupy() -if cupy_available: - import cupy as cp +from cutlass.backend.memory_manager import DevicePtrWrapper +from cutlass.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor class ArgumentBase: @@ -76,7 +70,7 @@ def __init__( self.ptr_A = self.tensor_to_ptr(A, "A") self.ptr_B = self.tensor_to_ptr(B, "B") self.ptr_C = self.tensor_to_ptr(C, "C") - self.ptr_D = self.tensor_to_ptr(D, "D", True) + self.ptr_D = self.tensor_to_ptr(D, "D", is_output=True) if C is not None: if not isinstance(C, cuda.CUdeviceptr): self.tensor_c_numel = prod(C.shape) @@ -88,18 +82,18 @@ def tensor_to_ptr(self, tensor, name, is_output=False): """ if tensor is None: return cuda.CUdeviceptr(0) - if isinstance(tensor, np.ndarray): + if is_numpy_tensor(tensor): if is_output: assert name self.buffers[name] = NumpyFrontend.argument(tensor, is_output) if is_output: self.host_tensors[name] = tensor return self.buffers[name].ptr - elif torch_available and isinstance(tensor, torch.Tensor): + elif is_torch_tensor(tensor): return TorchFrontend.argument(tensor) elif isinstance(tensor, cuda.CUdeviceptr): return tensor - elif cupy_available and isinstance(tensor, cp.ndarray): + elif is_cupy_tensor(tensor): return CupyFrontend.argument(tensor) else: raise TypeError("Unsupported Frontend. Only support numpy and torch") @@ -119,3 +113,23 @@ def sync(self, stream_sync=True): ) if err != cuda.CUresult.CUDA_SUCCESS: raise RuntimeError("CUDA Error %s" % str(err)) + + self.free() + + def free(self): + """ + Frees allocated device-side memory + """ + # Free any device memory allocated manually + if not cutlass.use_rmm: + for name, buf in self.buffers.items(): + if isinstance(buf, DevicePtrWrapper): + err, = cudart.cudaFree(buf.ptr) + if err != cudart.cudaError_t.cudaSuccess: + raise RuntimeError(f"cudaFree failed with error {err}") + + if hasattr(self, "workspace_buffer") and isinstance(self.workspace_buffer, DevicePtrWrapper): + err, = cudart.cudaFree(self.workspace_buffer.ptr) + if err != cudart.cudaError_t.cudaSuccess: + raise RuntimeError(f"cudaFree failed with error {err}") + del self.workspace_buffer diff --git a/python/cutlass/backend/c_types.py b/python/cutlass/backend/c_types.py index 73d0c66d94..17954a93f7 100644 --- a/python/cutlass/backend/c_types.py +++ b/python/cutlass/backend/c_types.py @@ -32,7 +32,7 @@ import ctypes -from cutlass import ( +from cutlass_library import ( DataType, KernelScheduleType ) @@ -125,7 +125,7 @@ def get_mainloop_arguments_3x( Returns the ctypes structure to be used for the 3.x kernel's mainloop parameters. :param kernel_schedule: type of kernel schedule to be used in the mainloop - :type kerel_schedule: cutlass.KernelScheduleType + :type kernel_schedule: cutlass_library.KernelScheduleType :param element_A: data type of operand A :param element_B: data type of operand B :param alignment_A: alignment of operand A @@ -166,25 +166,10 @@ def from_generic_mainloop_args(args: GenericMainloopArguments3x_): args.ptr_A, args.stride_A, args.ptr_B, args.stride_B, ) - tma_alignment_bytes = 16 - is_tma_aligned_A = ((DataTypeSizeBytes[element_A] * alignment_A) % tma_alignment_bytes) == 0 - is_tma_aligned_B = ((DataTypeSizeBytes[element_B] * alignment_B) % tma_alignment_bytes) == 0 - is_tma_aligned = is_tma_aligned_A and is_tma_aligned_B - - if kernel_schedule == KernelScheduleType.Multistage: - return _MainloopArgumentsMultistage - elif kernel_schedule == KernelScheduleType.ScheduleAuto: - if is_tma_aligned: - return _MainloopArgumentsTma - else: - return _MainloopArgumentsMultistage - else: - if is_tma_aligned: - return _MainloopArgumentsTma - else: - raise Exception(f"Specified a kernel schedule using TMA ({kernel_schedule}), but " - "the provided data types and alignments are not properly aligned for " - "using TMA.") + # Currently all 3.x kernels (CpAsync and Tma) have the same argument structure. + # Should that become not the case, this is the place to return custom ctypes + # structures based on selected kernel schedule. + return _MainloopArgumentsTma def get_gemm_arguments_3x(mainloop_arguments, epilogue_functor): diff --git a/python/cutlass/backend/compiler.py b/python/cutlass/backend/compiler.py index f03cd2be6f..e04a4eb266 100644 --- a/python/cutlass/backend/compiler.py +++ b/python/cutlass/backend/compiler.py @@ -38,12 +38,13 @@ import tempfile from cuda import cuda, nvrtc +from cutlass_library import SubstituteTemplate -from cutlass import CACHE_FILE, CUDA_INSTALL_PATH, CUTLASS_PATH, logger +import cutlass +from cutlass import CACHE_FILE, CUTLASS_PATH, cuda_install_path, logger from cutlass.backend.gemm_operation import GemmOperationUniversal from cutlass.backend.library import ApiVersion from cutlass.backend.utils.device import device_cc -from cutlass.backend.utils.software import SubstituteTemplate IncludeTemplate = r"""#include "${include}" """ @@ -316,7 +317,7 @@ def emit_compile_(self, operation_list, compilation_options, host_compilation_op # compile with nvcc cmd_template = "${cuda_install_path}/bin/nvcc ${options} -cubin ${srcfile} -o ${tarfile}" values = { - "cuda_install_path": CUDA_INSTALL_PATH, + "cuda_install_path": cuda_install_path(), "options": compilation_options.get_str(), "srcfile": temp_cu.name, "tarfile": temp_cubin.name, @@ -336,7 +337,7 @@ def emit_compile_(self, operation_list, compilation_options, host_compilation_op cmd = SubstituteTemplate( cmd_template, { - "cuda_install_path": CUDA_INSTALL_PATH, + "cuda_install_path": cuda_install_path(), "options": host_compilation_options.get_str(), }, ) @@ -356,18 +357,15 @@ def add_module(self, operations, compile_options=None, bypass_cache=False): Insert a new compiled device module """ include_paths = [ - CUDA_INSTALL_PATH + "/include", + cuda_install_path() + "/include", CUTLASS_PATH + "/include", CUTLASS_PATH + "/tools/util/include", CUTLASS_PATH + "/python/cutlass/cpp/include", ] - if device_cc() is not None: - arch = device_cc() - else: - # Find the maximum arch tag among the provided operations and compile for that target. - # Since we are compiling to .cubin files, only one architecture may be specified. - arch = max([op.arch for op in operations]) + cutlass.initialize_cuda_context() + arch = device_cc() + host_compile_options = CompilationOptions( self._nvcc_compile_options, arch, include_paths) if compile_options is None: diff --git a/python/cutlass/backend/conv2d_operation.py b/python/cutlass/backend/conv2d_operation.py index 466c71b491..4a2f2f03c2 100644 --- a/python/cutlass/backend/conv2d_operation.py +++ b/python/cutlass/backend/conv2d_operation.py @@ -34,9 +34,10 @@ from typing import Union from cuda import cuda +from cutlass_library import SubstituteTemplate import numpy as np -from cutlass import ( +from cutlass_library import ( ConvKindNames, ConvKindTag, DataTypeNames, @@ -71,13 +72,9 @@ ) from cutlass.backend.memory_manager import device_mem_alloc from cutlass.backend.operation import ExecutableOperation, LaunchConfiguration -from cutlass.backend.utils.datatypes import to_device_ptr -from cutlass.backend.utils.software import CheckPackages, SubstituteTemplate +from cutlass.backend.utils.device import to_device_ptr from cutlass.shape import GemmCoord -if CheckPackages().check_torch(): - import torch - class Conv2dArguments(ArgumentBase): """ diff --git a/python/cutlass/backend/epilogue.py b/python/cutlass/backend/epilogue.py index df87f6c9c2..784f8e9553 100644 --- a/python/cutlass/backend/epilogue.py +++ b/python/cutlass/backend/epilogue.py @@ -32,14 +32,15 @@ import ctypes +from cutlass_library import SubstituteTemplate import numpy as np from scipy.special import erf -from cutlass import DataType, DataTypeTag +from cutlass_library import DataType, DataTypeTag from cutlass.backend.c_types import MatrixCoord_ from cutlass.backend.frontend import NumpyFrontend from cutlass.backend.library import ActivationOp, ActivationOpTag -from cutlass.backend.utils.software import CheckPackages, SubstituteTemplate +from cutlass.utils.datatypes import is_numpy_tensor, is_torch_available, is_torch_tensor dtype2ctype = { DataType.f16: ctypes.c_uint16, @@ -49,8 +50,7 @@ DataType.s32: ctypes.c_int32 } -torch_available = CheckPackages().check_torch() -if torch_available: +if is_torch_available(): import torch import torch.nn.functional as F @@ -59,11 +59,11 @@ def get_scalar(value): """ Returns a scalar value from a container (e.g., np.ndarray) """ - if isinstance(value, np.ndarray): + if is_numpy_tensor(value): if value.size != 1: raise Exception("Scalars used in epilogue must be of size 1") return value.reshape(-1)[0] - elif CheckPackages().check_torch() and isinstance(value, torch.Tensor): + elif is_torch_tensor(value): if value.size != 1: raise Exception("Scalars used in epilogue must be of size 1") return value.reshape(-1)[0] @@ -353,9 +353,9 @@ def __init__(self, alpha, beta, *args) -> None: class ActivationMeta(type): @classmethod def __call__(cls, x, *args): - if isinstance(x, np.ndarray): + if is_numpy_tensor(x): return cls.numpy(x, *args) - elif torch_available and isinstance(x, torch.Tensor): + elif is_torch_tensor(x): return cls.torch(x, *args) else: raise NotImplementedError("Unsupported tensor type") diff --git a/python/cutlass/backend/evt/backend/emitter_base.py b/python/cutlass/backend/evt/backend/emitter_base.py index 375378c943..6b35476e39 100644 --- a/python/cutlass/backend/evt/backend/emitter_base.py +++ b/python/cutlass/backend/evt/backend/emitter_base.py @@ -34,7 +34,7 @@ Base class for Epilogue Visitor Emitter """ -from cutlass import DataTypeTag +from cutlass_library import DataTypeTag from cutlass.backend.evt.ir import TopoVisitorNode, DAGIR diff --git a/python/cutlass/backend/evt/backend/sm80_nodes.py b/python/cutlass/backend/evt/backend/sm80_nodes.py index 0158a905d1..d4cb561f25 100644 --- a/python/cutlass/backend/evt/backend/sm80_nodes.py +++ b/python/cutlass/backend/evt/backend/sm80_nodes.py @@ -30,7 +30,7 @@ # ################################################################################################# -from cutlass import DataTypeTag +from cutlass_library import DataTypeSize, DataTypeTag from cutlass.backend.evt.ir import ( # Load Node diff --git a/python/cutlass/backend/evt/backend/sm90_emitter.py b/python/cutlass/backend/evt/backend/sm90_emitter.py index 2e28cc3fb2..ca8246ee0e 100644 --- a/python/cutlass/backend/evt/backend/sm90_emitter.py +++ b/python/cutlass/backend/evt/backend/sm90_emitter.py @@ -34,7 +34,7 @@ Emitter for Sm90 Epilogue Visitor """ -from cutlass import DataTypeTag, EpilogueScheduleTag +from cutlass_library import DataTypeTag, EpilogueScheduleTag from cutlass.backend import GemmOperationUniversal from cutlass.backend.evt.backend.emitter_base import FusionCallbacks diff --git a/python/cutlass/backend/evt/backend/sm90_nodes.py b/python/cutlass/backend/evt/backend/sm90_nodes.py index 3e29a3af1f..4304c2ccd5 100644 --- a/python/cutlass/backend/evt/backend/sm90_nodes.py +++ b/python/cutlass/backend/evt/backend/sm90_nodes.py @@ -32,7 +32,7 @@ from pycute import product -from cutlass import DataTypeSize, DataTypeTag +from cutlass_library import DataTypeSize, DataTypeTag from cutlass.backend.evt.ir import ( # Load Node AccumulatorImpl, diff --git a/python/cutlass/backend/evt/epilogue.py b/python/cutlass/backend/evt/epilogue.py index 75bc703e9a..a49c154189 100644 --- a/python/cutlass/backend/evt/epilogue.py +++ b/python/cutlass/backend/evt/epilogue.py @@ -37,12 +37,13 @@ import ctypes from cuda import cuda +from cutlass_library import DataType import numpy as np -from cutlass import DataType from cutlass.backend.epilogue import EpilogueFunctorBase import cutlass.backend.evt.backend from cutlass.backend.frontend import TensorFrontend +from cutlass.utils.datatypes import is_numpy_tensor class EpilogueFunctorVisitor(EpilogueFunctorBase): @@ -125,7 +126,7 @@ def get_tensor_ptr(self, tensor_name, kwargs, is_output=False): # The tensor frontend returns a device buffer for np.ndarray # and device ptr for other frontends buffer_or_ptr = TensorFrontend.argument(tensor, is_output) - if isinstance(tensor, np.ndarray): + if is_numpy_tensor(tensor): # Remember the host tensor for later synchronization setattr(self, f"{tensor_name}_buffer", buffer_or_ptr) setattr(self, f"{tensor_name}_host", tensor) diff --git a/python/cutlass/backend/evt/frontend/frontend_base.py b/python/cutlass/backend/evt/frontend/frontend_base.py index 8d9f6c6e37..74c7ff3976 100644 --- a/python/cutlass/backend/evt/frontend/frontend_base.py +++ b/python/cutlass/backend/evt/frontend/frontend_base.py @@ -36,7 +36,7 @@ from typing import Union -from cutlass import DataType +from cutlass_library import DataType from cutlass.backend.evt.ir import ( ComputeNode, DAGIR, diff --git a/python/cutlass/backend/evt/frontend/python_ast.py b/python/cutlass/backend/evt/frontend/python_ast.py index ac799d8092..ec32f0e667 100644 --- a/python/cutlass/backend/evt/frontend/python_ast.py +++ b/python/cutlass/backend/evt/frontend/python_ast.py @@ -38,8 +38,9 @@ import inspect import textwrap +from cutlass_library import DataType + import cutlass -from cutlass import DataType from cutlass.backend.evt.frontend.frontend_base import EVTFrontendBase from cutlass.backend.epilogue import relu from cutlass.backend.library import FunctionalOp diff --git a/python/cutlass/backend/evt/ir/dag_ir.py b/python/cutlass/backend/evt/ir/dag_ir.py index d0ac9402f0..b3cdeb3ea1 100644 --- a/python/cutlass/backend/evt/ir/dag_ir.py +++ b/python/cutlass/backend/evt/ir/dag_ir.py @@ -36,7 +36,8 @@ import networkx as nx -from cutlass import DataType +from cutlass_library import DataType + from cutlass.backend.evt.ir.node import NodeBase from cutlass.backend.utils import device_cc diff --git a/python/cutlass/backend/evt/ir/layout_nodes.py b/python/cutlass/backend/evt/ir/layout_nodes.py index 4262389897..5f3a3328b2 100644 --- a/python/cutlass/backend/evt/ir/layout_nodes.py +++ b/python/cutlass/backend/evt/ir/layout_nodes.py @@ -38,10 +38,10 @@ from copy import deepcopy +from cutlass_library import LayoutType from pycute import product, flatten import cutlass -from cutlass import LayoutType from cutlass.backend.evt.ir.layout_algorithm import _list_to_tuple, _tuple_to_list from cutlass.backend.evt.ir.node import NodeBase from cutlass.backend.evt.ir.tensor import Tensor diff --git a/python/cutlass/backend/evt/ir/node.py b/python/cutlass/backend/evt/ir/node.py index 9cf23331f3..9a2b75a5a4 100644 --- a/python/cutlass/backend/evt/ir/node.py +++ b/python/cutlass/backend/evt/ir/node.py @@ -37,7 +37,8 @@ import ctypes from re import sub -from cutlass import LayoutType +from cutlass_library import LayoutType + from cutlass.backend.evt.ir.layout_algorithm import _list_to_tuple, _reverse_tuple from cutlass.backend.evt.ir.tensor import Tensor diff --git a/python/cutlass/backend/evt/ir/store_nodes.py b/python/cutlass/backend/evt/ir/store_nodes.py index e050e43009..b9715034dd 100644 --- a/python/cutlass/backend/evt/ir/store_nodes.py +++ b/python/cutlass/backend/evt/ir/store_nodes.py @@ -36,7 +36,8 @@ import ctypes -from cutlass import DataType +from cutlass_library import DataType + from cutlass.backend.c_types import tuple_factory from cutlass.backend.epilogue import dtype2ctype, to_ctype_value from cutlass.backend.evt.ir.node import NodeBase, ImplBase, NoOpImpl diff --git a/python/cutlass/backend/evt/ir/tensor.py b/python/cutlass/backend/evt/ir/tensor.py index aa0c008e89..1ab3bbd886 100644 --- a/python/cutlass/backend/evt/ir/tensor.py +++ b/python/cutlass/backend/evt/ir/tensor.py @@ -34,7 +34,7 @@ High-level class for tensor """ -from cutlass import LayoutType +from cutlass_library import LayoutType from cutlass.backend.evt.ir.layout_algorithm import ( Layout, diff --git a/python/cutlass/backend/evt/passes/graph_drawer.py b/python/cutlass/backend/evt/passes/graph_drawer.py index 83406f96ee..ee853f6f13 100644 --- a/python/cutlass/backend/evt/passes/graph_drawer.py +++ b/python/cutlass/backend/evt/passes/graph_drawer.py @@ -32,9 +32,9 @@ import subprocess +from cutlass_library import DataTypeTag import pydot -from cutlass import DataTypeTag from cutlass.backend.evt.ir.dag_ir import DAGIR diff --git a/python/cutlass/backend/evt/passes/pass_preprocess_red.py b/python/cutlass/backend/evt/passes/pass_preprocess_red.py index afb8a9c46d..b617601549 100644 --- a/python/cutlass/backend/evt/passes/pass_preprocess_red.py +++ b/python/cutlass/backend/evt/passes/pass_preprocess_red.py @@ -42,7 +42,6 @@ from cutlass.backend.evt.passes.pass_manager import EVTPassBase - class PassPreprocessRed(EVTPassBase): """ Preprocess red nodes diff --git a/python/cutlass/backend/evt/passes/smem_size_calculator.py b/python/cutlass/backend/evt/passes/smem_size_calculator.py index 670367d075..a0c60f3797 100644 --- a/python/cutlass/backend/evt/passes/smem_size_calculator.py +++ b/python/cutlass/backend/evt/passes/smem_size_calculator.py @@ -34,6 +34,7 @@ Compute the shared memory size in bytes """ +import cutlass_library from pycute import shape_div, product import cutlass @@ -56,10 +57,13 @@ def __init__(self, dag_ir: DAGIR) -> None: def sm90_epilogue_tile(self, tile_description): # Get the epilogue tile size schedule = tile_description.epilogue_schedule - if schedule == cutlass.EpilogueScheduleType.TmaWarpSpecialized: + if schedule == cutlass_library.EpilogueScheduleType.TmaWarpSpecialized: epilogue_tile_mn = (64, 32) - elif schedule == cutlass.EpilogueScheduleType.TmaWarpSpecializedCooperative: - epilogue_tile_mn = (128, 32) + elif schedule == cutlass_library.EpilogueScheduleType.TmaWarpSpecializedCooperative: + if tile_description.threadblock_shape[0] >= 128: + epilogue_tile_mn = (128, 32) + else: + epilogue_tile_mn = (64, 32) else: raise NotImplementedError(f"Unsupported schedule: {schedule}") diff --git a/python/cutlass/backend/frontend.py b/python/cutlass/backend/frontend.py index a43dcbb00b..a39635fa99 100644 --- a/python/cutlass/backend/frontend.py +++ b/python/cutlass/backend/frontend.py @@ -34,15 +34,7 @@ import numpy as np from cutlass.backend.memory_manager import device_mem_alloc, todevice -from cutlass.backend.utils.software import CheckPackages - -torch_available = CheckPackages().check_torch() -if torch_available: - import torch - -cupy_available = CheckPackages().check_cupy() -if cupy_available: - import cupy as cp +from cutlass.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor class NumpyFrontend: @@ -97,6 +89,7 @@ class CupyFrontend: def argument(cupy_ndarray: "cp.ndarray"): return cuda.CUdeviceptr(int(cupy_ndarray.data.ptr)) + class TensorFrontend: """ Universal Frontend for client-provide tensors @@ -104,11 +97,11 @@ class TensorFrontend: @staticmethod def argument(tensor, is_output=False): - if isinstance(tensor, np.ndarray): + if is_numpy_tensor(tensor): return NumpyFrontend.argument(tensor, is_output) - elif torch_available and isinstance(tensor, torch.Tensor): + elif is_torch_tensor(tensor): return TorchFrontend.argument(tensor) - elif cupy_available and isinstance(tensor, cp.ndarray): + elif is_cupy_tensor(tensor): return CupyFrontend.argument(tensor) else: raise NotImplementedError("Unknown Tensor Type") diff --git a/python/cutlass/backend/gemm_operation.py b/python/cutlass/backend/gemm_operation.py index 8bbf402418..c5c756db71 100644 --- a/python/cutlass/backend/gemm_operation.py +++ b/python/cutlass/backend/gemm_operation.py @@ -35,10 +35,10 @@ import enum from cuda import cuda, cudart +from cutlass_library import SubstituteTemplate import numpy as np -import rmm -from cutlass import ( +from cutlass_library import ( ComplexTransformTag, DataType, DataTypeNames, @@ -96,11 +96,7 @@ from cutlass.backend.memory_manager import device_mem_alloc, todevice from cutlass.backend.operation import ExecutableOperation, LaunchConfiguration from cutlass.backend.type_hint import GemmOperation, Tensor -from cutlass.backend.utils.software import ( - CheckPackages, - SubstituteTemplate, - device_sm_count, -) +from cutlass.backend.utils.device import device_sm_count from cutlass.shape import GemmCoord, MatrixCoord @@ -163,7 +159,7 @@ class GemmArguments2x(ArgumentBase): :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray :param gemm_mode: GEMM mode - :type gemm_mode: :class:`cutlass.GemmUniversalMode` + :type gemm_mode: :class:`cutlass_library.GemmUniversalMode` :param output_op: output operator, optional :type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments` @@ -387,7 +383,7 @@ class GemmArguments2xStreamK(GemmArguments2x): :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray :param gemm_mode: GEMM mode - :type gemm_mode: :class:`cutlass.GemmUniversalMode` + :type gemm_mode: :class:`cutlass_library.GemmUniversalMode` :param output_op: output operator, optional :type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments` @@ -426,9 +422,12 @@ def get_arguments(self): def initialize(self): # Get the host and device workspace - device_workspace_size = self.operation.rt_module.get_device_workspace_size(self) + device_workspace_size = self.operation.rt_module.get_device_workspace_size( + self, + device_sm_count(), + self.operation.rt_module.occupancy + ) - device_workspace_size = 10 << 20 if device_workspace_size > 0: self.workspace_buffer = device_mem_alloc(device_workspace_size) workspace_ptr = self.workspace_buffer.ptr @@ -626,7 +625,7 @@ def GemmArguments(operation, problem_size, A, B, C, D, gemm_mode=GemmUniversalMo :type D: cuda.CUdeviceptr | numpy.ndarray | torch.Tensor | cupy.ndarray :param gemm_mode: GEMM mode - :type gemm_mode: :class:`cutlass.GemmUniversalMode` + :type gemm_mode: :class:`cutlass_library.GemmUniversalMode` :param output_op: output operator, optional :type output_op: :class:`cutlass.backend.LinearCombinationFunctorArguments` @@ -1038,6 +1037,11 @@ class GemmRTUniversalStreamK(GemmRTUniversal): typename GemmType::Params params(*args, device_sms, sm_occupancy); return params.get_grid_dims(); } + + uint64_t ${operation_name}_get_kernel_workspace_size(GemmType::Arguments* args, int device_sms, int sm_occupancy) { + typename GemmType::Params params(*args, device_sms, sm_occupancy); + return params.get_workspace_size(); + } } """ @@ -1045,6 +1049,7 @@ def __init__(self, operation: "GemmOperation"): super(GemmRTUniversalStreamK, self).__init__(operation) self.extra_funcs = { "get_grid_shape": GemmCoord_, + "get_kernel_workspace_size": ctypes.c_uint64, } self._occupancy = None self.argument_type, self.epilogue_type = get_gemm_arguments_streamk(operation.epilogue_functor) @@ -1062,6 +1067,9 @@ def occupancy(self): f"{cuda.cuGetErrorString(err)[1]}") return self._occupancy + def get_device_workspace_size(self, arguments: GemmArguments2xStreamK, device_sms: int, sm_occupancy: int): + return self.get_kernel_workspace_size(ctypes.byref(arguments.get_arguments()), device_sms, sm_occupancy) + ################################################################################ # Runtime module for GEMM Universal within CUTLASS 3 @@ -1431,7 +1439,7 @@ def host_precompute(self, arguments, workspace_bytes): problem_info_array = bytearray(problem_info.contents) # copy to device memory - return rmm.DeviceBuffer.to_device(problem_info_array).ptr + return todevice(problem_info_array).ptr def plan(self, arguments): return LaunchConfiguration( @@ -1537,10 +1545,6 @@ def run(self, arguments: GemmArguments) -> cuda.CUresult: return err - def free(self): - if hasattr(self, "workspace_buffer"): - del self.workspace_buffer - def is_complex(self): complex_operators = [ MathOperation.multiply_add_complex, @@ -1627,7 +1631,7 @@ def extended_name_3x(self): element_b=DataTypeNames[self.B.element], element_acc=DataTypeNames[self.tile_description.math_instruction.element_accumulator], element_c=DataTypeNames[self.C.element], - element_d=DataTypeNames[self.C.element], + element_d=DataTypeNames[self.epilogue_functor.element_output], core_name=self.core_name()) return extended_name diff --git a/python/cutlass/backend/library.py b/python/cutlass/backend/library.py index 62939a521c..e78dd5391e 100644 --- a/python/cutlass/backend/library.py +++ b/python/cutlass/backend/library.py @@ -36,7 +36,7 @@ import enum -from cutlass import ( +from cutlass_library import ( ComplexTransform, DataType, DataTypeSize, @@ -94,18 +94,6 @@ def __class_getitem__(datatype): return bits // 8 -SharedMemPerCC = { - 70: 96 << 10, # 96KB of SMEM - 72: 96 << 10, # 96KB of SMEM - 75: 64 << 10, # 64KB of SMEM - 80: 160 << 10, # 164KB of SMEM - 4KB reserved for the driver - 86: 100 << 10, # 100KB of SMEM - 87: 160 << 10, # 164KB of SMEM - 4KB reserved for the driver - 89: 100 << 10, # 100KB of SMEM - 90: 227 << 10, # 228KB of SMEM - 1KB reserved for the driver -} - - class SchedulerMode(enum.Enum): Device = enum_auto() Host = enum_auto() @@ -277,11 +265,11 @@ def __init__( :type math_instruction: MathInstruction :param cluster_shape: number of threadblocks in the [X, Y, Z] dimensions of a threadblock cluster :param kernel_schedule: type of kernel schedule to use (only available for SM90+) - :type kernel_schedule: cutlass.KernelScheduleType + :type kernel_schedule: cutlass_library.KernelScheduleType :param epilogue_schedule: type of epilogue schedule to use (only available for SM90+) - :type epilogue_schedule: cutlass.EpilogueScheduleType + :type epilogue_schedule: cutlass_library.EpilogueScheduleType :param tile_scheduler: type of tile scheduler to use (only available for SM90+) - :type tile_scheduler: cutlass.TileSchedulerType + :type tile_scheduler: cutlass_library.TileSchedulerType """ if ((kernel_schedule is None and epilogue_schedule is not None) or (kernel_schedule is not None and epilogue_schedule is None)): @@ -413,7 +401,10 @@ class TensorDescription: def __init__(self, element, layout, alignment=1, complex_transform=ComplexTransform.none): self.element = element self.layout = layout - self.alignment = min(128 // DataTypeSize[self.element], alignment) + if element != DataType.void: + self.alignment = min(128 // DataTypeSize[self.element], alignment) + else: + self.alignment = alignment self.complex_transform = complex_transform @@ -473,9 +464,9 @@ def api_version(arch, opclass, dtype): :param arch: compute capability of device on which to run :type arch: int :param opclass: class of the operation being performed - :type opclass: cutlass.OpcodeClass + :type opclass: cutlass_library.OpcodeClass :param dtype: data type to be used in operation (assumes that ElementA and ElementB are the same) - :type dtype: cutlass.DataType + :type dtype: cutlass_library.DataType :return: API version to be used in code emission :rtype: ApiVersion diff --git a/python/cutlass/backend/memory_manager.py b/python/cutlass/backend/memory_manager.py index 7c759e64cc..d3bd4be361 100644 --- a/python/cutlass/backend/memory_manager.py +++ b/python/cutlass/backend/memory_manager.py @@ -31,7 +31,14 @@ ################################################################################################# import numpy as np -import rmm + +import cutlass +from cutlass.utils.datatypes import is_numpy_tensor + +if cutlass.use_rmm: + import rmm +else: + from cuda import cudart class PoolMemoryManager: @@ -44,31 +51,70 @@ def __init__(self, init_pool_size: int, max_pool_size: int) -> None: self.mr = rmm.mr.TrackingResourceAdaptor(self.pool) rmm.mr.set_current_device_resource(self.mr) - def get_allocated_size(self): - return self.mr.get_allocated_bytes() - def pool_size(self): return self.pool.pool_size() +class DevicePtrWrapper: + """ + Wrapper around a pointer to device memory to provide a uniform interface with the RMM DeviceBuffer + (at least in terms of the interface used by the CUTLASS Python interface) + """ + def __init__(self, dev_ptr): + self.dev_ptr = dev_ptr + + @property + def ptr(self): + return self.dev_ptr + + +def _todevice(host_data): + """ + Helper for transferring host data to device memory + """ + if cutlass.use_rmm: + return rmm.DeviceBuffer.to_device(host_data.tobytes()) + else: + nbytes = len(host_data.tobytes()) + dev_ptr_wrapper = device_mem_alloc(nbytes) + err, = cudart.cudaMemcpy( + dev_ptr_wrapper.ptr, + host_data.__array_interface__['data'][0], + nbytes, + cudart.cudaMemcpyKind.cudaMemcpyHostToDevice + ) + if err != cudart.cudaError_t.cudaSuccess: + raise Exception(f"cudaMemcpy failed with error {err}") + return dev_ptr_wrapper + + def todevice(host_data, dtype=np.float32): """ Pass the host_data to device memory """ if isinstance(host_data, list): - return rmm.DeviceBuffer.to_device(np.array(host_data, dtype=dtype).tobytes()) - elif isinstance(host_data, np.ndarray): - return rmm.DeviceBuffer.to_device(host_data.tobytes()) + return _todevice(np.array(host_data, dtype=dtype)) + elif is_numpy_tensor(host_data): + return _todevice(host_data) def device_mem_alloc(size): - return rmm.DeviceBuffer(size=size) + if cutlass.use_rmm: + return rmm.DeviceBuffer(size=size) + else: + err, ptr = cudart.cudaMalloc(size) + if err != cudart.cudaError_t.cudaSuccess: + raise Exception(f"cudaMalloc failed with error {err}") + return DevicePtrWrapper(ptr) def align_size(size, alignment=256): return ((size + alignment - 1) // alignment) * alignment -def get_allocated_size(): - device_resource = rmm.mr.get_current_device_resource() - return device_resource.get_allocated_bytes() +def create_memory_pool(init_pool_size=0, max_pool_size=2 ** 34): + if cutlass.use_rmm: + memory_pool = PoolMemoryManager(init_pool_size=init_pool_size, max_pool_size=max_pool_size) + return memory_pool + else: + return None diff --git a/python/cutlass/backend/operation.py b/python/cutlass/backend/operation.py index 8a4d57d649..426e721f1a 100644 --- a/python/cutlass/backend/operation.py +++ b/python/cutlass/backend/operation.py @@ -37,9 +37,15 @@ from cutlass.backend.utils.device import device_cc _version_splits = [int(x) for x in __version__.split("rc")[0].split(".")] -supports_cluster_launch = device_cc() >= 90 and ( - _version_splits[0] > 11 or (_version_splits[0] == 11 and _version_splits[1] >= 8) -) +_supports_cluster_launch = None + + +def supports_cluster_launch(): + global _supports_cluster_launch + if _supports_cluster_launch is None: + major, minor = _version_splits[0], _version_splits[1] + _supports_cluster_launch = device_cc() >= 90 and (major > 11 or (major == 11 and minor >= 8)) + return _supports_cluster_launch class LaunchConfiguration: @@ -121,7 +127,7 @@ def run(self, host_workspace, device_workspace, launch_config, stream=cuda.CUstr packed = (ctypes.c_void_p * 1)() packed[0] = ctypes.addressof(cArg) - if supports_cluster_launch: + if supports_cluster_launch(): return self.run_with_clusters(launch_config, packed, stream) else: return self.run_without_clusters(launch_config, packed, stream) diff --git a/python/cutlass/backend/reduction_operation.py b/python/cutlass/backend/reduction_operation.py index 9662017cc9..5d42cc5209 100644 --- a/python/cutlass/backend/reduction_operation.py +++ b/python/cutlass/backend/reduction_operation.py @@ -36,21 +36,22 @@ from cuda import cuda, cudart import numpy as np -from cutlass import ( +from cutlass_library import ( DataTypeNames, DataTypeSize, DataTypeTag, - LayoutType + LayoutType, + SubstituteTemplate ) + +import cutlass from cutlass.backend.c_types import MatrixCoord_, TensorRef2D_, get_reduction_params from cutlass.backend.frontend import NumpyFrontend, TorchFrontend from cutlass.backend.library import TensorDescription +from cutlass.backend.memory_manager import DevicePtrWrapper from cutlass.backend.operation import ExecutableOperation, LaunchConfiguration -from cutlass.backend.utils.software import CheckPackages, SubstituteTemplate from cutlass.shape import MatrixCoord - -if CheckPackages().check_torch(): - import torch +from cutlass.utils.datatypes import is_numpy_tensor, is_torch_tensor class ReductionOperation: @@ -85,13 +86,13 @@ def __init__( # number of split-k partitions self.partitions = partitions - if isinstance(destination, np.ndarray): + if is_numpy_tensor(destination): self.host_D = destination self.destination_buffer = NumpyFrontend.argument(destination, True) self.source_buffer = NumpyFrontend.argument(source, False) self.ptr_destination = cuda.CUdeviceptr(self.destination_buffer.ptr) self.ptr_source = cuda.CUdeviceptr(self.source_buffer.ptr) - elif CheckPackages().check_torch() and isinstance(destination, torch.Tensor): + elif is_torch_tensor(destination): self.ptr_destination = TorchFrontend.argument(destination) self.ptr_source = TorchFrontend.argument(source) elif isinstance(destination, cuda.CUdeviceptr): @@ -185,11 +186,22 @@ def sync(self): if err != cuda.CUresult.CUDA_SUCCESS: raise RuntimeError("CUDA Error %s" % str(err)) + self.free() + def free(self): - if hasattr(self, "destination_buffer"): - del self.destination_buffer - if hasattr(self, "source_buffer"): - del self.source_buffer + """ + Frees allocated device-side memory + """ + # Free any device memory allocated manually + if not cutlass.use_rmm: + for attr in ["destination_buffer", "source_buffer"]: + if hasattr(self, attr): + buf = getattr(self, attr) + if isinstance(buf, DevicePtrWrapper): + err, = cudart.cudaFree(buf.ptr) + if err != cudart.cudaError_t.cudaSuccess: + raise RuntimeError(f"cudaFree failed with error {err}") + del buf class ReductionRT(ExecutableOperation): diff --git a/python/cutlass/backend/utils/__init__.py b/python/cutlass/backend/utils/__init__.py index be36ad8337..6141bb5903 100644 --- a/python/cutlass/backend/utils/__init__.py +++ b/python/cutlass/backend/utils/__init__.py @@ -30,11 +30,4 @@ # ################################################################################ -from cutlass.backend.utils.datatypes import * from cutlass.backend.utils.device import check_cuda_errors, device_cc -from cutlass.backend.utils.software import ( - CheckPackages, - SubstituteTemplate, - device_sm_count, - get_memory_pool, -) diff --git a/python/cutlass/backend/utils/datatypes.py b/python/cutlass/backend/utils/datatypes.py deleted file mode 100644 index 1140cb84ba..0000000000 --- a/python/cutlass/backend/utils/datatypes.py +++ /dev/null @@ -1,156 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -""" -Utility functions for converting between frontend datatypes and CUTLASS datatypes -""" - -from cuda import cuda - -from cutlass import DataType -from cutlass.backend.utils.software import CheckPackages - -numpy_available = CheckPackages().check_numpy() -if numpy_available: - import numpy as np - - numpy_to_cutlass_dict = { - np.float16: DataType.f16, - np.float32: DataType.f32, - np.float64: DataType.f64, - np.int8: DataType.s8, - np.int32: DataType.s32, - np.dtype('float16'): DataType.f16, - np.dtype('float32'): DataType.f32, - np.dtype('float64'): DataType.f64, - np.dtype('int8'): DataType.s8, - np.dtype('int32'): DataType.s32, - } - - -def numpy_to_cutlass(inp): - numpy_available = CheckPackages().check_numpy() - if numpy_available: - return numpy_to_cutlass_dict.get(inp, None) - - -cupy_available = CheckPackages().check_cupy() -if cupy_available: - import cupy as cp - - cupy_to_cutlass_dict = { - cp.float16: DataType.f16, - cp.float32: DataType.f32, - cp.float64: DataType.f64, - } - - -def cupy_to_cutlass(inp): - cupy_available = CheckPackages().check_cupy() - if cupy_available: - return cupy_to_cutlass_dict.get(inp, None) - - -torch_available = CheckPackages().check_torch() -if torch_available: - import torch - - torch_to_cutlass_dict = { - torch.half: DataType.f16, - torch.float16: DataType.f16, - torch.float: DataType.f32, - torch.float32: DataType.f32, - torch.double: DataType.f64, - torch.float64: DataType.f64, - } - - -def torch_to_cutlass(inp): - if torch_available: - return torch_to_cutlass_dict.get(inp, None) - - -try: - import bfloat16 - - bfloat16_available = True - numpy_to_cutlass_dict[np.dtype(bfloat16.bfloat16)] = DataType.bf16 -except ImportError: - bfloat16_available = False - - -def bfloat16_to_cutlass(inp): - if bfloat16_available: - if inp == bfloat16.bfloat16: - return DataType.bf16 - - -def to_cutlass(inp): - for cvt_fn in [ - bfloat16_to_cutlass, - cupy_to_cutlass, - numpy_to_cutlass, - torch_to_cutlass, - ]: - out = cvt_fn(inp) - if out is not None: - return out - - raise Exception( - "No available conversion from type {} to a CUTLASS type.".format(inp) - ) - - -def to_device_ptr(tensor) -> cuda.CUdeviceptr: - """ - Converts a tensor to a CUdeviceptr - - :param tensor: tensor to convert - :type tensor: np.ndarray | torch.Tensor | cp.ndarray | int - - :return: device pointer - :rtype: cuda.CUdeviceptr - """ - if isinstance(tensor, np.ndarray): - ptr = cuda.CUdeviceptr(tensor.__array_interface__["data"][0]) - elif torch_available and isinstance(tensor, torch.Tensor): - ptr = cuda.CUdeviceptr(tensor.data_ptr()) - elif cupy_available and isinstance(tensor, cp.ndarray): - ptr = cuda.CUdeviceptr(int(tensor.data.ptr)) - elif isinstance(tensor, cuda.CUdeviceptr): - ptr = tensor - elif isinstance(tensor, int): - ptr = cuda.CUdeviceptr(tensor) - else: - raise NotImplementedError(tensor) - - return ptr diff --git a/python/cutlass/backend/utils/device.py b/python/cutlass/backend/utils/device.py index 15e5457f55..f6c0f42c84 100644 --- a/python/cutlass/backend/utils/device.py +++ b/python/cutlass/backend/utils/device.py @@ -34,7 +34,10 @@ Utility functions for interacting with the device """ -from cuda import cudart +from cuda import cuda, cudart + +import cutlass +from cutlass.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_tensor def check_cuda_errors(result: list): @@ -60,7 +63,7 @@ def check_cuda_errors(result: list): return result[1:] -def device_cc(device: int = 0) -> int: +def device_cc(device: int = -1) -> int: """ Returns the compute capability of the device with ID `device`. @@ -70,7 +73,51 @@ def device_cc(device: int = 0) -> int: :return: compute capability of the queried device (e.g., 80 for SM80) :rtype: int """ + if device == -1: + device = cutlass.device_id() + deviceProp = check_cuda_errors(cudart.cudaGetDeviceProperties(device)) major = str(deviceProp.major) minor = str(deviceProp.minor) return int(major + minor) + + +def device_sm_count(device: int = -1): + if device == -1: + device = cutlass.device_id() + err, device_sm_count = cuda.cuDeviceGetAttribute( + cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, device + ) + if err != cuda.CUresult.CUDA_SUCCESS: + raise Exception( + "Failed to retireve SM count. " + f"cuDeviceGetAttribute() failed with error: {cuda.cuGetErrorString(err)[1]}" + ) + + return device_sm_count + + +def to_device_ptr(tensor) -> cuda.CUdeviceptr: + """ + Converts a tensor to a CUdeviceptr + + :param tensor: tensor to convert + :type tensor: np.ndarray | torch.Tensor | cp.ndarray | int + + :return: device pointer + :rtype: cuda.CUdeviceptr + """ + if is_numpy_tensor(tensor): + ptr = cuda.CUdeviceptr(tensor.__array_interface__["data"][0]) + elif is_torch_tensor(tensor): + ptr = cuda.CUdeviceptr(tensor.data_ptr()) + elif is_cupy_tensor(tensor): + ptr = cuda.CUdeviceptr(int(tensor.data.ptr)) + elif isinstance(tensor, cuda.CUdeviceptr): + ptr = tensor + elif isinstance(tensor, int): + ptr = cuda.CUdeviceptr(tensor) + else: + raise NotImplementedError(tensor) + + return ptr diff --git a/python/cutlass/backend/utils/software.py b/python/cutlass/backend/utils/software.py deleted file mode 100644 index 9f099b8a29..0000000000 --- a/python/cutlass/backend/utils/software.py +++ /dev/null @@ -1,111 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -import re -import sys - -from cutlass.backend.memory_manager import PoolMemoryManager - - -class CheckPackages: - def __init__(self) -> None: - pass - - def check_cupy(self): - if "cupy" in sys.modules: - return True - else: - try: - import cupy - - cupy_available = True - except ImportError: - print("cupy is not loaded.") - - def check_numpy(self): - if "numpy" in sys.modules: - return True - else: - try: - import numpy - - numpy_available = True - except ImportError: - print("numpy is not loaded.") - - def check_torch(self): - if "torch" in sys.modules: - return True - else: - try: - import torch - - torch_available = True - except ImportError: - print("torch is not loaded.") - - -def SubstituteTemplate(template, values): - text = template - changed = True - while changed: - changed = False - for key, value in values.items(): - regex = "\\$\\{%s\\}" % key - newtext = re.sub(regex, value, text) - if newtext != text: - changed = True - text = newtext - return text - - -def device_sm_count(): - from cuda import cuda - - _device = 0 - err, _device_sm_count = cuda.cuDeviceGetAttribute( - cuda.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT, _device - ) - if err != cuda.CUresult.CUDA_SUCCESS: - raise Exception( - "Failed to retireve SM count. " - f"cuDeviceGetAttribute() failed with error: {cuda.cuGetErrorString(err)[1]}" - ) - - return _device_sm_count - - -def get_memory_pool(init_pool_size=0, max_pool_size=2 ** 34): - memory_pool = PoolMemoryManager( - init_pool_size=init_pool_size, max_pool_size=max_pool_size - ) - return memory_pool diff --git a/python/cutlass/emit/pytorch.py b/python/cutlass/emit/pytorch.py index 737f5cdf34..91a7f94a85 100644 --- a/python/cutlass/emit/pytorch.py +++ b/python/cutlass/emit/pytorch.py @@ -39,7 +39,7 @@ .. highlight:: python .. code-block:: python - plan = cutlass.op.Gemm(element=torch.float32, layout=cutlass.LayoutType.RowMajor) + plan = cutlass.op.Gemm(element=torch.float32, layout=cutlass_library.LayoutType.RowMajor) op = plan.construct() mod = cutlass.emit.pytorch(op, 'cutlass_gemm', 80, jit=True) @@ -81,15 +81,16 @@ import logging import os -from cutlass import CUTLASS_PATH, logger, swizzle, ConvKind, ConvKindNames, DataType +from cutlass_library import ConvKind, ConvKindNames, DataType, SubstituteTemplate + +from cutlass import CUTLASS_PATH, logger, swizzle from cutlass.backend.gemm_operation import GemmOperationGrouped, GemmOperationUniversal from cutlass.backend.conv2d_operation import Conv2dOperation from cutlass.backend.library import ApiVersion -from cutlass.backend.utils.software import CheckPackages, SubstituteTemplate from cutlass.emit import common +from cutlass.utils.datatypes import is_torch_available -torch_available = CheckPackages().check_torch() -if torch_available: +if is_torch_available(): import torch diff --git a/python/cutlass/epilogue/evt_ops.py b/python/cutlass/epilogue/evt_ops.py index 19f79a3dab..de900b715a 100644 --- a/python/cutlass/epilogue/evt_ops.py +++ b/python/cutlass/epilogue/evt_ops.py @@ -36,10 +36,9 @@ import numpy as np -from cutlass.backend.utils.software import CheckPackages +from cutlass.utils.datatypes import is_cupy_tensor, is_numpy_tensor, is_torch_available, is_torch_tensor -torch_available = CheckPackages().check_torch() -if torch_available: +if is_torch_available(): import torch @@ -48,16 +47,16 @@ def multiply_add(x, y, z): def sum(x, dim): - if isinstance(x, np.ndarray): + if is_numpy_tensor(x): return x.sum(axis=tuple(dim)) - elif torch_available and isinstance(x, torch.Tensor): + elif is_torch_tensor(x): return torch.sum(x, dim) def max(x, dim): - if isinstance(x, np.ndarray): + if is_numpy_tensor(x): return x.max(axis=tuple(dim)) - elif torch_available and isinstance(x, torch.Tensor): + elif is_torch_tensor(x): return torch.amax(x, dim) @@ -66,14 +65,14 @@ def max(x, dim): ############################################################################## def permute(x, indices: tuple): - if isinstance(x, np.ndarray): + if is_numpy_tensor(x): return np.transpose(x, axes=indices) - elif torch_available and isinstance(x, torch.Tensor): + elif is_torch_tensor(x): return x.permute(*indices) def reshape(x, new_shape: tuple): - if isinstance(x, np.ndarray): + if is_numpy_tensor(x): return np.reshape(x, newshape=new_shape) - elif torch_available and isinstance(x, torch.Tensor): + elif is_torch_tensor(x): return x.view(new_shape) diff --git a/python/cutlass/library_defaults.py b/python/cutlass/library_defaults.py index ad3e9ba8b3..ef1a8fce91 100644 --- a/python/cutlass/library_defaults.py +++ b/python/cutlass/library_defaults.py @@ -69,20 +69,23 @@ def add(self, operation): """ Add an operation to the list of supported kernels """ - alignment = operation.A.alignment - if alignment not in self.kernels_by_alignment: - self.kernels_by_alignment[alignment] = [] - self.kernels_by_alignment[alignment].append(operation) + alignment_key = f"{operation.A.alignment} {operation.B.alignment} {operation.C.alignment}" + if alignment_key not in self.kernels_by_alignment: + self.kernels_by_alignment[alignment_key] = [] + self.kernels_by_alignment[alignment_key].append(operation) - @property - def alignments(self): + def alignments(self, operand: str): """ Returns an unsorted list of alignments supported by this data type combination + :param operand: identifier of operand in question (e.g., A, B, C) + :type operand: str + :return: unsorted list of alignments supported by this data type combination :rtype: list """ - return list(self.kernels_by_alignment.keys()) + operand_idx = self._operand_idx(operand) + return [int(key.split(" ")[operand_idx]) for key in self.kernels_by_alignment.keys()] @property def all_operations(self): @@ -97,24 +100,48 @@ def all_operations(self): ops.extend(alignment_ops) return ops - def operations(self, alignment: int): + def default_operation(self): + key = sorted(list(self.kernels_by_alignment.keys()))[0] + return self.kernels_by_alignment[key][0] + + def operations(self, alignment_A: int, alignment_B: int, alignment_C: int): """ - Returns operations satisfying the alignment constraint indicated by `alignment` + Returns operations satisfying the alignment constraints - :param alignment: alignment constraint of operations to return - :type alignment: int + :param alignment_A: alignment constraint of operations to return + :type alignment_A: int + :param alignment_B: alignment constraint of operations to return + :type alignment_B: int + :param alignment_C: alignment constraint of operations to return + :type alignment_C: int :return: list of operations :rtype: list """ - if alignment not in self.kernels_by_alignment: - raise Exception( - f"No operations of alignment {alignment} found for data type and layout " - f"combination {self.datatype_comb} {self.layout_comb}" - ) - return self.kernels_by_alignment[alignment] + key = f"{alignment_A} {alignment_B} {alignment_C}" + + if key not in self.kernels_by_alignment: + og_key = key + # Reconcile A, B, and C alignments by trying to align to the minimum + min_alignment = min(alignment_A, alignment_B, alignment_C) + key = f"{min_alignment} {min_alignment} {min_alignment}" + if key not in self.kernels_by_alignment: + raise Exception( + f"No operations of alignment {og_key} found for data type and layout " + f"combination {self.datatype_comb} {self.layout_comb}. Tried to fall back " + f"to alignment {key}, but that was also not compatible. Compatible alignments " + f"are {self.kernels_by_alignment.keys()}" + ) + return self.kernels_by_alignment[key] - def find_alignment(self, shape: tuple, layout: cutlass.LayoutType) -> int: + def _operand_idx(self, key: str) -> int: + operand_list = ["A", "B", "C"] + if key not in operand_list: + raise Exception(f"Unexpected operand {operand}") + + return operand_list.index(key) + + def find_alignment(self, shape: tuple, layout: cutlass.LayoutType, operand=str) -> int: """ Returns the most preferable alignment for a given shape and layout @@ -122,10 +149,14 @@ def find_alignment(self, shape: tuple, layout: cutlass.LayoutType) -> int: :type shape: tuple :param layout: layout of the tensor :type layout: cutlass.LayoutType + :param operand: descriptor of the operand in question + :type operand: str :return: maximum alignment supported by the data type combination and tensor size :rtype: int """ + operand_idx = self._operand_idx(operand) + # Determine the leading dimension of the shape if layout == cutlass.LayoutType.ColumnMajor: ld = shape[-2] @@ -136,7 +167,8 @@ def find_alignment(self, shape: tuple, layout: cutlass.LayoutType) -> int: else: raise Exception(f"Unexpected or unsupported layout {layout}") - for alignment in sorted(list(self.kernels_by_alignment.keys()), reverse=True): + for alignments in sorted(list(self.kernels_by_alignment.keys()), reverse=True): + alignment = int(alignments.split(" ")[operand_idx]) if ld % alignment == 0: return alignment @@ -165,7 +197,7 @@ class ArchOptions: :param kernel_cc: compute capability of the kernels to generate :type kernel_cc: int :param operation_kind: type of operation to register - :type operation_kind: cutlass.OperationKind + :type operation_kind: cutlass_library.OperationKind :param gemm_kinds: types of GEMM operations that can be included :type gemm_kinds: list :param allowed_math_operations: types of primitive math operations allowed @@ -176,11 +208,12 @@ def __init__( self, target_cc: int, kernel_cc: int, - operation_kind: cutlass.OperationKind, + operation_kind: cutlass_library.OperationKind, gemm_kinds: list, allowed_math_operations: list = [ - cutlass.MathOperation.multiply_add, - cutlass.MathOperation.multiply_add_saturate, + cutlass_library.MathOperation.multiply_add, + cutlass_library.MathOperation.multiply_add_saturate, + cutlass_library.MathOperation.multiply_add_mixed_input_upcast ] ): self.cc = kernel_cc @@ -229,7 +262,7 @@ def __init__( # find available opclasses and data types for name, op_list in manifest.operations[operation_kind][kernel_cc].items(): for op in op_list: - if operation_kind == cutlass.OperationKind.Gemm: + if operation_kind == cutlass_library.OperationKind.Gemm: if op.gemm_kind not in gemm_kinds: continue @@ -237,15 +270,11 @@ def __init__( if mi.math_operation not in self.allowed_math_operations: continue - if op.C.element == cutlass.DataType.void: - # The CUTLASS Python interface currently does not support void-C kernels - continue - datatype_comb = (mi.element_a, mi.element_b, mi.element_accumulator) # Prune operations that don't fit in shared memory td = td_from_profiler_op(op) - if not valid_stage_count(target_cc, kernel_cc, td)[0]: + if not valid_stage_count(target_cc, kernel_cc, td, verbose=False)[0]: continue if mi.opcode_class not in self.operations_by_opclass: @@ -255,17 +284,17 @@ def __init__( layout_comb = (op.A.layout, op.B.layout) # Register TF32 kernels as F32 to enable F32 -> TF32 conversion + TF32 Tensor Core operations - if datatype_comb == (cutlass.DataType.tf32, cutlass.DataType.tf32, cutlass.DataType.f32): + if datatype_comb == (cutlass_library.DataType.tf32, cutlass_library.DataType.tf32, cutlass_library.DataType.f32): # TF32 kernels only supported on SM80 and beyond if self.cc < 80: continue elif self.cc == 90: - if (op.A.element != cutlass.DataType.f32 - or op.B.element != cutlass.DataType.f32 - or op.C.element != cutlass.DataType.f32): + if (op.A.element != cutlass_library.DataType.f32 + or op.B.element != cutlass_library.DataType.f32 + or op.C.element != cutlass_library.DataType.f32): continue - datatype_comb = (cutlass.DataType.f32, cutlass.DataType.f32, cutlass.DataType.f32) + datatype_comb = (cutlass_library.DataType.f32, cutlass_library.DataType.f32, cutlass_library.DataType.f32) opclass_dict = self.operations_by_opclass[mi.opcode_class] key = (datatype_comb, layout_comb) @@ -274,82 +303,82 @@ def __init__( opclass_dict[key].add(op) # Set the default opclass to TensorOp, if available. Otherwise default to SIMT - if cutlass.OpcodeClass.TensorOp in self.operations_by_opclass: - self.op_class = cutlass.OpcodeClass.TensorOp + if cutlass_library.OpcodeClass.TensorOp in self.operations_by_opclass: + self.op_class = cutlass_library.OpcodeClass.TensorOp else: - self.op_class = cutlass.OpcodeClass.Simt + self.op_class = cutlass_library.OpcodeClass.Simt # The profiler's generator may generate only a limited set of combinations of operands for SIMT kernels. # Here, we generate additional versions via a generic TileDescription. - if cutlass.OpcodeClass.Simt not in self.operations_by_opclass: - self.operations_by_opclass[cutlass.OpcodeClass.Simt] = {} + if cutlass_library.OpcodeClass.Simt not in self.operations_by_opclass: + self.operations_by_opclass[cutlass_library.OpcodeClass.Simt] = {} - if operation_kind == cutlass.OperationKind.Gemm: + if operation_kind == cutlass_library.OperationKind.Gemm: types = [ - (cutlass.DataType.s8, cutlass.DataType.s8, cutlass.DataType.s8), - (cutlass.DataType.s8, cutlass.DataType.s8, cutlass.DataType.s32), - (cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f16), - (cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f32), - (cutlass.DataType.f32, cutlass.DataType.f32, cutlass.DataType.f32), - (cutlass.DataType.f64, cutlass.DataType.f64, cutlass.DataType.f64), + (cutlass_library.DataType.s8, cutlass_library.DataType.s8, cutlass_library.DataType.s8), + (cutlass_library.DataType.s8, cutlass_library.DataType.s8, cutlass_library.DataType.s32), + (cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f16), + (cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f32), + (cutlass_library.DataType.f32, cutlass_library.DataType.f32, cutlass_library.DataType.f32), + (cutlass_library.DataType.f64, cutlass_library.DataType.f64, cutlass_library.DataType.f64), ] layouts = [ - (cutlass.LayoutType.RowMajor, cutlass.LayoutType.RowMajor), - (cutlass.LayoutType.RowMajor, cutlass.LayoutType.ColumnMajor), - (cutlass.LayoutType.ColumnMajor, cutlass.LayoutType.RowMajor), - (cutlass.LayoutType.ColumnMajor, cutlass.LayoutType.ColumnMajor), + (cutlass_library.LayoutType.RowMajor, cutlass_library.LayoutType.RowMajor), + (cutlass_library.LayoutType.RowMajor, cutlass_library.LayoutType.ColumnMajor), + (cutlass_library.LayoutType.ColumnMajor, cutlass_library.LayoutType.RowMajor), + (cutlass_library.LayoutType.ColumnMajor, cutlass_library.LayoutType.ColumnMajor), ] - elif operation_kind == cutlass.OperationKind.Conv2d: + elif operation_kind == cutlass_library.OperationKind.Conv2d: types = [ - (cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f16), - (cutlass.DataType.f16, cutlass.DataType.f16, cutlass.DataType.f32), - (cutlass.DataType.f32, cutlass.DataType.f32, cutlass.DataType.f32), - (cutlass.DataType.f64, cutlass.DataType.f64, cutlass.DataType.f64), + (cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f16), + (cutlass_library.DataType.f16, cutlass_library.DataType.f16, cutlass_library.DataType.f32), + (cutlass_library.DataType.f32, cutlass_library.DataType.f32, cutlass_library.DataType.f32), + (cutlass_library.DataType.f64, cutlass_library.DataType.f64, cutlass_library.DataType.f64), ] layouts = [ - (cutlass.LayoutType.TensorNHWC, cutlass.LayoutType.TensorNHWC), + (cutlass_library.LayoutType.TensorNHWC, cutlass_library.LayoutType.TensorNHWC), ] else: raise NotImplementedError(f"Operation kind {operation_kind} is currently unsupported.") alignment = 1 - epilogue_functor = cutlass.EpilogueFunctor.LinearCombination - swizzling_functor = cutlass.SwizzlingFunctor.Identity8 + epilogue_functor = cutlass_library.EpilogueFunctor.LinearCombination + swizzling_functor = cutlass_library.SwizzlingFunctor.Identity8 for type_comb in types: for layout_comb in layouts: comb = (type_comb, layout_comb) - if comb in self.operations_by_opclass[cutlass.OpcodeClass.Simt]: + if comb in self.operations_by_opclass[cutlass_library.OpcodeClass.Simt]: continue - A = cutlass.TensorDescription(type_comb[0], layout_comb[0], alignment) - B = cutlass.TensorDescription(type_comb[1], layout_comb[1], alignment) - C = cutlass.TensorDescription(type_comb[2], cutlass.LayoutType.ColumnMajor, alignment) - math_inst = cutlass.MathInstruction( + A = cutlass_library.TensorDescription(type_comb[0], layout_comb[0], alignment) + B = cutlass_library.TensorDescription(type_comb[1], layout_comb[1], alignment) + C = cutlass_library.TensorDescription(type_comb[2], cutlass_library.LayoutType.ColumnMajor, alignment) + math_inst = cutlass_library.MathInstruction( [1, 1, 1], type_comb[0], type_comb[1], type_comb[2], - cutlass.OpcodeClass.Simt, - cutlass.MathOperation.multiply_add + cutlass_library.OpcodeClass.Simt, + cutlass_library.MathOperation.multiply_add ) - td = cutlass.TileDescription( + td = cutlass_library.TileDescription( [128, 128, 8], 2, [4, 2, 1], math_inst, 50, 1024) # Prune operations that don't fit in shared memory - if not valid_stage_count(target_cc, kernel_cc, td_from_profiler_td(td))[0]: + if not valid_stage_count(target_cc, kernel_cc, td_from_profiler_td(td), verbose=False)[0]: continue new_kernels = KernelsForDataType(type_comb, layout_comb) - if operation_kind == cutlass.OperationKind.Gemm: + if operation_kind == cutlass_library.OperationKind.Gemm: new_operation = cutlass_library.manifest.GemmOperation( - cutlass.GemmKind.Universal, td.minimum_compute_capability, + cutlass_library.GemmKind.Universal, td.minimum_compute_capability, td, A, B, C, type_comb[2], epilogue_functor, swizzling_functor) new_kernels.add(new_operation) - elif operation_kind == cutlass.OperationKind.Conv2d: + elif operation_kind == cutlass_library.OperationKind.Conv2d: for conv_kind in [ConvKind.Fprop, ConvKind.Dgrad, ConvKind.Wgrad]: new_operation = cutlass_library.manifest.Conv2dOperation( conv_kind, IteratorAlgorithm.Analytic, td.minimum_compute_capability, td, @@ -358,7 +387,7 @@ def __init__( ) new_kernels.add(new_operation) - self.operations_by_opclass[cutlass.OpcodeClass.Simt][comb] = new_kernels + self.operations_by_opclass[cutlass_library.OpcodeClass.Simt][comb] = new_kernels # Sort all operations for oc in self.operations_by_opclass.keys(): @@ -366,17 +395,17 @@ def __init__( self.operations_by_opclass[oc][comb].sort() def opclass_supports_combination( - self, op_class: cutlass.OpcodeClass, datatype_comb: tuple, layout_comb: tuple + self, op_class: cutlass_library.OpcodeClass, datatype_comb: tuple, layout_comb: tuple ) -> bool: """ Returns whether the provided operation class supports the provided data type and layout combination :param op_class: operation class to consider - :type op_class: cutlass.OpcodeClass + :type op_class: cutlass_library.OpcodeClass :param datatype_comb: tuple of data types for (element_A, element_B, element_accumulator) - :type datatype_comb: tuple[cutlass.DataType] + :type datatype_comb: tuple[cutlass_library.DataType] :param layout_comb: tuple of data types for (layout_A, layout_B) - :type layout_comb: tuple[cutlass.LayoutType] + :type layout_comb: tuple[cutlass_library.LayoutType] :return: set of operation classes that support the provided data type and layout combination :rtype: set @@ -388,25 +417,25 @@ def opclass_supports_combination( def supporting_opclasses( self, - element_a: cutlass.DataType, - element_b: cutlass.DataType, - element_accumulator: cutlass.DataType, - layout_a: cutlass.LayoutType, - layout_b: cutlass.LayoutType, + element_a: cutlass_library.DataType, + element_b: cutlass_library.DataType, + element_accumulator: cutlass_library.DataType, + layout_a: cutlass_library.LayoutType, + layout_b: cutlass_library.LayoutType, ) -> set: """ Returns a set of operation classes that support the provided data type combination :param element_a: data type of operand A - :type element_a: cutlass.DataType + :type element_a: cutlass_library.DataType :param element_b: data type of operand B - :type element_b: cutlass.DataType + :type element_b: cutlass_library.DataType :param element_accumulator: data type of accumulator - :type element_accumulator: cutlass.DataType + :type element_accumulator: cutlass_library.DataType :param layout_a: layout of operand A - :type layout_a: cutlass.LayoutType + :type layout_a: cutlass_library.LayoutType :param layout_b: layout of operand B - :type layout_b: cutlass.LayoutType + :type layout_b: cutlass_library.LayoutType :return: set of operation classes that support the provided data type combination :rtype: set @@ -422,28 +451,28 @@ def supporting_opclasses( def operations( self, - op_class: cutlass.OpcodeClass, - element_a: cutlass.DataType, - element_b: cutlass.DataType, - element_accumulator: cutlass.DataType, - layout_a: cutlass.LayoutType, - layout_b: cutlass.LayoutType, + op_class: cutlass_library.OpcodeClass, + element_a: cutlass_library.DataType, + element_b: cutlass_library.DataType, + element_accumulator: cutlass_library.DataType, + layout_a: cutlass_library.LayoutType, + layout_b: cutlass_library.LayoutType, ) -> KernelsForDataType: """ Returns whether the provided operation class supports the provided data type combination :param op_class: operation class to consider - :type op_class: cutlass.OpcodeClass + :type op_class: cutlass_library.OpcodeClass :param element_a: data type of operand A - :type element_a: cutlass.DataType + :type element_a: cutlass_library.DataType :param element_b: data type of operand B - :type element_b: cutlass.DataType + :type element_b: cutlass_library.DataType :param element_accumulator: data type of accumulator - :type element_accumulator: cutlass.DataType + :type element_accumulator: cutlass_library.DataType :param layout_a: layout of operand A - :type layout_a: cutlass.LayoutType + :type layout_a: cutlass_library.LayoutType :param layout_b: layout of operand B - :type layout_b: cutlass.LayoutType + :type layout_b: cutlass_library.LayoutType :return: container of kernels by alignment supported by the provided combination of parameters :rtype: KernelsForDataType @@ -469,13 +498,13 @@ class OptionRegistry: def __init__(self, target_cc: int): self.registry = {} - gemm_kinds = [cutlass.GemmKind.Universal, cutlass.GemmKind.Universal3x] - operation_kinds = [cutlass.OperationKind.Gemm, cutlass.OperationKind.Conv2d] + gemm_kinds = [cutlass_library.GemmKind.Universal, cutlass_library.GemmKind.Universal3x] + operation_kinds = [cutlass_library.OperationKind.Gemm, cutlass_library.OperationKind.Conv2d] # Construct options for each CC for kernel_cc in _generator_ccs: self.registry[kernel_cc] = {} for opkind in operation_kinds: self.registry[kernel_cc][opkind] = ArchOptions(target_cc, kernel_cc, opkind, gemm_kinds) - def options_for_cc(self, cc: int, op_kind=cutlass.OperationKind.Gemm) -> ArchOptions: + def options_for_cc(self, cc: int, op_kind=cutlass_library.OperationKind.Gemm) -> ArchOptions: return self.registry.get(cc, None)[op_kind] diff --git a/python/cutlass/op/conv.py b/python/cutlass/op/conv.py index 3968785925..d7cd90ad0e 100644 --- a/python/cutlass/op/conv.py +++ b/python/cutlass/op/conv.py @@ -112,15 +112,18 @@ args.sync() """ -import cutlass -from cutlass import epilogue -from cutlass import ( +from cutlass_library import ( ConvKind, ConvMode, + DataTypeSize, IteratorAlgorithm, + OperationKind, SplitKMode, StrideSupport, ) + +import cutlass +from cutlass import epilogue from cutlass.backend import compiler from cutlass.backend.conv2d_operation import Conv2dArguments, Conv2dOperation from cutlass.backend.reduction_operation import ReductionOperation, ReductionArguments @@ -202,7 +205,7 @@ def __init__( element_accumulator=None, cc: int = None, kernel_cc: int = None ): - super().__init__(cc=cc, kernel_cc=kernel_cc, operation_kind=cutlass.OperationKind.Conv2d) + super().__init__(cc=cc, kernel_cc=kernel_cc, operation_kind=OperationKind.Conv2d) # Verify the kernel cc if self.current_cc == 90: # The Conv2d kernel on Hopper (SM90) is currently unsupported @@ -305,11 +308,11 @@ def _reset_operations(self, reset_epilogue: bool = True): self._reset_epilogue_functor_activation(epilogue.identity) self.alignment_pref_A = min( - 128 // cutlass.DataTypeSize[self._element_a], max(self.possible_operations.alignments)) + 128 // DataTypeSize[self._element_a], max(self.possible_operations.alignments("A"))) self.alignment_pref_B = min( - 128 // cutlass.DataTypeSize[self._element_b], max(self.possible_operations.alignments)) + 128 // DataTypeSize[self._element_b], max(self.possible_operations.alignments("B"))) self.alignment_pref_C = min( - 128 // cutlass.DataTypeSize[self._element_c], max(self.possible_operations.alignments)) + 128 // DataTypeSize[self._element_c], max(self.possible_operations.alignments("C"))) # # Tile description Related @@ -342,8 +345,7 @@ def tile_description( return if isinstance(td, dict): if self._tile_description is None: - alignment = list(self.possible_operations.kernels_by_alignment.keys())[0] - op = self.possible_operations.operations(alignment)[0] + op = self.possible_operations.default_operation() self._tile_description = datatypes.td_from_profiler_op(op) if "cluster_shape" in td.keys(): if td["cluster_shape"] != [1, 1, 1]: @@ -567,8 +569,7 @@ def construct( if self.tile_description is not None: tile_description = self.tile_description else: - min_alignment = min([alignment_A, alignment_B, alignment_C]) - op = self.possible_operations.operations(min_alignment)[0] + op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C)[0] tile_description = datatypes.td_from_profiler_op(op) else: valid, err_str = self._valid_tile_description(tile_description) @@ -753,6 +754,8 @@ def run(self, A=None, B=None, C=None, D=None, :return: arguments passed in to the kernel :rtype: cutlass.backend.Conv2dArguments """ + super().run_setup() + A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A") B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B") C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C") @@ -782,9 +785,9 @@ def run(self, A=None, B=None, C=None, D=None, shape_c = datatypes.get_tensor_shape(C, op="CONV") # Get the alignment - alignment_a = self.possible_operations.find_alignment(shape_a, self._layout_a) - alignment_b = self.possible_operations.find_alignment(shape_b, self._layout_b) - alignment_c = self.possible_operations.find_alignment(shape_c, self._layout_c) + alignment_a = self.possible_operations.find_alignment(shape_a, self._layout_a, operand="A") + alignment_b = self.possible_operations.find_alignment(shape_b, self._layout_b, operand="B") + alignment_c = self.possible_operations.find_alignment(shape_c, self._layout_c, operand="C") alignment_a = check.update_alignment(alignment_a, self.alignment_pref_A) alignment_b = check.update_alignment(alignment_b, self.alignment_pref_B) @@ -858,6 +861,10 @@ def run(self, A=None, B=None, C=None, D=None, if sync: if split_k[0] == "parallel" and split_k[1] > 1: reduction_arguments.sync() + + # Free memory allocated by args because we are not + # calling `arguments.sync()` in this case (which will free memory) + arguments.free() else: arguments.sync() diff --git a/python/cutlass/op/gemm.py b/python/cutlass/op/gemm.py index 3046e34dbb..718696f10a 100644 --- a/python/cutlass/op/gemm.py +++ b/python/cutlass/op/gemm.py @@ -116,12 +116,14 @@ from math import prod -import cutlass -from cutlass import ( - epilogue, - swizzle, +from cutlass_library import ( + DataType, + DataTypeSize, GemmUniversalMode, ) + +import cutlass +from cutlass import epilogue, swizzle from cutlass.backend import compiler from cutlass.backend.evt import EpilogueFunctorVisitor from cutlass.backend.gemm_operation import GemmArguments, GemmOperationUniversal @@ -292,7 +294,7 @@ def _reset_operations(self, reset_epilogue: bool = True): f'combination {datatype_comb}x{layout_comb}') if reset_epilogue: - self._reset_epilogue_functor_activation(epilogue.identity) + self._reset_epilogue_functor_activation(cutlass.epilogue.identity) @property def swizzling_functor(self): @@ -308,7 +310,7 @@ def swizzling_functor(self, swizzling_functor): """ Sets the swizzling functor to the type specified by `swizzling_functor` """ - if swizzling_functor == swizzle.ThreadblockSwizzleStreamK: + if swizzling_functor == cutlass.swizzle.ThreadblockSwizzleStreamK: if self.op_class == cutlass.OpcodeClass.Simt: raise Exception('ThreadblockSwizzleStreamK is currently only supported with opcode class TensorOp') @@ -347,8 +349,7 @@ def tile_description( return if isinstance(td, dict): if self._tile_description is None: - alignment = list(self.possible_operations.kernels_by_alignment.keys())[0] - op = self.possible_operations.operations(alignment)[0] + op = self.possible_operations.default_operation() self._tile_description = datatypes.td_from_profiler_op(op) td = self._tile_description.clone_and_update(td) @@ -414,22 +415,25 @@ def construct( :return: operation that was constructed :rtype: cutlass.backend.GemmOperationUniversal """ - alignment_pref_A = min(128 // cutlass.DataTypeSize[self._element_a], max(self.possible_operations.alignments)) - alignment_pref_B = min(128 // cutlass.DataTypeSize[self._element_b], max(self.possible_operations.alignments)) - alignment_pref_C = min(128 // cutlass.DataTypeSize[self._element_c], max(self.possible_operations.alignments)) + alignment_pref_A = min(128 // DataTypeSize[self._element_a], max(self.possible_operations.alignments("A"))) + alignment_pref_B = min(128 // DataTypeSize[self._element_b], max(self.possible_operations.alignments("B"))) alignment_A = check.alignment_or_default(alignment_A, alignment_pref_A) alignment_B = check.alignment_or_default(alignment_B, alignment_pref_B) - alignment_C = check.alignment_or_default(alignment_C, alignment_pref_C) - - self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor) tensor_A = TensorDescription(self._element_a, self._layout_a, alignment_A) tensor_B = TensorDescription(self._element_b, self._layout_b, alignment_B) + + alignment_pref_C = max(self.possible_operations.alignments("C")) + if self._element_c != DataType.void: + alignment_pref_C = min(128 // DataTypeSize[self._element_c], alignment_pref_C) + + alignment_C = check.alignment_or_default(alignment_C, alignment_pref_C) tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C) + self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor) if tile_description is None: if self._tile_description is None: - op = self.possible_operations.operations(alignment_A)[0] + op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C)[0] tile_description = datatypes.td_from_profiler_op(op) else: tile_description = self._tile_description @@ -527,7 +531,7 @@ def _get_batch_stride(self, tensor) -> int: :return: stride between each matrix in the batch :rtype: int """ - if len(tensor.shape) > 2: + if tensor is not None and len(tensor.shape) > 2: return tensor.shape[-2] * tensor.shape[-1] else: return 0 @@ -566,12 +570,14 @@ def _get_problem_args(self, A, B, C, D) -> tuple: B_row = self._layout_b == cutlass.LayoutType.RowMajor C_row = self._layout_c == cutlass.LayoutType.RowMajor - batched = lambda x : len(x.shape) > 2 and prod(x.shape[:-2]) == batch_count + # Consider a Tensor to be batched if its rank is > 2 and + # the product of the modes beyond rank 2 equals our pre-determined batch size. + batched = lambda x : x is None or (len(x.shape) > 2 and prod(x.shape[:-2]) == batch_count) - if batched(A) and not batched(B) and batched(C) and A_row and C_row: + if batched(A) and not batched(B) and (C is None or batched(C)) and A_row and C_row: M *= batch_count returned_batch_count = 1 - elif not batched(A) and batched(B) and batched(C) and not B_row and not C_row: + elif not batched(A) and batched(B) and (C is None or batched(C)) and not B_row and not C_row: N *= batch_count returned_batch_count = 1 else: @@ -625,6 +631,7 @@ def run(self, A=None, B=None, C=None, D=None, :return: arguments passed in to the kernel :rtype: cutlass.backend.GemmArguments """ + super().run_setup() A = self._verify_tensor(A, self.A, self._element_a, self._layout_a, "A") B = self._verify_tensor(B, self.B, self._element_b, self._layout_b, "B") C = self._verify_tensor(C, self.C, self._element_c, self._layout_c, "C") @@ -632,14 +639,20 @@ def run(self, A=None, B=None, C=None, D=None, alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha") beta = self._verify_scalar(beta, self.beta, self._element_c, "beta") + is_void_c = self._element_c == DataType.void + self._verify_rank(A) self._verify_rank(B) - self._verify_rank(C) + if not is_void_c: + self._verify_rank(C) self._verify_rank(D) - alignment_a = self.possible_operations.find_alignment(A.shape, self._layout_a) - alignment_b = self.possible_operations.find_alignment(B.shape, self._layout_b) - alignment_c = self.possible_operations.find_alignment(C.shape, self._layout_c) + alignment_a = self.possible_operations.find_alignment(A.shape, self._layout_a, operand="A") + alignment_b = self.possible_operations.find_alignment(B.shape, self._layout_b, operand="B") + + # Set C alignment based on D.shape so as to correctly get an alignment with void-C + # kernels, for which `C` is None. + alignment_c = self.possible_operations.find_alignment(D.shape, self._layout_c, operand="C") self.compile(self._tile_description, alignment_A=alignment_a, alignment_B=alignment_b, alignment_C=alignment_c, print_module=print_module) diff --git a/python/cutlass/op/gemm_grouped.py b/python/cutlass/op/gemm_grouped.py index bc8c98693e..d20ac50797 100644 --- a/python/cutlass/op/gemm_grouped.py +++ b/python/cutlass/op/gemm_grouped.py @@ -51,7 +51,8 @@ plan.run([A0, A1], [B0, B1], [C0, C1], [D0, D1]) """ -from cutlass import DataTypeSize +from cutlass_library import DataTypeSize + from cutlass.backend.gemm_operation import ( GemmGroupedArguments, GemmOperationGrouped, @@ -162,10 +163,9 @@ def construct(self, tile_description: TileDescription = None, :return: operation that was constructed :rtype: cutlass.backend.GemmOperationGrouped """ - alignment_preference = max(self.possible_operations.alignments) - alignment_A = check.alignment_or_default(alignment_A, alignment_preference) - alignment_B = check.alignment_or_default(alignment_B, alignment_preference) - alignment_C = check.alignment_or_default(alignment_C, alignment_preference) + alignment_A = check.alignment_or_default(alignment_A, max(self.possible_operations.alignments("A"))) + alignment_B = check.alignment_or_default(alignment_B, max(self.possible_operations.alignments("B"))) + alignment_C = check.alignment_or_default(alignment_C, max(self.possible_operations.alignments("C"))) self.epilogue_functor = self._reset_epilogue_functor_alignment(alignment_C, self.epilogue_functor) @@ -174,7 +174,7 @@ def construct(self, tile_description: TileDescription = None, tensor_C = TensorDescription(self._element_c, self._layout_c, alignment_C) if tile_description is None: - op = self.possible_operations.operations(alignment_A)[0] + op = self.possible_operations.operations(alignment_A, alignment_B, alignment_C)[0] tile_description = datatypes.td_from_profiler_op(op) else: valid, err_str = self._valid_tile_description(tile_description) @@ -221,6 +221,8 @@ def run(self, A, B, C, D, :return: arguments passed in to the kernel :rtype: cutlass.backend.GemmGroupedArguments """ + super().run_setup() + if len(A) != len(B) or len(A) != len(C) or len(A) != len(D): raise Exception("Lengths of A, B, C, and D lists must be equal") @@ -236,9 +238,9 @@ def run(self, A, B, C, D, alpha = self._verify_scalar(alpha, self.alpha, self._element_c, "alpha") beta = self._verify_scalar(beta, self.beta, self._element_c, "beta") - alignment_a = min((self.possible_operations.find_alignment(A.shape, self._layout_a) for A in As)) - alignment_b = min((self.possible_operations.find_alignment(B.shape, self._layout_b) for B in Bs)) - alignment_c = min((self.possible_operations.find_alignment(C.shape, self._layout_c) for C in Cs)) + alignment_a = min((self.possible_operations.find_alignment(A.shape, self._layout_a, operand="A") for A in As)) + alignment_b = min((self.possible_operations.find_alignment(B.shape, self._layout_b, operand="B") for B in Bs)) + alignment_c = min((self.possible_operations.find_alignment(C.shape, self._layout_c, operand="C") for C in Cs)) self.compile(self.tile_description, alignment_A=alignment_a, alignment_B=alignment_b, alignment_C=alignment_c, print_module=print_module) diff --git a/python/cutlass/op/op.py b/python/cutlass/op/op.py index 7bd0e545e9..d0630d6795 100644 --- a/python/cutlass/op/op.py +++ b/python/cutlass/op/op.py @@ -36,11 +36,13 @@ from bisect import bisect_left +from cutlass_library import DataType, DataTypeSize, OperationKind, SharedMemPerCC + import cutlass -from cutlass import option_registry, epilogue +from cutlass import get_option_registry from cutlass.backend.evt import EpilogueFunctorVisitor from cutlass.backend.utils.device import device_cc -from cutlass.epilogue import get_activations +from cutlass.epilogue import get_activations, get_activation_epilogue, identity from cutlass.library_defaults import KernelsForDataType, _generator_ccs from cutlass.swizzle import get_swizzling_functors from cutlass.utils import datatypes, check @@ -51,12 +53,14 @@ class OperationBase: Base operation used for defining high-level CUTLASS operations (e.g., GEMM, Conv2d) """ - def __init__(self, cc: int = None, kernel_cc: int = None, operation_kind = cutlass.OperationKind.Gemm): + def __init__(self, cc: int = None, kernel_cc: int = None, operation_kind = OperationKind.Gemm): """ :param cc: compute capability of device for which kernels should be compiled. For example, if running on H100, this should be set to 90 :type cc: int :param kernel_cc: compute capability of kernels to generate. For example, if running on SM90, but desiring to use a CUTLASS 2.x-style Ampere kernel, this should be set to 80 :type kernel_cc: int + :param operation_kind: class of operation that will be performed (e.g., GEMM, Conv) + :type operation_kind: cutlass_library.OperationKind """ self.operation_kind = operation_kind self.cc = cc if cc is not None else device_cc() @@ -64,13 +68,13 @@ def __init__(self, cc: int = None, kernel_cc: int = None, operation_kind = cutla self.current_cc = kernel_cc if kernel_cc is not None else self._find_closest_cc(self.cc) self.tile_description = None - self.options = option_registry.options_for_cc(self.current_cc, operation_kind) + self.options = get_option_registry().options_for_cc(self.current_cc, operation_kind) if self.options is None: raise Exception(f"Invalid or unsupported compute capability: {self.current_cc}") # Default activation function: identity - self._activation = epilogue.identity + self._activation = identity def _find_closest_cc(self, cc: int) -> int: """ @@ -120,7 +124,7 @@ def _reset_options(self, cc: int): if cc not in _generator_ccs: raise Exception(f'Invalid CC for CUTLASS kernels: {cc}.') self.current_cc = cc - self.options = option_registry.options_for_cc(self.current_cc, self.operation_kind) + self.options = get_option_registry().options_for_cc(self.current_cc, self.operation_kind) def _verify_scalar(self, scalar, ref_scalar, ref_dtype, name): """ @@ -158,9 +162,12 @@ def _verify_scalar(self, scalar, ref_scalar, ref_dtype, name): def _verify_tensor(self, tensor, ref_tensor, ref_dtype, ref_layout, name): """ Verifies the following properties: - 1) Either ``tensor`` or ``ref_tensor`` must be set (i.e., not ``None``) - 2) If ``tensor`` is not ``None``, its datatype and layout must match matches the current versions - set by the plan (i.e., those in ``ref_dtype`` and ``ref_layout``) + If ref_dtype is not void: + 1) Either ``tensor`` or ``ref_tensor`` must be set (i.e., not ``None``) + 2) If ``tensor`` is not ``None``, its datatype and layout must match matches the current versions + set by the plan (i.e., those in ``ref_dtype`` and ``ref_layout``) + If ref_dtype is void: + Neither ``tensor`` nor ``ref_tensor`` are set If either of these properties does not hold, an exception is raised. If these properties hold and ``tensor`` is not ``None``, ``tensor`` is returned. Otherwise, ``ref_tensor`` is returned. @@ -177,6 +184,11 @@ def _verify_tensor(self, tensor, ref_tensor, ref_dtype, ref_layout, name): :return: valid tensor object to use :rtype: numpy/cupy/torch array/tensor object """ + if ref_dtype == DataType.void: + if tensor is not None or ref_tensor is not None: + raise Exception("Operands with element DataType.void must not be provided a tensor") + return None + if tensor is None: if ref_tensor is None: raise Exception(f"Tensor {name} must be set.") @@ -211,58 +223,60 @@ def opclass(self, oc: cutlass.OpcodeClass): f'({self._element_a}, {self._element_b}, {self._element_accumulator}) and ' f'layout combination ({self._layout_a}, {self._layout_b}).') - # Changing the op class changes the elements per access in the epilogue. Reset this. - if self.op_class == cutlass.OpcodeClass.Simt: - elements_per_access = 1 - else: - elements_per_access = 128 // cutlass.DataTypeSize[self._element_c] - - if self.epilogue_functor is not None: - self.epilogue_functor = self._reset_epilogue_functor_alignment(elements_per_access, self.epilogue_functor) - # Changing the op class also changes the possible operations available. Reset these. self.possible_operations = self.options.operations( self.op_class, self._element_a, self._element_b, self._element_accumulator, self._layout_a, self._layout_b) + # Changing the op class changes the elements per access in the epilogue. Reset this. + if self.epilogue_functor is not None: + self.epilogue_functor = self._reset_epilogue_functor_alignment(self._elements_per_access(), self.epilogue_functor) + # # Epilogue # + def _elements_per_access(self): + if self.op_class == cutlass.OpcodeClass.Simt: + return 1 + elif self._element_c != DataType.void: + return 128 // DataTypeSize[self._element_c] + else: + return 128 // max(self.possible_operations.alignments("C")) + def _create_epilogue_functor_activation(self, activation): """ Returns the epilogue functor with given activation function """ if self.epilogue_functor is None: - if self.op_class == cutlass.OpcodeClass.Simt: - elements_per_access = 1 - else: - elements_per_access = 128 // cutlass.DataTypeSize[self._element_c] + elements_per_access = self._elements_per_access() else: elements_per_access = self.epilogue_functor.epilogue_vector_length if not self.specified_kernel_cc: - if self.current_cc == 90 and activation != epilogue.identity: - # CUTLASS 3.0 kernels currently only support identity activation. If one requests a non-identity activation, + if self.current_cc == 90 and activation != identity: + # CUTLASS 3.0 kernels in Python currently only support identity activation. If one requests a non-identity activation, # revert to using a CUTLASS 2.x kernel by using SM80-tagged kernels. cutlass.logger.warning("Reverting to using SM80-tagged kernel. Opclass may change.") + if self._element_c != self._element_d: + raise Exception("CUTLASS 2.x kernels require element C to be the same as element D") self._reset_options(80) self._reset_operations(reset_epilogue=False) - elif (self.cc == 90 and self.current_cc != 90 and activation == epilogue.identity): + elif (self.cc == 90 and self.current_cc != 90 and activation == identity): # SM80 fallback kernels are currently used. Since an identity activation is requested, # we can switch back to using SM90 kernels. self._reset_options(90) self._reset_operations(reset_epilogue=False) else: - if self.current_cc == 90 and activation != epilogue.identity: + if self.current_cc == 90 and activation != identity: raise Exception("Epilogues with elementwise fusion are not currently supported " "in the Python interface for 3.x kernels. To use 2.x kernels " "with fused elementwise epilogues, do not set the `kernel_cc` " "parameter when constructing the Gemm object.") - return epilogue.get_activation_epilogue( + return get_activation_epilogue( activation, - self._element_c, + self._element_d, elements_per_access, self._element_accumulator, self._element_accumulator, @@ -283,13 +297,13 @@ def _reset_epilogue_functor_alignment(self, alignment, epilogue_functor): if epilogue_functor is None or not hasattr(epilogue_functor, 'activation_functor'): # Identity epilogue does not have 'activation_functor' - activation = epilogue.identity + activation = identity else: activation = epilogue_functor.activation_functor - epilogue_functor = epilogue.get_activation_epilogue( + epilogue_functor = get_activation_epilogue( activation, - self._element_c, + self._element_d, alignment, self._element_accumulator, self._element_accumulator, @@ -304,7 +318,7 @@ def activation(self): if hasattr(self.epilogue_functor, "activation_functor"): return self.epilogue_functor.activation_functor else: - return epilogue.identity + return identity @activation.setter def activation(self, act): @@ -363,8 +377,8 @@ def epilogue_visitor(self, visitor): epilogue_smem_bytes = self.epilogue_functor.get_smem_size(td) # Verify the maximum number of mainloop stages - mainloop_smem_per_stage = check.calculate_smem_usage_per_stage(td, cutlass.OperationKind.Gemm) - smem_capacity_bytes = cutlass.SharedMemPerCC[self.cc] << 10 + mainloop_smem_per_stage = check.calculate_smem_usage_per_stage(td, OperationKind.Gemm) + smem_capacity_bytes = SharedMemPerCC[self.cc] << 10 mainloop_stages = (smem_capacity_bytes - epilogue_smem_bytes) // mainloop_smem_per_stage if mainloop_stages < 2: # Mainloop stages must >= 2 @@ -376,3 +390,11 @@ def epilogue_visitor(self, visitor): "The epilogue consumes too much shared memory. " "No valid tile description is found in the generator.") self.possible_operations = new_possible_operations + + + def run_setup(self): + """ + Steps that must be taken before caling `plan.run()` + """ + # Initialize the memory pool if, if not already done + cutlass.get_memory_pool() diff --git a/python/cutlass/shape.py b/python/cutlass/shape.py index 78e164d764..6e21dbbad9 100644 --- a/python/cutlass/shape.py +++ b/python/cutlass/shape.py @@ -34,7 +34,7 @@ Utilities for expressing shapes """ -from cutlass import ( +from cutlass_library import ( ConvMode, ConvKind, LayoutType @@ -64,7 +64,7 @@ def leading_dimension(self, layout: LayoutType) -> int: Returns the leading dimension for a matrix with layout ``layout`` and shape provided by the MatrixCoord. :param layout: layout of matrix - :type layout: cutlass.LayoutType + :type layout: cutlass_library.LayoutType :returns: leading dimension :rtype: int diff --git a/python/cutlass/swizzle.py b/python/cutlass/swizzle.py index 498ab74eb5..ef5957c3e6 100644 --- a/python/cutlass/swizzle.py +++ b/python/cutlass/swizzle.py @@ -34,7 +34,7 @@ Registry of swizzling functions """ -from cutlass import SwizzlingFunctor +from cutlass_library import SwizzlingFunctor IdentitySwizzle1 = SwizzlingFunctor.Identity1 diff --git a/python/cutlass/utils/check.py b/python/cutlass/utils/check.py index 1ca0eb8a8c..e16fb05c57 100644 --- a/python/cutlass/utils/check.py +++ b/python/cutlass/utils/check.py @@ -36,26 +36,27 @@ import ctypes +from cutlass_library import DataTypeSize, OperationKind, SharedMemPerCC + import cutlass -from cutlass import DataTypeSize from cutlass.backend.library import TileDescription -def calculate_smem_usage_per_stage(td: TileDescription, operation_kind: cutlass.OperationKind) -> int: +def calculate_smem_usage_per_stage(td: TileDescription, operation_kind: OperationKind) -> int: """ Returns the amount of shared memory in bytes consumed in a single stage of a kernel. :param td: tile description to compute shared memory of :type td: TileDescription :param operation_kind: identifier for the type of operation being performed - :type operation_kind: cutlass.OperationKind + :type operation_kind: cutlass_library.OperationKind :return: number of bytes of shared memory consumed by a single stage :rtype: int """ m, n, k = td.threadblock_shape - if operation_kind == cutlass.OperationKind.Gemm: + if operation_kind == OperationKind.Gemm: stage_barrier_bytes = 32 return ( (DataTypeSize[td.math_instruction.element_a] * m * k // 8) @@ -82,7 +83,8 @@ def valid_stage_count( kernel_cc: int, td: TileDescription, element_C: cutlass.DataType = None, - element_D: cutlass.DataType = None) -> tuple: + element_D: cutlass.DataType = None, + verbose: bool = True) -> tuple: """ Checks whether a device with `cc` supports the number of stages within `tile_description`, both based on raw limits on the number of stages and based on shared memory capacity @@ -97,6 +99,8 @@ def valid_stage_count( :type element_C: cutlass.DataType :param element_D: data type of operand D :type element_D: cutlass.DataType + :param verbose: whether to log warnings + :type verbose: bool :return: tuple with the first element indicating whether the provided tile description is valid for the provided device and the second element being an error message @@ -107,7 +111,7 @@ def valid_stage_count( # Stage count of None or 0 for SM90 indicates that the CollectiveBuilder automatically # determines the stage count to use. Thus, all settings are valid in these scenarios. return (True, "") - else: + elif verbose: cutlass.logger.warning( "Setting an explicit stage count for SM90 kernels currently may " "result in compilation errors if the combination of tile shape, " @@ -125,9 +129,9 @@ def valid_stage_count( # only catches cases in which the mainloop exceeds the device's shared memory capacity. # This is not a concern for CUTLASS 2.x kernels, for which the shared memory of the # mainloop and epilogue is shared. - smem_per_stage = calculate_smem_usage_per_stage(td, cutlass.OperationKind.Gemm) + smem_per_stage = calculate_smem_usage_per_stage(td, OperationKind.Gemm) smem_usage_mainloop = (smem_per_stage * td.stages) - smem_arch = cutlass.SharedMemPerCC[cc] << 10 + smem_arch = SharedMemPerCC[cc] << 10 if smem_usage_mainloop > smem_arch: return ( False, "Configuration uses too much shared memory. Consider reducing stage count or tile shape.\n" @@ -214,7 +218,9 @@ def valid_schedule( return (False, "Kernel and epilogue schedules must either both be auto or neither be auto") if not tile_scheduler_default: - if (tile_scheduler == cutlass.TileSchedulerType.StreamK) and (kernel_schedule != cutlass.KernelScheduleType.TmaWarpSpecializedCooperative): + cooperative_kernels = [cutlass.KernelScheduleType.TmaWarpSpecializedCooperative, + cutlass.KernelScheduleType.CpAsyncWarpSpecializedCooperative] + if (tile_scheduler == cutlass.TileSchedulerType.StreamK) and (kernel_schedule not in cooperative_kernels): return (False, "Stream-K tile scheduler is currently only supported with the cooperative kernel schedule") return (True, "") diff --git a/python/cutlass/utils/datatypes.py b/python/cutlass/utils/datatypes.py index fa229557f5..d26ada29d5 100644 --- a/python/cutlass/utils/datatypes.py +++ b/python/cutlass/utils/datatypes.py @@ -35,33 +35,55 @@ """ import cutlass -from cutlass import ( +from cutlass_library import ( DataTypeSize, + MathOperation, + MathInstruction ) from cutlass.backend.library import ( - MathInstruction, - MathOperation, TileDescription, ) -try: - import numpy as np - - numpy_available = True - _library_to_numpy_dict = { - cutlass.DataType.f16: np.float16, - cutlass.DataType.f32: np.float32, - cutlass.DataType.f64: np.float64, - cutlass.DataType.s8: np.int8, - cutlass.DataType.s32: np.int32, - } -except ImportError: - numpy_available = False - _library_to_numpy_dict = {} +bfloat16_available = None +cupy_available = None +numpy_available = None +torch_available = None +_library_to_cupy_dict = None +_library_to_numpy_dict = None +_library_to_torch_dict = None +_torch_to_library_dict = None + + +def is_numpy_available(): + global numpy_available, _library_to_numpy_dict + if numpy_available is None: + try: + import numpy as np + + numpy_available = True + _library_to_numpy_dict = { + cutlass.DataType.f16: np.float16, + cutlass.DataType.f32: np.float32, + cutlass.DataType.f64: np.float64, + cutlass.DataType.s8: np.int8, + cutlass.DataType.s32: np.int32, + } + except ImportError: + numpy_available = False + _library_to_numpy_dict = {} + return numpy_available + + +def is_numpy_tensor(inp) -> bool: + if is_numpy_available(): + import numpy as np + return isinstance(inp, np.ndarray) + return False def numpy_library_type(inp) -> cutlass.DataType: - if numpy_available: + if is_numpy_available(): + import numpy as np if inp == np.float16: return cutlass.DataType.f16 elif inp == np.float32: @@ -79,24 +101,36 @@ def numpy_type(inp): return _library_to_numpy_dict.get(inp, None) -try: - import cupy as cp +def is_cupy_available(): + global cupy_available + if cupy_available is None: + try: + import cupy as cp + + cupy_available = True + _library_to_cupy_dict = { + cutlass.DataType.f16: cp.float16, + cutlass.DataType.f32: cp.float32, + cutlass.DataType.f64: cp.float64, + cutlass.DataType.s8: cp.int8, + cutlass.DataType.s32: cp.int32, + } + except ImportError: + cupy_available = False + _library_to_cupy_dict = {} + return cupy_available - cupy_available = True - _library_to_cupy_dict = { - cutlass.DataType.f16: cp.float16, - cutlass.DataType.f32: cp.float32, - cutlass.DataType.f64: cp.float64, - cutlass.DataType.s8: cp.int8, - cutlass.DataType.s32: cp.int32, - } -except ImportError: - cupy_available = False - _library_to_cupy_dict = {} + +def is_cupy_tensor(inp) -> bool: + if is_cupy_available(): + import cupy as cp + return isinstance(inp, cp.ndarray) + return False def cupy_library_type(inp) -> cutlass.DataType: - if cupy_available: + if is_cupy_available(): + import cupy as cp if inp == cp.float16: return cutlass.DataType.f16 elif inp == cp.float32: @@ -110,39 +144,50 @@ def cupy_type(inp): return _library_to_cupy_dict.get(inp, None) -try: - import torch - - torch_available = True - _torch_to_library_dict = { - torch.half: cutlass.DataType.f16, - torch.float16: cutlass.DataType.f16, - torch.bfloat16: cutlass.DataType.bf16, - torch.float: cutlass.DataType.f32, - torch.float32: cutlass.DataType.f32, - torch.double: cutlass.DataType.f64, - torch.float64: cutlass.DataType.f64, - torch.int8: cutlass.DataType.s8, - torch.int32: cutlass.DataType.s32, - torch.uint8: cutlass.DataType.u8, - } - - _library_to_torch_dict = { - cutlass.DataType.f16: torch.half, - cutlass.DataType.f16: torch.float16, - cutlass.DataType.bf16: torch.bfloat16, - cutlass.DataType.f32: torch.float, - cutlass.DataType.f32: torch.float32, - cutlass.DataType.f64: torch.double, - cutlass.DataType.f64: torch.float64, - cutlass.DataType.s8: torch.int8, - cutlass.DataType.s32: torch.int32, - cutlass.DataType.u8: torch.uint8, - } -except ImportError: - torch_available = False - _torch_to_library_dict = {} - _library_to_torch_dict = {} +def is_torch_available(): + global torch_available, _library_to_torch_dict, _torch_to_library_dict + if torch_available is None: + try: + import torch + + torch_available = True + _torch_to_library_dict = { + torch.half: cutlass.DataType.f16, + torch.float16: cutlass.DataType.f16, + torch.bfloat16: cutlass.DataType.bf16, + torch.float: cutlass.DataType.f32, + torch.float32: cutlass.DataType.f32, + torch.double: cutlass.DataType.f64, + torch.float64: cutlass.DataType.f64, + torch.int8: cutlass.DataType.s8, + torch.int32: cutlass.DataType.s32, + torch.uint8: cutlass.DataType.u8, + } + + _library_to_torch_dict = { + cutlass.DataType.f16: torch.half, + cutlass.DataType.f16: torch.float16, + cutlass.DataType.bf16: torch.bfloat16, + cutlass.DataType.f32: torch.float, + cutlass.DataType.f32: torch.float32, + cutlass.DataType.f64: torch.double, + cutlass.DataType.f64: torch.float64, + cutlass.DataType.s8: torch.int8, + cutlass.DataType.s32: torch.int32, + cutlass.DataType.u8: torch.uint8, + } + except ImportError: + torch_available = False + _torch_to_library_dict = {} + _library_to_torch_dict = {} + return torch_available + + +def is_torch_tensor(inp) -> bool: + if is_torch_available(): + import torch + return isinstance(inp, torch.Tensor) + return False def torch_library_type(inp) -> cutlass.DataType: @@ -153,28 +198,35 @@ def torch_type(inp): return _library_to_torch_dict.get(inp, None) -try: - import bfloat16 +def is_bfloat16_available(): + global bfloat16_available + + if bfloat16_available is None: + try: + import bfloat16 - bfloat16_available = True -except ImportError: - bfloat16_available = False + bfloat16_available = True + except ImportError: + bfloat16_available = False + return bfloat16_available def bfloat16_library_type(inp) -> cutlass.DataType: - if bfloat16_available: + if is_bfloat16_available(): + import bfloat16 if inp == bfloat16.bfloat16: return cutlass.DataType.bf16 def bfloat16_type(inp): - if bfloat16_available: + if is_bfloat16_available(): + import bfloat16 if inp == cutlass.DataType.bf16: return bfloat16.bfloat16 def library_type(inp): - if inp in cutlass.DataTypeSize.keys(): + if inp in DataTypeSize: return inp for cvt_fn in [ @@ -205,23 +257,20 @@ def _tensor_from_torch(pt_tensor): def get_datatype_and_layout(tensor): - if (numpy_available and isinstance(tensor, np.ndarray)) or ( - cupy_available and isinstance(tensor, cp.ndarray) - ): + if (is_numpy_tensor(tensor) or is_cupy_tensor(tensor)): return _tensor_from_numpy(tensor) - elif torch_available and isinstance(tensor, torch.Tensor): + elif is_torch_tensor(tensor): return _tensor_from_torch(tensor) elif isinstance(tensor, float) or isinstance(tensor, int): return (cutlass.DataType.f32, cutlass.LayoutType.RowMajor) else: raise Exception(f"Unable to convert tensor of type {type(tensor)} to Python-bound CUTLASS datatype and layout.") + def get_tensor_shape(tensor, op="GEMM"): - if (numpy_available and isinstance(tensor, np.ndarray)) or ( - cupy_available and isinstance(tensor, cp.ndarray) - ): + if (is_numpy_tensor(tensor) or is_cupy_tensor(tensor)): return tensor.shape - elif torch_available and isinstance(tensor, torch.Tensor): + elif is_torch_tensor(tensor): size = tensor.size() if op == "CONV": # PyTorch Tensors have shape NCHW @@ -237,7 +286,7 @@ def get_tensor_shape(tensor, op="GEMM"): _math_operation_value_map = {x.value: x for x in MathOperation} -def backend_math_operation(math_op: cutlass.MathOperation): +def backend_math_operation(math_op: MathOperation): if math_op.value not in _math_operation_value_map.keys(): raise Exception(f"Unable to convert math operation of type {math_op} to backend math operation.") return _math_operation_value_map[math_op.value] diff --git a/python/cutlass/profiler/event_profiler.py b/python/cutlass/utils/profiler.py similarity index 99% rename from python/cutlass/profiler/event_profiler.py rename to python/cutlass/utils/profiler.py index 71f290c120..0843a0457e 100644 --- a/python/cutlass/profiler/event_profiler.py +++ b/python/cutlass/utils/profiler.py @@ -39,12 +39,12 @@ from cuda import cuda, cudart import numpy as np -import torch from cutlass import CUTLASS_PATH from cutlass.backend.library import DataTypeSize from cutlass.op.op import OperationBase from cutlass.shape import GemmCoord +from cutlass.utils.datatypes import is_numpy_tensor class GpuTimer: diff --git a/python/cutlass_library/__init__.py b/python/cutlass_library/__init__.py index dfc3154138..d948cff18a 100644 --- a/python/cutlass_library/__init__.py +++ b/python/cutlass_library/__init__.py @@ -30,6 +30,7 @@ # ################################################################################################# +import os import sys from . import conv2d_operation @@ -47,3 +48,16 @@ from . import rank_k_operation from . import symm_operation from . import trmm_operation + +# Make enum types from library.py accessible via cutlass_library.* +from .library import * + +# Set up `source` to point to the path containing the CUTLASS source. +# Check first if the path cotains a `source` subdirectory -- this will +# be the case when the package has been installed via pip. Otherwise, +# default to the root of CUTLASS. +install_source_path = os.path.join(__path__[0], 'source') +if os.path.isdir(install_source_path): + source_path = install_source_path +else: + source_path = os.path.join(__path__[0], '../..') diff --git a/python/cutlass_library/conv2d_operation.py b/python/cutlass_library/conv2d_operation.py index fcfcd24a76..60dc47b679 100644 --- a/python/cutlass_library/conv2d_operation.py +++ b/python/cutlass_library/conv2d_operation.py @@ -38,7 +38,13 @@ import os.path import shutil -from cutlass_library.library import * +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * ################################################################################################### @@ -62,11 +68,6 @@ def __init__(self, conv_kind, iterator_algorithm, arch, tile_description, A, B, self.stride_support = stride_support self.swizzling_functor = swizzling_functor self.group_mode = group_mode - - # - def is_mixed_input(self): - return self.A.element != self.B.element - # def is_complex(self): complex_operators = [ @@ -75,6 +76,10 @@ def is_complex(self): ] return self.tile_description.math_instruction.math_operation in complex_operators + # + def is_mixed_input(self): + return self.A.element != self.B.element + # def accumulator_type(self): accum = self.tile_description.math_instruction.element_accumulator @@ -262,7 +267,7 @@ def __init__(self): 1, ${threadblock_output_shape_n}, ${threadblock_output_shape_p}, - ${threadblock_output_shape_q}>, + ${threadblock_output_shape_q}>, ${stages}, ${math_operator}, ${iterator_algorithm}, diff --git a/python/cutlass_library/conv3d_operation.py b/python/cutlass_library/conv3d_operation.py index 5ab1b900ae..10cb5e1446 100644 --- a/python/cutlass_library/conv3d_operation.py +++ b/python/cutlass_library/conv3d_operation.py @@ -38,7 +38,13 @@ import os.path import shutil -from cutlass_library.library import * +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * ################################################################################################### @@ -60,11 +66,11 @@ def __init__(self, conv_kind, iterator_algorithm, arch, tile_description, A, B, self.iterator_algorithm = iterator_algorithm self.stride_support = stride_support self.swizzling_functor = swizzling_functor - + # def is_mixed_input(self): return self.A.element != self.B.element - + # def core_name(self): ''' The basic operation kind is prefixed with a letter indicating the accumulation type. ''' diff --git a/python/cutlass_library/gemm_operation.py b/python/cutlass_library/gemm_operation.py index ad62422c93..11691f42a7 100644 --- a/python/cutlass_library/gemm_operation.py +++ b/python/cutlass_library/gemm_operation.py @@ -34,14 +34,20 @@ Utilities for emitting GEMM kernels """ +import collections import enum -import os.path -import shutil import functools import operator -import collections +import os.path +import shutil -from cutlass_library.library import * +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * ################################################################################################### # @@ -55,9 +61,14 @@ class GemmOperation: def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, \ epilogue_functor = EpilogueFunctor.LinearCombination, swizzling_functor = SwizzlingFunctor.Identity8, D = None, kernel_schedule = KernelScheduleType.ScheduleAuto, epilogue_schedule = EpilogueScheduleType.ScheduleAuto, - tile_scheduler = TileSchedulerType.Default): + tile_scheduler = TileSchedulerType.Default, extra_args = None): - self.prefix = "3x" if gemm_kind == GemmKind.Universal3x else "" + kinds_3x = { + GemmKind.Universal3x, + GemmKind.SparseUniversal3x, + } + self.is_3x = gemm_kind in kinds_3x + self.prefix = "3x" if self.is_3x else "" self.operation_kind = OperationKind.Gemm self.arch = arch self.tile_description = tile_description @@ -66,10 +77,11 @@ def __init__(self, gemm_kind, arch, tile_description, A, B, C, element_epilogue, self.B = B self.C = C self.D = D + if self.D == None: self.D = self.C - if gemm_kind != GemmKind.Universal3x: + if not self.is_3x: assert(kernel_schedule == KernelScheduleType.ScheduleAuto) assert(epilogue_schedule == EpilogueScheduleType.ScheduleAuto) self.kernel_schedule = kernel_schedule @@ -91,7 +103,7 @@ def is_complex(self): # def is_mixed_input(self): return self.A.element != self.B.element - + # def is_planar_complex(self): return self.gemm_kind in (GemmKind.PlanarComplex, GemmKind.PlanarComplexArray) @@ -125,13 +137,20 @@ def core_name(self): MathOperation.and_popc: 'and' } - if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp or \ - self.tile_description.math_instruction.opcode_class == OpcodeClass.WmmaTensorOp: + tensor_ops = [ + OpcodeClass.TensorOp, + OpcodeClass.WmmaTensorOp, + OpcodeClass.SparseTensorOp, + ] + + is_tensor_op = self.tile_description.math_instruction.opcode_class in tensor_ops + + if is_tensor_op: math_op = self.tile_description.math_instruction.math_operation math_op_string = math_operations_map[math_op] if math_op in math_operations_map.keys() else '' - if self.gemm_kind == GemmKind.Universal3x: + if self.is_3x: inst_shape = "{0}x{1}x{2}".format(*tuple(self.tile_description.math_instruction.instruction_shape)) else: inst_shape = "{0}{1}{2}".format(*tuple(self.tile_description.math_instruction.instruction_shape)) @@ -183,6 +202,16 @@ def extended_name_3x(self): core_name = self.core_name()) return extended_name + def datatype_name_3x(self): + '''Generates a string representing the MMA atom. Assumes accumulator type is C type.''' + datatype_name = "{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format( + element_a = DataTypeNames[self.A.element], + element_b = DataTypeNames[self.B.element], + element_acc = DataTypeNames[self.tile_description.math_instruction.element_accumulator], + element_c = DataTypeNames[self.C.element], + element_d = DataTypeNames[self.D.element]) + return datatype_name + # Generates a short string representing the AB layout tags (e.g. nt or tn) def layout_name(self): if self.is_complex() or self.is_planar_complex(): @@ -213,6 +242,10 @@ def kernel_schedule_name_3x(self): def epilogue_schedule_name_3x(self): return EpilogueScheduleSuffixes[self.epilogue_schedule] + # Generate a short string representing the operation class + def opcode_class_name(self): + return OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + # Generates the full kernel function name def procedural_name(self): ''' The full procedural name indicates architecture, extended name, tile size, and layout. ''' @@ -661,7 +694,6 @@ def emit(self, operation): ################################################################################################### -# class EmitGemmUniversal3xInstance: ''' Responsible for emitting a CUTLASS 3.x template definition''' @@ -687,10 +719,10 @@ def __init__(self, operation_suffix = ''): using ${operation_name}_epilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - ${arch}, ${opcode_class}, + ${arch}, ${opcode_class_epi}, cute::Shape, cute::Shape, - cutlass::epilogue::collective::EpilogueTileAuto, + ${epi_tile_mn}, ${element_accumulator}, ${element_epilogue}, ${element_c}, ${layout_c}, ${align_c}, ${element_d}, ${layout_d}, ${align_d}, @@ -699,7 +731,7 @@ def __init__(self, operation_suffix = ''): using ${operation_name}_mainloop = typename cutlass::gemm::collective::CollectiveBuilder< - ${arch}, ${opcode_class}, + ${arch}, ${opcode_class_main}, ${element_a}, ${layout_a}, ${align_a}, ${element_b}, ${layout_b}, ${align_b}, ${element_accumulator}, @@ -743,6 +775,10 @@ def emit(self, operation): stage_count_string = f"cutlass::gemm::collective::StageCountAutoCarveout" warp_shape = [tile_shape[idx] // warp_count[idx] for idx in range(3)] + epi_tile_mn = "cutlass::epilogue::collective::EpilogueTileAuto" + opcode_class_main = operation.tile_description.math_instruction.opcode_class + opcode_class_epi = opcode_class_main + instance_layout_A, instance_layout_B, instance_layout_C , instance_layout_D = \ (operation.A.layout, operation.B.layout, operation.C.layout, operation.D.layout) @@ -760,20 +796,23 @@ def emit(self, operation): else: epilogue_functor = self.epilogue_functor.emit_declaration() # - + element_a = DataTypeTag[operation.A.element] + element_b = DataTypeTag[operation.B.element] + epilogue_schedule_type = EpilogueScheduleTag[operation.epilogue_schedule] values = { 'operation_name': operation.procedural_name(), 'operation_suffix': self.operation_suffix, - 'element_a': DataTypeTag[operation.A.element], + 'element_a': element_a, 'layout_a': LayoutTag[instance_layout_A], - 'element_b': DataTypeTag[operation.B.element], + 'element_b': element_b, 'layout_b': LayoutTag[instance_layout_B], 'element_c': DataTypeTag[operation.C.element], 'layout_c': LayoutTag[instance_layout_C], 'element_d': DataTypeTag[operation.D.element], 'layout_d': LayoutTag[instance_layout_D], 'element_accumulator': DataTypeTag[operation.accumulator_type()], - 'opcode_class': OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], + 'opcode_class_main': OpcodeClassTag[opcode_class_main], + 'opcode_class_epi': OpcodeClassTag[opcode_class_epi], 'arch': "cutlass::arch::Sm%d" % operation.arch, 'tile_shape_m': str(operation.tile_description.tile_shape[0]), 'tile_shape_n': str(operation.tile_description.tile_shape[1]), @@ -788,7 +827,8 @@ def emit(self, operation): 'instruction_shape_n': str(operation.tile_description.math_instruction.instruction_shape[1]), 'instruction_shape_k': str(operation.tile_description.math_instruction.instruction_shape[2]), 'kernel_schedule' : str(KernelScheduleTag[operation.kernel_schedule]), - 'epilogue_schedule' : str(EpilogueScheduleTag[operation.epilogue_schedule]), + 'epilogue_schedule' : str(epilogue_schedule_type), + 'epi_tile_mn' : epi_tile_mn, 'epilogue_functor': epilogue_functor, 'stages': stage_count_string, 'align_a': str(operation.A.alignment), @@ -800,7 +840,7 @@ def emit(self, operation): 'math_operation': MathOperationTag[operation.tile_description.math_instruction.math_operation], 'epilogue_vector_length': str(epilogue_vector_length), 'element_epilogue': str(DataTypeTag[operation.element_epilogue]), - 'tile_scheduler': str(TileSchedulerTag[operation.tile_scheduler]) + 'tile_scheduler': str(TileSchedulerTag[operation.tile_scheduler]), } return SubstituteTemplate(self.gemm_template, values) diff --git a/python/cutlass_library/generator.py b/python/cutlass_library/generator.py index 099c4271b7..1f07e76bc2 100644 --- a/python/cutlass_library/generator.py +++ b/python/cutlass_library/generator.py @@ -34,16 +34,52 @@ Utilities for enumerating CUTLASS library kernels """ +import argparse import enum +from itertools import product +import logging import os.path import shutil -import argparse -import logging - -from cutlass_library.library import * -from cutlass_library.manifest import * -from itertools import product +import sys + + +# Certain usecases of cutlass_library nearly always prefer to run as scripts with +# relative imports, rather than via an installed Python package. An example of this +# is using CUTLASS's CMake system to generate a library of kernels to be profiled. +# To make it easy to use these use cases when an existing installation of cutlass_library +# exists, this global flag can be set to true (via command-line arguments) to ensure +# that package-based installations are not used. + +# Create a temporary argument parser to check only for the availability of the +# --disable-cutlass-package-imports argument, which controls whether package-based +# imports are disabled. +def _add_package_disablement_flag(argparser): + argparser.add_argument("--disable-cutlass-package-imports", action='store_true', required=False, + help="Disable use of cutlass_library from Python package") + +_parser = argparse.ArgumentParser() +_add_package_disablement_flag(_parser) +_args, _ = _parser.parse_known_args() + +# Add `CUTLASS_IGNORE_PACKAGE` to `builtins` so that it is visible for gating future +# imports without requiring importing another module. Ideally, we would just place this +# as a global variable in a module to that could be imported and checked (e.g., +# utils.CUTLASS_IGNORE_PACKAGE). However, this raises the issue of determining +# where this module should be sourced (from the cutlass_library package or from +# a relative import), which is the problem this variable is being used to solve in the +# first place. +import builtins +builtins.CUTLASS_IGNORE_PACKAGE = _args.disable_cutlass_package_imports + +try: + if CUTLASS_IGNORE_PACKAGE: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * + from cutlass_library.manifest import * +except ImportError: + from library import * + from manifest import * ################################################################################################### # @@ -79,7 +115,7 @@ def product(X, identity = 1): return min(max_alignment, elements_per_thread) def DefaultSwizzlingFunctor(): - return SwizzlingFunctor.Identity8; + return SwizzlingFunctor.Identity8 # To use StreamK decomposition for basic GEMMs, set `swizzling_functor = SwizzlingFunctor.StreamK` # @@ -103,7 +139,7 @@ def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \ for tile_description in tile_descriptions: for alignment in alignment_constraints: for complex_transform in complex_transforms: - + # If alignment is a tuple or a list, then we have different alignments for A and B alignment_a = alignment if isinstance(alignment, int) else alignment[0] alignment_b = alignment if isinstance(alignment, int) else alignment[1] @@ -121,7 +157,6 @@ def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \ return operations - # Generates 3.0 API based GemmUniversal API kernels. Alignment constraints are folded in with layouts def CreateGemmUniversal3xOperator( manifest, layouts, tile_descriptions, data_types, @@ -157,11 +192,14 @@ def CreateGemmUniversal3xOperator( C = TensorDescription(data_type["c_type"], layout[2][0], layout[2][1]) D = TensorDescription(data_type["d_type"], layout[2][0], layout[2][1]) + extra_args = {} + gemm_kind = GemmKind.Universal3x element_compute = data_type.get("epi_type", data_type["acc_type"]) + operation = GemmOperation( - GemmKind.Universal3x, tile_description.minimum_compute_capability, + gemm_kind, tile_description.minimum_compute_capability, tile_description, A, B, C, element_compute, epilogue_functor, swizzling_functor, D, - kernel_schedule, epilogue_schedule, tile_scheduler) + kernel_schedule, epilogue_schedule, tile_scheduler, extra_args) manifest.append(operation) operations.append(operation) @@ -2153,7 +2191,6 @@ def GenerateSM80_PlanarComplexTensorOp_16816(manifest, cuda_version): CreateGemmPlanarComplexOperator(manifest, layouts, tile_descriptions, \ data_type_mixed, alignment_constraints, complex_transforms) - # def GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version): @@ -2225,8 +2262,9 @@ def GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version): math_inst.element_accumulator, ] + # streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit. operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints) + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) if math_inst.element_a != math_inst.element_accumulator: @@ -2239,14 +2277,13 @@ def GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version): ] operations += CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type_mixed, alignment_constraints) - + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) + for op in operations: if (DataTypeSize[op.C.element] == 16) and \ (op.tile_description.threadblock_shape[1] <= 32): op.C.alignment = 4 - # def GenerateSM80_TensorOp_16816_mixed_input_upcast_b(manifest, cuda_version): @@ -2287,8 +2324,7 @@ def GenerateSM80_TensorOp_16816_mixed_input_upcast_b(manifest, cuda_version): # inner list contains the alignment constraints for operands/matrices # [[alignA, alignB, alignC],..] alignment_constraints = [[8, 16, 8],] - - + for math_inst in math_instructions: tile_descriptions = [ # 128x128 @@ -2321,8 +2357,9 @@ def GenerateSM80_TensorOp_16816_mixed_input_upcast_b(manifest, cuda_version): math_inst.element_accumulator, ] + # streamk uses more regs which can cause spill for the biggest warp tile size when the accumulators are 32bit. CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type, alignment_constraints) + data_type, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) if math_inst.element_a != math_inst.element_accumulator: @@ -2335,12 +2372,12 @@ def GenerateSM80_TensorOp_16816_mixed_input_upcast_b(manifest, cuda_version): ] operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ - data_type_mixed, alignment_constraints) - + data_type_mixed, alignment_constraints, None, EpilogueFunctor.LinearCombination, SwizzlingFunctor.Identity8) + for op in operations: if op.tile_description.threadblock_shape[1] <= 32: op.C.alignment = 4 - + # def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version): @@ -2723,6 +2760,7 @@ def GenerateSM80_TensorOp_16864_Interleaved(manifest, cuda_version): for op in operations: op.C.alignment = 16 +# # def GenerateSM80_TensorOp_168256(manifest, cuda_version): @@ -4458,6 +4496,154 @@ def GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version): [[KernelScheduleType.TmaWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]], tile_schedulers=[TileSchedulerType.StreamK]) +# +def GenerateSM90_TensorOp_16b_WGMMA_alignx_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 4], [LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 2], [LayoutType.ColumnMajor, 2], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 2], [LayoutType.RowMajor, 2], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 2], [LayoutType.ColumnMajor, 2], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 2], [LayoutType.RowMajor, 2], [LayoutType.ColumnMajor, 1]], + ] + + math_instructions = [ + MathInstruction( + [64, 128, 16], + DataType.f16, DataType.f16, DataType.f16, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [64, 128, 16], + DataType.f16, DataType.f16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [64, 128, 16], + DataType.bf16, DataType.bf16, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + ] + + min_cc = 90 + max_cc = 90 + + for math_inst in math_instructions: + tile_descriptions_small = [ + # TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + # 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), + ] + tile_descriptions_medium = [ + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), + # TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1]*2, math_inst.instruction_shape[2]*4], + # 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), + ] + tile_descriptions = tile_descriptions_small + tile_descriptions_medium + + data_type = { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_accumulator, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + } + + # Set alignment c based on Destination format. + for layout in layouts: + if data_type["c_type"] in [DataType.s32, DataType.f32]: + layout[2][1] = 4 + elif data_type["c_type"] in [DataType.f16, DataType.bf16]: + layout[2][1] = 8 + + schedules = [ + # [KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto], + [KernelScheduleType.CpAsyncWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized] + ] + stream_k_schedules = [] + + if CudaToolkitVersionSatisfies(cuda_version, 12, 1): + schedules += [ + [KernelScheduleType.CpAsyncWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized], + # [KernelScheduleType.CpAsyncWarpSpecializedPingpong, EpilogueScheduleType.NoSmemWarpSpecialized] + ] + stream_k_schedules += [[KernelScheduleType.CpAsyncWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized]] + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, schedules) + + if CudaToolkitVersionSatisfies(cuda_version, 12, 1): + # Add stream-K variants + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, stream_k_schedules, tile_schedulers=[TileSchedulerType.StreamK]) + + # persistent kernels with TMA epilogues + # if CudaToolkitVersionSatisfies(cuda_version, 12, 1): + # CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + # [[KernelScheduleType.CpAsyncWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized], + # [KernelScheduleType.CpAsyncWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]]) + + # CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + # [[KernelScheduleType.CpAsyncWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]], + # tile_schedulers=[TileSchedulerType.StreamK]) + + # # Emit instance without C allocation + load + # data_type["c_type"] = DataType.void + # CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + # [[KernelScheduleType.CpAsyncWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized], + # [KernelScheduleType.CpAsyncWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]]) + + # CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + # [[KernelScheduleType.CpAsyncWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]], + # tile_schedulers=[TileSchedulerType.StreamK]) + + # for mixed precision kernels, also generate kernels that write output matrix in the A/B format + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + data_type_mixed = { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_a, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + } + + # Set alignment c based on Destination format. + for layout in layouts: + if data_type_mixed["c_type"] in [DataType.s32, DataType.f32]: + layout[2][1] = 4 + elif data_type_mixed["c_type"] in [DataType.f16, DataType.bf16]: + layout[2][1] = 8 + + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed, schedules) + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed, stream_k_schedules, tile_schedulers=[TileSchedulerType.StreamK]) + + # persistent kernels with TMA epilogues + # if CudaToolkitVersionSatisfies(cuda_version, 12, 1): + # CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed, + # [[KernelScheduleType.CpAsyncWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized], + # [KernelScheduleType.CpAsyncWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]]) + + # CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed, + # [[KernelScheduleType.CpAsyncWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]], + # tile_schedulers=[TileSchedulerType.StreamK]) + + # # Emit instance without C allocation+load + # data_type_mixed["c_type"] = DataType.void + # CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed, + # [[KernelScheduleType.CpAsyncWarpSpecializedPingpong, EpilogueScheduleType.TmaWarpSpecialized], + # [KernelScheduleType.CpAsyncWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]]) + + # CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type_mixed, + # [[KernelScheduleType.CpAsyncWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]], + # tile_schedulers=[TileSchedulerType.StreamK]) + # def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): @@ -4582,6 +4768,91 @@ def GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version): CreateGemmUniversal3xOperator(manifest, layouts_tf32_tn_nn_nt, tile_descriptions, data_types, schedules_default) CreateGemmUniversal3xOperator(manifest, layouts_tf32_tt, tile_descriptions, data_types, schedules_transposed_epilogue) +# +def GenerateSM90_TensorOp_tf32_WGMMA_alignx_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + # layouts for ABC and their alignments. + layouts = [ + [[LayoutType.RowMajor, 2], [LayoutType.ColumnMajor, 2], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 2], [LayoutType.RowMajor, 2], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 2], [LayoutType.ColumnMajor, 2], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 2], [LayoutType.RowMajor, 2], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 1], [LayoutType.ColumnMajor, 1], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 1], [LayoutType.RowMajor, 1], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 1], [LayoutType.ColumnMajor, 1], [LayoutType.ColumnMajor, 1]], + [[LayoutType.ColumnMajor, 1], [LayoutType.RowMajor, 1], [LayoutType.ColumnMajor, 1]], + ] + + math_inst = MathInstruction( + [64, 128, 8], + DataType.tf32, DataType.tf32, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add) + + min_cc = 90 + max_cc = 90 + + tile_descriptions_medium = [ + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]) + ] + + tile_descriptions_small = [ + # TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + # 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]) + ] + + tile_descriptions = tile_descriptions_medium + tile_descriptions_small + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_accumulator, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + }, + { + "a_type" : DataType.f32, + "b_type" : DataType.f32, + "c_type" : math_inst.element_accumulator, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32 + } + ] + + is_tt_layout = lambda v: v[0][0] == LayoutType.RowMajor and v[1][0] == LayoutType.RowMajor + # Split kernels into TN/NT, NN or TT layouts + layouts_tn_nn_nt = filter(lambda v: not is_tt_layout(v), layouts) + layouts_tt = filter(is_tt_layout, layouts) + + CreateGemmUniversal3xOperator(manifest, layouts_tn_nn_nt, tile_descriptions, data_types, [ + # [KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto], + [KernelScheduleType.CpAsyncWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized], + ]) + + # Kernels with TT layout use EpilogueTransposed (NoSmemWarpSpecialized with swapped strides), + # because they use NN kernels underneath and transposing its epilogue will get the correct output + CreateGemmUniversal3xOperator(manifest, layouts_tt, tile_descriptions, data_types, [ + # [KernelScheduleType.ScheduleAuto, EpilogueScheduleType.EpilogueTransposed], + [KernelScheduleType.CpAsyncWarpSpecialized, EpilogueScheduleType.EpilogueTransposed] + ]) + + if CudaToolkitVersionSatisfies(cuda_version, 12, 1): + CreateGemmUniversal3xOperator(manifest, layouts_tn_nn_nt, tile_descriptions, data_types, [ + # [KernelScheduleType.CpAsyncWarpSpecializedPingpong, EpilogueScheduleType.NoSmemWarpSpecialized], + [KernelScheduleType.CpAsyncWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized] + ]) + + # Stream-K schedules + CreateGemmUniversal3xOperator(manifest, layouts_tn_nn_nt, tile_descriptions, data_types, [ + [KernelScheduleType.CpAsyncWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized] + ], tile_schedulers=[TileSchedulerType.StreamK]) + # def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): @@ -4677,6 +4948,81 @@ def GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version): tile_schedulers=[TileSchedulerType.Persistent, TileSchedulerType.StreamK] ) +# +def GenerateSM90_TensorOp_int8_WGMMA_alignx_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + # layouts for ABC and their alignments + layouts = [ + [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 1]], + [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]], + ] + + math_instructions = [ + MathInstruction( + [64, 128, 32], + DataType.s8, DataType.s8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [64, 128, 32], + DataType.u8, DataType.u8, DataType.s32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + ] + + min_cc = 90 + max_cc = 90 + + for math_inst in math_instructions: + tile_descriptions_small = [ + # TileDescription([math_inst.instruction_shape[0], math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + # 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), + ] + tile_descriptions_medium = [ + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), + ] + tile_descriptions = tile_descriptions_medium + tile_descriptions_small + + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : math_inst.element_accumulator, + "d_type" : math_inst.element_accumulator, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.s8, + "d_type" : math_inst.element_a, + "acc_type" : math_inst.element_accumulator, + "epi_type" : DataType.f32 + } + ] + + for data_type in data_types: + for layout in layouts: + layout[2][1] = 128 // DataTypeSize[data_type["d_type"]] + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, [ + # [KernelScheduleType.ScheduleAuto, EpilogueScheduleType.NoSmemWarpSpecialized], + [KernelScheduleType.CpAsyncWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized] + ]) + + if CudaToolkitVersionSatisfies(cuda_version, 12, 1): + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, [ + # [KernelScheduleType.CpAsyncWarpSpecializedPingpong, EpilogueScheduleType.NoSmemWarpSpecialized], + [KernelScheduleType.CpAsyncWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized] + ]) + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + [[KernelScheduleType.CpAsyncWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized]], + tile_schedulers=[TileSchedulerType.StreamK]) + +# def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): return @@ -4882,6 +5228,188 @@ def GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version): [KernelScheduleType.TmaWarpSpecializedCooperativeFP8FastAccum, EpilogueScheduleType.TmaWarpSpecializedCooperative]], tile_schedulers=[TileSchedulerType.StreamK]) +# +def GenerateSM90_TensorOp_fp8_WGMMA_alignx_gemm(manifest, cuda_version): + if not CudaToolkitVersionSatisfies(cuda_version, 12, 0): + return + + # layouts for ABC and their alignments + layouts = [ + [[LayoutType.RowMajor, 8], [LayoutType.ColumnMajor, 8], [LayoutType.ColumnMajor, 1]], # TN Layout + [[LayoutType.RowMajor, 4], [LayoutType.ColumnMajor, 4], [LayoutType.ColumnMajor, 1]], # TN Layout + ] + + math_instructions = [ + # inst 64x128x32 + MathInstruction( + [64, 128, 32], + DataType.e4m3, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [64, 128, 32], + DataType.e4m3, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [64, 128, 32], + DataType.e5m2, DataType.e4m3, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + MathInstruction( + [64, 128, 32], + DataType.e5m2, DataType.e5m2, DataType.f32, + OpcodeClass.TensorOp, + MathOperation.multiply_add), + # inst 64x64x32 + # MathInstruction( + # [64, 64, 32], + # DataType.e4m3, DataType.e4m3, DataType.f32, + # OpcodeClass.TensorOp, + # MathOperation.multiply_add), + # MathInstruction( + # [64, 64, 32], + # DataType.e4m3, DataType.e5m2, DataType.f32, + # OpcodeClass.TensorOp, + # MathOperation.multiply_add), + # MathInstruction( + # [64, 64, 32], + # DataType.e5m2, DataType.e4m3, DataType.f32, + # OpcodeClass.TensorOp, + # MathOperation.multiply_add), + # MathInstruction( + # [64, 64, 32], + # DataType.e5m2, DataType.e5m2, DataType.f32, + # OpcodeClass.TensorOp, + # MathOperation.multiply_add), + ] + + min_cc = 90 + max_cc = 90 + + for math_inst in math_instructions: + data_types = [ + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f32, + "d_type" : DataType.f32, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f32, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f32, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.bf16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.bf16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.f16, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e4m3, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + }, + { + "a_type" : math_inst.element_a, + "b_type" : math_inst.element_b, + "c_type" : DataType.f16, + "d_type" : DataType.e5m2, + "acc_type" : math_inst.element_accumulator, + "epi_type" : math_inst.element_accumulator + }, + ] + + if math_inst.instruction_shape[1] == 128: + tile_descriptions = [ + # 128x128x128 + TileDescription([math_inst.instruction_shape[0]*2, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), + ] + + # elif math_inst.instruction_shape[1] == 64: + # tile_descriptions = [ + # # 256x64x128 + # TileDescription([math_inst.instruction_shape[0]*4, math_inst.instruction_shape[1], math_inst.instruction_shape[2]*4], + # 0, [4, 1, 1], math_inst, min_cc, max_cc, [1,1,1]), + # ] + + else: + assert False, "math inst is not supported" + + if CudaToolkitVersionSatisfies(cuda_version, 12, 1): + schedules = [ + # [KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto], + [KernelScheduleType.CpAsyncWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized], + # [KernelScheduleType.CpAsyncWarpSpecializedPingpong, EpilogueScheduleType.NoSmemWarpSpecialized], + [KernelScheduleType.CpAsyncWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized], + ] + stream_k_schedules = [[KernelScheduleType.CpAsyncWarpSpecializedCooperative, EpilogueScheduleType.NoSmemWarpSpecialized]] + else: + schedules = [ + # [KernelScheduleType.ScheduleAuto, EpilogueScheduleType.ScheduleAuto], + [KernelScheduleType.CpAsyncWarpSpecialized, EpilogueScheduleType.NoSmemWarpSpecialized] + ] + stream_k_schedules = [] + + + for data_type in data_types: + # With No-SMEM epilogues + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, schedules) + + if CudaToolkitVersionSatisfies(cuda_version, 12, 1): + # Persistent kernels with TMA epilogues + # CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + # [[KernelScheduleType.CpAsyncWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]]) + + # Add stream-K variants (with and without TMA epilogues) + CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, stream_k_schedules, tile_schedulers=[TileSchedulerType.StreamK]) + # CreateGemmUniversal3xOperator(manifest, layouts, tile_descriptions, data_type, + # [[KernelScheduleType.CpAsyncWarpSpecializedCooperative, EpilogueScheduleType.TmaWarpSpecializedCooperative]], + # tile_schedulers=[TileSchedulerType.StreamK]) # def GenerateSM90_TensorOp_1684(manifest, cuda_version): @@ -5488,9 +6016,13 @@ def GenerateSM90_TensorOp_1684_symm_complex_gaussian(manifest, cuda_version): # def GenerateSM90(manifest, cuda_version): GenerateSM90_TensorOp_16b_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_16b_WGMMA_alignx_gemm(manifest, cuda_version) GenerateSM90_TensorOp_tf32_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_tf32_WGMMA_alignx_gemm(manifest, cuda_version) GenerateSM90_TensorOp_int8_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_int8_WGMMA_alignx_gemm(manifest, cuda_version) GenerateSM90_TensorOp_fp8_WGMMA_gemm(manifest, cuda_version) + GenerateSM90_TensorOp_fp8_WGMMA_alignx_gemm(manifest, cuda_version) GenerateSM90_TensorOp_1684(manifest, cuda_version) GenerateSM90_TensorOp_1684_complex(manifest, cuda_version) GenerateSM90_TensorOp_1684_complex_gaussian(manifest, cuda_version) @@ -5543,6 +6075,7 @@ def define_parser(): parser.add_argument("--disable-full-archs-compilation", action="store_true", required=False, help="Disable compilation for every archs in --architectures") parser.add_argument("--log-level", default='info', type=numeric_log_level, required=False, help='Logging level to be used by the generator script') + _add_package_disablement_flag(parser) return parser diff --git a/python/cutlass_library/library.py b/python/cutlass_library/library.py index 66c7f940b6..c0c425c28e 100644 --- a/python/cutlass_library/library.py +++ b/python/cutlass_library/library.py @@ -400,6 +400,9 @@ class LayoutType(enum.Enum): class KernelScheduleType(enum.Enum): ScheduleAuto = enum_auto() Multistage = enum_auto() + CpAsyncWarpSpecialized = enum_auto() + CpAsyncWarpSpecializedPingpong = enum_auto() + CpAsyncWarpSpecializedCooperative = enum_auto() Tma = enum_auto() TmaWarpSpecialized = enum_auto() TmaWarpSpecializedPingpong = enum_auto() @@ -411,6 +414,9 @@ class KernelScheduleType(enum.Enum): KernelScheduleTag = { KernelScheduleType.ScheduleAuto: 'cutlass::gemm::collective::KernelScheduleAuto', KernelScheduleType.Multistage: 'cutlass::gemm::KernelMultistage', + KernelScheduleType.CpAsyncWarpSpecialized: 'cutlass::gemm::KernelCpAsyncWarpSpecialized', + KernelScheduleType.CpAsyncWarpSpecializedPingpong: 'cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong', + KernelScheduleType.CpAsyncWarpSpecializedCooperative: 'cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative', KernelScheduleType.Tma: 'cutlass::gemm::KernelTma', KernelScheduleType.TmaWarpSpecialized: 'cutlass::gemm::KernelTmaWarpSpecialized', KernelScheduleType.TmaWarpSpecializedPingpong: 'cutlass::gemm::KernelTmaWarpSpecializedPingpong', @@ -424,6 +430,9 @@ class KernelScheduleType(enum.Enum): KernelScheduleSuffixes = { KernelScheduleType.ScheduleAuto: '', KernelScheduleType.Multistage: '_cpasync', + KernelScheduleType.CpAsyncWarpSpecialized: '_cpasync_warpspecialized', + KernelScheduleType.CpAsyncWarpSpecializedPingpong: '_cpasync_warpspecialized_pingpong', + KernelScheduleType.CpAsyncWarpSpecializedCooperative: '_cpasync_warpspecialized_cooperative', KernelScheduleType.Tma: '_unspecialized', KernelScheduleType.TmaWarpSpecialized: '_warpspecialized', KernelScheduleType.TmaWarpSpecializedPingpong: '_warpspecialized_pingpong', @@ -541,7 +550,6 @@ class OpcodeClass(enum.Enum): WmmaTensorOp = enum_auto() SparseTensorOp = enum_auto() - OpcodeClassNames = { OpcodeClass.Simt: 'simt', OpcodeClass.TensorOp: 'tensorop', @@ -628,19 +636,20 @@ class GemmKind(enum.Enum): Sparse = enum_auto() Universal = enum_auto() Universal3x = enum_auto() + SparseUniversal3x = enum_auto() PlanarComplex = enum_auto() PlanarComplexArray = enum_auto() Grouped = enum_auto() - # GemmKindNames = { GemmKind.Gemm: "gemm", GemmKind.Sparse: "spgemm", GemmKind.Universal: "gemm", GemmKind.Universal3x: "gemm", + GemmKind.SparseUniversal3x: "spgemm", GemmKind.PlanarComplex: "gemm_planar_complex", GemmKind.PlanarComplexArray: "gemm_planar_complex_array", - GemmKind.Grouped: "gemm_grouped" + GemmKind.Grouped: "gemm_grouped", } # @@ -797,7 +806,7 @@ class GroupMode(enum.Enum): NoneGroup = enum_auto() # dense conv (G=1) SingleGroup = enum_auto() # grouped convolution (single group per CTA) MultipleGroup = enum_auto() # grouped convolution ( multiple groups per CTA) - Depthwise = enum_auto() # Depthwise convolution ( C=K=G ) + Depthwise = enum_auto() # Depthwise convolution ( C=K=G ) # GroupModeTag = { @@ -818,14 +827,18 @@ class GroupMode(enum.Enum): # class MathInstruction: - def __init__(self, instruction_shape, element_a, element_b, element_accumulator, opcode_class, math_operation = MathOperation.multiply_add): + def __init__(self, + instruction_shape, \ + element_a, element_b, element_accumulator, \ + opcode_class, math_operation = MathOperation.multiply_add \ + ): + self.instruction_shape = instruction_shape self.element_a = element_a self.element_b = element_b self.element_accumulator = element_accumulator self.opcode_class = opcode_class self.math_operation = math_operation - # class TileDescription: diff --git a/python/cutlass_library/manifest.py b/python/cutlass_library/manifest.py index 07427d6a88..3e3b477f41 100644 --- a/python/cutlass_library/manifest.py +++ b/python/cutlass_library/manifest.py @@ -36,18 +36,31 @@ """ import enum +import logging import os.path import shutil -from cutlass_library.library import * -from cutlass_library.gemm_operation import * -from cutlass_library.rank_k_operation import * -from cutlass_library.rank_2k_operation import * -from cutlass_library.trmm_operation import * -from cutlass_library.symm_operation import * -from cutlass_library.conv2d_operation import * -from cutlass_library.conv3d_operation import * -import logging +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * + from cutlass_library.gemm_operation import * + from cutlass_library.rank_k_operation import * + from cutlass_library.rank_2k_operation import * + from cutlass_library.trmm_operation import * + from cutlass_library.symm_operation import * + from cutlass_library.conv2d_operation import * + from cutlass_library.conv3d_operation import * +except ImportError: + from library import * + from gemm_operation import * + from rank_k_operation import * + from rank_2k_operation import * + from trmm_operation import * + from symm_operation import * + from conv2d_operation import * + from conv3d_operation import * ################################################################################################### _LOGGER = logging.getLogger(__name__) @@ -380,7 +393,6 @@ def __init__(self, args = None): architectures = args.architectures.split(';') if len(args.architectures) else ['50',] architectures = [x if x != '90a' else '90' for x in architectures] - self.compute_capabilities = [int(x) for x in architectures] if args.filter_by_cc in ['false', 'False', '0']: diff --git a/python/cutlass_library/rank_2k_operation.py b/python/cutlass_library/rank_2k_operation.py index dfa5f07068..977a169e8b 100644 --- a/python/cutlass_library/rank_2k_operation.py +++ b/python/cutlass_library/rank_2k_operation.py @@ -35,12 +35,18 @@ """ import enum -import os.path -import shutil import functools import operator +import os.path +import shutil -from cutlass_library.library import * +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * ################################################################################################### @@ -82,7 +88,7 @@ def is_complex(self): # def is_mixed_input(self): return self.A.element != self.B.element - + # def is_planar_complex(self): return False @@ -234,7 +240,7 @@ def __init__(self): """ self.rank_k_complex_template = """ // Rank K operator ${operation_name} -using Operation_${operation_name} = +using Operation_${operation_name} = typename cutlass::gemm::device::Rank2K< ${element_a}, ${layout_a}, ${element_b}, ${layout_b}, diff --git a/python/cutlass_library/rank_k_operation.py b/python/cutlass_library/rank_k_operation.py index 5868d20deb..91f9b15b35 100644 --- a/python/cutlass_library/rank_k_operation.py +++ b/python/cutlass_library/rank_k_operation.py @@ -35,12 +35,18 @@ """ import enum -import os.path -import shutil import functools import operator +import os.path +import shutil -from cutlass_library.library import * +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * ################################################################################################### @@ -80,7 +86,7 @@ def is_complex(self): # def is_mixed_input(self): return False - + # def is_planar_complex(self): return False @@ -259,7 +265,7 @@ def __init__(self): def emit(self, operation): threadblock_shape = operation.tile_description.threadblock_shape - + warp_count = operation.tile_description.warp_count warp_shape = [threadblock_shape[idx] // warp_count[idx] for idx in range(3)] diff --git a/python/cutlass_library/symm_operation.py b/python/cutlass_library/symm_operation.py index e97245b19d..2b8f83333a 100644 --- a/python/cutlass_library/symm_operation.py +++ b/python/cutlass_library/symm_operation.py @@ -35,12 +35,18 @@ """ import enum -import os.path -import shutil import functools import operator +import os.path +import shutil -from cutlass_library.library import * +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * ################################################################################################### @@ -82,7 +88,7 @@ def is_complex(self): # def is_mixed_input(self): return self.A.element != self.B.element - + # def is_planar_complex(self): return False @@ -241,7 +247,7 @@ def __init__(self): // Symm operator ${operation_name} using Operation_${operation_name} = typename cutlass::gemm::device::Symm< - ${element_a}, ${layout_a}, ${side_mode}, ${fill_mode}, + ${element_a}, ${layout_a}, ${side_mode}, ${fill_mode}, ${element_b}, ${layout_b}, ${element_c}, ${layout_c}, ${element_accumulator}, diff --git a/python/cutlass_library/trmm_operation.py b/python/cutlass_library/trmm_operation.py index fe2c1f9365..7be18a122f 100644 --- a/python/cutlass_library/trmm_operation.py +++ b/python/cutlass_library/trmm_operation.py @@ -35,12 +35,18 @@ """ import enum -import os.path -import shutil import functools import operator +import os.path +import shutil -from cutlass_library.library import * +try: + import builtins + if hasattr(builtins, "CUTLASS_IGNORE_PACKAGE") and CUTLASS_IGNORE_PACKAGE == True: + raise ImportError("Disabling attempt to import cutlass_library") + from cutlass_library.library import * +except ImportError: + from library import * ################################################################################################### @@ -84,7 +90,7 @@ def is_planar_complex(self): # def is_mixed_input(self): return self.A.element != self.B.element - + # def accumulator_type(self): accum = self.tile_description.math_instruction.element_accumulator diff --git a/python/docker/Dockerfile-cuda12.1-pytorch b/python/docker/Dockerfile-cuda12.1-pytorch deleted file mode 100644 index 884472f5ae..0000000000 --- a/python/docker/Dockerfile-cuda12.1-pytorch +++ /dev/null @@ -1,38 +0,0 @@ -################################################################################################# -# -# Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# 1. Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# 2. Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# 3. Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# -################################################################################################# - -FROM nvcr.io/nvidia/pytorch:23.03-py3 - -RUN chmod ugo+rwx /home -ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH -ENV LIBRARY_PATH=/usr/local/cuda/lib64:$LIBRARY_PATH -ENV CUDA_INSTALL_PATH=/usr/local/cuda diff --git a/python/docs_src/source/install.md b/python/docs_src/source/install.md index 4b5da10517..5d30740dd4 100644 --- a/python/docs_src/source/install.md +++ b/python/docs_src/source/install.md @@ -9,28 +9,25 @@ Prior to installing the CUTLASS Python interface, one may optionally set the fol * `CUDA_INSTALL_PATH`: the path to the installation of CUDA If these environment variables are not set, the installation process will infer them to be the following: -* `CUTLASS_PATH`: one directory level above the current directory (i.e., `$(pwd)/..`) +* `CUTLASS_PATH`: either one directory level above the current directory (i.e., `$(pwd)/..`) if installed locally or in the `source` directory of the location in which `cutlass_library` was installed * `CUDA_INSTALL_PATH`: the directory holding `/bin/nvcc` for the first version of `nvcc` on `$PATH` (i.e., `which nvcc | awk -F'/bin/nvcc' '{print $1}'`) **NOTE:** The version of `cuda-python` installed must match the CUDA version in `CUDA_INSTALL_PATH`. ### Installing a developer-mode package -The CUTLASS Python interface can currently be installed via: +The CUTLASS Python interface can currently be installed by navigating to the root of the CUTLASS directory and performing ```bash -python setup.py develop --user +pip install . ``` -This will allow changes to the Python interface source to be reflected when using the Python interface. -We plan to add support for installing via `python setup.py install` in a future release. +If you would like to be able to make changes to CULASS Python interface and have them reflected when using the interface, perform: +```bash +pip install -e . +``` ## Docker -To ensure that you have all of the necessary Python modules for running the examples using the -CUTLASS Python interface, we recommend using one of the Docker images located in the docker directory. +We recommend using the CUTLASS Python interface via an [NGC PyTorch Docker container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch): -For example, to build and launch a container that uses CUDA 12.1 via an NGC PyTorch container, run: ```bash -docker build -t cutlass-cuda12.1:latest -f docker/Dockerfile-cuda12.1-pytorch . -docker run --gpus all -it --rm cutlass-cuda12.1:latest +docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.08-py3 ``` - -The CUTLASS Python interface has been tested with CUDA 11.8, 12.0, and 12.1 on Python 3.8.10 and 3.9.7. diff --git a/python/setup.py b/python/setup_cutlass.py similarity index 99% rename from python/setup.py rename to python/setup_cutlass.py index 6ff46c4419..cf57b223cc 100644 --- a/python/setup.py +++ b/python/setup_cutlass.py @@ -51,7 +51,7 @@ setup( name='cutlass', - version='3.2.1', + version='3.3.0', description='CUTLASS Pythonic Interface', package_dir={'': '.'}, packages=[ diff --git a/python/setup_library.py b/python/setup_library.py index 2aff4e0c5c..17905f40fa 100644 --- a/python/setup_library.py +++ b/python/setup_library.py @@ -36,7 +36,7 @@ def perform_setup(): setup( name='cutlass_library', - version='3.2.1', + version='3.3.0', description='CUTLASS library generation scripts', packages=['cutlass_library'] ) diff --git a/python/setup_pycute.py b/python/setup_pycute.py index ab7881f067..316dbc8880 100644 --- a/python/setup_pycute.py +++ b/python/setup_pycute.py @@ -36,7 +36,7 @@ def perform_setup(): setup( name='pycute', - version='3.2.1', + version='3.3.0', description='Python implementation of CuTe', packages=['pycute'], ) diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000000..99b7709629 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,30 @@ +[metadata] +name = cutlass +version = 3.3.0.0 + +[options] +packages = + cutlass + cutlass.backend + cutlass.backend.evt + cutlass.backend.evt.backend + cutlass.backend.evt.frontend + cutlass.backend.evt.ir + cutlass.backend.evt.passes + cutlass.backend.utils + cutlass.emit + cutlass.epilogue + cutlass.op + cutlass.utils + cutlass_library + cutlass_library.source + pycute +package_dir = + cutlass=python/cutlass + cutlass_library=python/cutlass_library + cutlass_library.source=. + pycute=python/pycute +include_package_data = True + +[options.package_data] +cutlass_library.source = include/**/*, examples/**/*, tools/**/* diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index bbc31de2a4..33bae61756 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -28,4 +28,8 @@ if (CUTLASS_ENABLE_GTEST_UNIT_TESTS) add_subdirectory(unit) +else() + # Always provide at least the phony test_unit target. + add_custom_target(test_unit) endif() + diff --git a/test/python/cutlass/conv2d/conv2d_problem_sizes.py b/test/python/cutlass/conv2d/conv2d_problem_sizes.py index 3c16f406e9..bf16420726 100644 --- a/test/python/cutlass/conv2d/conv2d_problem_sizes.py +++ b/test/python/cutlass/conv2d/conv2d_problem_sizes.py @@ -36,8 +36,9 @@ This file was ported from the C++ version in test/unit/conv/device/conv2d_problems.h """ +from cutlass_library import ConvMode + import cutlass -from cutlass import ConvMode from cutlass.shape import Conv2DProblemSize diff --git a/test/python/cutlass/conv2d/conv2d_test_utils.py b/test/python/cutlass/conv2d/conv2d_test_utils.py index 8cc288d0c3..ca7a60750b 100644 --- a/test/python/cutlass/conv2d/conv2d_test_utils.py +++ b/test/python/cutlass/conv2d/conv2d_test_utils.py @@ -34,10 +34,11 @@ Utility functions for Conv2d tests. """ +from cutlass_library import SubstituteTemplate import torch import cutlass -from cutlass import ( +from cutlass_library import ( ConvKind, ConvMode, DataType, @@ -50,7 +51,6 @@ ShortLayoutTypeNames, SplitKMode, ) -from cutlass.backend.utils.software import SubstituteTemplate from cutlass.shape import Conv2DProblemSize from cutlass.utils.datatypes import numpy_type, torch_type @@ -301,17 +301,19 @@ def run(self, ps, split_k_mode=SplitKMode.Serial, split_k_slices=1, alpha=1.0, b tensor_B = self.uniform_init(size=tensor_B_size, dtype=self.dtype_B) tensor_C = self.uniform_init(size=tensor_C_size, dtype=self.dtype_C) tensor_D = torch.zeros_like(tensor_C).to(memory_format=torch.channels_last) - self.operation.run(tensor_A, tensor_B, tensor_C, tensor_D, + args = self.operation.run(tensor_A, tensor_B, tensor_C, tensor_D, stride=(ps.stride_h, ps.stride_w), padding=(ps.pad_h, ps.pad_w), dilation=(ps.dilation_h, ps.dilation_w), alpha=alpha, beta=beta, split_k=(split_k_mode, split_k_slices)) + args.sync() + tensor_D_ref = self.reference(ps, tensor_A, tensor_B, tensor_C, alpha, beta, self.activation) torch.cuda.synchronize() - passed = torch.equal(tensor_D, tensor_D_ref) + passed = torch.allclose(tensor_D, tensor_D_ref, atol=2e-06) return passed @@ -378,7 +380,8 @@ def run(self): conv2d_launcher = Conv2dLauncherFrontend(plan, 80, backend="torch") for ps in problem_sizes: - if not validate_problem_size(ps, conv_kind, split_k_slices): continue + if not validate_problem_size(ps, conv_kind, split_k_slices): + continue self.assertTrue(conv2d_launcher.run(ps, split_k_mode, split_k_slices, 1.0, 2.0)) diff --git a/test/python/cutlass/emit/pytorch.py b/test/python/cutlass/emit/pytorch.py index 8f6c2c6db7..30c692a18a 100644 --- a/test/python/cutlass/emit/pytorch.py +++ b/test/python/cutlass/emit/pytorch.py @@ -38,9 +38,11 @@ import tempfile import unittest +from cutlass_library import ConvMode + import cutlass -if cutlass.utils.datatypes.torch_available: +if cutlass.utils.datatypes.is_torch_available(): import torch @@ -88,7 +90,7 @@ def _generate_problems(dtype, num): def _generate_conv2d_problem(conv_kind, dtype, ps): """ Utility function to generate conv2d inputs - + :param conv_kind: kind of convolution :type conv_kind: str :param dtype: data type of tensors @@ -114,7 +116,7 @@ def _generate_conv2d_problem(conv_kind, dtype, ps): return [torch.ceil(torch.empty(size, dtype=dtype, device='cuda').uniform_(-4.5, 3.5)).to(memory_format=torch.channels_last) for size in sizes] -@unittest.skipIf(not cutlass.utils.datatypes.torch_available, 'PyTorch must be available to run PyTorch extension tests') +@unittest.skipIf(not cutlass.utils.datatypes.is_torch_available(), 'PyTorch must be available to run PyTorch extension tests') class PyTorchExtensionTest(unittest.TestCase): def test_gemm(self): @@ -183,18 +185,18 @@ def check_all(X, Y): Ds_ref = [(a @ b) * alpha + (beta * c) for a, b, c in zip(As, Bs, Cs)] Ds = mod.run(As, Bs, Cs, alpha, beta) check_all(Ds, Ds_ref) - + def test_conv2d_fprop(self): torch.manual_seed(2023) - + dtype = torch.float16 plan = cutlass.op.Conv2d(kind="fprop", element=dtype, element_accumulator=torch.float32) plan.activation = "relu" - + op = plan.construct() with tempfile.TemporaryDirectory() as tmpdir: mod = cutlass.emit.pytorch(op, name="conv2d_mod", cc=plan.cc, sourcedir=tmpdir, jit=True) - + problem_size = cutlass.shape.Conv2DProblemSize( 1, 4, 4, 16, 8, 3, 3, 16, @@ -202,50 +204,50 @@ def test_conv2d_fprop(self): 3, 3, 1, 1 ) - + A, B, C = _generate_conv2d_problem("fprop", dtype, problem_size) stride = (problem_size.stride_h, problem_size.stride_w) padding = (problem_size.pad_h, problem_size.pad_w) alpha = 1.0 beta = 0.5 - + D_ref = alpha * torch.ops.aten.conv2d( A, B, stride=stride, padding=padding ) + beta * C D_ref = torch.nn.functional.relu(D_ref) D = mod.run(A, B, C, stride, padding, alpha=alpha, beta=beta) - - assert torch.allclose(D, D_ref) - + + assert torch.allclose(D, D_ref) + # Test serial split-K D_serial_split_k = mod.run(A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="serial", split_k_slices=3) assert torch.allclose(D, D_serial_split_k) - + # Test parallel split-K D_parallel_split_k = mod.run(A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="parallel", split_k_slices=7) assert torch.allclose(D, D_parallel_split_k) - - + + def test_conv2d_dgrad(self): torch.manual_seed(2023) dtype = torch.float16 plan = cutlass.op.Conv2d(kind="dgrad", element=dtype, element_accumulator=torch.float32) - + op = plan.construct() with tempfile.TemporaryDirectory() as tmpdir: mod = cutlass.emit.pytorch(op, name="conv2d_dgrad_mod", cc=plan.cc, sourcedir=tmpdir, jit=True) - + problem_size = cutlass.shape.Conv2DProblemSize( 1, 4, 4, 16, 8, 3, 3, 16, 0, 0, 3, 3, 1, 1, - cutlass.ConvMode.CrossCorrelation, + ConvMode.CrossCorrelation, 1, 1 ) - + A, B, C = _generate_conv2d_problem("dgrad", dtype, problem_size) stride = (problem_size.stride_h, problem_size.stride_w) padding = (problem_size.pad_h, problem_size.pad_w) @@ -254,32 +256,32 @@ def test_conv2d_dgrad(self): beta = 0.5 input_size = (problem_size.N, problem_size.C, problem_size.H, problem_size.W) D_ref = alpha * torch.nn.grad.conv2d_input( - input_size, B, A, + input_size, B, A, stride=stride, padding=padding ) + beta * C D = mod.run(input_size, A, B, C, stride, padding, alpha=alpha, beta=beta, ) - - assert torch.allclose(D, D_ref) - + + assert torch.allclose(D, D_ref) + def test_conv2d_wgrad(self): torch.manual_seed(2023) dtype = torch.float16 plan = cutlass.op.Conv2d(kind="wgrad", element=dtype, element_accumulator=torch.float32) - + op = plan.construct() with tempfile.TemporaryDirectory() as tmpdir: mod = cutlass.emit.pytorch(op, name="conv2d_wgrad_mod", cc=plan.cc, sourcedir=tmpdir, jit=True) - + problem_size = cutlass.shape.Conv2DProblemSize( 1, 4, 4, 16, 8, 3, 3, 16, 0, 0, 3, 3, 1, 1, - cutlass.ConvMode.CrossCorrelation, + ConvMode.CrossCorrelation, 1, 1 ) - + A, B, C = _generate_conv2d_problem("wgrad", dtype, problem_size) stride = (problem_size.stride_h, problem_size.stride_w) padding = (problem_size.pad_h, problem_size.pad_w) @@ -288,17 +290,17 @@ def test_conv2d_wgrad(self): beta = 0.5 weight_size = (problem_size.K, problem_size.C, problem_size.R, problem_size.S) D_ref = alpha * torch.nn.grad.conv2d_weight( - B, weight_size, A, + B, weight_size, A, stride=stride, padding=padding ) + beta * C D = mod.run(weight_size, A, B, C, stride, padding, alpha=alpha, beta=beta) - - assert torch.allclose(D, D_ref) - + + assert torch.allclose(D, D_ref) + # Test serial split-K D_serial_split_k = mod.run(weight_size, A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="serial", split_k_slices=3) assert torch.allclose(D, D_serial_split_k) - + # Test parallel split-K D_parallel_split_k = mod.run(weight_size, A, B, C, stride, padding, alpha=alpha, beta=beta, split_k_mode="parallel", split_k_slices=7) assert torch.allclose(D, D_parallel_split_k) diff --git a/test/python/cutlass/evt/utils/evt_testbed.py b/test/python/cutlass/evt/utils/evt_testbed.py index ea2ecc947e..555f860d30 100644 --- a/test/python/cutlass/evt/utils/evt_testbed.py +++ b/test/python/cutlass/evt/utils/evt_testbed.py @@ -40,9 +40,9 @@ import cutlass from cutlass import Tensor import cutlass.backend.evt -from cutlass.profiler import CUDAEventProfiler from cutlass.shape import GemmCoord from cutlass.utils.datatypes import torch_type +from cutlass.utils.profiler import CUDAEventProfiler class EVTReferenceModule: diff --git a/test/python/cutlass/gemm/gemm_batched.py b/test/python/cutlass/gemm/gemm_batched.py index 77592740b6..51798b32ed 100644 --- a/test/python/cutlass/gemm/gemm_batched.py +++ b/test/python/cutlass/gemm/gemm_batched.py @@ -43,7 +43,7 @@ from cutlass.backend.utils.device import device_cc import torch -from utils import LayoutCombination, add_test_gemm +from utils import LayoutCombination cutlass.set_log_level(logging.WARNING) diff --git a/test/python/cutlass/gemm/gemm_f16_sm80.py b/test/python/cutlass/gemm/gemm_f16_sm80.py index e2ec3718a3..02de6da480 100644 --- a/test/python/cutlass/gemm/gemm_f16_sm80.py +++ b/test/python/cutlass/gemm/gemm_f16_sm80.py @@ -67,58 +67,58 @@ class GemmF16Sm80StreamK(unittest.TestCase): # Tests using TensorOp add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.NTT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TTT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 128, 32], warp_count=[1, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 64, 32], warp_count=[2, 1, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass.DataType.f16, +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass.DataType.f16, +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[4, 4, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=5) -add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[2, 2, 2], element_output=cutlass.DataType.f16, +add_test_tensorop(cls=GemmF16Sm80, layouts=LayoutCombination.TNT, alignments=[2, 2, 2], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) # Tests using SIMT add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt) -add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) -add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) -add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) -add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) -add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, +add_test_simt(cls=GemmF16Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) # Stream K tests add_test_streamk = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK) -add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, +add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.NNN, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, +add_test_streamk(cls=GemmF16Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, element_C=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=5) if __name__ == '__main__': diff --git a/test/python/cutlass/gemm/gemm_f16_sm90.py b/test/python/cutlass/gemm/gemm_f16_sm90.py index 7df305267a..0e8fe94519 100644 --- a/test/python/cutlass/gemm/gemm_f16_sm90.py +++ b/test/python/cutlass/gemm/gemm_f16_sm90.py @@ -135,6 +135,10 @@ class GemmF16Sm90(unittest.TestCase): add_test_simt(layouts=LayoutCombination.TTN, element_output=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 8]) add_test_simt(layouts=LayoutCombination.NNT, element_output=cutlass.DataType.f16, element_accumulator=cutlass.DataType.f16, threadblock_shape=[128, 128, 8]) +# Tests with void-C kernels +add_test_cluster_shape(layouts=LayoutCombination.NNT, alignments=[8, 8, 8], element_output=cutlass.DataType.f16, + element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], stages=None, + cluster_shape=[2, 1, 1], element_C=cutlass.DataType.void) if __name__ == '__main__': unittest.main() diff --git a/test/python/cutlass/gemm/gemm_f32_sm80.py b/test/python/cutlass/gemm/gemm_f32_sm80.py index 0dca12c00b..903965a3d7 100644 --- a/test/python/cutlass/gemm/gemm_f32_sm80.py +++ b/test/python/cutlass/gemm/gemm_f32_sm80.py @@ -68,31 +68,31 @@ class GemmF32Sm80StreamK(unittest.TestCase): # Tests using TensorOp add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp) -add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, +add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32, element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, +add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32, element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) -add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, +add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32, element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 128, 32], warp_count=[1, 2, 1], stages=3) -add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, +add_test_tensorop(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32, element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 32], warp_count=[1, 1, 1], stages=4) # Tests using SIMT add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt) -add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32, element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) -add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32, element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) -add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32, element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) -add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32, element_accumulator=cutlass.DataType.f32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) -add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, +add_test_simt(cls=GemmF32Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32, element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) # Stream K tests add_test_streamk = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK) -add_test_streamk(cls=GemmF32Sm80StreamK, layouts=LayoutCombination.TTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, +add_test_streamk(cls=GemmF32Sm80StreamK, layouts=LayoutCombination.TTN, alignments=[4, 4, 4], element_output=cutlass.DataType.f32, element_C=cutlass.DataType.f32, element_accumulator=cutlass.DataType.f32, threadblock_shape=[128, 128, 32], warp_count=[2, 2, 1], stages=3) diff --git a/test/python/cutlass/gemm/gemm_f64_sm80.py b/test/python/cutlass/gemm/gemm_f64_sm80.py index 32c0348359..67049b94b6 100644 --- a/test/python/cutlass/gemm/gemm_f64_sm80.py +++ b/test/python/cutlass/gemm/gemm_f64_sm80.py @@ -68,30 +68,30 @@ class GemmF64Sm80StreamK(unittest.TestCase): # Tests using TensorOp add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp) -add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, +add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64, element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3) -add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, +add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64, element_accumulator=cutlass.DataType.f64, threadblock_shape=[ 64, 64, 16], warp_count=[2, 2, 1], stages=4) -add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, +add_test_tensorop(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64, element_accumulator=cutlass.DataType.f64, threadblock_shape=[ 32, 32, 16], warp_count=[2, 1, 1], stages=5) # Tests using SIMT add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt) -add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64, element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) -add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64, element_accumulator=cutlass.DataType.f64, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) -add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64, element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) -add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64, element_accumulator=cutlass.DataType.f64, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) -add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, +add_test_simt(cls=GemmF64Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64, element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) # Stream K tests add_test_streamk = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK) -add_test_streamk(cls=GemmF64Sm80StreamK, layouts=LayoutCombination.NTT, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, +add_test_streamk(cls=GemmF64Sm80StreamK, layouts=LayoutCombination.NTT, alignments=[1, 1, 1], element_output=cutlass.DataType.f64, element_C=cutlass.DataType.f64, element_accumulator=cutlass.DataType.f64, threadblock_shape=[128, 128, 16], warp_count=[4, 2, 1], stages=3) diff --git a/python/docker/Dockerfile-cuda11.8-pytorch b/test/python/cutlass/gemm/gemm_mixed_sm80.py similarity index 57% rename from python/docker/Dockerfile-cuda11.8-pytorch rename to test/python/cutlass/gemm/gemm_mixed_sm80.py index c573dfe7aa..152f8eb42b 100644 --- a/python/docker/Dockerfile-cuda11.8-pytorch +++ b/test/python/cutlass/gemm/gemm_mixed_sm80.py @@ -30,11 +30,43 @@ # ################################################################################################# -FROM nvcr.io/nvidia/pytorch:22.11-py3 - -RUN chmod ugo+rwx /home -RUN pip uninstall -y rmm -RUN pip install rmm-cu11 --extra-index-url=https://pypi.ngc.nvidia.com -ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH -ENV LIBRARY_PATH=/usr/local/cuda/lib64:$LIBRARY_PATH -ENV CUDA_INSTALL_PATH=/usr/local/cuda +""" +Low-level functionality tests for GEMM with mixed operands on SM80 +""" + +from functools import partial +import logging +import unittest + +import cutlass +from cutlass.backend.utils.device import device_cc + +from utils import LayoutCombination, add_test_gemm + + +cutlass.set_log_level(logging.WARNING) +cc = 80 + +@unittest.skipIf(device_cc() < cc, 'Device compute capability is insufficient for SM80 tests.') +class GemmMixedSm80(unittest.TestCase): + """ + Wrapper class to which tests will be added dynamically in __main__ + """ + pass + + +add_test_mixed = partial(add_test_gemm, cls=GemmMixedSm80, element=cutlass.DataType.f16, cc=cc, cluster_shape=[1, 1, 1], + opclass=cutlass.OpcodeClass.TensorOp, threadblock_shape=[128, 128, 64], + warp_count=[2, 2, 1], stages=3, element_accumulator=cutlass.DataType.f32) + +# Test with upcast on A +add_test_mixed(element_A=cutlass.DataType.s8, alignments=[16, 8, 8], layouts=LayoutCombination.TNT) +add_test_mixed(element_A=cutlass.DataType.s8, alignments=[16, 8, 8], layouts=LayoutCombination.TNN) + +# Test with upcast on B +add_test_mixed(element_B=cutlass.DataType.s8, alignments=[8, 16, 8], layouts=LayoutCombination.TNT) +add_test_mixed(element_B=cutlass.DataType.s8, alignments=[8, 16, 8], layouts=LayoutCombination.TNN) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/python/cutlass/gemm/gemm_s8_sm80.py b/test/python/cutlass/gemm/gemm_s8_sm80.py index f98770a056..d9b929f9d6 100644 --- a/test/python/cutlass/gemm/gemm_s8_sm80.py +++ b/test/python/cutlass/gemm/gemm_s8_sm80.py @@ -68,30 +68,30 @@ class GemmS8Sm80StreamK(unittest.TestCase): # Tests using TensorOp add_test_tensorop = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp) -add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, +add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, element_C=cutlass.DataType.s8, element_accumulator=cutlass.DataType.s32, threadblock_shape=[256, 128, 64], warp_count=[4, 2, 1], stages=3) -add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, +add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, element_C=cutlass.DataType.s8, element_accumulator=cutlass.DataType.s32, threadblock_shape=[128, 256, 64], warp_count=[2, 4, 1], stages=3) -add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 4], element_output=cutlass.DataType.s32, +add_test_tensorop(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[16, 16, 4], element_output=cutlass.DataType.s32, element_C=cutlass.DataType.s32, element_accumulator=cutlass.DataType.s32, threadblock_shape=[ 64, 64, 64], warp_count=[1, 1, 1], stages=4) # Tests using SIMT add_test_simt = partial(add_test_specialized, opclass=cutlass.OpcodeClass.Simt) -add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.s8, +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNN, alignments=[1, 1, 1], element_output=cutlass.DataType.s8, element_C=cutlass.DataType.s8, element_accumulator=cutlass.DataType.s32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) -add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.s8, +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TNN, alignments=[1, 1, 1], element_output=cutlass.DataType.s8, element_C=cutlass.DataType.s8, element_accumulator=cutlass.DataType.s32, threadblock_shape=[ 64, 128, 8], warp_count=[1, 2, 1], stages=2) -add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.s8, +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NTN, alignments=[1, 1, 1], element_output=cutlass.DataType.s8, element_C=cutlass.DataType.s8, element_accumulator=cutlass.DataType.s32, threadblock_shape=[128, 64, 8], warp_count=[2, 1, 1], stages=2) -add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.s32, +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.TTN, alignments=[1, 1, 1], element_output=cutlass.DataType.s32, element_C=cutlass.DataType.s32, element_accumulator=cutlass.DataType.s32, threadblock_shape=[ 64, 64, 8], warp_count=[1, 1, 1], stages=2) -add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.s32, +add_test_simt(cls=GemmS8Sm80, layouts=LayoutCombination.NNT, alignments=[1, 1, 1], element_output=cutlass.DataType.s32, element_C=cutlass.DataType.s32, element_accumulator=cutlass.DataType.s32, threadblock_shape=[128, 128, 8], warp_count=[2, 2, 1], stages=2) # Stream K tests add_test_streamk = partial(add_test_specialized, opclass=cutlass.OpcodeClass.TensorOp, swizzle=cutlass.swizzle.ThreadblockSwizzleStreamK) -add_test_streamk(cls=GemmS8Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, +add_test_streamk(cls=GemmS8Sm80StreamK, layouts=LayoutCombination.TNT, alignments=[16, 16, 16], element_output=cutlass.DataType.s8, element_C=cutlass.DataType.s8, element_accumulator=cutlass.DataType.s32, threadblock_shape=[128, 256, 64], warp_count=[2, 4, 1], stages=3) diff --git a/test/python/cutlass/gemm/gemm_testbed.py b/test/python/cutlass/gemm/gemm_testbed.py index ac0a551d21..2507cd750b 100644 --- a/test/python/cutlass/gemm/gemm_testbed.py +++ b/test/python/cutlass/gemm/gemm_testbed.py @@ -37,7 +37,7 @@ import torch -from cutlass import ( +from cutlass_library import ( DataType, DataTypeSize, GemmUniversalMode, @@ -49,7 +49,6 @@ from cutlass.backend import compiler from cutlass.backend.gemm_operation import GemmArguments, GemmOperationUniversal -from cutlass.backend.memory_manager import get_allocated_size from cutlass.backend.reduction_operation import ReductionArguments, ReductionOperation from cutlass.shape import GemmCoord, MatrixCoord from cutlass.utils.datatypes import torch_type @@ -65,16 +64,6 @@ def __init__( compiler_mode= "nvcc", **kwargs, ) -> None: - # Create the reduction kernel, if needed - self.reduction_operation: ReductionOperation = ReductionOperation( - shape=MatrixCoord(4, 32 * operation.C.alignment), - C=operation.C, - element_accumulator=operation.tile_description.math_instruction.element_accumulator, - element_compute=operation.epilogue_functor.element_epilogue, - epilogue_functor=operation.epilogue_functor, - count=operation.C.alignment, - ) - self.math_operation = operation.tile_description.math_instruction.math_operation self.verification = verification @@ -88,19 +77,26 @@ def __init__( op_list = [operation] if operation.arch < 90: # Split K via Python is currently only supported for pre-SM90 kernels + self.reduction_operation: ReductionOperation = ReductionOperation( + shape=MatrixCoord(4, 32 * operation.C.alignment), + C=operation.C, + element_accumulator=operation.tile_description.math_instruction.element_accumulator, + element_compute=operation.epilogue_functor.element_epilogue, + epilogue_functor=operation.epilogue_functor, + count=operation.C.alignment, + ) op_list.append(self.reduction_operation) compiler.add_module(op_list, bypass_cache=False) self.operation = operation - self.dtype_A = torch_type(operation.A.element) - self.dtype_B = torch_type(operation.B.element) + self.dtype_A = torch_type(operation.A.element if not self.operation.switched else self.operation.B.element) + self.dtype_B = torch_type(operation.B.element if not self.operation.switched else self.operation.A.element) self.dtype_C = torch_type(operation.C.element) - self.dtype_D = torch_type(operation.C.element) + self.dtype_D = torch_type(operation.epilogue_functor.element_output) - accumulator_size = DataTypeSize[operation.tile_description.math_instruction.element_accumulator] - element_size = DataTypeSize[operation.A.element] + element_size = min(DataTypeSize[operation.A.element], DataTypeSize[operation.B.element]) if element_size == 1: self.rand_max = 1 @@ -154,7 +150,18 @@ def uniform_init(self, shape, dtype, layout): def reference(self, problem_size, tensor_A, tensor_B, tensor_C, alpha, beta): # If any tensor is on CPU, place all tensors on CPU unless only # tensor C is on CPU - devices = [x.device.type for x in [tensor_A, tensor_B, tensor_C]] + # Handle mixed-input cases by casting to the larger data type and overriding + # to whatever the data type of the larger type is + if self.dtype_A != self.dtype_B: + if DataTypeSize[self.operation.A.element] < DataTypeSize[self.operation.B.element]: + tensor_A = tensor_A.to(self.dtype_B).to(tensor_B.device) + else: + tensor_B = tensor_B.to(self.dtype_A).to(tensor_A.device) + + devices = [x.device.type for x in [tensor_A, tensor_B]] + if tensor_C is not None: + devices.append(tensor_C.device.type) + if "cpu" in devices and devices != ["cuda", "cuda", "cpu"]: device = torch.device("cpu") else: @@ -162,14 +169,17 @@ def reference(self, problem_size, tensor_A, tensor_B, tensor_C, alpha, beta): tensor_A = tensor_A.to(device) tensor_B = tensor_B.to(device) - tensor_C = tensor_C.to(device) + if tensor_C is not None: + tensor_C = tensor_C.to(device) dtype = torch_type(self.compute_type) alpha_torch = torch.tensor([alpha], device=device).to(dtype) beta_torch = torch.tensor([beta], device=device).to(dtype) tmp = tensor_A @ tensor_B - tensor_D_ref = (alpha_torch * tmp) + (tensor_C * beta_torch) + tensor_D_ref = (alpha_torch * tmp) + if tensor_C is not None: + tensor_D_ref += (tensor_C * beta_torch) return tensor_D_ref.to(self.dtype_D) def run(self, mode, problem_size, batch_count=1, split_k_slices=1, alpha=1.0, beta=0.0): @@ -199,12 +209,22 @@ def transpose(layout): self.dtype_B, self.operation.B.layout if not self.operation.switched else transpose(self.operation.A.layout), ) - tensor_C, tensor_C_ref = self.uniform_init( + if self.dtype_C is not None: + tensor_C, tensor_C_ref = self.uniform_init( + (true_batch_count, problem_size.m, problem_size.n), + self.dtype_C, + self.operation.C.layout if not self.operation.switched else transpose(self.operation.C.layout), + ) + else: + tensor_C = None + tensor_C_ref = None + + tensor_D, _ = self.uniform_init( (true_batch_count, problem_size.m, problem_size.n), - self.dtype_C, + self.dtype_D, self.operation.C.layout if not self.operation.switched else transpose(self.operation.C.layout), ) - tensor_D = torch.zeros_like(tensor_C) + tensor_D = torch.zeros_like(tensor_D) if self.compute_type in [DataType.s8, DataType.s32, DataType.u8, DataType.u32]: alpha = int(alpha) @@ -248,6 +268,10 @@ def transpose(layout): if self.verification: if mode == GemmUniversalMode.GemmSplitKParallel: reduction_arguments.sync() + + # Free memory allocated by args because we are not + # calling `arguments.sync()` in this case (which will free memory) + arguments.free() else: arguments.sync() tensor_D_ref = self.reference( @@ -274,9 +298,6 @@ def transpose(layout): if mode == GemmUniversalMode.GemmSplitKParallel: del reduction_arguments - cur_size = get_allocated_size() - assert cur_size == 0, f"{cur_size} B of memory were not released after this run" - return passed diff --git a/test/python/cutlass/gemm/utils.py b/test/python/cutlass/gemm/utils.py index 7282fe5a50..abc8bac28a 100644 --- a/test/python/cutlass/gemm/utils.py +++ b/test/python/cutlass/gemm/utils.py @@ -30,9 +30,10 @@ # ################################################################################################# -import cutlass +from cutlass_library import SubstituteTemplate -from cutlass import ( +import cutlass +from cutlass_library import ( DataTypeNames, EpilogueScheduleSuffixes, KernelScheduleSuffixes, @@ -42,7 +43,6 @@ ShortLayoutTypeNames ) from cutlass.backend import library -from cutlass.backend.utils.software import SubstituteTemplate from gemm_testbed import test_all_gemm @@ -82,6 +82,7 @@ def get_name( stages, element_a, element_b, + element_c, arch, opclass, kernel_schedule=None, @@ -102,6 +103,7 @@ def get_name( :type stages: int :param element_a: data type of operand A :param element_b: data type of operand B + :param element_c: data type of operand C :param arch: compute capability of kernel being generated :type arch: int :param opclass: class of operation being performed (e.g., SIMT, Tensor Core) @@ -122,7 +124,7 @@ def get_name( "arch": str(arch), "eA": DataTypeNames[element_a], "eB": DataTypeNames[element_b], - "eC": DataTypeNames[element_output], + "eC": DataTypeNames[element_c], "lA": ShortLayoutTypeNames[layouts[0]], "lB": ShortLayoutTypeNames[layouts[1]], "lC": ShortLayoutTypeNames[layouts[2]], @@ -161,7 +163,10 @@ def add_test_gemm( swizzle=None, kernel_schedule=None, epilogue_schedule=None, - compilation_modes=['nvcc', 'nvrtc']): + compilation_modes=['nvcc', 'nvrtc'], + element_A=None, + element_B=None, + element_C=None): """ Create test-running functions with the given specification and set it as a method of ``cls``. @@ -195,22 +200,38 @@ def add_test_gemm( :param epilogue_schedule: epilogue schedule to use :type epilogue_schedule: cutlass.EpilogueScheduleType :param compilation_modes: list of compilers to used in testing the kernel (options: 'nvrtc', 'nvcc') - :type compilation_modes: list + :type compilation_modes: list, + :param element_A: data type of operand A. If set, overrides ``element`` + :type element_A: cutlass.DataType + :param element_B: data type of operand B. If set, overrides ``element`` + :type element_B: cutlass.DataType + :param element_C: data type of operand C. If set, overrides ``element`` + :type element_C: cutlass.DataType """ + if element_A is None: + element_A = element + if element_B is None: + element_B = element + if element_C is None: + element_C = element + if element_output is None: + element_output = element + if element_accumulator is None: + element_accumulator = element + for compilation_mode in compilation_modes: def run(self): """ Dynamically-generated function that constructs a GEMM operation and verifies it against multiple test cases. """ - element_A = element - element_B = element + layout_A, layout_B, layout_C = layouts alignment_A, alignment_B, alignment_C = alignments plan = cutlass.op.Gemm(element_A=element_A, element_B=element_B, - element_C=element_output, element_D=element_output, + element_C=element_C, element_D=element_output, layout_A=layout_A, layout_B=layout_B, layout_C=layout_C, element_accumulator=element_accumulator, kernel_cc=cc) @@ -233,7 +254,7 @@ def run(self): name = get_name( layouts=layouts, alignments=alignments, element_output=element_output, element_accumulator=element_accumulator, element_epilogue=element_epilogue, cluster_shape=cluster_shape, threadblock_shape=threadblock_shape, - stages=stages, element_a=element, element_b=element, arch=cc, opclass=opclass, + stages=stages, element_a=element_A, element_b=element_B, element_c=element_C, arch=cc, opclass=opclass, kernel_schedule=kernel_schedule, epilogue_schedule=epilogue_schedule, suffix=f'_{compilation_mode}') setattr(cls, name, run) diff --git a/python/docker/Dockerfile-cuda12.0-pytorch b/test/python/cutlass/installation.py similarity index 68% rename from python/docker/Dockerfile-cuda12.0-pytorch rename to test/python/cutlass/installation.py index a9a84bf36c..b63ee1c612 100644 --- a/python/docker/Dockerfile-cuda12.0-pytorch +++ b/test/python/cutlass/installation.py @@ -30,9 +30,28 @@ # ################################################################################################# -FROM nvcr.io/nvidia/pytorch:23.01-py3 +""" +Tests for a successful installation of the CUTLASS Python interface +""" -RUN chmod ugo+rwx /home -ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH -ENV LIBRARY_PATH=/usr/local/cuda/lib64:$LIBRARY_PATH -ENV CUDA_INSTALL_PATH=/usr/local/cuda +import os +import unittest + +import cutlass +import cutlass_library + + +class InstallationTest(unittest.TestCase): + def test_cutlass_source_paths(self): + """ + Tests that CUTLASS source is available as part of the cutlass and cutlass_library packages + """ + src_file = 'include/cutlass/cutlass.h' + library_file = os.path.join(cutlass_library.source_path, src_file) + cutlass_file = os.path.join(cutlass.CUTLASS_PATH, src_file) + assert os.path.isfile(library_file), f"Unable to locate file {library_file}. Installation has not succeeded." + assert os.path.isfile(cutlass_file), f"Unable to locate file {cutlass_file}. Installation has not succeeded." + + +if __name__ == "__main__": + unittest.main() diff --git a/test/python/cutlass/interface/conv2d_interface.py b/test/python/cutlass/interface/conv2d_interface.py index 4937c4a0e0..23da5c93fc 100644 --- a/test/python/cutlass/interface/conv2d_interface.py +++ b/test/python/cutlass/interface/conv2d_interface.py @@ -50,7 +50,7 @@ class Conv2dEquivalence: """ def __init__(self, conv_kind, element_A, element_B, element_C, element_D, element_accumulator, alignment_A, alignment_B, alignment_C): - + self.element_A = element_A self.element_B = element_B self.element_C = element_C @@ -59,21 +59,21 @@ def __init__(self, conv_kind, element_A, element_B, element_C, element_D, elemen self.alignment_A = alignment_A self.alignment_B = alignment_B self.alignment_C = alignment_C - + self.conv_kind = conv_kind - + self.plan = cutlass.op.Conv2d( kind=self.conv_kind, element_A=element_A, element_B=element_B, element_C=element_C, element_D=element_D, element_accumulator=element_accumulator) - + self.op = self.plan.construct( - alignment_A=self.alignment_A, alignment_B=self.alignment_B, + alignment_A=self.alignment_A, alignment_B=self.alignment_B, alignment_C=self.alignment_C) - + def _plans_equal(self, other_plan) -> bool: """ Compares whether two plans are equal - + :param other_plan: plan to compare against the default Conv2d :type other_plan: cutlass.op.Conv2d @@ -81,9 +81,9 @@ def _plans_equal(self, other_plan) -> bool: :rtype: bool """ other_op = other_plan.construct( - alignment_A=self.alignment_A, alignment_B=self.alignment_B, + alignment_A=self.alignment_A, alignment_B=self.alignment_B, alignment_C=self.alignment_C) - + return self.op.rt_module.emit() == other_op.rt_module.emit() def generic_test(self): @@ -91,16 +91,16 @@ def generic_test(self): Tests the equivalence of various constructions of the Conv2d interface when using CUTLASS data types and layouts for constructing the Conv2d interface """ - if not datatypes.numpy_available: + if not datatypes.is_numpy_available(): return - + # Test when specifying all parameters plan_other = cutlass.op.Conv2d( kind=self.conv_kind, element_A=self.element_A, element_B=self.element_B, element_C=self.element_C, element_D=self.element_D, element_accumulator=self.element_accumulator) assert self._plans_equal(plan_other) - + # Test when specifying all parameters but A plan_other = cutlass.op.Conv2d( kind=self.conv_kind, @@ -108,7 +108,7 @@ def generic_test(self): element_D=self.element_D, element_accumulator=self.element_accumulator, element=self.element_A) assert self._plans_equal(plan_other) - + # Test when specifying all parameters but A and B as tensors using generic element and output plan_other = cutlass.op.Conv2d( kind=self.conv_kind, @@ -116,7 +116,7 @@ def generic_test(self): element_D=self.element_D, element_accumulator=self.element_accumulator, element=self.element_A) assert self._plans_equal(plan_other) - + # Test without explicit accumulator. Only run if the type of C and the accumulator are equal if self.element_C == self.element_accumulator: plan_other = cutlass.op.Conv2d( @@ -125,18 +125,18 @@ def generic_test(self): element_D=self.element_D, element=self.element_A) assert self._plans_equal(plan_other) - + # Test with only the generic types. Only rune if the types of A, B, C, and D are the same if (self.element_A == self.element_B and self.element_A == self.element_C and self.element_A == self.element_D and self.element_A == self.element_accumulator): plan_other = cutlass.op.Conv2d(kind=self.conv_kind, element=self.element_A) assert self._plans_equal(plan_other) - + def numpy_test(self): """ Tests the equivalence of various constructions of the Conv2d interface when using numpy as a frontend """ - if not datatypes.numpy_available: + if not datatypes.is_numpy_available(): return import numpy as np @@ -145,7 +145,7 @@ def numpy_test(self): type_C = datatypes.numpy_type(self.element_C) type_D = datatypes.numpy_type(self.element_D) type_accum = datatypes.numpy_type(self.element_accumulator) - + size = (2, 2) A = np.zeros(size, dtype=type_A) B = np.zeros(size, dtype=type_B) @@ -153,49 +153,49 @@ def numpy_test(self): D = np.zeros(size, dtype=type_D) return self.tensor_test(type_A, type_B, type_C, type_D, type_accum, A, B, C, D) - + def torch_test(self): """ Tests the equivalence of various constructions of the Conv2d interface when using torch as a frontend """ - if not datatypes.torch_available: + if not datatypes.is_torch_available(): return - + import torch type_A = datatypes.torch_type(self.element_A) type_B = datatypes.torch_type(self.element_B) type_C = datatypes.torch_type(self.element_C) type_D = datatypes.torch_type(self.element_D) type_accum = datatypes.torch_type(self.element_accumulator) - + size = (2, 2) - + A = torch.empty(size, dtype=type_A) B = torch.empty(size, dtype=type_B) C = torch.empty(size, dtype=type_C) D = torch.empty(size, dtype=type_D) - + return self.tensor_test(type_A, type_B, type_C, type_D, type_accum, A, B, C, D) - + def tensor_test(self, type_A, type_B, type_C, type_D, type_accum, A, B, C, D): # Test when specifying all parameters via tensors plan_np = cutlass.op.Conv2d(kind=self.conv_kind, A=A, B=B, C=C, D=D, element_accumulator=type_accum) assert self._plans_equal(plan_np) - + # Test when specifying all parameters but A as tensors plan_np = cutlass.op.Conv2d(kind=self.conv_kind, B=B, C=C, D=D, element_accumulator=type_accum, element_A=type_A) assert self._plans_equal(plan_np) - + # Test when specifying all parameters but A and B as tensors and using generic element and output if type_A == type_B: plan_np = cutlass.op.Conv2d(kind=self.conv_kind, C=C, D=D, element_accumulator=type_accum, element=type_A) assert self._plans_equal(plan_np) - + # Test without explicit accumulator. Only run if the type of C and the accumulator. if type_C == type_accum: plan_np = cutlass.op.Conv2d(kind=self.conv_kind, A=A, B=B, C=C, D=D) assert self._plans_equal(plan_np) - + # Test with only the generic types and layouts. Only run if types and layouts of A, B, C, and D are the same. if (type_A == type_B and type_A == type_C and type_A == type_D and type_A == type_accum): plan_np = cutlass.op.Conv2d(kind=self.conv_kind, element=type_A) @@ -223,20 +223,20 @@ class ConvEquivalenceTest(unittest.TestCase): } def add_test(conv_kind, element_A, element_B, element_C, element_D, element_accumulator): - + test_name = f"test_conv2d_{conv_kind}_{element_A}_{element_B}_{element_C}_{element_D}_{element_accumulator}" - + def run(self): conv2d_eq = Conv2dEquivalence( - conv_kind=conv_kind, + conv_kind=conv_kind, element_A=element_A, element_B=element_B, element_C=element_C, element_D=element_D, - element_accumulator=element_accumulator, + element_accumulator=element_accumulator, alignment_A=type2alignment[element_A], alignment_B=type2alignment[element_B], alignment_C=type2alignment[element_C] ) conv2d_eq.test_all() - + setattr(ConvEquivalenceTest, test_name, run) for conv_kind in ["fprop", "wgrad", "dgrad"]: @@ -255,25 +255,25 @@ class Conv2dErrorTests(unittest.TestCase): """ Tests various error scenarios that arise with the high-level Gemm interface """ - + def test_alignment(self): """ Tests case in which the alignment specified is unsupported """ plan = cutlass.op.Conv2d(kind="fprop", element=cutlass.DataType.f16) - + with ExpectException(True, 'Alignment 3 is not supported for F16. The construction should fail.'): op = plan.construct(alignment_A=3, alignment_B=3, alignment_C=3) - + def test_invalid_tile_description(self): """ Tests scenarios in which an invalid tile description is provided for a given CC """ plan = cutlass.op.Conv2d(kind="fprop", element=cutlass.DataType.f16) - + td = plan.tile_descriptions()[0] td.threadblock_shape=[17, 32, 5] - + plan.tile_description = td with ExpectException(True, 'The threadblock shape is invalid. The compilation should fail.'): plan.compile() diff --git a/test/python/cutlass/interface/evt_interface.py b/test/python/cutlass/interface/evt_interface.py index bd284f9e9e..717c746a9c 100644 --- a/test/python/cutlass/interface/evt_interface.py +++ b/test/python/cutlass/interface/evt_interface.py @@ -93,13 +93,16 @@ def test_too_much_shared_memory(self): """ Test when the epilogue consumes too much shared memory """ - def evt_too_much_shared_memory(accum, C1, C2, C3, C4, C5): + def evt_too_much_shared_memory(accum, C1, C2, C3, C4, C5, C6, C7, C8): D1 = accum + C1 D2 = D1 + C2 D3 = D2 + C3 D4 = D3 + C4 - D = D4 + C5 - return D, D1, D2, D3, D4 + D5 = D4 + C5 + D6 = D5 + C6 + D7 = D6 + C7 + D = D7 + C8 + return D, D1, D2, D3, D4, D5, D6, D7 example_tensors = { "accum": self.fake_tensor(np.float16, (6, 512, 512)), @@ -108,10 +111,16 @@ def evt_too_much_shared_memory(accum, C1, C2, C3, C4, C5): "C3": self.fake_tensor(np.float16, (6, 512, 512)), "C4": self.fake_tensor(np.float16, (6, 512, 512)), "C5": self.fake_tensor(np.float16, (6, 512, 512)), + "C6": self.fake_tensor(np.float16, (6, 512, 512)), + "C7": self.fake_tensor(np.float16, (6, 512, 512)), + "C8": self.fake_tensor(np.float16, (6, 512, 512)), "D1": self.fake_tensor(np.float16, (6, 512, 512)), "D2": self.fake_tensor(np.float16, (6, 512, 512)), "D3": self.fake_tensor(np.float16, (6, 512, 512)), "D4": self.fake_tensor(np.float16, (6, 512, 512)), + "D5": self.fake_tensor(np.float16, (6, 512, 512)), + "D6": self.fake_tensor(np.float16, (6, 512, 512)), + "D7": self.fake_tensor(np.float16, (6, 512, 512)), "D": self.fake_tensor(np.float16, (6, 512, 512)) } diff --git a/test/python/cutlass/interface/gemm_interface.py b/test/python/cutlass/interface/gemm_interface.py index 2429718280..d1d7d16928 100644 --- a/test/python/cutlass/interface/gemm_interface.py +++ b/test/python/cutlass/interface/gemm_interface.py @@ -85,7 +85,7 @@ def generic_test(self): Tests the equivalence of various constructions of the Gemm interface when using CUTLASS data types and layouts for constructing the Gemm interface """ - if not datatypes.numpy_available: + if not datatypes.is_numpy_available(): return # Test when specifying all parameters @@ -126,7 +126,7 @@ def numpy_test(self): """ Tests the equivalence of various constructions of the Gemm interface when using numpy as a frontend """ - if not datatypes.numpy_available: + if not datatypes.is_numpy_available(): return import numpy as np diff --git a/test/unit/core/fast_numeric_conversion.cu b/test/unit/core/fast_numeric_conversion.cu index 1eeb8e8d6b..ac3b5cf75d 100644 --- a/test/unit/core/fast_numeric_conversion.cu +++ b/test/unit/core/fast_numeric_conversion.cu @@ -114,7 +114,7 @@ void run_test_integer_range_all() { ); destination.sync_host(); - + // Verify conversion bool passed = true; for (int i = 0; i < kN; ++i) { @@ -124,7 +124,7 @@ void run_test_integer_range_all() { } } EXPECT_TRUE(passed) << " FastNumericArrayConverter failed"; - + // Print out results for the failed conversion. if (!passed) { for (int i = 0; i < kN; ++i) { diff --git a/test/unit/core/float8.cu b/test/unit/core/float8.cu index 79805031d3..7ed803e5ed 100644 --- a/test/unit/core/float8.cu +++ b/test/unit/core/float8.cu @@ -48,11 +48,20 @@ TEST(float_e4m3_t, host_conversion) { for (int i = -8; i < 8; ++i) { float f = static_cast(i); + cutlass::int4b_t s = static_cast(i); + FP8 w = static_cast(s); FP8 x = static_cast(i); FP8 y = static_cast(f); + EXPECT_TRUE(static_cast(w) == s); EXPECT_TRUE(static_cast(x) == i); EXPECT_TRUE(static_cast(y) == f); + + if (i >= 0) { + cutlass::uint4b_t u = static_cast(i); + FP8 z = static_cast(u); + EXPECT_TRUE(static_cast(z) == u); + } } // Try out default-ctor (zero initialization of primitive proxy type) @@ -72,11 +81,20 @@ TEST(float_e5m2_t, host_conversion) { for (int i = -8; i < 8; ++i) { float f = static_cast(i); + cutlass::int4b_t s = static_cast(i); + FP8 w = static_cast(s); FP8 x = static_cast(i); FP8 y = static_cast(f); + EXPECT_TRUE(static_cast(w) == s); EXPECT_TRUE(static_cast(x) == i); EXPECT_TRUE(static_cast(y) == f); + + if (i >= 0) { + cutlass::uint4b_t u = static_cast(i); + FP8 z = static_cast(u); + EXPECT_TRUE(static_cast(z) == u); + } } // Try out default-ctor (zero initialization of primitive proxy type) diff --git a/test/unit/core/numeric_conversion.cu b/test/unit/core/numeric_conversion.cu index 4faea52564..63b132f3de 100644 --- a/test/unit/core/numeric_conversion.cu +++ b/test/unit/core/numeric_conversion.cu @@ -60,7 +60,7 @@ __global__ void convert( ///////////////////////////////////////////////////////////////////////////////////////////////// -template +template void run_test(const char dest_name[], const char source_name[]) { const int kN = Count; @@ -69,9 +69,11 @@ void run_test(const char dest_name[], const char source_name[]) { cutlass::HostTensor destination({1, kN}); cutlass::HostTensor source({1, kN}); + auto source_ref = source.host_ref(); + auto destination_ref = destination.host_ref(); for (int i = 0; i < kN; ++i) { - source.host_data()[i] = Source(i % 4); + source_ref.at({0, i}) = Source(i % Range); } source.sync_device(); @@ -84,9 +86,67 @@ void run_test(const char dest_name[], const char source_name[]) { destination.sync_host(); for (int i = 0; i < kN; ++i) { - EXPECT_TRUE(float(destination.host_data()[i]) == float(source.host_data()[i])) - << "Destination type: " << dest_name - << ", Source type: " << source_name + EXPECT_TRUE(float(destination_ref.at({0, i})) == float(source_ref.at({0, i}))) + << "Destination type: " << dest_name << " "<< float(destination_ref.at({0, i})) + << ", Source type: " << source_name << " " << float(source_ref.at({0, i})) + << ", Count: " << Count; + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +__global__ void convert_with_scale_factor( + cutlass::Array *destination, + cutlass::Array const *source, + cutlass::Array const *scale_factor) { + + cutlass::NumericArrayConverter convert; + + *destination = convert(*source, *scale_factor); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +void run_test_with_scalefactor(const char dest_name[], const char source_name[], const char scale_factor_name[]) { + const int kN = Count; + + dim3 grid(1, 1); + dim3 block(1, 1); + + cutlass::HostTensor destination({1, kN}); + cutlass::HostTensor source({1, kN}); + cutlass::HostTensor scale_factor({1, kN}); + auto source_ref = source.host_ref(); + auto destination_ref = destination.host_ref(); + auto scale_factor_ref = scale_factor.host_ref(); + + + for (int i = 0; i < kN; ++i) { + source_ref.at({0, i}) = Source(i % Range); + } + + for (int i = 0; i < kN; ++i) { + scale_factor_ref.at({0, i}) = ScaleFactor(1 + i % 8); + } + + source.sync_device(); + scale_factor.sync_device(); + + convert_with_scale_factor<<< grid, block >>>( + reinterpret_cast *>(destination.device_data()), + reinterpret_cast const *>(source.device_data()), + reinterpret_cast const *>(scale_factor.device_data()) + ); + + destination.sync_host(); + + for (int i = 0; i < kN; ++i) { + float ref = float(source_ref.at({0, i})) / float(scale_factor_ref.at({0, i})); + EXPECT_TRUE(float(destination_ref.at({0, i})) == ref) + << "Destination type: " << dest_name << " "<< float(destination_ref.at({0, i})) + << ", Source type: " << source_name << " " << float(source_ref.at({0, i})) << ", Count: " << Count; } } @@ -98,7 +158,16 @@ void run_test(const char dest_name[], const char source_name[]) { ///////////////////////////////////////////////////////////////////////////////////////////////// TEST(NumericConversion, f32_to_f16_rn) { - int const kN = 1; + constexpr int kN = 1; + using Source = float; + const char source_name[] = "float"; + using Destination = cutlass::half_t; + const char dest_name[] = "half_t"; + test::core::kernel::run_test(dest_name, source_name); +} + +TEST(NumericConversion, f32x2_to_f16x2_rn) { + constexpr int kN = 2; using Source = float; const char source_name[] = "float"; using Destination = cutlass::half_t; @@ -107,7 +176,7 @@ TEST(NumericConversion, f32_to_f16_rn) { } TEST(NumericConversion, f32x8_to_f16x8_rn) { - int const kN = 8; + constexpr int kN = 8; using Source = float; const char source_name[] = "float"; using Destination = cutlass::half_t; @@ -394,4 +463,50 @@ TEST(NumericConversion, fe5m2_to_bf16_array) { test::core::kernel::run_test(dest_name, source_name); } -///////////////////////////////////////////////////////////////////////////////////////////////// +// These are included as regression tests for a special case when N = 4. +TEST(NumericConversion, int4b_t_to_fe5m2_t_array_4) { + int const kN = 4; + using Source = cutlass::int4b_t; + const char source_name[] = "int4b_t"; + using Destination = cutlass::float_e5m2_t; + const char dest_name[] = "float_e5m2_t"; + test::core::kernel::run_test(dest_name, source_name); +} + +TEST(NumericConversion, int_to_fe4m3_t_array_4) { + int const kN = 4; + using Source = int; + const char source_name[] = "int"; + using Destination = cutlass::float_e4m3_t; + const char dest_name[] = "float_e4m3_t"; + test::core::kernel::run_test(dest_name, source_name); +} + +TEST(NumericConversion, int2b_t_to_fe4m3_t_array_4) { + int const kN = 4; + using Source = cutlass::int2b_t; + const char source_name[] = "int2b_t"; + using Destination = cutlass::float_e4m3_t; + const char dest_name[] = "float_e4m3_t"; + test::core::kernel::run_test(dest_name, source_name); +} + +TEST(NumericConversion, fe5m2_t_to_double_array_4) { + int const kN = 4; + using Source = cutlass::float_e5m2_t; + const char source_name[] = "float_e5m2_t"; + using Destination = double; + const char dest_name[] = "double"; + test::core::kernel::run_test(dest_name, source_name); +} + +TEST(NumericConversion, int_to_fe4m3_t_array_32) { + int const kN = 32; + using Source = int; + const char source_name[] = "int"; + using Destination = cutlass::float_e4m3_t; + const char dest_name[] = "float_e4m3_t"; + test::core::kernel::run_test(dest_name, source_name); +} + + diff --git a/test/unit/cute/core/pointer.cpp b/test/unit/cute/core/pointer.cpp index 26ccb8723f..b1e33a3591 100644 --- a/test/unit/cute/core/pointer.cpp +++ b/test/unit/cute/core/pointer.cpp @@ -32,6 +32,7 @@ #include "cutlass_unit_test.h" #include + #include TEST(CuTe_core, Pointer) @@ -45,7 +46,7 @@ TEST(CuTe_core, Pointer) // Test T* overloads (T can be nonconst or const) { using T = float; - using expected_type = cute::gmem_ptr; + using expected_type = cute::gmem_ptr; T* p = nullptr; // explicit template argument @@ -58,7 +59,7 @@ TEST(CuTe_core, Pointer) } { using T = float const; - using expected_type = cute::gmem_ptr; + using expected_type = cute::gmem_ptr; T* p = nullptr; // explicit template argument @@ -74,7 +75,7 @@ TEST(CuTe_core, Pointer) // (these require an explicit template argument) { using T = float; - using expected_type = cute::gmem_ptr; + using expected_type = cute::gmem_ptr; void* p = nullptr; auto gmem_p0 = cute::make_gmem_ptr(p); @@ -82,7 +83,7 @@ TEST(CuTe_core, Pointer) } { using T = float const; - using expected_type = cute::gmem_ptr; + using expected_type = cute::gmem_ptr; void const* p = nullptr; auto gmem_p0 = cute::make_gmem_ptr(p); @@ -92,14 +93,14 @@ TEST(CuTe_core, Pointer) // Test nullptr_t overload. { using T = float; - using expected_type = cute::gmem_ptr; + using expected_type = cute::gmem_ptr; auto gmem_p0 = cute::make_gmem_ptr(nullptr); static_assert(cute::is_same_v); } { using T = float const; - using expected_type = cute::gmem_ptr; + using expected_type = cute::gmem_ptr; auto gmem_p0 = cute::make_gmem_ptr(nullptr); static_assert(cute::is_same_v); diff --git a/test/unit/cute/hopper/tma_load.cu b/test/unit/cute/hopper/tma_load.cu index 335d6091b7..c850d92cef 100644 --- a/test/unit/cute/hopper/tma_load.cu +++ b/test/unit/cute/hopper/tma_load.cu @@ -416,7 +416,6 @@ TEST(SM90_CuTe_Hopper, Tma_Load_InternalType) test_tma_load(gmem_layout, smem_layout); test_tma_load< float, uint64_t>(gmem_layout, smem_layout); test_tma_load(gmem_layout, smem_layout); - } // Complex is 128bit, which the TMA has no concept of diff --git a/test/unit/cute/hopper/tma_load_testbed.hpp b/test/unit/cute/hopper/tma_load_testbed.hpp index ce8aa8dd84..5e01345a5d 100644 --- a/test/unit/cute/hopper/tma_load_testbed.hpp +++ b/test/unit/cute/hopper/tma_load_testbed.hpp @@ -43,7 +43,7 @@ namespace cutlass::test { template struct SharedStorage { - cute::array_aligned> smem; + cute::ArrayEngine> smem; cute::uint64_t tma_load_mbar[1]; }; @@ -62,26 +62,26 @@ tma_test_device_cute(T const* g_in, T* g_out, extern __shared__ char shared_memory[]; using SharedStorage = SharedStorage; SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + // Construct SMEM tensor - Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.data()), smem_layout); // (CTA_TILE_M,CTA_TILE_N,...) + Tensor sA = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout); // (CTA_TILE_M,CTA_TILE_N,...) // Shared memory barriers use 64bits in SMEM for synchronization uint64_t* tma_load_mbar = shared_storage.tma_load_mbar; // TMA requires special handling of strides to deal with coord codomain mapping // Represent the full tensors -- get these from TMA Tensor mA = tma.get_tma_tensor(shape(gmem_layout)); - Tensor mB = make_tensor(make_gmem_ptr(g_out), gmem_layout); + Tensor mB = make_tensor(make_gmem_ptr(g_out), gmem_layout); constexpr int R = rank_v; - Tensor gA = local_tile(mA, cta_tiler, repeat(_)); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...) - Tensor gB = local_tile(mB, cta_tiler, repeat(_)); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...) + Tensor gA = flat_divide(mA, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...) + Tensor gB = flat_divide(mB, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...) // // Prepare the TMA_LOAD // auto cta_tma = tma.get_slice(Int<0>{}); // CTA slice - Tensor tAgA_x = cta_tma.partition_S(gA); // (TMA,TMA_M,TMA_N,REST_M,REST_N) Tensor tAsA_x = cta_tma.partition_D(sA); // (TMA,TMA_M,TMA_N) @@ -89,11 +89,13 @@ tma_test_device_cute(T const* g_in, T* g_out, if (thread0()) { print(tma); print("TILE : "); print(cta_tiler); print("\n"); - print(" mA : "); print( mA.data()); print(" o "); print( mA.layout()); print("\n"); - print(" gA : "); print( gA.data()); print(" o "); print( gA.layout()); print("\n"); - print("tAgA_x: "); print(tAgA_x.data()); print(" o "); print(tAgA_x.layout()); print("\n"); - print(" sA : "); print( sA.data()); print(" o "); print( sA.layout()); print("\n"); - print("tAsA_x: "); print(tAsA_x.data()); print(" o "); print(tAsA_x.layout()); print("\n"); + print(" mA : "); print( mA); print("\n"); + print(" mB : "); print( mB); print("\n"); + print(" gA : "); print( gA); print("\n"); + print(" gB : "); print( gB); print("\n"); + print(" sA : "); print( sA); print("\n"); + print("tAgA_x: "); print(tAgA_x); print("\n"); + print("tAsA_x: "); print(tAsA_x); print("\n"); } #endif @@ -111,9 +113,9 @@ tma_test_device_cute(T const* g_in, T* g_out, #if 0 if (thread0()) { - print("tAgA : "); print(tAgA.data()); print(" o "); print(tAgA.layout()); print("\n"); - print("tAsA : "); print(tAsA.data()); print(" o "); print(tAsA.layout()); print("\n"); - print("tBgB : "); print(tBgB.data()); print(" o "); print(tBgB.layout()); print("\n"); + print("tAgA : "); print(tAgA); print("\n"); + print("tAsA : "); print(tAsA); print("\n"); + print("tBgB : "); print(tBgB); print("\n"); } #endif @@ -121,7 +123,7 @@ tma_test_device_cute(T const* g_in, T* g_out, for (int stage = 0; stage < size<1>(tAgA); ++stage) { // Set the bytes transferred in this TMA transaction (may involve multiple issues) - constexpr int kTmaTransactionBytes = size(sA) * sizeof_bits_v / 8; + constexpr int kTmaTransactionBytes = sizeof(ArrayEngine); if (threadIdx.x == 0) { @@ -146,9 +148,15 @@ tma_test_device_cute(T const* g_in, T* g_out, // print_tensor(sA); //} - for (int i = threadIdx.x; i < size(sA); i += blockDim.x) { - tBgB(i,stage) = sA(i); + // for (int i = threadIdx.x; i < size(sA); i += blockDim.x) { + // tBgB(i,stage) = sA(i); + // } + + // Subbyte elements could cause race conditions, so be even more conservative + if (thread0()) { + copy(sA, tBgB(_,stage)); } + __syncthreads(); } } @@ -161,30 +169,38 @@ test_tma_load(CopyOp const& copy_op, CTA_Tile const& cta_tile) { using namespace cute; - thrust::host_vector h_in(cosize(gmem_layout)); - for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i % 13); } - thrust::device_vector d_in = h_in; - thrust::device_vector d_out(h_in.size(), T(-1)); - Tensor gA = make_tensor(d_in.data().get(), gmem_layout); + // Allocate and initialize host test data + size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits::value, 8); + thrust::host_vector h_in(N); + Tensor hA_in = make_tensor(recast_ptr(h_in.data()), gmem_layout); + for (int i = 0; i < size(hA_in); ++i) { hA_in(i) = static_cast(i % 13); } + + // Allocate and initialize device test data + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), char(-1)); + + // Create TMA for this device Tensor + Tensor gA = make_tensor(make_gmem_ptr(raw_pointer_cast(d_in.data())), gmem_layout); auto tma = make_tma_copy(copy_op, gA, smem_layout, cta_tile, Int<1>{}); //print(tma); + // Launch int smem_size = int(sizeof(SharedStorage)); tma_test_device_cute<<<1, 128, smem_size>>>( - thrust::raw_pointer_cast(d_in.data()), - thrust::raw_pointer_cast(d_out.data()), + reinterpret_cast(raw_pointer_cast(d_in.data())), + reinterpret_cast (raw_pointer_cast(d_out.data())), tma, cta_tile, gmem_layout, smem_layout); - thrust::host_vector h_out = d_out; + // Copy results back to host + thrust::host_vector h_out = d_out; + Tensor hA_out = make_tensor(recast_ptr(h_out.data()), gmem_layout); - // Validate the results, and tolerate the first 3 errors: - Tensor hA_in = make_tensor(h_in.data(), gmem_layout); - Tensor hA_out = make_tensor(h_out.data(), gmem_layout); + // Validate the results. Print only the first 3 errors. int count = 3; - for (int i = 0; i < cute::size(gmem_layout) && count > 0; ++i) { + for (int i = 0; i < size(hA_out) && count > 0; ++i) { EXPECT_EQ(hA_in(i), hA_out(i)); if (hA_in(i) != hA_out(i)) { --count; diff --git a/test/unit/cute/hopper/tma_store_testbed.hpp b/test/unit/cute/hopper/tma_store_testbed.hpp index 990d625dd6..47a31d9b8a 100644 --- a/test/unit/cute/hopper/tma_store_testbed.hpp +++ b/test/unit/cute/hopper/tma_store_testbed.hpp @@ -43,7 +43,7 @@ namespace cutlass::test { template struct SharedStorage { - cute::array_aligned> smem; + cute::ArrayEngine> smem; }; #if CUDA_12_0_SM90_FEATURES_SUPPORTED @@ -61,24 +61,24 @@ tma_test_device_cute(T const* g_in, T* g_out, extern __shared__ char shared_memory[]; using SharedStorage = SharedStorage; SharedStorage& shared_storage = *reinterpret_cast(shared_memory); + // Construct SMEM tensor - Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem.data()), smem_layout); // (CTA_TILE_M,CTA_TILE_N,...) + Tensor sB = make_tensor(make_smem_ptr(shared_storage.smem.begin()), smem_layout); // (CTA_TILE_M,CTA_TILE_N,...) // TMA requires special handling of strides to deal with coord codomain mapping // Represent the full tensors -- get these from TMA - Tensor mA = make_tensor(make_gmem_ptr(g_in), gmem_layout); + Tensor mA = make_tensor(make_gmem_ptr(g_in), gmem_layout); Tensor mB = tma.get_tma_tensor(shape(gmem_layout)); constexpr int R = rank_v; - Tensor gA = local_tile(mA, cta_tiler, repeat(_)); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...) - Tensor gB = local_tile(mB, cta_tiler, repeat(_)); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...) + Tensor gA = flat_divide(mA, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...) + Tensor gB = flat_divide(mB, cta_tiler); // (CTA_TILE_M,CTA_TILE_N,...REST_M,REST_N,...) // // Prepare the TMA_STORE // auto cta_tma = tma.get_slice(Int<0>{}); // CTA slice - Tensor tBsB_x = cta_tma.partition_S(sB); // (TMA,TMA_M,TMA_N) Tensor tBgB_x = cta_tma.partition_D(gB); // (TMA,TMA_M,TMA_N,REST_M,REST_N) @@ -121,11 +121,17 @@ tma_test_device_cute(T const* g_in, T* g_out, // Read in trivially gmem -> smem // - for (int i = threadIdx.x; i < size(sB); i += blockDim.x) { - sB(i) = tAgA(i,stage); + // for (int i = threadIdx.x; i < size(sB); i += blockDim.x) { + // sB(i) = tAgA(i,stage); + // } + + // Subbyte elements could cause race conditions, so be even more conservative + if (thread0()) { + copy(tAgA(_,stage), sB); } __syncthreads(); + cute::cp_async_wait<0>(); // // Perform the TMA_STORE @@ -148,30 +154,38 @@ test_tma_store(CopyOp const& copy_op, CTA_Tile const& cta_tile) { using namespace cute; - thrust::host_vector h_in(cosize(gmem_layout)); - for (int i = 0; i < h_in.size(); ++i) { h_in[i] = T(i % 13); } - thrust::device_vector d_in = h_in; - thrust::device_vector d_out(h_in.size(), T(-1)); - Tensor gA = make_tensor(d_out.data().get(), gmem_layout); + // Allocate and initialize host test data + size_t N = ceil_div(cosize(gmem_layout) * sizeof_bits::value, 8); + thrust::host_vector h_in(N); + Tensor hA_in = make_tensor(recast_ptr(h_in.data()), gmem_layout); + for (int i = 0; i < size(hA_in); ++i) { hA_in(i) = static_cast(i % 13); } + + // Allocate and initialize device test data + thrust::device_vector d_in = h_in; + thrust::device_vector d_out(h_in.size(), char(-1)); + + // Create TMA for this device Tensor + Tensor gA = make_tensor(make_gmem_ptr(raw_pointer_cast(d_out.data())), gmem_layout); auto tma = make_tma_copy(copy_op, gA, smem_layout, cta_tile, Int<1>{}); //print(tma); + // Launch int smem_size = int(sizeof(SharedStorage)); tma_test_device_cute<<<1, 128, smem_size>>>( - thrust::raw_pointer_cast(d_in.data()), - thrust::raw_pointer_cast(d_out.data()), + reinterpret_cast(raw_pointer_cast(d_in.data())), + reinterpret_cast (raw_pointer_cast(d_out.data())), tma, cta_tile, gmem_layout, smem_layout); - thrust::host_vector h_out = d_out; + // Copy results back to host + thrust::host_vector h_out = d_out; + Tensor hA_out = make_tensor(recast_ptr(h_out.data()), gmem_layout); - // Validate the results, and tolerate the first 3 errors: - Tensor hA_in = make_tensor(h_in.data(), gmem_layout); - Tensor hA_out = make_tensor(h_out.data(), gmem_layout); + // Validate the results. Print only the first 3 errors. int count = 3; - for (int i = 0; i < cute::size(gmem_layout) && count > 0; ++i) { + for (int i = 0; i < size(hA_out) && count > 0; ++i) { EXPECT_EQ(hA_in(i), hA_out(i)); if (hA_in(i) != hA_out(i)) { --count; diff --git a/test/unit/gemm/device/CMakeLists.txt b/test/unit/gemm/device/CMakeLists.txt index 0bd60ed936..5bab6b7be0 100644 --- a/test/unit/gemm/device/CMakeLists.txt +++ b/test/unit/gemm/device/CMakeLists.txt @@ -242,6 +242,24 @@ cutlass_test_unit_add_executable( sm80_gemm_tf32_tf32_f32_tensor_op_f32.cu ) + +cutlass_test_unit_add_executable( + cutlass_test_unit_gemm_device_mixed_input_tensorop_sm80 + + BATCH_SOURCES ON + BATCH_SIZE 4 + + # Upcast on Operand A + gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu + gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu + gemm_universal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32_sm80.cu + + # Upcast on Operand B + gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f16_sm80.cu + gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f16_sm80.cu + gemm_universal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32_sm80.cu +) + cutlass_test_unit_add_executable( cutlass_test_unit_gemm_device_tensorop_sm90 @@ -272,10 +290,22 @@ cutlass_test_unit_add_executable( BATCH_SOURCES ON BATCH_SIZE 4 - sm90_gemm_f16_f16_f16_alignx_tensor_op.cu + sm90_gemm_f16_f16_f16_alignx_tensor_op_f32.cu + sm90_gemm_f16_f16_f16_alignx_tensor_op_f32_warpspecialized.cu + sm90_gemm_f16_f16_f16_alignx_tensor_op_f32_warpspecialized_cooperative.cu + sm90_gemm_f16_f16_f16_alignx_tensor_op_f32_warpspecialized_pingpong.cu sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32.cu + sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized.cu + sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_cooperative.cu + sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_pingpong.cu sm90_gemm_s8_s8_s8_alignx_tensor_op_s32.cu + sm90_gemm_s8_s8_s8_alignx_tensor_op_s32_warpspecialized.cu + sm90_gemm_s8_s8_s8_alignx_tensor_op_s32_warpspecialized_cooperative.cu + sm90_gemm_s8_s8_s8_alignx_tensor_op_s32_warpspecialized_pingpong.cu sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32.cu + sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32_warpspecialized.cu + sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32_warpspecialized_cooperative.cu + sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32_warpspecialized_pingpong.cu ) # Fused epilogue tests @@ -298,7 +328,6 @@ cutlass_test_unit_add_executable( sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_dag.cu sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_dag.cu ) - cutlass_test_unit_add_executable( cutlass_test_unit_gemm_device_tensorop_cluster_multicast_sm90 @@ -311,7 +340,6 @@ cutlass_test_unit_add_executable( sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu sm90_gemm_f8_f8_f32_tensor_op_f32_cluster_warpspecialized_cooperative.cu ) - cutlass_test_unit_add_executable( cutlass_test_unit_gemm_device_tensorop_gmma_rs_warpspecialized_sm90 @@ -319,6 +347,8 @@ cutlass_test_unit_add_executable( BATCH_SIZE 4 sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu + sm90_gemm_f8_f8_f32_tensor_op_f32_rs_cluster_warpspecialized_cooperative.cu + sm90_gemm_f16_f16_f32_tensor_op_f32_rs_cluster_warpspecialized_cooperative.cu ) cutlass_test_unit_add_executable( @@ -341,23 +371,6 @@ cutlass_test_unit_add_executable( sm80_gemm_f16_f16_f32_tensor_op_f32.cu ) -cutlass_test_unit_add_executable( - cutlass_test_unit_gemm_device_mixed_input_tensorop_sm80 - - BATCH_SOURCES ON - BATCH_SIZE 4 - - # Upcast on Operand A - gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu - gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu - gemm_universal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32_sm80.cu - - # Upcast on Operand B - gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f16_sm80.cu - gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f16_sm80.cu - gemm_universal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32_sm80.cu -) - cutlass_test_unit_add_executable( cutlass_test_unit_gemm_device_tensorop_f64 diff --git a/test/unit/gemm/device/gemm_testbed_3x.hpp b/test/unit/gemm/device/gemm_testbed_3x.hpp index 156913fdf3..6895380411 100644 --- a/test/unit/gemm/device/gemm_testbed_3x.hpp +++ b/test/unit/gemm/device/gemm_testbed_3x.hpp @@ -49,16 +49,18 @@ #include "cutlass/util/reference/host/tensor_compare.h" #include "cutlass/util/reference/host/tensor_norm.h" #include "cutlass/util/reference/host/gett.hpp" - #include "testbed_utils.h" #include "cutlass/kernel_hardware_info.hpp" #include "cutlass/layout/matrix.h" #include "cutlass/matrix_coord.h" #include "cutlass/gemm/gemm.h" +#include "cutlass/fast_math.h" +#include "cutlass/platform/platform.h" #include "cutlass/epilogue/fusion/operations.hpp" #include "cute/int_tuple.hpp" +#include "cute/layout.hpp" namespace test { namespace gemm { @@ -68,9 +70,9 @@ namespace device { namespace detail{ -// Helper classes that take default data type when +// Helper classes that take default data type when // the Gemm::EpilogueOutputOp does not have ElementCompute -// and ElementScalar. +// and ElementScalar. // (e.g. when Sm90TreeVisitor is used as FusionCallbacks) template struct ElementComputeType { @@ -138,6 +140,34 @@ class Iterations { int iterations_ = 20; }; +// The maxium swizzle size to use +// +// This class, like Splits above makes it harder to confuse +// the order of arguments of the various run(...) functions in this file. +class MaxSwizzleSize { +public: + MaxSwizzleSize() = default; + + template && + !std::is_same_v)) > + explicit MaxSwizzleSize(IntegralNotBool max_swizzle_size) : max_swizzle_size_(max_swizzle_size) {} + explicit operator int() const { return max_swizzle_size_; } +private: + int max_swizzle_size_ = 1; +}; + +template +auto make_iterator(T* ptr) { + using namespace cute; + if constexpr (is_subbyte_v) { + return subbyte_iterator(ptr); + } + else { + return ptr; + } +} + template < typename Gemm, template class ActivationFunctor_ = cutlass::epilogue::thread::Identity @@ -161,6 +191,8 @@ struct TestbedImpl { using ElementScalar = typename ElementScalarType::Type; using ActivationFunctor = ActivationFunctor_; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + static_assert(rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); static_assert(rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); @@ -190,6 +222,7 @@ struct TestbedImpl { using LayoutTagB = cutlass::detail::StrideToLayoutTagB_t; using LayoutTagC = cutlass::detail::StrideToLayoutTagA_t; using LayoutTagD = cutlass::detail::StrideToLayoutTagA_t; + using LayoutTagVector = cutlass::layout::PackedVectorLayout; /// Initialization StrideA stride_a; @@ -323,10 +356,10 @@ struct TestbedImpl { // 2.x host tensor does not natively contain a batch stride or coord, so we spoof if by folding it into the outer mode auto a_coord = cutlass::make_Coord(M * L, K); auto c_coord = cutlass::make_Coord(M * L, N); - // Cutlass has Row/Col major refers to MxK times KxN matrix product, + // Cutlass has Row/Col major refers to MxK times KxN matrix product, // so the HostTensorB should be treated as KxN in "coord"'s view auto b_coord = cutlass::make_Coord(K, N * L); - + tensor_A.resize(a_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(a_coord, stride_factor_A)); tensor_B.resize(b_coord, cutlass::layout::Affine2Layout_Factory::layout_factory(b_coord, stride_factor_B)); @@ -387,7 +420,7 @@ struct TestbedImpl { std::ofstream file(fname.str()); file << "problem: " << ' ' << M << "x" << N << "x" << K << ", Batch count = " << L - << ", alpha: " << float(alpha) << ", beta: " << float(beta) << "\n\n"; + << ", alpha: " << alpha << ", beta: " << beta << "\n\n"; file << "A =\n" << tensor_A.host_view() @@ -404,7 +437,7 @@ struct TestbedImpl { bool verify( ProblemShapeType problem_size, ElementScalar alpha, - ElementScalar beta) + ElementScalar beta) { auto problem_shape_MNKL = cute::append<4>(problem_size, 1); auto M = cute::size<0>(problem_shape_MNKL); @@ -412,13 +445,13 @@ struct TestbedImpl { auto K = cute::size<2>(problem_shape_MNKL); auto L = cute::size<3>(problem_shape_MNKL); - auto A = cute::make_tensor(tensor_A.host_data(), + auto A = cute::make_tensor(detail::make_iterator(tensor_A.host_data()), cute::make_layout(cute::make_shape(M, K, L), stride_a)); - auto B = cute::make_tensor(tensor_B.host_data(), + auto B = cute::make_tensor(detail::make_iterator(tensor_B.host_data()), cute::make_layout(cute::make_shape(N, K, L), stride_b)); - auto C = cute::make_tensor(tensor_C.host_data(), + auto C = cute::make_tensor(detail::make_iterator(tensor_C.host_data()), cute::make_layout(cute::make_shape(M, N, L), stride_c)); - auto D = cute::make_tensor(reference_D.host_data(), + auto D = cute::make_tensor(detail::make_iterator(reference_D.host_data()), cute::make_layout(cute::make_shape(M, N, L), stride_d)); auto Bias = cute::make_tensor(static_cast(nullptr), cute::make_layout(cute::make_shape(M, cute::_1{}))); @@ -451,7 +484,6 @@ struct TestbedImpl { }; cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); - return compare_reference(problem_shape_MNKL, alpha, beta); } @@ -529,8 +561,10 @@ struct TestbedImpl { ElementScalar alpha = ElementScalar(1), ElementScalar beta = ElementScalar(0), bool profiling = false, - detail::Iterations iterations = Iterations{}, - detail::Splits splits = Splits{}) + detail::Iterations iterations = detail::Iterations{}, + RasterOrderOptions raster_order = RasterOrderOptions::Heuristic, + detail::MaxSwizzleSize max_swizzle = detail::MaxSwizzleSize{}, + detail::Splits splits = detail::Splits{}) { // Fail test if insufficient CUDA device if (!sufficient()) { @@ -557,7 +591,10 @@ struct TestbedImpl { typename Gemm::GemmKernel::TileScheduler::Arguments scheduler_args; if constexpr (std::is_same_v) { - scheduler_args = { static_cast(splits) }; + scheduler_args = { static_cast(splits), static_cast(max_swizzle), raster_order }; + } + else { + scheduler_args = { static_cast(max_swizzle), raster_order }; } // DefaultEpilogue @@ -613,7 +650,7 @@ struct TestbedImpl { // bool passed = this->verify(problem_size, alpha, beta); if (!passed) { - std::cout << "Error : Failed : with alpha: " << float(alpha) << ", beta: " << float(beta) + std::cout << "Error : Failed : with alpha: " << alpha << ", beta: " << beta << "\n"; } @@ -648,6 +685,8 @@ struct Testbed3x { using LayoutTagC = typename TestBedImpl::LayoutTagC; using LayoutTagD = typename TestBedImpl::LayoutTagD; + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + // Detail Implementation TestBedImpl impl_; @@ -661,7 +700,7 @@ struct Testbed3x { uint64_t seed_ = TestBedImpl::kDefaultSeed) : impl_(init_A_, init_B_, init_C_, seed_) {} - Testbed3x( + Testbed3x( typename LayoutTagA::Stride stride_factor_A_, typename LayoutTagB::Stride stride_factor_B_, typename LayoutTagC::Stride stride_factor_C_, @@ -684,12 +723,14 @@ struct Testbed3x { typename TestBedImpl::ProblemShapeType problem_size, ElementScalar alpha = ElementScalar(1), ElementScalar beta = ElementScalar(0), + RasterOrderOptions raster_order = RasterOrderOptions::Heuristic, + detail::MaxSwizzleSize max_swizzle = detail::MaxSwizzleSize{}, detail::Splits splits = detail::Splits{}, bool profiling = false, detail::Iterations iterations = detail::Iterations{}) { return impl_.run( - problem_size, alpha, beta, profiling, iterations, splits + problem_size, alpha, beta, profiling, iterations, raster_order, max_swizzle, splits ); } }; @@ -722,13 +763,15 @@ struct Testbed3xFusionOperation { using StrideD = typename Kernel::StrideD; using ProblemShapeType = typename Kernel::ProblemShape; using ElementAccumulator = typename Kernel::ElementAccumulator; - + // // FusionOperation derived types/queries // using FusionOp = typename Gemm::EpilogueOutputOp; static_assert(cute::is_base_of_v); + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + // fusion types are potentially void if the fusion is not supported // helper so we don't try to construct HostTensor with void type template @@ -744,11 +787,17 @@ struct Testbed3xFusionOperation { cutlass::epilogue::thread::Identity>; static constexpr bool IsBiasEnabled = FusionOp::IsPerRowBiasSupported; + static constexpr bool IsDeBiasEnabled = FusionOp::IsDePerRowBiasSupported; static constexpr bool IsPerRowScaleEnabled = FusionOp::IsPerRowScaleSupported; static constexpr bool IsScaleFactorEnabled = FusionOp::IsScaleFactorSupported; - static constexpr bool IsAuxEnabled = FusionOp::IsAuxOutSupported; - static constexpr bool IsAbsMaxEnabled = FusionOp::IsAbsMaxSupported; - + static constexpr bool IsAuxInEnabled = FusionOp::IsAuxInSupported; + static constexpr bool IsAuxOutEnabled = FusionOp::IsAuxOutSupported; + static constexpr bool IsAbsMaxEnabledD = FusionOp::IsAbsMaxSupported && + (cute::is_same_v || + cute::is_same_v); + static constexpr bool IsAbsMaxEnabledAux = IsAuxOutEnabled && FusionOp::IsAbsMaxSupported && + (cute::is_same_v || + cute::is_same_v); // Legacy support for deprecated bias-elementwise collective, will be removed next release using EpiloguePolicy = typename Epilogue::DispatchPolicy; static constexpr bool IsLegacy = @@ -773,6 +822,7 @@ struct Testbed3xFusionOperation { cutlass::HostTensor tensor_Aux; cutlass::gemm::TagToStrideC_t< LayoutTagAux > stride_Aux; // References + cutlass::HostTensor reference_dbias; cutlass::HostTensor reference_Aux; cutlass::HostTensor reference_abs_max_Aux; cutlass::HostTensor reference_abs_max_D; @@ -791,12 +841,6 @@ struct Testbed3xFusionOperation { // Random distribution with which to initialize the bias vector cutlass::Distribution::Kind init_bias = cutlass::Distribution::Uniform; - // Factors used for calculating relative equality. These default - // values are borrowed from those used by default in the CUTLASS - // profiler for performing relative equality checks. - float epsilon = 0.05f; - float nonzero_floor = 1.0f / 256.0f; - // // Methods // @@ -853,7 +897,7 @@ struct Testbed3xFusionOperation { else { beta.resize(col_vector_coord); EXPECT_TRUE(impl_.initialize_tensor(beta.host_view(), init_scale, impl_.seed + 2024)); - } + } } else { alpha.resize(scalar_coord, use_device_scalars); @@ -885,13 +929,34 @@ struct Testbed3xFusionOperation { bias.sync_device(); } - if constexpr (IsAbsMaxEnabled) { + if constexpr (IsDeBiasEnabled) { + bias.resize(col_vector_coord); + reference_dbias.resize(col_vector_coord); + cutlass::reference::host::TensorFill(bias.host_view(), ElementBias(0)); + cutlass::reference::host::TensorFill(reference_dbias.host_view(), ElementBias(0)); + bias.sync_device(); + } + + if constexpr (IsAbsMaxEnabledD) { abs_max_D.resize(scalar_coord); + // ensure in-place device reductions perform their own initialization + cutlass::reference::host::TensorFill(abs_max_D.host_view(), + CUTLASS_STL_NAMESPACE::numeric_limits::max()); abs_max_D.sync_device(); reference_abs_max_D.resize(scalar_coord); + cutlass::reference::host::TensorFill(reference_abs_max_D.host_view(), ElementAmax(0)); } - if constexpr (IsAuxEnabled) { + if constexpr (IsAuxInEnabled) { + auto aux_coord = cutlass::make_Coord(M * L, N); + auto aux_layout = cutlass::layout::Affine2Layout_Factory::layout_factory(aux_coord, typename LayoutTagAux::Stride{}); + tensor_Aux.resize(aux_coord, aux_layout); + EXPECT_TRUE(impl_.initialize_tensor(tensor_Aux.host_view(), impl_.init_C, impl_.seed + 2023)); + tensor_Aux.sync_device(); + stride_Aux = cutlass::make_cute_packed_stride(cutlass::gemm::TagToStrideC_t{}, cute::make_shape(M, N, L)); + } + + if constexpr (IsAuxOutEnabled) { auto aux_coord = cutlass::make_Coord(M * L, N); auto aux_layout = cutlass::layout::Affine2Layout_Factory::layout_factory(aux_coord, typename LayoutTagAux::Stride{}); tensor_Aux.resize(aux_coord, aux_layout); @@ -905,10 +970,14 @@ struct Testbed3xFusionOperation { scale_Aux.sync_device(); } - if constexpr (IsAbsMaxEnabled) { + if constexpr (IsAbsMaxEnabledAux) { abs_max_Aux.resize(scalar_coord); + // ensure in-place device reductions perform their own initialization + cutlass::reference::host::TensorFill(abs_max_Aux.host_view(), + CUTLASS_STL_NAMESPACE::numeric_limits::max()); abs_max_Aux.sync_device(); reference_abs_max_Aux.resize(scalar_coord); + cutlass::reference::host::TensorFill(reference_abs_max_Aux.host_view(), ElementAmax(0)); } } @@ -922,9 +991,16 @@ struct Testbed3xFusionOperation { cutlass::TensorView const& lhs, cutlass::TensorView const& rhs) const { + // Factors used for calculating relative equality. CUTLASS's relative-equality + // checks in include/cutlass/relatively_equal.h are inspired by + // https://floating-point-gui.de/errors/comparison/. This reference suggests using + // the minimum normal value of a given type as the nonzero_floor. + Element epsilon(0.1f); + Element nonzero_floor(std::numeric_limits::min()); + if (check_relative_equality) { return cutlass::reference::host::TensorRelativelyEquals( - lhs, rhs, Element(epsilon), Element(nonzero_floor)); + lhs, rhs, epsilon, nonzero_floor); } else { return cutlass::reference::host::TensorEquals(lhs, rhs); @@ -933,6 +1009,7 @@ struct Testbed3xFusionOperation { /// Compares computed reference with device reference and outputs to a file if incorrect bool compare_reference(cute::Shape problem_shape_MNKL) { + auto [M, N, K, L] = problem_shape_MNKL; auto coord_0 = cutlass::make_Coord(0); @@ -947,17 +1024,24 @@ struct Testbed3xFusionOperation { } bool passed = equality_check(impl_.reference_D.host_view(), impl_.tensor_D.host_view()); - if constexpr (IsAbsMaxEnabled) { + if constexpr (IsAbsMaxEnabledD) { abs_max_D.sync_host(); passed &= equality_check(reference_abs_max_D.host_view(), abs_max_D.host_view()); } - if constexpr (IsAuxEnabled) { + if constexpr (IsDeBiasEnabled) { + bias.sync_host(); + EXPECT_GT(cutlass::reference::host::TensorNorm(bias.host_view()), 0); + EXPECT_GT(cutlass::reference::host::TensorNorm(reference_dbias.host_view()), 0); + passed &= equality_check(reference_dbias.host_view(), bias.host_view()); + } + + if constexpr (IsAuxOutEnabled) { tensor_Aux.sync_host(); EXPECT_GT(cutlass::reference::host::TensorNorm(tensor_Aux.host_view()), 0); EXPECT_GT(cutlass::reference::host::TensorNorm(reference_Aux.host_view()), 0); passed &= equality_check(reference_Aux.host_view(), tensor_Aux.host_view()); - if constexpr (IsAbsMaxEnabled) { + if constexpr (IsAbsMaxEnabledAux) { abs_max_Aux.sync_host(); passed &= equality_check(reference_abs_max_Aux.host_view(), abs_max_Aux.host_view()); } @@ -990,7 +1074,7 @@ struct Testbed3xFusionOperation { } file << "\n\n"; - if constexpr (IsAbsMaxEnabled) { + if constexpr (IsAbsMaxEnabledD) { file << "scale_d: " << float(scale_D.at(coord_0)); file << "\nReference abs_max_D :"; file << " " << float(reference_abs_max_D.at(coord_0)); @@ -998,15 +1082,16 @@ struct Testbed3xFusionOperation { file << "\nComputed abs_max_D :"; file << " " << float(abs_max_D.at(coord_0)); file << "\n\n"; - if constexpr (IsAuxEnabled) { - file << "scale_aux: " << float(scale_Aux.at(coord_0)); - file << "\nReference abs_max_Aux :"; - file << " " << float(reference_abs_max_Aux.at(coord_0)); - - file << "\nComputed abs_max_Aux :"; - file << " " << float(abs_max_Aux.at(coord_0)); - file << "\n\n"; - } + } + + if constexpr (IsAbsMaxEnabledAux) { + file << "scale_aux: " << float(scale_Aux.at(coord_0)); + file << "\nReference abs_max_Aux :"; + file << " " << float(reference_abs_max_Aux.at(coord_0)); + + file << "\nComputed abs_max_Aux :"; + file << " " << float(abs_max_Aux.at(coord_0)); + file << "\n\n"; } file @@ -1018,7 +1103,16 @@ struct Testbed3xFusionOperation { file << "\n\nBias = \n" << bias.host_view(); } - if constexpr (IsAuxEnabled) { + if constexpr (IsAuxInEnabled) { + file << "\n\nAux Input = \n" << tensor_Aux.host_view(); + } + + if constexpr (IsDeBiasEnabled) { + file << "\n\nReference dBias = \n" << reference_dbias.host_view(); + file << "\n\nComputed dBias = \n" << bias.host_view(); + } + + if constexpr (IsAuxOutEnabled) { file << "\n\nReference Aux =\n" << reference_Aux.host_view() << "\n\nComputed Aux =\n" << tensor_Aux.host_view(); @@ -1041,21 +1135,21 @@ struct Testbed3xFusionOperation { auto L = cute::get<3>(problem_shape_MNKL); auto coord_0 = cutlass::make_Coord(0); - auto A = cute::make_tensor(impl_.tensor_A.host_data(), + auto A = cute::make_tensor(detail::make_iterator(impl_.tensor_A.host_data()), cute::make_layout(cute::make_shape(M, K, L), impl_.stride_a)); - auto B = cute::make_tensor(impl_.tensor_B.host_data(), + auto B = cute::make_tensor(detail::make_iterator(impl_.tensor_B.host_data()), cute::make_layout(cute::make_shape(N, K, L), impl_.stride_b)); - auto C = cute::make_tensor(impl_.tensor_C.host_data(), + auto C = cute::make_tensor(detail::make_iterator(impl_.tensor_C.host_data()), cute::make_layout(cute::make_shape(M, N, L), impl_.stride_c)); - auto D = cute::make_tensor(impl_.reference_D.host_data(), + auto D = cute::make_tensor(detail::make_iterator(impl_.reference_D.host_data()), cute::make_layout(cute::make_shape(M, N, L), impl_.stride_d)); - auto Bias = cute::make_tensor(bias.host_data(), + auto Bias = cute::make_tensor(detail::make_iterator(IsDeBiasEnabled ? reference_dbias.host_data() : bias.host_data()), cute::make_layout(cute::make_shape(M, cute::_1{}))); - auto Aux = cute::make_tensor(reference_Aux.host_data(), + auto Aux = cute::make_tensor(detail::make_iterator(IsAuxInEnabled ? tensor_Aux.host_data() : reference_Aux.host_data()), cute::make_layout(cute::make_shape(M, N, L), stride_Aux)); - auto Valpha = cute::make_tensor(alpha.host_data(), + auto Valpha = cute::make_tensor(detail::make_iterator(alpha.host_data()), cute::make_layout(cute::make_shape(M, cute::_1{}))); - auto Vbeta = cute::make_tensor(beta.host_data(), + auto Vbeta = cute::make_tensor(detail::make_iterator(beta.host_data()), cute::make_layout(cute::make_shape(M, cute::_1{}))); cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; @@ -1086,20 +1180,24 @@ struct Testbed3xFusionOperation { epilogue_params.scale_d = scale_D.at(coord_0); } - if constexpr (IsBiasEnabled) { + if constexpr (IsBiasEnabled or IsDeBiasEnabled) { epilogue_params.Bias = Bias; } - if constexpr (IsAbsMaxEnabled) { + if constexpr (IsAbsMaxEnabledD) { epilogue_params.abs_max_D = reference_abs_max_D.host_data(); } - if constexpr (IsAuxEnabled) { + if constexpr (IsAuxInEnabled) { + epilogue_params.Aux = Aux; + } + + if constexpr (IsAuxOutEnabled) { epilogue_params.Aux = Aux; if constexpr (IsScaleFactorEnabled) { epilogue_params.scale_aux = scale_Aux.at(coord_0); } - if constexpr (IsAbsMaxEnabled) { + if constexpr (IsAbsMaxEnabledAux) { epilogue_params.abs_max_Aux = reference_abs_max_Aux.host_data(); } } @@ -1121,6 +1219,8 @@ struct Testbed3xFusionOperation { ProblemShapeType problem_size, ElementScalar alpha_ = ElementScalar(1), ElementScalar beta_ = ElementScalar(0), + RasterOrderOptions raster_order = RasterOrderOptions::Heuristic, + detail::MaxSwizzleSize max_swizzle = detail::MaxSwizzleSize{}, detail::Splits splits = detail::Splits{}, bool profiling = false, detail::Iterations iterations = detail::Iterations{}) @@ -1136,6 +1236,8 @@ struct Testbed3xFusionOperation { typename Gemm::Arguments arguments; cutlass::KernelHardwareInfo hw_info; + cudaDeviceProp prop; + hw_info.device_id = 0; if (not profiling) { impl_.sm_count = min(impl_.MaxSmCount, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id)); @@ -1146,6 +1248,8 @@ struct Testbed3xFusionOperation { hw_info.sm_count = impl_.sm_count; } + cudaGetDeviceProperties(&prop, hw_info.device_id); + /// Initializes data structures /// A/B/C/D Tensor initialize(problem_size, alpha_, beta_); @@ -1172,7 +1276,7 @@ struct Testbed3xFusionOperation { hw_info, scheduler_args }; - + auto coord_0 = cutlass::make_Coord(0); if constexpr (IsLegacy) { arguments.epilogue.thread = { @@ -1186,7 +1290,7 @@ struct Testbed3xFusionOperation { } else { auto &fusion_args = arguments.epilogue.thread; - + fusion_args.alpha = alpha.at(coord_0); fusion_args.beta = beta.at(coord_0); fusion_args.alpha_ptr = alpha.device_data(); @@ -1207,6 +1311,10 @@ struct Testbed3xFusionOperation { fusion_args.bias_ptr = bias.device_data(); } + if constexpr (IsDeBiasEnabled) { + fusion_args.dbias_ptr = bias.device_data(); + } + // example of how to set kernel activation arguments if constexpr (cute::is_same_v>) { // see ActivationFunctor::Arguments in activation.h for definition @@ -1214,18 +1322,23 @@ struct Testbed3xFusionOperation { fusion_args.activation.scale = ElementCompute(1); } - if constexpr (IsAbsMaxEnabled) { + if constexpr (IsAbsMaxEnabledD) { fusion_args.amax_D_ptr = abs_max_D.device_data(); } - if constexpr (IsAuxEnabled) { + if constexpr (IsAuxInEnabled) { + fusion_args.aux_ptr = tensor_Aux.device_data(); + fusion_args.dAux = stride_Aux; + } + + if constexpr (IsAuxOutEnabled) { fusion_args.aux_ptr = tensor_Aux.device_data(); fusion_args.dAux = stride_Aux; if constexpr (IsScaleFactorEnabled) { fusion_args.scale_aux = scale_Aux.at(coord_0); fusion_args.scale_aux_ptr = scale_Aux.device_data(); } - if constexpr (IsAbsMaxEnabled) { + if constexpr (IsAbsMaxEnabledAux) { fusion_args.amax_aux_ptr = abs_max_Aux.device_data(); } } @@ -1277,6 +1390,7 @@ struct Testbed3xFusionOperation { } }; + ///////////////////////////////////////////////////////////////////////////////////////////////// template < @@ -1311,29 +1425,39 @@ bool TestAll(double alpha = 1.0, double beta = 0.0, Testbed testbed = {}) { problem_splits.push_back(Stages + 1); } + using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90::RasterOrderOptions; + std::vector raster_orders = {RasterOrderOptions::AlongM, RasterOrderOptions::AlongN}; + std::vector max_swizzle_sizes = {1, 4}; + bool passed = true; for (int m : problem_size_m) { for (int n : problem_size_n) { for (int k : problem_size_k) { - for (int splits : problem_splits) { - ProblemShapeType problem_size; - if constexpr (cute::rank(ProblemShapeType{}) == 4) { - problem_size = ProblemShapeType{m, n, k, /* l */ 1}; - } - else { - problem_size = ProblemShapeType{m, n, k}; - } - - passed = testbed.run( - problem_size, - cutlass::from_real(alpha), - cutlass::from_real(beta), - detail::Splits(splits) - ); - - if (!passed) { - return false; + for (auto raster_order : raster_orders) { + for (int max_swizzle_size : max_swizzle_sizes) { + for (int splits : problem_splits) { + ProblemShapeType problem_size; + if constexpr (cute::rank(ProblemShapeType{}) == 4) { + problem_size = ProblemShapeType{m, n, k, /* l */ 1}; + } + else { + problem_size = ProblemShapeType{m, n, k}; + } + + passed = testbed.run( + problem_size, + cutlass::from_real(alpha), + cutlass::from_real(beta), + raster_order, + detail::MaxSwizzleSize(max_swizzle_size), + detail::Splits(splits) + ); + + if (!passed) { + return false; + } + } } } } diff --git a/test/unit/gemm/device/gemm_universal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_universal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32_sm80.cu index 9a29512d5e..206f183f01 100644 --- a/test/unit/gemm/device/gemm_universal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32_sm80.cu +++ b/test/unit/gemm/device/gemm_universal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32_sm80.cu @@ -275,4 +275,4 @@ TEST(SM80_Device_GemmUniversal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32, 16x128 #endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -//////////////////////////////////////////////////////////////////////////////// \ No newline at end of file +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f16_sm80.cu b/test/unit/gemm/device/gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f16_sm80.cu index 035a33b982..dfab4a09a7 100644 --- a/test/unit/gemm/device/gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f16_sm80.cu +++ b/test/unit/gemm/device/gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f16_sm80.cu @@ -94,4 +94,4 @@ TEST(SM80_Device_GemmUniversal_f16t_s8n_f16t_mixed_input_tensor_op_f16, 128x128x #endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -//////////////////////////////////////////////////////////////////////////////// \ No newline at end of file +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f16_sm80.cu b/test/unit/gemm/device/gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f16_sm80.cu index eae1cb1044..c560ef718c 100644 --- a/test/unit/gemm/device/gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f16_sm80.cu +++ b/test/unit/gemm/device/gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f16_sm80.cu @@ -94,4 +94,4 @@ TEST(SM80_Device_GemmUniversal_f16t_u8t_f16t_mixed_input_tensor_op_f16, 128x128x #endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -//////////////////////////////////////////////////////////////////////////////// \ No newline at end of file +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_universal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_universal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32_sm80.cu index 340a5a1cf4..f4eab4a030 100644 --- a/test/unit/gemm/device/gemm_universal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32_sm80.cu +++ b/test/unit/gemm/device/gemm_universal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32_sm80.cu @@ -381,4 +381,4 @@ TEST(SM80_Device_GemmUniversal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32, 128x16 #endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -//////////////////////////////////////////////////////////////////////////////// \ No newline at end of file +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu b/test/unit/gemm/device/gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu index 8e5a70e8a8..89cb98debd 100644 --- a/test/unit/gemm/device/gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu +++ b/test/unit/gemm/device/gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu @@ -94,4 +94,4 @@ TEST(SM80_Device_GemmUniversal_s8t_f16n_f16t_mixed_input_tensor_op_f16, 128x128x #endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -//////////////////////////////////////////////////////////////////////////////// \ No newline at end of file +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu b/test/unit/gemm/device/gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu index ad153ba33f..710ec1753b 100644 --- a/test/unit/gemm/device/gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu +++ b/test/unit/gemm/device/gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu @@ -94,4 +94,4 @@ TEST(SM80_Device_GemmUniversal_u8t_f16t_f16t_mixed_input_tensor_op_f16, 128x128x #endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) -//////////////////////////////////////////////////////////////////////////////// \ No newline at end of file +//////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/device/sm90_evt_operations.hpp b/test/unit/gemm/device/sm90_evt_operations.hpp index 71425fee2c..b7e7231bd8 100644 --- a/test/unit/gemm/device/sm90_evt_operations.hpp +++ b/test/unit/gemm/device/sm90_evt_operations.hpp @@ -29,7 +29,7 @@ * **************************************************************************************************/ /*! \file - \brief Host reference and operations for Sm90 EVT unit test + \brief Host reference and operations for Sm90 EVT unit test */ #pragma once #include "gemm_testbed_3x_evt.hpp" @@ -53,10 +53,10 @@ class HostEVTAuxLoad { using ScalarAlpha = HostScalarBroadcast; using AccFetchNode = HostAccumulator; using AuxLoadNode = HostAuxLoad; - using TernaryCompute0 = HEVT, ScalarAlpha, AccFetchNode, AuxLoadNode>; + using TernaryCompute0 = HEVT, ScalarAlpha, AccFetchNode, AuxLoadNode>; using ScalarBeta = HostScalarBroadcast; using CLoadNode = HostAuxLoad; - using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, TernaryCompute0>; + using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, TernaryCompute0>; using EVTModule = HEVT, TernaryCompute1>; }; @@ -67,10 +67,10 @@ class HostPerColBias { using ScalarAlpha = HostScalarBroadcast; using AccFetchNode = HostAccumulator; using RowBroadcastNode = HostRowBroadcast; - using TernaryCompute0 = HEVT, ScalarAlpha, AccFetchNode, RowBroadcastNode>; + using TernaryCompute0 = HEVT, ScalarAlpha, AccFetchNode, RowBroadcastNode>; using ScalarBeta = HostScalarBroadcast; using CLoadNode = HostAuxLoad; - using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, TernaryCompute0>; + using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, TernaryCompute0>; using EVTModule = HEVT, TernaryCompute1>; }; @@ -95,13 +95,13 @@ class HostEVTDAG { ScalarAlpha, AccFetchNode, AuxLoadNode, - HostCompute, + HostCompute, HostCompute, HostCompute >; using ScalarBeta = HostScalarBroadcast; using CLoadNode = HostAuxLoad; - using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, DAGNode>; + using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, DAGNode>; using EVTModule = HEVT, TernaryCompute1>; }; @@ -114,7 +114,7 @@ class HostDAGEVT { using EVTNode = HEVT< HostAuxStore, HEVT< - HostCompute, + HostCompute, HostScalarBroadcast, HostAccumulator, HostAuxLoad @@ -133,7 +133,7 @@ class HostDAGEVT { EVTNode, HostColBroadcast, HostCompute, - HostCompute + HostCompute > >; }; @@ -147,13 +147,13 @@ class HostReduce { using BinaryCompute0 = HEVT, ScalarAlpha, AccFetchNode>; using ScalarBeta = HostScalarBroadcast; using CLoadNode = HostAuxLoad; - using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, BinaryCompute0>; + using TernaryCompute1 = HEVT, ScalarBeta, CLoadNode, BinaryCompute0>; using ReduceNode = HEVT, TernaryCompute1>; using EVTModule = HEVT, ReduceNode>; }; // Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias -// if D is fp8 +// if D is fp8 // D = scale_d * activation(Z) // else // D = activation(Z) @@ -167,11 +167,11 @@ class HostScaledLinCombPerRowBiasEltAct { HEVT< HostCompute, // activation(Z) HEVT< - HostCompute, + HostCompute, HostScalarBroadcast, // scale_c * beta HostAuxLoad, // C HEVT< - HostCompute, + HostCompute, HostScalarBroadcast, // scale_a * scale_b * alpha HostAccumulator, HostColBroadcast, @@ -184,12 +184,12 @@ class HostScaledLinCombPerRowBiasEltAct { }; // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias -// if D is fp8 +// if D is fp8 // amax_d = max(abs(elements in activation(Z))) // D = scale_d * activation(Z) // else // D = activation(Z) -// if Aux is fp8 +// if Aux is fp8 // amax_aux = max(abs(elements in Z)) // Aux = scale_aux * Z // else @@ -204,11 +204,11 @@ class HostScaledLinCombPerRowBiasEltActAmaxAux { HST, + HostCompute, HostScalarBroadcast, // scale_c * beta HostAuxLoad, // C HEVT< - HostCompute, + HostCompute, HostScalarBroadcast, // scale_a * scale_b * alpha HostAccumulator, HostColBroadcast, @@ -218,7 +218,7 @@ class HostScaledLinCombPerRowBiasEltActAmaxAux { HEVT< HostCompute::Op>, HEVT< - HostScalarReduce, + HostScalarReduce, HEVT< HostCompute, //activation(Z) * scaled_d HostAccumulator, // Z @@ -247,6 +247,13 @@ class HostScaledLinCombPerRowBiasEltActAmaxAux { namespace cutlass::epilogue { namespace fusion { +namespace detail { + +template +struct maximum_with_default_nan_propagation : maximum {}; + +} // namespace detail + ////////////////////////////////////////////////////////////////////////////// /// D = alpha * acc + beta * C + AuxLoad template< @@ -258,16 +265,16 @@ template< FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest > using Sm90LinCombAuxLoad = - Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90EVT, // beta * C + (alpha * acc + bias) Sm90ScalarBroadcast, // beta Sm90SrcFetch, // C - Sm90EVT, // alpha * acc + bias + Sm90EVT, // alpha * acc + bias Sm90ScalarBroadcast, // alpha Sm90AccFetch, // acc Sm90AuxLoad< - AuxLoadDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile, - typename AuxLoadDescriptor::Element, - typename AuxLoadDescriptor::Stride, typename AuxLoadDescriptor::SmemLayoutAtom, + AuxLoadDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile, + typename AuxLoadDescriptor::Element, + typename AuxLoadDescriptor::Stride, typename AuxLoadDescriptor::SmemLayoutAtom, typename AuxLoadDescriptor::CopyOpS2R // aux load > > @@ -286,7 +293,7 @@ template< FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest > using Sm90LinCombEVTDAG = - Sm90EVT, // beta * C + (alpha * acc + aux) + Sm90EVT, // beta * C + (alpha * acc + aux) Sm90ScalarBroadcast, // beta Sm90SrcFetch, // C Sm90TopologicalVisitor< @@ -302,13 +309,13 @@ using Sm90LinCombEVTDAG = Sm90ScalarBroadcast, // alpha Sm90AccFetch, // acc Sm90AuxLoad< - AuxLoadDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile, - typename AuxLoadDescriptor::Element, typename AuxLoadDescriptor::Stride, + AuxLoadDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile, + typename AuxLoadDescriptor::Element, typename AuxLoadDescriptor::Stride, typename AuxLoadDescriptor::SmemLayoutAtom, typename AuxLoadDescriptor::CopyOpS2R>, - Sm90Compute, + Sm90Compute, Sm90Compute, Sm90Compute - > + > >; @@ -336,10 +343,10 @@ using Sm90LinCombDAGEVT = >, Sm90EVT< Sm90AuxStore< - AuxStoreDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile, + AuxStoreDescriptor::Stages, typename EpilogueDescriptor::EpilogueTile, typename AuxStoreDescriptor::Element, RoundStyle, typename AuxStoreDescriptor::Stride, typename AuxStoreDescriptor::SmemLayoutAtom, typename AuxStoreDescriptor::CopyOpR2S>, - Sm90EVT, + Sm90EVT, Sm90ScalarBroadcast, Sm90AccFetch, Sm90SrcFetch @@ -347,7 +354,7 @@ using Sm90LinCombDAGEVT = >, Sm90ColBroadcast<0, typename EpilogueDescriptor::TileShape, ElementBias>, Sm90Compute, - Sm90Compute + Sm90Compute >; @@ -362,18 +369,18 @@ template< FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest > using Sm90LinCombPerColumnBias = - Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90EVT, // beta * C + (alpha * acc + bias) Sm90ScalarBroadcast, // beta Sm90SrcFetch, // C - Sm90EVT, // alpha * acc + bias + Sm90EVT, // alpha * acc + bias Sm90ScalarBroadcast, // alpha Sm90AccFetch, // acc Sm90RowBroadcast< ceil_div( - EpilogueDescriptor::StagesC, + EpilogueDescriptor::StagesC, size(shape_div(take<0, 2>(typename EpilogueDescriptor::TileShape{}), typename EpilogueDescriptor::EpilogueTile{})) - ) + 1, - typename EpilogueDescriptor::TileShape, + ) + 1, + typename EpilogueDescriptor::TileShape, ElementBias > > @@ -385,7 +392,7 @@ using Sm90LinCombPerColumnBias = template< template class RegReduceFn, template class GmemReduceFn, - class ElementReduce, + class ElementReduce, class CtaTileShapeMNK, class ElementOutput, class ElementCompute, @@ -394,7 +401,7 @@ template< > using Sm90LinCombPerColumnReduce = Sm90EVT, // per column reduce - Sm90EVT, // beta * C + alpha * acc + Sm90EVT, // beta * C + alpha * acc Sm90ScalarBroadcast, // beta Sm90SrcFetch, // C Sm90EVT, // alpha * acc @@ -410,7 +417,7 @@ using Sm90LinCombPerColumnReduce = template< template class RegReduceFn, template class GmemReduceFn, - class ElementReduce, + class ElementReduce, class CtaTileShapeMNK, class ElementOutput, class ElementCompute, @@ -419,7 +426,7 @@ template< > using Sm90LinCombPerRowReduce = Sm90EVT, // per column reduce - Sm90EVT, // beta * C + alpha * acc + Sm90EVT, // beta * C + alpha * acc Sm90ScalarBroadcast, // beta Sm90SrcFetch, // C Sm90EVT, // alpha * acc @@ -435,7 +442,7 @@ using Sm90LinCombPerRowReduce = template< template class RegReduceFn, template class GmemReduceFn, - class ElementReduce, + class ElementReduce, class ElementOutput, class ElementCompute, class ElementScalar = ElementCompute, @@ -443,7 +450,7 @@ template< > using Sm90LinCombScalarReduce = Sm90EVT, // per column reduce - Sm90EVT, // beta * C + alpha * acc + Sm90EVT, // beta * C + alpha * acc Sm90ScalarBroadcast, // beta Sm90SrcFetch, // C Sm90EVT, // alpha * acc diff --git a/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32.cu b/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32.cu index b0279f0fff..6d29b2315f 100644 --- a/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32.cu +++ b/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32.cu @@ -58,44 +58,7 @@ using namespace cute; /////////////////////////////////////////////////////////////////////////////// -TEST(SM90_Device_Gemm_bf16t_bf16t_bf16n_align8_tensor_op_gmma_f32, 64x128x64) { - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::RowMajor; - using LayoutC = cutlass::layout::ColumnMajor; - - using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - cutlass::bfloat16_t, LayoutA, 8, - cutlass::bfloat16_t, LayoutB, 8, - float, - Shape<_64,_128,_64>, Shape<_1,_1,_1>, - cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::collective::KernelScheduleAuto - >::CollectiveOp; - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - Shape<_64,_128,_64>, Shape<_1,_1,_1>, - cutlass::epilogue::collective::EpilogueTileAuto, - float, float, - cutlass::bfloat16_t, LayoutC, 8, - cutlass::bfloat16_t, LayoutC, 8, - cutlass::epilogue::collective::EpilogueScheduleAuto - >::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveOp, - CollectiveEpilogue - >; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); -} - -/////////////////////////////////////////////////////////////////////////////// - -TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align4_tensor_op_gmma_f32, 64x128x64) { +TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align4_tensor_op_gmma_f32, 128x128x64) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::ColumnMajor; @@ -105,14 +68,14 @@ TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align4_tensor_op_gmma_f32, 64x128x64) { cutlass::bfloat16_t, LayoutA, 4, cutlass::bfloat16_t, LayoutB, 4, float, - Shape<_64,_128,_64>, Shape<_1,_1,_1>, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, cutlass::gemm::collective::StageCountAuto, cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - Shape<_64,_128,_64>, Shape<_1,_1,_1>, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, cutlass::epilogue::collective::EpilogueTileAuto, float, float, cutlass::bfloat16_t, LayoutC, 4, @@ -132,9 +95,9 @@ TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align4_tensor_op_gmma_f32, 64x128x64) { /////////////////////////////////////////////////////////////////////////////// -TEST(SM90_Device_Gemm_bf16n_bf16t_bf16n_align2_tensor_op_gmma_f32, 64x128x64) { - using LayoutA = cutlass::layout::ColumnMajor; - using LayoutB = cutlass::layout::RowMajor; +TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align2_tensor_op_gmma_f32, 128x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::ColumnMajor; using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< @@ -142,14 +105,14 @@ TEST(SM90_Device_Gemm_bf16n_bf16t_bf16n_align2_tensor_op_gmma_f32, 64x128x64) { cutlass::bfloat16_t, LayoutA, 2, cutlass::bfloat16_t, LayoutB, 2, float, - Shape<_64,_128,_64>, Shape<_1,_1,_1>, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, cutlass::gemm::collective::StageCountAuto, cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - Shape<_64,_128,_64>, Shape<_1,_1,_1>, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, cutlass::epilogue::collective::EpilogueTileAuto, float, float, cutlass::bfloat16_t, LayoutC, 2, @@ -169,41 +132,4 @@ TEST(SM90_Device_Gemm_bf16n_bf16t_bf16n_align2_tensor_op_gmma_f32, 64x128x64) { /////////////////////////////////////////////////////////////////////////////// -TEST(SM90_Device_Gemm_bf16n_bf16n_bf16n_align8_tensor_op_gmma_f32, 64x128x64) { - using LayoutA = cutlass::layout::ColumnMajor; - using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::ColumnMajor; - - using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - cutlass::bfloat16_t, LayoutA, 8, - cutlass::bfloat16_t, LayoutB, 8, - float, - Shape<_64,_128,_64>, Shape<_1,_1,_1>, - cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::collective::KernelScheduleAuto - >::CollectiveOp; - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - Shape<_64,_128,_64>, Shape<_1,_1,_1>, - cutlass::epilogue::collective::EpilogueTileAuto, - float, float, - cutlass::bfloat16_t, LayoutC, 8, - cutlass::bfloat16_t, LayoutC, 8, - cutlass::epilogue::collective::EpilogueScheduleAuto - >::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveOp, - CollectiveEpilogue - >; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); -} - -/////////////////////////////////////////////////////////////////////////////// - #endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized.cu b/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized.cu new file mode 100644 index 0000000000..c8b2964cd5 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized.cu @@ -0,0 +1,172 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align8_tensor_op_gmma_f32_warpspecialized, 64x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::bfloat16_t, LayoutA, 8, + cutlass::bfloat16_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecialized + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::bfloat16_t, LayoutC, 8, + cutlass::bfloat16_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align4_tensor_op_gmma_f32_warpspecialized, 64x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::bfloat16_t, LayoutA, 4, + cutlass::bfloat16_t, LayoutB, 4, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecialized + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::bfloat16_t, LayoutC, 4, + cutlass::bfloat16_t, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align2_tensor_op_gmma_f32_warpspecialized, 64x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::bfloat16_t, LayoutA, 2, + cutlass::bfloat16_t, LayoutB, 2, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecialized + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::bfloat16_t, LayoutC, 2, + cutlass::bfloat16_t, LayoutC, 2, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_cooperative.cu b/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_cooperative.cu new file mode 100644 index 0000000000..fdfcc54e81 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_cooperative.cu @@ -0,0 +1,172 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align8_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::bfloat16_t, LayoutA, 8, + cutlass::bfloat16_t, LayoutB, 8, + float, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::bfloat16_t, LayoutC, 8, + cutlass::bfloat16_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align4_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::bfloat16_t, LayoutA, 4, + cutlass::bfloat16_t, LayoutB, 4, + float, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::bfloat16_t, LayoutC, 4, + cutlass::bfloat16_t, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align2_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::bfloat16_t, LayoutA, 2, + cutlass::bfloat16_t, LayoutB, 2, + float, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::bfloat16_t, LayoutC, 2, + cutlass::bfloat16_t, LayoutC, 2, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_pingpong.cu b/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_pingpong.cu new file mode 100644 index 0000000000..8e125c9023 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_pingpong.cu @@ -0,0 +1,172 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align8_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::bfloat16_t, LayoutA, 8, + cutlass::bfloat16_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::bfloat16_t, LayoutC, 8, + cutlass::bfloat16_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align4_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::bfloat16_t, LayoutA, 4, + cutlass::bfloat16_t, LayoutB, 4, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::bfloat16_t, LayoutC, 4, + cutlass::bfloat16_t, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_bf16t_bf16n_bf16n_align2_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::bfloat16_t, LayoutA, 2, + cutlass::bfloat16_t, LayoutB, 2, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::bfloat16_t, LayoutC, 2, + cutlass::bfloat16_t, LayoutC, 2, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_alignx_tensor_op_f32.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_alignx_tensor_op_f32.cu new file mode 100644 index 0000000000..ee1a3d17f2 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_alignx_tensor_op_f32.cu @@ -0,0 +1,365 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////// TT ////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16t_f16n_align4_tensor_op_gmma_f32, 128x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 4, + cutlass::half_t, LayoutB, 4, + float, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 4, + cutlass::half_t, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16t_f16n_align2_tensor_op_gmma_f32, 128x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 2, + cutlass::half_t, LayoutB, 2, + float, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 2, + cutlass::half_t, LayoutC, 2, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////// TN ////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f16n_align4_tensor_op_gmma_f32, 128x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 4, + cutlass::half_t, LayoutB, 4, + float, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 4, + cutlass::half_t, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f16n_align2_tensor_op_gmma_f32, 128x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 2, + cutlass::half_t, LayoutB, 2, + float, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 2, + cutlass::half_t, LayoutC, 2, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////// NT ////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16t_f16n_align4_tensor_op_gmma_f32, 128x128x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 4, + cutlass::half_t, LayoutB, 4, + float, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 4, + cutlass::half_t, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16t_f16n_align2_tensor_op_gmma_f32, 128x128x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 2, + cutlass::half_t, LayoutB, 2, + float, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 2, + cutlass::half_t, LayoutC, 2, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////// NN ////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16n_f16n_align4_tensor_op_gmma_f32, 128x128x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 4, + cutlass::half_t, LayoutB, 4, + float, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 4, + cutlass::half_t, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16n_f16n_align2_tensor_op_gmma_f32, 128x128x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 2, + cutlass::half_t, LayoutB, 2, + float, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 2, + cutlass::half_t, LayoutC, 2, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_alignx_tensor_op.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_alignx_tensor_op_f32_warpspecialized.cu similarity index 91% rename from test/unit/gemm/device/sm90_gemm_f16_f16_f16_alignx_tensor_op.cu rename to test/unit/gemm/device/sm90_gemm_f16_f16_f16_alignx_tensor_op_f32_warpspecialized.cu index 1e0c395b8b..324c3527e3 100644 --- a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_alignx_tensor_op.cu +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_alignx_tensor_op_f32_warpspecialized.cu @@ -60,7 +60,7 @@ using namespace cute; ///////////////////////////////////// TT ////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////// -TEST(SM90_Device_Gemm_f16t_f16t_f16n_align8_tensor_op_gmma_f32, 64x128x64) { +TEST(SM90_Device_Gemm_f16t_f16t_f16n_align8_tensor_op_gmma_f32_warpspecialized, 64x128x64) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::RowMajor; using LayoutC = cutlass::layout::ColumnMajor; @@ -72,7 +72,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f16n_align8_tensor_op_gmma_f32, 64x128x64) { float, Shape<_64,_128,_64>, Shape<_1,_1,_1>, cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::KernelMultistage + cutlass::gemm::KernelCpAsyncWarpSpecialized >::CollectiveOp; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< @@ -95,7 +95,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f16n_align8_tensor_op_gmma_f32, 64x128x64) { EXPECT_TRUE(test::gemm::device::TestAll()); } -TEST(SM90_Device_Gemm_f16t_f16t_f16n_align4_tensor_op_gmma_f32, 64x128x64) { +TEST(SM90_Device_Gemm_f16t_f16t_f16n_align4_tensor_op_gmma_f32_warpspecialized, 64x128x64) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::RowMajor; using LayoutC = cutlass::layout::ColumnMajor; @@ -107,7 +107,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f16n_align4_tensor_op_gmma_f32, 64x128x64) { float, Shape<_64,_128,_64>, Shape<_1,_1,_1>, cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::collective::KernelScheduleAuto + cutlass::gemm::KernelCpAsyncWarpSpecialized >::CollectiveOp; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< @@ -131,7 +131,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f16n_align4_tensor_op_gmma_f32, 64x128x64) { } -TEST(SM90_Device_Gemm_f16t_f16t_f16n_align2_tensor_op_gmma_f32, 64x128x64) { +TEST(SM90_Device_Gemm_f16t_f16t_f16n_align2_tensor_op_gmma_f32_warpspecialized, 64x128x64) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::RowMajor; using LayoutC = cutlass::layout::ColumnMajor; @@ -143,7 +143,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f16n_align2_tensor_op_gmma_f32, 64x128x64) { float, Shape<_64,_128,_64>, Shape<_1,_1,_1>, cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::collective::KernelScheduleAuto + cutlass::gemm::KernelCpAsyncWarpSpecialized >::CollectiveOp; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< @@ -170,7 +170,7 @@ TEST(SM90_Device_Gemm_f16t_f16t_f16n_align2_tensor_op_gmma_f32, 64x128x64) { ///////////////////////////////////// TN ////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////// -TEST(SM90_Device_Gemm_f16t_f16n_f16n_align8_tensor_op_gmma_f32, 64x128x64) { +TEST(SM90_Device_Gemm_f16t_f16n_f16n_align8_tensor_op_gmma_f32_warpspecialized, 64x128x64) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::ColumnMajor; @@ -182,7 +182,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_align8_tensor_op_gmma_f32, 64x128x64) { float, Shape<_64,_128,_64>, Shape<_1,_1,_1>, cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::KernelMultistage + cutlass::gemm::KernelCpAsyncWarpSpecialized >::CollectiveOp; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< @@ -207,7 +207,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_align8_tensor_op_gmma_f32, 64x128x64) { /////////////////////////////////////////////////////////////////////////////// -TEST(SM90_Device_Gemm_f16t_f16n_f16n_align4_tensor_op_gmma_f32, 64x128x64) { +TEST(SM90_Device_Gemm_f16t_f16n_f16n_align4_tensor_op_gmma_f32_warpspecialized, 64x128x64) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::ColumnMajor; @@ -219,7 +219,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_align4_tensor_op_gmma_f32, 64x128x64) { float, Shape<_64,_128,_64>, Shape<_1,_1,_1>, cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::collective::KernelScheduleAuto + cutlass::gemm::KernelCpAsyncWarpSpecialized >::CollectiveOp; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< @@ -244,7 +244,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_align4_tensor_op_gmma_f32, 64x128x64) { /////////////////////////////////////////////////////////////////////////////// -TEST(SM90_Device_Gemm_f16t_f16n_f16n_align2_tensor_op_gmma_f32, 64x128x64) { +TEST(SM90_Device_Gemm_f16t_f16n_f16n_align2_tensor_op_gmma_f32_warpspecialized, 64x128x64) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::ColumnMajor; @@ -256,7 +256,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_align2_tensor_op_gmma_f32, 64x128x64) { float, Shape<_64,_128,_64>, Shape<_1,_1,_1>, cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::collective::KernelScheduleAuto + cutlass::gemm::KernelCpAsyncWarpSpecialized >::CollectiveOp; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< @@ -283,7 +283,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f16n_align2_tensor_op_gmma_f32, 64x128x64) { ///////////////////////////////////// NT ////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////// -TEST(SM90_Device_Gemm_f16n_f16t_f16n_align8_tensor_op_gmma_f32, 64x128x64) { +TEST(SM90_Device_Gemm_f16n_f16t_f16n_align8_tensor_op_gmma_f32_warpspecialized, 64x128x64) { using LayoutA = cutlass::layout::ColumnMajor; using LayoutB = cutlass::layout::RowMajor; using LayoutC = cutlass::layout::ColumnMajor; @@ -295,7 +295,7 @@ TEST(SM90_Device_Gemm_f16n_f16t_f16n_align8_tensor_op_gmma_f32, 64x128x64) { float, Shape<_64,_128,_64>, Shape<_1,_1,_1>, cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::KernelMultistage + cutlass::gemm::KernelCpAsyncWarpSpecialized >::CollectiveOp; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< @@ -320,7 +320,7 @@ TEST(SM90_Device_Gemm_f16n_f16t_f16n_align8_tensor_op_gmma_f32, 64x128x64) { /////////////////////////////////////////////////////////////////////////////// -TEST(SM90_Device_Gemm_f16n_f16t_f16n_align4_tensor_op_gmma_f32, 64x128x64) { +TEST(SM90_Device_Gemm_f16n_f16t_f16n_align4_tensor_op_gmma_f32_warpspecialized, 64x128x64) { using LayoutA = cutlass::layout::ColumnMajor; using LayoutB = cutlass::layout::RowMajor; using LayoutC = cutlass::layout::ColumnMajor; @@ -332,7 +332,7 @@ TEST(SM90_Device_Gemm_f16n_f16t_f16n_align4_tensor_op_gmma_f32, 64x128x64) { float, Shape<_64,_128,_64>, Shape<_1,_1,_1>, cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::collective::KernelScheduleAuto + cutlass::gemm::KernelCpAsyncWarpSpecialized >::CollectiveOp; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< @@ -357,7 +357,7 @@ TEST(SM90_Device_Gemm_f16n_f16t_f16n_align4_tensor_op_gmma_f32, 64x128x64) { /////////////////////////////////////////////////////////////////////////////// -TEST(SM90_Device_Gemm_f16n_f16t_f16n_align2_tensor_op_gmma_f32, 64x128x64) { +TEST(SM90_Device_Gemm_f16n_f16t_f16n_align2_tensor_op_gmma_f32_warpspecialized, 64x128x64) { using LayoutA = cutlass::layout::ColumnMajor; using LayoutB = cutlass::layout::RowMajor; using LayoutC = cutlass::layout::ColumnMajor; @@ -369,7 +369,7 @@ TEST(SM90_Device_Gemm_f16n_f16t_f16n_align2_tensor_op_gmma_f32, 64x128x64) { float, Shape<_64,_128,_64>, Shape<_1,_1,_1>, cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::collective::KernelScheduleAuto + cutlass::gemm::KernelCpAsyncWarpSpecialized >::CollectiveOp; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< @@ -396,7 +396,7 @@ TEST(SM90_Device_Gemm_f16n_f16t_f16n_align2_tensor_op_gmma_f32, 64x128x64) { ///////////////////////////////////// NN ////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////// -TEST(SM90_Device_Gemm_f16n_f16n_f16n_align8_tensor_op_gmma_f32, 64x128x64) { +TEST(SM90_Device_Gemm_f16n_f16n_f16n_align8_tensor_op_gmma_f32_warpspecialized, 64x128x64) { using LayoutA = cutlass::layout::ColumnMajor; using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::ColumnMajor; @@ -408,7 +408,7 @@ TEST(SM90_Device_Gemm_f16n_f16n_f16n_align8_tensor_op_gmma_f32, 64x128x64) { float, Shape<_64,_128,_64>, Shape<_1,_1,_1>, cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::KernelMultistage + cutlass::gemm::KernelCpAsyncWarpSpecialized >::CollectiveOp; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< @@ -433,7 +433,7 @@ TEST(SM90_Device_Gemm_f16n_f16n_f16n_align8_tensor_op_gmma_f32, 64x128x64) { /////////////////////////////////////////////////////////////////////////////// -TEST(SM90_Device_Gemm_f16n_f16n_f16n_align4_tensor_op_gmma_f32, 64x128x64) { +TEST(SM90_Device_Gemm_f16n_f16n_f16n_align4_tensor_op_gmma_f32_warpspecialized, 64x128x64) { using LayoutA = cutlass::layout::ColumnMajor; using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::ColumnMajor; @@ -445,7 +445,7 @@ TEST(SM90_Device_Gemm_f16n_f16n_f16n_align4_tensor_op_gmma_f32, 64x128x64) { float, Shape<_64,_128,_64>, Shape<_1,_1,_1>, cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::collective::KernelScheduleAuto + cutlass::gemm::KernelCpAsyncWarpSpecialized >::CollectiveOp; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< @@ -470,7 +470,7 @@ TEST(SM90_Device_Gemm_f16n_f16n_f16n_align4_tensor_op_gmma_f32, 64x128x64) { /////////////////////////////////////////////////////////////////////////////// -TEST(SM90_Device_Gemm_f16n_f16n_f16n_align2_tensor_op_gmma_f32, 64x128x64) { +TEST(SM90_Device_Gemm_f16n_f16n_f16n_align2_tensor_op_gmma_f32_warpspecialized, 64x128x64) { using LayoutA = cutlass::layout::ColumnMajor; using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::ColumnMajor; @@ -482,7 +482,7 @@ TEST(SM90_Device_Gemm_f16n_f16n_f16n_align2_tensor_op_gmma_f32, 64x128x64) { float, Shape<_64,_128,_64>, Shape<_1,_1,_1>, cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::collective::KernelScheduleAuto + cutlass::gemm::KernelCpAsyncWarpSpecialized >::CollectiveOp; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_alignx_tensor_op_f32_warpspecialized_cooperative.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_alignx_tensor_op_f32_warpspecialized_cooperative.cu new file mode 100644 index 0000000000..410e6fd964 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_alignx_tensor_op_f32_warpspecialized_cooperative.cu @@ -0,0 +1,510 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////// TT ////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16t_f16n_align8_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16t_f16n_align4_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 4, + cutlass::half_t, LayoutB, 4, + float, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 4, + cutlass::half_t, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + + +TEST(SM90_Device_Gemm_f16t_f16t_f16n_align2_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 2, + cutlass::half_t, LayoutB, 2, + float, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 2, + cutlass::half_t, LayoutC, 2, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////// TN ////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f16n_align8_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f16n_align4_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 4, + cutlass::half_t, LayoutB, 4, + float, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 4, + cutlass::half_t, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f16n_align2_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 2, + cutlass::half_t, LayoutB, 2, + float, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 2, + cutlass::half_t, LayoutC, 2, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////// NT ////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16t_f16n_align8_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16t_f16n_align4_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 4, + cutlass::half_t, LayoutB, 4, + float, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 4, + cutlass::half_t, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16t_f16n_align2_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 2, + cutlass::half_t, LayoutB, 2, + float, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 2, + cutlass::half_t, LayoutC, 2, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////// NN ////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16n_f16n_align8_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16n_f16n_align4_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 4, + cutlass::half_t, LayoutB, 4, + float, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 4, + cutlass::half_t, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16n_f16n_align2_tensor_op_gmma_f32_warpspecialized_cooperative, 128x128x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 2, + cutlass::half_t, LayoutB, 2, + float, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 2, + cutlass::half_t, LayoutC, 2, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_alignx_tensor_op_f32_warpspecialized_pingpong.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_alignx_tensor_op_f32_warpspecialized_pingpong.cu new file mode 100644 index 0000000000..ec8eef31e6 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_alignx_tensor_op_f32_warpspecialized_pingpong.cu @@ -0,0 +1,510 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////// TT ////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16t_f16n_align8_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_f16t_f16t_f16n_align4_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 4, + cutlass::half_t, LayoutB, 4, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 4, + cutlass::half_t, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + + +TEST(SM90_Device_Gemm_f16t_f16t_f16n_align2_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 2, + cutlass::half_t, LayoutB, 2, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 2, + cutlass::half_t, LayoutC, 2, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////// TN ////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f16n_align8_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f16n_align4_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 4, + cutlass::half_t, LayoutB, 4, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 4, + cutlass::half_t, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16t_f16n_f16n_align2_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 2, + cutlass::half_t, LayoutB, 2, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 2, + cutlass::half_t, LayoutC, 2, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////// NT ////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16t_f16n_align8_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16t_f16n_align4_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 4, + cutlass::half_t, LayoutB, 4, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 4, + cutlass::half_t, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16t_f16n_align2_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 2, + cutlass::half_t, LayoutB, 2, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 2, + cutlass::half_t, LayoutC, 2, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// +///////////////////////////////////// NN ////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16n_f16n_align8_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16n_f16n_align4_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 4, + cutlass::half_t, LayoutB, 4, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 4, + cutlass::half_t, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_f16n_f16n_f16n_align2_tensor_op_gmma_f32_warpspecialized_pingpong, 64x128x64) { + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 2, + cutlass::half_t, LayoutB, 2, + float, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_64,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + cutlass::half_t, LayoutC, 2, + cutlass::half_t, LayoutC, 2, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu index a24d8d2b34..ac90820d70 100644 --- a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu @@ -469,7 +469,7 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 EXPECT_TRUE(passed); } -TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasS8_ReLU_VoidC) { +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_BiasS8_ReLU_VoidC_U1Aux) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; @@ -478,8 +478,9 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 using ClusterShape_MNK = Shape<_2,_2,_1>; using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + // ReLU with uint1b_t aux will compute dReLU/dZ as the aux output, i.e. Aux(i) = (Z(i) >= 0) ? 1 : 0 using FusionOperation = cutlass::epilogue::fusion::LinCombPerRowBiasEltActAux< - LayoutC, cutlass::epilogue::thread::ReLu, cutlass::half_t, float, cutlass::half_t, int8_t>; + LayoutC, cutlass::epilogue::thread::ReLU, cutlass::half_t, float, cutlass::uint1b_t, int8_t>; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, @@ -514,4 +515,94 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 25 EXPECT_TRUE(passed); } +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_dReLU_dBias_VoidC) { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::LinCombDeEltActDePerRowBias< + LayoutC, cutlass::epilogue::thread::dReLU, cutlass::half_t, float, cutlass::uint1b_t, float>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + bool passed = test::gemm::device::TestAllBiasElementwise(); + EXPECT_TRUE(passed); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_cooperative_epilogue, 256x128x64_2x2x1_dGELU_VoidC) { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_256,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::LinCombDeEltAct< + LayoutC, cutlass::epilogue::thread::dGELU, cutlass::half_t, float, cutlass::half_t>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedCooperative + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + bool passed = test::gemm::device::TestAllBiasElementwise(1.0, 0.0, /*check_relative_equality=*/true); + EXPECT_TRUE(passed); +} + #endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_bias_elementwise.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_bias_elementwise.cu index b3af865116..18660318f2 100644 --- a/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_bias_elementwise.cu +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_pingpong_bias_elementwise.cu @@ -460,4 +460,49 @@ TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_persistent_epilogue, 128 EXPECT_TRUE(passed); } +TEST(SM90_Device_Gemm_f16t_f16n_f32t_tensor_op_gmma_f32_pingpong_epilogue, 128x128x64_2x2x1_dReLU_dBias_VoidC) { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + using TileShape_MNK = Shape<_128,_128,_64>; + using ClusterShape_MNK = Shape<_2,_2,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + using FusionOperation = cutlass::epilogue::fusion::LinCombDeEltActDePerRowBias< + LayoutC, cutlass::epilogue::thread::dReLU, cutlass::half_t, float, cutlass::uint1b_t, float>; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + void, LayoutC, 8, + cutlass::half_t, LayoutC, 8, + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::half_t, LayoutA, 8, + cutlass::half_t, LayoutB, 8, + float, + TileShape_MNK, ClusterShape_MNK, + cutlass::gemm::collective::StageCountAutoCarveout, + cutlass::gemm::KernelTmaWarpSpecializedPingpong + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + bool passed = test::gemm::device::TestAllBiasElementwise(); + EXPECT_TRUE(passed); +} + #endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f16_f16_f32_tensor_op_f32_rs_cluster_warpspecialized_cooperative.cu b/test/unit/gemm/device/sm90_gemm_f16_f16_f32_tensor_op_f32_rs_cluster_warpspecialized_cooperative.cu new file mode 100644 index 0000000000..8da1c0a281 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f16_f16_f32_tensor_op_f32_rs_cluster_warpspecialized_cooperative.cu @@ -0,0 +1,209 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + /*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative, 128x192x64_1x1x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + + using TileShape_MNK = Shape<_128,_192,_64>; + using ClusterShape_MNK = Shape<_1,_1,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, ElementC, ElementAccumulator, ElementAccumulator>; + + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementC, LayoutC, 16 / sizeof(ElementC), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using KernelScheduleType = cutlass::gemm::KernelTmaWarpSpecializedCooperative; + using AtomLayoutMNK = Layout>; + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::rs_op_selector< + ElementA, ElementB, ElementAccumulator, TileShape_MNK, GMMA::Major::K, GMMA::Major::K>(), AtomLayoutMNK{})); + using GmemTiledCopyA = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + static constexpr cute::GMMA::Major GmmaMajorA = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_B(); + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::rs_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), false>()); + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::rs_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), false>()); + + using StageCountType = cutlass::gemm::collective::StageCountAutoCarveout; + + static constexpr int PipelineStages = cutlass::gemm::collective::detail::compute_stage_count_or_override< + cutlass::gemm::collective::detail::sm90_smem_capacity_bytes, + ElementA, ElementB, TileShape_MNK>(StageCountType{}); + using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmmaRmemAWarpSpecialized< + PipelineStages, ClusterShape_MNK, KernelScheduleType>; + + using CollectiveOp = cutlass::gemm::collective::CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + cutlass::gemm::TagToStrideA_t, + ElementB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + cute::Copy_Atom, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + void, + cute::identity + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +TEST(SM90_Device_Gemm_f16t_f16n_f32n_tensor_op_gmma_f32_cooperative, 128x192x64_2x1x1) { + using ElementA = cutlass::half_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::half_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + + using TileShape_MNK = Shape<_128,_192,_64>; + using ClusterShape_MNK = Shape<_2,_1,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, ElementC, ElementAccumulator, ElementAccumulator>; + + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementC, LayoutC, 16 / sizeof(ElementC), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using KernelScheduleType = cutlass::gemm::KernelTmaWarpSpecializedCooperative; + using AtomLayoutMNK = Layout>; + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::rs_op_selector< + ElementA, ElementB, ElementAccumulator, TileShape_MNK, GMMA::Major::K, GMMA::Major::K>(), AtomLayoutMNK{})); + using GmemTiledCopyA = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + static constexpr cute::GMMA::Major GmmaMajorA = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_B(); + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::rs_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), false>()); + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::rs_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), false>()); + + using StageCountType = cutlass::gemm::collective::StageCountAutoCarveout; + + static constexpr int PipelineStages = cutlass::gemm::collective::detail::compute_stage_count_or_override< + cutlass::gemm::collective::detail::sm90_smem_capacity_bytes, + ElementA, ElementB, TileShape_MNK>(StageCountType{}); + using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmmaRmemAWarpSpecialized< + PipelineStages, ClusterShape_MNK, KernelScheduleType>; + + using CollectiveOp = cutlass::gemm::collective::CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + cutlass::gemm::TagToStrideA_t, + ElementB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + cute::Copy_Atom, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + void, + cute::identity + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_f8_f8_f32_tensor_op_f32_rs_cluster_warpspecialized_cooperative.cu b/test/unit/gemm/device/sm90_gemm_f8_f8_f32_tensor_op_f32_rs_cluster_warpspecialized_cooperative.cu new file mode 100644 index 0000000000..8204b7f901 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_f8_f8_f32_tensor_op_f32_rs_cluster_warpspecialized_cooperative.cu @@ -0,0 +1,209 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + /*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative, 128x128x128_1x1x1) { + using ElementA = cutlass::float_e4m3_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::float_e4m3_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + + using TileShape_MNK = Shape<_128,_128,_128>; + using ClusterShape_MNK = Shape<_1,_1,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, ElementC, ElementAccumulator, ElementAccumulator>; + + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementC, LayoutC, 16 / sizeof(ElementC), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using KernelScheduleType = cutlass::gemm::KernelTmaWarpSpecializedCooperative; + using AtomLayoutMNK = Layout>; + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::rs_op_selector< + ElementA, ElementB, ElementAccumulator, TileShape_MNK, GMMA::Major::K, GMMA::Major::K>(), AtomLayoutMNK{})); + using GmemTiledCopyA = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + static constexpr cute::GMMA::Major GmmaMajorA = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_B(); + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::rs_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), false>()); + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::rs_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), false>()); + + using StageCountType = cutlass::gemm::collective::StageCountAutoCarveout; + + static constexpr int PipelineStages = cutlass::gemm::collective::detail::compute_stage_count_or_override< + cutlass::gemm::collective::detail::sm90_smem_capacity_bytes, + ElementA, ElementB, TileShape_MNK>(StageCountType{}); + using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmmaRmemAWarpSpecialized< + PipelineStages, ClusterShape_MNK, KernelScheduleType>; + + using CollectiveOp = cutlass::gemm::collective::CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + cutlass::gemm::TagToStrideA_t, + ElementB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + cute::Copy_Atom, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + void, + cute::identity + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +TEST(SM90_Device_Gemm_e4m3t_e4m3n_f32n_tensor_op_gmma_f32_cooperative, 128x128x128_2x1x1) { + using ElementA = cutlass::float_e4m3_t; + using LayoutA = cutlass::layout::RowMajor; + using ElementB = cutlass::float_e4m3_t; + using LayoutB = cutlass::layout::ColumnMajor; + using ElementC = float; + using LayoutC = cutlass::layout::ColumnMajor; + using ElementAccumulator = float; + + using TileShape_MNK = Shape<_128,_128,_128>; + using ClusterShape_MNK = Shape<_2,_1,_1>; + + using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; + using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltAct< + cutlass::epilogue::thread::Identity, ElementC, ElementAccumulator, ElementAccumulator>; + + using EpilogueOp = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape_MNK, ClusterShape_MNK, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, 16 / sizeof(ElementC), + ElementC, LayoutC, 16 / sizeof(ElementC), + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + + using KernelScheduleType = cutlass::gemm::KernelTmaWarpSpecializedCooperative; + using AtomLayoutMNK = Layout>; + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::rs_op_selector< + ElementA, ElementB, ElementAccumulator, TileShape_MNK, GMMA::Major::K, GMMA::Major::K>(), AtomLayoutMNK{})); + using GmemTiledCopyA = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + static constexpr cute::GMMA::Major GmmaMajorA = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = cutlass::gemm::collective::detail::gmma_rs_tag_to_major_B(); + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::rs_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), false>()); + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::rs_smem_selector(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{})), false>()); + + using StageCountType = cutlass::gemm::collective::StageCountAutoCarveout; + + static constexpr int PipelineStages = cutlass::gemm::collective::detail::compute_stage_count_or_override< + cutlass::gemm::collective::detail::sm90_smem_capacity_bytes, + ElementA, ElementB, TileShape_MNK>(StageCountType{}); + using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmmaRmemAWarpSpecialized< + PipelineStages, ClusterShape_MNK, KernelScheduleType>; + + using CollectiveOp = cutlass::gemm::collective::CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + cutlass::gemm::TagToStrideA_t, + ElementB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + cute::Copy_Atom, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + void, + cute::identity + >; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + EpilogueOp + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAllBiasElementwise()); +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_s8_s8_s8_alignx_tensor_op_s32.cu b/test/unit/gemm/device/sm90_gemm_s8_s8_s8_alignx_tensor_op_s32.cu index 174fb1b496..e51b11fe80 100644 --- a/test/unit/gemm/device/sm90_gemm_s8_s8_s8_alignx_tensor_op_s32.cu +++ b/test/unit/gemm/device/sm90_gemm_s8_s8_s8_alignx_tensor_op_s32.cu @@ -58,7 +58,7 @@ using namespace cute; /////////////////////////////////////////////////////////////////////////////// -TEST(SM90_Device_Gemm_s8t_s8n_s8n_align8_tensor_op_gmma_s32, 64x128x128) { +TEST(SM90_Device_Gemm_s8t_s8n_s8n_align8_tensor_op_gmma_s32, 128x128x128) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::ColumnMajor; @@ -68,44 +68,9 @@ TEST(SM90_Device_Gemm_s8t_s8n_s8n_align8_tensor_op_gmma_s32, 64x128x128) { int8_t, LayoutA, 8, int8_t, LayoutB, 8, int32_t, - Shape<_64,_128,_128>, Shape<_1,_1,_1>, - cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::collective::KernelScheduleAuto - >::CollectiveOp; - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - Shape<_64,_128,_128>, Shape<_1,_1,_1>, - cutlass::epilogue::collective::EpilogueTileAuto, - int32_t, int32_t, - int8_t, LayoutC, 8, - int8_t, LayoutC, 8, - cutlass::epilogue::collective::EpilogueScheduleAuto - >::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveOp, - CollectiveEpilogue - >; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); -} - -TEST(SM90_Device_Gemm_s8t_s8n_s8n_align16_tensor_op_gmma_s32, 128x128x128) { - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::ColumnMajor; - - using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - int8_t, LayoutA, 16, - int8_t, LayoutB, 16, - int32_t, Shape<_128,_128,_128>, Shape<_1,_1,_1>, cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::KernelMultistage + cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< @@ -115,7 +80,7 @@ TEST(SM90_Device_Gemm_s8t_s8n_s8n_align16_tensor_op_gmma_s32, 128x128x128) { int32_t, int32_t, int8_t, LayoutC, 8, int8_t, LayoutC, 8, - cutlass::epilogue::NoSmemWarpSpecialized + cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< @@ -128,7 +93,7 @@ TEST(SM90_Device_Gemm_s8t_s8n_s8n_align16_tensor_op_gmma_s32, 128x128x128) { EXPECT_TRUE(test::gemm::device::TestAll()); } -TEST(SM90_Device_Gemm_s8t_s8n_s8n_align4_tensor_op_gmma_s32, 128x64x128) { +TEST(SM90_Device_Gemm_s8t_s8n_s8n_align4_tensor_op_gmma_s32, 128x128x128) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::ColumnMajor; @@ -138,14 +103,14 @@ TEST(SM90_Device_Gemm_s8t_s8n_s8n_align4_tensor_op_gmma_s32, 128x64x128) { int8_t, LayoutA, 4, int8_t, LayoutB, 4, int32_t, - Shape<_128,_64,_128>, Shape<_1,_1,_1>, + Shape<_128,_128,_128>, Shape<_1,_1,_1>, cutlass::gemm::collective::StageCountAuto, cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - Shape<_128,_64,_128>, Shape<_1,_1,_1>, + Shape<_128,_128,_128>, Shape<_1,_1,_1>, cutlass::epilogue::collective::EpilogueTileAuto, int32_t, int32_t, int8_t, LayoutC, 4, diff --git a/test/unit/gemm/device/sm90_gemm_s8_s8_s8_alignx_tensor_op_s32_warpspecialized.cu b/test/unit/gemm/device/sm90_gemm_s8_s8_s8_alignx_tensor_op_s32_warpspecialized.cu new file mode 100644 index 0000000000..82d3ad681b --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_s8_s8_s8_alignx_tensor_op_s32_warpspecialized.cu @@ -0,0 +1,168 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_s8t_s8n_s8n_align16_tensor_op_gmma_s32_warpspecialized, 128x128x128) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + int8_t, LayoutA, 16, + int8_t, LayoutB, 16, + int32_t, + Shape<_128,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecialized + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + int32_t, int32_t, + int8_t, LayoutC, 16, + int8_t, LayoutC, 16, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_s8t_s8n_s8n_align8_tensor_op_gmma_s32_warpspecialized, 128x128x128) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + int8_t, LayoutA, 8, + int8_t, LayoutB, 8, + int32_t, + Shape<_128,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecialized + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + int32_t, int32_t, + int8_t, LayoutC, 8, + int8_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_s8t_s8n_s8n_align4_tensor_op_gmma_s32_warpspecialized, 128x128x128) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + int8_t, LayoutA, 4, + int8_t, LayoutB, 4, + int32_t, + Shape<_128,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecialized + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + int32_t, int32_t, + int8_t, LayoutC, 4, + int8_t, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_s8_s8_s8_alignx_tensor_op_s32_warpspecialized_cooperative.cu b/test/unit/gemm/device/sm90_gemm_s8_s8_s8_alignx_tensor_op_s32_warpspecialized_cooperative.cu new file mode 100644 index 0000000000..e69d2fded2 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_s8_s8_s8_alignx_tensor_op_s32_warpspecialized_cooperative.cu @@ -0,0 +1,168 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_s8t_s8n_s8n_align16_tensor_op_gmma_s32_warpspecialized_cooperative, 128x128x128) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + int8_t, LayoutA, 16, + int8_t, LayoutB, 16, + int32_t, + Shape<_128,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + int32_t, int32_t, + int8_t, LayoutC, 16, + int8_t, LayoutC, 16, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_s8t_s8n_s8n_align8_tensor_op_gmma_s32_warpspecialized_cooperative, 128x128x128) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + int8_t, LayoutA, 8, + int8_t, LayoutB, 8, + int32_t, + Shape<_128,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + int32_t, int32_t, + int8_t, LayoutC, 8, + int8_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_s8t_s8n_s8n_align4_tensor_op_gmma_s32_warpspecialized_cooperative, 128x128x128) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + int8_t, LayoutA, 4, + int8_t, LayoutB, 4, + int32_t, + Shape<_128,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + int32_t, int32_t, + int8_t, LayoutC, 4, + int8_t, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_s8_s8_s8_alignx_tensor_op_s32_warpspecialized_pingpong.cu b/test/unit/gemm/device/sm90_gemm_s8_s8_s8_alignx_tensor_op_s32_warpspecialized_pingpong.cu new file mode 100644 index 0000000000..96cf81ebd1 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_s8_s8_s8_alignx_tensor_op_s32_warpspecialized_pingpong.cu @@ -0,0 +1,168 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_s8t_s8n_s8n_align16_tensor_op_gmma_s32_warpspecialized_pingpong, 128x128x128) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + int8_t, LayoutA, 16, + int8_t, LayoutB, 16, + int32_t, + Shape<_128,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + int32_t, int32_t, + int8_t, LayoutC, 16, + int8_t, LayoutC, 16, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_s8t_s8n_s8n_align8_tensor_op_gmma_s32_warpspecialized_pingpong, 128x128x128) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + int8_t, LayoutA, 8, + int8_t, LayoutB, 8, + int32_t, + Shape<_128,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + int32_t, int32_t, + int8_t, LayoutC, 8, + int8_t, LayoutC, 8, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_s8t_s8n_s8n_align4_tensor_op_gmma_s32_warpspecialized_pingpong, 128x128x128) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + int8_t, LayoutA, 4, + int8_t, LayoutB, 4, + int32_t, + Shape<_128,_128,_128>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_128>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + int32_t, int32_t, + int8_t, LayoutC, 4, + int8_t, LayoutC, 4, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_stream_k_scheduler.cu b/test/unit/gemm/device/sm90_gemm_stream_k_scheduler.cu index 989d60aab7..9942f2d755 100644 --- a/test/unit/gemm/device/sm90_gemm_stream_k_scheduler.cu +++ b/test/unit/gemm/device/sm90_gemm_stream_k_scheduler.cu @@ -62,10 +62,10 @@ run_scheduler(int* visit_counters, typename Scheduler::Params params, TileShape Scheduler scheduler{params}; auto work_tile_info = scheduler.get_current_work(); - while (work_tile_info.is_valid_tile) { + while (work_tile_info.is_valid()) { // Increment counters to indicate coverage auto tile_idx = Scheduler::output_tile_index(params, work_tile_info); - auto offset = tile_idx * params.k_tiles_per_output_tile_ + work_tile_info.K_idx; + auto offset = tile_idx * params.divmod_tiles_per_output_tile_.divisor + work_tile_info.K_idx; for (auto i = 0; i < work_tile_info.k_tile_count; ++i) { // Use atomicAdd because the visit counters are shared by multiple thread blocks. // While having more than one block increment the same counter indicates failure, @@ -108,7 +108,7 @@ test_scheduler( // Allocate counters indicating the number of times each k iteration of each output tile has been visited auto [blk_m, blk_n, blk_l] = Scheduler::get_tiled_cta_shape_mnl(problem_shape_mnkl, tile_shape, cluster_shape); - auto total_counters = blk_m * blk_n * blk_l * params.k_tiles_per_output_tile_; + auto total_counters = blk_m * blk_n * blk_l * params.divmod_tiles_per_output_tile_.divisor; cutlass::DeviceAllocation visit_counters(total_counters); // Initialize counters to zero @@ -181,8 +181,6 @@ test_scheduler( for (size_t i = 0; i < host_visit_counts.size(); ++i) { if (host_visit_counts[i] != 1) { - // for (int count : host_visit_counts) { - // if (count != 1) { std::cout << "Failed with problem size " << size<0>(problem_shape_mnkl) << "x" << size<1>(problem_shape_mnkl) << "x" @@ -191,11 +189,12 @@ test_scheduler( << " and grid size " << grid.x << "x" << grid.y << "x" << grid.z << " splits=" << params.splits_ - << " k_iter=" << params.k_tiles_per_output_tile_ + << " k_iter=" << params.divmod_tiles_per_output_tile_.divisor << " big_units=" << params.big_units_ << " sk_tiles=" << params.sk_tiles_ << " sk_units=" << params.sk_units_ - << " k_tiles_per_sk_unit=" << params.k_tiles_per_sk_unit_ << std::endl; + << " k_tiles_per_sk_unit=" << params.k_tiles_per_sk_unit_ + << " units_per_problem=" << params.units_per_problem_ << std::endl; std::cout << "Error at idx: " << i << ". Got count " << host_visit_counts[i] << std::endl; return false; } diff --git a/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32.cu b/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32.cu index bb25de29b2..65a9328df8 100644 --- a/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32.cu +++ b/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32.cu @@ -57,42 +57,7 @@ using namespace cute; /////////////////////////////////////////////////////////////////////////////// -TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align4_tensor_op_gmma_f32, 64x128x32) { - using LayoutA = cutlass::layout::RowMajor; - using LayoutB = cutlass::layout::ColumnMajor; - using LayoutC = cutlass::layout::ColumnMajor; - - using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - tfloat32_t, LayoutA, 4, - tfloat32_t, LayoutB, 4, - float, - Shape<_64,_128,_32>, Shape<_1,_1,_1>, - cutlass::gemm::collective::StageCountAuto, - cutlass::gemm::KernelMultistage - >::CollectiveOp; - - using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< - cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - Shape<_64,_128,_32>, Shape<_1,_1,_1>, - cutlass::epilogue::collective::EpilogueTileAuto, - float, float, - float, LayoutC, 4, - float, LayoutC, 4, - cutlass::epilogue::NoSmemWarpSpecialized - >::CollectiveOp; - - using GemmKernel = cutlass::gemm::kernel::GemmUniversal< - Shape, - CollectiveOp, - CollectiveEpilogue - >; - - using Gemm = cutlass::gemm::device::GemmUniversalAdapter; - EXPECT_TRUE(test::gemm::device::TestAll()); -} - -TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align2_tensor_op_gmma_f32, 64x64x32) { +TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align2_tensor_op_gmma_f32, 128x64x32) { using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::ColumnMajor; @@ -102,7 +67,7 @@ TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align2_tensor_op_gmma_f32, 64x64x32) { cutlass::tfloat32_t, LayoutA, 2, cutlass::tfloat32_t, LayoutB, 2, float, - Shape<_64,_64,_32>, Shape<_1,_1,_1>, + Shape<_128,_64,_32>, Shape<_1,_1,_1>, cutlass::gemm::collective::StageCountAuto, cutlass::gemm::collective::KernelScheduleAuto >::CollectiveOp; diff --git a/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32_warpspecialized.cu b/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32_warpspecialized.cu new file mode 100644 index 0000000000..fe25cc707c --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32_warpspecialized.cu @@ -0,0 +1,167 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align4_tensor_op_gmma_f32_warpspecialized, 128x64x32) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + tfloat32_t, LayoutA, 4, + tfloat32_t, LayoutB, 4, + float, + Shape<_128,_64,_32>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecialized + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_64,_32>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align2_tensor_op_gmma_f32_warpspecialized, 128x64x32) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::tfloat32_t, LayoutA, 2, + cutlass::tfloat32_t, LayoutB, 2, + float, + Shape<_128,_64,_32>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecialized + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_64,_32>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 2, + float, LayoutC, 2, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align1_tensor_op_gmma_f32_warpspecialized, 128x64x32) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::tfloat32_t, LayoutA, 1, + cutlass::tfloat32_t, LayoutB, 1, + float, + Shape<_128,_64,_32>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecialized + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_64,_32>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 1, + float, LayoutC, 1, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32_warpspecialized_cooperative.cu b/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32_warpspecialized_cooperative.cu new file mode 100644 index 0000000000..06025f9b32 --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32_warpspecialized_cooperative.cu @@ -0,0 +1,167 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align4_tensor_op_gmma_f32_warpspecialized_cooperative, 128x64x32) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + tfloat32_t, LayoutA, 4, + tfloat32_t, LayoutB, 4, + float, + Shape<_128,_64,_32>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_64,_32>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align2_tensor_op_gmma_f32_warpspecialized_cooperative, 128x64x32) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::tfloat32_t, LayoutA, 2, + cutlass::tfloat32_t, LayoutB, 2, + float, + Shape<_128,_64,_32>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_64,_32>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 2, + float, LayoutC, 2, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align1_tensor_op_gmma_f32_warpspecialized_cooperative, 128x64x32) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::tfloat32_t, LayoutA, 1, + cutlass::tfloat32_t, LayoutB, 1, + float, + Shape<_128,_64,_32>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedCooperative + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_64,_32>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 1, + float, LayoutC, 1, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32_warpspecialized_pingpong.cu b/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32_warpspecialized_pingpong.cu new file mode 100644 index 0000000000..1d530d5f9c --- /dev/null +++ b/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_alignx_tensor_op_f32_warpspecialized_pingpong.cu @@ -0,0 +1,167 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" + +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "../../common/cutlass_unit_test.h" + +#include "gemm_testbed_3x.hpp" + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////// + +TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align4_tensor_op_gmma_f32_warpspecialized_pingpong, 128x64x32) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + tfloat32_t, LayoutA, 4, + tfloat32_t, LayoutB, 4, + float, + Shape<_128,_64,_32>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_64,_32>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 4, + float, LayoutC, 4, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align2_tensor_op_gmma_f32_warpspecialized_pingpong, 128x64x32) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::tfloat32_t, LayoutA, 2, + cutlass::tfloat32_t, LayoutB, 2, + float, + Shape<_128,_64,_32>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_64,_32>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 2, + float, LayoutC, 2, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +TEST(SM90_Device_Gemm_tf32t_tf32n_f32n_align1_tensor_op_gmma_f32_warpspecialized_pingpong, 128x64x32) { + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + cutlass::tfloat32_t, LayoutA, 1, + cutlass::tfloat32_t, LayoutB, 1, + float, + Shape<_128,_64,_32>, Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecializedPingpong + >::CollectiveOp; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_64,_32>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + float, float, + float, LayoutC, 1, + float, LayoutC, 1, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveOp, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + EXPECT_TRUE(test::gemm::device::TestAll()); +} + +/////////////////////////////////////////////////////////////////////////////// + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) diff --git a/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32.cu b/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32.cu index bc31d24a68..194ec04a41 100644 --- a/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32.cu +++ b/test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32.cu @@ -99,7 +99,7 @@ TEST(SM90_Device_Gemm_tf32n_tf32n_f32n_tensor_op_gmma_f32, 64x128x32) { using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - cutlass::tfloat32_t, LayoutA, 1, + cutlass::tfloat32_t, LayoutA, 4, cutlass::tfloat32_t, LayoutB, 4, float, Shape<_64,_128,_32>, Shape<_1,_1,_1>, @@ -136,8 +136,8 @@ TEST(SM90_Device_Gemm_tf32n_tf32t_f32n_tensor_op_gmma_f32, 64x128x32) { using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, - cutlass::tfloat32_t, LayoutA, 1, - cutlass::tfloat32_t, LayoutB, 1, + cutlass::tfloat32_t, LayoutA, 4, + cutlass::tfloat32_t, LayoutB, 4, float, Shape<_64,_128,_32>, Shape<_1,_1,_1>, cutlass::gemm::collective::StageCountAuto, @@ -149,8 +149,8 @@ TEST(SM90_Device_Gemm_tf32n_tf32t_f32n_tensor_op_gmma_f32, 64x128x32) { Shape<_64,_128,_32>, Shape<_1,_1,_1>, cutlass::epilogue::collective::EpilogueTileAuto, float, float, - float, LayoutC, 1, - float, LayoutC, 1, + float, LayoutC, 4, + float, LayoutC, 4, cutlass::epilogue::collective::EpilogueScheduleAuto >::CollectiveOp; @@ -174,7 +174,7 @@ TEST(SM90_Device_Gemm_tf32t_tf32t_f32n_tensor_op_gmma_f32, 64x128x32) { using CollectiveOp = typename cutlass::gemm::collective::CollectiveBuilder< cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, cutlass::tfloat32_t, LayoutA, 4, - cutlass::tfloat32_t, LayoutB, 1, + cutlass::tfloat32_t, LayoutB, 4, float, Shape<_64,_128,_32>, Shape<_1,_1,_1>, cutlass::gemm::collective::StageCountAuto, @@ -188,7 +188,7 @@ TEST(SM90_Device_Gemm_tf32t_tf32t_f32n_tensor_op_gmma_f32, 64x128x32) { float, float, float, LayoutC, 4, float, LayoutC, 4, - cutlass::epilogue::collective::EpilogueScheduleAuto + cutlass::gemm::EpilogueTransposed >::CollectiveOp; using GemmKernel = cutlass::gemm::kernel::GemmUniversal< diff --git a/test/unit/gemm/device/trmm_tf32t_tf32n_f32t_tensor_op_f32_ls_sm80.cu b/test/unit/gemm/device/trmm_tf32t_tf32n_f32t_tensor_op_f32_ls_sm80.cu index ba3f7f3864..9ca0015396 100644 --- a/test/unit/gemm/device/trmm_tf32t_tf32n_f32t_tensor_op_f32_ls_sm80.cu +++ b/test/unit/gemm/device/trmm_tf32t_tf32n_f32t_tensor_op_f32_ls_sm80.cu @@ -337,6 +337,8 @@ TEST(SM80_Device_Trmm_tf32t_tf32n_f32t_ls_u_un_tensor_op_f32_align1_align4, 128x ///////////////////////////////////////////////////////////////////////////////////////////////// +// This test fails on Ada when running with 11.8 +#if ((__CUDACC_VER_MAJOR__ != 11) || (__CUDACC_VER_MINOR__ != 8) || (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 890))) TEST(SM80_Device_Trmm_tf32t_tf32n_f32t_ls_u_nu_tensor_op_f32_align1_align4, 256x128x16_128x64x16) { using ElementOutput = float; @@ -374,6 +376,7 @@ TEST(SM80_Device_Trmm_tf32t_tf32n_f32t_ls_u_nu_tensor_op_f32_align1_align4, 256x EXPECT_TRUE(test::gemm::device::TestAllTrmmUniversal()); } +#endif ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/test/unit/gemm/warp/gemm_mixed_input_sm80.cu b/test/unit/gemm/warp/gemm_mixed_input_sm80.cu index 56f6fb742f..c83f99f305 100644 --- a/test/unit/gemm/warp/gemm_mixed_input_sm80.cu +++ b/test/unit/gemm/warp/gemm_mixed_input_sm80.cu @@ -55,6 +55,7 @@ //////////////////////////////////////////////////////////////////////////////// /// F32 <= F16 * I8 + F32 (Upcast on Operand B) //////////////////////////////////////////////////////////////////////////////// + TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i8, 128x128x64_64x64x64_16x8x16) { using Shape = cutlass::gemm::GemmShape<64, 64, 64>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; @@ -98,6 +99,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i8, 64x64x64_64x64x64_16 //////////////////////////////////////////////////////////////////////////////// /// F32 <= I8 * F16 + F32 (Upcast on Operand A) //////////////////////////////////////////////////////////////////////////////// + TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_f16, 128x128x64_64x64x64_16x8x16) { using Shape = cutlass::gemm::GemmShape<64, 64, 64>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; @@ -118,7 +120,6 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_f16, 128x128x64_64x64x64_ .run(); } - TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_f16, 64x64x64_64x64x64_16x8x16) { using Shape = cutlass::gemm::GemmShape<64, 64, 64>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; @@ -142,6 +143,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_f16, 64x64x64_64x64x64_16 //////////////////////////////////////////////////////////////////////////////// /// F32 <= F16 * U8 + F32 (Upcast on Operand B) //////////////////////////////////////////////////////////////////////////////// + TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_u8, 64x64x64_64x64x64_16x8x16) { using Shape = cutlass::gemm::GemmShape<64, 64, 64>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; @@ -185,6 +187,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_u8, 128x128x64_64x64x64_ //////////////////////////////////////////////////////////////////////////////// /// F32 <= U8 * F16 + F32 (Upcast on Operand A) //////////////////////////////////////////////////////////////////////////////// + TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_f16, 64x64x64_64x64x64_16x8x16) { using Shape = cutlass::gemm::GemmShape<64, 64, 64>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; @@ -225,10 +228,10 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_f16, 128x128x64_64x64x64_ .run(); } - //////////////////////////////////////////////////////////////////////////////// /// F32 <= B16 * U8 + F32 (Upcast on Operand B) //////////////////////////////////////////////////////////////////////////////// + TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_u8, 64x64x64_64x64x64_16x8x16) { using Shape = cutlass::gemm::GemmShape<64, 64, 64>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; @@ -252,6 +255,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_u8, 64x64x64_64x64x64_1 //////////////////////////////////////////////////////////////////////////////// /// F32 <= U8 * BF16 + F32 (Upcast on Operand A) //////////////////////////////////////////////////////////////////////////////// + TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_bf16, 64x64x64_64x64x64_16x8x16) { using Shape = cutlass::gemm::GemmShape<64, 64, 64>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; @@ -273,8 +277,9 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_bf16, 64x64x64_64x64x64_1 } //////////////////////////////////////////////////////////////////////////////// -/// F32 <= B16 * I8 + F32 (Upcast on Operand B) +/// F32 <= I8 * BF16 + F32 (Upcast on Operand A) //////////////////////////////////////////////////////////////////////////////// + TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_i8, 64x64x64_64x64x64_16x8x16) { using Shape = cutlass::gemm::GemmShape<64, 64, 64>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; @@ -296,8 +301,9 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_i8, 64x64x64_64x64x64_1 } //////////////////////////////////////////////////////////////////////////////// -/// F32 <= I8 * BF16 + F32 (Upcast on Operand A) +/// F32 <= B16 * I8 + F32 (Upcast on Operand B) //////////////////////////////////////////////////////////////////////////////// + TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_bf16, 64x64x64_64x64x64_16x8x16) { using Shape = cutlass::gemm::GemmShape<64, 64, 64>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; @@ -318,4 +324,4 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_bf16, 64x64x64_64x64x64_1 .run(); } -#endif // if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) \ No newline at end of file +#endif // if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) diff --git a/test/unit/util/rms_norm.cu b/test/unit/util/rms_norm.cu index 7897de5104..7ff533da2f 100644 --- a/test/unit/util/rms_norm.cu +++ b/test/unit/util/rms_norm.cu @@ -44,7 +44,7 @@ void rmsnorm_host(cutlass::MatrixCoord tensor_size, cutlass::TensorRef output, cutlass::TensorRef input, cutlass::TensorRef weight, - float epsilon) { + float epsilon) { const int M = tensor_size.row(); const int N = tensor_size.column(); @@ -94,7 +94,7 @@ void run_test(int M, int N) { rmsnorm_host({M, N}, output_ref.host_ref(), input.host_ref(), weight.host_ref(), (float)1e-5); cutlass::rmsnorm({M, N}, output.device_ref(), - input.device_ref(), weight.device_ref(), NULL, (float)1e-5); + input.device_ref(), weight.device_ref(), NULL, (float)1e-5L); output.sync_host(); diff --git a/tools/library/CMakeLists.txt b/tools/library/CMakeLists.txt index d4b00c9209..fd6a0a0ff9 100644 --- a/tools/library/CMakeLists.txt +++ b/tools/library/CMakeLists.txt @@ -264,7 +264,8 @@ set(CUTLASS_LIBRARY_GENERATED_KERNEL_LIST_FILE ${CMAKE_CURRENT_BINARY_DIR}/gener # in ${CMAKE_CURRENT_BINARY_DIR}/library_instance_generation.log, set this parameter to INFO execute_process( WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/../../python/cutlass_library - COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/../../python/cutlass_library/generator.py + COMMAND ${CMAKE_COMMAND} -E env PYTHONPATH=${CUTLASS_LIBRARY_PACKAGE_DIR} + ${Python3_EXECUTABLE} ${CUTLASS_SOURCE_DIR}/python/cutlass_library/generator.py --operations "${CUTLASS_LIBRARY_OPERATIONS}" --build-dir ${PROJECT_BINARY_DIR} --curr-build-dir ${CMAKE_CURRENT_BINARY_DIR} @@ -275,6 +276,7 @@ execute_process( --selected-kernel-list "${CUTLASS_LIBRARY_GENERATED_KERNEL_LIST_FILE}" --cuda-version "${CUTLASS_GENERATOR_CUDA_COMPILER_VERSION}" --log-level DEBUG + --disable-cutlass-package-imports RESULT_VARIABLE cutlass_lib_INSTANCE_GENERATION_RESULT OUTPUT_VARIABLE cutlass_lib_INSTANCE_GENERATION_OUTPUT OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/library_instance_generation.log diff --git a/tools/library/include/cutlass/library/operation_table.h b/tools/library/include/cutlass/library/operation_table.h index b23036fda4..a9b4cbd13b 100644 --- a/tools/library/include/cutlass/library/operation_table.h +++ b/tools/library/include/cutlass/library/operation_table.h @@ -215,21 +215,21 @@ struct GemmPreferenceKey { return compute_capability == rhs.compute_capability; } }; + ///////////////////////////////////////////////////////////////////////////////////////////////// + inline std::ostream& operator<< (std::ostream& out, const cutlass::library::GemmPreferenceKey& key) { out << "{\n" << "compute_capability : " << key.compute_capability << std::endl << "alignment : " << key.alignment << std::endl << "}"; - + return out; } ///////////////////////////////////////////////////////////////////////////////////////////////// -///////////////////////////////////////////////////////////////////////////////////////////////// - /// Maps minimum compute capability onto a vector of possible operations using GemmOperationVectorMap = std::map< GemmPreferenceKey, @@ -242,7 +242,6 @@ using GemmOperationFunctionalMap = std::unordered_map< GemmOperationVectorMap, GemmFunctionalKeyHasher >; -///////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////// // Data Structures for Conv Functional Maps diff --git a/tools/library/src/handle.cu b/tools/library/src/handle.cu index a24e4e03b0..38d53f4e03 100644 --- a/tools/library/src/handle.cu +++ b/tools/library/src/handle.cu @@ -1183,7 +1183,7 @@ Operation const* find_gemm_operation_for_parallel_reduction(Operation const *ope return nullptr; } - // return matching gemm operation (same tile shape, stages, warp count, and instruction) + // return matching gemm opertion (same tile shape, stages, warp count, and instruction) for (auto op : it->second) { if (op->description().tile_description == operation->description().tile_description) { return op; diff --git a/tools/library/src/library_internal.h b/tools/library/src/library_internal.h index f45b7d1cf3..b42d8fe4d4 100644 --- a/tools/library/src/library_internal.h +++ b/tools/library/src/library_internal.h @@ -287,6 +287,10 @@ template <> struct OpcodeClassMap { static OpcodeClassID const kId = OpcodeClassID::kTensorOp; }; +template <> struct OpcodeClassMap { + static OpcodeClassID const kId = OpcodeClassID::kSparseTensorOp; +}; + template <> struct OpcodeClassMap { static OpcodeClassID const kId = OpcodeClassID::kWmmaTensorOp; }; diff --git a/tools/library/src/operation_table.cu b/tools/library/src/operation_table.cu index 113e48d20d..2f772591a5 100644 --- a/tools/library/src/operation_table.cu +++ b/tools/library/src/operation_table.cu @@ -47,9 +47,7 @@ void OperationTable::append(Manifest const &manifest) { // Insert operations into appropriate data structure for (auto const & operation : manifest) { - OperationDescription const &desc = operation->description(); - // insert all gemm operation into operation table if (desc.kind == OperationKind::kGemm) { GemmDescription const &gemm_desc = static_cast(desc); diff --git a/tools/library/src/reduction/init_reduction_operations.cu b/tools/library/src/reduction/init_reduction_operations.cu index 0d7ce6af13..0fc2d202a9 100644 --- a/tools/library/src/reduction/init_reduction_operations.cu +++ b/tools/library/src/reduction/init_reduction_operations.cu @@ -42,6 +42,7 @@ namespace library { /////////////////////////////////////////////////////////////////////////////////////////////// // CUTLASS Reduction Instances // /////////////////////////////////////////////////////////////////////////////////////////////// + void initialize_reduce_add_linear_combination_f16_f16_f16(Manifest &manifest); void initialize_reduce_add_linear_combination_f32_f32_f16(Manifest &manifest); void initialize_reduce_add_linear_combination_f32_f32_bf16(Manifest &manifest); diff --git a/tools/library/src/reduction/reduction_device.cu b/tools/library/src/reduction/reduction_device.cu index 34852c4ace..41758d26f4 100644 --- a/tools/library/src/reduction/reduction_device.cu +++ b/tools/library/src/reduction/reduction_device.cu @@ -146,7 +146,6 @@ void initialize_reduce_add_linear_combination_f32_f32_bf16(Manifest &manifest) { )); } - void initialize_reduce_add_linear_combination_f32_f32_f32(Manifest &manifest) { using ElementWorkspace = float; diff --git a/tools/library/src/reference/gemm_fp_mixed_input.cu b/tools/library/src/reference/gemm_fp_mixed_input.cu index 786b610187..bda34ac643 100644 --- a/tools/library/src/reference/gemm_fp_mixed_input.cu +++ b/tools/library/src/reference/gemm_fp_mixed_input.cu @@ -78,7 +78,7 @@ void initialize_gemm_reference_operations_fp_mixed_input(Manifest &manifest) { float, float >(manifest); - + make_gemm_real_canonical_layouts< half_t, uint8_t, @@ -151,4 +151,3 @@ void initialize_gemm_reference_operations_fp_mixed_input(Manifest &manifest) { } // namespace cutlass /////////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/tools/library/src/reference/initialize_reference_operations.cu b/tools/library/src/reference/initialize_reference_operations.cu index cc92f91faf..33612570f6 100644 --- a/tools/library/src/reference/initialize_reference_operations.cu +++ b/tools/library/src/reference/initialize_reference_operations.cu @@ -84,6 +84,7 @@ void initialize_reference_operations(Manifest &manifest) { initialize_gemm_reference_operations_fp32out(manifest); initialize_gemm_reference_operations_fp_other(manifest); initialize_gemm_reference_operations_fp_mixed_input(manifest); + } /////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/src/util.cu b/tools/library/src/util.cu index f734fb8f66..319e1fd618 100644 --- a/tools/library/src/util.cu +++ b/tools/library/src/util.cu @@ -333,7 +333,7 @@ static struct { } OperationKind_enumerants[] = { {"eq_gemm", "EqGemm", OperationKind::kEqGemm}, - {"gemm", "Gemm", OperationKind::kGemm}, + {"gemm", "Gemm", OperationKind::kGemm}, {"rank_k", "RankK", OperationKind::kRankK}, {"rank_2k", "Rank2K", OperationKind::kRank2K}, {"trmm", "Trmm", OperationKind::kTrmm}, diff --git a/tools/profiler/CMakeLists.txt b/tools/profiler/CMakeLists.txt index 16d94db294..cca41d1d13 100644 --- a/tools/profiler/CMakeLists.txt +++ b/tools/profiler/CMakeLists.txt @@ -97,7 +97,12 @@ install( RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR} ) -set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=Gemm --providers=cutlass --verification-providers=cublas,device --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true) +if (CUDA_VERSION VERSION_GREATER_EQUAL 12.3 AND CUDA_VERSION VERSION_LESS 12.4 AND (90a IN_LIST CUTLASS_NVCC_ARCHS_ENABLED OR (90 IN_LIST CUTLASS_NVCC_ARCHS_ENABLED))) + set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=Gemm --providers=cutlass --verification-providers=cublas,host --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true) +else() + set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_GEMM --operation=Gemm --providers=cutlass --verification-providers=cublas,device --junit-output=test_cutlass_profiler_gemm --print-kernel-before-running=true) +endif() + set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_CONV2D --operation=Conv2d --providers=cutlass --verification-providers=cudnn,device --junit-output=test_cutlass_profiler_conv2d --print-kernel-before-running=true) set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_CONV3D --operation=Conv3d --providers=cutlass --verification-providers=cudnn,device,host --junit-output=test_cutlass_profiler_conv3d --print-kernel-before-running=true) set(CUTLASS_PROFILER_TEST_COMMAND_OPTIONS_SPGEMM --operation=SparseGemm --providers=cutlass --verification-providers=cublas,device,host --junit-output=test_cutlass_profiler_spgemm --print-kernel-before-running=true) diff --git a/tools/profiler/src/device_allocation.cu b/tools/profiler/src/device_allocation.cu index 088358278a..f16ccf7d6a 100644 --- a/tools/profiler/src/device_allocation.cu +++ b/tools/profiler/src/device_allocation.cu @@ -2005,7 +2005,6 @@ void DeviceAllocation::write_tensor_csv( case library::NumericTypeID::kFE5M2: write_tensor_csv_static_type(out, *this); break; - case library::NumericTypeID::kF16: write_tensor_csv_static_type(out, *this); break; diff --git a/tools/profiler/src/gemm_operation_profiler.cu b/tools/profiler/src/gemm_operation_profiler.cu index a67118cef4..f50b4d4a5e 100644 --- a/tools/profiler/src/gemm_operation_profiler.cu +++ b/tools/profiler/src/gemm_operation_profiler.cu @@ -585,11 +585,11 @@ Status GemmOperationProfiler::initialize_workspace( workspace_size = underlying_operation->get_device_workspace_size(&gemm_workspace_.configuration, &gemm_workspace_.arguments); gemm_workspace_.device_workspace.reset(library::NumericTypeID::kU8, workspace_size); + status = underlying_operation->initialize( &gemm_workspace_.configuration, gemm_workspace_.host_workspace.data(), gemm_workspace_.device_workspace.data()); - if (status != Status::kSuccess) { return status; } diff --git a/tools/util/include/cutlass/util/device_rmsnorm.h b/tools/util/include/cutlass/util/device_rmsnorm.h index 18b5c33f5a..034ab1ddcd 100644 --- a/tools/util/include/cutlass/util/device_rmsnorm.h +++ b/tools/util/include/cutlass/util/device_rmsnorm.h @@ -118,7 +118,7 @@ __global__ void rmsnorm_twoPassAlgo_e1(T* output, const T* input, const T* weight, const int m, const int n, - float epsilon) + float epsilon) { const int m_idx = blockIdx.x; const int tid = threadIdx.x; diff --git a/tools/util/include/cutlass/util/host_tensor.h b/tools/util/include/cutlass/util/host_tensor.h index 4b2b8d152b..7592c81aea 100644 --- a/tools/util/include/cutlass/util/host_tensor.h +++ b/tools/util/include/cutlass/util/host_tensor.h @@ -112,7 +112,7 @@ class HostTensor { /// Example /// int2: kBitsStoredVec = 8; kElementsPerStoredVec = 4; kNumStoragePerStoredVec = 1 uint8_t; /// int4: kBitsStoredVec = 8; kElementsPerStoredVec = 2; kNumStoragePerStoredVec = 1 uint8_t; - static int const kBitsStoredVec = (sizeof_bits::value < 8) ? cutlass::lcm(sizeof_bits::value, 8) : sizeof_bits::value; + static int const kBitsStoredVec = (sizeof_bits::value < 8) ? cutlass::lcm(static_cast(sizeof_bits::value), 8) : sizeof_bits::value; static int const kElementsPerStoredVec = kBitsStoredVec / sizeof_bits::value; static int const kNumStoragePerStoredVec = kBitsStoredVec / (sizeof(Element) * 8); @@ -129,7 +129,8 @@ class HostTensor { Layout layout_; /// Host-side memory allocation - std::vector host_; + /// avoid the std::vector specialization + std::vector, uint8_t, Element>> host_; /// Device-side memory device_memory::allocation device_; @@ -250,10 +251,10 @@ class HostTensor { } /// Gets pointer to host data - Element * host_data() { return host_.data(); } + Element * host_data() { return reinterpret_cast(host_.data()); } /// Gets pointer to host data with a pointer offset - Element * host_data_ptr_offset(LongIndex ptr_element_offset) { return &ReferenceFactory::get(host_.data(), ptr_element_offset); } + Element * host_data_ptr_offset(LongIndex ptr_element_offset) { return &ReferenceFactory::get(host_data(), ptr_element_offset); } /// Gets a reference to an element in host memory Reference host_data(LongIndex idx) { @@ -261,10 +262,10 @@ class HostTensor { } /// Gets pointer to host data - Element const * host_data() const { return host_.data(); } + Element const * host_data() const { return reinterpret_cast(host_.data()); } /// Gets pointer to host data with a pointer offset - Element const * host_data_ptr_offset(LongIndex ptr_element_offset) const { return &ReferenceFactory::get(host_.data(), ptr_element_offset); } + Element const * host_data_ptr_offset(LongIndex ptr_element_offset) const { return &ReferenceFactory::get(host_data(), ptr_element_offset); } /// Gets a constant reference to an element in host memory ConstReference host_data(LongIndex idx) const { diff --git a/tools/util/include/cutlass/util/print_error.hpp b/tools/util/include/cutlass/util/print_error.hpp index 4d84af8a65..4488c5b1e3 100644 --- a/tools/util/include/cutlass/util/print_error.hpp +++ b/tools/util/include/cutlass/util/print_error.hpp @@ -63,6 +63,7 @@ matrix_inf_norm(cute::Tensor const& host_matrix) { using std::abs; using error_type = decltype(std::declval().inf_norm); + using element_type = typename EngineType::value_type; error_type inf_norm = 0.0; bool found_nan = false; @@ -95,6 +96,7 @@ matrix_diff_inf_norm(cute::Tensor const& X, { using std::abs; using error_type = decltype(std::declval().inf_norm); + using element_type = typename EngineType::value_type; assert(cute::size<0>(X) == cute::size<0>(Y)); assert(cute::size<1>(X) == cute::size<1>(Y)); @@ -110,7 +112,8 @@ matrix_diff_inf_norm(cute::Tensor const& X, for(int64_t i = 0; i < num_rows; ++i) { error_type row_abs_sum = 0.0; for(int64_t j = 0; j < num_cols; ++j) { - row_abs_sum += abs(X(i,j) - Y(i,j)); + row_abs_sum += error_type(abs(element_type(X(i,j)) - + element_type(Y(i,j)))); } if(std::isnan(row_abs_sum)) { found_nan = true; diff --git a/tools/util/include/cutlass/util/reference/host/gett.hpp b/tools/util/include/cutlass/util/reference/host/gett.hpp index 7b52dc5874..60a2281471 100644 --- a/tools/util/include/cutlass/util/reference/host/gett.hpp +++ b/tools/util/include/cutlass/util/reference/host/gett.hpp @@ -46,6 +46,16 @@ namespace cutlass::reference::host { +template +struct ElementTraits { + using type = T; +}; + +template +struct ElementTraits().get()), void> > > { + using type = decltype(std::declval().get()); +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// template< @@ -170,8 +180,8 @@ void gett_mainloop( static_assert(cute::rank(typename MainloopParams::LayoutA{}) == 3, "M, K, B"); static_assert(cute::rank(typename MainloopParams::LayoutB{}) == 3, "N, K, B"); - using ElementA = typename MainloopParams::TensorA::value_type; - using ElementB = typename MainloopParams::TensorB::value_type; + using ElementA = typename ElementTraits::type; + using ElementB = typename ElementTraits::type; using RingOp = multiply_add; RingOp fma_op; @@ -189,7 +199,8 @@ void gett_mainloop( ElementAccumulator a_frag[kBlockM]; for (int m_b = 0; m_b < kBlockM; ++m_b) { if (m + m_b < cute::size<0>(mainloop_params.A.layout())) { - a_frag[m_b] = static_cast(mainloop_params.A(m + m_b, k, l)); + // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. + a_frag[m_b] = static_cast(ElementA(mainloop_params.A(m + m_b, k, l))); if (mainloop_params.transform_A == ComplexTransform::kConjugate) { a_frag[m_b] = conj(a_frag[m_b]); } @@ -202,7 +213,8 @@ void gett_mainloop( ElementAccumulator b_frag[kBlockN]; for (int n_b = 0; n_b < kBlockN; ++n_b) { if (n + n_b < cute::size<0>(mainloop_params.B.layout())) { - b_frag[n_b] = static_cast(mainloop_params.B(n + n_b, k, l)); + // Perform reference GEMM calculations at the accumulator's precision. Cast A value to accumulator type. + b_frag[n_b] = static_cast(ElementB(mainloop_params.B(n + n_b, k, l))); if (mainloop_params.transform_B == ComplexTransform::kConjugate) { b_frag[n_b] = conj(b_frag[n_b]); } @@ -252,10 +264,19 @@ void gett_epilogue( std::is_same_v or std::is_same_v; + constexpr bool IsReLUAuxNeeded = + cute::is_same_v> and + cute::is_same_v; + + constexpr bool IsBackpropFusion = + cute::is_same_v> or + cute::is_same_v>; + // Input related converter NumericConverter accumulator_converter; NumericConverter source_converter; NumericConverter bias_converter; + NumericConverter aux_source_converter; // Scale related converter NumericConverter scale_converter; @@ -267,10 +288,12 @@ void gett_epilogue( // Output related converter NumericConverter destination_converter; NumericConverter aux_destination_converter; + NumericConverter dBias_converter; // Epilogue operations multiply_add epilogue_fma; multiplies mul; + plus add; // Activation operation ActivationFunctor activation; @@ -294,23 +317,25 @@ void gett_epilogue( converted_alpha = mul(converted_alpha, mul(converted_scale_a, converted_scale_b)); converted_beta = mul(converted_beta, converted_scale_c); - for (int n_b = 0; n_b < kBlockN; ++n_b) { - for (int m_b = 0; m_b < kBlockM; ++m_b) { + for (int m_b = 0; m_b < kBlockM; ++m_b) { + ElementCompute local_dBias = ElementCompute(0); + + for (int n_b = 0; n_b < kBlockN; ++n_b) { if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n + n_b < cute::size<1>(epilogue_params.D.layout())) { // Convert every type to ElementCompute first, do compute, convert to output type, write it out ElementCompute converted_acc = accumulator_converter(acc[m_b][n_b]); // per-row alpha - if (epilogue_params.Valpha.data()) { + if (raw_pointer_cast(epilogue_params.Valpha.data())) { converted_alpha = scale_converter(epilogue_params.Valpha(m + m_b)); } ElementCompute output = mul(converted_alpha, converted_acc); - if (epilogue_params.Bias.data()) { + if (raw_pointer_cast(epilogue_params.Bias.data()) && not IsBackpropFusion) { ElementCompute converted_bias = bias_converter(epilogue_params.Bias(m + m_b)); output = bias_op(output, converted_bias); } - if (epilogue_params.C.data()) { + if (raw_pointer_cast(epilogue_params.C.data())) { ElementCompute converted_src = source_converter(epilogue_params.C(m + m_b, n + n_b, l)); // per-row beta if (epilogue_params.Vbeta.data()) { @@ -319,18 +344,33 @@ void gett_epilogue( output = epilogue_fma(converted_beta, converted_src, output); } - if (epilogue_params.Aux.data()) { - auto aux_output = output; - if constexpr (IsScalingAndAmaxAuxOutputNeeded) { - maximum_absolute_value_reduction amax_op; - local_abs_max_aux_output = amax_op(local_abs_max_aux_output, aux_output); - aux_output = epilogue_fma(converted_scale_aux, aux_output, ElementCompute(0)); + if constexpr (IsBackpropFusion) { + ElementAux aux_input = ElementAux(0); + if (raw_pointer_cast(epilogue_params.Aux.data())) { + aux_input = epilogue_params.Aux(m + m_b, n + n_b, l); } - epilogue_params.Aux(m + m_b, n + n_b, l) = aux_destination_converter(aux_output); + output = activation(output, aux_source_converter(aux_input)); + local_dBias = add(local_dBias, output); } + else { + if (raw_pointer_cast(epilogue_params.Aux.data())) { + auto aux_output = output; + if constexpr (IsScalingAndAmaxAuxOutputNeeded) { + maximum_absolute_value_reduction amax_op; + local_abs_max_aux_output = amax_op(local_abs_max_aux_output, aux_output); + aux_output = epilogue_fma(converted_scale_aux, aux_output, ElementCompute(0)); + } + + if constexpr (IsReLUAuxNeeded) { + epilogue_params.Aux(m + m_b, n + n_b, l) = not (aux_output < 0) ? uint1b_t(1) : uint1b_t(0); + } else { + epilogue_params.Aux(m + m_b, n + n_b, l) = aux_destination_converter(aux_output); + } + } - output = activation(output); + output = activation(output); + } if constexpr (IsScalingAndAmaxOutputNeeded) { maximum_absolute_value_reduction amax_op; @@ -340,8 +380,16 @@ void gett_epilogue( epilogue_params.D(m + m_b, n + n_b, l) = destination_converter(output); } + } // n_b + + if (m + m_b < cute::size<0>(epilogue_params.D.layout()) && n < cute::size<1>(epilogue_params.D.layout())) { + if (raw_pointer_cast(epilogue_params.Bias.data()) && IsBackpropFusion) { + ElementCompute converted_dBias = bias_converter(epilogue_params.Bias(m + m_b)); + local_dBias = add(local_dBias, converted_dBias); + epilogue_params.Bias(m + m_b) = dBias_converter(local_dBias); + } } - } + } // m_b #if defined(_OPENMP) #pragma omp critical(Abs_Max_Data_Update) #endif diff --git a/tools/util/include/cutlass/util/reference/host/tensor_fill.h b/tools/util/include/cutlass/util/reference/host/tensor_fill.h index 9b0dcdb374..3d776e286c 100644 --- a/tools/util/include/cutlass/util/reference/host/tensor_fill.h +++ b/tools/util/include/cutlass/util/reference/host/tensor_fill.h @@ -566,7 +566,6 @@ struct RandomUniformFunc { // Random values are cast to integer after scaling by a power of two to facilitate error // testing Element result; - if (int_scale >= 0) { rnd = double(int64_t(rnd * double(1 << int_scale))) / double(1 << int_scale); result = static_cast(Real(rnd)); @@ -1253,15 +1252,15 @@ void TensorFillRandom( TensorFillRandomGaussian( view, seed, - static_cast(dist.gaussian.mean), - static_cast(dist.gaussian.stddev), + dist.gaussian.mean, + dist.gaussian.stddev, dist.int_scale); } else if (dist.kind == Distribution::Uniform) { TensorFillRandomUniform( view, seed, - static_cast(dist.uniform.max), - static_cast(dist.uniform.min), + dist.uniform.max, + dist.uniform.min, dist.int_scale); } }