Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the baseclass related to transformer #978

Merged
merged 47 commits into from
Jun 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
d3977cd
minor changes
Apr 14, 2021
36418e1
change to modulist
Apr 23, 2021
806fed0
change to Sequential
HIT-cwh Apr 23, 2021
12140f8
replace dropout with attn_drop and proj_drop in MultiheadAttention
HIT-cwh Apr 23, 2021
b15616e
add operation_name for attn
Apr 26, 2021
6bc3254
Merge pull request #1 from HIT-cwh/refactor_transformer
jshilong Apr 26, 2021
869e87f
add drop path and move all ffn args to ffncfgs
Apr 30, 2021
a022220
add drop path and move all ffn args to ffncfgs
Apr 30, 2021
8b5d9b4
fix typo
Apr 30, 2021
7ab36c0
fix a bug when use default value of ffn_cfgs
Apr 30, 2021
a1659aa
fix ffns
Apr 30, 2021
f808e2a
add deprecate warning
May 7, 2021
47cbf8b
pull master
May 7, 2021
282a2e4
fix deprecate warning
May 7, 2021
8b15261
change to pop kwargs
May 7, 2021
be135f3
support register FFN of transformer
congee524 May 8, 2021
0f11d0f
support batch first
May 13, 2021
f3635bd
fix batch first wapper
May 13, 2021
5df1df0
fix forward wapper
May 15, 2021
5e5b47e
fix typo
May 15, 2021
8eb66e5
Merge pull request #2 from congee524/transformer
jshilong May 18, 2021
464f5fe
fix lint
May 18, 2021
5e966cb
add unitest for transformer
May 18, 2021
b81e893
fix unitest
May 19, 2021
c2de1e2
resolve conflict
May 19, 2021
98ab503
fix equal
May 19, 2021
65cccc5
use allclose
May 19, 2021
37fc581
fix comments
May 19, 2021
7cf2445
fix comments
May 20, 2021
ffe1571
change configdict to dict
May 25, 2021
c16e4bb
Merge branch 'master' into refactor_tramsformer_base
ZwwWayne May 25, 2021
28d0782
move drop to a file
jshilong Jun 8, 2021
b1ab260
add comments for drop path
jshilong Jun 8, 2021
d410613
add noqa 501
jshilong Jun 8, 2021
2444084
move bnc wapper to MultiheadAttention
jshilong Jun 10, 2021
d33f54e
move bnc wapper to MultiheadAttention
jshilong Jun 10, 2021
f5622d1
use dep warning
jshilong Jun 10, 2021
08f3739
resolve comments
jshilong Jun 10, 2021
987a11c
add unitest:
jshilong Jun 10, 2021
41830d5
rename residual to identity
jshilong Jun 10, 2021
db3f857
revert runner
jshilong Jun 11, 2021
d59c7fd
msda residual to identity
jshilong Jun 11, 2021
d4871b7
rename inp_identity to identity
jshilong Jun 11, 2021
4cb38de
fix name
jshilong Jun 11, 2021
2fb7ae0
fix transformer
jshilong Jun 11, 2021
8b46c09
remove key in msda
jshilong Jun 11, 2021
65bcd3f
remove assert for key
jshilong Jun 11, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mmcv/cnn/bricks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .conv_module import ConvModule
from .conv_ws import ConvAWS2d, ConvWS2d, conv_ws_2d
from .depthwise_separable_conv_module import DepthwiseSeparableConvModule
from .drop import Dropout, DropPath
from .generalized_attention import GeneralizedAttention
from .hsigmoid import HSigmoid
from .hswish import HSwish
Expand All @@ -29,5 +30,5 @@
'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d',
'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear',
'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d',
'ConvTranspose3d', 'MaxPool3d', 'Conv3d'
'ConvTranspose3d', 'MaxPool3d', 'Conv3d', 'Dropout', 'DropPath'
]
64 changes: 64 additions & 0 deletions mmcv/cnn/bricks/drop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch
import torch.nn as nn

from mmcv import build_from_cfg
from .registry import DROPOUT_LAYERS


def drop_path(x, drop_prob=0., training=False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of
residual blocks).

We follow the implementation
https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
# handle tensors with different dimensions, not just 4D tensors.
shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
random_tensor = keep_prob + torch.rand(
shape, dtype=x.dtype, device=x.device)
output = x.div(keep_prob) * random_tensor.floor()
return output


@DROPOUT_LAYERS.register_module()
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of
residual blocks).

We follow the implementation
https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501

Args:
drop_prob (float): Probability of the path to be zeroed. Default: 0.1
"""

def __init__(self, drop_prob=0.1):
super(DropPath, self).__init__()
self.drop_prob = drop_prob

def forward(self, x):
return drop_path(x, self.drop_prob, self.training)


@DROPOUT_LAYERS.register_module()
class Dropout(nn.Dropout):
"""A wrapper for ``torch.nn.Dropout``, We rename the ``p`` of
``torch.nn.Dropout`` to ``drop_prob`` so as to be consistent with
``DropPath``

Args:
drop_prob (float): Probability of the elements to be
zeroed. Default: 0.5.
inplace (bool): Do the operation inplace or not. Default: False.
"""

def __init__(self, drop_prob=0.5, inplace=False):
super().__init__(p=drop_prob, inplace=inplace)


def build_dropout(cfg, default_args=None):
"""Builder for drop out layers."""
return build_from_cfg(cfg, DROPOUT_LAYERS, default_args)
10 changes: 6 additions & 4 deletions mmcv/cnn/bricks/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
UPSAMPLE_LAYERS = Registry('upsample layer')
PLUGIN_LAYERS = Registry('plugin layer')

POSITIONAL_ENCODING = Registry('Position encoding')
ATTENTION = Registry('Attention')
TRANSFORMER_LAYER = Registry('TransformerLayer')
TRANSFORMER_LAYER_SEQUENCE = Registry('TransformerLayerSequence')
DROPOUT_LAYERS = Registry('drop out layers')
POSITIONAL_ENCODING = Registry('position encoding')
ATTENTION = Registry('attention')
FEEDFORWARD_NETWORK = Registry('feed-forward Network')
TRANSFORMER_LAYER = Registry('transformerLayer')
TRANSFORMER_LAYER_SEQUENCE = Registry('transformer-layers sequence')
Loading