Skip to content

Commit

Permalink
Merge pull request #652 from dreamerlin/3d
Browse files Browse the repository at this point in the history
[Feature] Add 3D support in wrapper
  • Loading branch information
ZwwWayne authored Nov 20, 2020
2 parents ec43b67 + 1a12ac7 commit dfa36df
Show file tree
Hide file tree
Showing 4 changed files with 349 additions and 132 deletions.
12 changes: 8 additions & 4 deletions mmcv/cnn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
# Copyright (c) Open-MMLab. All rights reserved.
from .alexnet import AlexNet
# yapf: disable
from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS,
ContextBlock, Conv2d, ConvAWS2d, ConvModule,
ConvTranspose2d, ConvWS2d, DepthwiseSeparableConvModule,
GeneralizedAttention, HSigmoid, HSwish, Linear, MaxPool2d,
ContextBlock, Conv2d, Conv3d, ConvAWS2d, ConvModule,
ConvTranspose2d, ConvTranspose3d, ConvWS2d,
DepthwiseSeparableConvModule, GeneralizedAttention,
HSigmoid, HSwish, Linear, MaxPool2d, MaxPool3d,
NonLocal1d, NonLocal2d, NonLocal3d, Scale, Swish,
build_activation_layer, build_conv_layer,
build_norm_layer, build_padding_layer, build_plugin_layer,
build_upsample_layer, conv_ws_2d, is_norm)
# yapf: enable
from .resnet import ResNet, make_res_layer
from .utils import (bias_init_with_prob, caffe2_xavier_init, constant_init,
fuse_conv_bn, get_model_complexity_info, kaiming_init,
Expand All @@ -26,5 +29,6 @@
'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS', 'UPSAMPLE_LAYERS',
'PLUGIN_LAYERS', 'Scale', 'get_model_complexity_info', 'conv_ws_2d',
'ConvAWS2d', 'ConvWS2d', 'fuse_conv_bn', 'DepthwiseSeparableConvModule',
'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d'
'Linear', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d',
'MaxPool3d', 'Conv3d'
]
6 changes: 4 additions & 2 deletions mmcv/cnn/bricks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from .scale import Scale
from .swish import Swish
from .upsample import build_upsample_layer
from .wrappers import Conv2d, ConvTranspose2d, Linear, MaxPool2d
from .wrappers import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d,
Linear, MaxPool2d, MaxPool3d)

__all__ = [
'ConvModule', 'build_activation_layer', 'build_conv_layer',
Expand All @@ -27,5 +28,6 @@
'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS',
'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d',
'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear',
'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d'
'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d',
'ConvTranspose3d', 'MaxPool3d', 'Conv3d'
]
67 changes: 65 additions & 2 deletions mmcv/cnn/bricks/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import torch
import torch.nn as nn
from torch.nn.modules.utils import _pair
from torch.nn.modules.utils import _pair, _triple

from .registry import CONV_LAYERS, UPSAMPLE_LAYERS

Expand Down Expand Up @@ -58,6 +58,27 @@ def forward(self, x):
return super().forward(x)


@CONV_LAYERS.register_module('Conv3d', force=True)
class Conv3d(nn.Conv3d):

def forward(self, x):
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size,
self.padding, self.stride, self.dilation):
o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1
out_shape.append(o)
empty = NewEmptyTensorOp.apply(x, out_shape)
if self.training:
# produce dummy gradient to avoid DDP warning.
dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
return empty + dummy
else:
return empty

return super().forward(x)


@CONV_LAYERS.register_module()
@CONV_LAYERS.register_module('deconv')
@UPSAMPLE_LAYERS.register_module('deconv', force=True)
Expand All @@ -78,7 +99,30 @@ def forward(self, x):
else:
return empty

return super(ConvTranspose2d, self).forward(x)
return super().forward(x)


@CONV_LAYERS.register_module()
@CONV_LAYERS.register_module('deconv3d')
@UPSAMPLE_LAYERS.register_module('deconv3d', force=True)
class ConvTranspose3d(nn.ConvTranspose3d):

def forward(self, x):
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
out_shape = [x.shape[0], self.out_channels]
for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size,
self.padding, self.stride,
self.dilation, self.output_padding):
out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op)
empty = NewEmptyTensorOp.apply(x, out_shape)
if self.training:
# produce dummy gradient to avoid DDP warning.
dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
return empty + dummy
else:
return empty

return super().forward(x)


class MaxPool2d(nn.MaxPool2d):
Expand All @@ -99,6 +143,25 @@ def forward(self, x):
return super().forward(x)


class MaxPool3d(nn.MaxPool3d):

def forward(self, x):
# PyTorch 1.7 does not support empty tensor inference yet
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 7)):
out_shape = list(x.shape[:2])
for i, k, p, s, d in zip(x.shape[-3:], _triple(self.kernel_size),
_triple(self.padding),
_triple(self.stride),
_triple(self.dilation)):
o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1
o = math.ceil(o) if self.ceil_mode else math.floor(o)
out_shape.append(o)
empty = NewEmptyTensorOp.apply(x, out_shape)
return empty

return super().forward(x)


class Linear(torch.nn.Linear):

def forward(self, x):
Expand Down
Loading

0 comments on commit dfa36df

Please sign in to comment.