From 757275f2796bb901575c633e2a32bc76ca84ffec Mon Sep 17 00:00:00 2001 From: Manish Gupta Date: Fri, 13 Oct 2023 08:33:15 -0700 Subject: [PATCH] Adding more Threadblock Tiles for Mixed-input TensorOp (BF16 * S8) in cutlass_library (#1132) * Adding more tiles in the cutlass_library for mixed-input support. * fix rebase issue * more tiles to upcast a --- .gitignore | 1 + python/cutlass_library/generator.py | 105 ++++- test/unit/gemm/device/CMakeLists.txt | 2 + ...8n_bf16t_mixed_input_tensor_op_f32_sm80.cu | 278 +++++++++++++ ...s8n_f16t_mixed_input_tensor_op_f16_sm80.cu | 2 +- ...6n_bf16t_mixed_input_tensor_op_f32_sm80.cu | 384 ++++++++++++++++++ ...16n_f16t_mixed_input_tensor_op_f16_sm80.cu | 2 +- test/unit/gemm/warp/gemm_mixed_input_sm80.cu | 7 +- .../include/cutlass/library/operation_table.h | 12 + tools/library/include/cutlass/library/types.h | 1 + tools/library/src/library_internal.h | 4 + .../reduction/init_reduction_operations.cu | 2 + .../library/src/reduction/reduction_device.cu | 34 ++ .../src/reference/gemm_fp_mixed_input.cu | 16 + 14 files changed, 830 insertions(+), 20 deletions(-) create mode 100644 test/unit/gemm/device/gemm_universal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32_sm80.cu create mode 100644 test/unit/gemm/device/gemm_universal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32_sm80.cu diff --git a/.gitignore b/.gitignore index 1328f6b7d6..acddb1f9d1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ # PyCache files __pycache__/ +cutlass_library.egg-info/ \ No newline at end of file diff --git a/python/cutlass_library/generator.py b/python/cutlass_library/generator.py index ee6bb2ced1..365ea6cf9e 100644 --- a/python/cutlass_library/generator.py +++ b/python/cutlass_library/generator.py @@ -107,7 +107,7 @@ def CreateGemmOperator(manifest, layouts, tile_descriptions, data_type, \ # If alignment is a tuple or a list, then we have different alignments for A and B alignment_a = alignment if isinstance(alignment, int) else alignment[0] alignment_b = alignment if isinstance(alignment, int) else alignment[1] - alignment_c = min(8, alignment_a) + alignment_c = min(8, alignment_a) if isinstance(alignment, int) else alignment[2] A = TensorDescription(element_a, layout[0], alignment_a, complex_transform[0]) B = TensorDescription(element_b, layout[1], alignment_b, complex_transform[1]) @@ -2155,7 +2155,7 @@ def GenerateSM80_PlanarComplexTensorOp_16816(manifest, cuda_version): # -def GenerateSM80_MixedInputTensorOp_16816(manifest, cuda_version): +def GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version): if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): return @@ -2196,27 +2196,66 @@ def GenerateSM80_MixedInputTensorOp_16816(manifest, cuda_version): min_cc = 80 max_cc = 1024 - # For mixed-input alignment constraints are a list of lists, where the inner list - # contains the alignment constraints for [operandA, operandB]. - alignment_constraints = [[16, 8],] + # For mixed-input alignment constraints are a list of lists, where the + # inner list contains the alignment constraints for operands/matrices + # [[alignA, alignB, alignC],..] + alignment_constraints = [[16, 8, 8],] for math_inst in math_instructions: tile_descriptions = [ + # 128x128 TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + # 128x64 + TileDescription([128, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + # 128x32 + TileDescription([128, 32, 64], 9, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + # 128x16 + TileDescription([128, 16, 64], 5, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 16, 64], 3, [2, 1, 1], math_inst, min_cc, max_cc), ] data_type = [ math_inst.element_a, math_inst.element_b, - math_inst.element_b, + math_inst.element_accumulator, math_inst.element_accumulator, ] CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type, alignment_constraints) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_b, + math_inst.element_accumulator, + ] + + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) - # Upcast on Operand B + for op in operations: + if op.tile_description.threadblock_shape[1] <= 32: + op.C.alignment = 4 + + +# +def GenerateSM80_TensorOp_16816_mixed_input_upcast_b(manifest, cuda_version): + + if not CudaToolkitVersionSatisfies(cuda_version, 11, 0): + return + + layouts = [ + (LayoutType.RowMajor, LayoutType.ColumnMajor, LayoutType.ColumnMajor), + ] + math_instructions = [ MathInstruction( \ [16, 8, 16], \ @@ -2243,26 +2282,64 @@ def GenerateSM80_MixedInputTensorOp_16816(manifest, cuda_version): min_cc = 80 max_cc = 1024 - # For mixed-input alignment constraints are a list of lists, where the inner list - # contains the alignment constraints for [operandA, operandB]. - alignment_constraints = [[8, 16],] - + # For mixed-input alignment constraints are a list of lists, where the + # inner list contains the alignment constraints for operands/matrices + # [[alignA, alignB, alignC],..] + alignment_constraints = [[8, 16, 8],] + + for math_inst in math_instructions: tile_descriptions = [ + # 128x128 TileDescription([128, 128, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), TileDescription([128, 128, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + # 128x64 + TileDescription([128, 64, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 4, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 64, 64], 3, [2, 2, 1], math_inst, min_cc, max_cc), + # 128x32 + TileDescription([128, 32, 64], 9, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 64], 5, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 32], 9, [2, 2, 1], math_inst, min_cc, max_cc), + TileDescription([128, 32, 32], 5, [2, 2, 1], math_inst, min_cc, max_cc), + # 128x16 + TileDescription([128, 16, 64], 5, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 16, 64], 3, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 16, 32], 9, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 16, 32], 5, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([128, 16, 32], 3, [2, 1, 1], math_inst, min_cc, max_cc), + # 256x16 + TileDescription([256, 16, 32], 5, [2, 1, 1], math_inst, min_cc, max_cc), + TileDescription([256, 16, 32], 3, [2, 1, 1], math_inst, min_cc, max_cc), ] data_type = [ math_inst.element_a, math_inst.element_b, - math_inst.element_a, + math_inst.element_accumulator, math_inst.element_accumulator, ] CreateGemmOperator(manifest, layouts, tile_descriptions, \ data_type, alignment_constraints) + + # Avoid emitting two kernels if the accumulator type does not differ from the input type (e.g. F16 accumulation) + if math_inst.element_a != math_inst.element_accumulator: + + data_type_mixed = [ + math_inst.element_a, + math_inst.element_b, + math_inst.element_a, + math_inst.element_accumulator, + ] + + operations = CreateGemmOperator(manifest, layouts, tile_descriptions, \ + data_type_mixed, alignment_constraints) + for op in operations: + if op.tile_description.threadblock_shape[1] <= 32: + op.C.alignment = 4 + # def GenerateSM80_TensorOp_16832_TN(manifest, cuda_version): @@ -2645,7 +2722,6 @@ def GenerateSM80_TensorOp_16864_Interleaved(manifest, cuda_version): for op in operations: op.C.alignment = 16 -# # def GenerateSM80_TensorOp_168256(manifest, cuda_version): @@ -4196,7 +4272,8 @@ def GenerateSM80(manifest, cuda_version): GenerateSM80_TensorOp_884_symm(manifest, cuda_version) GenerateSM80_TensorOp_884_symm_complex(manifest, cuda_version) GenerateSM80_TensorOp_884_symm_complex_gaussian(manifest, cuda_version) - GenerateSM80_MixedInputTensorOp_16816(manifest, cuda_version) + GenerateSM80_TensorOp_16816_mixed_input_upcast_a(manifest, cuda_version) + GenerateSM80_TensorOp_16816_mixed_input_upcast_b(manifest, cuda_version) GenerateSM80_TensorOp_16832_TN(manifest, cuda_version) GenerateSM80_SparseTensorOp_16864_TN(manifest, cuda_version) GenerateSM80_TensorOp_16832_Interleaved(manifest, cuda_version) diff --git a/test/unit/gemm/device/CMakeLists.txt b/test/unit/gemm/device/CMakeLists.txt index 752239ab45..0bd60ed936 100644 --- a/test/unit/gemm/device/CMakeLists.txt +++ b/test/unit/gemm/device/CMakeLists.txt @@ -350,10 +350,12 @@ cutlass_test_unit_add_executable( # Upcast on Operand A gemm_universal_u8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu + gemm_universal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32_sm80.cu # Upcast on Operand B gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f16_sm80.cu gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f16_sm80.cu + gemm_universal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32_sm80.cu ) cutlass_test_unit_add_executable( diff --git a/test/unit/gemm/device/gemm_universal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_universal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32_sm80.cu new file mode 100644 index 0000000000..9a29512d5e --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32_sm80.cu @@ -0,0 +1,278 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// + + +TEST(SM80_Device_GemmUniversal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32, 128x128x64_64x64x64) { + + using ElementA = cutlass::bfloat16_t; + using ElementB = int8_t; + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 8, // AlignmentA + 16, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmUniversal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32, 128x128x32_64x64x32) { + + using ElementA = cutlass::bfloat16_t; + using ElementB = int8_t; + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 8, // AlignmentA + 16, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmUniversal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32, 64x128x32_32x64x32) { + + using ElementA = cutlass::bfloat16_t; + using ElementB = int8_t; + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 32>, + cutlass::gemm::GemmShape<32, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 8, // AlignmentA + 16, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + + +TEST(SM80_Device_GemmUniversal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32, 128x64x32_64x32x32) { + + using ElementA = cutlass::bfloat16_t; + using ElementB = int8_t; + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<64, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 8, // Stages + 8, // AlignmentA + 16, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmUniversal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32, 64x64x32_32x32x32) { + + using ElementA = cutlass::bfloat16_t; + using ElementB = int8_t; + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 8, // Stages + 8, // AlignmentA + 16, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmUniversal_bf16t_s8n_bf16t_mixed_input_tensor_op_f32, 16x128x32_16x64x32) { + + using ElementA = cutlass::bfloat16_t; + using ElementB = int8_t; + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<16, 128, 32>, + cutlass::gemm::GemmShape<16, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 8, // Stages + 8, // AlignmentA + 16, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/test/unit/gemm/device/gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f16_sm80.cu b/test/unit/gemm/device/gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f16_sm80.cu index e991c02773..035a33b982 100644 --- a/test/unit/gemm/device/gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f16_sm80.cu +++ b/test/unit/gemm/device/gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f16_sm80.cu @@ -56,7 +56,7 @@ //////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_GemmUniversal_f16t_s8t_f16t_mixed_input_tensor_op_f16, 128x128x64_64x64x64) { +TEST(SM80_Device_GemmUniversal_f16t_s8n_f16t_mixed_input_tensor_op_f16, 128x128x64_64x64x64) { using ElementA = cutlass::half_t; using ElementB = int8_t; diff --git a/test/unit/gemm/device/gemm_universal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32_sm80.cu b/test/unit/gemm/device/gemm_universal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32_sm80.cu new file mode 100644 index 0000000000..340a5a1cf4 --- /dev/null +++ b/test/unit/gemm/device/gemm_universal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32_sm80.cu @@ -0,0 +1,384 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tests for device-wide GEMM interface + +*/ + +#include + +#include "../../common/cutlass_unit_test.h" +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "testbed_universal.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// +TEST(SM80_Device_GemmUniversal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32, 128x128x64_64x64x64) { + + using ElementA = int8_t; + using ElementB = cutlass::bfloat16_t; + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 16, // AlignmentA + 8, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmUniversal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32, 128x128x32_64x64x32) { + + using ElementA = int8_t; + using ElementB = cutlass::bfloat16_t; + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 16, // AlignmentA + 8, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmUniversal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32, 64x128x32_32x64x32) { + + using ElementA = int8_t; + using ElementB = cutlass::bfloat16_t; + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 128, 32>, + cutlass::gemm::GemmShape<32, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 16, // AlignmentA + 8, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmUniversal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32, 64x64x32_32x32x32) { + + using ElementA = int8_t; + using ElementB = cutlass::bfloat16_t; + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 8, // Stages + 16, // AlignmentA + 8, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmUniversal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32, 128x64x32_64x32x32) { + + using ElementA = int8_t; + using ElementB = cutlass::bfloat16_t; + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<64, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 8, // Stages + 16, // AlignmentA + 8, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmUniversal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32, 128x64x32_64x64x32) { + + using ElementA = int8_t; + using ElementB = cutlass::bfloat16_t; + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 8, // Stages + 16, // AlignmentA + 8, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} + +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmUniversal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32, 128x32x32_64x32x32) { + + using ElementA = int8_t; + using ElementB = cutlass::bfloat16_t; + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 32, 32>, + cutlass::gemm::GemmShape<64, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 8, // Stages + 16, // AlignmentA + 8, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmUniversal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32, 128x16x32_64x16x32) { + + using ElementA = int8_t; + using ElementB = cutlass::bfloat16_t; + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 16, 32>, + cutlass::gemm::GemmShape<64, 16, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 4, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 16, // AlignmentA + 8, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +TEST(SM80_Device_GemmUniversal_s8t_bf16n_bf16t_mixed_input_tensor_op_f32, 128x16x64_64x16x64) { + + using ElementA = int8_t; + using ElementB = cutlass::bfloat16_t; + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using Gemm = cutlass::gemm::device::GemmUniversal< + ElementA, + cutlass::layout::RowMajor, + ElementB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 16, 64>, + cutlass::gemm::GemmShape<64, 16, 64>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 4, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 4, // Stages + 16, // AlignmentA + 8, // AlignmentB + cutlass::arch::OpMultiplyAddMixedInputUpcast, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone + >; + + EXPECT_TRUE(test::gemm::device::TestAllGemmUniversal()); +} +//////////////////////////////////////////////////////////////////////////////// + +#endif // #if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED) + +//////////////////////////////////////////////////////////////////////////////// \ No newline at end of file diff --git a/test/unit/gemm/device/gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu b/test/unit/gemm/device/gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu index a0753cda40..8e5a70e8a8 100644 --- a/test/unit/gemm/device/gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu +++ b/test/unit/gemm/device/gemm_universal_s8t_f16n_f16t_mixed_input_tensor_op_f16_sm80.cu @@ -56,7 +56,7 @@ //////////////////////////////////////////////////////////////////////////////// -TEST(SM80_Device_GemmUniversal_s8t_f16t_f16t_mixed_input_tensor_op_f16, 128x128x64_64x64x64) { +TEST(SM80_Device_GemmUniversal_s8t_f16n_f16t_mixed_input_tensor_op_f16, 128x128x64_64x64x64) { using ElementA = int8_t; using ElementB = cutlass::half_t; diff --git a/test/unit/gemm/warp/gemm_mixed_input_sm80.cu b/test/unit/gemm/warp/gemm_mixed_input_sm80.cu index 89d56e10ab..56f6fb742f 100644 --- a/test/unit/gemm/warp/gemm_mixed_input_sm80.cu +++ b/test/unit/gemm/warp/gemm_mixed_input_sm80.cu @@ -75,7 +75,6 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i8, 128x128x64_64x64x64_ .run(); } - TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i8, 64x64x64_64x64x64_16x8x16) { using Shape = cutlass::gemm::GemmShape<64, 64, 64>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; @@ -140,7 +139,6 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_f16, 64x64x64_64x64x64_16 .run(); } - //////////////////////////////////////////////////////////////////////////////// /// F32 <= F16 * U8 + F32 (Upcast on Operand B) //////////////////////////////////////////////////////////////////////////////// @@ -227,6 +225,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_f16, 128x128x64_64x64x64_ .run(); } + //////////////////////////////////////////////////////////////////////////////// /// F32 <= B16 * U8 + F32 (Upcast on Operand B) //////////////////////////////////////////////////////////////////////////////// @@ -251,7 +250,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_u8, 64x64x64_64x64x64_1 } //////////////////////////////////////////////////////////////////////////////// -/// F32 <= B16 * U8 + F32 (Upcast on Operand B) +/// F32 <= U8 * BF16 + F32 (Upcast on Operand A) //////////////////////////////////////////////////////////////////////////////// TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_bf16, 64x64x64_64x64x64_16x8x16) { using Shape = cutlass::gemm::GemmShape<64, 64, 64>; @@ -297,7 +296,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_i8, 64x64x64_64x64x64_1 } //////////////////////////////////////////////////////////////////////////////// -/// F32 <= B16 * I8 + F32 (Upcast on Operand B) +/// F32 <= I8 * BF16 + F32 (Upcast on Operand A) //////////////////////////////////////////////////////////////////////////////// TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_bf16, 64x64x64_64x64x64_16x8x16) { using Shape = cutlass::gemm::GemmShape<64, 64, 64>; diff --git a/tools/library/include/cutlass/library/operation_table.h b/tools/library/include/cutlass/library/operation_table.h index 06ea28b00a..b23036fda4 100644 --- a/tools/library/include/cutlass/library/operation_table.h +++ b/tools/library/include/cutlass/library/operation_table.h @@ -215,6 +215,18 @@ struct GemmPreferenceKey { return compute_capability == rhs.compute_capability; } }; +///////////////////////////////////////////////////////////////////////////////////////////////// +inline +std::ostream& operator<< (std::ostream& out, const cutlass::library::GemmPreferenceKey& key) { + out << "{\n" + << "compute_capability : " << key.compute_capability << std::endl + << "alignment : " << key.alignment << std::endl + << "}"; + + return out; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/library/include/cutlass/library/types.h b/tools/library/include/cutlass/library/types.h index c28efef539..46f9bf54ba 100644 --- a/tools/library/include/cutlass/library/types.h +++ b/tools/library/include/cutlass/library/types.h @@ -172,6 +172,7 @@ enum class MathOperationID { kAdd, kMultiplyAdd, kMultiplyAddSaturate, + kMultiplyAddMixedInputUpcast, kMultiplyAddFastBF16, kMultiplyAddFastF16, kMultiplyAddFastF32, diff --git a/tools/library/src/library_internal.h b/tools/library/src/library_internal.h index 4e4e09d2e1..f45b7d1cf3 100644 --- a/tools/library/src/library_internal.h +++ b/tools/library/src/library_internal.h @@ -174,6 +174,10 @@ template <> struct MathOperationMap { static MathOperationID const kId = MathOperationID::kMultiplyAddSaturate; }; +template <> struct MathOperationMap { + static MathOperationID const kId = MathOperationID::kMultiplyAddMixedInputUpcast; +}; + template <> struct MathOperationMap { static MathOperationID const kId = MathOperationID::kMultiplyAddComplex; }; diff --git a/tools/library/src/reduction/init_reduction_operations.cu b/tools/library/src/reduction/init_reduction_operations.cu index 2d14166b13..0d7ce6af13 100644 --- a/tools/library/src/reduction/init_reduction_operations.cu +++ b/tools/library/src/reduction/init_reduction_operations.cu @@ -44,6 +44,7 @@ namespace library { /////////////////////////////////////////////////////////////////////////////////////////////// void initialize_reduce_add_linear_combination_f16_f16_f16(Manifest &manifest); void initialize_reduce_add_linear_combination_f32_f32_f16(Manifest &manifest); +void initialize_reduce_add_linear_combination_f32_f32_bf16(Manifest &manifest); void initialize_reduce_add_linear_combination_f32_f32_f32(Manifest &manifest); void initialize_reduce_add_linear_combination_f64_f64_f64(Manifest &manifest); void initialize_reduce_add_linear_combination_cf32_cf32_cf32(Manifest &manifest); @@ -55,6 +56,7 @@ void initialize_all_reduction_op(Manifest &manifest) { initialize_reduce_add_linear_combination_f16_f16_f16(manifest); initialize_reduce_add_linear_combination_f32_f32_f16(manifest); + initialize_reduce_add_linear_combination_f32_f32_bf16(manifest); initialize_reduce_add_linear_combination_f32_f32_f32(manifest); initialize_reduce_add_linear_combination_f64_f64_f64(manifest); initialize_reduce_add_linear_combination_cf32_cf32_cf32(manifest); diff --git a/tools/library/src/reduction/reduction_device.cu b/tools/library/src/reduction/reduction_device.cu index 5ede2fdfae..34852c4ace 100644 --- a/tools/library/src/reduction/reduction_device.cu +++ b/tools/library/src/reduction/reduction_device.cu @@ -112,6 +112,40 @@ void initialize_reduce_add_linear_combination_f32_f32_f16(Manifest &manifest) { )); } +void initialize_reduce_add_linear_combination_f32_f32_bf16(Manifest &manifest) { + + using ElementWorkspace = float; + using ElementAccumulator = float; + using ElementOutput = cutlass::bfloat16_t; + using ElementCompute = float; + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute + >; + + using ReductionOp = cutlass::reduction::thread::ReduceAdd< + ElementAccumulator, + typename EpilogueOutputOp::ElementAccumulator, + EpilogueOutputOp::kCount + >; + + using Operation_reduce_add_linear_combination_f32_f32_bf16 = cutlass::reduction::device::ReduceSplitK< + cutlass::reduction::kernel::ReduceSplitK< + cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, + EpilogueOutputOp, + ReductionOp + > + >; + + manifest.append(new ReductionOperation< + Operation_reduce_add_linear_combination_f32_f32_bf16>( + "reduce_add_linear_combination_f32_f32_bf16" + )); +} + void initialize_reduce_add_linear_combination_f32_f32_f32(Manifest &manifest) { diff --git a/tools/library/src/reference/gemm_fp_mixed_input.cu b/tools/library/src/reference/gemm_fp_mixed_input.cu index ea1c88bab0..786b610187 100644 --- a/tools/library/src/reference/gemm_fp_mixed_input.cu +++ b/tools/library/src/reference/gemm_fp_mixed_input.cu @@ -104,6 +104,14 @@ void initialize_gemm_reference_operations_fp_mixed_input(Manifest &manifest) { float >(manifest); + make_gemm_real_canonical_layouts< + int8_t, + bfloat16_t, + float, + float, + float + >(manifest); + make_gemm_real_canonical_layouts< int8_t, bfloat16_t, @@ -112,6 +120,14 @@ void initialize_gemm_reference_operations_fp_mixed_input(Manifest &manifest) { float >(manifest); + make_gemm_real_canonical_layouts< + bfloat16_t, + uint8_t, + float, + float, + float + >(manifest); + make_gemm_real_canonical_layouts< bfloat16_t, uint8_t,