diff --git a/src/cpu/jit_uni_pool_kernel.cpp b/src/cpu/jit_uni_pool_kernel.cpp index 4d5be2f9bf5..5f65c39cbc4 100644 --- a/src/cpu/jit_uni_pool_kernel.cpp +++ b/src/cpu/jit_uni_pool_kernel.cpp @@ -641,7 +641,7 @@ void jit_uni_pool_kernel::generate() { vmovups(vmm_idx(), ptr[tmp_gpr]); } - if (jpp.is_backward) + if (jpp.is_backward && jpp.simple_alg) maybe_zero_diff_src(); if (jpp.alg == pooling_max && (jpp.is_training || jpp.is_backward)) { diff --git a/src/cpu/jit_uni_pooling.cpp b/src/cpu/jit_uni_pooling.cpp index 691837dbead..86499fa7820 100644 --- a/src/cpu/jit_uni_pooling.cpp +++ b/src/cpu/jit_uni_pooling.cpp @@ -225,6 +225,19 @@ void jit_uni_pooling_bwd_t::execute_backward_3d() const { if (jpp.simple_alg) { + int back_pad = (jpp.od - 1) * jpp.stride_d + jpp.kd + - jpp.f_pad - jpp.id; + // zero-out untouched portions of diff_src (when back_pad is negative) + if (back_pad < 0) + parallel_nd(jpp.mb, jpp.nb_c, -back_pad, jpp.ih, jpp.iw, + [&](int n, int b_c, int id_e, int ih, int iw) { + int id_s = jpp.id + back_pad; + auto ds = &diff_src[diff_src_d.blk_off(n, b_c, + id_s + id_e, ih, iw)]; + for (int i = 0; i < jpp.c_block; ++i) + ds[i] = (data_t)0.f; + }); + parallel_nd(jpp.mb, jpp.nb_c, jpp.od, [&](int n, int b_c, int od) { const int ik = od * jpp.stride_d;