From bca77ba0ede93611e5776f1344f991b8ca353e74 Mon Sep 17 00:00:00 2001 From: Pedro Larroy Date: Tue, 23 Apr 2019 14:15:06 -0700 Subject: [PATCH] Add backward to fully connected. (_backward_FullyConnected) --- src/operator/nn/fully_connected.cc | 15 +++++++++++++++ tests/python/unittest/test_gluon.py | 20 +++++++++++++++++++- 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc index a097357ef5a3..111c18ef0aff 100644 --- a/src/operator/nn/fully_connected.cc +++ b/src/operator/nn/fully_connected.cc @@ -176,6 +176,20 @@ struct FullyConnectedGrad { } }; +std::vector FullyConnectedBackwardGrad( + const nnvm::NodePtr& n, + const std::vector& ograds) { + std::vector ret; + size_t i = 0; + for (const auto& x : n->inputs) { + std::ostringstream os; + os << n->attrs.name << "_backward_" << i; + ret.emplace_back(nnvm::NodeEntry{MakeNode("zeros_like", os.str(), {x}, nullptr, &n), 0, 0}); + ++i; + } + return ret; +} + inline static bool FCStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask, DispatchMode* dispatch_mode, @@ -325,6 +339,7 @@ NNVM_REGISTER_OP(_backward_FullyConnected) .set_attr("FInplaceOption", [](const NodeAttrs& attrs){ return std::vector >{{1, 0}}; }) +.set_attr("FGradient", FullyConnectedBackwardGrad) .set_attr("FInferStorageType", BackwardFCStorageType) .set_attr_parser(ParamParser) #if MXNET_USE_MKLDNN == 1 diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index efa04f4fa47a..755ca3a4add1 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -21,7 +21,7 @@ import mxnet as mx from mxnet import gluon from mxnet.gluon import nn -from mxnet.test_utils import assert_almost_equal +from mxnet.test_utils import assert_almost_equal, same from mxnet.ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID from common import (setup_module, with_seed, assertRaises, teardown, assert_raises_cudnn_not_satisfied) @@ -915,6 +915,24 @@ def test_sequential_warning(): assert len(w) == 1 +@with_seed() +def test_dense_backward(): + import mxnet.autograd as ag + import mxnet.ndarray as nd + x = nd.array([[1,2,3,400]]) + net = gluon.nn.Sequential() + with net.name_scope(): + net.add(gluon.nn.Dense(1, in_units=x.shape[1])) + net.initialize(mx.initializer.Constant(.5)) + params = [p.data() for p in net.collect_params().values()] + x.attach_grad() + with ag.record(): + y = net.forward(x) + y_grad = ag.grad(y, x, create_graph=True, retain_graph=True)[0] + y_grad.backward() + same(x.grad, nd.zeros(4)) + + @with_seed() def test_global_norm_clip(): stypes = ['default', 'row_sparse']