Skip to content

Commit

Permalink
rename OpMultiplyAddMixedInput to OpMultiplyAddMixedInputUpcast
Browse files Browse the repository at this point in the history
  • Loading branch information
Manish Gupta committed Sep 12, 2023
1 parent 46d8dd9 commit 9c6b967
Show file tree
Hide file tree
Showing 7 changed files with 31 additions and 31 deletions.
4 changes: 2 additions & 2 deletions include/cutlass/arch/mma.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ struct OpMultiplyAddFastF16 {};
/////////////////////////////////////////////////////////////////////////////////////////////////

/// Tag indicating the input data types are mixed and the narrower type is
/// converted to the wider type
struct OpMultiplyAddMixedInput {};
/// upcasted to the wider type
struct OpMultiplyAddMixedInputUpcast {};

/////////////////////////////////////////////////////////////////////////////////////////////////

Expand Down
18 changes: 9 additions & 9 deletions include/cutlass/gemm/warp/default_mma_tensor_op_sm80.h
Original file line number Diff line number Diff line change
Expand Up @@ -252,20 +252,20 @@ template <
bool AccumulatorsInRowMajor>
struct DefaultMmaTensorOp<
WarpShape_,
GemmShape<16, 8, 16>, // InstructionShape
ElementA, // Element type of A matrix in Global Memory
LayoutA, // Layout of A matrix in Global Memory
ElementB, // Element type of B matrix in Global Memory
LayoutB, // Layout of B matrix in Global Memory
ElementC, // Element type of C matrix in Global Memory
LayoutC, // Layout of C matrix in Global Memory
arch::OpMultiplyAddMixedInput, // Tag to indicate mixed-input datatype
GemmShape<16, 8, 16>, // InstructionShape
ElementA, // Element type of A matrix in Global Memory
LayoutA, // Layout of A matrix in Global Memory
ElementB, // Element type of B matrix in Global Memory
LayoutB, // Layout of B matrix in Global Memory
ElementC, // Element type of C matrix in Global Memory
LayoutC, // Layout of C matrix in Global Memory
arch::OpMultiplyAddMixedInputUpcast, // Tag to indicate mixed-input datatype, where narrower datatype is upcasted to wider datatype
PartitionsK, AccumulatorsInRowMajor> {


// Check if the ElementA and ElementB are of different data types
static_assert(!std::is_same<ElementA, ElementB>::value,
"DefaultMmaTensorOp with arch::OpMultiplyAddMixedInput ElementA and ElementB cannot be of the same data type");
"DefaultMmaTensorOp with arch::OpMultiplyAddMixedInputUpcast ElementA and ElementB cannot be of the same data type");

// Data type used for internal computation - use the wider of the two data types for mma.sync operands
using ElementOperand = typename std::conditional<(sizeof(ElementA) > sizeof(ElementB)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ TEST(SM80_Device_GemmUniversal_f16t_s8t_f16t_mixed_input_tensor_op_f16, 128x128x
4, // Stages
8, // AlignmentA
16, // AlignmentB
cutlass::arch::OpMultiplyAddMixedInput,
cutlass::arch::OpMultiplyAddMixedInputUpcast,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone
>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ TEST(SM80_Device_GemmUniversal_f16t_u8t_f16t_mixed_input_tensor_op_f16, 128x128x
4, // Stages
8, // AlignmentA
16, // AlignmentB
cutlass::arch::OpMultiplyAddMixedInput,
cutlass::arch::OpMultiplyAddMixedInputUpcast,
cutlass::ComplexTransform::kNone,
cutlass::ComplexTransform::kNone
>;
Expand Down
24 changes: 12 additions & 12 deletions test/unit/gemm/warp/gemm_mixed_input_sm80.cu
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i8, 128x128x64_64x64x64_

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInput>::Type;
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<128, 128, 64> >()
Expand All @@ -89,7 +89,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i8, 64x64x64_64x64x64_16

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInput>::Type;
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<64, 64, 64> >()
Expand All @@ -112,7 +112,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_f16, 128x128x64_64x64x64_

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInput>::Type;
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<128, 128, 64> >()
Expand All @@ -133,7 +133,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_f16, 64x64x64_64x64x64_16

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInput>::Type;
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<64, 64, 64> >()
Expand All @@ -157,7 +157,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_u8, 64x64x64_64x64x64_16

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInput>::Type;
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<64, 64, 64> >()
Expand All @@ -177,7 +177,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_u8, 128x128x64_64x64x64_

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInput>::Type;
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<128, 128, 64> >()
Expand All @@ -200,7 +200,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_f16, 64x64x64_64x64x64_16

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInput>::Type;
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<64, 64, 64> >()
Expand All @@ -220,7 +220,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_f16, 128x128x64_64x64x64_

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInput>::Type;
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<128, 128, 64> >()
Expand All @@ -243,7 +243,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_u8, 64x64x64_64x64x64_1

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInput>::Type;
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<64, 64, 64> >()
Expand All @@ -266,7 +266,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_bf16, 64x64x64_64x64x64_1

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInput>::Type;
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<64, 64, 64> >()
Expand All @@ -289,7 +289,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_i8, 64x64x64_64x64x64_1

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInput>::Type;
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<64, 64, 64> >()
Expand All @@ -312,7 +312,7 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_bf16, 64x64x64_64x64x64_1

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInput>::Type;
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInputUpcast>::Type;

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<64, 64, 64> >()
Expand Down
8 changes: 4 additions & 4 deletions tools/library/scripts/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2172,22 +2172,22 @@ def GenerateSM80_MixedInputTensorOp_16816(manifest, cuda_version):
[16, 8, 16], \
DataType.s8, DataType.f16, DataType.f32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input),
MathOperation.multiply_add_mixed_input_upcast),
MathInstruction( \
[16, 8, 16], \
DataType.u8, DataType.f16, DataType.f32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input),
MathOperation.multiply_add_mixed_input_upcast),
MathInstruction( \
[16, 8, 16], \
DataType.u8, DataType.bf16, DataType.f32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input),
MathOperation.multiply_add_mixed_input_upcast),
MathInstruction( \
[16, 8, 16], \
DataType.s8, DataType.bf16, DataType.f32, \
OpcodeClass.TensorOp, \
MathOperation.multiply_add_mixed_input),
MathOperation.multiply_add_mixed_input_upcast),
]

min_cc = 80
Expand Down
4 changes: 2 additions & 2 deletions tools/library/scripts/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ class ComplexMultiplyOp(enum.Enum):
class MathOperation(enum.Enum):
multiply_add = enum_auto()
multiply_add_saturate = enum_auto()
multiply_add_mixed_input = enum_auto()
multiply_add_mixed_input_upcast = enum_auto()
xor_popc = enum_auto()
and_popc = enum_auto()
multiply_add_fast_bf16 = enum_auto()
Expand All @@ -276,7 +276,7 @@ class MathOperation(enum.Enum):
MathOperationTag = {
MathOperation.multiply_add: 'cutlass::arch::OpMultiplyAdd',
MathOperation.multiply_add_saturate: 'cutlass::arch::OpMultiplyAddSaturate',
MathOperation.multiply_add_mixed_input: 'cutlass::arch::OpMultiplyAddMixedInput',
MathOperation.multiply_add_mixed_input_upcast: 'cutlass::arch::OpMultiplyAddMixedInputUpcast',
MathOperation.xor_popc: 'cutlass::arch::OpXorPopc',
MathOperation.and_popc: 'cutlass::arch::OpAndPopc',
MathOperation.multiply_add_fast_bf16: 'cutlass::arch::OpMultiplyAddFastBF16',
Expand Down

0 comments on commit 9c6b967

Please sign in to comment.