Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-408] inplace ReLU activation #10847

Merged
merged 7 commits into from
May 8, 2018
Merged

Conversation

eric-haibin-lin
Copy link
Member

@eric-haibin-lin eric-haibin-lin commented May 8, 2018

Description

For y = relu(x), this PR calculates dx based on (dy, y) instead of (dy, x) which enables inplace operation during y = relu(x) (i.e. y and x shares the same memory).

Checklist

Essentials

Please feel free to remove inapplicable items for your PR.

  • The PR title starts with [MXNET-$JIRA_ID], where $JIRA_ID refers to the relevant JIRA issue created (except PRs with tiny changes)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage:
  • Unit tests are added for small changes to verify correctness (e.g. adding a new operator)
  • Nightly tests are added for complicated/long-running ones (e.g. changing distributed kvstore)
  • Build tests will be added for build configuration changes (e.g. adding a new build option with NCCL)
  • Code is well-documented:
  • For user-facing API changes, API doc string has been updated.
  • For new C++ functions in header files, their functionalities and arguments are documented.
  • For new examples, README.md is added to explain the what the example does, the source of the dataset, expected performance on test set and reference to the original paper if applicable
  • Check the API doc at http://mxnet-ci-doc.s3-accelerate.dualstack.amazonaws.com/PR-$PR_ID/$BUILD_ID/index.html
  • To the my best knowledge, examples are either not affected by this change, or have been fixed to be compatible with this change

Changes

  • Feature1, tests, (and when applicable, API doc)
  • Feature2, tests, (and when applicable, API doc)

Comments

  • If this change is a backward incompatible change, why must this change be made.
  • Interesting edge cases to note here
import mxnet as mx
import numpy as np
import argparse, os

parser = argparse.ArgumentParser(description="Memory benchmark",
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--batch-size', type=int, default=128,
                    help='batch size')
parser.add_argument('--num-classes', type=int, default=2000,
                    help='number of classes')
parser.add_argument('--num-layers', type=int, default=11,
                    help='number of classes')
class SyntheticDataIter(mx.io.DataIter):
    def __init__(self, num_classes, data_shape, max_iter, dtype):
        self.batch_size = data_shape[0]
        self.cur_iter = 0
        self.max_iter = max_iter
        self.dtype = dtype
        label = np.random.randint(0, num_classes, [self.batch_size,])
        data = np.random.uniform(-1, 1, data_shape)
        self.data = mx.nd.array(data, dtype=self.dtype, ctx=mx.Context('cpu_pinned', 0))
        self.label = mx.nd.array(label, dtype=self.dtype, ctx=mx.Context('cpu_pinned', 0))

    def __iter__(self):
        return self

    @property
    def provide_data(self):
        return [mx.io.DataDesc('data', self.data.shape, self.dtype)]

    @property
    def provide_label(self):
        return [mx.io.DataDesc('softmax_label', (self.batch_size,), self.dtype)]

    def next(self):
        self.cur_iter += 1
        if self.cur_iter <= self.max_iter:
            return mx.io.DataBatch(data=(self.data,),
                                   label=(self.label,),
                                   pad=0,
                                   index=None,
                                   provide_data=self.provide_data,
                                   provide_label=self.provide_label)
        else:
            raise StopIteration

    def __next__(self):
        return self.next()

    def reset(self):
        self.cur_iter = 0
def get_feature(internel_layer, layers, filters, batch_norm=False, **kwargs):
    cudnn_off = False
    workspace = 1024
    cudnn_tune = None
    for i, num in enumerate(layers):
        for j in range(num):
            internel_layer = mx.sym.Convolution(data=internel_layer, kernel=(3, 3), pad=(1, 1),
                                                num_filter=filters[i], name="conv%s_%s" %(i + 1, j + 1),
                                                cudnn_off=cudnn_off, workspace=workspace,
                                                cudnn_tune=cudnn_tune)
            if batch_norm:
                internel_layer = mx.symbol.BatchNorm(data=internel_layer, name="bn%s_%s" %(i + 1, j + 1))
            internel_layer = mx.sym.Activation(data=internel_layer, act_type="relu",
                                               name="relu%s_%s" %(i + 1, j + 1))
        internel_layer = mx.sym.Pooling(data=internel_layer, pool_type="max",
                                        kernel=(2, 2), stride=(2,2), name="pool%s" %(i + 1))
    return internel_layer

def get_classifier(input_data, num_classes, **kwargs):
    flatten = mx.sym.Flatten(data=input_data, name="flatten")
    fc6 = mx.sym.FullyConnected(data=flatten, num_hidden=4096, name="fc6")
    relu6 = mx.sym.Activation(data=fc6, act_type="relu", name="relu6")
    drop6 = mx.sym.Dropout(data=relu6, p=0.5, name="drop6")
    fc7 = mx.sym.FullyConnected(data=drop6, num_hidden=4096, name="fc7")
    relu7 = mx.sym.Activation(data=fc7, act_type="relu", name="relu7")
    drop7 = mx.sym.Dropout(data=relu7, p=0.5, name="drop7")
    fc8 = mx.sym.FullyConnected(data=drop7, num_hidden=num_classes, name="fc8")
    return fc8

def get_symbol(num_classes, num_layers=11, batch_norm=False, dtype='float32', **kwargs):
    """
    Parameters
    ----------
    num_classes : int, default 1000
        Number of classification classes.
    num_layers : int
        Number of layers for the variant of densenet. Options are 11, 13, 16, 19.
    batch_norm : bool, default False
        Use batch normalization.
    dtype: str, float32 or float16
        Data precision.
    """
    vgg_spec = {11: ([1, 1, 2, 2, 2], [64, 128, 256, 512, 512]),
                13: ([2, 2, 2, 2, 2], [64, 128, 256, 512, 512]),
                16: ([2, 2, 3, 3, 3], [64, 128, 256, 512, 512]),
                1: ([1], [64]),
                2: ([1], [64]),
                19: ([2, 2, 4, 4, 4], [64, 128, 256, 512, 512])}
    if num_layers not in vgg_spec:
        raise ValueError("Invalide num_layers {}. Possible choices are 11,13,16,19.".format(num_layers))
    layers, filters = vgg_spec[num_layers]
    data = mx.sym.Variable(name="data")
    if dtype == 'float16':
        data = mx.sym.Cast(data=data, dtype=np.float16)
    feature = get_feature(data, layers, filters, batch_norm)
    classifier = get_classifier(feature, num_classes)
    if dtype == 'float16':
        classifier = mx.sym.Cast(data=classifier, dtype=np.float32)
    symbol = mx.sym.SoftmaxOutput(data=classifier, name='softmax')
    return symbol

args = parser.parse_args()
print(args)
bs = args.batch_size
num_classes = args.num_classes
data_shape = (bs, 3, 224, 224)
num_layers = args.num_layers
ctx=mx.gpu()

net = get_symbol(num_classes, num_layers=num_layers, dtype='float32')

# initialize the module
mod = mx.module.Module(net, context=ctx)
train_iter = SyntheticDataIter(num_classes, data_shape, 500, np.float32)
mod.bind(data_shapes=train_iter.provide_data, label_shapes=train_iter.provide_label)
mod.init_params(initializer=mx.init.Xavier(factor_type="in", magnitude=2.34))
optim = mx.optimizer.create('sgd')
mod.init_optimizer(optimizer=optim)

metric = mx.metric.create(['accuracy'])
nbatch = 0
metric.reset()
for batch in train_iter:
    nbatch += 1
    mod.forward_backward(batch)
    # update all parameters
    mod.update()
    # update training metric
    mod.update_metric(metric, batch.label)

Before:
Batch size = 120, Memory = 11249 MB

After:
Batch size = 120, Memory = 7856 MB
Batch size = 190, Memory = 11250 MB

For train_imagenet.py --network resnet --benchmark 1 --gpus=0 --batch-size=xx with kvstore=None, max batch size increased from 88 to 128.

CUDNN7, CUDA 9, K80.

{'ctx': mx.cpu(0), 'act_data': shape, 'type_dict': {'act_data': np.float64}},
{'ctx': mx.cpu(0), 'act_data': shape, 'type_dict': {'act_data': np.float32}},
{'ctx': mx.cpu(0), 'act_data': shape, 'type_dict': {'act_data': np.float16}}]
check_consistency(sym, ctx_list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to add gradient tests

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I didn't know it doesn't check grad

@piiswrong piiswrong merged commit b2ec05b into apache:master May 8, 2018
@eric-haibin-lin eric-haibin-lin changed the title [MXNET-408] [WIP] inplace ReLU activation [MXNET-408] inplace ReLU activation May 8, 2018
jinhuang415 pushed a commit to jinhuang415/incubator-mxnet that referenced this pull request May 29, 2018
* inplace version of activation(relu)

* inplace relu

* add comments

* add commnet

* comments

* fix compilation error

* add check_numerical_grad test
rahul003 pushed a commit to rahul003/mxnet that referenced this pull request Jun 4, 2018
* inplace version of activation(relu)

* inplace relu

* add comments

* add commnet

* comments

* fix compilation error

* add check_numerical_grad test
eric-haibin-lin added a commit to eric-haibin-lin/mxnet that referenced this pull request Jun 4, 2018
* inplace version of activation(relu)

* inplace relu

* add comments

* add commnet

* comments

* fix compilation error

* add check_numerical_grad test
anirudh2290 pushed a commit that referenced this pull request Jun 11, 2018
* inplace version of activation(relu)

* inplace relu

* add comments

* add commnet

* comments

* fix compilation error

* add check_numerical_grad test
zheng-da pushed a commit to zheng-da/incubator-mxnet that referenced this pull request Jun 28, 2018
* inplace version of activation(relu)

* inplace relu

* add comments

* add commnet

* comments

* fix compilation error

* add check_numerical_grad test
@eric-haibin-lin eric-haibin-lin deleted the relu branch September 18, 2018 23:33
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants