Skip to content

Commit

Permalink
Update to use pre-assigned norm layer in ConvModule
Browse files Browse the repository at this point in the history
  • Loading branch information
sungchul2 committed Aug 13, 2024
1 parent e53ac57 commit 49da71e
Show file tree
Hide file tree
Showing 49 changed files with 471 additions and 415 deletions.
43 changes: 24 additions & 19 deletions src/otx/algo/action_classification/backbones/x3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from otx.algo.modules.activation import Swish
from otx.algo.modules.conv_module import Conv3dModule
from otx.algo.modules.norm import build_norm_layer
from otx.algo.utils.mmengine_utils import load_checkpoint
from otx.algo.utils.weight_init import constant_init, kaiming_init

Expand Down Expand Up @@ -73,7 +74,7 @@ class BlockX3D(nn.Module):
unit. If set as None, it means not using SE unit. Default: None.
use_swish (bool): Whether to use swish as the activation function
before and after the 3x3x3 conv. Default: True.
norm_callable (Callable[..., nn.Module] | None): Normalization layer module.
normalization_callable (Callable[..., nn.Module] | None): Normalization layer module.
Defaults to ``nn.BatchNorm3d``.
activation_callable (Callable[..., nn.Module] | None): Activation layer module.
Defaults to ``nn.ReLU``.
Expand All @@ -90,7 +91,7 @@ def __init__(
downsample: nn.Module | None = None,
se_ratio: float | None = None,
use_swish: bool = True,
norm_callable: Callable[..., nn.Module] | None = nn.BatchNorm3d,
normalization_callable: Callable[..., nn.Module] | None = nn.BatchNorm3d,
activation_callable: Callable[..., nn.Module] | None = nn.ReLU,
with_cp: bool = False,
):
Expand All @@ -103,7 +104,7 @@ def __init__(
self.downsample = downsample
self.se_ratio = se_ratio
self.use_swish = use_swish
self.norm_callable = norm_callable
self.normalization_callable = normalization_callable
self.activation_callable = activation_callable
self.with_cp = with_cp

Expand All @@ -114,7 +115,7 @@ def __init__(
stride=1,
padding=0,
bias=False,
norm_callable=self.norm_callable,
normalization=build_norm_layer(normalization_callable, num_features=planes),
activation_callable=self.activation_callable,
)
# Here we use the channel-wise conv
Expand All @@ -126,7 +127,7 @@ def __init__(
padding=1,
groups=planes,
bias=False,
norm_callable=self.norm_callable,
normalization=build_norm_layer(normalization_callable, num_features=planes),
activation_callable=None,
)

Expand All @@ -139,7 +140,7 @@ def __init__(
stride=1,
padding=0,
bias=False,
norm_callable=self.norm_callable,
normalization=build_norm_layer(normalization_callable, num_features=outplanes),
activation_callable=None,
)

Expand Down Expand Up @@ -196,8 +197,8 @@ class X3DBackbone(nn.Module):
unit. If set as None, it means not using SE unit. Default: 1 / 16.
use_swish (bool): Whether to use swish as the activation function
before and after the 3x3x3 conv. Default: True.
norm_callable (Callable[..., nn.Module] | None): Normalization layer module.
Defaults to ``partial(nn.BatchNorm3d, requires_grad=True)``.
normalization_callable (Callable[..., nn.Module] | None): Normalization layer module.
Defaults to ``partial(build_norm_layer, nn.BatchNorm3d, requires_grad=True)``.
activation_callable (Callable[..., nn.Module] | None): Activation layer module.
Defaults to ``nn.ReLU``.
norm_eval (bool): Whether to set BN layers to eval mode, namely, freeze
Expand All @@ -223,7 +224,11 @@ def __init__(
se_style: str = "half",
se_ratio: float = 1 / 16,
use_swish: bool = True,
norm_callable: Callable[..., nn.Module] | None = partial(nn.BatchNorm3d, requires_grad=True),
normalization_callable: Callable[..., nn.Module] | None = partial(
build_norm_layer,
nn.BatchNorm3d,
requires_grad=True,
),
activation_callable: Callable[..., nn.Module] | None = nn.ReLU,
norm_eval: bool = False,
with_cp: bool = False,
Expand Down Expand Up @@ -266,7 +271,7 @@ def __init__(
raise ValueError(msg)
self.use_swish = use_swish

self.norm_callable = norm_callable
self.normalization_callable = normalization_callable
self.activation_callable = activation_callable
self.norm_eval = norm_eval
self.with_cp = with_cp
Expand All @@ -293,7 +298,7 @@ def __init__(
se_style=self.se_style,
se_ratio=self.se_ratio,
use_swish=self.use_swish,
norm_callable=self.norm_callable,
normalization_callable=self.normalization_callable,
activation_callable=self.activation_callable,
with_cp=with_cp,
**kwargs,
Expand All @@ -311,7 +316,7 @@ def __init__(
stride=1,
padding=0,
bias=False,
norm_callable=self.norm_callable,
normalization=build_norm_layer(self.normalization_callable, num_features=int(self.feat_dim * self.gamma_b)),
activation_callable=self.activation_callable,
)
self.feat_dim = int(self.feat_dim * self.gamma_b)
Expand Down Expand Up @@ -349,7 +354,7 @@ def make_res_layer(
se_style: str = "half",
se_ratio: float | None = None,
use_swish: bool = True,
norm_callable: Callable[..., nn.Module] | None = None,
normalization_callable: Callable[..., nn.Module] | None = None,
activation_callable: Callable[..., nn.Module] | None = nn.ReLU,
with_cp: bool = False,
**kwargs,
Expand All @@ -375,7 +380,7 @@ def make_res_layer(
Default: None.
use_swish (bool): Whether to use swish as the activation function
before and after the 3x3x3 conv. Default: True.
norm_callable (Callable[..., nn.Module] | None): Normalization layer module.
normalization_callable (Callable[..., nn.Module] | None): Normalization layer module.
Defaults to None.
activation_callable (Callable[..., nn.Module] | None): Activation layer module.
Defaults to ``nn.ReLU``.
Expand All @@ -395,7 +400,7 @@ def make_res_layer(
stride=(1, spatial_stride, spatial_stride),
padding=0,
bias=False,
norm_callable=norm_callable,
normalization=build_norm_layer(normalization_callable, num_features=inplanes),
activation_callable=None,
)

Expand All @@ -417,7 +422,7 @@ def make_res_layer(
downsample=downsample,
se_ratio=se_ratio if use_se[0] else None,
use_swish=use_swish,
norm_callable=norm_callable,
normalization_callable=normalization_callable,
activation_callable=activation_callable,
with_cp=with_cp,
**kwargs,
Expand All @@ -433,7 +438,7 @@ def make_res_layer(
spatial_stride=1,
se_ratio=se_ratio if use_se[i] else None,
use_swish=use_swish,
norm_callable=norm_callable,
normalization_callable=normalization_callable,
activation_callable=activation_callable,
with_cp=with_cp,
**kwargs,
Expand All @@ -451,7 +456,7 @@ def _make_stem_layer(self) -> None:
stride=(1, 2, 2),
padding=(0, 1, 1),
bias=False,
norm_callable=None,
normalization=None,
activation_callable=None,
)
self.conv1_t = Conv3dModule(
Expand All @@ -462,7 +467,7 @@ def _make_stem_layer(self) -> None:
padding=(2, 0, 0),
groups=self.base_channels,
bias=False,
norm_callable=self.norm_callable,
normalization=build_norm_layer(self.normalization_callable, num_features=self.base_channels),
activation_callable=self.activation_callable,
)

Expand Down
3 changes: 2 additions & 1 deletion src/otx/algo/action_classification/x3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from otx.algo.action_classification.backbones.x3d import X3DBackbone
from otx.algo.action_classification.heads.x3d_head import X3DHead
from otx.algo.action_classification.recognizers.recognizer import BaseRecognizer
from otx.algo.modules.norm import build_norm_layer
from otx.algo.utils.mmengine_utils import load_checkpoint
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.metrics.accuracy import MultiClassClsMetricCallable
Expand Down Expand Up @@ -65,7 +66,7 @@ def _build_model(self, num_classes: int) -> nn.Module:
gamma_b=2.25,
gamma_d=2.2,
gamma_w=1,
norm_callable=partial(nn.BatchNorm3d, requires_grad=True),
normalization_callable=partial(build_norm_layer, nn.BatchNorm3d, requires_grad=True),
activation_callable=partial(nn.ReLU, inplace=True),
),
cls_head=X3DHead(
Expand Down
10 changes: 5 additions & 5 deletions src/otx/algo/classification/backbones/efficientnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from __future__ import annotations

import math
from functools import partial
from pathlib import Path
from typing import Callable, Literal

Expand All @@ -17,6 +16,7 @@

from otx.algo.modules.activation import Swish
from otx.algo.modules.conv_module import Conv2dModule
from otx.algo.modules.norm import build_norm_layer
from otx.algo.utils.mmengine_utils import load_checkpoint_to_model

PRETRAINED_ROOT = "https://github.com/osmr/imgclsmob/releases/download/v0.0.364/"
Expand Down Expand Up @@ -45,7 +45,7 @@ def conv1x1_block(
padding=padding,
groups=groups,
bias=bias,
norm_callable=partial(nn.BatchNorm2d, eps=bn_eps) if use_bn else None,
normalization=build_norm_layer(nn.BatchNorm2d, num_features=out_channels, eps=bn_eps) if use_bn else None,
activation_callable=activation_callable,
)

Expand All @@ -72,7 +72,7 @@ def conv3x3_block(
dilation=dilation,
groups=groups,
bias=bias,
norm_callable=partial(nn.BatchNorm2d, eps=bn_eps) if use_bn else None,
normalization=build_norm_layer(nn.BatchNorm2d, num_features=out_channels, eps=bn_eps) if use_bn else None,
activation_callable=activation_callable,
)

Expand All @@ -98,7 +98,7 @@ def dwconv3x3_block(
dilation=dilation,
groups=out_channels,
bias=bias,
norm_callable=partial(nn.BatchNorm2d, eps=bn_eps) if use_bn else None,
normalization=build_norm_layer(nn.BatchNorm2d, num_features=out_channels, eps=bn_eps) if use_bn else None,
activation_callable=activation_callable,
)

Expand All @@ -124,7 +124,7 @@ def dwconv5x5_block(
dilation=dilation,
groups=out_channels,
bias=bias,
norm_callable=partial(nn.BatchNorm2d, eps=bn_eps) if use_bn else None,
normalization=build_norm_layer(nn.BatchNorm2d, num_features=out_channels, eps=bn_eps) if use_bn else None,
activation_callable=activation_callable,
)

Expand Down
6 changes: 3 additions & 3 deletions src/otx/algo/classification/utils/swiglu_ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
out_dims: int | None = None,
bias: bool = True,
dropout_layer: dict | None = None,
norm_callable: Callable[..., nn.Module] | None = None,
normalization_callable: Callable[..., nn.Module] | None = None,
add_identity: bool = True,
) -> None:
super().__init__()
Expand All @@ -38,8 +38,8 @@ def __init__(

self.w12 = nn.Linear(self.embed_dims, 2 * hidden_dims, bias=bias)

if norm_callable is not None:
_, self.norm = build_norm_layer(norm_callable, hidden_dims)
if normalization_callable is not None:
_, self.norm = build_norm_layer(normalization_callable, hidden_dims)
else:
self.norm = nn.Identity()

Expand Down
26 changes: 18 additions & 8 deletions src/otx/algo/common/backbones/cspnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from otx.algo.detection.layers import CSPLayer
from otx.algo.modules.base_module import BaseModule
from otx.algo.modules.conv_module import Conv2dModule, DepthwiseSeparableConvModule
from otx.algo.modules.norm import build_norm_layer
from torch import Tensor, nn
from torch.nn.modules.batchnorm import _BatchNorm

Expand Down Expand Up @@ -44,7 +45,7 @@ class CSPNeXt(BaseModule):
layers. Defaults to (5, 9, 13).
channel_attention (bool): Whether to add channel attention in each
stage. Defaults to True.
norm_callable (Callable[..., nn.Module]): Normalization layer module.
normalization_callable (Callable[..., nn.Module]): Normalization layer module.
Defaults to ``partial(nn.BatchNorm2d, momentum=0.03, eps=0.001)``.
activation_callable (Callable[..., nn.Module] | None): Activation layer module.
Defaults to ``nn.SiLU``.
Expand Down Expand Up @@ -84,7 +85,7 @@ def __init__(
arch_ovewrite: dict | None = None,
spp_kernel_sizes: tuple[int, int, int] = (5, 9, 13),
channel_attention: bool = True,
norm_callable: Callable[..., nn.Module] = partial(nn.BatchNorm2d, momentum=0.03, eps=0.001),
normalization_callable: Callable[..., nn.Module] = partial(nn.BatchNorm2d, momentum=0.03, eps=0.001),
activation_callable: Callable[..., nn.Module] = nn.SiLU,
norm_eval: bool = False,
init_cfg: dict | None = None,
Expand Down Expand Up @@ -123,7 +124,10 @@ def __init__(
3,
padding=1,
stride=2,
norm_callable=norm_callable,
normalization=build_norm_layer(
normalization_callable,
num_features=int(arch_setting[0][0] * widen_factor // 2),
),
activation_callable=activation_callable,
),
Conv2dModule(
Expand All @@ -132,7 +136,10 @@ def __init__(
3,
padding=1,
stride=1,
norm_callable=norm_callable,
normalization=build_norm_layer(
normalization_callable,
num_features=int(arch_setting[0][0] * widen_factor // 2),
),
activation_callable=activation_callable,
),
Conv2dModule(
Expand All @@ -141,7 +148,10 @@ def __init__(
3,
padding=1,
stride=1,
norm_callable=norm_callable,
normalization=build_norm_layer(
normalization_callable,
num_features=int(arch_setting[0][0] * widen_factor),
),
activation_callable=activation_callable,
),
)
Expand All @@ -158,7 +168,7 @@ def __init__(
3,
stride=2,
padding=1,
norm_callable=norm_callable,
normalization=build_norm_layer(normalization_callable, num_features=out_channels),
activation_callable=activation_callable,
)
stage.append(conv_layer)
Expand All @@ -167,7 +177,7 @@ def __init__(
out_channels,
out_channels,
kernel_sizes=spp_kernel_sizes,
norm_callable=norm_callable,
normalization_callable=normalization_callable,
activation_callable=activation_callable,
)
stage.append(spp)
Expand All @@ -180,7 +190,7 @@ def __init__(
use_cspnext_block=True,
expand_ratio=expand_ratio,
channel_attention=channel_attention,
norm_callable=norm_callable,
normalization_callable=normalization_callable,
activation_callable=activation_callable,
)
stage.append(csp_layer)
Expand Down
12 changes: 6 additions & 6 deletions src/otx/algo/common/backbones/pytorchcv_backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,13 @@ def replace_activation(model: nn.Module, activation_callable: Callable[..., nn.M
return model


def replace_norm(model: nn.Module, norm_callable: Callable[..., nn.Module]) -> nn.Module:
def replace_norm(model: nn.Module, normalization_callable: Callable[..., nn.Module]) -> nn.Module:
"""Replace norm funtion."""
for name, module in model._modules.items():
if len(list(module.children())) > 0:
model._modules[name] = replace_norm(module, norm_callable)
model._modules[name] = replace_norm(module, normalization_callable)
if "bn" in name:
model._modules[name] = build_norm_layer(norm_callable, num_features=module.num_features)[1]
model._modules[name] = build_norm_layer(normalization_callable, num_features=module.num_features)[1]
return model


Expand Down Expand Up @@ -120,7 +120,7 @@ def _build_pytorchcv_model(
norm_eval: bool = False,
verbose: bool = False,
activation_callable: Callable[..., nn.Module] | None = None,
norm_callable: Callable[..., nn.Module] | None = None,
normalization_callable: Callable[..., nn.Module] | None = None,
**kwargs,
) -> nn.Module:
"""Build pytorchcv model."""
Expand All @@ -132,8 +132,8 @@ def _build_pytorchcv_model(
model = _models[type](**kwargs)
if activation_callable:
model = replace_activation(model, activation_callable)
if norm_callable:
model = replace_norm(model, norm_callable)
if normalization_callable:
model = replace_norm(model, normalization_callable)
model.out_indices = out_indices
model.frozen_stages = frozen_stages
model.norm_eval = norm_eval
Expand Down
Loading

0 comments on commit 49da71e

Please sign in to comment.