Skip to content

Commit

Permalink
Handle mixed-input upcast on OperandA (Support [S8|U8]*[F16|BF16]
Browse files Browse the repository at this point in the history
  • Loading branch information
Manish Gupta committed Sep 12, 2023
1 parent 9d705de commit 46d8dd9
Show file tree
Hide file tree
Showing 5 changed files with 254 additions and 7 deletions.
96 changes: 94 additions & 2 deletions include/cutlass/gemm/warp/mma_mixed_input_tensor_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ template <
int NumElementsInWarpFragment,
/// Number of elements in mma fragment
int NumElementsInMmaFragment,
/// Identifies A or B multiplicand
Operand Operand_,
///
typename Enable = void >
struct FragmentShuffler {
Expand All @@ -89,6 +91,7 @@ struct FragmentShuffler {
static int const kNumMmaInstructions = NumMmaInstructions;
static int const kNumElementsInWarpFragment = NumElementsInWarpFragment;
static int const kNumElementsInMmaFragment = NumElementsInMmaFragment;
static Operand const kOperand = Operand_;

using WarpFragment = Array<ElementMma, kNumElementsInWarpFragment>;
using MmaFragment = Array<ElementMma, kNumElementsInMmaFragment>;
Expand All @@ -101,6 +104,7 @@ struct FragmentShuffler {
////////////////////////////////////////////////////////////////////////////////

/// Partial specialization for `mma.sync` on 16b (F16/BF16) and `ldmatrix` on 8b (S8/U8)
/// for operand A multiplicand going through upcasting.
template <
/// Element type for the operand in registers for the mma.sync
typename ElementMma_,
Expand All @@ -117,6 +121,7 @@ struct FragmentShuffler <ElementMma_, ElementLoad_,
NumMmaInstructions,
NumElementsInWarpFragment,
NumElementsInMmaFragment,
Operand::kA,
typename std::enable_if<(sizeof_bits<ElementMma_>::value == 16) &&
(sizeof_bits<ElementLoad_>::value == 8)>::type> {
public:
Expand All @@ -126,6 +131,93 @@ struct FragmentShuffler <ElementMma_, ElementLoad_,
static int const kNumMmaInstructions = NumMmaInstructions;
static int const kNumElementsInWarpFragment = NumElementsInWarpFragment;
static int const kNumElementsInMmaFragment = NumElementsInMmaFragment;
static Operand const kOperand = Operand::kA;

using WarpFragment = Array<ElementMma, kNumElementsInWarpFragment>;
using MmaFragment = Array<ElementMma, kNumElementsInMmaFragment>;

private:
int delta_up_;
int delta_down_;
int odd_even_lane_id_;

public:
CUTLASS_DEVICE
FragmentShuffler() {
int lane_id = cutlass::arch::LaneId();
delta_up_ = (lane_id & 1) + ((lane_id & 2) >> 1);
delta_down_ = 2 - delta_up_;
odd_even_lane_id_ = static_cast<int>(lane_id & 1);
}

CUTLASS_DEVICE
void operator()(WarpFragment &dst, WarpFragment &src) {

MmaFragment *ptr_mma_frag_src = reinterpret_cast<MmaFragment *>(&src);
MmaFragment *ptr_mma_frag_dst = reinterpret_cast<MmaFragment *>(&dst);

CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < kNumMmaInstructions; n++) {

uint32_t *ptr_src = reinterpret_cast<uint32_t *>(&ptr_mma_frag_src[n]);
uint32_t *ptr_dst = reinterpret_cast<uint32_t *>(&ptr_mma_frag_dst[n]);

uint32_t even_thread_r0 = __shfl_up_sync(0xFFFFFFFF, ptr_src[0], delta_up_);
uint32_t odd_thread_r0 = __shfl_up_sync(0xFFFFFFFF, ptr_src[1], delta_up_);

uint32_t even_thread_r1 = __shfl_up_sync(0xFFFFFFFF, ptr_src[2], delta_up_);
uint32_t odd_thread_r1 = __shfl_up_sync(0xFFFFFFFF, ptr_src[3], delta_up_);

uint32_t even_thread_r2 = __shfl_down_sync(0xFFFFFFFF, ptr_src[0], delta_down_);
uint32_t odd_thread_r2 = __shfl_down_sync(0xFFFFFFFF, ptr_src[1], delta_down_);

uint32_t even_thread_r3 = __shfl_down_sync(0xFFFFFFFF, ptr_src[2], delta_down_);
uint32_t odd_thread_r3 = __shfl_down_sync(0xFFFFFFFF, ptr_src[3], delta_down_);

ptr_dst[0] = odd_even_lane_id_ * odd_thread_r0 +
(1 - odd_even_lane_id_) * even_thread_r0;
ptr_dst[1] = odd_even_lane_id_ * odd_thread_r1 +
(1 - odd_even_lane_id_) * even_thread_r1;
ptr_dst[2] = odd_even_lane_id_ * odd_thread_r2 +
(1 - odd_even_lane_id_) * even_thread_r2;
ptr_dst[3] = odd_even_lane_id_ * odd_thread_r3 +
(1 - odd_even_lane_id_) * even_thread_r3;

}
}

};
////////////////////////////////////////////////////////////////////////////////

/// Partial specialization for `mma.sync` on 16b (F16/BF16) and `ldmatrix` on 8b (S8/U8)
/// for operand B multiplicand going through upcasting.
template <
/// Element type for the operand in registers for the mma.sync
typename ElementMma_,
/// Element type for the operand in shared memory for ldmatrix
typename ElementLoad_,
/// Number of mma.sync operations performed along rows or columns
int NumMmaInstructions,
/// Number of elements in warp fragment
int NumElementsInWarpFragment,
/// Number of elements in mma fragment
int NumElementsInMmaFragment
>
struct FragmentShuffler <ElementMma_, ElementLoad_,
NumMmaInstructions,
NumElementsInWarpFragment,
NumElementsInMmaFragment,
Operand::kB,
typename std::enable_if<(sizeof_bits<ElementMma_>::value == 16) &&
(sizeof_bits<ElementLoad_>::value == 8)>::type> {
public:
using ElementMma = ElementMma_;
using ElementLoad = ElementLoad_;

static int const kNumMmaInstructions = NumMmaInstructions;
static int const kNumElementsInWarpFragment = NumElementsInWarpFragment;
static int const kNumElementsInMmaFragment = NumElementsInMmaFragment;
static Operand const kOperand = Operand::kB;

using WarpFragment = Array<ElementMma, kNumElementsInWarpFragment>;
using MmaFragment = Array<ElementMma, kNumElementsInMmaFragment>;
Expand Down Expand Up @@ -433,7 +525,7 @@ class MmaMixedInputTensorOp {

// Shuffle data within warp to obtain the mma.sync operand layout
detail::FragmentShuffler<MmaElementA, ElementA, MmaIterations::kRow,
FragmentA::kElements, MmaOperandA::kElements> shuffler_A;
FragmentA::kElements, MmaOperandA::kElements, Operand::kA> shuffler_A;

// Shuffle the A operand, inplace, to the Mma Instruction operand layout
shuffler_A(dst_A, dst_A);
Expand All @@ -444,7 +536,7 @@ class MmaMixedInputTensorOp {

// Shuffle data within warp to obtain the mma.sync operand layout
detail::FragmentShuffler<MmaElementB, ElementB, MmaIterations::kColumn,
FragmentB::kElements, MmaOperandB::kElements> shuffler_B;
FragmentB::kElements, MmaOperandB::kElements, Operand::kB> shuffler_B;

// Shuffle the B operand, inplace, to the Mma Instruction operand layout
shuffler_B(dst_B, dst_B);
Expand Down
4 changes: 2 additions & 2 deletions test/unit/gemm/device/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,8 @@ cutlass_test_unit_add_executable(
BATCH_SOURCES ON
BATCH_SIZE 4

gemm_universal_f16n_u8t_f16t_mixed_input_tensor_op_f16_sm80.cu
gemm_universal_f16n_s8t_f16t_mixed_input_tensor_op_f16_sm80.cu
gemm_universal_f16t_u8n_f16t_mixed_input_tensor_op_f16_sm80.cu
gemm_universal_f16t_s8n_f16t_mixed_input_tensor_op_f16_sm80.cu
)

cutlass_test_unit_add_executable(
Expand Down
161 changes: 158 additions & 3 deletions test/unit/gemm/warp/gemm_mixed_input_sm80.cu
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@


////////////////////////////////////////////////////////////////////////////////
/// F32 <= F16 * I8 + F32
/// F32 <= F16 * I8 + F32 (Upcast on Opernad B)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i8, 128x128x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
Expand Down Expand Up @@ -96,9 +96,53 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_i8, 64x64x64_64x64x64_16
.run();
}

////////////////////////////////////////////////////////////////////////////////
/// F32 <= I8 * F16 + F32 (Upcast on Opernad A)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_f16, 128x128x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using ElementA = int8_t;
using ElementB = cutlass::half_t;;
using ElementC = float;
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementA>::value, 64>;
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementB>::value, 64>;

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInput>::Type;

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<128, 128, 64> >()
.run();
}


TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_f16, 64x64x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using ElementA = int8_t;
using ElementB = cutlass::half_t;
using ElementC = float;
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementA>::value, 64>;
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementB>::value, 64>;

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInput>::Type;

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<64, 64, 64> >()
.run();
}


////////////////////////////////////////////////////////////////////////////////
/// F32 <= F16 * U8 + F32
/// F32 <= F16 * U8 + F32 (Upcast on Opernad B)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_u8, 64x64x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
Expand Down Expand Up @@ -141,7 +185,50 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_f16_u8, 128x128x64_64x64x64_
}

////////////////////////////////////////////////////////////////////////////////
/// F32 <= B16 * U8 + F32
/// F32 <= U8 * F16 + F32 (Upcast on Opernad A)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_f16, 64x64x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using ElementA = uint8_t;
using ElementB = cutlass::half_t;
using ElementC = float;
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementA>::value, 64>;
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementB>::value, 64>;

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInput>::Type;

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<64, 64, 64> >()
.run();
}

TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_f16, 128x128x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using ElementA = uint8_t;
using ElementB = cutlass::half_t;
using ElementC = float;
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementA>::value, 64>;
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementB>::value, 64>;

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInput>::Type;

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<128, 128, 64> >()
.run();
}

////////////////////////////////////////////////////////////////////////////////
/// F32 <= B16 * U8 + F32 (Upcast on Opernad B)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_u8, 64x64x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
Expand All @@ -163,5 +250,73 @@ TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_u8, 64x64x64_64x64x64_1
.run();
}

////////////////////////////////////////////////////////////////////////////////
/// F32 <= B16 * U8 + F32 (Upcast on Opernad B)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_u8_bf16, 64x64x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using ElementA = uint8_t;
using ElementB = cutlass::bfloat16_t;
using ElementC = float;
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementA>::value, 64>;
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementB>::value, 64>;

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInput>::Type;

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<64, 64, 64> >()
.run();
}

////////////////////////////////////////////////////////////////////////////////
/// F32 <= B16 * I8 + F32 (Upcast on Opernad B)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_bf16_i8, 64x64x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using ElementA = cutlass::bfloat16_t;
using ElementB = int8_t;
using ElementC = float;
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementA>::value, 64>;
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementB>::value, 64>;

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInput>::Type;

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<64, 64, 64> >()
.run();
}

////////////////////////////////////////////////////////////////////////////////
/// F32 <= B16 * I8 + F32 (Upcast on Opernad B)
////////////////////////////////////////////////////////////////////////////////
TEST(SM80_warp_gemm_mixed_input_tensor_op_crosswise_i8_bf16, 64x64x64_64x64x64_16x8x16) {
using Shape = cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>;
using ElementA = int8_t;
using ElementB = cutlass::bfloat16_t;
using ElementC = float;
using LayoutA = cutlass::layout::RowMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementA>::value, 64>;
using LayoutB = cutlass::layout::ColumnMajorTensorOpMultiplicandCrosswise<
cutlass::sizeof_bits<ElementB>::value, 64>;

using MmaTensorOp = typename cutlass::gemm::warp::DefaultMmaTensorOp<
Shape, InstructionShape, ElementA, LayoutA, ElementB, LayoutB, ElementC,
cutlass::layout::RowMajor, cutlass::arch::OpMultiplyAddMixedInput>::Type;

test::gemm::warp::TransformTestbed<MmaTensorOp,
cutlass::gemm::GemmShape<64, 64, 64> >()
.run();
}

#endif // if defined(CUTLASS_ARCH_MMA_SM80_SUPPORTED)

0 comments on commit 46d8dd9

Please sign in to comment.