From 017950a16168640764d17558e41010d0ae038377 Mon Sep 17 00:00:00 2001 From: Andrey Kalinin Date: Thu, 13 Apr 2023 15:13:29 -0700 Subject: [PATCH] x64: brgemm bwd_w conv: use ih_block instead of ih for tr_src scratchpad --- src/cpu/x64/jit_brgemm_conv_bwd_w.cpp | 8 ++++---- src/cpu/x64/jit_brgemm_conv_utils.cpp | 23 +++++++++++++---------- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/src/cpu/x64/jit_brgemm_conv_bwd_w.cpp b/src/cpu/x64/jit_brgemm_conv_bwd_w.cpp index cc2ae19ff11..666eed571ba 100644 --- a/src/cpu/x64/jit_brgemm_conv_bwd_w.cpp +++ b/src/cpu/x64/jit_brgemm_conv_bwd_w.cpp @@ -136,7 +136,7 @@ status_t brgemm_convolution_bwd_weights_t::pd_t::init(engine_t *engine) { brgattr.max_top_vpad = 0; brgattr.max_bottom_vpad = 0; - brgattr.LDA2 = jcp_.tr_iw * jcp_.ih * jcp_.id; + brgattr.LDA2 = jcp_.tr_iw * jcp_.ih_block * jcp_.id; brgattr.LDB2 = jcp_.tr_ow * jcp_.oc_block * jcp_.oh * jcp_.od; brgattr.LDC2_M = jcp_.oc_block * jcp_.kd * jcp_.kh * jcp_.kw; brgattr.LDC2_N = jcp_.nb_ic * jcp_.ic_block * jcp_.oc_block @@ -465,7 +465,7 @@ struct brgemm_convolution_bwd_weights_t::thread_info_t { size_t tr_src_off(int g, int icb, int id, int ih) const { const size_t tr_row_size = jcp.tr_iw * jcp.ic_block; - const size_t tr_3d_size = tr_row_size * jcp.ih; + const size_t tr_3d_size = tr_row_size * jcp.ih_block; int adj = (jcp.global_transpose) ? 1 : jcp.nb_ic_blocking; // Aligned to buffer end to use guard elements return tr_src_buf_number(g, icb) * adj * jcp.tr_src_buf_size @@ -1023,7 +1023,7 @@ void brgemm_convolution_bwd_weights_t::compute_diff_weights_3d( + _pd->filter_w_to_src(kw) / jcp.stride_w + (kw % jcp.stride_w) * src_stride_w_shift + (bs_ih_s - ih_s) * jcp.tr_iw * jcp.ic_block - + (bs_id_s - id_s) * jcp.ih * jcp.tr_iw * jcp.ic_block; + + (bs_id_s - id_s) * jcp.ih_block * jcp.tr_iw * jcp.ic_block; const void *ptr_B = ((diff_dst_data_t *)p_dst) + (bs_oh_s - oh_s) * jcp.tr_ow * jcp.oc_block + (bs_od_s - od_s) * jcp.oh * jcp.tr_ow * jcp.oc_block; @@ -1045,7 +1045,7 @@ void brgemm_convolution_bwd_weights_t::compute_diff_weights_3d( ti->brg_batch[odb * bs_h + ohb].ptr.A = (char *)ptr_A + ohb * jcp.typesize_in * jcp.tr_iw * jcp.ic_block * jcp.stride_h - + odb * jcp.typesize_in * jcp.ih * jcp.tr_iw + + odb * jcp.typesize_in * jcp.ih_block * jcp.tr_iw * jcp.ic_block * jcp.stride_d; ti->brg_batch[odb * bs_h + ohb].ptr.B = (char *)ptr_B + ohb * jcp.typesize_in * jcp.tr_ow * jcp.oc_block diff --git a/src/cpu/x64/jit_brgemm_conv_utils.cpp b/src/cpu/x64/jit_brgemm_conv_utils.cpp index 9fc86b07d9b..fff992e1740 100644 --- a/src/cpu/x64/jit_brgemm_conv_utils.cpp +++ b/src/cpu/x64/jit_brgemm_conv_utils.cpp @@ -2518,16 +2518,6 @@ void balance_bwd_w(jit_brgemm_conv_conf_t &jcp) { jcp.nthr_g = nthr_g; jcp.nthr_oc_b = nthr_oc_b; jcp.nthr_ic_b = nthr_ic_b; - - // TODO: Optimize memory allocation when threaded on height and depth - jcp.tr_src_buf_size = jcp.tr_iw * jcp.ic_block * jcp.ih * jcp.id; - jcp.tr_diff_dst_buf_size = jcp.tr_ow * jcp.oc_block * jcp.oh * jcp.od; - jcp.tr_src_buf_count = jcp.global_transpose - ? jcp.nthr_mb * jcp.nb_ic * jcp.ngroups - : jcp.nthr; - jcp.tr_diff_dst_buf_count = jcp.global_transpose - ? jcp.nthr_mb * jcp.nb_oc * jcp.ngroups - : jcp.nthr; } status_t init_conf_bwd_w(jit_brgemm_conv_conf_t &jcp, @@ -2758,6 +2748,19 @@ status_t init_conf_bwd_w(jit_brgemm_conv_conf_t &jcp, // try to split oh by equal oh blocks oh_block_limit = div_up(jcp.oh, div_up(jcp.oh, oh_block_limit)); jcp.oh_block = utils::saturate(1, jcp.oh, oh_block_limit); + jcp.ih_block = nstl::min(jcp.ih, + jcp.stride_h + * brg_blocking_t::get_inp_size(jcp.ih, jcp.oh_block, jcp.kh, + jcp.stride_h, jcp.dilate_h)); + // TODO: Optimize memory allocation when threaded on height and depth + jcp.tr_src_buf_count = jcp.global_transpose + ? jcp.nthr_mb * jcp.nb_ic * jcp.ngroups + : jcp.nthr; + jcp.tr_diff_dst_buf_count = jcp.global_transpose + ? jcp.nthr_mb * jcp.nb_oc * jcp.ngroups + : jcp.nthr; + jcp.tr_src_buf_size = jcp.tr_iw * jcp.ic_block * jcp.ih_block * jcp.id; + jcp.tr_diff_dst_buf_size = jcp.tr_ow * jcp.oc_block * jcp.oh * jcp.od; const int iframe_size = irow_size * jcp.id; const int oframe_size = orow_size * jcp.od;