Skip to content

Commit

Permalink
[Relay][Training] Add and fix gradients (#4126)
Browse files Browse the repository at this point in the history
* add and fix gradients

* fix linter issues
  • Loading branch information
altanh authored and vinx13 committed Oct 16, 2019
1 parent 1c0e743 commit 46fa6ee
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 19 deletions.
82 changes: 74 additions & 8 deletions python/tvm/relay/op/_tensor_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@
tile,
transpose,
where,
repeat,
expand_dims,
full_like
)


Expand Down Expand Up @@ -198,6 +201,7 @@ def clip_grad(orig, grad):

@register_gradient("nn.max_pool2d")
def max_pool2d_grad(orig, grad):
"""Returns the gradient of max_pool2d."""
attrs = orig.attrs
pool_grad = _nn.max_pool2d_grad(grad, orig.args[0], pool_size=attrs.pool_size,
strides=attrs.strides, padding=attrs.padding,
Expand All @@ -207,6 +211,7 @@ def max_pool2d_grad(orig, grad):

@register_gradient("nn.avg_pool2d")
def avg_pool2d_grad(orig, grad):
"""Returns the gradient of avg_pool2d."""
attrs = orig.attrs
pool_grad = _nn.avg_pool2d_grad(grad, orig.args[0], pool_size=attrs.pool_size,
strides=attrs.strides, padding=attrs.padding,
Expand All @@ -215,6 +220,26 @@ def avg_pool2d_grad(orig, grad):
return [pool_grad]


@register_gradient("nn.global_avg_pool2d")
def global_avg_pool2d_grad(orig, grad):
"""Returns the gradient of global_avg_pool2d."""
data = orig.args[0]
shape = data.checked_type.shape
layout = orig.attrs.layout

# we assume NCHW or NHWC layout for now, but easy to add more
assert layout in ["NCHW", "NHWC"]
if layout == "NCHW":
pool_size = shape[2], shape[3]
elif layout == "NHWC":
pool_size = shape[1], shape[2]

pool_grad = _nn.avg_pool2d_grad(grad, data, pool_size=pool_size,
strides=(1, 1), padding=(0, 0),
layout=layout)
return [pool_grad]


# not implemented, this is only for testing.
@register_gradient("concatenate")
def concatenate_grad(orig, grad):
Expand Down Expand Up @@ -287,16 +312,53 @@ def conv2d_grad(orig, grad):
return [backward_data, backward_weight]


def _get_reduce_axis(call):
"""Helper function that returns the reduce axis of the call as plain python ints."""
x, axis = call.args[0], call.attrs.axis
shape = x.checked_type.concrete_shape

# should never exclude when axis is None
assert not (axis is None and call.attrs.exclude)

if axis is None:
return None

# convert to nonnegative integers and sort
axis = sorted([ax if ax >= 0 else len(shape) + ax for ax in map(int, axis)])
if call.attrs.exclude:
axis = [ax for ax in range(len(shape)) if ax not in axis]
return axis


def _unreduce_expand(x, axis):
"""Helper function that returns x expanded on the reduced dimensions in axis."""
# assume axis is sorted nonnegative ints
for ax in axis:
x = expand_dims(x, ax)
return x


@register_gradient("max")
def max_grad(orig, grad):
"""Returns the gradient of max"""
# Only support axis=0, since broadcasting orig to x behaves incorrectly
x, axis = orig.args[0], orig.attrs.axis
assert(axis is not None and len(axis) == 1 and int(axis[0]) == 0)
orig = broadcast_to_like(orig, x)
grad = broadcast_to_like(grad, x)
indicators = cast_like(equal(orig, x), grad)
return [indicators * grad]
x, axis = orig.args[0], _get_reduce_axis(orig)
shape = x.checked_type.concrete_shape

repeated = orig
if axis is None:
repeated = full_like(x, repeated)
else:
# expand dims (if necessary) and repeat along each axis
if not orig.attrs.keepdims:
repeated = _unreduce_expand(repeated, axis)
grad = _unreduce_expand(grad, axis)
for ax in axis:
repeated = repeat(repeated, shape[ax], ax)

indicators = cast_like(equal(repeated, x), grad)
num_selected = _sum(indicators, axis, keepdims=True)
# spread error across all max weights
return [indicators * grad / num_selected]


@register_gradient("nn.softmax")
Expand Down Expand Up @@ -372,7 +434,11 @@ def negative_grad(orig, grad):
@register_gradient("sum")
def sum_grad(orig, grad):
"""Returns grad broadcasted to data dims"""
data = orig.args[0]
data, axis = orig.args[0], _get_reduce_axis(orig)
if not orig.attrs.keepdims:
if axis is None:
axis = list(range(len(data.checked_type.concrete_shape)))
grad = _unreduce_expand(grad, axis)
return [broadcast_to_like(grad, data)]


Expand Down
29 changes: 26 additions & 3 deletions tests/python/relay/test_op_grad_level2.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ def verify_max_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode):


def test_max_pool2d_grad():
verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0),
ceil_mode=False)
verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0), ceil_mode=False)
verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(1, 1), strides=(1, 1), padding=(1, 1), ceil_mode=False)


Expand All @@ -75,14 +74,37 @@ def verify_avg_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode, coun
op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)


def test_avg_pool2d_grad():
verify_avg_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0),
ceil_mode=False, count_include_pad=True)
verify_avg_pool2d_grad((1, 4, 16, 16), pool_size=(1, 1), strides=(1, 1), padding=(1, 1),
ceil_mode=False, count_include_pad=False)


def verify_global_avg_pool2d_grad(x_shape):
x = relay.var("x", relay.TensorType(x_shape, "float32"))
y = tvm.relay.nn.global_avg_pool2d(x)

fwd_func = relay.Function([x], y)
fwd_func = run_infer_type(fwd_func)
bwd_func = run_infer_type(gradient(fwd_func))

data = np.random.rand(*x_shape).astype("float32")
y_shape = topi.util.get_const_tuple(fwd_func.ret_type.shape)
out_grad = np.ones(shape=y_shape)
ref_grad = topi.testing.pool_grad_nchw(data, out_grad, pool_size=(x_shape[2], x_shape[3]),
strides=(1, 1), padding=[0, 0, 0, 0], pool_type='avg',
ceil_mode=False)

for target, ctx in ctx_list():
intrp = relay.create_executor(ctx=ctx, target=target)
op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data)
np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01)

def test_global_avg_pool2d_grad():
verify_global_avg_pool2d_grad((1, 4, 16, 16))
verify_global_avg_pool2d_grad((1, 8, 8, 24))

def verify_conv2d_grad(dshape, wshape, strides, padding, dilation, groups=1, mode='higher_order'):
try:
import torch
Expand Down Expand Up @@ -155,6 +177,7 @@ def test_batch_flatten_grad():
if __name__ == "__main__":
test_max_pool2d_grad()
test_avg_pool2d_grad()
test_global_avg_pool2d_grad()
test_conv2d_grad()
test_dense_grad()
test_batch_flatten_grad()
19 changes: 11 additions & 8 deletions tests/python/relay/test_op_grad_level4.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,21 @@ def test_sum_grad():
verify_sum_grad((4, 2))
verify_sum_grad((4, 2), axis=-1, keepdims=True)
verify_sum_grad((4, 2, 1), axis=(1, 2), exclude=True)
verify_sum_grad((4, 2, 1), axis=1)


def test_max_grad():
s = (10, 10)
t = relay.TensorType(s)
x = relay.var("x", t)
axis = 0
z = relay.max(x, axis)

fwd_func = relay.Function([x], z)
def verify_max_grad(d_shape, axis=None, keepdims=False, exclude=False):
data = relay.var("data", relay.TensorType(d_shape, "float32"))
fwd_func = relay.Function([data], relay.max(data, axis=axis, keepdims=keepdims, exclude=exclude))
check_grad(fwd_func, scale=1e-3)


def test_max_grad():
verify_max_grad((10, 10), axis=None)
verify_max_grad((10, 10), axis=-1)
verify_max_grad((6, 3, 2), axis=(1, 2), keepdims=True)
verify_max_grad((5, 4, 3), axis=(0, 2), exclude=True)


if __name__ == "__main__":
pytest.main()

0 comments on commit 46fa6ee

Please sign in to comment.