Skip to content

Commit

Permalink
fix dropout mask output (apache#15697)
Browse files Browse the repository at this point in the history
  • Loading branch information
wuxun-zhang authored and test committed Aug 8, 2019
1 parent 8f6c815 commit 73a0400
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
17 changes: 11 additions & 6 deletions src/operator/nn/dropout-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,18 @@ class DropoutOp {
DType *dataptr = data.dptr_;
auto maskptr = reinterpret_cast<int *>(mask.dptr_);
int count = mask.shape_[0] * mask.shape_[1];
if (sizeof(DType) > sizeof(int)) {
// allocating new buffer to avoiding memory overlapping between `mask.dptr_` and `maskptr`
Tensor<xpu, 1, int> temp = ctx.requested[1].get_space_typed<xpu, 1, int>(Shape1(count), s);
maskptr = temp.dptr_;
}
BernoulliGenerate(*pgen, count, this->pkeep_, maskptr);
const float pk_1 = 1.0f / this->pkeep_;
#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
for (int i = 0; i < count; ++i) {
outptr[i] = dataptr[i] * maskptr[i] * pk_1;
const DType maskVal = static_cast<DType>(maskptr[i]) * pk_1;
outptr[i] = dataptr[i] * maskVal;
mask.dptr_[i] = maskVal;
}
}

Expand All @@ -149,12 +156,11 @@ class DropoutOp {
Tensor<xpu, 2, DType> gdata = in_grad[dropout::kData].FlatTo2D<xpu, DType>(s);
DType *ingradptr = gdata.dptr_;
const DType *outgradptr = grad.dptr_;
auto maskptr = reinterpret_cast<int *>(mask.dptr_);
int count = mask.shape_[0] * mask.shape_[1];
const float pk_1 = 1.0f / this->pkeep_;
const DType *maskptr = mask.dptr_;
const int count = mask.shape_[0] * mask.shape_[1];
#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
for (int i = 0; i < count; ++i) {
ingradptr[i] = outgradptr[i] * maskptr[i] * pk_1;
ingradptr[i] = outgradptr[i] * maskptr[i];
}
}

Expand Down Expand Up @@ -527,5 +533,4 @@ void DropoutGradCompute(const OpStatePtr& state,
} // namespace op
} // namespace mxnet

#undef MXNET_USE_MKL_DROPOUT
#endif // MXNET_OPERATOR_NN_DROPOUT_INL_H_
5 changes: 3 additions & 2 deletions src/operator/nn/dropout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
#include "../operator_common.h"
#include "mxnet/op_attr_types.h"



namespace mxnet {
namespace op {

Expand Down Expand Up @@ -163,6 +161,9 @@ Example::
#endif
}
request.emplace_back(ResourceRequest::kParallelRandom);
#if MXNET_USE_MKL_DROPOUT
request.emplace_back(ResourceRequest::kTempSpace);
#endif
return request;
})
.add_argument("data", "NDArray-or-Symbol", "Input array to which dropout will be applied.")
Expand Down

0 comments on commit 73a0400

Please sign in to comment.