diff --git a/gluoncv/model_zoo/action_recognition/i3d_resnet.py b/gluoncv/model_zoo/action_recognition/i3d_resnet.py index a93f837101..fe9d03f3ea 100644 --- a/gluoncv/model_zoo/action_recognition/i3d_resnet.py +++ b/gluoncv/model_zoo/action_recognition/i3d_resnet.py @@ -1,7 +1,9 @@ # pylint: disable=line-too-long,too-many-lines,missing-docstring,arguments-differ,unused-argument # Code partially borrowed from https://github.com/open-mmlab/mmaction. -__all__ = ['I3D_ResNetV1', 'i3d_resnet50_v1_kinetics400', 'i3d_resnet101_v1_kinetics400'] +__all__ = ['I3D_ResNetV1', 'i3d_resnet50_v1_kinetics400', 'i3d_resnet101_v1_kinetics400', + 'i3d_nl5_resnet50_v1_kinetics400', 'i3d_nl10_resnet50_v1_kinetics400', + 'i3d_nl5_resnet101_v1_kinetics400', 'i3d_nl10_resnet101_v1_kinetics400'] from mxnet import nd from mxnet import init @@ -10,6 +12,7 @@ from mxnet.gluon import nn from mxnet.gluon.nn import BatchNorm from ..resnetv1b import resnet50_v1b, resnet101_v1b +from .non_local import build_nonlocal_block def conv3x3x3(in_planes, out_planes, spatial_stride=1, temporal_stride=1, dilation=1): "3x3x3 convolution with padding" @@ -185,6 +188,7 @@ def __init__(self, nonlocal_cfg_ = nonlocal_cfg.copy() nonlocal_cfg_['in_channels'] = planes * self.expansion self.nonlocal_block = build_nonlocal_block(nonlocal_cfg_) + self.bottleneck.add(self.nonlocal_block) else: self.nonlocal_block = None @@ -424,8 +428,19 @@ def init_weights(self, ctx): resnet2d = resnet101_v1b(pretrained=True) else: print('No such 2D pre-trained network of depth %d.' % (self.depth)) + weights2d = resnet2d.collect_params() - weights3d = self.collect_params() + if self.nonlocal_cfg is None: + weights3d = self.collect_params() + else: + train_params_list = [] + raw_params = self.collect_params() + for raw_name in raw_params.keys(): + if 'nonlocal' in raw_name: + continue + train_params_list.append(raw_name) + init_patterns = '|'.join(train_params_list) + weights3d = self.collect_params(init_patterns) assert len(weights2d.keys()) == len(weights3d.keys()), 'Number of parameters should be same.' dict2d = {} @@ -579,3 +594,215 @@ def i3d_resnet101_v1_kinetics400(nclass=400, pretrained=False, pretrained_base=T model.collect_params().reset_ctx(ctx) return model + +def i3d_nl5_resnet50_v1_kinetics400(nclass=400, pretrained=False, pretrained_base=True, ctx=cpu(), + root='~/.mxnet/models', tsn=False, num_segments=1, partial_bn=False, **kwargs): + r"""Inflated 3D model (I3D) from + `"Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset" + `_ paper. + `"Non-local Neural Networks" + `_ paper. + + Parameters + ---------- + pretrained : bool or str + Boolean value controls whether to load the default pretrained weights for model. + String value represents the hashtag for a certain version of pretrained weights. + ctx : Context, default CPU + The context in which to load the pretrained weights. + root : str, default $MXNET_HOME/models + Location for keeping the model parameters. + partial_bn : bool, default False + Freeze all batch normalization layers during training except the first layer. + norm_layer : object + Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`) + Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`. + norm_kwargs : dict + Additional `norm_layer` arguments, for example `num_devices=4` + for :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`. + """ + + model = I3D_ResNetV1(nclass=nclass, + depth=50, + pretrained=pretrained, + pretrained_base=pretrained_base, + num_segments=num_segments, + out_indices=[3], + inflate_freq=((1, 1, 1), (1, 0, 1, 0), (1, 0, 1, 0, 1, 0), (0, 1, 0)), + nonlocal_stages=(1, 2), + nonlocal_cfg=dict(nonlocal_type="gaussian"), + nonlocal_freq=((0, 0, 0), (0, 1, 0, 1), (0, 1, 0, 1, 0, 1), (0, 0, 0)), + bn_eval=False, + partial_bn=partial_bn, + ctx=ctx, + **kwargs) + + if pretrained: + from ..model_store import get_model_file + model.load_parameters(get_model_file('i3d_nl5_resnet50_v1_kinetics400', + tag=pretrained, root=root), ctx=ctx) + from ...data import Kinetics400Attr + attrib = Kinetics400Attr() + model.classes = attrib.classes + model.collect_params().reset_ctx(ctx) + + return model + +def i3d_nl10_resnet50_v1_kinetics400(nclass=400, pretrained=False, pretrained_base=True, ctx=cpu(), + root='~/.mxnet/models', tsn=False, num_segments=1, partial_bn=False, **kwargs): + r"""Inflated 3D model (I3D) from + `"Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset" + `_ paper. + `"Non-local Neural Networks" + `_ paper. + + Parameters + ---------- + pretrained : bool or str + Boolean value controls whether to load the default pretrained weights for model. + String value represents the hashtag for a certain version of pretrained weights. + ctx : Context, default CPU + The context in which to load the pretrained weights. + root : str, default $MXNET_HOME/models + Location for keeping the model parameters. + partial_bn : bool, default False + Freeze all batch normalization layers during training except the first layer. + norm_layer : object + Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`) + Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`. + norm_kwargs : dict + Additional `norm_layer` arguments, for example `num_devices=4` + for :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`. + """ + + model = I3D_ResNetV1(nclass=nclass, + depth=50, + pretrained=pretrained, + pretrained_base=pretrained_base, + num_segments=num_segments, + out_indices=[3], + inflate_freq=((1, 1, 1), (1, 0, 1, 0), (1, 0, 1, 0, 1, 0), (0, 1, 0)), + nonlocal_stages=(1, 2), + nonlocal_cfg=dict(nonlocal_type="gaussian"), + nonlocal_freq=((0, 0, 0), (1, 1, 1, 1), (1, 1, 1, 1, 1, 1), (0, 0, 0)), + bn_eval=False, + partial_bn=partial_bn, + ctx=ctx, + **kwargs) + + if pretrained: + from ..model_store import get_model_file + model.load_parameters(get_model_file('i3d_nl10_resnet50_v1_kinetics400', + tag=pretrained, root=root), ctx=ctx) + from ...data import Kinetics400Attr + attrib = Kinetics400Attr() + model.classes = attrib.classes + model.collect_params().reset_ctx(ctx) + + return model + +def i3d_nl5_resnet101_v1_kinetics400(nclass=400, pretrained=False, pretrained_base=True, ctx=cpu(), + root='~/.mxnet/models', tsn=False, num_segments=1, partial_bn=False, **kwargs): + r"""Inflated 3D model (I3D) from + `"Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset" + `_ paper. + `"Non-local Neural Networks" + `_ paper. + + Parameters + ---------- + pretrained : bool or str + Boolean value controls whether to load the default pretrained weights for model. + String value represents the hashtag for a certain version of pretrained weights. + ctx : Context, default CPU + The context in which to load the pretrained weights. + root : str, default $MXNET_HOME/models + Location for keeping the model parameters. + partial_bn : bool, default False + Freeze all batch normalization layers during training except the first layer. + norm_layer : object + Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`) + Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`. + norm_kwargs : dict + Additional `norm_layer` arguments, for example `num_devices=4` + for :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`. + """ + + model = I3D_ResNetV1(nclass=nclass, + depth=101, + pretrained=pretrained, + pretrained_base=pretrained_base, + num_segments=num_segments, + out_indices=[3], + inflate_freq=((1, 1, 1), (1, 0, 1, 0), (1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1), (0, 1, 0)), + nonlocal_stages=(1, 2), + nonlocal_cfg=dict(nonlocal_type="gaussian"), + nonlocal_freq=((0, 0, 0), (0, 1, 0, 1), (0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0), (0, 0, 0)), + bn_eval=False, + partial_bn=partial_bn, + ctx=ctx, + **kwargs) + + if pretrained: + from ..model_store import get_model_file + model.load_parameters(get_model_file('i3d_nl5_resnet101_v1_kinetics400', + tag=pretrained, root=root), ctx=ctx) + from ...data import Kinetics400Attr + attrib = Kinetics400Attr() + model.classes = attrib.classes + model.collect_params().reset_ctx(ctx) + + return model + +def i3d_nl10_resnet101_v1_kinetics400(nclass=400, pretrained=False, pretrained_base=True, ctx=cpu(), + root='~/.mxnet/models', tsn=False, num_segments=1, partial_bn=False, **kwargs): + r"""Inflated 3D model (I3D) from + `"Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset" + `_ paper. + `"Non-local Neural Networks" + `_ paper. + + Parameters + ---------- + pretrained : bool or str + Boolean value controls whether to load the default pretrained weights for model. + String value represents the hashtag for a certain version of pretrained weights. + ctx : Context, default CPU + The context in which to load the pretrained weights. + root : str, default $MXNET_HOME/models + Location for keeping the model parameters. + partial_bn : bool, default False + Freeze all batch normalization layers during training except the first layer. + norm_layer : object + Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`) + Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`. + norm_kwargs : dict + Additional `norm_layer` arguments, for example `num_devices=4` + for :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`. + """ + + model = I3D_ResNetV1(nclass=nclass, + depth=101, + pretrained=pretrained, + pretrained_base=pretrained_base, + num_segments=num_segments, + out_indices=[3], + inflate_freq=((1, 1, 1), (1, 0, 1, 0), (1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1), (0, 1, 0)), + nonlocal_stages=(1, 2), + nonlocal_cfg=dict(nonlocal_type="gaussian"), + nonlocal_freq=((0, 0, 0), (1, 1, 1, 1), (0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1), (0, 0, 0)), + bn_eval=False, + partial_bn=partial_bn, + ctx=ctx, + **kwargs) + + if pretrained: + from ..model_store import get_model_file + model.load_parameters(get_model_file('i3d_nl10_resnet101_v1_kinetics400', + tag=pretrained, root=root), ctx=ctx) + from ...data import Kinetics400Attr + attrib = Kinetics400Attr() + model.classes = attrib.classes + model.collect_params().reset_ctx(ctx) + + return model diff --git a/gluoncv/model_zoo/action_recognition/non_local.py b/gluoncv/model_zoo/action_recognition/non_local.py new file mode 100644 index 0000000000..a99ca1b424 --- /dev/null +++ b/gluoncv/model_zoo/action_recognition/non_local.py @@ -0,0 +1,121 @@ +"""Non-local block for video action recognition""" +# pylint: disable=line-too-long,too-many-lines,missing-docstring,arguments-differ,unused-argument +from mxnet.gluon.block import HybridBlock +from mxnet import init +from mxnet.gluon import nn +from mxnet.gluon.nn import BatchNorm + +def build_nonlocal_block(cfg): + """ Build nonlocal block from + `"Non-local Neural Networks" + `_ paper. + Code adapted from mmaction. + """ + assert isinstance(cfg, dict) + cfg_ = cfg.copy() + return NonLocal(**cfg_) + +class NonLocal(HybridBlock): + def __init__(self, in_channels=1024, nonlocal_type="gaussian", dim=3, embed=True, embed_dim=None, sub_sample=True, use_bn=True, + norm_layer=BatchNorm, norm_kwargs=None, ctx=None, **kwargs): + super(NonLocal, self).__init__() + + assert nonlocal_type in ['gaussian', 'dot', 'concat'] + self.nonlocal_type = nonlocal_type + self.embed = embed + self.embed_dim = embed_dim if embed_dim is not None else in_channels // 2 + self.sub_sample = sub_sample + self.use_bn = use_bn + + with self.name_scope(): + if self.embed: + if dim == 2: + self.theta = nn.Conv2D(in_channels=in_channels, channels=self.embed_dim, kernel_size=(1, 1), + strides=(1, 1), padding=(0, 0), weight_initializer=init.MSRAPrelu()) + self.phi = nn.Conv2D(in_channels=in_channels, channels=self.embed_dim, kernel_size=(1, 1), + strides=(1, 1), padding=(0, 0), weight_initializer=init.MSRAPrelu()) + self.g = nn.Conv2D(in_channels=in_channels, channels=self.embed_dim, kernel_size=(1, 1), + strides=(1, 1), padding=(0, 0), weight_initializer=init.MSRAPrelu()) + elif dim == 3: + self.theta = nn.Conv3D(in_channels=in_channels, channels=self.embed_dim, kernel_size=(1, 1, 1), + strides=(1, 1, 1), padding=(0, 0, 0), weight_initializer=init.MSRAPrelu()) + self.phi = nn.Conv3D(in_channels=in_channels, channels=self.embed_dim, kernel_size=(1, 1, 1), + strides=(1, 1, 1), padding=(0, 0, 0), weight_initializer=init.MSRAPrelu()) + self.g = nn.Conv3D(in_channels=in_channels, channels=self.embed_dim, kernel_size=(1, 1, 1), + strides=(1, 1, 1), padding=(0, 0, 0), weight_initializer=init.MSRAPrelu()) + + if self.nonlocal_type == 'concat': + if dim == 2: + self.concat_proj = nn.HybridSequential() + self.concat_proj.add(nn.Conv2D(in_channels=self.embed_dim * 2, channels=1, kernel_size=(1, 1), + strides=(1, 1), padding=(0, 0), weight_initializer=init.MSRAPrelu())) + self.concat_proj.add(nn.Activation('relu')) + elif dim == 3: + self.concat_proj = nn.HybridSequential() + self.concat_proj.add(nn.Conv3D(in_channels=self.embed_dim * 2, channels=1, kernel_size=(1, 1, 1), + strides=(1, 1, 1), padding=(0, 0, 0), weight_initializer=init.MSRAPrelu())) + self.concat_proj.add(nn.Activation('relu')) + + if sub_sample: + if dim == 2: + self.max_pool = nn.MaxPool2D(pool_size=(2, 2)) + elif dim == 3: + self.max_pool = nn.MaxPool3D(pool_size=(1, 2, 2)) + self.sub_phi = nn.HybridSequential() + self.sub_phi.add(self.phi) + self.sub_phi.add(self.max_pool) + self.sub_g = nn.HybridSequential() + self.sub_g.add(self.g) + self.sub_g.add(self.max_pool) + + if dim == 2: + self.W = nn.Conv2D(in_channels=self.embed_dim, channels=in_channels, kernel_size=(1, 1), + strides=(1, 1), padding=(0, 0), weight_initializer=init.MSRAPrelu()) + elif dim == 3: + self.W = nn.Conv3D(in_channels=self.embed_dim, channels=in_channels, kernel_size=(1, 1, 1), + strides=(1, 1, 1), padding=(0, 0, 0), weight_initializer=init.MSRAPrelu()) + + if use_bn: + self.bn = norm_layer(in_channels=in_channels, gamma_initializer='zeros', **({} if norm_kwargs is None else norm_kwargs)) + self.W_bn = nn.HybridSequential() + self.W_bn.add(self.W) + self.W_bn.add(self.bn) + + def hybrid_forward(self, F, x): + if self.embed: + theta = self.theta(x) + if self.sub_sample: + phi = self.sub_phi(x) + g = self.sub_g(x) + else: + phi = self.phi(x) + g = self.g(x) + else: + theta = x + phi = x + g = x + + if self.nonlocal_type == 'gaussian': + # reshape [BxCxTxHxW] to [BxCxTHW] + theta = F.reshape(theta, (0, 0, -1)) + phi = F.reshape(phi, (0, 0, -1)) + g = F.reshape(g, (0, 0, -1)) + # Direct transpose is slow, merge it into `batch_dot` operation. + # theta_phi = nd.batch_dot(F.transpose(theta, axes=(0, 2, 1)), phi) + theta_phi = F.batch_dot(theta, phi, transpose_a=True) + attn = F.softmax(theta_phi, axis=2) + elif self.non_local_type == 'concat': + raise NotImplementedError + elif self.non_local_type == 'dot': + raise NotImplementedError + else: + raise NotImplementedError + + y = F.batch_dot(g, attn, transpose_b=True) + y = F.reshape_like(y, x, lhs_begin=2, lhs_end=None, rhs_begin=2, rhs_end=None) + + if self.use_bn: + z = self.W_bn(y) + x + else: + z = self.W(y) + x + return z diff --git a/gluoncv/model_zoo/model_zoo.py b/gluoncv/model_zoo/model_zoo.py index 98d5228f53..e2317a4e9f 100644 --- a/gluoncv/model_zoo/model_zoo.py +++ b/gluoncv/model_zoo/model_zoo.py @@ -231,6 +231,10 @@ 'i3d_resnet101_v1_kinetics400': i3d_resnet101_v1_kinetics400, 'i3d_inceptionv1_kinetics400': i3d_inceptionv1_kinetics400, 'i3d_inceptionv3_kinetics400': i3d_inceptionv3_kinetics400, + 'i3d_nl5_resnet50_v1_kinetics400': i3d_nl5_resnet50_v1_kinetics400, + 'i3d_nl10_resnet50_v1_kinetics400': i3d_nl10_resnet50_v1_kinetics400, + 'i3d_nl5_resnet101_v1_kinetics400': i3d_nl5_resnet101_v1_kinetics400, + 'i3d_nl10_resnet101_v1_kinetics400': i3d_nl10_resnet101_v1_kinetics400, 'fcn_resnet101_voc_int8': fcn_resnet101_voc_int8, 'fcn_resnet101_coco_int8': fcn_resnet101_coco_int8, 'psp_resnet101_voc_int8': psp_resnet101_voc_int8,