Skip to content

Commit

Permalink
CUTLASS 3.3.0 (#1167)
Browse files Browse the repository at this point in the history
* Release 3.3.0

Adds support for mixed precision GEMMs On Hopper and Ampere
Adds support for < 16B aligned GEMMs on Hopper
Enhancements to EVT
Enhancements to Python interface
Enhancements to Sub-byte type handling in CuTe
Several other bug-fixes and performance improvements.

* minor doc update
  • Loading branch information
IonThruster authored Nov 2, 2023
1 parent 922fb51 commit c008b4a
Show file tree
Hide file tree
Showing 263 changed files with 16,203 additions and 4,997 deletions.
2 changes: 1 addition & 1 deletion .github/ISSUE_TEMPLATE/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ blank_issues_enabled: true
contact_links:
- name: CUTLASS Discord
url: https://discord.gg/nvidiadeveloper
about: Come chat about using and contributing to CUTLASS!
about: Come chat about using and contributing to CUTLASS!
15 changes: 15 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
# NVIDIA CUTLASS Changelog

## [3.3](https://github.com/NVIDIA/cutlass/releases/tag/v3.3) (2023-10-31)
* [Mixed Precision Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input operand types.
* [Mixed Precision Ampere GEMMs](https://github.com/NVIDIA/cutlass/commit/7d8317a63e0a978a8dbb3c1fb7af4dbe4f286616) with support for canonical layouts (TN) and {fp16, bf16} x {s8/u8}.
* [Copy Async based Hopper GEMMs](/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_cooperative.cu) - which support lower than 16B aligned input tensors.
* Kernel schedules and Builder support for mixed precision and Copy Async GEMMs with < 16B aligned input tensors.
* Profiler support for lower-aligned Hopper GEMMs.
* Performance Improvements to [Scatter-Gather Hopper Example](/examples/52_hopper_gather_scatter_fusion)
* Sub-Byte type fixes and improvements
* EVT Support for RELU with Aux bitmap tensor store (used in dRELU). See [SM90 EVT fusions](/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp) for details.
* Fusion support for backprop fusions including drelu, dgelu, and dbias.
* Support for void-C kernels and SM80 mixed-precision GEMMs in the CUTLASS Python interface

## [3.2.2](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1) (2023-10-25)
* Minor patch for issue/1138

## [3.2.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1) (2023-09-22)
* Python support SM90 Epilogue Visitor Tree (EVT) on top of the C++ support released in 3.2.0.
* SM80 EVT support in C++ and Python.
Expand Down
24 changes: 9 additions & 15 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ endif()
message(STATUS "CMake Version: ${CMAKE_VERSION}")
set(IMPLICIT_CMAKE_CXX_STANDARD OFF CACHE BOOL "Do not explicitly specify -std=c++11 if set")

project(CUTLASS VERSION 3.2.1 LANGUAGES CXX)
project(CUTLASS VERSION 3.3.0 LANGUAGES CXX)
include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake)

if (CUDA_VERSION VERSION_LESS 11.3)
Expand Down Expand Up @@ -87,19 +87,6 @@ set(CUTLASS_TEST_LEVEL "0" CACHE STRING "Level of tests to compile.")

find_package(Python3 3.5 COMPONENTS Interpreter REQUIRED)

# Install cutlass_library Python package
execute_process(
WORKING_DIRECTORY ${CUTLASS_DIR}/python
COMMAND ${Python3_EXECUTABLE} ${CUTLASS_DIR}/python/setup_library.py develop --user
RESULT_VARIABLE cutlass_lib_GENERATOR_INSTALL_RESULT
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/cutlass_library_installation.log
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/cutlass_library_installation.log
)

if(NOT cutlass_lib_GENERATOR_INSTALL_RESULT EQUAL 0)
message(FATAL_ERROR "Error installing cutlass_library package. See ${CMAKE_CURRENT_BINARY_DIR}/cutlass_library_installation.log")
endif()

################################################################################
set(CUTLASS_ENABLE_HEADERS_ONLY OFF CACHE BOOL "Enable only the header library")

Expand Down Expand Up @@ -829,6 +816,8 @@ function(cutlass_add_executable_tests NAME TARGET)
set(TEST_GEN_DIR ${CMAKE_CURRENT_BINARY_DIR}/ctest/${TEST_NAME})
file(MAKE_DIRECTORY ${TEST_GEN_DIR})

set(TEST_SETS_SUPPORTED default)

set(TEST_EXE_PATH $<TARGET_FILE:${TARGET}>)
set(TEST_USE_EXTENDED_FORMAT ON)
configure_file("${CUTLASS_CTEST_TEMPLATE_FILE}" "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.cmake" @ONLY)
Expand Down Expand Up @@ -883,7 +872,12 @@ if (CUTLASS_INSTALL_TESTS)

file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/ctest")

file(WRITE "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "# Generated File\n")
file(WRITE "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "# Generated File\n\n")

file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "if (NOT DEFINED ENV{CUTLASS_TEST_SET})\n")
file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" " set(ENV{CUTLASS_TEST_SET} \"default\")\n")
file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "endif()\n\n")

foreach(GENERATED_FILE ${CUTLASS_CTEST_GENERATED_FILES})
file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "include(${GENERATED_FILE})\n")
endforeach()
Expand Down
40 changes: 16 additions & 24 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition")

# CUTLASS 3.2
# CUTLASS 3.3

_CUTLASS 3.2 - August 2023_
_CUTLASS 3.3 - October 2023_

CUTLASS is a collection of CUDA C++ template abstractions for implementing
high-performance matrix-matrix multiplication (GEMM) and related computations at all levels
Expand Down Expand Up @@ -41,26 +41,17 @@ and improves code composability and readability. More documentation specific to

In addition to GEMMs, CUTLASS implements high-performance convolution via the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline. This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components.

# What's New in CUTLASS 3.2

CUTLASS 3.2.0 is an update to CUTLASS adding:
- New warp-specialized persistent FP8 GEMM kernel [kernel schedules](/include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) and [mainloops](/include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters. An example showcasing [Hopper warp-specialized FP8 GEMMs](/examples/54_hopper_fp8_warp_specialized_gemm).
- New [Epilogue Visitor Tree (EVT)](/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu) support for Hopper TMA epilogues. EVTs allows for user-defined customized epilogue fusion patterns without having to write a new epilogue.
- [Stream-K](/include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp) feature for Hopper. Note that this is only a functional implementation of stream-K, and should not be used for performance comparison. Optimizations are expected in a future release.
- Improved CTA rasterization and support for CTA swizzling for Hopper kernels using the [Tile Scheduler](/include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp).
- Improved performance for [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA.
- [Hopper GEMM+Permute](/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu), an example of fusing tensor reordering (permutation) with GEMM mainloop or epilogue.
- New CUTLASS 2D Convolution Python interface. New [example](/examples/python/03_basic_conv2d.ipynb) here.
- Support for Windows (MSVC) builds.

CUTLASS 3.2.1 is an update to CUTLASS adding:
- Python support SM90 Epilogue Visitor Tree (EVT) on top of the C++ support released in 3.2.0.
- SM80 EVT support in C++ and Python.
- Splitting CUTLASS library into smaller units based on operation, arch and datatypes. See [1105](https://github.com/NVIDIA/cutlass/discussions/1105) for details.
- Making `tools/library/scripts` packageable - `tools/library/scripts` is now moving to `python/cutlass_library`. See the Python [README](/python/README.md) for details.
- SM90 TF32 kernel improvements for all layouts.
- SM90 rasterization direction support in the CUTLASS profiler.
- Improvement for CUTLASS profiler build times.
# What's New in CUTLASS 3.3

CUTLASS 3.3.0 is an update to CUTLASS adding:

- New [Mixed Precision Hopper GEMMs](/examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input types with optimal performance.
- New [Mixed Precision Ampere GEMMs](https://github.com/NVIDIA/cutlass/commit/7d8317a63e0a978a8dbb3c1fb7af4dbe4f286616) with support for canonical layouts (TN) and {fp16, bf16} x {s8/u8}. They also include fast numeric conversion recipes and warp level shuffles to achieve optimal performance.
- New [Copy Async based Hopper GEMMs](/test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_cooperative.cu) - which support lower than 16B aligned input tensors (across s8/fp8/fp16/bf16/tf32 types) with optimal performance. As a part of this, new kernel schedules, and Copy Ops [SM80\_CP\_ASYNC\_CACHE\_\*](/include/cute/arch/copy_sm80.hpp) were also added.
- EVT Support for RELU with Aux bitmap tensor store (used in dRELU). See [SM90 EVT fusions](/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp) for details.
- Various subbyte enhancements like tagged device ptrs, support for vectorized copy, various operators to treat subbyte iterators as pointers, and full-fledged CuTe Tensor support.
- Support for Clang as a host compiler
- Support for void-C kernels and SM80 mixed-precision GEMMs in the CUTLASS Python interface

Minimum requirements:

Expand Down Expand Up @@ -103,7 +94,7 @@ as shown in the above figure. Tensor Core operations are implemented using CUDA
# Compatibility

CUTLASS requires a C++17 host compiler and
performs best when built with the [**CUDA 12.2 Toolkit**](https://developer.nvidia.com/cuda-toolkit).
performs best when built with the [**CUDA 12.2.2 Toolkit**](https://developer.nvidia.com/cuda-toolkit-archive).
It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, CUDA 12.0 and CUDA 12.1.

## Operating Systems
Expand All @@ -114,9 +105,10 @@ We have tested the following environments.
| Ubuntu 18.04 | GCC 7.5.0 |
| Ubuntu 20.04 | GCC 10.3.0 |
| Ubuntu 22.04 | GCC 11.2.0 |
| Ubuntu 22.04 | Clang 10.0.0 |
| Ubuntu 22.04 | Clang 14.0.6 |
| Windows 10.0 | Visual Studio 2019 v16.11.27 |

Note: We plan to add Clang compiler support soon.
Note: GCC 8.5.0 has known regressions regarding fold expressions and overloaded operators. Using GCC 7.5.0 or (preferred) GCC >= 9 is recommended.

## Hardware
Expand Down
2 changes: 1 addition & 1 deletion bin2hex.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ function(FILE_TO_C_STRING FILENAME VARIABLE_NAME OUTPUT_STRING ZERO_TERMINATED)
endif()

string(REGEX REPLACE "(....)" "\\1\n" HEX_OUTPUT ${HEX_INPUT})
string(REGEX REPLACE "([0-9a-f][0-9a-f])" "0x\\1," HEX_OUTPUT ${HEX_OUTPUT})
string(REGEX REPLACE "([0-9a-f][0-9a-f])" "char(0x\\1)," HEX_OUTPUT ${HEX_OUTPUT})

set(HEX_OUTPUT "static char const ${VARIABLE_NAME}[] = {\n ${HEX_OUTPUT}\n};\n")

Expand Down
8 changes: 7 additions & 1 deletion cmake/CTestTestfile.configure.cmake
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Generated file

set(TEST_SETS_SUPPORTED @TEST_SETS_SUPPORTED@)

#? if (DEFINED ENV{CUTLASS_TEST_SET} AND NOT ENV{CUTLASS_TEST_SET} IN_LIST TEST_SETS_SUPPORTED)
#? message(STATUS "Skipping tests for @TEST_EXE_PATH@ as $ENV{CUTLASS_TEST_SET} is not in the set of ${TEST_SETS_SUPPORTED}.")
#? return()
#? endif()

set(TEST_EXE_PATH @TEST_EXE_PATH@)
set(TEST_EXE_WORKING_DIRECTORY @TEST_EXE_WORKING_DIRECTORY@)
set(CUTLASS_USE_EXTENDED_ADD_TEST_FORMAT @TEST_USE_EXTENDED_FORMAT@)
Expand All @@ -11,4 +18,3 @@ else()
endif()

@_INLINE_PER_TEST_CODE@

6 changes: 6 additions & 0 deletions examples/47_ampere_gemm_universal_streamk/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,13 @@ cutlass_example_add_executable(
ampere_gemm_universal_streamk.cu
)

# Deliberately test non-square sizes to ensure that internal transpose is
# not triggered when using SM80 EVT
set(TEST_COMMAND_00 --m=512 --n=768 --k=1152)

cutlass_example_add_executable(
47_ampere_gemm_universal_streamk_broadcast
ampere_gemm_universal_streamk_broadcast.cu
TEST_COMMAND_OPTIONS
TEST_COMMAND_00
)
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,6 @@ using ClusterShape = Shape<_1,_2,_1>; // S
using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default setting in the Collective Builder

using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;

using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp,
TileShape, ClusterShape,
Expand All @@ -127,6 +117,17 @@ using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBui
cutlass::epilogue::collective::EpilogueScheduleAuto
>::CollectiveOp;

using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
cutlass::gemm::collective::KernelScheduleAuto
>::CollectiveOp;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int>, // Indicates ProblemShape
CollectiveMainloop,
Expand Down Expand Up @@ -397,6 +398,7 @@ int run(Options &options)
GpuTimer timer;
timer.start();
for (int iter = 0; iter < options.iterations; ++iter) {
CUTLASS_CHECK(gemm.initialize(arguments, workspace.get()));
CUTLASS_CHECK(gemm.run());
}
timer.stop();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@
CUTLASS builders make an attempt to pick the best schedule when `Auto` is provided such that the
assembled collectives have the best performance, but this is not a guarantee. A user relying on `Auto`
may get a free performance upgrade with newer CUTLASS releases in case we can provide more optimized
implementations that the builder can transparently assemble for `Auto`. But a user should not rely on
implementations that the builder can transparently assemble for `Auto`. But a user should not rely on
`Auto` if they require a specific scheduling policy and/or stage count to be used.
If a user decides to let the builders pick the collective specialization via `Auto` schedules,
Expand Down Expand Up @@ -289,7 +289,7 @@ struct ExampleRunner {
// EVTs can be constructed by composing the fundamental load/store/compute visitor operations defined in include/cutlass/epilogue/fusion
// For more complex examples of EVT construction please refer to include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp
using CustomEVT = // alpha * acc + beta * C
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::multiply_add, ElementD, ElementCompute, RoundStyle>, // beta * C + (alpha * acc)
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::homogeneous_multiply_add, ElementD, ElementCompute, RoundStyle>, // beta * C + (alpha * acc)
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementScalar>, // beta
cutlass::epilogue::fusion::Sm90SrcFetch, // C
cutlass::epilogue::fusion::Sm90EVT<cutlass::epilogue::fusion::Sm90Compute<cutlass::multiplies, ElementCompute, ElementCompute, RoundStyle>, // alpha * acc
Expand Down Expand Up @@ -322,7 +322,7 @@ struct ExampleRunner {
ElementAccumulator,
Shape<_128,_128,_64>, Shape<_2,_1,_1>,
cute::conditional_t<cute::is_same_v<StageCountType, cutlass::gemm::collective::StageCountAuto>,
cutlass::gemm::collective::StageCountAutoCarveout<(int)sizeof(typename CollectiveEpilogue::SharedStorage)>,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))>,
StageCountType>,
MainloopScheduleType
>::CollectiveOp;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ struct Options {
int mode = 1; // N-mode gather/scatter by default

float alpha = 1.0f;
float beta = 1.0f;
float beta = 0.0f;

bool reference_check = true;
int iterations = 20;
Expand Down Expand Up @@ -179,19 +179,27 @@ struct ExampleRunner
{
// Useful aliases

using ProblemShape = Shape<int,int,int,int>;

using StrideA = cutlass::gemm::TagToStrideA_t<LayoutA>;
using StrideB = cutlass::gemm::TagToStrideB_t<LayoutB>;
using StrideC = cutlass::gemm::TagToStrideC_t<LayoutC>;
using StrideD = cutlass::gemm::TagToStrideC_t<LayoutD>;

// Alias to for the epilogue type that supports gather/scatter
using Epilogue = cutlass::epilogue::collective::EpilogueGatherScatter<
cutlass::gemm::TagToStrideC_t<LayoutC>,
cutlass::gemm::TagToStrideC_t<LayoutD>,
cutlass::epilogue::thread::LinearCombination<
ElementD, 1,
ElementAccumulator, ElementComputeEpilogue,
cutlass::epilogue::thread::ScaleType::Default,
cutlass::FloatRoundStyle::round_to_nearest, ElementC
>,
cutlass::gemm::EpilogueDefault,
GatherC,
ScatterD
using Epilogue = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter<
cutlass::epilogue::collective::EpilogueGatherScatter<
StrideC, StrideD,
cutlass::epilogue::thread::LinearCombination<
ElementD, 1,
ElementAccumulator, ElementComputeEpilogue,
cutlass::epilogue::thread::ScaleType::Default,
cutlass::FloatRoundStyle::round_to_nearest, ElementC
>,
cutlass::gemm::EpilogueDefault,
GatherC,
ScatterD
>
>;

// Alias to for the mainloop type
Expand All @@ -202,27 +210,21 @@ struct ExampleRunner
ElementAccumulator,
Shape<_128,_128,_64>,
Shape<_1,_1,_1>,
cutlass::gemm::collective::StageCount<5>,
cutlass::gemm::KernelMultistage
cutlass::gemm::collective::StageCountAuto,
cutlass::gemm::KernelCpAsyncWarpSpecialized
>::CollectiveOp;

using ProblemShape = Shape<int,int,int,int>;

using Kernel = cutlass::gemm::kernel::GemmGather<
ProblemShape,
Mainloop,
Epilogue,
void,
GatherA,
GatherB
>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<Kernel>;

using StrideA = typename Kernel::StrideA;
using StrideB = typename Kernel::StrideB;
using StrideC = typename Kernel::StrideC;
using StrideD = typename Kernel::StrideD;

static constexpr bool DoGatherA = not cutlass::platform::is_same<GatherA, NoGather>::value;
static constexpr bool DoGatherB = not cutlass::platform::is_same<GatherB, NoGather>::value;
static constexpr bool DoGatherC = not cutlass::platform::is_same<GatherC, NoGather>::value;
Expand Down Expand Up @@ -250,16 +252,19 @@ struct ExampleRunner

using MainloopRef = Mainloop;

using EpilogueRef = typename cutlass::epilogue::collective::DefaultEpilogue<
StrideC, StrideD,
typename Epilogue::ThreadEpilogueOp,
typename Epilogue::EpilogueSchedule
using EpilogueRef = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter<
cutlass::epilogue::collective::DefaultEpilogue<
StrideC, StrideD,
typename Epilogue::ThreadEpilogueOp,
typename Epilogue::EpilogueSchedule
>
>;

using KernelRef = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
MainloopRef,
EpilogueRef
ProblemShape,
MainloopRef,
EpilogueRef,
void
>;

using GemmRef = cutlass::gemm::device::GemmUniversalAdapter<KernelRef>;
Expand Down Expand Up @@ -289,9 +294,10 @@ struct ExampleRunner
>::CollectiveOp;

using KernelOpt = cutlass::gemm::kernel::GemmUniversal<
ProblemShape,
MainloopOpt,
EpilogueOpt
ProblemShape,
MainloopOpt,
EpilogueOpt,
void
>;

using GemmOpt = cutlass::gemm::device::GemmUniversalAdapter<KernelOpt>;
Expand Down Expand Up @@ -404,6 +410,7 @@ struct ExampleRunner
typename Epilogue::ScatterD{gather_indices.get()}
},
hw_info,
{},
typename Kernel::GatherA{gather_indices.get()},
typename Kernel::GatherB{gather_indices.get()}
},
Expand Down
Loading

0 comments on commit c008b4a

Please sign in to comment.