Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Check OpReqTpye in CommitOutput automatically
Browse files Browse the repository at this point in the history
  • Loading branch information
xziya committed Jul 11, 2019
1 parent 8b2cee4 commit 9ca0428
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions src/operator/nn/mkldnn/mkldnn_convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -662,21 +662,20 @@ void MKLDNNConvolutionBackward(const nnvm::NodeAttrs& attrs, const OpContext &ct
in_grad[conv::kWeight],
convBwdWeight.bwdWeights_pd.diff_weights_primitive_desc(),
req[conv::kWeight]);

if (!param.no_bias && req[conv::kBias]) {
if (param.no_bias) {
convBwdWeight.SetWeightNewMem(*data_mem, *out_grad_mem,
*in_grad_weight.second);
MKLDNNStream::Get()->RegisterPrim(convBwdWeight.GetBwdWeights());
} else {
auto in_grad_bias = CreateMKLDNNMem(
in_grad[conv::kBias],
convBwdWeight.bwdWeights_pd.diff_bias_primitive_desc(), req[conv::kBias]);
convBwdWeight.SetWeightNewMem(*data_mem, *out_grad_mem,
*in_grad_weight.second, *in_grad_bias.second);
MKLDNNStream::Get()->RegisterPrim(convBwdWeight.GetBwdWeights());
CommitOutput(in_grad[conv::kBias], in_grad_bias);
} else if (req[conv::kWeight]) {
convBwdWeight.SetWeightNewMem(*data_mem, *out_grad_mem,
*in_grad_weight.second);
MKLDNNStream::Get()->RegisterPrim(convBwdWeight.GetBwdWeights());
}
if (req[conv::kWeight]) CommitOutput(in_grad[conv::kWeight], in_grad_weight);
CommitOutput(in_grad[conv::kWeight], in_grad_weight);
}
MKLDNNStream::Get()->Submit();
}
Expand Down

0 comments on commit 9ca0428

Please sign in to comment.