Skip to content

Commit

Permalink
Merge branch 'master' into zz/base
Browse files Browse the repository at this point in the history
  • Loading branch information
innerlee authored Jun 25, 2021
2 parents 2f4ac79 + eb08835 commit 5a41a39
Show file tree
Hide file tree
Showing 65 changed files with 2,553 additions and 499 deletions.
6 changes: 5 additions & 1 deletion .github/pull_request_template.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.

## Motivation

Please describe the motivation of this PR and the goal you want to achieve through this PR.

## Modification

Please briefly describe what modification is made in this PR.

## BC-breaking (Optional)
Does the modification introduce changes that break the back-compatibility of the downstream repos?

Does the modification introduce changes that break the backward-compatibility of the downstream repos?
If so, please describe how it breaks the compatibility and how the downstream projects should modify their code to keep compatibility with this PR.

## Use cases (Optional)

If this PR introduces a new feature, it is better to list some use cases here, and update the documentation.

## Checklist
Expand Down
163 changes: 160 additions & 3 deletions docs/runner.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,163 @@
## Runner

The runner module aims to help users to start training with less code, while stays
flexible and configurable.
The runner class is designed to manage the training. It eases the training process with less code demanded from users while staying flexible and configurable. The main features are as listed:

Documentation and examples are still on going.
- Support `EpochBasedRunner` and `IterBasedRunner` for different scenarios. Implementing customized runners is also allowed to meet customized needs.
- Support customized workflow to allow switching between different modes while training. Currently, supported modes are train and val.
- Enable extensibility through various hooks, including hooks defined in MMCV and customized ones.

### EpochBasedRunner

As its name indicates, workflow in `EpochBasedRunner` should be set based on epochs. For example, [('train', 2), ('val', 1)] means running 2 epochs for training and 1 epoch for validation, iteratively. And each epoch may contain multiple iterations. Currently, MMDetection uses `EpochBasedRunner` by default.

Let's take a look at its core logic:

```python
# the condition to stop training
while curr_epoch < max_epochs:
# traverse the workflow.
# e.g. workflow = [('train', 2), ('val', 1)]
for i, flow in enumerate(workflow):
# mode(e.g. train) determines which function to run
mode, epochs = flow
# epoch_runner will be either self.train() or self.val()
epoch_runner = getattr(self, mode)
# execute the corresponding function
for _ in range(epochs):
epoch_runner(data_loaders[i], **kwargs)
```

Currently, we support 2 modes: train and val. Let's take a train function for example and have a look at its core logic:

```python
# Currently, epoch_runner could be either train or val
def train(self, data_loader, **kwargs):
# traverse the dataset and get batch data for 1 epoch
for i, data_batch in enumerate(data_loader):
# it will execute all before_train_iter function in the hooks registered. You may want to watch out for the order.
self.call_hook('before_train_iter')
# set train_mode as False in val function
self.run_iter(data_batch, train_mode=True, **kwargs)
self.call_hook('after_train_iter')
self.call_hook('after_train_epoch')
```

### IterBasedRunner

Different from `EpochBasedRunner`, workflow in `IterBasedRunner` should be set based on iterations. For example, [('train', 2), ('val', 1)] means running 2 iters for training and 1 iter for validation, iteratively. Currently, MMSegmentation uses `IterBasedRunner` by default.

Let's take a look at its core logic:

```python
# Although we set workflow by iters here, we might also need info on the epochs in some using cases. That can be provided by IterLoader.
iter_loaders = [IterLoader(x) for x in data_loaders]
# the condition to stop training
while curr_iter < max_iters:
# traverse the workflow.
# e.g. workflow = [('train', 2), ('val', 1)]
for i, flow in enumerate(workflow):
# mode(e.g. train) determines which function to run
mode, iters = flow
# epoch_runner will be either self.train() or self.val()
iter_runner = getattr(self, mode)
# execute the corresponding function
for _ in range(iters):
iter_runner(iter_loaders[i], **kwargs)
```

Currently, we support 2 modes: train and val. Let's take a val function for example and have a look at its core logic:

```python
# Currently, iter_runner could be either train or val
def val(self, data_loader, **kwargs):
# get batch data for 1 iter
data_batch = next(data_loader)
# it will execute all before_val_iter function in the hooks registered. You may want to watch out for the order.
self.call_hook('before_val_iter')
outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
self.outputs = outputs
self.call_hook('after_val_iter')
```

Other than the basic functionalities explained above, `EpochBasedRunner` and `IterBasedRunner` provide methods such as `resume`, `save_checkpoint` and `register_hook`. In case you are not familiar with the term Hook mentioned earlier, we will also provide a tutorial about it.(coming soon...) Essentially, a hook is functionality to alter or augment the code behaviors through predefined api. It allows users to have their own code called under certain circumstances. It makes code extensible in a non-intrusive manner.

### A Simple Example

We will walk you through the usage of runner with a classification task. The following code only contains essential steps for demonstration purposes. The following steps are necessary for any training tasks.

**(1) Initialize dataloader, model, optimizer, etc.**

```python
# initialize model
model=...
# initialize optimizer, typically, we set: cfg.optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)
optimizer = build_optimizer(model, cfg.optimizer)
# intialize the dataloader corresponding to the workflow(train/val)
data_loaders = [
build_dataloader(
ds,
cfg.data.samples_per_gpu,
cfg.data.workers_per_gpu,
...) for ds in dataset
]
```

**(2) Initialize runner**

```python
runner = build_runner(
# cfg.runner is typically set as:
# runner = dict(type='EpochBasedRunner', max_epochs=200)
cfg.runner,
default_args=dict(
model=model,
batch_processor=None,
optimizer=optimizer,
logger=logger))
```

**(3) Register training hooks and customized hooks.**

```python
# register defalt hooks neccesary for traning
runner.register_training_hooks(
# configs of learning rate,it is typically set as:
# lr_config = dict(policy='step', step=[100, 150])
cfg.lr_config,
# configuration of optimizer, e.g. grad_clip
optimizer_config,
# configuration of saving checkpoints, it is typically set as:
# checkpoint_config = dict(interval=1),saving checkpoints every epochs
cfg.checkpoint_config,
# configuration of logs
cfg.log_config,
...)

# register customized hooks
# say we want to enable ema, then we could set custom_hooks=[dict(type='EMAHook')]
if cfg.get('custom_hooks', None):
custom_hooks = cfg.custom_hooks
for hook_cfg in cfg.custom_hooks:
hook_cfg = hook_cfg.copy()
priority = hook_cfg.pop('priority', 'NORMAL')
hook = build_from_cfg(hook_cfg, HOOKS)
runner.register_hook(hook, priority=priority)
```

Then, we can use `resume` or `load_checkpoint` to load existing weights.

**(4) Start training**

```python
# workflow is typically set as: workflow = [('train', 1)]
# here the training begins.
runner.run(data_loaders, cfg.workflow)
```

Let's take `EpochBasedRunner` for example and go a little bit into details about setting workflow:

- Say we only want to put train in the workflow, then we can set: workflow = [('train', 1)]. The runner will only execute train iteratively in this case.
- Say we want to put both train and val in the workflow, then we can set: workflow = [('train', 3), ('val',1)]. The runner will first execute train for 3 epochs and then switch to val mode and execute val for 1 epoch. The workflow will be repeated until the current epoch hit the max_epochs.
- Workflow is highly flexible. Therefore, you can set workflow = [('val', 1), ('train',1)] if you would like the runner to validate first and train after.

The code we demonstrated above is already in `train.py` in MM repositories. Simply modify the corresponding keys in the configuration files and the script will execute the expected workflow automatically.
48 changes: 48 additions & 0 deletions docs/tensorrt_custom_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@
- [Inputs](#inputs-7)
- [Outputs](#outputs-7)
- [Type Constraints](#type-constraints-7)
- [MMCVModulatedDeformConv2d](#mmcvmodulateddeformconv2d)
- [Description](#description-8)
- [Parameters](#parameters-8)
- [Inputs](#inputs-8)
- [Outputs](#outputs-8)
- [Type Constraints](#type-constraints-8)

<!-- TOC -->

Expand Down Expand Up @@ -345,3 +351,45 @@ y = scale * (x - mean) / sqrt(variance + epsilon) + B, where mean and variance a
### Type Constraints

- T:tensor(float32, Linear)

## MMCVModulatedDeformConv2d

### Description

Perform Modulated Deformable Convolution on input feature, read [Deformable ConvNets v2: More Deformable, Better Results](https://arxiv.org/abs/1811.11168?from=timeline) for detail.

### Parameters

| Type | Parameter | Description |
| -------------- | ------------------ | ------------------------------------------------------------------------------------- |
| `list of ints` | `stride` | The stride of the convolving kernel. (sH, sW) |
| `list of ints` | `padding` | Paddings on both sides of the input. (padH, padW) |
| `list of ints` | `dilation` | The spacing between kernel elements. (dH, dW) |
| `int` | `deformable_group` | Groups of deformable offset. |
| `int` | `group` | Split input into groups. `input_channel` should be divisible by the number of groups. |

### Inputs

<dl>
<dt><tt>inputs[0]</tt>: T</dt>
<dd>Input feature; 4-D tensor of shape (N, C, inH, inW), where N is the batch size, C is the number of channels, inH and inW are the height and width of the data.</dd>
<dt><tt>inputs[1]</tt>: T</dt>
<dd>Input offset; 4-D tensor of shape (N, deformable_group* 2* kH* kW, outH, outW), where kH and kW is the height and width of weight, outH and outW is the height and width of offset and output.</dd>
<dt><tt>inputs[2]</tt>: T</dt>
<dd>Input mask; 4-D tensor of shape (N, deformable_group* kH* kW, outH, outW), where kH and kW is the height and width of weight, outH and outW is the height and width of offset and output.</dd>
<dt><tt>inputs[3]</tt>: T</dt>
<dd>Input weight; 4-D tensor of shape (output_channel, input_channel, kH, kW).</dd>
<dt><tt>inputs[4]</tt>: T, optional</dt>
<dd>Input weight; 1-D tensor of shape (output_channel).</dd>
</dl>

### Outputs

<dl>
<dt><tt>outputs[0]</tt>: T</dt>
<dd>Output feature; 4-D tensor of shape (N, output_channel, outH, outW).</dd>
</dl>

### Type Constraints

- T:tensor(float32, Linear)
7 changes: 4 additions & 3 deletions docs/tensorrt_plugin.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ To ease the deployment of trained models with custom operators from `mmcv.ops` u
| NonMaxSuppression | [NonMaxSuppression](./tensorrt_custom_ops.md#nonmaxsuppression) | 1.3.0 |
| MMCVDeformConv2d | [MMCVDeformConv2d](./tensorrt_custom_ops.md#mmcvdeformconv2d) | 1.3.0 |
| grid_sampler | [grid_sampler](./tensorrt_custom_ops.md#grid-sampler) | 1.3.1 |
| cummax | [cummax](./tensorrt_custom_ops.md#cummax) | master |
| cummin | [cummin](./tensorrt_custom_ops.md#cummin) | master |
| MMCVInstanceNormalization | [MMCVInstanceNormalization](./tensorrt_custom_ops.md#mmcvinstancenormalization) | master |
| cummax | [cummax](./tensorrt_custom_ops.md#cummax) | 1.3.5 |
| cummin | [cummin](./tensorrt_custom_ops.md#cummin) | 1.3.5 |
| MMCVInstanceNormalization | [MMCVInstanceNormalization](./tensorrt_custom_ops.md#mmcvinstancenormalization) | 1.3.5 |
| MMCVModulatedDeformConv2d | [MMCVModulatedDeformConv2d](./tensorrt_custom_ops.md#mmcvmodulateddeformconv2d) | master |

Notes

Expand Down
3 changes: 2 additions & 1 deletion mmcv/cnn/bricks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .conv_module import ConvModule
from .conv_ws import ConvAWS2d, ConvWS2d, conv_ws_2d
from .depthwise_separable_conv_module import DepthwiseSeparableConvModule
from .drop import Dropout, DropPath
from .generalized_attention import GeneralizedAttention
from .hsigmoid import HSigmoid
from .hswish import HSwish
Expand All @@ -29,5 +30,5 @@
'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d',
'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear',
'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d',
'ConvTranspose3d', 'MaxPool3d', 'Conv3d'
'ConvTranspose3d', 'MaxPool3d', 'Conv3d', 'Dropout', 'DropPath'
]
64 changes: 64 additions & 0 deletions mmcv/cnn/bricks/drop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch
import torch.nn as nn

from mmcv import build_from_cfg
from .registry import DROPOUT_LAYERS


def drop_path(x, drop_prob=0., training=False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of
residual blocks).
We follow the implementation
https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
# handle tensors with different dimensions, not just 4D tensors.
shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(
shape, dtype=x.dtype, device=x.device)
output = x.div(keep_prob) * random_tensor.floor()
return output


@DROPOUT_LAYERS.register_module()
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of
residual blocks).
We follow the implementation
https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
Args:
drop_prob (float): Probability of the path to be zeroed. Default: 0.1
"""

def __init__(self, drop_prob=0.1):
super(DropPath, self).__init__()
self.drop_prob = drop_prob

def forward(self, x):
return drop_path(x, self.drop_prob, self.training)


@DROPOUT_LAYERS.register_module()
class Dropout(nn.Dropout):
"""A wrapper for ``torch.nn.Dropout``, We rename the ``p`` of
``torch.nn.Dropout`` to ``drop_prob`` so as to be consistent with
``DropPath``
Args:
drop_prob (float): Probability of the elements to be
zeroed. Default: 0.5.
inplace (bool): Do the operation inplace or not. Default: False.
"""

def __init__(self, drop_prob=0.5, inplace=False):
super().__init__(p=drop_prob, inplace=inplace)


def build_dropout(cfg, default_args=None):
"""Builder for drop out layers."""
return build_from_cfg(cfg, DROPOUT_LAYERS, default_args)
10 changes: 6 additions & 4 deletions mmcv/cnn/bricks/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
UPSAMPLE_LAYERS = Registry('upsample layer')
PLUGIN_LAYERS = Registry('plugin layer')

POSITIONAL_ENCODING = Registry('Position encoding')
ATTENTION = Registry('Attention')
TRANSFORMER_LAYER = Registry('TransformerLayer')
TRANSFORMER_LAYER_SEQUENCE = Registry('TransformerLayerSequence')
DROPOUT_LAYERS = Registry('drop out layers')
POSITIONAL_ENCODING = Registry('position encoding')
ATTENTION = Registry('attention')
FEEDFORWARD_NETWORK = Registry('feed-forward Network')
TRANSFORMER_LAYER = Registry('transformerLayer')
TRANSFORMER_LAYER_SEQUENCE = Registry('transformer-layers sequence')
Loading

0 comments on commit 5a41a39

Please sign in to comment.