From b2e948f931367c81a6887d4e0e544a9f50dcd673 Mon Sep 17 00:00:00 2001 From: Andrey Kalinin Date: Fri, 10 Feb 2023 17:51:54 -0800 Subject: [PATCH] cpu: x64: brgemm bwd_w convolution: fix batchsizes indexes --- src/cpu/x64/jit_brgemm_conv_bwd_w.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/cpu/x64/jit_brgemm_conv_bwd_w.cpp b/src/cpu/x64/jit_brgemm_conv_bwd_w.cpp index 0c6237c5658..fd3602f9442 100644 --- a/src/cpu/x64/jit_brgemm_conv_bwd_w.cpp +++ b/src/cpu/x64/jit_brgemm_conv_bwd_w.cpp @@ -98,7 +98,8 @@ status_t brgemm_convolution_bwd_weights_t::pd_t::init(engine_t *engine) { auto M = (i) ? jcp_.M_tail : jcp_.M; if (M <= 0) continue; // init only needed brgemm descriptors - for (int bs = 0; bs <= jcp_.max_batch; bs++) { + const auto bs_end = jcp_.var_bs ? 1 : jcp_.max_batch; + for (int bs = 0; bs <= bs_end; bs++) { if (batchsizes[bs] == -1) continue; for_(int i_init = init_begin; i_init < init_end; i_init++) for_(int i_N = N_begin; i_N < N_end; i_N++) @@ -274,7 +275,8 @@ status_t brgemm_convolution_bwd_weights_t::init(engine_t *engine) { int init_begin = 0; int init_end = 2; - for (int bs = 0; bs <= jcp.max_batch; bs++) { + const auto bs_end = jcp.var_bs ? 1 : jcp.max_batch; + for (int bs = 0; bs <= bs_end; bs++) { if (_pd->batchsizes[bs] == -1) continue; for_(int i_N = N_begin; i_N < N_end; i_N++)