Skip to content

Commit

Permalink
Merge pull request #3553 from reyoung/feature/unittest_for_mean_grad
Browse files Browse the repository at this point in the history
Add MeanOp's Gradient Test And Fix Mean Op Gradient
  • Loading branch information
gangliao authored Aug 17, 2017
2 parents c68bfc3 + 7f8c3f8 commit 62aedce
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
3 changes: 2 additions & 1 deletion paddle/operators/mean_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,10 @@ class MeanGradKernel : public framework::OpKernel {
IG->mutable_data<T>(context.GetPlace());

T ig_size = (T)framework::product(IG->dims());
Eigen::DSizes<int, 1> bcast(ig_size);

EigenVector<T>::Flatten(*IG).device(context.GetEigenDevice<Place>()) =
EigenScalar<T>::From(*OG) / ig_size;
(EigenVector<T>::From(*OG) / ig_size).broadcast(bcast);
}
};

Expand Down
8 changes: 8 additions & 0 deletions python/paddle/v2/framework/tests/test_mean_op.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest
from op_test_util import OpTestMeta
from gradient_checker import GradientChecker, create_op
import numpy as np


Expand All @@ -12,5 +13,12 @@ def setUp(self):
self.outputs = {'Out': np.mean(self.inputs['X'])}


class MeanGradOpTest(GradientChecker):
def test_normal(self):
op = create_op("mean")
inputs = {"X": np.random.random((10, 10)).astype("float32")}
self.check_grad(op, inputs, set("X"), "Out")


if __name__ == '__main__':
unittest.main()

0 comments on commit 62aedce

Please sign in to comment.