diff --git a/python/tvm/relay/op/_tensor_grad.py b/python/tvm/relay/op/_tensor_grad.py index 3a82e46e6a7d..1c94162d87d9 100644 --- a/python/tvm/relay/op/_tensor_grad.py +++ b/python/tvm/relay/op/_tensor_grad.py @@ -48,6 +48,9 @@ tile, transpose, where, + repeat, + expand_dims, + full_like ) @@ -198,6 +201,7 @@ def clip_grad(orig, grad): @register_gradient("nn.max_pool2d") def max_pool2d_grad(orig, grad): + """Returns the gradient of max_pool2d.""" attrs = orig.attrs pool_grad = _nn.max_pool2d_grad(grad, orig.args[0], pool_size=attrs.pool_size, strides=attrs.strides, padding=attrs.padding, @@ -207,6 +211,7 @@ def max_pool2d_grad(orig, grad): @register_gradient("nn.avg_pool2d") def avg_pool2d_grad(orig, grad): + """Returns the gradient of avg_pool2d.""" attrs = orig.attrs pool_grad = _nn.avg_pool2d_grad(grad, orig.args[0], pool_size=attrs.pool_size, strides=attrs.strides, padding=attrs.padding, @@ -215,6 +220,26 @@ def avg_pool2d_grad(orig, grad): return [pool_grad] +@register_gradient("nn.global_avg_pool2d") +def global_avg_pool2d_grad(orig, grad): + """Returns the gradient of global_avg_pool2d.""" + data = orig.args[0] + shape = data.checked_type.shape + layout = orig.attrs.layout + + # we assume NCHW or NHWC layout for now, but easy to add more + assert layout in ["NCHW", "NHWC"] + if layout == "NCHW": + pool_size = shape[2], shape[3] + elif layout == "NHWC": + pool_size = shape[1], shape[2] + + pool_grad = _nn.avg_pool2d_grad(grad, data, pool_size=pool_size, + strides=(1, 1), padding=(0, 0), + layout=layout) + return [pool_grad] + + # not implemented, this is only for testing. @register_gradient("concatenate") def concatenate_grad(orig, grad): @@ -287,16 +312,53 @@ def conv2d_grad(orig, grad): return [backward_data, backward_weight] +def _get_reduce_axis(call): + """Helper function that returns the reduce axis of the call as plain python ints.""" + x, axis = call.args[0], call.attrs.axis + shape = x.checked_type.concrete_shape + + # should never exclude when axis is None + assert not (axis is None and call.attrs.exclude) + + if axis is None: + return None + + # convert to nonnegative integers and sort + axis = sorted([ax if ax >= 0 else len(shape) + ax for ax in map(int, axis)]) + if call.attrs.exclude: + axis = [ax for ax in range(len(shape)) if ax not in axis] + return axis + + +def _unreduce_expand(x, axis): + """Helper function that returns x expanded on the reduced dimensions in axis.""" + # assume axis is sorted nonnegative ints + for ax in axis: + x = expand_dims(x, ax) + return x + + @register_gradient("max") def max_grad(orig, grad): """Returns the gradient of max""" - # Only support axis=0, since broadcasting orig to x behaves incorrectly - x, axis = orig.args[0], orig.attrs.axis - assert(axis is not None and len(axis) == 1 and int(axis[0]) == 0) - orig = broadcast_to_like(orig, x) - grad = broadcast_to_like(grad, x) - indicators = cast_like(equal(orig, x), grad) - return [indicators * grad] + x, axis = orig.args[0], _get_reduce_axis(orig) + shape = x.checked_type.concrete_shape + + repeated = orig + if axis is None: + repeated = full_like(x, repeated) + else: + # expand dims (if necessary) and repeat along each axis + if not orig.attrs.keepdims: + repeated = _unreduce_expand(repeated, axis) + grad = _unreduce_expand(grad, axis) + for ax in axis: + repeated = repeat(repeated, shape[ax], ax) + + indicators = cast_like(equal(repeated, x), grad) + num_selected = _sum(indicators, axis, keepdims=True) + # spread error across all max weights + return [indicators * grad / num_selected] @register_gradient("nn.softmax") @@ -372,7 +434,11 @@ def negative_grad(orig, grad): @register_gradient("sum") def sum_grad(orig, grad): """Returns grad broadcasted to data dims""" - data = orig.args[0] + data, axis = orig.args[0], _get_reduce_axis(orig) + if not orig.attrs.keepdims: + if axis is None: + axis = list(range(len(data.checked_type.concrete_shape))) + grad = _unreduce_expand(grad, axis) return [broadcast_to_like(grad, data)] diff --git a/tests/python/relay/test_op_grad_level2.py b/tests/python/relay/test_op_grad_level2.py index 8e809250d1de..57b1e2c676ac 100644 --- a/tests/python/relay/test_op_grad_level2.py +++ b/tests/python/relay/test_op_grad_level2.py @@ -48,8 +48,7 @@ def verify_max_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode): def test_max_pool2d_grad(): - verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0), - ceil_mode=False) + verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0), ceil_mode=False) verify_max_pool2d_grad((1, 4, 16, 16), pool_size=(1, 1), strides=(1, 1), padding=(1, 1), ceil_mode=False) @@ -75,7 +74,6 @@ def verify_avg_pool2d_grad(x_shape, pool_size, strides, padding, ceil_mode, coun op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data) np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01) - def test_avg_pool2d_grad(): verify_avg_pool2d_grad((1, 4, 16, 16), pool_size=(2, 2), strides=(2, 2), padding=(0, 0), ceil_mode=False, count_include_pad=True) @@ -83,6 +81,30 @@ def test_avg_pool2d_grad(): ceil_mode=False, count_include_pad=False) +def verify_global_avg_pool2d_grad(x_shape): + x = relay.var("x", relay.TensorType(x_shape, "float32")) + y = tvm.relay.nn.global_avg_pool2d(x) + + fwd_func = relay.Function([x], y) + fwd_func = run_infer_type(fwd_func) + bwd_func = run_infer_type(gradient(fwd_func)) + + data = np.random.rand(*x_shape).astype("float32") + y_shape = topi.util.get_const_tuple(fwd_func.ret_type.shape) + out_grad = np.ones(shape=y_shape) + ref_grad = topi.testing.pool_grad_nchw(data, out_grad, pool_size=(x_shape[2], x_shape[3]), + strides=(1, 1), padding=[0, 0, 0, 0], pool_type='avg', + ceil_mode=False) + + for target, ctx in ctx_list(): + intrp = relay.create_executor(ctx=ctx, target=target) + op_res, (op_grad, ) = intrp.evaluate(bwd_func)(data) + np.testing.assert_allclose(op_grad.asnumpy(), ref_grad, rtol=0.01) + +def test_global_avg_pool2d_grad(): + verify_global_avg_pool2d_grad((1, 4, 16, 16)) + verify_global_avg_pool2d_grad((1, 8, 8, 24)) + def verify_conv2d_grad(dshape, wshape, strides, padding, dilation, groups=1, mode='higher_order'): try: import torch @@ -155,6 +177,7 @@ def test_batch_flatten_grad(): if __name__ == "__main__": test_max_pool2d_grad() test_avg_pool2d_grad() + test_global_avg_pool2d_grad() test_conv2d_grad() test_dense_grad() test_batch_flatten_grad() diff --git a/tests/python/relay/test_op_grad_level4.py b/tests/python/relay/test_op_grad_level4.py index f8d6c3a56c93..f690a186ea41 100644 --- a/tests/python/relay/test_op_grad_level4.py +++ b/tests/python/relay/test_op_grad_level4.py @@ -29,18 +29,21 @@ def test_sum_grad(): verify_sum_grad((4, 2)) verify_sum_grad((4, 2), axis=-1, keepdims=True) verify_sum_grad((4, 2, 1), axis=(1, 2), exclude=True) + verify_sum_grad((4, 2, 1), axis=1) -def test_max_grad(): - s = (10, 10) - t = relay.TensorType(s) - x = relay.var("x", t) - axis = 0 - z = relay.max(x, axis) - - fwd_func = relay.Function([x], z) +def verify_max_grad(d_shape, axis=None, keepdims=False, exclude=False): + data = relay.var("data", relay.TensorType(d_shape, "float32")) + fwd_func = relay.Function([data], relay.max(data, axis=axis, keepdims=keepdims, exclude=exclude)) check_grad(fwd_func, scale=1e-3) +def test_max_grad(): + verify_max_grad((10, 10), axis=None) + verify_max_grad((10, 10), axis=-1) + verify_max_grad((6, 3, 2), axis=(1, 2), keepdims=True) + verify_max_grad((5, 4, 3), axis=(0, 2), exclude=True) + + if __name__ == "__main__": pytest.main()