Skip to content

Commit

Permalink
Add non-local NN for video action recognition (dmlc#978)
Browse files Browse the repository at this point in the history
* add non-local

* add nonlocal

* rm ascii char
  • Loading branch information
bryanyzhu authored Oct 10, 2019
1 parent e980f47 commit f6a97ba
Show file tree
Hide file tree
Showing 3 changed files with 354 additions and 2 deletions.
231 changes: 229 additions & 2 deletions gluoncv/model_zoo/action_recognition/i3d_resnet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# pylint: disable=line-too-long,too-many-lines,missing-docstring,arguments-differ,unused-argument
# Code partially borrowed from https://github.com/open-mmlab/mmaction.

__all__ = ['I3D_ResNetV1', 'i3d_resnet50_v1_kinetics400', 'i3d_resnet101_v1_kinetics400']
__all__ = ['I3D_ResNetV1', 'i3d_resnet50_v1_kinetics400', 'i3d_resnet101_v1_kinetics400',
'i3d_nl5_resnet50_v1_kinetics400', 'i3d_nl10_resnet50_v1_kinetics400',
'i3d_nl5_resnet101_v1_kinetics400', 'i3d_nl10_resnet101_v1_kinetics400']

from mxnet import nd
from mxnet import init
Expand All @@ -10,6 +12,7 @@
from mxnet.gluon import nn
from mxnet.gluon.nn import BatchNorm
from ..resnetv1b import resnet50_v1b, resnet101_v1b
from .non_local import build_nonlocal_block

def conv3x3x3(in_planes, out_planes, spatial_stride=1, temporal_stride=1, dilation=1):
"3x3x3 convolution with padding"
Expand Down Expand Up @@ -185,6 +188,7 @@ def __init__(self,
nonlocal_cfg_ = nonlocal_cfg.copy()
nonlocal_cfg_['in_channels'] = planes * self.expansion
self.nonlocal_block = build_nonlocal_block(nonlocal_cfg_)
self.bottleneck.add(self.nonlocal_block)
else:
self.nonlocal_block = None

Expand Down Expand Up @@ -424,8 +428,19 @@ def init_weights(self, ctx):
resnet2d = resnet101_v1b(pretrained=True)
else:
print('No such 2D pre-trained network of depth %d.' % (self.depth))

weights2d = resnet2d.collect_params()
weights3d = self.collect_params()
if self.nonlocal_cfg is None:
weights3d = self.collect_params()
else:
train_params_list = []
raw_params = self.collect_params()
for raw_name in raw_params.keys():
if 'nonlocal' in raw_name:
continue
train_params_list.append(raw_name)
init_patterns = '|'.join(train_params_list)
weights3d = self.collect_params(init_patterns)
assert len(weights2d.keys()) == len(weights3d.keys()), 'Number of parameters should be same.'

dict2d = {}
Expand Down Expand Up @@ -579,3 +594,215 @@ def i3d_resnet101_v1_kinetics400(nclass=400, pretrained=False, pretrained_base=T
model.collect_params().reset_ctx(ctx)

return model

def i3d_nl5_resnet50_v1_kinetics400(nclass=400, pretrained=False, pretrained_base=True, ctx=cpu(),
root='~/.mxnet/models', tsn=False, num_segments=1, partial_bn=False, **kwargs):
r"""Inflated 3D model (I3D) from
`"Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset"
<https://arxiv.org/abs/1705.07750>`_ paper.
`"Non-local Neural Networks"
<https://arxiv.org/abs/1711.07971>`_ paper.
Parameters
----------
pretrained : bool or str
Boolean value controls whether to load the default pretrained weights for model.
String value represents the hashtag for a certain version of pretrained weights.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default $MXNET_HOME/models
Location for keeping the model parameters.
partial_bn : bool, default False
Freeze all batch normalization layers during training except the first layer.
norm_layer : object
Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`)
Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
norm_kwargs : dict
Additional `norm_layer` arguments, for example `num_devices=4`
for :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
"""

model = I3D_ResNetV1(nclass=nclass,
depth=50,
pretrained=pretrained,
pretrained_base=pretrained_base,
num_segments=num_segments,
out_indices=[3],
inflate_freq=((1, 1, 1), (1, 0, 1, 0), (1, 0, 1, 0, 1, 0), (0, 1, 0)),
nonlocal_stages=(1, 2),
nonlocal_cfg=dict(nonlocal_type="gaussian"),
nonlocal_freq=((0, 0, 0), (0, 1, 0, 1), (0, 1, 0, 1, 0, 1), (0, 0, 0)),
bn_eval=False,
partial_bn=partial_bn,
ctx=ctx,
**kwargs)

if pretrained:
from ..model_store import get_model_file
model.load_parameters(get_model_file('i3d_nl5_resnet50_v1_kinetics400',
tag=pretrained, root=root), ctx=ctx)
from ...data import Kinetics400Attr
attrib = Kinetics400Attr()
model.classes = attrib.classes
model.collect_params().reset_ctx(ctx)

return model

def i3d_nl10_resnet50_v1_kinetics400(nclass=400, pretrained=False, pretrained_base=True, ctx=cpu(),
root='~/.mxnet/models', tsn=False, num_segments=1, partial_bn=False, **kwargs):
r"""Inflated 3D model (I3D) from
`"Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset"
<https://arxiv.org/abs/1705.07750>`_ paper.
`"Non-local Neural Networks"
<https://arxiv.org/abs/1711.07971>`_ paper.
Parameters
----------
pretrained : bool or str
Boolean value controls whether to load the default pretrained weights for model.
String value represents the hashtag for a certain version of pretrained weights.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default $MXNET_HOME/models
Location for keeping the model parameters.
partial_bn : bool, default False
Freeze all batch normalization layers during training except the first layer.
norm_layer : object
Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`)
Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
norm_kwargs : dict
Additional `norm_layer` arguments, for example `num_devices=4`
for :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
"""

model = I3D_ResNetV1(nclass=nclass,
depth=50,
pretrained=pretrained,
pretrained_base=pretrained_base,
num_segments=num_segments,
out_indices=[3],
inflate_freq=((1, 1, 1), (1, 0, 1, 0), (1, 0, 1, 0, 1, 0), (0, 1, 0)),
nonlocal_stages=(1, 2),
nonlocal_cfg=dict(nonlocal_type="gaussian"),
nonlocal_freq=((0, 0, 0), (1, 1, 1, 1), (1, 1, 1, 1, 1, 1), (0, 0, 0)),
bn_eval=False,
partial_bn=partial_bn,
ctx=ctx,
**kwargs)

if pretrained:
from ..model_store import get_model_file
model.load_parameters(get_model_file('i3d_nl10_resnet50_v1_kinetics400',
tag=pretrained, root=root), ctx=ctx)
from ...data import Kinetics400Attr
attrib = Kinetics400Attr()
model.classes = attrib.classes
model.collect_params().reset_ctx(ctx)

return model

def i3d_nl5_resnet101_v1_kinetics400(nclass=400, pretrained=False, pretrained_base=True, ctx=cpu(),
root='~/.mxnet/models', tsn=False, num_segments=1, partial_bn=False, **kwargs):
r"""Inflated 3D model (I3D) from
`"Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset"
<https://arxiv.org/abs/1705.07750>`_ paper.
`"Non-local Neural Networks"
<https://arxiv.org/abs/1711.07971>`_ paper.
Parameters
----------
pretrained : bool or str
Boolean value controls whether to load the default pretrained weights for model.
String value represents the hashtag for a certain version of pretrained weights.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default $MXNET_HOME/models
Location for keeping the model parameters.
partial_bn : bool, default False
Freeze all batch normalization layers during training except the first layer.
norm_layer : object
Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`)
Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
norm_kwargs : dict
Additional `norm_layer` arguments, for example `num_devices=4`
for :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
"""

model = I3D_ResNetV1(nclass=nclass,
depth=101,
pretrained=pretrained,
pretrained_base=pretrained_base,
num_segments=num_segments,
out_indices=[3],
inflate_freq=((1, 1, 1), (1, 0, 1, 0), (1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1), (0, 1, 0)),
nonlocal_stages=(1, 2),
nonlocal_cfg=dict(nonlocal_type="gaussian"),
nonlocal_freq=((0, 0, 0), (0, 1, 0, 1), (0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0), (0, 0, 0)),
bn_eval=False,
partial_bn=partial_bn,
ctx=ctx,
**kwargs)

if pretrained:
from ..model_store import get_model_file
model.load_parameters(get_model_file('i3d_nl5_resnet101_v1_kinetics400',
tag=pretrained, root=root), ctx=ctx)
from ...data import Kinetics400Attr
attrib = Kinetics400Attr()
model.classes = attrib.classes
model.collect_params().reset_ctx(ctx)

return model

def i3d_nl10_resnet101_v1_kinetics400(nclass=400, pretrained=False, pretrained_base=True, ctx=cpu(),
root='~/.mxnet/models', tsn=False, num_segments=1, partial_bn=False, **kwargs):
r"""Inflated 3D model (I3D) from
`"Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset"
<https://arxiv.org/abs/1705.07750>`_ paper.
`"Non-local Neural Networks"
<https://arxiv.org/abs/1711.07971>`_ paper.
Parameters
----------
pretrained : bool or str
Boolean value controls whether to load the default pretrained weights for model.
String value represents the hashtag for a certain version of pretrained weights.
ctx : Context, default CPU
The context in which to load the pretrained weights.
root : str, default $MXNET_HOME/models
Location for keeping the model parameters.
partial_bn : bool, default False
Freeze all batch normalization layers during training except the first layer.
norm_layer : object
Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`)
Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
norm_kwargs : dict
Additional `norm_layer` arguments, for example `num_devices=4`
for :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
"""

model = I3D_ResNetV1(nclass=nclass,
depth=101,
pretrained=pretrained,
pretrained_base=pretrained_base,
num_segments=num_segments,
out_indices=[3],
inflate_freq=((1, 1, 1), (1, 0, 1, 0), (1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1), (0, 1, 0)),
nonlocal_stages=(1, 2),
nonlocal_cfg=dict(nonlocal_type="gaussian"),
nonlocal_freq=((0, 0, 0), (1, 1, 1, 1), (0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1), (0, 0, 0)),
bn_eval=False,
partial_bn=partial_bn,
ctx=ctx,
**kwargs)

if pretrained:
from ..model_store import get_model_file
model.load_parameters(get_model_file('i3d_nl10_resnet101_v1_kinetics400',
tag=pretrained, root=root), ctx=ctx)
from ...data import Kinetics400Attr
attrib = Kinetics400Attr()
model.classes = attrib.classes
model.collect_params().reset_ctx(ctx)

return model
121 changes: 121 additions & 0 deletions gluoncv/model_zoo/action_recognition/non_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""Non-local block for video action recognition"""
# pylint: disable=line-too-long,too-many-lines,missing-docstring,arguments-differ,unused-argument
from mxnet.gluon.block import HybridBlock
from mxnet import init
from mxnet.gluon import nn
from mxnet.gluon.nn import BatchNorm

def build_nonlocal_block(cfg):
""" Build nonlocal block from
`"Non-local Neural Networks"
<https://arxiv.org/abs/1711.07971>`_ paper.
Code adapted from mmaction.
"""
assert isinstance(cfg, dict)
cfg_ = cfg.copy()
return NonLocal(**cfg_)

class NonLocal(HybridBlock):
def __init__(self, in_channels=1024, nonlocal_type="gaussian", dim=3, embed=True, embed_dim=None, sub_sample=True, use_bn=True,
norm_layer=BatchNorm, norm_kwargs=None, ctx=None, **kwargs):
super(NonLocal, self).__init__()

assert nonlocal_type in ['gaussian', 'dot', 'concat']
self.nonlocal_type = nonlocal_type
self.embed = embed
self.embed_dim = embed_dim if embed_dim is not None else in_channels // 2
self.sub_sample = sub_sample
self.use_bn = use_bn

with self.name_scope():
if self.embed:
if dim == 2:
self.theta = nn.Conv2D(in_channels=in_channels, channels=self.embed_dim, kernel_size=(1, 1),
strides=(1, 1), padding=(0, 0), weight_initializer=init.MSRAPrelu())
self.phi = nn.Conv2D(in_channels=in_channels, channels=self.embed_dim, kernel_size=(1, 1),
strides=(1, 1), padding=(0, 0), weight_initializer=init.MSRAPrelu())
self.g = nn.Conv2D(in_channels=in_channels, channels=self.embed_dim, kernel_size=(1, 1),
strides=(1, 1), padding=(0, 0), weight_initializer=init.MSRAPrelu())
elif dim == 3:
self.theta = nn.Conv3D(in_channels=in_channels, channels=self.embed_dim, kernel_size=(1, 1, 1),
strides=(1, 1, 1), padding=(0, 0, 0), weight_initializer=init.MSRAPrelu())
self.phi = nn.Conv3D(in_channels=in_channels, channels=self.embed_dim, kernel_size=(1, 1, 1),
strides=(1, 1, 1), padding=(0, 0, 0), weight_initializer=init.MSRAPrelu())
self.g = nn.Conv3D(in_channels=in_channels, channels=self.embed_dim, kernel_size=(1, 1, 1),
strides=(1, 1, 1), padding=(0, 0, 0), weight_initializer=init.MSRAPrelu())

if self.nonlocal_type == 'concat':
if dim == 2:
self.concat_proj = nn.HybridSequential()
self.concat_proj.add(nn.Conv2D(in_channels=self.embed_dim * 2, channels=1, kernel_size=(1, 1),
strides=(1, 1), padding=(0, 0), weight_initializer=init.MSRAPrelu()))
self.concat_proj.add(nn.Activation('relu'))
elif dim == 3:
self.concat_proj = nn.HybridSequential()
self.concat_proj.add(nn.Conv3D(in_channels=self.embed_dim * 2, channels=1, kernel_size=(1, 1, 1),
strides=(1, 1, 1), padding=(0, 0, 0), weight_initializer=init.MSRAPrelu()))
self.concat_proj.add(nn.Activation('relu'))

if sub_sample:
if dim == 2:
self.max_pool = nn.MaxPool2D(pool_size=(2, 2))
elif dim == 3:
self.max_pool = nn.MaxPool3D(pool_size=(1, 2, 2))
self.sub_phi = nn.HybridSequential()
self.sub_phi.add(self.phi)
self.sub_phi.add(self.max_pool)
self.sub_g = nn.HybridSequential()
self.sub_g.add(self.g)
self.sub_g.add(self.max_pool)

if dim == 2:
self.W = nn.Conv2D(in_channels=self.embed_dim, channels=in_channels, kernel_size=(1, 1),
strides=(1, 1), padding=(0, 0), weight_initializer=init.MSRAPrelu())
elif dim == 3:
self.W = nn.Conv3D(in_channels=self.embed_dim, channels=in_channels, kernel_size=(1, 1, 1),
strides=(1, 1, 1), padding=(0, 0, 0), weight_initializer=init.MSRAPrelu())

if use_bn:
self.bn = norm_layer(in_channels=in_channels, gamma_initializer='zeros', **({} if norm_kwargs is None else norm_kwargs))
self.W_bn = nn.HybridSequential()
self.W_bn.add(self.W)
self.W_bn.add(self.bn)

def hybrid_forward(self, F, x):
if self.embed:
theta = self.theta(x)
if self.sub_sample:
phi = self.sub_phi(x)
g = self.sub_g(x)
else:
phi = self.phi(x)
g = self.g(x)
else:
theta = x
phi = x
g = x

if self.nonlocal_type == 'gaussian':
# reshape [BxCxTxHxW] to [BxCxTHW]
theta = F.reshape(theta, (0, 0, -1))
phi = F.reshape(phi, (0, 0, -1))
g = F.reshape(g, (0, 0, -1))
# Direct transpose is slow, merge it into `batch_dot` operation.
# theta_phi = nd.batch_dot(F.transpose(theta, axes=(0, 2, 1)), phi)
theta_phi = F.batch_dot(theta, phi, transpose_a=True)
attn = F.softmax(theta_phi, axis=2)
elif self.non_local_type == 'concat':
raise NotImplementedError
elif self.non_local_type == 'dot':
raise NotImplementedError
else:
raise NotImplementedError

y = F.batch_dot(g, attn, transpose_b=True)
y = F.reshape_like(y, x, lhs_begin=2, lhs_end=None, rhs_begin=2, rhs_end=None)

if self.use_bn:
z = self.W_bn(y) + x
else:
z = self.W(y) + x
return z
4 changes: 4 additions & 0 deletions gluoncv/model_zoo/model_zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@
'i3d_resnet101_v1_kinetics400': i3d_resnet101_v1_kinetics400,
'i3d_inceptionv1_kinetics400': i3d_inceptionv1_kinetics400,
'i3d_inceptionv3_kinetics400': i3d_inceptionv3_kinetics400,
'i3d_nl5_resnet50_v1_kinetics400': i3d_nl5_resnet50_v1_kinetics400,
'i3d_nl10_resnet50_v1_kinetics400': i3d_nl10_resnet50_v1_kinetics400,
'i3d_nl5_resnet101_v1_kinetics400': i3d_nl5_resnet101_v1_kinetics400,
'i3d_nl10_resnet101_v1_kinetics400': i3d_nl10_resnet101_v1_kinetics400,
'fcn_resnet101_voc_int8': fcn_resnet101_voc_int8,
'fcn_resnet101_coco_int8': fcn_resnet101_coco_int8,
'psp_resnet101_voc_int8': psp_resnet101_voc_int8,
Expand Down

0 comments on commit f6a97ba

Please sign in to comment.