Skip to content

Commit

Permalink
src: common: rnn: return invalid when sic!=dic for lbr_gru
Browse files Browse the repository at this point in the history
  • Loading branch information
mgouicem committed Aug 12, 2019
1 parent 4eb9f56 commit 1188010
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
5 changes: 3 additions & 2 deletions src/common/rnn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,8 +146,9 @@ status_t check_dim_consistency(mkldnn_alg_kind_t cell_kind,

// * algorithm specific
args_ok = true
&& IMPLICATION(cell_kind == alg_kind::vanilla_gru,
DIC == SIC);
&& IMPLICATION(utils::one_of(cell_kind, alg_kind::vanilla_gru,
alg_kind::lbr_gru),
DIC == SIC);
if (!args_ok) return invalid_arguments;
int extra_bias =
cell_kind == alg_kind::lbr_gru;
Expand Down
4 changes: 2 additions & 2 deletions tests/benchdnn/inputs/rnn/test_rnn_small
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
--direction=left2right
--activation=TANH
--prop=FWD_D,BWD_DW
--alg=VANILLA_LSTM,LBR_GRU --batch=rnn_small
--alg=VANILLA_GRU --batch=rnn_gru_small
--alg=VANILLA_LSTM --batch=rnn_small
--alg=VANILLA_GRU,LBR_GRU --batch=rnn_gru_small

# LSTM int8
--alg=VANILLA_LSTM
Expand Down

0 comments on commit 1188010

Please sign in to comment.