Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CUTLASS 3.3.0 #1167

Merged
merged 2 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ blank_issues_enabled: true
contact_links:
- name: CUTLASS Discord
url: https://discord.gg/nvidiadeveloper
about: Come chat about using and contributing to CUTLASS!
about: Come chat about using and contributing to CUTLASS!
15 changes: 15 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
# NVIDIA CUTLASS Changelog

## [3.3](https://github.com/NVIDIA/cutlass/releases/tag/v3.3) (2023-10-31)
* [Mixed Precision Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input operand types.
* [Mixed Precision Ampere GEMMs](https://github.com/NVIDIA/cutlass/commit/7d8317a63e0a978a8dbb3c1fb7af4dbe4f286616) with support for canonical layouts (TN) and {fp16, bf16} x {s8/u8}.
* [Copy Async based Hopper GEMMs](/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_cooperative.cu) - which support lower than 16B aligned input tensors.
* Kernel schedules and Builder support for mixed precision and Copy Async GEMMs with < 16B aligned input tensors.
* Profiler support for lower-aligned Hopper GEMMs.
* Performance Improvements to [Scatter-Gather Hopper Example](/examples/52_hopper_gather_scatter_fusion)
* Sub-Byte type fixes and improvements
* EVT Support for RELU with Aux bitmap tensor store (used in dRELU). See [SM90 EVT fusions](/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp) for details.
* Fusion support for backprop fusions including drelu, dgelu, and dbias.
* Support for void-C kernels and SM80 mixed-precision GEMMs in the CUTLASS Python interface

## [3.2.2](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1) (2023-10-25)
* Minor patch for issue/1138

## [3.2.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1) (2023-09-22)
* Python support SM90 Epilogue Visitor Tree (EVT) on top of the C++ support released in 3.2.0.
* SM80 EVT support in C++ and Python.
Expand Down
24 changes: 9 additions & 15 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ endif()
message(STATUS "CMake Version: ${CMAKE_VERSION}")
set(IMPLICIT_CMAKE_CXX_STANDARD OFF CACHE BOOL "Do not explicitly specify -std=c++11 if set")

project(CUTLASS VERSION 3.2.1 LANGUAGES CXX)
project(CUTLASS VERSION 3.3.0 LANGUAGES CXX)
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)

if (CUDA_VERSION VERSION_LESS 11.3)
Expand Down Expand Up @@ -87,19 +87,6 @@ set(CUTLASS_TEST_LEVEL "0" CACHE STRING "Level of tests to compile.")

find_package(Python3 3.5 COMPONENTS Interpreter REQUIRED)

# Install cutlass_library Python package
execute_process(
WORKING_DIRECTORY ${CUTLASS_DIR}/python
COMMAND ${Python3_EXECUTABLE} ${CUTLASS_DIR}/python/setup_library.py develop --user
RESULT_VARIABLE cutlass_lib_GENERATOR_INSTALL_RESULT
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/cutlass_library_installation.log
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/cutlass_library_installation.log
)

if(NOT cutlass_lib_GENERATOR_INSTALL_RESULT EQUAL 0)
message(FATAL_ERROR "Error installing cutlass_library package. See ${CMAKE_CURRENT_BINARY_DIR}/cutlass_library_installation.log")
endif()

################################################################################
set(CUTLASS_ENABLE_HEADERS_ONLY OFF CACHE BOOL "Enable only the header library")

Expand Down Expand Up @@ -829,6 +816,8 @@ function(cutlass_add_executable_tests NAME TARGET)
set(TEST_GEN_DIR ${CMAKE_CURRENT_BINARY_DIR}/ctest/${TEST_NAME})
file(MAKE_DIRECTORY ${TEST_GEN_DIR})

set(TEST_SETS_SUPPORTED default)

set(TEST_EXE_PATH $<TARGET_FILE:${TARGET}>)
set(TEST_USE_EXTENDED_FORMAT ON)
configure_file("${CUTLASS_CTEST_TEMPLATE_FILE}" "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.cmake" @ONLY)
Expand Down Expand Up @@ -883,7 +872,12 @@ if (CUTLASS_INSTALL_TESTS)

file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/ctest")

file(WRITE "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "# Generated File\n")
file(WRITE "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "# Generated File\n\n")

file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "if (NOT DEFINED ENV{CUTLASS_TEST_SET})\n")
file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" " set(ENV{CUTLASS_TEST_SET} \"default\")\n")
file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "endif()\n\n")

foreach(GENERATED_FILE ${CUTLASS_CTEST_GENERATED_FILES})
file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "include(${GENERATED_FILE})\n")
endforeach()
Expand Down
40 changes: 16 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition")

# CUTLASS 3.2
# CUTLASS 3.3

_CUTLASS 3.2 - August 2023_
_CUTLASS 3.3 - October 2023_

CUTLASS is a collection of CUDA C++ template abstractions for implementing
high-performance matrix-matrix multiplication (GEMM) and related computations at all levels
Expand Down Expand Up @@ -41,26 +41,17 @@ and improves code composability and readability. More documentation specific to

In addition to GEMMs, CUTLASS implements high-performance convolution via the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline. This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components.

# What's New in CUTLASS 3.2

CUTLASS 3.2.0 is an update to CUTLASS adding:
- New warp-specialized persistent FP8 GEMM kernel [kernel schedules](/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) and [mainloops](/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters. An example showcasing [Hopper warp-specialized FP8 GEMMs](/examples/54_hopper_fp8_warp_specialized_gemm).
- New [Epilogue Visitor Tree (EVT)](/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu) support for Hopper TMA epilogues. EVTs allows for user-defined customized epilogue fusion patterns without having to write a new epilogue.
- [Stream-K](/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp) feature for Hopper. Note that this is only a functional implementation of stream-K, and should not be used for performance comparison. Optimizations are expected in a future release.
- Improved CTA rasterization and support for CTA swizzling for Hopper kernels using the [Tile Scheduler](/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp).
- Improved performance for [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
- [Hopper GEMM+Permute](/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu), an example of fusing tensor reordering (permutation) with GEMM mainloop or epilogue.
- New CUTLASS 2D Convolution Python interface. New [example](/examples/python/03_basic_conv2d.ipynb) here.
- Support for Windows (MSVC) builds.

CUTLASS 3.2.1 is an update to CUTLASS adding:
- Python support SM90 Epilogue Visitor Tree (EVT) on top of the C++ support released in 3.2.0.
- SM80 EVT support in C++ and Python.
- Splitting CUTLASS library into smaller units based on operation, arch and datatypes. See [1105](https://github.com/NVIDIA/cutlass/discussions/1105) for details.
- Making `tools/library/scripts` packageable - `tools/library/scripts` is now moving to `python/cutlass_library`. See the Python [README](/python/README.md) for details.
- SM90 TF32 kernel improvements for all layouts.
- SM90 rasterization direction support in the CUTLASS profiler.
- Improvement for CUTLASS profiler build times.
# What's New in CUTLASS 3.3

CUTLASS 3.3.0 is an update to CUTLASS adding:

- New [Mixed Precision Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input types with optimal performance.
- New [Mixed Precision Ampere GEMMs](https://github.com/NVIDIA/cutlass/commit/7d8317a63e0a978a8dbb3c1fb7af4dbe4f286616) with support for canonical layouts (TN) and {fp16, bf16} x {s8/u8}. They also include fast numeric conversion recipes and warp level shuffles to achieve optimal performance.
- New [Copy Async based Hopper GEMMs](/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_cooperative.cu) - which support lower than 16B aligned input tensors (across s8/fp8/fp16/bf16/tf32 types) with optimal performance. As a part of this, new kernel schedules, and Copy Ops [SM80\_CP\_ASYNC\_CACHE\_\*](/include/cute/arch/copy_sm80.hpp) were also added.
- EVT Support for RELU with Aux bitmap tensor store (used in dRELU). See [SM90 EVT fusions](/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp) for details.
- Various subbyte enhancements like tagged device ptrs, support for vectorized copy, various operators to treat subbyte iterators as pointers, and full-fledged CuTe Tensor support.
- Support for Clang as a host compiler
- Support for void-C kernels and SM80 mixed-precision GEMMs in the CUTLASS Python interface

Minimum requirements:

Expand Down Expand Up @@ -103,7 +94,7 @@ as shown in the above figure. Tensor Core operations are implemented using CUDA
# Compatibility

CUTLASS requires a C++17 host compiler and
performs best when built with the [**CUDA 12.2 Toolkit**](https://developer.nvidia.com/cuda-toolkit).
performs best when built with the [**CUDA 12.2.2 Toolkit**](https://developer.nvidia.com/cuda-toolkit-archive).
It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, CUDA 12.0 and CUDA 12.1.

## Operating Systems
Expand All @@ -114,9 +105,10 @@ We have tested the following environments.
| Ubuntu 18.04 | GCC 7.5.0 |
| Ubuntu 20.04 | GCC 10.3.0 |
| Ubuntu 22.04 | GCC 11.2.0 |
| Ubuntu 22.04 | Clang 10.0.0 |
| Ubuntu 22.04 | Clang 14.0.6 |
| Windows 10.0 | Visual Studio 2019 v16.11.27 |

Note: We plan to add Clang compiler support soon.
Note: GCC 8.5.0 has known regressions regarding fold expressions and overloaded operators. Using GCC 7.5.0 or (preferred) GCC >= 9 is recommended.

## Hardware
Expand Down
2 changes: 1 addition & 1 deletion bin2hex.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ function(FILE_TO_C_STRING FILENAME VARIABLE_NAME OUTPUT_STRING ZERO_TERMINATED)
endif()

string(REGEX REPLACE "(....)" "\\1\n" HEX_OUTPUT ${HEX_INPUT})
string(REGEX REPLACE "([0-9a-f][0-9a-f])" "0x\\1," HEX_OUTPUT ${HEX_OUTPUT})
string(REGEX REPLACE "([0-9a-f][0-9a-f])" "char(0x\\1)," HEX_OUTPUT ${HEX_OUTPUT})

set(HEX_OUTPUT "static char const ${VARIABLE_NAME}[] = {\n ${HEX_OUTPUT}\n};\n")

Expand Down
8 changes: 7 additions & 1 deletion cmake/CTestTestfile.configure.cmake
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Generated file

set(TEST_SETS_SUPPORTED @TEST_SETS_SUPPORTED@)

#? if (DEFINED ENV{CUTLASS_TEST_SET} AND NOT ENV{CUTLASS_TEST_SET} IN_LIST TEST_SETS_SUPPORTED)
#? message(STATUS "Skipping tests for @TEST_EXE_PATH@ as $ENV{CUTLASS_TEST_SET} is not in the set of ${TEST_SETS_SUPPORTED}.")
#? return()
#? endif()

set(TEST_EXE_PATH @TEST_EXE_PATH@)
set(TEST_EXE_WORKING_DIRECTORY @TEST_EXE_WORKING_DIRECTORY@)
set(CUTLASS_USE_EXTENDED_ADD_TEST_FORMAT @TEST_USE_EXTENDED_FORMAT@)
Expand All @@ -11,4 +18,3 @@ else()
endif()

@_INLINE_PER_TEST_CODE@

6 changes: 6 additions & 0 deletions examples/47_ampere_gemm_universal_streamk/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@ cutlass_example_add_executable(
ampere_gemm_universal_streamk.cu
)

# Deliberately test non-square sizes to ensure that internal transpose is
# not triggered when using SM80 EVT
set(TEST_COMMAND_00 --m=512 --n=768 --k=1152)

cutlass_example_add_executable(
47_ampere_gemm_universal_streamk_broadcast
ampere_gemm_universal_streamk_broadcast.cu
TEST_COMMAND_OPTIONS
TEST_COMMAND_00
)
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,6 @@ using ClusterShape = Shape<_1,_2,_1>; // S
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default setting in the Collective Builder

using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;

using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
Expand All @@ -127,6 +117,17 @@ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBui
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;

using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int>, // Indicates ProblemShape
CollectiveMainloop,
Expand Down Expand Up @@ -397,6 +398,7 @@ int run(Options &options)
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm.run());
}
timer.stop();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
CUTLASS builders make an attempt to pick the best schedule when `Auto` is provided such that the
assembled collectives have the best performance, but this is not a guarantee. A user relying on `Auto`
may get a free performance upgrade with newer CUTLASS releases in case we can provide more optimized
implementations that the builder can transparently assemble for `Auto`. But a user should not rely on
implementations that the builder can transparently assemble for `Auto`. But a user should not rely on
`Auto` if they require a specific scheduling policy and/or stage count to be used.

If a user decides to let the builders pick the collective specialization via `Auto` schedules,
Expand Down Expand Up @@ -289,7 +289,7 @@ struct ExampleRunner {
// EVTs can be constructed by composing the fundamental load/store/compute visitor operations defined in include/cutlass/epilogue/fusion
// For more complex examples of EVT construction please refer to include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp
using CustomEVT = // alpha * acc + beta * C
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::multiply_add, ElementD, ElementCompute, RoundStyle>, // beta * C + (alpha * acc)
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::homogeneous_multiply_add, ElementD, ElementCompute, RoundStyle>, // beta * C + (alpha * acc)
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementScalar>, // beta
cutlass::epilogue::fusion::Sm90SrcFetch, // C
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc
Expand Down Expand Up @@ -322,7 +322,7 @@ struct ExampleRunner {
ElementAccumulator,
Shape<_128,_128,_64>, Shape<_2,_1,_1>,
cute::conditional_t<cute::is_same_v<StageCountType, cutlass::gemm::collective::StageCountAuto>,
cutlass::gemm::collective::StageCountAutoCarveout<(int)sizeof(typename CollectiveEpilogue::SharedStorage)>,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
StageCountType>,
MainloopScheduleType
>::CollectiveOp;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ struct Options {
int mode = 1; // N-mode gather/scatter by default

float alpha = 1.0f;
float beta = 1.0f;
float beta = 0.0f;

bool reference_check = true;
int iterations = 20;
Expand Down Expand Up @@ -179,19 +179,27 @@ struct ExampleRunner
{
// Useful aliases

using ProblemShape = Shape<int,int,int,int>;

using StrideA = cutlass::gemm::TagToStrideA_t<LayoutA>;
using StrideB = cutlass::gemm::TagToStrideB_t<LayoutB>;
using StrideC = cutlass::gemm::TagToStrideC_t<LayoutC>;
using StrideD = cutlass::gemm::TagToStrideC_t<LayoutD>;

// Alias to for the epilogue type that supports gather/scatter
using Epilogue = cutlass::epilogue::collective::EpilogueGatherScatter<
cutlass::gemm::TagToStrideC_t<LayoutC>,
cutlass::gemm::TagToStrideC_t<LayoutD>,
cutlass::epilogue::thread::LinearCombination<
ElementD, 1,
ElementAccumulator, ElementComputeEpilogue,
cutlass::epilogue::thread::ScaleType::Default,
cutlass::FloatRoundStyle::round_to_nearest, ElementC
>,
cutlass::gemm::EpilogueDefault,
GatherC,
ScatterD
using Epilogue = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter<
cutlass::epilogue::collective::EpilogueGatherScatter<
StrideC, StrideD,
cutlass::epilogue::thread::LinearCombination<
ElementD, 1,
ElementAccumulator, ElementComputeEpilogue,
cutlass::epilogue::thread::ScaleType::Default,
cutlass::FloatRoundStyle::round_to_nearest, ElementC
>,
cutlass::gemm::EpilogueDefault,
GatherC,
ScatterD
>
>;

// Alias to for the mainloop type
Expand All @@ -202,27 +210,21 @@ struct ExampleRunner
ElementAccumulator,
Shape<_128,_128,_64>,
Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCount<5>,
cutlass::gemm::KernelMultistage
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecialized
>::CollectiveOp;

using ProblemShape = Shape<int,int,int,int>;

using Kernel = cutlass::gemm::kernel::GemmGather<
ProblemShape,
Mainloop,
Epilogue,
void,
GatherA,
GatherB
>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<Kernel>;

using StrideA = typename Kernel::StrideA;
using StrideB = typename Kernel::StrideB;
using StrideC = typename Kernel::StrideC;
using StrideD = typename Kernel::StrideD;

static constexpr bool DoGatherA = not cutlass::platform::is_same<GatherA, NoGather>::value;
static constexpr bool DoGatherB = not cutlass::platform::is_same<GatherB, NoGather>::value;
static constexpr bool DoGatherC = not cutlass::platform::is_same<GatherC, NoGather>::value;
Expand Down Expand Up @@ -250,16 +252,19 @@ struct ExampleRunner

using MainloopRef = Mainloop;

using EpilogueRef = typename cutlass::epilogue::collective::DefaultEpilogue<
StrideC, StrideD,
typename Epilogue::ThreadEpilogueOp,
typename Epilogue::EpilogueSchedule
using EpilogueRef = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter<
cutlass::epilogue::collective::DefaultEpilogue<
StrideC, StrideD,
typename Epilogue::ThreadEpilogueOp,
typename Epilogue::EpilogueSchedule
>
>;

using KernelRef = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
MainloopRef,
EpilogueRef
ProblemShape,
MainloopRef,
EpilogueRef,
void
>;

using GemmRef = cutlass::gemm::device::GemmUniversalAdapter<KernelRef>;
Expand Down Expand Up @@ -289,9 +294,10 @@ struct ExampleRunner
>::CollectiveOp;

using KernelOpt = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
MainloopOpt,
EpilogueOpt
ProblemShape,
MainloopOpt,
EpilogueOpt,
void
>;

using GemmOpt = cutlass::gemm::device::GemmUniversalAdapter<KernelOpt>;
Expand Down Expand Up @@ -404,6 +410,7 @@ struct ExampleRunner
typename Epilogue::ScatterD{gather_indices.get()}
},
hw_info,
{},
typename Kernel::GatherA{gather_indices.get()},
typename Kernel::GatherB{gather_indices.get()}
},
Expand Down
Loading