From 8c20f62dbcd4d622c8a279b7c81dacb629f1de41 Mon Sep 17 00:00:00 2001 From: Andrey Kharitonchik Date: Wed, 24 Aug 2022 09:52:09 -0700 Subject: [PATCH] x64: brgemm matmul: enable blocked B for 3D problems --- src/cpu/x64/matmul/brgemm_matmul.cpp | 23 +++++---- src/cpu/x64/matmul/brgemm_matmul_utils.cpp | 47 +++++++++++++------ .../inputs/matmul/harness_matmul_data_tags | 16 +++++++ 3 files changed, 63 insertions(+), 23 deletions(-) diff --git a/src/cpu/x64/matmul/brgemm_matmul.cpp b/src/cpu/x64/matmul/brgemm_matmul.cpp index 075b31c2228..c650b2cc862 100644 --- a/src/cpu/x64/matmul/brgemm_matmul.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul.cpp @@ -305,7 +305,8 @@ void brgemm_matmul_t::compute_kernel( ? brgmm_ctx.get_buf_C_ptr(ithr, m_blk_idx, n_blk_idx) : ptr_D; - const auto zp_comp_a = brgmm_ctx.get_zp_a_compensation_ptr(ithr, n_blk_idx); + const auto zp_comp_a + = brgmm_ctx.get_zp_a_compensation_ptr(ithr, b_idx, n_blk_idx); const auto zp_comp_b = brgmm_ctx.get_zp_b_compensation_result_ptr(ithr, m_blk_idx); const auto zp_c_val_ptr = brgmm_ctx.get_zp_c_val_ptr(); @@ -475,7 +476,8 @@ void brgemm_matmul_t::maybe_reduce_partial_results_and_apply_postops( // TODO: support reduction for zp/s8s8 compensations // computed in copy routines const auto zp_comp_a - = brgmm_ctx.get_zp_a_compensation_ptr(ithr, nb); + = brgmm_ctx.get_zp_a_compensation_ptr( + ithr, b, nb); const auto zp_comp_b = brgmm_ctx.get_zp_b_compensation_result_ptr( ithr, mb); @@ -579,8 +581,8 @@ void brgemm_matmul_t::copy_b_chunk_in_buffer( const int n = n_blk_idx * bgmmc.N_blk; const bool is_N_tail = (bgmmc.N - n < bgmmc.N_blk); ctx.current_N_blk = is_N_tail ? bgmmc.N_tail : bgmmc.N_blk; - ctx.zp_a_compensation_ptr - = (void *)brgmm_ctx.get_zp_a_compensation_ptr(ithr, n_blk_idx); + ctx.zp_a_compensation_ptr = (void *)brgmm_ctx.get_zp_a_compensation_ptr( + ithr, b_idx, n_blk_idx); ctx.zp_a_neg_value_ptr = (void *)brgmm_ctx.get_zp_a_neg_val_ptr(); int gb = 0; @@ -709,8 +711,10 @@ struct brgemm_matmul_t::brg_matmul_exec_ctx_t { // multitreaded execution mode const size_t reorder_zp_a_comp_offset = weights_d.size() - weights_d.additional_buffer_size(); + const size_t b_batch + = get_bb_idx(bgmmc.batch - 1, bgmmc_.bcast_B_desc) + 1; const size_t s8s8_buffer_sz = bgmmc.s8s8_compensation_required - ? bgmmc.s8s8_comp_b_str * sizeof(int32_t) + ? sizeof(int32_t) * b_batch * bgmmc.s8s8_comp_b_str : 0; reorder_zp_a_comp_ptr_ = const_cast(reinterpret_cast( @@ -965,7 +969,7 @@ struct brgemm_matmul_t::brg_matmul_exec_ctx_t { ? n_blk_idx % bgmmc_.N_chunk_size : n_blk_idx; return s8s8_compensation_ptr_ + ithr * bgmmc_.s8s8_comp_ithr_str - + b * bgmmc_.s8s8_comp_b_str + + get_bb_idx(b, bgmmc_.bcast_B_desc) * bgmmc_.s8s8_comp_b_str + n_blk_local * bgmmc_.s8s8_comp_n_str; } @@ -987,7 +991,8 @@ struct brgemm_matmul_t::brg_matmul_exec_ctx_t { const int32_t *get_zp_c_val_ptr() const { return &zero_point_c_val_; } - int32_t *get_zp_a_compensation_ptr(int ithr, int n_blk_idx) const { + int32_t *get_zp_a_compensation_ptr( + int ithr, int b_idx, int n_blk_idx) const { if (!bgmmc_.has_zero_point_a) return nullptr; const int n_blk_local = n_blk_idx % bgmmc_.N_chunk_size; @@ -1000,7 +1005,9 @@ struct brgemm_matmul_t::brg_matmul_exec_ctx_t { // locally just before usage. Using the single global scaling before // parallel section might produce significant overhead for small // problems running in multitreaded execution mode - const int base_offset = n_blk_idx * bgmmc_.wei_n_blk; + const int base_offset = get_bb_idx(b_idx, bgmmc_.bcast_B_desc) + * rnd_up(bgmmc_.N, bgmmc_.wei_n_blk) + + n_blk_idx * bgmmc_.wei_n_blk; PRAGMA_OMP_SIMD() for (int b = 0; b < bgmmc_.wei_n_blk; b++) zp_comp[b] = -zero_point_a_negative_val_ diff --git a/src/cpu/x64/matmul/brgemm_matmul_utils.cpp b/src/cpu/x64/matmul/brgemm_matmul_utils.cpp index dca1423a47e..c24b30093e3 100644 --- a/src/cpu/x64/matmul/brgemm_matmul_utils.cpp +++ b/src/cpu/x64/matmul/brgemm_matmul_utils.cpp @@ -42,15 +42,27 @@ int get_default_n_block(format_tag_t matrix_b_tag) { // Note: consider using weights mem_descriptor 'inner_blks' to // return B's inner block for non-default cases. switch (matrix_b_tag) { + case aCB16b64c: + case aCB16b64c2b: + case aCB16b64c4b: case BA16a64b4a: case BA16a64b2a: case BA16a64b: return 64; + case aCB16b48c: + case aCB16b48c2b: + case aCB16b48c4b: case BA16a48b: case BA16a48b2a: case BA16a48b4a: return 48; + case aCB16b32c: + case aCB16b32c2b: + case aCB16b32c4b: case BA16a32b: case BA16a32b2a: case BA16a32b4a: return 32; + case aCB16b16c: + case aCB16b16c2b: + case aCB16b16c4b: case BA16a16b: case BA16a16b2a: case BA16a16b4a: return 16; @@ -242,14 +254,17 @@ status_t brgemm_matmul_conf_utils_t::set_or_check_tags(memory_desc_t &A_md, status_t brgemm_matmul_conf_utils_t::set_B_flags(memory_desc_t &B_md) const { memory_desc_t want_B_md = B_md; + // Set bits for all dimensions except k dimension + const int compensation_mask + = ((1 << bgmmc.ndims) - 1 - (1 << (bgmmc.ndims - 2))); if (bgmmc.s8s8_compensation_required && bgmmc.blocked_B) { want_B_md.extra.flags |= memory_extra_flags::compensation_conv_s8s8; - want_B_md.extra.compensation_mask = (1 << 1); + want_B_md.extra.compensation_mask = compensation_mask; } if (bgmmc.src_zp_type != brgemm_broadcast_t::none && bgmmc.blocked_B) { want_B_md.extra.flags |= memory_extra_flags::compensation_conv_asymmetric_src; - want_B_md.extra.asymm_compensation_mask = (1 << 1); + want_B_md.extra.asymm_compensation_mask = compensation_mask; } if (B_any_layout) { @@ -262,27 +277,29 @@ status_t brgemm_matmul_conf_utils_t::set_B_flags(memory_desc_t &B_md) const { format_tag_t brgemm_matmul_conf_utils_t::pick_blocked_B_layout( int n_blk) const { - if (bgmmc.ndims > 2) return format_tag::undef; + + if (bgmmc.ndims > 3) return format_tag::undef; if (this->is_int8()) switch (n_blk) { - case 64: return BA16a64b4a; - case 48: return BA16a48b4a; - case 32: return BA16a32b4a; - case 16: return BA16a16b4a; + case 64: return bgmmc.ndims == 3 ? aCB16b64c4b : BA16a64b4a; + case 48: return bgmmc.ndims == 3 ? aCB16b48c4b : BA16a48b4a; + case 32: return bgmmc.ndims == 3 ? aCB16b32c4b : BA16a32b4a; + case 16: return bgmmc.ndims == 3 ? aCB16b16c4b : BA16a16b4a; default: return format_tag::undef; } + if (this->is_bf16()) switch (n_blk) { - case 64: return BA16a64b2a; - case 48: return BA16a48b2a; - case 32: return BA16a32b2a; - case 16: return BA16a16b2a; + case 64: return bgmmc.ndims == 3 ? aCB16b64c2b : BA16a64b2a; + case 48: return bgmmc.ndims == 3 ? aCB16b48c2b : BA16a48b2a; + case 32: return bgmmc.ndims == 3 ? aCB16b32c2b : BA16a32b2a; + case 16: return bgmmc.ndims == 3 ? aCB16b16c2b : BA16a16b2a; default: return format_tag::undef; } // Note: bf32 assumes f32 blocking if (this->is_f32() || this->is_bf32()) switch (n_blk) { - case 64: return BA16a64b; - case 48: return BA16a48b; - case 32: return BA16a32b; - case 16: return BA16a16b; + case 64: return bgmmc.ndims == 3 ? aCB16b64c : BA16a64b; + case 48: return bgmmc.ndims == 3 ? aCB16b48c : BA16a48b; + case 32: return bgmmc.ndims == 3 ? aCB16b32c : BA16a32b; + case 16: return bgmmc.ndims == 3 ? aCB16b16c : BA16a16b; default: return format_tag::undef; } return format_tag::undef; diff --git a/tests/benchdnn/inputs/matmul/harness_matmul_data_tags b/tests/benchdnn/inputs/matmul/harness_matmul_data_tags index e049576ea9f..e24c2e3497a 100644 --- a/tests/benchdnn/inputs/matmul/harness_matmul_data_tags +++ b/tests/benchdnn/inputs/matmul/harness_matmul_data_tags @@ -13,6 +13,7 @@ --attr-fpmath=,bf16 --wtag=BA16a64b,BA16a48b,BA16a32b,BA16a16b --batch=shapes_2d +--attr-fpmath= --cfg=bf16bf16bf16 --wtag=BA16a64b2a,BA16a48b2a,BA16a32b2a,BA16a16b2a @@ -21,3 +22,18 @@ --cfg=u8s8f32 --wtag=BA16a64b4a,BA16a48b4a,BA16a32b4a,BA16a16b4a --batch=shapes_2d + +--stag=abc --dtag=abc +--cfg=f32 +--attr-fpmath=,bf16 +--wtag=aCB16b16c,aCB16b32c,aCB16b48c,aCB16b64c +--batch=shapes_3d +--attr-fpmath= + +--cfg=bf16bf16bf16 +--wtag=aCB16b16c2b,aCB16b32c2b,aCB16b48c2b,aCB16b64c2b +--batch=shapes_3d + +--cfg=u8s8f32 +--wtag=aCB16b16c4b,aCB16b32c4b,aCB16b48c4b,aCB16b64c4b +--batch=shapes_3d