Skip to content

Commit

Permalink
[CPU] Remove 8x8x16 i8mm microkernel
Browse files Browse the repository at this point in the history
  • Loading branch information
mariecwhite committed Mar 19, 2024
1 parent a2ed5d1 commit 2eae9b3
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -286,8 +286,6 @@ iree_uk_mmt4d_select_tile_func_arm_64_i8i4i32_M0x8x16(
return iree_uk_mmt4d_tile_s8s4s32_2x8x16_arm_64_i8mm;
case 4:
return iree_uk_mmt4d_tile_s8s4s32_4x8x16_arm_64_i8mm;
case 8:
return iree_uk_mmt4d_tile_s8s4s32_8x8x16_arm_64_i8mm;
}
}
#endif
Expand Down
21 changes: 10 additions & 11 deletions runtime/src/iree/builtins/ukernel/arch/arm_64/mmt4d_arm_64_i8mm.c
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,12 @@ iree_uk_mmt4d_tile_s8s4s32_1x8x16_arm_64_i8mm(
}

IREE_UK_ATTRIBUTE_ALWAYS_INLINE static inline void
iree_uk_mmt4d_tile_s8s4s32_2x8x16_to_8x8x16_arm_64_i8mm(
iree_uk_mmt4d_tile_s8s4s32_2x8x16_to_4x8x16_arm_64_i8mm(
void* IREE_UK_RESTRICT out_tile, const void* IREE_UK_RESTRICT lhs_panel,
const void* IREE_UK_RESTRICT rhs_panel,
const iree_uk_mmt4d_params_t* params, int M0) {
IREE_UK_ASSERT(M0 >= 2 && M0 <= 8 && iree_uk_is_po2_u32(M0));
// We support M0 up to 4 in order to fit within the register budget.
IREE_UK_ASSERT(M0 >= 2 && M0 <= 4 && iree_uk_is_po2_u32(M0));
IREE_UK_ASSERT(!(params->K0 % 16));
const iree_uk_int8_t* IREE_UK_RESTRICT lhs_ptr = lhs_panel;
const iree_uk_int8_t* IREE_UK_RESTRICT rhs_ptr = rhs_panel;
Expand All @@ -184,7 +185,7 @@ iree_uk_mmt4d_tile_s8s4s32_2x8x16_to_8x8x16_arm_64_i8mm(
const int8x16_t vmask = vmovq_n_s8(0xF0);
const int mtiles = M0 / 2;

int32x4_t acc[4][4];
int32x4_t acc[2][4];
IREE_UK_UNROLL for (int i = 0; i < mtiles; i++) {
IREE_UK_UNROLL for (int j = 0; j < 4; j++) {
// We start with zero accumulators and add the value of *out_ptr later.
Expand All @@ -202,23 +203,24 @@ iree_uk_mmt4d_tile_s8s4s32_2x8x16_to_8x8x16_arm_64_i8mm(
}
rhs_ptr += 64;

int8x16_t lhs[2][4];
int8x16_t lhs[2][2];
if (M0 == 2) {
int8x8x2_t lhs_uzp[2];
IREE_UK_UNROLL for (int i = 0; i < 2; i++) {
lhs_uzp[i] = vld2_s8(lhs_ptr + 16 * i);
}
lhs[0][0] = vcombine_s8(lhs_uzp[0].val[0], lhs_uzp[1].val[0]);
lhs[1][0] = vcombine_s8(lhs_uzp[0].val[1], lhs_uzp[1].val[1]);
lhs_ptr += 32;
} else {
IREE_UK_UNROLL for (int i = 0; i < mtiles; i++) {
IREE_UK_UNROLL for (int i = 0; i < 2; i++) {
int8x8x2_t lhs_0 = vld2_s8(lhs_ptr + 16 * 2 * i);
int8x8x2_t lhs_1 = vld2_s8(lhs_ptr + 16 * (2 * i + 1));
lhs[0][i] = vcombine_s8(lhs_0.val[0], lhs_1.val[0]);
lhs[1][i] = vcombine_s8(lhs_0.val[1], lhs_1.val[1]);
}
lhs_ptr += 64;
}
lhs_ptr += 32 * mtiles;

IREE_UK_UNROLL for (int i = 0; i < mtiles; i++) {
IREE_UK_UNROLL for (int j = 0; j < 4; j++) {
Expand Down Expand Up @@ -255,11 +257,8 @@ iree_uk_mmt4d_tile_s8s4s32_2x8x16_to_8x8x16_arm_64_i8mm(
}

IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_s8s4s32_2x8x16_to_8x8x16_arm_64_i8mm,
iree_uk_mmt4d_tile_s8s4s32_2x8x16_to_4x8x16_arm_64_i8mm,
iree_uk_mmt4d_tile_s8s4s32_2x8x16_arm_64_i8mm, 2)
IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_s8s4s32_2x8x16_to_8x8x16_arm_64_i8mm,
iree_uk_mmt4d_tile_s8s4s32_2x8x16_to_4x8x16_arm_64_i8mm,
iree_uk_mmt4d_tile_s8s4s32_4x8x16_arm_64_i8mm, 4)
IREE_UK_MMT4D_TILE_FUNC_IMPL_FOR_M0(
iree_uk_mmt4d_tile_s8s4s32_2x8x16_to_8x8x16_arm_64_i8mm,
iree_uk_mmt4d_tile_s8s4s32_8x8x16_arm_64_i8mm, 8)
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,5 @@ IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_8x8x8_arm_64_dotprod)
IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_1x8x16_arm_64_i8mm)
IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_2x8x16_arm_64_i8mm)
IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_4x8x16_arm_64_i8mm)
IREE_UK_MMT4D_TILE_FUNC_DECL(iree_uk_mmt4d_tile_s8s4s32_8x8x16_arm_64_i8mm)

#endif // IREE_BUILTINS_UKERNEL_ARCH_ARM_64_MMT4D_ARM_64_INTERNAL_H_
2 changes: 1 addition & 1 deletion runtime/src/iree/builtins/ukernel/tools/mmt4d_benchmark.c
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ int main(int argc, char** argv) {
"");
iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S4S32, 8, 8, 8,
"dotprod");
iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S4S32, 8, 8, 16,
iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S4S32, 4, 8, 16,
"i8mm");
#elif defined(IREE_ARCH_X86_64)
iree_uk_benchmark_register_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_F32F32F32, 8, 8, 1,
Expand Down
2 changes: 1 addition & 1 deletion runtime/src/iree/builtins/ukernel/tools/mmt4d_test.c
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,7 @@ int main(int argc, char** argv) {
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 8, 4, "dotprod");
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S8S32, 8, 8, 8, "i8mm");
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S4S32, 8, 8, 8, "dotprod");
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S4S32, 8, 8, 16, "i8mm");
iree_uk_test_mmt4d(IREE_UK_FLAG_MMT4D_TYPE_S8S4S32, 4, 8, 16, "i8mm");

#elif defined(IREE_ARCH_X86_64)

Expand Down

0 comments on commit 2eae9b3

Please sign in to comment.