From 49da71eef966cfb242d2042ecfa26dbb01f770a4 Mon Sep 17 00:00:00 2001 From: "Kim, Sungchul" Date: Tue, 13 Aug 2024 16:16:04 +0900 Subject: [PATCH] Update to use pre-assigned norm layer in `ConvModule` --- .../action_classification/backbones/x3d.py | 43 +++-- src/otx/algo/action_classification/x3d.py | 3 +- .../classification/backbones/efficientnet.py | 10 +- .../algo/classification/utils/swiglu_ffn.py | 6 +- src/otx/algo/common/backbones/cspnext.py | 26 ++- .../common/backbones/pytorchcv_backbones.py | 12 +- src/otx/algo/common/backbones/resnet.py | 29 ++- src/otx/algo/common/backbones/resnext.py | 8 +- src/otx/algo/common/layers/res_layer.py | 16 +- src/otx/algo/common/layers/spp_layer.py | 9 +- .../algo/detection/backbones/csp_darknet.py | 19 +- src/otx/algo/detection/backbones/presnet.py | 62 +++---- src/otx/algo/detection/heads/atss_head.py | 18 +- src/otx/algo/detection/heads/rtmdet_head.py | 16 +- src/otx/algo/detection/heads/yolox_head.py | 11 +- src/otx/algo/detection/layers/csp_layer.py | 61 +++---- src/otx/algo/detection/necks/cspnext_pafpn.py | 15 +- src/otx/algo/detection/necks/fpn.py | 13 +- .../algo/detection/necks/hybrid_encoder.py | 20 +- src/otx/algo/detection/necks/yolox_pafpn.py | 15 +- src/otx/algo/detection/rtdetr.py | 9 +- src/otx/algo/detection/rtmdet.py | 6 +- .../instance_segmentation/backbones/swin.py | 26 +-- .../heads/convfc_bbox_head.py | 4 +- .../heads/fcn_mask_head.py | 7 +- .../heads/rtmdet_ins_head.py | 26 +-- .../layers/transformer.py | 16 +- .../algo/instance_segmentation/maskrcnn.py | 5 +- .../algo/instance_segmentation/necks/fpn.py | 11 +- .../algo/instance_segmentation/rtmdet_inst.py | 6 +- src/otx/algo/keypoint_detection/rtmpose.py | 2 +- src/otx/algo/modules/conv_module.py | 6 +- src/otx/algo/modules/norm.py | 5 +- src/otx/algo/modules/transformer.py | 8 +- .../algo/segmentation/backbones/litehrnet.py | 172 ++++++++++-------- src/otx/algo/segmentation/backbones/mscan.py | 40 ++-- src/otx/algo/segmentation/dino_v2_seg.py | 3 +- .../algo/segmentation/heads/base_segm_head.py | 6 +- src/otx/algo/segmentation/heads/fcn_head.py | 15 +- src/otx/algo/segmentation/heads/ham_head.py | 13 +- src/otx/algo/segmentation/litehrnet.py | 11 +- .../algo/segmentation/modules/aggregators.py | 11 +- src/otx/algo/segmentation/modules/blocks.py | 21 ++- src/otx/algo/segmentation/segnext.py | 13 +- .../algo/detection/backbones/test_presnet.py | 8 +- tests/unit/algo/modules/test_conv_module.py | 3 +- tests/unit/algo/modules/test_transformer.py | 10 +- .../algo/segmentation/heads/test_ham_head.py | 3 +- .../algo/segmentation/modules/test_blokcs.py | 8 +- 49 files changed, 471 insertions(+), 415 deletions(-) diff --git a/src/otx/algo/action_classification/backbones/x3d.py b/src/otx/algo/action_classification/backbones/x3d.py index 12afa2455c3..ff16a225be2 100644 --- a/src/otx/algo/action_classification/backbones/x3d.py +++ b/src/otx/algo/action_classification/backbones/x3d.py @@ -16,6 +16,7 @@ from otx.algo.modules.activation import Swish from otx.algo.modules.conv_module import Conv3dModule +from otx.algo.modules.norm import build_norm_layer from otx.algo.utils.mmengine_utils import load_checkpoint from otx.algo.utils.weight_init import constant_init, kaiming_init @@ -73,7 +74,7 @@ class BlockX3D(nn.Module): unit. If set as None, it means not using SE unit. Default: None. use_swish (bool): Whether to use swish as the activation function before and after the 3x3x3 conv. Default: True. - norm_callable (Callable[..., nn.Module] | None): Normalization layer module. + normalization_callable (Callable[..., nn.Module] | None): Normalization layer module. Defaults to ``nn.BatchNorm3d``. activation_callable (Callable[..., nn.Module] | None): Activation layer module. Defaults to ``nn.ReLU``. @@ -90,7 +91,7 @@ def __init__( downsample: nn.Module | None = None, se_ratio: float | None = None, use_swish: bool = True, - norm_callable: Callable[..., nn.Module] | None = nn.BatchNorm3d, + normalization_callable: Callable[..., nn.Module] | None = nn.BatchNorm3d, activation_callable: Callable[..., nn.Module] | None = nn.ReLU, with_cp: bool = False, ): @@ -103,7 +104,7 @@ def __init__( self.downsample = downsample self.se_ratio = se_ratio self.use_swish = use_swish - self.norm_callable = norm_callable + self.normalization_callable = normalization_callable self.activation_callable = activation_callable self.with_cp = with_cp @@ -114,7 +115,7 @@ def __init__( stride=1, padding=0, bias=False, - norm_callable=self.norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=planes), activation_callable=self.activation_callable, ) # Here we use the channel-wise conv @@ -126,7 +127,7 @@ def __init__( padding=1, groups=planes, bias=False, - norm_callable=self.norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=planes), activation_callable=None, ) @@ -139,7 +140,7 @@ def __init__( stride=1, padding=0, bias=False, - norm_callable=self.norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=outplanes), activation_callable=None, ) @@ -196,8 +197,8 @@ class X3DBackbone(nn.Module): unit. If set as None, it means not using SE unit. Default: 1 / 16. use_swish (bool): Whether to use swish as the activation function before and after the 3x3x3 conv. Default: True. - norm_callable (Callable[..., nn.Module] | None): Normalization layer module. - Defaults to ``partial(nn.BatchNorm3d, requires_grad=True)``. + normalization_callable (Callable[..., nn.Module] | None): Normalization layer module. + Defaults to ``partial(build_norm_layer, nn.BatchNorm3d, requires_grad=True)``. activation_callable (Callable[..., nn.Module] | None): Activation layer module. Defaults to ``nn.ReLU``. norm_eval (bool): Whether to set BN layers to eval mode, namely, freeze @@ -223,7 +224,11 @@ def __init__( se_style: str = "half", se_ratio: float = 1 / 16, use_swish: bool = True, - norm_callable: Callable[..., nn.Module] | None = partial(nn.BatchNorm3d, requires_grad=True), + normalization_callable: Callable[..., nn.Module] | None = partial( + build_norm_layer, + nn.BatchNorm3d, + requires_grad=True, + ), activation_callable: Callable[..., nn.Module] | None = nn.ReLU, norm_eval: bool = False, with_cp: bool = False, @@ -266,7 +271,7 @@ def __init__( raise ValueError(msg) self.use_swish = use_swish - self.norm_callable = norm_callable + self.normalization_callable = normalization_callable self.activation_callable = activation_callable self.norm_eval = norm_eval self.with_cp = with_cp @@ -293,7 +298,7 @@ def __init__( se_style=self.se_style, se_ratio=self.se_ratio, use_swish=self.use_swish, - norm_callable=self.norm_callable, + normalization_callable=self.normalization_callable, activation_callable=self.activation_callable, with_cp=with_cp, **kwargs, @@ -311,7 +316,7 @@ def __init__( stride=1, padding=0, bias=False, - norm_callable=self.norm_callable, + normalization=build_norm_layer(self.normalization_callable, num_features=int(self.feat_dim * self.gamma_b)), activation_callable=self.activation_callable, ) self.feat_dim = int(self.feat_dim * self.gamma_b) @@ -349,7 +354,7 @@ def make_res_layer( se_style: str = "half", se_ratio: float | None = None, use_swish: bool = True, - norm_callable: Callable[..., nn.Module] | None = None, + normalization_callable: Callable[..., nn.Module] | None = None, activation_callable: Callable[..., nn.Module] | None = nn.ReLU, with_cp: bool = False, **kwargs, @@ -375,7 +380,7 @@ def make_res_layer( Default: None. use_swish (bool): Whether to use swish as the activation function before and after the 3x3x3 conv. Default: True. - norm_callable (Callable[..., nn.Module] | None): Normalization layer module. + normalization_callable (Callable[..., nn.Module] | None): Normalization layer module. Defaults to None. activation_callable (Callable[..., nn.Module] | None): Activation layer module. Defaults to ``nn.ReLU``. @@ -395,7 +400,7 @@ def make_res_layer( stride=(1, spatial_stride, spatial_stride), padding=0, bias=False, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=inplanes), activation_callable=None, ) @@ -417,7 +422,7 @@ def make_res_layer( downsample=downsample, se_ratio=se_ratio if use_se[0] else None, use_swish=use_swish, - norm_callable=norm_callable, + normalization_callable=normalization_callable, activation_callable=activation_callable, with_cp=with_cp, **kwargs, @@ -433,7 +438,7 @@ def make_res_layer( spatial_stride=1, se_ratio=se_ratio if use_se[i] else None, use_swish=use_swish, - norm_callable=norm_callable, + normalization_callable=normalization_callable, activation_callable=activation_callable, with_cp=with_cp, **kwargs, @@ -451,7 +456,7 @@ def _make_stem_layer(self) -> None: stride=(1, 2, 2), padding=(0, 1, 1), bias=False, - norm_callable=None, + normalization=None, activation_callable=None, ) self.conv1_t = Conv3dModule( @@ -462,7 +467,7 @@ def _make_stem_layer(self) -> None: padding=(2, 0, 0), groups=self.base_channels, bias=False, - norm_callable=self.norm_callable, + normalization=build_norm_layer(self.normalization_callable, num_features=self.base_channels), activation_callable=self.activation_callable, ) diff --git a/src/otx/algo/action_classification/x3d.py b/src/otx/algo/action_classification/x3d.py index e5e8e9d12e1..ef8e9a6e714 100644 --- a/src/otx/algo/action_classification/x3d.py +++ b/src/otx/algo/action_classification/x3d.py @@ -13,6 +13,7 @@ from otx.algo.action_classification.backbones.x3d import X3DBackbone from otx.algo.action_classification.heads.x3d_head import X3DHead from otx.algo.action_classification.recognizers.recognizer import BaseRecognizer +from otx.algo.modules.norm import build_norm_layer from otx.algo.utils.mmengine_utils import load_checkpoint from otx.algo.utils.support_otx_v1 import OTXv1Helper from otx.core.metrics.accuracy import MultiClassClsMetricCallable @@ -65,7 +66,7 @@ def _build_model(self, num_classes: int) -> nn.Module: gamma_b=2.25, gamma_d=2.2, gamma_w=1, - norm_callable=partial(nn.BatchNorm3d, requires_grad=True), + normalization_callable=partial(build_norm_layer, nn.BatchNorm3d, requires_grad=True), activation_callable=partial(nn.ReLU, inplace=True), ), cls_head=X3DHead( diff --git a/src/otx/algo/classification/backbones/efficientnet.py b/src/otx/algo/classification/backbones/efficientnet.py index 2cb8b6fed1e..61f22231cef 100644 --- a/src/otx/algo/classification/backbones/efficientnet.py +++ b/src/otx/algo/classification/backbones/efficientnet.py @@ -6,7 +6,6 @@ from __future__ import annotations import math -from functools import partial from pathlib import Path from typing import Callable, Literal @@ -17,6 +16,7 @@ from otx.algo.modules.activation import Swish from otx.algo.modules.conv_module import Conv2dModule +from otx.algo.modules.norm import build_norm_layer from otx.algo.utils.mmengine_utils import load_checkpoint_to_model PRETRAINED_ROOT = "https://github.com/osmr/imgclsmob/releases/download/v0.0.364/" @@ -45,7 +45,7 @@ def conv1x1_block( padding=padding, groups=groups, bias=bias, - norm_callable=partial(nn.BatchNorm2d, eps=bn_eps) if use_bn else None, + normalization=build_norm_layer(nn.BatchNorm2d, num_features=out_channels, eps=bn_eps) if use_bn else None, activation_callable=activation_callable, ) @@ -72,7 +72,7 @@ def conv3x3_block( dilation=dilation, groups=groups, bias=bias, - norm_callable=partial(nn.BatchNorm2d, eps=bn_eps) if use_bn else None, + normalization=build_norm_layer(nn.BatchNorm2d, num_features=out_channels, eps=bn_eps) if use_bn else None, activation_callable=activation_callable, ) @@ -98,7 +98,7 @@ def dwconv3x3_block( dilation=dilation, groups=out_channels, bias=bias, - norm_callable=partial(nn.BatchNorm2d, eps=bn_eps) if use_bn else None, + normalization=build_norm_layer(nn.BatchNorm2d, num_features=out_channels, eps=bn_eps) if use_bn else None, activation_callable=activation_callable, ) @@ -124,7 +124,7 @@ def dwconv5x5_block( dilation=dilation, groups=out_channels, bias=bias, - norm_callable=partial(nn.BatchNorm2d, eps=bn_eps) if use_bn else None, + normalization=build_norm_layer(nn.BatchNorm2d, num_features=out_channels, eps=bn_eps) if use_bn else None, activation_callable=activation_callable, ) diff --git a/src/otx/algo/classification/utils/swiglu_ffn.py b/src/otx/algo/classification/utils/swiglu_ffn.py index f205a72ee7e..fab2a37cd17 100644 --- a/src/otx/algo/classification/utils/swiglu_ffn.py +++ b/src/otx/algo/classification/utils/swiglu_ffn.py @@ -28,7 +28,7 @@ def __init__( out_dims: int | None = None, bias: bool = True, dropout_layer: dict | None = None, - norm_callable: Callable[..., nn.Module] | None = None, + normalization_callable: Callable[..., nn.Module] | None = None, add_identity: bool = True, ) -> None: super().__init__() @@ -38,8 +38,8 @@ def __init__( self.w12 = nn.Linear(self.embed_dims, 2 * hidden_dims, bias=bias) - if norm_callable is not None: - _, self.norm = build_norm_layer(norm_callable, hidden_dims) + if normalization_callable is not None: + _, self.norm = build_norm_layer(normalization_callable, hidden_dims) else: self.norm = nn.Identity() diff --git a/src/otx/algo/common/backbones/cspnext.py b/src/otx/algo/common/backbones/cspnext.py index 1224faf50eb..2cae930ea66 100644 --- a/src/otx/algo/common/backbones/cspnext.py +++ b/src/otx/algo/common/backbones/cspnext.py @@ -16,6 +16,7 @@ from otx.algo.detection.layers import CSPLayer from otx.algo.modules.base_module import BaseModule from otx.algo.modules.conv_module import Conv2dModule, DepthwiseSeparableConvModule +from otx.algo.modules.norm import build_norm_layer from torch import Tensor, nn from torch.nn.modules.batchnorm import _BatchNorm @@ -44,7 +45,7 @@ class CSPNeXt(BaseModule): layers. Defaults to (5, 9, 13). channel_attention (bool): Whether to add channel attention in each stage. Defaults to True. - norm_callable (Callable[..., nn.Module]): Normalization layer module. + normalization_callable (Callable[..., nn.Module]): Normalization layer module. Defaults to ``partial(nn.BatchNorm2d, momentum=0.03, eps=0.001)``. activation_callable (Callable[..., nn.Module] | None): Activation layer module. Defaults to ``nn.SiLU``. @@ -84,7 +85,7 @@ def __init__( arch_ovewrite: dict | None = None, spp_kernel_sizes: tuple[int, int, int] = (5, 9, 13), channel_attention: bool = True, - norm_callable: Callable[..., nn.Module] = partial(nn.BatchNorm2d, momentum=0.03, eps=0.001), + normalization_callable: Callable[..., nn.Module] = partial(nn.BatchNorm2d, momentum=0.03, eps=0.001), activation_callable: Callable[..., nn.Module] = nn.SiLU, norm_eval: bool = False, init_cfg: dict | None = None, @@ -123,7 +124,10 @@ def __init__( 3, padding=1, stride=2, - norm_callable=norm_callable, + normalization=build_norm_layer( + normalization_callable, + num_features=int(arch_setting[0][0] * widen_factor // 2), + ), activation_callable=activation_callable, ), Conv2dModule( @@ -132,7 +136,10 @@ def __init__( 3, padding=1, stride=1, - norm_callable=norm_callable, + normalization=build_norm_layer( + normalization_callable, + num_features=int(arch_setting[0][0] * widen_factor // 2), + ), activation_callable=activation_callable, ), Conv2dModule( @@ -141,7 +148,10 @@ def __init__( 3, padding=1, stride=1, - norm_callable=norm_callable, + normalization=build_norm_layer( + normalization_callable, + num_features=int(arch_setting[0][0] * widen_factor), + ), activation_callable=activation_callable, ), ) @@ -158,7 +168,7 @@ def __init__( 3, stride=2, padding=1, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=out_channels), activation_callable=activation_callable, ) stage.append(conv_layer) @@ -167,7 +177,7 @@ def __init__( out_channels, out_channels, kernel_sizes=spp_kernel_sizes, - norm_callable=norm_callable, + normalization_callable=normalization_callable, activation_callable=activation_callable, ) stage.append(spp) @@ -180,7 +190,7 @@ def __init__( use_cspnext_block=True, expand_ratio=expand_ratio, channel_attention=channel_attention, - norm_callable=norm_callable, + normalization_callable=normalization_callable, activation_callable=activation_callable, ) stage.append(csp_layer) diff --git a/src/otx/algo/common/backbones/pytorchcv_backbones.py b/src/otx/algo/common/backbones/pytorchcv_backbones.py index 1f0f6de3967..63d56175412 100644 --- a/src/otx/algo/common/backbones/pytorchcv_backbones.py +++ b/src/otx/algo/common/backbones/pytorchcv_backbones.py @@ -29,13 +29,13 @@ def replace_activation(model: nn.Module, activation_callable: Callable[..., nn.M return model -def replace_norm(model: nn.Module, norm_callable: Callable[..., nn.Module]) -> nn.Module: +def replace_norm(model: nn.Module, normalization_callable: Callable[..., nn.Module]) -> nn.Module: """Replace norm funtion.""" for name, module in model._modules.items(): if len(list(module.children())) > 0: - model._modules[name] = replace_norm(module, norm_callable) + model._modules[name] = replace_norm(module, normalization_callable) if "bn" in name: - model._modules[name] = build_norm_layer(norm_callable, num_features=module.num_features)[1] + model._modules[name] = build_norm_layer(normalization_callable, num_features=module.num_features)[1] return model @@ -120,7 +120,7 @@ def _build_pytorchcv_model( norm_eval: bool = False, verbose: bool = False, activation_callable: Callable[..., nn.Module] | None = None, - norm_callable: Callable[..., nn.Module] | None = None, + normalization_callable: Callable[..., nn.Module] | None = None, **kwargs, ) -> nn.Module: """Build pytorchcv model.""" @@ -132,8 +132,8 @@ def _build_pytorchcv_model( model = _models[type](**kwargs) if activation_callable: model = replace_activation(model, activation_callable) - if norm_callable: - model = replace_norm(model, norm_callable) + if normalization_callable: + model = replace_norm(model, normalization_callable) model.out_indices = out_indices model.frozen_stages = frozen_stages model.norm_eval = norm_eval diff --git a/src/otx/algo/common/backbones/resnet.py b/src/otx/algo/common/backbones/resnet.py index 19951522e5e..984608af8a9 100644 --- a/src/otx/algo/common/backbones/resnet.py +++ b/src/otx/algo/common/backbones/resnet.py @@ -34,7 +34,7 @@ def __init__( self, inplanes: int, planes: int, - norm_callable: Callable[..., nn.Module], + normalization_callable: Callable[..., nn.Module], stride: int = 1, dilation: int = 1, downsample: nn.Module | None = None, @@ -48,14 +48,14 @@ def __init__( self.stride = stride self.dilation = dilation self.with_cp = with_cp - self.norm_callable = norm_callable + self.normalization_callable = normalization_callable self.conv1_stride = 1 self.conv2_stride = stride - self.norm1_name, norm1 = build_norm_layer(norm_callable, planes, postfix=1) - self.norm2_name, norm2 = build_norm_layer(norm_callable, planes, postfix=2) - self.norm3_name, norm3 = build_norm_layer(norm_callable, planes * self.expansion, postfix=3) + self.norm1_name, norm1 = build_norm_layer(normalization_callable, planes, postfix=1) + self.norm2_name, norm2 = build_norm_layer(normalization_callable, planes, postfix=2) + self.norm3_name, norm3 = build_norm_layer(normalization_callable, planes * self.expansion, postfix=3) self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=self.conv1_stride, bias=False) self.add_module(self.norm1_name, norm1) @@ -141,7 +141,7 @@ class ResNet(BaseModule): downsampling in the bottleneck. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. - norm_callable (Callable[..., nn.Module] | None): Normalization layer module. + normalization_callable (Callable[..., nn.Module] | None): Normalization layer module. Defaults to ``partial(nn.BatchNorm2d, requires_grad=True)``. norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm @@ -180,7 +180,11 @@ def __init__( out_indices: tuple[int, int, int, int] = (0, 1, 2, 3), avg_down: bool = False, frozen_stages: int = -1, - norm_callable: Callable[..., nn.Module] = partial(nn.BatchNorm2d, requires_grad=True), + normalization_callable: Callable[..., nn.Module] = partial( + build_norm_layer, + nn.BatchNorm2d, + requires_grad=True, + ), norm_eval: bool = True, with_cp: bool = False, zero_init_residual: bool = True, @@ -233,7 +237,7 @@ def __init__( raise ValueError(msg) self.avg_down = avg_down self.frozen_stages = frozen_stages - self.norm_callable = norm_callable + self.normalization_callable = normalization_callable self.with_cp = with_cp self.norm_eval = norm_eval self.block, stage_blocks = self.arch_settings[depth] @@ -256,7 +260,7 @@ def __init__( dilation=dilation, avg_down=self.avg_down, with_cp=with_cp, - norm_callable=norm_callable, + normalization_callable=normalization_callable, init_cfg=block_init_cfg, ) self.inplanes = planes * self.block.expansion @@ -286,7 +290,12 @@ def _make_stem_layer(self, in_channels: int, stem_channels: int) -> None: padding=3, bias=False, ) - self.norm1_name, norm1 = build_norm_layer(self.norm_callable, stem_channels, postfix=1) + self.norm1_name, norm1 = build_norm_layer( + self.normalization_callable, + stem_channels, + postfix=1, + requires_grad=True, + ) self.add_module(self.norm1_name, norm1) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) diff --git a/src/otx/algo/common/backbones/resnext.py b/src/otx/algo/common/backbones/resnext.py index 8a7a1a5b374..f827d0508fa 100644 --- a/src/otx/algo/common/backbones/resnext.py +++ b/src/otx/algo/common/backbones/resnext.py @@ -42,9 +42,9 @@ def __init__( width = self.planes if groups == 1 else math.floor(self.planes * (base_width / base_channels)) * groups - self.norm1_name, norm1 = build_norm_layer(self.norm_callable, width, postfix=1) - self.norm2_name, norm2 = build_norm_layer(self.norm_callable, width, postfix=2) - self.norm3_name, norm3 = build_norm_layer(self.norm_callable, self.planes * self.expansion, postfix=3) + self.norm1_name, norm1 = build_norm_layer(self.normalization_callable, width, postfix=1) + self.norm2_name, norm2 = build_norm_layer(self.normalization_callable, width, postfix=2) + self.norm3_name, norm3 = build_norm_layer(self.normalization_callable, self.planes * self.expansion, postfix=3) self.conv1 = nn.Conv2d( self.inplanes, @@ -95,7 +95,7 @@ class ResNeXt(ResNet): the first 1x1 conv layer. frozen_stages (int): Stages to be frozen (all param fixed). -1 means not freezing any parameters. - norm_callable (Callable[..., nn.Module] | None): Normalization layer module. + normalization_callable (Callable[..., nn.Module] | None): Normalization layer module. Defaults to ``partial(nn.BatchNorm2d, requires_grad=True)``. norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm diff --git a/src/otx/algo/common/layers/res_layer.py b/src/otx/algo/common/layers/res_layer.py index 154cad18150..0e204d86030 100644 --- a/src/otx/algo/common/layers/res_layer.py +++ b/src/otx/algo/common/layers/res_layer.py @@ -26,7 +26,7 @@ class ResLayer(Sequential): stride (int): stride of the first block. Defaults to 1 avg_down (bool): Use AvgPool instead of stride conv when downsampling in the bottleneck. Defaults to False - norm_callable (Callable[..., nn.Module]): Normalization layer module. + normalization_callable (Callable[..., nn.Module]): Normalization layer module. Defaults to ``nn.BatchNorm2d``. downsample_first (bool): Downsample at the first block or last block. False for Hourglass, True for ResNet. Defaults to True @@ -38,7 +38,7 @@ def __init__( inplanes: int, planes: int, num_blocks: int, - norm_callable: Callable[..., nn.Module], + normalization_callable: Callable[..., nn.Module], stride: int = 1, avg_down: bool = False, downsample_first: bool = True, @@ -69,7 +69,7 @@ def __init__( stride=conv_stride, bias=False, ), - build_norm_layer(norm_callable, planes * block.expansion)[1], + build_norm_layer(normalization_callable, planes * block.expansion)[1], ], ) downsample = nn.Sequential(*downsample) @@ -82,14 +82,20 @@ def __init__( planes=planes, stride=stride, downsample=downsample, - norm_callable=norm_callable, + normalization_callable=normalization_callable, **kwargs, ), ) inplanes = planes * block.expansion layers.extend( [ - block(inplanes=inplanes, planes=planes, stride=1, norm_callable=norm_callable, **kwargs) + block( + inplanes=inplanes, + planes=planes, + stride=1, + normalization_callable=normalization_callable, + **kwargs, + ) for _ in range(1, num_blocks) ], ) diff --git a/src/otx/algo/common/layers/spp_layer.py b/src/otx/algo/common/layers/spp_layer.py index cb580ebdf64..dbd0799cfda 100644 --- a/src/otx/algo/common/layers/spp_layer.py +++ b/src/otx/algo/common/layers/spp_layer.py @@ -15,6 +15,7 @@ from otx.algo.modules.activation import Swish from otx.algo.modules.base_module import BaseModule from otx.algo.modules.conv_module import Conv2dModule +from otx.algo.modules.norm import build_norm_layer from torch import Tensor, nn @@ -26,7 +27,7 @@ class SPPBottleneck(BaseModule): out_channels (int): The output channels of this Module. kernel_sizes (tuple[int]): Sequential of kernel sizes of pooling layers. Default: (5, 9, 13). - norm_callable (Callable[..., nn.Module]): Normalization layer module. + normalization_callable (Callable[..., nn.Module]): Normalization layer module. Defaults to ``partial(nn.BatchNorm2d, momentum=0.03, eps=0.001)``. activation_callable (Callable[..., nn.Module] | None): Activation layer module. Defaults to ``Swish``. @@ -39,7 +40,7 @@ def __init__( in_channels: int, out_channels: int, kernel_sizes: tuple[int, ...] = (5, 9, 13), - norm_callable: Callable[..., nn.Module] = partial(nn.BatchNorm2d, momentum=0.03, eps=0.001), + normalization_callable: Callable[..., nn.Module] = partial(nn.BatchNorm2d, momentum=0.03, eps=0.001), activation_callable: Callable[..., nn.Module] | None = Swish, init_cfg: dict | list[dict] | None = None, ): @@ -50,7 +51,7 @@ def __init__( mid_channels, 1, stride=1, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=mid_channels), activation_callable=activation_callable, ) self.poolings = nn.ModuleList([nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2) for ks in kernel_sizes]) @@ -59,7 +60,7 @@ def __init__( conv2_channels, out_channels, 1, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=out_channels), activation_callable=activation_callable, ) diff --git a/src/otx/algo/detection/backbones/csp_darknet.py b/src/otx/algo/detection/backbones/csp_darknet.py index a0c4c2d3b66..c4000599505 100644 --- a/src/otx/algo/detection/backbones/csp_darknet.py +++ b/src/otx/algo/detection/backbones/csp_darknet.py @@ -21,6 +21,7 @@ from otx.algo.modules.activation import Swish from otx.algo.modules.base_module import BaseModule from otx.algo.modules.conv_module import Conv2dModule, DepthwiseSeparableConvModule +from otx.algo.modules.norm import build_norm_layer class Focus(nn.Module): @@ -31,7 +32,7 @@ class Focus(nn.Module): out_channels (int): The output channels of this Module. kernel_size (int): The kernel size of the convolution. Default: 1 stride (int): The stride of the convolution. Default: 1 - norm_callable (Callable[..., nn.Module] | None): Normalization layer module. + normalization_callable (Callable[..., nn.Module] | None): Normalization layer module. Defaults to ``partial(nn.BatchNorm2d, momentum=0.03, eps=0.001)``. activation_callable (Callable[..., nn.Module] | None): Activation layer module. Defaults to ``Swish``. @@ -43,7 +44,7 @@ def __init__( out_channels: int, kernel_size: int = 1, stride: int = 1, - norm_callable: Callable[..., nn.Module] = partial(nn.BatchNorm2d, momentum=0.03, eps=0.001), + normalization_callable: Callable[..., nn.Module] = partial(nn.BatchNorm2d, momentum=0.03, eps=0.001), activation_callable: Callable[..., nn.Module] | None = Swish, ): super().__init__() @@ -53,7 +54,7 @@ def __init__( kernel_size, stride, padding=(kernel_size - 1) // 2, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=out_channels), activation_callable=activation_callable, ) @@ -108,7 +109,7 @@ class CSPDarknet(BaseModule): arch_ovewrite(list): Overwrite default arch settings. Default: None. spp_kernal_sizes: (tuple[int]): Sequential of kernel sizes of SPP layers. Default: (5, 9, 13). - norm_callable (Callable[..., nn.Module] | None): Normalization layer module. + normalization_callable (Callable[..., nn.Module] | None): Normalization layer module. Defaults to ``partial(nn.BatchNorm2d, momentum=0.03, eps=0.001)``. activation_callable (Callable[..., nn.Module] | None): Activation layer module. Defaults to ``Swish``. @@ -147,7 +148,7 @@ def __init__( use_depthwise: bool = False, arch_ovewrite: list | None = None, spp_kernal_sizes: tuple[int, ...] = (5, 9, 13), - norm_callable: Callable[..., nn.Module] = partial(nn.BatchNorm2d, momentum=0.03, eps=0.001), + normalization_callable: Callable[..., nn.Module] = partial(nn.BatchNorm2d, momentum=0.03, eps=0.001), activation_callable: Callable[..., nn.Module] = Swish, norm_eval: bool = False, init_cfg: dict | list[dict] | None = None, @@ -180,7 +181,7 @@ def __init__( 3, int(arch_setting[0][0] * widen_factor), kernel_size=3, - norm_callable=norm_callable, + normalization_callable=normalization_callable, activation_callable=activation_callable, ) self.layers = ["stem"] @@ -196,7 +197,7 @@ def __init__( 3, stride=2, padding=1, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=out_channels), activation_callable=activation_callable, ) stage.append(conv_layer) @@ -205,7 +206,7 @@ def __init__( out_channels, out_channels, kernel_sizes=spp_kernal_sizes, - norm_callable=norm_callable, + normalization_callable=normalization_callable, activation_callable=activation_callable, ) stage.append(spp) @@ -215,7 +216,7 @@ def __init__( num_blocks=num_blocks, add_identity=add_identity, use_depthwise=use_depthwise, - norm_callable=norm_callable, + normalization_callable=normalization_callable, activation_callable=activation_callable, ) stage.append(csp_layer) diff --git a/src/otx/algo/detection/backbones/presnet.py b/src/otx/algo/detection/backbones/presnet.py index e97e39badff..1c97be6c457 100644 --- a/src/otx/algo/detection/backbones/presnet.py +++ b/src/otx/algo/detection/backbones/presnet.py @@ -13,6 +13,7 @@ from otx.algo.modules.base_module import BaseModule from otx.algo.modules.conv_module import Conv2dModule +from otx.algo.modules.norm import build_norm_layer __all__ = ["PResNet"] @@ -30,8 +31,7 @@ def __init__( shortcut: bool, activation_callable: Callable[..., nn.Module] | None = None, variant: str = "b", - norm_callable: Callable[..., nn.Module] | None = None, - norm_name: str | None = None, + normalization_callable: Callable[..., nn.Module] | None = None, ) -> None: super().__init__() @@ -51,8 +51,7 @@ def __init__( 1, 1, activation_callable=None, - norm_callable=norm_callable, - norm_name=norm_name, + normalization=build_norm_layer(normalization_callable, num_features=ch_out), ), ), ], @@ -65,8 +64,7 @@ def __init__( 1, stride, activation_callable=None, - norm_callable=norm_callable, - norm_name=norm_name, + normalization=build_norm_layer(normalization_callable, num_features=ch_out), ) self.branch2a = Conv2dModule( @@ -76,8 +74,7 @@ def __init__( stride, padding=1, activation_callable=activation_callable, - norm_callable=norm_callable, - norm_name=norm_name, + normalization=build_norm_layer(normalization_callable, num_features=ch_out), ) self.branch2b = Conv2dModule( ch_out, @@ -86,8 +83,7 @@ def __init__( 1, padding=1, activation_callable=None, - norm_callable=norm_callable, - norm_name=norm_name, + normalization=build_norm_layer(normalization_callable, num_features=ch_out), ) self.act = activation_callable() if activation_callable else nn.Identity() @@ -115,8 +111,7 @@ def __init__( shortcut: bool, activation_callable: Callable[..., nn.Module] | None = None, variant: str = "b", - norm_callable: Callable[..., nn.Module] | None = None, - norm_name: str | None = None, + normalization_callable: Callable[..., nn.Module] | None = None, ) -> None: super().__init__() @@ -133,8 +128,7 @@ def __init__( 1, stride1, activation_callable=activation_callable, - norm_callable=norm_callable, - norm_name=norm_name, + normalization=build_norm_layer(normalization_callable, num_features=width), ) self.branch2b = Conv2dModule( width, @@ -143,8 +137,7 @@ def __init__( stride2, padding=1, activation_callable=activation_callable, - norm_callable=norm_callable, - norm_name=norm_name, + normalization=build_norm_layer(normalization_callable, num_features=width), ) self.branch2c = Conv2dModule( width, @@ -152,8 +145,10 @@ def __init__( 1, 1, activation_callable=None, - norm_callable=norm_callable, - norm_name=norm_name, + normalization=build_norm_layer( + normalization_callable, + num_features=ch_out * self.expansion, + ), ) self.shortcut = shortcut @@ -171,8 +166,10 @@ def __init__( 1, 1, activation_callable=None, - norm_callable=norm_callable, - norm_name=norm_name, + normalization=build_norm_layer( + normalization_callable, + num_features=ch_out * self.expansion, + ), ), ), ], @@ -185,8 +182,10 @@ def __init__( 1, stride, activation_callable=None, - norm_callable=norm_callable, - norm_name=norm_name, + normalization=build_norm_layer( + normalization_callable, + num_features=ch_out * self.expansion, + ), ) self.act = activation_callable() if activation_callable else nn.Identity() @@ -213,8 +212,7 @@ def __init__( stage_num: int, activation_callable: Callable[..., nn.Module] | None = None, variant: str = "b", - norm_callable: Callable[..., nn.Module] | None = None, - norm_name: str | None = None, + normalization_callable: Callable[..., nn.Module] | None = None, ) -> None: super().__init__() @@ -228,8 +226,7 @@ def __init__( shortcut=i != 0, variant=variant, activation_callable=activation_callable, - norm_callable=norm_callable, - norm_name=norm_name, + normalization_callable=normalization_callable, ), ) @@ -254,10 +251,8 @@ class PResNet(BaseModule): return_idx (list[int]): The indices of the stages to return as output. Defaults to [0, 1, 2, 3]. activation_callable (Callable[..., nn.Module] | None): Activation layer module. Defaults to None. - norm_callable (Callable[..., nn.Module] | None): Normalization layer module. + normalization_callable (Callable[..., nn.Module] | None): Normalization layer module. Defaults to ``nn.BatchNorm2d``. - norm_name (str | None): The name of the normalization layer fpr ``build_norm_layer``. - Defaults to 'norm'. freeze_at (int): The stage at which to freeze the parameters. Defaults to -1. pretrained (bool): Whether to load pretrained weights. Defaults to False. """ @@ -283,8 +278,7 @@ def __init__( num_stages: int = 4, return_idx: list[int] = [0, 1, 2, 3], # noqa: B006 activation_callable: Callable[..., nn.Module] | None = nn.ReLU, - norm_callable: Callable[..., nn.Module] = nn.BatchNorm2d, - norm_name: str = "norm", + normalization_callable: Callable[..., nn.Module] = nn.BatchNorm2d, freeze_at: int = -1, pretrained: bool = False, ) -> None: @@ -314,8 +308,7 @@ def __init__( s, padding=(k - 1) // 2, activation_callable=activation_callable, - norm_callable=norm_callable, - norm_name=norm_name, + normalization=build_norm_layer(normalization_callable, num_features=c_out), ), ) for c_in, c_out, k, s, _name in conv_def @@ -341,8 +334,7 @@ def __init__( stage_num, activation_callable=activation_callable, variant=variant, - norm_callable=norm_callable, - norm_name=norm_name, + normalization_callable=normalization_callable, ), ) ch_in = _out_channels[i] diff --git a/src/otx/algo/detection/heads/atss_head.py b/src/otx/algo/detection/heads/atss_head.py index fd8748dc463..f9549c605a1 100644 --- a/src/otx/algo/detection/heads/atss_head.py +++ b/src/otx/algo/detection/heads/atss_head.py @@ -24,6 +24,7 @@ from otx.algo.detection.utils.prior_generators.utils import anchor_inside_flags from otx.algo.detection.utils.utils import unmap from otx.algo.modules.conv_module import Conv2dModule +from otx.algo.modules.norm import build_norm_layer from otx.algo.modules.scale import Scale from otx.algo.utils.mmengine_utils import InstanceData @@ -42,8 +43,8 @@ class ATSSHead(ClassIncrementalMixin, AnchorHead): in_channels (int): Number of channels in the input feature map. pred_kernel_size (int): Kernel size of ``nn.Conv2d``. Defaults to 3. stacked_convs (int): Number of stacking convs of the head. Defaults to 4. - norm_callable (Callable[..., nn.Module] | None): Normalization layer module. - Defaults to ``partial(nn.GroupNorm, num_groups=32, requires_grad=True)``. + normalization_callable (Callable[..., nn.Module] | None): Normalization layer module. + Defaults to ``partial(build_norm_layer, nn.GroupNorm, num_groups=32, requires_grad=True)``. reg_decoded_bbox (bool): If true, the regression loss would be applied directly on decoded bounding boxes, converting both the predicted boxes and regression targets to absolute @@ -59,7 +60,12 @@ def __init__( in_channels: int, pred_kernel_size: int = 3, stacked_convs: int = 4, - norm_callable: Callable[..., nn.Module] = partial(nn.GroupNorm, num_groups=32, requires_grad=True), + normalization_callable: Callable[..., nn.Module] = partial( + build_norm_layer, + nn.GroupNorm, + num_groups=32, + requires_grad=True, + ), reg_decoded_bbox: bool = True, loss_centerness: nn.Module | None = None, init_cfg: dict | None = None, @@ -70,7 +76,7 @@ def __init__( ) -> None: self.pred_kernel_size = pred_kernel_size self.stacked_convs = stacked_convs - self.norm_callable = norm_callable + self.normalization_callable = normalization_callable init_cfg = init_cfg or { "type": "Normal", "layer": "Conv2d", @@ -116,7 +122,7 @@ def _init_layers(self) -> None: 3, stride=1, padding=1, - norm_callable=self.norm_callable, + normalization=build_norm_layer(self.normalization_callable, num_features=self.feat_channels), ), ) self.reg_convs.append( @@ -126,7 +132,7 @@ def _init_layers(self) -> None: 3, stride=1, padding=1, - norm_callable=self.norm_callable, + normalization=build_norm_layer(self.normalization_callable, num_features=self.feat_channels), ), ) pred_pad_size = self.pred_kernel_size // 2 diff --git a/src/otx/algo/detection/heads/rtmdet_head.py b/src/otx/algo/detection/heads/rtmdet_head.py index 73a431f24ab..9b18790b18a 100644 --- a/src/otx/algo/detection/heads/rtmdet_head.py +++ b/src/otx/algo/detection/heads/rtmdet_head.py @@ -24,7 +24,7 @@ unmap, ) from otx.algo.modules.conv_module import Conv2dModule, DepthwiseSeparableConvModule -from otx.algo.modules.norm import is_norm +from otx.algo.modules.norm import build_norm_layer, is_norm from otx.algo.modules.scale import Scale from otx.algo.utils.mmengine_utils import InstanceData from otx.algo.utils.weight_init import bias_init_with_prob, constant_init, normal_init @@ -69,7 +69,7 @@ def _init_layers(self) -> None: 3, stride=1, padding=1, - norm_callable=self.norm_callable, + normalization=build_norm_layer(self.normalization_callable, num_features=self.feat_channels), activation_callable=self.activation_callable, ), ) @@ -80,7 +80,7 @@ def _init_layers(self) -> None: 3, stride=1, padding=1, - norm_callable=self.norm_callable, + normalization=build_norm_layer(self.normalization_callable, num_features=self.feat_channels), activation_callable=self.activation_callable, ), ) @@ -644,7 +644,7 @@ class RTMDetSepBNHead(RTMDetHead): Defaults to True. use_depthwise (bool): Whether to use depthwise separable convolution in head. Defaults to False. - norm_callable (Callable[..., nn.Module]): Normalization layer module. + normalization_callable (Callable[..., nn.Module]): Normalization layer module. Defaults to ``partial(nn.BatchNorm2d, momentum=0.03, eps=0.001)``. activation_callable (Callable[..., nn.Module]): Activation layer module. Defaults to ``nn.SiLU``. @@ -658,7 +658,7 @@ def __init__( in_channels: int, share_conv: bool = True, use_depthwise: bool = False, - norm_callable: Callable[..., nn.Module] = partial(nn.BatchNorm2d, momentum=0.03, eps=0.001), + normalization_callable: Callable[..., nn.Module] = partial(nn.BatchNorm2d, momentum=0.03, eps=0.001), activation_callable: Callable[..., nn.Module] = nn.SiLU, pred_kernel_size: int = 1, exp_on_reg: bool = False, @@ -670,7 +670,7 @@ def __init__( super().__init__( num_classes, in_channels, - norm_callable=norm_callable, + normalization_callable=normalization_callable, activation_callable=activation_callable, pred_kernel_size=pred_kernel_size, **kwargs, @@ -698,7 +698,7 @@ def _init_layers(self) -> None: 3, stride=1, padding=1, - norm_callable=self.norm_callable, + normalization=build_norm_layer(self.normalization_callable, num_features=self.feat_channels), activation_callable=self.activation_callable, ), ) @@ -709,7 +709,7 @@ def _init_layers(self) -> None: 3, stride=1, padding=1, - norm_callable=self.norm_callable, + normalization=build_norm_layer(self.normalization_callable, num_features=self.feat_channels), activation_callable=self.activation_callable, ), ) diff --git a/src/otx/algo/detection/heads/yolox_head.py b/src/otx/algo/detection/heads/yolox_head.py index f1c7ae7732c..8ad42a3d9e1 100644 --- a/src/otx/algo/detection/heads/yolox_head.py +++ b/src/otx/algo/detection/heads/yolox_head.py @@ -27,6 +27,7 @@ from otx.algo.detection.losses import IoULoss from otx.algo.modules.activation import Swish from otx.algo.modules.conv_module import Conv2dModule, DepthwiseSeparableConvModule +from otx.algo.modules.norm import build_norm_layer from otx.algo.utils.mmengine_utils import InstanceData logger = logging.getLogger() @@ -49,9 +50,9 @@ class YOLOXHead(BaseDenseHead): dcn_on_last_conv (bool): If true, use dcn in the last layer of towers. Defaults to False. conv_bias (bool or str): If specified as `auto`, it will be decided by - the norm_callable. Bias of conv will be set as True if `norm_callable` is + the normalization_callable. Bias of conv will be set as True if `normalization_callable` is None, otherwise False. Defaults to "auto". - norm_callable (Callable[..., nn.Module]): Normalization layer module. + normalization_callable (Callable[..., nn.Module]): Normalization layer module. Defaults to ``partial(nn.BatchNorm2d, momentum=0.03, eps=0.001)``. activation_callable (Callable[..., nn.Module]): Activation layer module. Defaults to ``Swish``. @@ -77,7 +78,7 @@ def __init__( use_depthwise: bool = False, dcn_on_last_conv: bool = False, conv_bias: bool | str = "auto", - norm_callable: Callable[..., nn.Module] = partial(nn.BatchNorm2d, momentum=0.03, eps=0.001), + normalization_callable: Callable[..., nn.Module] = partial(nn.BatchNorm2d, momentum=0.03, eps=0.001), activation_callable: Callable[..., nn.Module] = Swish, loss_cls: nn.Module | None = None, loss_bbox: nn.Module | None = None, @@ -113,7 +114,7 @@ def __init__( self.conv_bias = conv_bias self.use_sigmoid_cls = True - self.norm_callable = norm_callable + self.normalization_callable = normalization_callable self.activation_callable = activation_callable self.loss_cls = loss_cls or CrossEntropyLoss(use_sigmoid=True, reduction="sum", loss_weight=1.0) @@ -171,7 +172,7 @@ def _build_stacked_convs(self) -> nn.Sequential: 3, stride=1, padding=1, - norm_callable=self.norm_callable, + normalization=build_norm_layer(self.normalization_callable, num_features=self.feat_channels), activation_callable=self.activation_callable, bias=self.conv_bias, ), diff --git a/src/otx/algo/detection/layers/csp_layer.py b/src/otx/algo/detection/layers/csp_layer.py index 6d29b14e5a3..fbcecb41100 100644 --- a/src/otx/algo/detection/layers/csp_layer.py +++ b/src/otx/algo/detection/layers/csp_layer.py @@ -15,6 +15,7 @@ from otx.algo.modules.activation import Swish from otx.algo.modules.base_module import BaseModule from otx.algo.modules.conv_module import Conv2dModule, DepthwiseSeparableConvModule +from otx.algo.modules.norm import build_norm_layer class DarknetBottleneck(BaseModule): @@ -34,7 +35,7 @@ class DarknetBottleneck(BaseModule): Defaults to True. use_depthwise (bool): Whether to use depthwise separable convolution. Defaults to False. - norm_callable (Callable[..., nn.Module]): Normalization layer module. + normalization_callable (Callable[..., nn.Module]): Normalization layer module. Defaults to ``partial(nn.BatchNorm2d, momentum=0.03, eps=0.001)``. activation_callable (Callable[..., nn.Module]): Activation layer module. Defaults to ``Swish``. @@ -47,7 +48,7 @@ def __init__( expansion: float = 0.5, add_identity: bool = True, use_depthwise: bool = False, - norm_callable: Callable[..., nn.Module] = partial(nn.BatchNorm2d, momentum=0.03, eps=0.001), + normalization_callable: Callable[..., nn.Module] = partial(nn.BatchNorm2d, momentum=0.03, eps=0.001), activation_callable: Callable[..., nn.Module] = Swish, init_cfg: dict | list[dict] | None = None, ) -> None: @@ -59,7 +60,7 @@ def __init__( in_channels, hidden_channels, 1, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=hidden_channels), activation_callable=activation_callable, ) self.conv2 = conv( @@ -68,7 +69,7 @@ def __init__( 3, stride=1, padding=1, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=out_channels), activation_callable=activation_callable, ) self.add_identity = add_identity and in_channels == out_channels @@ -97,7 +98,7 @@ class CSPNeXtBlock(BaseModule): Defaults to False. kernel_size (int): The kernel size of the second convolution layer. Defaults to 5. - norm_callable (Callable[..., nn.Module] | None): Normalization layer module. + normalization_callable (Callable[..., nn.Module] | None): Normalization layer module. Defaults to ``partial(nn.BatchNorm2d, momentum=0.03, eps=0.001)``. activation_callable (Callable[..., nn.Module]): Activation layer module. Defaults to ``nn.SiLU``. @@ -113,7 +114,7 @@ def __init__( add_identity: bool = True, use_depthwise: bool = False, kernel_size: int = 5, - norm_callable: Callable[..., nn.Module] = partial(nn.BatchNorm2d, momentum=0.03, eps=0.001), + normalization_callable: Callable[..., nn.Module] = partial(nn.BatchNorm2d, momentum=0.03, eps=0.001), activation_callable: Callable[..., nn.Module] = nn.SiLU, init_cfg: dict | list[dict] | None = None, ) -> None: @@ -127,7 +128,7 @@ def __init__( 3, stride=1, padding=1, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=hidden_channels), activation_callable=activation_callable, ) self.conv2 = DepthwiseSeparableConvModule( @@ -136,7 +137,7 @@ def __init__( kernel_size, stride=1, padding=kernel_size // 2, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=out_channels), activation_callable=activation_callable, ) self.add_identity = add_identity and in_channels == out_channels @@ -160,9 +161,7 @@ class RepVggBlock(nn.Module): ch_out (int): The output channels of this Module. activation_callable (Callable[..., nn.Module] | None): Activation layer module. Defaults to None. - norm_callable (Callable[..., nn.Module] | None): Normalization layer module. - Defaults to None. - norm_name (str | None): The name of the normalization layer fpr ``build_norm_layer``. + normalization_callable (Callable[..., nn.Module] | None): Normalization layer module. Defaults to None. """ @@ -171,8 +170,7 @@ def __init__( ch_in: int, ch_out: int, activation_callable: Callable[..., nn.Module] | None = None, - norm_callable: Callable[..., nn.Module] | None = None, - norm_name: str | None = None, + normalization_callable: Callable[..., nn.Module] | None = None, ) -> None: """Initialize RepVggBlock.""" super().__init__() @@ -185,8 +183,7 @@ def __init__( 1, padding=1, activation_callable=None, - norm_callable=norm_callable, - norm_name=norm_name, + normalization=build_norm_layer(normalization_callable, num_features=ch_out), ) self.conv2 = Conv2dModule( ch_in, @@ -194,8 +191,7 @@ def __init__( 1, 1, activation_callable=None, - norm_callable=norm_callable, - norm_name=norm_name, + normalization=build_norm_layer(normalization_callable, num_features=ch_out), ) self.act = activation_callable() if activation_callable else nn.Identity() @@ -249,7 +245,7 @@ class CSPLayer(BaseModule): blocks. Defaults to False. channel_attention (bool): Whether to add channel attention in each stage. Defaults to True. - norm_callable (Callable[..., nn.Module]): Normalization layer module. + normalization_callable (Callable[..., nn.Module]): Normalization layer module. Defaults to ``partial(nn.BatchNorm2d, momentum=0.03, eps=0.001)``. activation_callable (Callable[..., nn.Module] | None): Activation layer module. Defaults to ``Swish``. @@ -267,7 +263,7 @@ def __init__( use_depthwise: bool = False, use_cspnext_block: bool = False, channel_attention: bool = False, - norm_callable: Callable[..., nn.Module] = partial(nn.BatchNorm2d, momentum=0.03, eps=0.001), + normalization_callable: Callable[..., nn.Module] = partial(nn.BatchNorm2d, momentum=0.03, eps=0.001), activation_callable: Callable[..., nn.Module] | None = Swish, init_cfg: dict | list[dict] | None = None, ) -> None: @@ -280,21 +276,21 @@ def __init__( in_channels, mid_channels, 1, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=mid_channels), activation_callable=activation_callable, ) self.short_conv = Conv2dModule( in_channels, mid_channels, 1, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=mid_channels), activation_callable=activation_callable, ) self.final_conv = Conv2dModule( 2 * mid_channels, out_channels, 1, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=out_channels), activation_callable=activation_callable, ) @@ -306,7 +302,7 @@ def __init__( 1.0, add_identity, use_depthwise, - norm_callable=norm_callable, + normalization_callable=normalization_callable, activation_callable=activation_callable, ) for _ in range(num_blocks) @@ -342,9 +338,7 @@ class CSPRepLayer(nn.Module): Defaults to False. activation_callable (Callable[..., nn.Module] | None): Activation layer module. Defaults to None. - norm_callable (Callable[..., nn.Module] | None): Normalization layer module. - Defaults to None. - norm_name (str | None): The name of the normalization layer fpr ``build_norm_layer``. + normalization_callable (Callable[..., nn.Module] | None): Normalization layer module. Defaults to None. """ @@ -356,8 +350,7 @@ def __init__( expansion: float = 1.0, bias: bool = False, activation_callable: Callable[..., nn.Module] | None = None, - norm_callable: Callable[..., nn.Module] | None = None, - norm_name: str | None = None, + normalization_callable: Callable[..., nn.Module] | None = None, ) -> None: """Initialize CSPRepLayer.""" super().__init__() @@ -369,8 +362,7 @@ def __init__( 1, bias=bias, activation_callable=activation_callable, - norm_callable=norm_callable, - norm_name=norm_name, + normalization=build_norm_layer(normalization_callable, num_features=hidden_channels), ) self.conv2 = Conv2dModule( in_channels, @@ -379,8 +371,7 @@ def __init__( 1, bias=bias, activation_callable=activation_callable, - norm_callable=norm_callable, - norm_name=norm_name, + normalization=build_norm_layer(normalization_callable, num_features=hidden_channels), ) self.bottlenecks = nn.Sequential( *[ @@ -388,8 +379,7 @@ def __init__( hidden_channels, hidden_channels, activation_callable=activation_callable, - norm_callable=norm_callable, - norm_name=norm_name, + normalization_callable=normalization_callable, ) for _ in range(num_blocks) ], @@ -402,8 +392,7 @@ def __init__( 1, bias=bias, activation_callable=activation_callable, - norm_callable=norm_callable, - norm_name=norm_name, + normalization=build_norm_layer(normalization_callable, num_features=out_channels), ) else: self.conv3 = nn.Identity() diff --git a/src/otx/algo/detection/necks/cspnext_pafpn.py b/src/otx/algo/detection/necks/cspnext_pafpn.py index 80acd42c8ee..a7236472268 100644 --- a/src/otx/algo/detection/necks/cspnext_pafpn.py +++ b/src/otx/algo/detection/necks/cspnext_pafpn.py @@ -21,6 +21,7 @@ from otx.algo.modules.activation import Swish from otx.algo.modules.base_module import BaseModule from otx.algo.modules.conv_module import Conv2dModule, DepthwiseSeparableConvModule +from otx.algo.modules.norm import build_norm_layer class CSPNeXtPAFPN(BaseModule): @@ -33,7 +34,7 @@ class CSPNeXtPAFPN(BaseModule): use_depthwise (bool): Whether to use depthwise separable convolution in blocks. Defaults to False. expand_ratio (float): Ratio to adjust the number of channels of the hidden layer. Default: 0.5 upsample_cfg (dict): Config dict for interpolate layer. Default: `dict(scale_factor=2, mode='nearest')` - norm_callable (Callable[..., nn.Module] | None): Normalization layer module. + normalization_callable (Callable[..., nn.Module] | None): Normalization layer module. Defaults to ``partial(nn.BatchNorm2d, momentum=0.03, eps=0.001)``. activation_callable (Callable[..., nn.Module]): Activation layer module. Defaults to ``Swish``. @@ -48,7 +49,7 @@ def __init__( use_depthwise: bool = False, expand_ratio: float = 0.5, upsample_cfg: dict | None = None, - norm_callable: Callable[..., nn.Module] = partial(nn.BatchNorm2d, momentum=0.03, eps=0.001), + normalization_callable: Callable[..., nn.Module] = partial(nn.BatchNorm2d, momentum=0.03, eps=0.001), activation_callable: Callable[..., nn.Module] = Swish, init_cfg: dict | None = None, ) -> None: @@ -78,7 +79,7 @@ def __init__( in_channels[idx], in_channels[idx - 1], 1, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=in_channels[idx - 1]), activation_callable=activation_callable, ), ) @@ -91,7 +92,7 @@ def __init__( use_depthwise=use_depthwise, use_cspnext_block=True, expand_ratio=expand_ratio, - norm_callable=norm_callable, + normalization_callable=normalization_callable, activation_callable=activation_callable, ), ) @@ -107,7 +108,7 @@ def __init__( 3, stride=2, padding=1, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=in_channels[idx]), activation_callable=activation_callable, ), ) @@ -120,7 +121,7 @@ def __init__( use_depthwise=use_depthwise, use_cspnext_block=True, expand_ratio=expand_ratio, - norm_callable=norm_callable, + normalization_callable=normalization_callable, activation_callable=activation_callable, ), ) @@ -133,7 +134,7 @@ def __init__( out_channels, 3, padding=1, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=out_channels), activation_callable=activation_callable, ), ) diff --git a/src/otx/algo/detection/necks/fpn.py b/src/otx/algo/detection/necks/fpn.py index 672f28b8709..d678d8f5e6b 100644 --- a/src/otx/algo/detection/necks/fpn.py +++ b/src/otx/algo/detection/necks/fpn.py @@ -16,6 +16,7 @@ from otx.algo.modules.base_module import BaseModule from otx.algo.modules.conv_module import Conv2dModule +from otx.algo.modules.norm import build_norm_layer class FPN(BaseModule): @@ -45,7 +46,7 @@ class FPN(BaseModule): conv. Defaults to False. no_norm_on_lateral (bool): Whether to apply norm on lateral. Defaults to False. - norm_callable (Callable[..., nn.Module] | None): Normalization layer module. + normalization_callable (Callable[..., nn.Module] | None): Normalization layer module. Defaults to None. activation_callable (Callable[..., nn.Module] | None): Activation layer module. Defaults to None. @@ -64,7 +65,7 @@ def __init__( add_extra_convs: bool | str = False, relu_before_extra_convs: bool = False, no_norm_on_lateral: bool = False, - norm_callable: Callable[..., nn.Module] | None = None, + normalization_callable: Callable[..., nn.Module] | None = None, activation_callable: Callable[..., nn.Module] | None = None, upsample_cfg: dict | None = None, init_cfg: dict | list[dict] | None = None, @@ -104,7 +105,9 @@ def __init__( in_channels[i], out_channels, 1, - norm_callable=norm_callable if not self.no_norm_on_lateral else None, + normalization=build_norm_layer(normalization_callable, num_features=out_channels) + if not self.no_norm_on_lateral + else None, activation_callable=activation_callable, inplace=False, ) @@ -113,7 +116,7 @@ def __init__( out_channels, 3, padding=1, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=out_channels), activation_callable=activation_callable, inplace=False, ) @@ -135,7 +138,7 @@ def __init__( 3, stride=2, padding=1, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=out_channels), activation_callable=activation_callable, inplace=False, ) diff --git a/src/otx/algo/detection/necks/hybrid_encoder.py b/src/otx/algo/detection/necks/hybrid_encoder.py index 03a600ae9f0..0466e0d9958 100644 --- a/src/otx/algo/detection/necks/hybrid_encoder.py +++ b/src/otx/algo/detection/necks/hybrid_encoder.py @@ -14,6 +14,7 @@ from otx.algo.detection.layers import CSPRepLayer from otx.algo.modules import Conv2dModule from otx.algo.modules.base_module import BaseModule +from otx.algo.modules.norm import build_norm_layer __all__ = ["HybridEncoder"] @@ -113,10 +114,8 @@ class HybridEncoder(BaseModule): dropout (float, optional): Dropout rate. Defaults to 0.0. enc_activation_callable (Callable[..., nn.Module]): Activation layer module. Defaults to ``nn.GELU``. - norm_callable (Callable[..., nn.Module]): Normalization layer module. + normalization_callable (Callable[..., nn.Module]): Normalization layer module. Defaults to ``nn.BatchNorm2d``. - norm_name (str): The name of the normalization layer fpr ``build_norm_layer``. - Defaults to 'norm'. use_encoder_idx (list[int], optional): List of indices of the encoder to use. Defaults to [2]. num_encoder_layers (int, optional): Number of layers in the transformer encoder. @@ -142,8 +141,7 @@ def __init__( dim_feedforward: int = 1024, dropout: float = 0.0, enc_activation_callable: Callable[..., nn.Module] = nn.GELU, - norm_callable: Callable[..., nn.Module] = nn.BatchNorm2d, - norm_name: str = "norm", + normalization_callable: Callable[..., nn.Module] = nn.BatchNorm2d, use_encoder_idx: list[int] = [2], # noqa: B006 num_encoder_layers: int = 1, pe_temperature: float = 10000, @@ -198,8 +196,7 @@ def __init__( 1, 1, activation_callable=activation_callable, - norm_callable=norm_callable, - norm_name=norm_name, + normalization=build_norm_layer(normalization_callable, num_features=hidden_dim), ), ) self.fpn_blocks.append( @@ -209,8 +206,7 @@ def __init__( round(3 * depth_mult), activation_callable=activation_callable, expansion=expansion, - norm_callable=norm_callable, - norm_name=norm_name, + normalization_callable=normalization_callable, ), ) @@ -226,8 +222,7 @@ def __init__( 2, padding=1, activation_callable=activation_callable, - norm_callable=norm_callable, - norm_name=norm_name, + normalization=build_norm_layer(normalization_callable, num_features=hidden_dim), ), ) self.pan_blocks.append( @@ -237,8 +232,7 @@ def __init__( round(3 * depth_mult), activation_callable=activation_callable, expansion=expansion, - norm_callable=norm_callable, - norm_name=norm_name, + normalization_callable=normalization_callable, ), ) diff --git a/src/otx/algo/detection/necks/yolox_pafpn.py b/src/otx/algo/detection/necks/yolox_pafpn.py index 3d744749823..c44a65db1ef 100644 --- a/src/otx/algo/detection/necks/yolox_pafpn.py +++ b/src/otx/algo/detection/necks/yolox_pafpn.py @@ -19,6 +19,7 @@ from otx.algo.modules.activation import Swish from otx.algo.modules.base_module import BaseModule from otx.algo.modules.conv_module import Conv2dModule, DepthwiseSeparableConvModule +from otx.algo.modules.norm import build_norm_layer class YOLOXPAFPN(BaseModule): @@ -32,7 +33,7 @@ class YOLOXPAFPN(BaseModule): blocks. Default: False upsample_cfg (dict): Config dict for interpolate layer. Default: `dict(scale_factor=2, mode='nearest')` - norm_callable (Callable[..., nn.Module] | None): Normalization layer module. + normalization_callable (Callable[..., nn.Module] | None): Normalization layer module. Defaults to ``partial(nn.BatchNorm2d, momentum=0.03, eps=0.001)``. activation_callable (Callable[..., nn.Module]): Activation layer module. Defaults to ``nn.Swish``. @@ -47,7 +48,7 @@ def __init__( num_csp_blocks: int = 3, use_depthwise: bool = False, upsample_cfg: dict | None = None, - norm_callable: Callable[..., nn.Module] = partial(nn.BatchNorm2d, momentum=0.03, eps=0.001), + normalization_callable: Callable[..., nn.Module] = partial(nn.BatchNorm2d, momentum=0.03, eps=0.001), activation_callable: Callable[..., nn.Module] = Swish, init_cfg: dict | list[dict] | None = None, ): @@ -78,7 +79,7 @@ def __init__( in_channels[idx], in_channels[idx - 1], 1, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=in_channels[idx - 1]), activation_callable=activation_callable, ), ) @@ -89,7 +90,7 @@ def __init__( num_blocks=num_csp_blocks, add_identity=False, use_depthwise=use_depthwise, - norm_callable=norm_callable, + normalization_callable=normalization_callable, activation_callable=activation_callable, ), ) @@ -105,7 +106,7 @@ def __init__( 3, stride=2, padding=1, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=in_channels[idx]), activation_callable=activation_callable, ), ) @@ -116,7 +117,7 @@ def __init__( num_blocks=num_csp_blocks, add_identity=False, use_depthwise=use_depthwise, - norm_callable=norm_callable, + normalization_callable=normalization_callable, activation_callable=activation_callable, ), ) @@ -128,7 +129,7 @@ def __init__( in_channels[i], out_channels, 1, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=out_channels), activation_callable=activation_callable, ), ) diff --git a/src/otx/algo/detection/rtdetr.py b/src/otx/algo/detection/rtdetr.py index 753547a6a45..bee4cd41d53 100644 --- a/src/otx/algo/detection/rtdetr.py +++ b/src/otx/algo/detection/rtdetr.py @@ -7,6 +7,7 @@ import copy import re +from functools import partial from typing import Any import torch @@ -18,7 +19,7 @@ from otx.algo.detection.base_models.detection_transformer import DETR from otx.algo.detection.heads import RTDETRTransformer from otx.algo.detection.necks import HybridEncoder -from otx.algo.modules.norm import FrozenBatchNorm2d +from otx.algo.modules.norm import FrozenBatchNorm2d, build_norm_layer from otx.core.data.entity.base import OTXBatchLossEntity from otx.core.data.entity.detection import DetBatchDataEntity, DetBatchPredEntity from otx.core.exporter.base import OTXModelExporter @@ -253,8 +254,7 @@ def _build_model(self, num_classes: int) -> nn.Module: return_idx=[1, 2, 3], pretrained=True, freeze_at=0, - norm_callable=FrozenBatchNorm2d, - norm_name="norm", + normalization_callable=partial(build_norm_layer, FrozenBatchNorm2d, layer_name="norm"), ) encoder = HybridEncoder( eval_spatial_size=self.image_size[2:], @@ -295,8 +295,7 @@ def _build_model(self, num_classes: int) -> nn.Module: backbone = PResNet( depth=101, return_idx=[1, 2, 3], - norm_callable=FrozenBatchNorm2d, - norm_name="norm", + normalization_callable=partial(build_norm_layer, FrozenBatchNorm2d, layer_name="norm"), pretrained=True, freeze_at=0, ) diff --git a/src/otx/algo/detection/rtmdet.py b/src/otx/algo/detection/rtmdet.py index fc2c181512a..1b9a4326994 100644 --- a/src/otx/algo/detection/rtmdet.py +++ b/src/otx/algo/detection/rtmdet.py @@ -93,7 +93,7 @@ def _build_model(self, num_classes: int) -> RTMDet: backbone = CSPNeXt( deepen_factor=0.167, widen_factor=0.375, - norm_callable=nn.BatchNorm2d, + normalization_callable=nn.BatchNorm2d, activation_callable=partial(nn.SiLU, inplace=True), ) @@ -101,7 +101,7 @@ def _build_model(self, num_classes: int) -> RTMDet: in_channels=(96, 192, 384), out_channels=96, num_csp_blocks=1, - norm_callable=nn.BatchNorm2d, + normalization_callable=nn.BatchNorm2d, activation_callable=partial(nn.SiLU, inplace=True), ) @@ -116,7 +116,7 @@ def _build_model(self, num_classes: int) -> RTMDet: loss_cls=QualityFocalLoss(use_sigmoid=True, beta=2.0, loss_weight=1.0), loss_bbox=GIoULoss(loss_weight=2.0), loss_centerness=CrossEntropyLoss(use_sigmoid=True, loss_weight=1.0), - norm_callable=nn.BatchNorm2d, + normalization=nn.BatchNorm2d, activation_callable=partial(nn.SiLU, inplace=True), train_cfg=train_cfg, test_cfg=test_cfg, diff --git a/src/otx/algo/instance_segmentation/backbones/swin.py b/src/otx/algo/instance_segmentation/backbones/swin.py index 6761a74ca66..26803aab71b 100644 --- a/src/otx/algo/instance_segmentation/backbones/swin.py +++ b/src/otx/algo/instance_segmentation/backbones/swin.py @@ -320,7 +320,7 @@ class SwinBlock(BaseModule): drop_path_rate (float, optional): Stochastic depth rate. Default: 0. activation_callable (Callable[..., nn.Module]): Activation layer module. Defaults to ``nn.GELU``. - norm_callable (Callable[..., nn.Module]): Normalization layer module. + normalization_callable (Callable[..., nn.Module]): Normalization layer module. Defaults to ``nn.LayerNorm``. with_cp (bool, optional): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. @@ -342,7 +342,7 @@ def __init__( attn_drop_rate: float = 0.0, drop_path_rate: float = 0.0, activation_callable: Callable[..., nn.Module] = nn.GELU, - norm_callable: Callable[..., nn.Module] = nn.LayerNorm, + normalization_callable: Callable[..., nn.Module] = nn.LayerNorm, with_cp: bool = False, init_cfg: None = None, ): @@ -351,7 +351,7 @@ def __init__( self.init_cfg = init_cfg self.with_cp = with_cp - self.norm1 = build_norm_layer(norm_callable, embed_dims)[1] + self.norm1 = build_norm_layer(normalization_callable, embed_dims)[1] self.attn = ShiftWindowMSA( embed_dims=embed_dims, num_heads=num_heads, @@ -365,7 +365,7 @@ def __init__( init_cfg=None, ) - self.norm2 = build_norm_layer(norm_callable, embed_dims)[1] + self.norm2 = build_norm_layer(normalization_callable, embed_dims)[1] self.ffn = FFN( embed_dims=embed_dims, feedforward_channels=feedforward_channels, @@ -415,7 +415,7 @@ class SwinBlockSequence(BaseModule): module. Default: None. activation_callable (Callable[..., nn.Module]): Activation layer module. Defaults to ``nn.GELU``. - norm_callable (Callable[..., nn.Module]): Normalization layer module. + normalization_callable (Callable[..., nn.Module]): Normalization layer module. Defaults to ``nn.LayerNorm``. with_cp (bool, optional): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. @@ -438,7 +438,7 @@ def __init__( drop_path_rate: list[float] | float = 0.0, downsample: BaseModule | None = None, activation_callable: Callable[..., nn.Module] = nn.GELU, - norm_callable: Callable[..., nn.Module] = nn.LayerNorm, + normalization_callable: Callable[..., nn.Module] = nn.LayerNorm, with_cp: bool = False, init_cfg: None = None, ): @@ -466,7 +466,7 @@ def __init__( attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rates[i], activation_callable=activation_callable, - norm_callable=norm_callable, + normalization_callable=normalization_callable, with_cp=with_cp, init_cfg=None, ) @@ -525,7 +525,7 @@ class SwinTransformer(BaseModule): drop_path_rate (float): Stochastic depth rate. Defaults: 0.1. activation_callable (Callable[..., nn.Module]): Activation layer module. Defaults to ``nn.GELU``. - norm_callable (Callable[..., nn.Module]): Normalization layer module. + normalization_callable (Callable[..., nn.Module]): Normalization layer module. Defaults to ``nn.LayerNorm``. with_cp (bool, optional): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. @@ -560,7 +560,7 @@ def __init__( attn_drop_rate: float = 0.0, drop_path_rate: float = 0.1, activation_callable: Callable[..., nn.Module] = nn.GELU, - norm_callable: Callable[..., nn.Module] = nn.LayerNorm, + normalization_callable: Callable[..., nn.Module] = nn.LayerNorm, with_cp: bool = False, pretrained: str | None = None, convert_weights: bool = False, @@ -607,7 +607,7 @@ def __init__( embed_dims=embed_dims, kernel_size=patch_size, stride=strides[0], - norm_callable=norm_callable if patch_norm else None, + normalization_callable=normalization_callable if patch_norm else None, init_cfg=None, ) @@ -625,7 +625,7 @@ def __init__( in_channels=in_channels, out_channels=2 * in_channels, stride=strides[i + 1], - norm_callable=norm_callable if patch_norm else None, + normalization_callable=normalization_callable if patch_norm else None, init_cfg=None, ) else: @@ -644,7 +644,7 @@ def __init__( drop_path_rate=dpr[sum(depths[:i]) : sum(depths[: i + 1])], downsample=downsample, activation_callable=activation_callable, - norm_callable=norm_callable, + normalization_callable=normalization_callable, with_cp=with_cp, init_cfg=None, ) @@ -655,7 +655,7 @@ def __init__( self.num_features = [int(embed_dims * 2**i) for i in range(num_layers)] # Add a norm layer for each output for i in out_indices: - layer = build_norm_layer(norm_callable, self.num_features[i])[1] + layer = build_norm_layer(normalization_callable, self.num_features[i])[1] layer_name = f"norm{i}" self.add_module(layer_name, layer) diff --git a/src/otx/algo/instance_segmentation/heads/convfc_bbox_head.py b/src/otx/algo/instance_segmentation/heads/convfc_bbox_head.py index 3d1fe221976..692bbabda71 100644 --- a/src/otx/algo/instance_segmentation/heads/convfc_bbox_head.py +++ b/src/otx/algo/instance_segmentation/heads/convfc_bbox_head.py @@ -35,7 +35,7 @@ def __init__( num_reg_fcs: int = 0, conv_out_channels: int = 256, fc_out_channels: int = 1024, - norm_callable: Callable[..., nn.Module] | None = None, + normalization_callable: Callable[..., nn.Module] | None = None, init_cfg: dict | None = None, *args, **kwargs, @@ -64,7 +64,7 @@ def __init__( self.num_reg_fcs = num_reg_fcs self.conv_out_channels = conv_out_channels self.fc_out_channels = fc_out_channels - self.norm_callable = norm_callable + self.normalization_callable = normalization_callable # add shared convs and fcs self.shared_convs, self.shared_fcs, last_layer_dim = self._add_conv_fc_branch( diff --git a/src/otx/algo/instance_segmentation/heads/fcn_mask_head.py b/src/otx/algo/instance_segmentation/heads/fcn_mask_head.py index e68c922da57..d8f3928dbee 100644 --- a/src/otx/algo/instance_segmentation/heads/fcn_mask_head.py +++ b/src/otx/algo/instance_segmentation/heads/fcn_mask_head.py @@ -22,6 +22,7 @@ from otx.algo.instance_segmentation.utils.utils import empty_instances from otx.algo.modules.base_module import BaseModule, ModuleList from otx.algo.modules.conv_module import Conv2dModule +from otx.algo.modules.norm import build_norm_layer BYTES_PER_FLOAT = 4 # determine it based on available resources. @@ -45,7 +46,7 @@ def __init__( conv_out_channels: int = 256, num_classes: int = 80, class_agnostic: int = False, - norm_callable: Callable[..., nn.Module] | None = None, + normalization_callable: Callable[..., nn.Module] | None = None, init_cfg: dict | list[dict] | None = None, ) -> None: if init_cfg is not None: @@ -61,7 +62,7 @@ def __init__( self.conv_out_channels = conv_out_channels self.num_classes = num_classes self.class_agnostic = class_agnostic - self.norm_callable = norm_callable + self.normalization_callable = normalization_callable self.loss_mask = loss_mask @@ -75,7 +76,7 @@ def __init__( self.conv_out_channels, self.conv_kernel_size, padding=padding, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=self.conv_out_channels), ), ) upsample_in_channels = self.conv_out_channels if self.num_convs > 0 else in_channels diff --git a/src/otx/algo/instance_segmentation/heads/rtmdet_ins_head.py b/src/otx/algo/instance_segmentation/heads/rtmdet_ins_head.py index 6aca961ae14..abd439ff63e 100644 --- a/src/otx/algo/instance_segmentation/heads/rtmdet_ins_head.py +++ b/src/otx/algo/instance_segmentation/heads/rtmdet_ins_head.py @@ -34,7 +34,7 @@ from otx.algo.instance_segmentation.utils.utils import unpack_inst_seg_entity from otx.algo.modules.base_module import BaseModule from otx.algo.modules.conv_module import Conv2dModule -from otx.algo.modules.norm import is_norm +from otx.algo.modules.norm import build_norm_layer, is_norm from otx.algo.utils.mmengine_utils import InstanceData from otx.algo.utils.weight_init import bias_init_with_prob, constant_init, normal_init from otx.core.data.entity.instance_segmentation import InstanceSegBatchDataEntity @@ -110,7 +110,7 @@ def _init_layers(self) -> None: 3, stride=1, padding=1, - norm_callable=self.norm_callable, + normalization=build_norm_layer(self.normalization_callable, num_features=self.feat_channels), activation_callable=self.activation_callable, ), ) @@ -128,7 +128,7 @@ def _init_layers(self) -> None: num_levels=len(self.prior_generator.strides), num_prototypes=self.num_prototypes, activation_callable=self.activation_callable, - norm_callable=self.norm_callable, + normalization_callable=self.normalization_callable, ) def forward(self, feats: tuple[Tensor, ...]) -> tuple: @@ -714,7 +714,7 @@ class MaskFeatModule(BaseModule): stacked_convs (int): Number of convs in mask feature branch. activation_callable (Callable[..., nn.Module]): Activation layer module. Defaults to ``partial(nn.ReLU, inplace=True)``. - norm_callable (Callable[..., nn.Module] | None): Normalization layer module. + normalization_callable (Callable[..., nn.Module] | None): Normalization layer module. Defaults to ``nn.BatchNorm2d``. """ @@ -726,7 +726,7 @@ def __init__( num_levels: int = 3, num_prototypes: int = 8, activation_callable: Callable[..., nn.Module] = partial(nn.ReLU, inplace=True), - norm_callable: Callable[..., nn.Module] = nn.BatchNorm2d, + normalization_callable: Callable[..., nn.Module] = nn.BatchNorm2d, ) -> None: super().__init__(init_cfg=None) @@ -742,7 +742,7 @@ def __init__( 3, padding=1, activation_callable=activation_callable, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=feat_channels), ), ) self.stacked_convs = nn.Sequential(*convs) @@ -772,7 +772,7 @@ class RTMDetInsSepBNHead(RTMDetInsHead): in_channels (int): Number of channels in the input feature map. share_conv (bool): Whether to share conv layers between stages. Defaults to True. - norm_callable (Callable[..., nn.Module]): Normalization layer module. + normalization_callable (Callable[..., nn.Module]): Normalization layer module. Defaults to ``partial(nn.BatchNorm2d, requires_grad=True)``. activation_callable (Callable[..., nn.Module]): Activation layer module. Defaults to ``partial(nn.SiLU, inplace=True)``. @@ -785,7 +785,7 @@ def __init__( in_channels: int, share_conv: bool = True, with_objectness: bool = False, - norm_callable: Callable[..., nn.Module] = partial(nn.BatchNorm2d, requires_grad=True), + normalization_callable: Callable[..., nn.Module] = partial(nn.BatchNorm2d, requires_grad=True), activation_callable: Callable[..., nn.Module] = partial(nn.SiLU, inplace=True), pred_kernel_size: int = 1, **kwargs, @@ -794,7 +794,7 @@ def __init__( super().__init__( num_classes, in_channels, - norm_callable=norm_callable, + normalization_callable=normalization_callable, activation_callable=activation_callable, pred_kernel_size=pred_kernel_size, with_objectness=with_objectness, @@ -849,7 +849,7 @@ def _init_layers(self) -> None: 3, stride=1, padding=1, - norm_callable=self.norm_callable, + normalization=build_norm_layer(self.normalization_callable, num_features=self.feat_channels), activation_callable=self.activation_callable, ), ) @@ -860,7 +860,7 @@ def _init_layers(self) -> None: 3, stride=1, padding=1, - norm_callable=self.norm_callable, + normalization=build_norm_layer(self.normalization_callable, num_features=self.feat_channels), activation_callable=self.activation_callable, ), ) @@ -871,7 +871,7 @@ def _init_layers(self) -> None: 3, stride=1, padding=1, - norm_callable=self.norm_callable, + normalization=build_norm_layer(self.normalization_callable, num_features=self.feat_channels), activation_callable=self.activation_callable, ), ) @@ -909,7 +909,7 @@ def _init_layers(self) -> None: num_levels=len(self.prior_generator.strides), num_prototypes=self.num_prototypes, activation_callable=self.activation_callable, - norm_callable=self.norm_callable, + normalization_callable=self.normalization_callable, ) def init_weights(self) -> None: diff --git a/src/otx/algo/instance_segmentation/layers/transformer.py b/src/otx/algo/instance_segmentation/layers/transformer.py index 011a17b20ab..e7ccc0f6015 100644 --- a/src/otx/algo/instance_segmentation/layers/transformer.py +++ b/src/otx/algo/instance_segmentation/layers/transformer.py @@ -104,7 +104,7 @@ class PatchEmbed(BaseModule): Default: "corner". dilation (int): The dilation rate of embedding conv. Default: 1. bias (bool): Bias of embed conv. Default: True. - norm_callable (Callable[..., nn.Module] | None): Normalization layer module. + normalization_callable (Callable[..., nn.Module] | None): Normalization layer module. Defaults to None. init_cfg (dict, optional): The Config for initialization. Default: None. @@ -119,7 +119,7 @@ def __init__( padding: int | tuple | str = "corner", dilation: int = 1, bias: bool = True, - norm_callable: Callable[..., nn.Module] | None = None, + normalization_callable: Callable[..., nn.Module] | None = None, init_cfg: dict | None = None, ) -> None: super().__init__(init_cfg=init_cfg) @@ -156,8 +156,8 @@ def __init__( bias=bias, ) - if norm_callable is not None: - self.norm = build_norm_layer(norm_callable, embed_dims)[1] + if normalization_callable is not None: + self.norm = build_norm_layer(normalization_callable, embed_dims)[1] else: self.norm = None @@ -210,7 +210,7 @@ class PatchMerging(BaseModule): layer. Default: 1. bias (bool, optional): Whether to add bias in linear layer or not. Defaults: False. - norm_callable (Callable[..., nn.Module] | None): Normalization layer module. + normalization_callable (Callable[..., nn.Module] | None): Normalization layer module. Defaults to ``nn.LayerNorm``. init_cfg (dict, optional): The extra config for initialization. Default: None. @@ -225,7 +225,7 @@ def __init__( padding: int | tuple | str = "corner", dilation: int | tuple = 1, bias: bool = False, - norm_callable: Callable[..., nn.Module] | None = nn.LayerNorm, + normalization_callable: Callable[..., nn.Module] | None = nn.LayerNorm, init_cfg: dict | None = None, ) -> None: super().__init__(init_cfg=init_cfg) @@ -255,8 +255,8 @@ def __init__( sample_dim = _kernel_size[0] * _kernel_size[1] * in_channels - if norm_callable is not None: - self.norm = build_norm_layer(norm_callable, sample_dim)[1] + if normalization_callable is not None: + self.norm = build_norm_layer(normalization_callable, sample_dim)[1] else: self.norm = None diff --git a/src/otx/algo/instance_segmentation/maskrcnn.py b/src/otx/algo/instance_segmentation/maskrcnn.py index 86d1329a274..3438a2cd20c 100644 --- a/src/otx/algo/instance_segmentation/maskrcnn.py +++ b/src/otx/algo/instance_segmentation/maskrcnn.py @@ -22,6 +22,7 @@ from otx.algo.instance_segmentation.necks import FPN from otx.algo.instance_segmentation.two_stage import TwoStageDetector from otx.algo.instance_segmentation.utils.roi_extractors import SingleRoIExtractor +from otx.algo.modules.norm import build_norm_layer from otx.algo.utils.support_otx_v1 import OTXv1Helper from otx.core.exporter.base import OTXModelExporter from otx.core.exporter.native import OTXNativeModelExporter @@ -154,7 +155,7 @@ def _build_model(self, num_classes: int) -> TwoStageDetector: backbone = ResNet( depth=50, frozen_stages=1, - norm_callable=partial(nn.BatchNorm2d, requires_grad=True), + normalization_callable=partial(nn.BatchNorm2d, requires_grad=True), norm_eval=True, num_stages=4, out_indices=(0, 1, 2, 3), @@ -333,7 +334,7 @@ def _build_model(self, num_classes: int) -> TwoStageDetector: "frozen_stages": -1, "pretrained": True, "activation_callable": nn.SiLU, - "norm_callable": partial(nn.BatchNorm2d, requires_grad=True), + "normalization_callable": partial(build_norm_layer, nn.BatchNorm2d, requires_grad=True), }, ) diff --git a/src/otx/algo/instance_segmentation/necks/fpn.py b/src/otx/algo/instance_segmentation/necks/fpn.py index 62df69bfad1..98060241024 100644 --- a/src/otx/algo/instance_segmentation/necks/fpn.py +++ b/src/otx/algo/instance_segmentation/necks/fpn.py @@ -15,6 +15,7 @@ from otx.algo.modules.base_module import BaseModule from otx.algo.modules.conv_module import Conv2dModule +from otx.algo.modules.norm import build_norm_layer class FPN(BaseModule): @@ -36,7 +37,7 @@ class FPN(BaseModule): conv. Defaults to False. no_norm_on_lateral (bool): Whether to apply norm on lateral. Defaults to False. - norm_callable (Callable[..., nn.Module] | None): Normalization layer module. + normalization_callable (Callable[..., nn.Module] | None): Normalization layer module. Defaults to None. activation_callable (Callable[..., nn.Module] | None): Activation layer module. Defaults to None. @@ -54,7 +55,7 @@ def __init__( end_level: int = -1, relu_before_extra_convs: bool = False, no_norm_on_lateral: bool = False, - norm_callable: Callable[..., nn.Module] | None = None, + normalization_callable: Callable[..., nn.Module] | None = None, activation_callable: Callable[..., nn.Module] | None = None, upsample_cfg: dict | None = None, init_cfg: dict | list[dict] | None = None, @@ -98,7 +99,9 @@ def __init__( in_channels[i], out_channels, 1, - norm_callable=norm_callable if not self.no_norm_on_lateral else None, + normalization=build_norm_layer(normalization_callable, num_features=out_channels) + if not self.no_norm_on_lateral + else None, activation_callable=activation_callable, inplace=False, ) @@ -107,7 +110,7 @@ def __init__( out_channels, 3, padding=1, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=out_channels), activation_callable=activation_callable, inplace=False, ) diff --git a/src/otx/algo/instance_segmentation/rtmdet_inst.py b/src/otx/algo/instance_segmentation/rtmdet_inst.py index f86d01b09f1..1b996811bae 100644 --- a/src/otx/algo/instance_segmentation/rtmdet_inst.py +++ b/src/otx/algo/instance_segmentation/rtmdet_inst.py @@ -114,7 +114,7 @@ def _build_model(self, num_classes: int) -> SingleStageDetector: deepen_factor=0.167, widen_factor=0.375, channel_attention=True, - norm_callable=nn.BatchNorm2d, + normalization_callable=nn.BatchNorm2d, activation_callable=partial(nn.SiLU, inplace=True), ) @@ -123,7 +123,7 @@ def _build_model(self, num_classes: int) -> SingleStageDetector: out_channels=96, num_csp_blocks=1, expand_ratio=0.5, - norm_callable=nn.BatchNorm2d, + normalization_callable=nn.BatchNorm2d, activation_callable=partial(nn.SiLU, inplace=True), ) @@ -135,7 +135,7 @@ def _build_model(self, num_classes: int) -> SingleStageDetector: pred_kernel_size=1, feat_channels=96, activation_callable=partial(nn.SiLU, inplace=True), - norm_callable=partial(nn.BatchNorm2d, requires_grad=True), + normalization=partial(nn.BatchNorm2d, requires_grad=True), anchor_generator=MlvlPointGenerator( offset=0, strides=[8, 16, 32], diff --git a/src/otx/algo/keypoint_detection/rtmpose.py b/src/otx/algo/keypoint_detection/rtmpose.py index 7c4f7842c52..52003a2300f 100644 --- a/src/otx/algo/keypoint_detection/rtmpose.py +++ b/src/otx/algo/keypoint_detection/rtmpose.py @@ -78,7 +78,7 @@ def _build_model(self, num_classes: int) -> RTMPose: widen_factor=0.375, out_indices=(4,), channel_attention=True, - norm_callable=nn.BatchNorm2d, + normalization_callable=nn.BatchNorm2d, activation_callable=partial(nn.SiLU, inplace=True), ) head = RTMCCHead( diff --git a/src/otx/algo/modules/conv_module.py b/src/otx/algo/modules/conv_module.py index 28b1458408d..05dbd3e03d9 100644 --- a/src/otx/algo/modules/conv_module.py +++ b/src/otx/algo/modules/conv_module.py @@ -8,9 +8,9 @@ from __future__ import annotations -from copy import deepcopy import inspect import warnings +from copy import deepcopy from functools import partial from typing import TYPE_CHECKING, Callable @@ -324,9 +324,9 @@ def __init__( dilation: int | tuple[int, int] = 1, normalization: tuple[str, nn.Module] | None = None, activation_callable: Callable[..., nn.Module] = nn.ReLU, - dw_normalization: Callable[..., nn.Module] | tuple[str, nn.Module] | None = None, + dw_normalization: tuple[str, nn.Module] | None = None, dw_activation_callable: Callable[..., nn.Module] | None = None, - pw_normalization: Callable[..., nn.Module] | tuple[str, nn.Module] | None = None, + pw_normalization: tuple[str, nn.Module] | None = None, pw_activation_callable: Callable[..., nn.Module] | None = None, **kwargs, ): diff --git a/src/otx/algo/modules/norm.py b/src/otx/algo/modules/norm.py index ffe204ef39c..3011116dae6 100644 --- a/src/otx/algo/modules/norm.py +++ b/src/otx/algo/modules/norm.py @@ -171,8 +171,8 @@ def build_norm_layer( """Build normalization layer. Args: - normalization_callable (Callable[..., nn.Module] | tuple[str, nn.Module] | nn.Module): Normalization layer module. - If tuple is given, return it as is. If callable is given, create the layer. + normalization_callable (Callable[..., nn.Module] | tuple[str, nn.Module] | nn.Module): Normalization layer + module. If tuple is given, return it as is. If callable is given, create the layer. num_features (int): Number of input channels. postfix (int | str): The postfix to be appended into norm abbreviation to create named layer. @@ -207,6 +207,7 @@ def _build_layer(normalization_callable: Callable[..., nn.Module]) -> nn.Module: if isinstance(normalization_callable, partial) and normalization_callable.func.__name__ == "build_norm_layer": # add `num_features` to `normalization_callable` and return it + # TODO (sungchul): is adding more arguments needed? return normalization_callable(num_features=num_features) if not callable(normalization_callable): diff --git a/src/otx/algo/modules/transformer.py b/src/otx/algo/modules/transformer.py index 935a9c1f3b9..3bad72cb1a1 100644 --- a/src/otx/algo/modules/transformer.py +++ b/src/otx/algo/modules/transformer.py @@ -137,7 +137,7 @@ class PatchEmbed(BaseModule): Default: "corner". dilation (int): The dilation rate of embedding conv. Default: 1. bias (bool): Bias of embed conv. Default: True. - norm_callable (Callable[..., nn.Module] | None): Normalization layer module. + normalization_callable (Callable[..., nn.Module] | None): Normalization layer module. Defaults to None. input_size (int | tuple | None): The size of input, which will be used to calculate the out size. Only works when `dynamic_size` @@ -155,7 +155,7 @@ def __init__( padding: str | int | tuple[int, int] = "corner", dilation: int | tuple[int, int] = 1, bias: bool = True, - norm_callable: Callable[..., nn.Module] | None = None, + normalization_callable: Callable[..., nn.Module] | None = None, input_size: int | tuple[int, int] | None = None, init_cfg: dict | None = None, ): @@ -194,8 +194,8 @@ def __init__( ) self.norm: nn.Module | None - if norm_callable is not None: - self.norm = build_norm_layer(norm_callable, embed_dims)[1] + if normalization_callable is not None: + self.norm = build_norm_layer(normalization_callable, embed_dims)[1] else: self.norm = None diff --git a/src/otx/algo/segmentation/backbones/litehrnet.py b/src/otx/algo/segmentation/backbones/litehrnet.py index ded71591cd4..97e854af66f 100644 --- a/src/otx/algo/segmentation/backbones/litehrnet.py +++ b/src/otx/algo/segmentation/backbones/litehrnet.py @@ -36,7 +36,7 @@ class NeighbourSupport(nn.Module): kernel_size (int): Kernel size for convolutional layers. Default is 3. key_ratio (int): Ratio of input channels to key channels. Default is 8. value_ratio (int): Ratio of input channels to value channels. Default is 8. - norm_callable (Callable[..., nn.Module] | None): Normalization layer module. + normalization_callable (Callable[..., nn.Module] | None): Normalization layer module. Defaults to None. """ @@ -46,7 +46,7 @@ def __init__( kernel_size: int = 3, key_ratio: int = 8, value_ratio: int = 8, - norm_callable: Callable[..., nn.Module] | None = None, + normalization_callable: Callable[..., nn.Module] | None = None, ) -> None: super().__init__() @@ -61,7 +61,7 @@ def __init__( out_channels=self.key_channels, kernel_size=1, stride=1, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=self.key_channels), activation_callable=nn.ReLU, ), Conv2dModule( @@ -71,7 +71,7 @@ def __init__( stride=1, padding=(self.kernel_size - 1) // 2, groups=self.key_channels, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=self.key_channels), activation_callable=None, ), Conv2dModule( @@ -79,7 +79,10 @@ def __init__( out_channels=self.kernel_size * self.kernel_size, kernel_size=1, stride=1, - norm_callable=norm_callable, + normalization=build_norm_layer( + normalization_callable, + num_features=self.kernel_size * self.kernel_size, + ), activation_callable=None, ), ) @@ -89,7 +92,7 @@ def __init__( out_channels=self.value_channels, kernel_size=1, stride=1, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=self.value_channels), activation_callable=None, ), nn.Unfold(kernel_size=self.kernel_size, stride=1, padding=1), @@ -99,7 +102,7 @@ def __init__( out_channels=self.in_channels, kernel_size=1, stride=1, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=self.in_channels), activation_callable=None, ) @@ -123,7 +126,7 @@ class CrossResolutionWeighting(nn.Module): Args: channels (list[int]): Number of channels for each stage. ratio (int): Reduction ratio of the bottleneck block. - norm_callable (Callable[..., nn.Module] | None): Normalization layer module. + normalization_callable (Callable[..., nn.Module] | None): Normalization layer module. Defaults to None. activation_callable (Callable[..., nn.Module] | tuple[Callable[..., nn.Module], Callable[..., nn.Module]]): \ Activation layer module or a tuple of activation layer modules. @@ -134,7 +137,7 @@ def __init__( self, channels: list[int], ratio: int = 16, - norm_callable: Callable[..., nn.Module] | None = None, + normalization_callable: Callable[..., nn.Module] | None = None, activation_callable: Callable[..., nn.Module] | tuple[Callable[..., nn.Module], Callable[..., nn.Module]] = ( nn.ReLU, nn.Sigmoid, @@ -157,7 +160,7 @@ def __init__( out_channels=int(total_channel / ratio), kernel_size=1, stride=1, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=int(total_channel / ratio)), activation_callable=activation_callable[0], ) self.conv2 = Conv2dModule( @@ -165,7 +168,7 @@ def __init__( out_channels=total_channel, kernel_size=1, stride=1, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=total_channel), activation_callable=activation_callable[1], ) @@ -250,7 +253,7 @@ class SpatialWeightingV2(nn.Module): Args: channels (int): Number of input channels. ratio (int): Reduction ratio of internal channels. - norm_callable (Callable[..., nn.Module] | None): Normalization layer module. + normalization_callable (Callable[..., nn.Module] | None): Normalization layer module. Defaults to None. enable_norm (bool): Whether to enable normalization layers. """ @@ -259,7 +262,7 @@ def __init__( self, channels: int, ratio: int = 16, - norm_callable: Callable[..., nn.Module] | None = None, + normalization_callable: Callable[..., nn.Module] | None = None, enable_norm: bool = False, ) -> None: super().__init__() @@ -274,7 +277,9 @@ def __init__( kernel_size=1, stride=1, bias=False, - norm_callable=norm_callable if enable_norm else None, + normalization=build_norm_layer(normalization_callable, num_features=self.internal_channels) + if enable_norm + else None, activation_callable=None, ) self.q_channel = Conv2dModule( @@ -283,7 +288,7 @@ def __init__( kernel_size=1, stride=1, bias=False, - norm_callable=norm_callable if enable_norm else None, + normalization=build_norm_layer(normalization_callable, num_features=1) if enable_norm else None, activation_callable=None, ) self.out_channel = Conv2dModule( @@ -291,7 +296,7 @@ def __init__( out_channels=self.in_channels, kernel_size=1, stride=1, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=self.in_channels), activation_callable=nn.Sigmoid, ) @@ -302,7 +307,9 @@ def __init__( kernel_size=1, stride=1, bias=False, - norm_callable=norm_callable if enable_norm else None, + normalization=build_norm_layer(normalization_callable, num_features=self.internal_channels) + if enable_norm + else None, activation_callable=None, ) self.q_spatial = Conv2dModule( @@ -311,7 +318,9 @@ def __init__( kernel_size=1, stride=1, bias=False, - norm_callable=norm_callable if enable_norm else None, + normalization=build_norm_layer(normalization_callable, num_features=self.internal_channels) + if enable_norm + else None, activation_callable=None, ) self.global_avgpool = nn.AdaptiveAvgPool2d(1) @@ -378,7 +387,7 @@ class ConditionalChannelWeighting(nn.Module): in_channels (list[int]): Number of input channels for each input feature map. stride (int): Stride used in the first convolutional layer. reduce_ratio (int): Reduction ratio used in the cross-resolution weighting module. - norm_callable (Callable[..., nn.Module]): Normalization layer module. + normalization_callable (Callable[..., nn.Module]): Normalization layer module. Defaults to ``nn.BatchNorm2d``. with_cp (bool): Whether to use checkpointing to save memory. dropout (float | None): Dropout probability used in the depthwise convolutional layers. @@ -395,7 +404,7 @@ def __init__( in_channels: list[int], stride: int, reduce_ratio: int, - norm_callable: Callable[..., nn.Module] = nn.BatchNorm2d, + normalization_callable: Callable[..., nn.Module] = nn.BatchNorm2d, with_cp: bool = False, dropout: float | None = None, weighting_module_version: str = "v1", @@ -416,7 +425,7 @@ def __init__( self.cross_resolution_weighting = CrossResolutionWeighting( branch_channels, ratio=reduce_ratio, - norm_callable=norm_callable, + normalization_callable=normalization_callable, ) self.depthwise_convs = nn.ModuleList( [ @@ -427,7 +436,7 @@ def __init__( stride=self.stride, padding=dw_ksize // 2, groups=channel, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=channel), activation_callable=None, ) for channel in branch_channels @@ -438,7 +447,7 @@ def __init__( spatial_weighting_module( # type: ignore[call-arg] channels=channel, ratio=4, - norm_callable=norm_callable, + normalization_callable=normalization_callable, enable_norm=True, ) for channel in branch_channels @@ -454,7 +463,7 @@ def __init__( kernel_size=3, key_ratio=8, value_ratio=4, - norm_callable=norm_callable, + normalization_callable=normalization_callable, ) for channel in branch_channels ], @@ -505,7 +514,7 @@ class Stem(nn.Module): stem_channels (int): Number of output channels of the stem layer. out_channels (int): Number of output channels of the backbone network. expand_ratio (int): Expansion ratio of the internal channels. - norm_callable (Callable[..., nn.Module] | None): Normalization layer module. + normalization_callable (Callable[..., nn.Module] | None): Normalization layer module. Defaults to ``nn.BatchNorm2d``. with_cp (bool): Use checkpointing to save memory during forward pass. num_stages (int): Number of stages in the backbone network. @@ -524,7 +533,7 @@ def __init__( stem_channels: int, out_channels: int, expand_ratio: int, - norm_callable: Callable[..., nn.Module] = nn.BatchNorm2d, + normalization_callable: Callable[..., nn.Module] = nn.BatchNorm2d, with_cp: bool = False, strides: tuple[int, int] = (2, 2), extra_stride: bool = False, @@ -542,7 +551,7 @@ def __init__( self.in_channels = in_channels self.out_channels = out_channels - self.norm_callable = norm_callable + self.normalization_callable = normalization_callable self.with_cp = with_cp self.input_norm = None @@ -555,7 +564,7 @@ def __init__( kernel_size=3, stride=strides[0], padding=1, - norm_callable=self.norm_callable, + normalization=build_norm_layer(self.normalization_callable, num_features=stem_channels), activation_callable=nn.ReLU, ) @@ -567,7 +576,7 @@ def __init__( kernel_size=3, stride=2, padding=1, - norm_callable=self.norm_callable, + normalization=build_norm_layer(self.normalization_callable, num_features=stem_channels), activation_callable=nn.ReLU, ) @@ -586,7 +595,7 @@ def __init__( stride=strides[1], padding=1, groups=branch_channels, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=branch_channels), activation_callable=None, ), Conv2dModule( @@ -595,7 +604,7 @@ def __init__( kernel_size=1, stride=1, padding=0, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=inc_channels), activation_callable=nn.ReLU, ), ) @@ -606,7 +615,7 @@ def __init__( kernel_size=1, stride=1, padding=0, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=mid_channels), activation_callable=nn.ReLU, ) self.depthwise_conv = Conv2dModule( @@ -616,7 +625,7 @@ def __init__( stride=strides[1], padding=1, groups=mid_channels, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=mid_channels), activation_callable=None, ) self.linear_conv = Conv2dModule( @@ -625,7 +634,10 @@ def __init__( kernel_size=1, stride=1, padding=0, - norm_callable=norm_callable, + normalization=build_norm_layer( + normalization_callable, + num_features=branch_channels if stem_channels == self.out_channels else stem_channels, + ), activation_callable=nn.ReLU, ) @@ -670,7 +682,7 @@ class StemV2(nn.Module): stem_channels (int): Number of output channels of the stem layer. out_channels (int): Number of output channels of the backbone network. expand_ratio (int): Expansion ratio of the internal channels. - norm_callable (Callable[..., nn.Module]): Normalization layer module. + normalization_callable (Callable[..., nn.Module]): Normalization layer module. Defaults to ``nn.BatchNorm2d``. with_cp (bool): Use checkpointing to save memory during forward pass. num_stages (int): Number of stages in the backbone network. @@ -690,7 +702,7 @@ def __init__( stem_channels: int, out_channels: int, expand_ratio: int, - norm_callable: Callable[..., nn.Module] = nn.BatchNorm2d, + normalization_callable: Callable[..., nn.Module] = nn.BatchNorm2d, with_cp: bool = False, num_stages: int = 1, strides: tuple[int, int] = (2, 2), @@ -713,7 +725,7 @@ def __init__( self.in_channels = in_channels self.out_channels = out_channels - self.norm_callable = norm_callable + self.normalization_callable = normalization_callable self.with_cp = with_cp self.num_stages = num_stages @@ -727,7 +739,7 @@ def __init__( kernel_size=3, stride=strides[0], padding=1, - norm_callable=self.norm_callable, + normalization=build_norm_layer(self.normalization_callable, num_features=stem_channels), activation_callable=nn.ReLU, ) @@ -739,7 +751,7 @@ def __init__( kernel_size=3, stride=2, padding=1, - norm_callable=self.norm_callable, + normalization=build_norm_layer(self.normalization_callable, num_features=stem_channels), activation_callable=nn.ReLU, ) @@ -758,7 +770,7 @@ def __init__( stride=strides[stage], padding=1, groups=internal_branch_channels, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=internal_branch_channels), activation_callable=None, ), Conv2dModule( @@ -767,7 +779,10 @@ def __init__( kernel_size=1, stride=1, padding=0, - norm_callable=norm_callable, + normalization=build_norm_layer( + normalization_callable, + num_features=out_branch_channels if stage == num_stages else internal_branch_channels, + ), activation_callable=nn.ReLU, ), ), @@ -781,7 +796,7 @@ def __init__( kernel_size=1, stride=1, padding=0, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=mid_channels), activation_callable=nn.ReLU, ), Conv2dModule( @@ -791,7 +806,7 @@ def __init__( stride=strides[stage], padding=1, groups=mid_channels, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=mid_channels), activation_callable=None, ), Conv2dModule( @@ -800,7 +815,10 @@ def __init__( kernel_size=1, stride=1, padding=0, - norm_callable=norm_callable, + normalization=build_norm_layer( + normalization_callable, + num_features=out_branch_channels if stage == num_stages else internal_branch_channels, + ), activation_callable=nn.ReLU, ), ), @@ -847,7 +865,7 @@ class ShuffleUnit(nn.Module): in_channels (int): The input channels of the block. out_channels (int): The output channels of the block. stride (int): Stride of the 3x3 convolution layer. Default: 1 - norm_callable (Callable[..., nn.Module]): Normalization layer module. + normalization_callable (Callable[..., nn.Module]): Normalization layer module. Defaults to ``nn.BatchNorm2d``. activation_callable (Callable[..., nn.Module]): Activation layer module. Defaults to ``nn.ReLU``. @@ -860,7 +878,7 @@ def __init__( in_channels: int, out_channels: int, stride: int = 1, - norm_callable: Callable[..., nn.Module] = nn.BatchNorm2d, + normalization_callable: Callable[..., nn.Module] = nn.BatchNorm2d, activation_callable: Callable[..., nn.Module] = nn.ReLU, with_cp: bool = False, ) -> None: @@ -887,7 +905,7 @@ def __init__( stride=self.stride, padding=1, groups=in_channels, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=in_channels), activation_callable=None, ), Conv2dModule( @@ -896,7 +914,7 @@ def __init__( kernel_size=1, stride=1, padding=0, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=branch_features), activation_callable=activation_callable, ), ) @@ -908,7 +926,7 @@ def __init__( kernel_size=1, stride=1, padding=0, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=branch_features), activation_callable=activation_callable, ), Conv2dModule( @@ -918,7 +936,7 @@ def __init__( stride=self.stride, padding=1, groups=branch_features, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=branch_features), activation_callable=None, ), Conv2dModule( @@ -927,7 +945,7 @@ def __init__( kernel_size=1, stride=1, padding=0, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=branch_features), activation_callable=activation_callable, ), ) @@ -958,7 +976,7 @@ class LiteHRModule(nn.Module): module_type (str): Type of module to use for the network. Can be "LITE" or "NAIVE". multiscale_output (bool, optional): Whether to output features from all branches. Defaults to False. with_fuse (bool, optional): Whether to use the fuse layer. Defaults to True. - norm_callable (Callable[..., nn.Module]): Normalization layer module. + normalization_callable (Callable[..., nn.Module]): Normalization layer module. Defaults to ``nn.BatchNorm2d``. with_cp (bool, optional): Whether to use checkpointing. Defaults to False. dropout (float, optional): Dropout rate. Defaults to None. @@ -975,7 +993,7 @@ def __init__( module_type: str, multiscale_output: bool = False, with_fuse: bool = True, - norm_callable: Callable[..., nn.Module] = nn.BatchNorm2d, + normalization_callable: Callable[..., nn.Module] = nn.BatchNorm2d, with_cp: bool = False, dropout: float | None = None, weighting_module_version: str = "v1", @@ -991,7 +1009,7 @@ def __init__( self.module_type = module_type self.multiscale_output = multiscale_output self.with_fuse = with_fuse - self.norm_callable = norm_callable + self.normalization_callable = normalization_callable self.with_cp = with_cp self.weighting_module_version = weighting_module_version self.neighbour_weighting = neighbour_weighting @@ -1024,7 +1042,7 @@ def _make_weighting_blocks( self.in_channels, stride=stride, reduce_ratio=reduce_ratio, - norm_callable=self.norm_callable, + normalization_callable=self.normalization_callable, with_cp=self.with_cp, dropout=dropout, weighting_module_version=self.weighting_module_version, @@ -1042,7 +1060,7 @@ def _make_one_branch(self, branch_index: int, num_blocks: int, stride: int = 1) self.in_channels[branch_index], self.in_channels[branch_index], stride=stride, - norm_callable=self.norm_callable, + normalization_callable=self.normalization_callable, activation_callable=nn.ReLU, with_cp=self.with_cp, ), @@ -1051,7 +1069,7 @@ def _make_one_branch(self, branch_index: int, num_blocks: int, stride: int = 1) self.in_channels[branch_index], self.in_channels[branch_index], stride=1, - norm_callable=self.norm_callable, + normalization_callable=self.normalization_callable, activation_callable=nn.ReLU, with_cp=self.with_cp, ) @@ -1089,7 +1107,7 @@ def _make_fuse_layers(self) -> nn.ModuleList: padding=0, bias=False, ), - build_norm_layer(self.norm_callable, in_channels[i])[1], + build_norm_layer(self.normalization_callable, in_channels[i])[1], ), ) elif j == i: @@ -1109,7 +1127,7 @@ def _make_fuse_layers(self) -> nn.ModuleList: groups=in_channels[j], bias=False, ), - build_norm_layer(self.norm_callable, in_channels[j])[1], + build_norm_layer(self.normalization_callable, in_channels[j])[1], nn.Conv2d( in_channels[j], in_channels[i], @@ -1118,7 +1136,7 @@ def _make_fuse_layers(self) -> nn.ModuleList: padding=0, bias=False, ), - build_norm_layer(self.norm_callable, in_channels[i])[1], + build_norm_layer(self.normalization_callable, in_channels[i])[1], ), ) else: @@ -1133,7 +1151,7 @@ def _make_fuse_layers(self) -> nn.ModuleList: groups=in_channels[j], bias=False, ), - build_norm_layer(self.norm_callable, in_channels[j])[1], + build_norm_layer(self.normalization_callable, in_channels[j])[1], nn.Conv2d( in_channels[j], in_channels[j], @@ -1142,7 +1160,7 @@ def _make_fuse_layers(self) -> nn.ModuleList: padding=0, bias=False, ), - build_norm_layer(self.norm_callable, in_channels[j])[1], + build_norm_layer(self.normalization_callable, in_channels[j])[1], nn.ReLU(inplace=True), ), ) @@ -1192,7 +1210,7 @@ class LiteHRNet(BaseModule): Args: extra (dict): detailed configuration for each stage of HRNet. in_channels (int): Number of input image channels. Default: 3. - norm_callable (Callable[..., nn.Module]): Normalization layer module. + normalization_callable (Callable[..., nn.Module]): Normalization layer module. Defaults to ``nn.BatchNorm2d``. norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm @@ -1207,7 +1225,7 @@ def __init__( self, extra: dict, in_channels: int = 3, - norm_callable: Callable[..., nn.Module] = nn.BatchNorm2d, + normalization_callable: Callable[..., nn.Module] = nn.BatchNorm2d, norm_eval: bool = False, with_cp: bool = False, zero_init_residual: bool = False, @@ -1219,7 +1237,7 @@ def __init__( super().__init__(init_cfg=init_cfg) self.extra = extra - self.norm_callable = norm_callable + self.normalization_callable = normalization_callable self.norm_eval = norm_eval self.with_cp = with_cp self.zero_init_residual = zero_init_residual @@ -1231,7 +1249,7 @@ def __init__( expand_ratio=self.extra["stem"]["expand_ratio"], strides=self.extra["stem"]["strides"], extra_stride=self.extra["stem"]["extra_stride"], - norm_callable=self.norm_callable, + normalization_callable=self.normalization_callable, ) self.enable_stem_pool = self.extra["stem"].get("out_pool", False) @@ -1276,7 +1294,7 @@ def __init__( kernel_size=1, stride=1, padding=0, - norm_callable=self.norm_callable, + normalization=build_norm_layer(self.normalization_callable, num_features=out_modules_channels), activation_callable=nn.ReLU, ), ) @@ -1288,14 +1306,14 @@ def __init__( key_channels=self.extra["out_modules"]["position_att"]["key_channels"], value_channels=self.extra["out_modules"]["position_att"]["value_channels"], psp_size=self.extra["out_modules"]["position_att"]["psp_size"], - norm_callable=self.norm_callable, + normalization_callable=self.normalization_callable, ), ) if self.extra["out_modules"]["local_att"]["enable"]: out_modules.append( LocalAttentionModule( num_channels=in_modules_channels, - norm_callable=self.norm_callable, + normalization_callable=self.normalization_callable, ), ) @@ -1313,7 +1331,7 @@ def __init__( stride=1, padding=1, groups=self.stem.out_channels, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=self.stem.out_channels), activation_callable=None, ), Conv2dModule( @@ -1322,7 +1340,7 @@ def __init__( kernel_size=1, stride=1, padding=0, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=num_channels_last[0]), activation_callable=nn.ReLU, ), ) @@ -1334,7 +1352,7 @@ def __init__( self.aggregator = IterativeAggregator( in_channels=num_channels_last, min_channels=self.extra["out_aggregator"].get("min_channels", None), - norm_callable=self.norm_callable, + normalization_callable=self.normalization_callable, ) if pretrained_weights is not None: @@ -1364,7 +1382,7 @@ def _make_transition_layer( groups=num_channels_pre_layer[i], bias=False, ), - build_norm_layer(self.norm_callable, num_channels_pre_layer[i])[1], + build_norm_layer(self.normalization_callable, num_channels_pre_layer[i])[1], nn.Conv2d( num_channels_pre_layer[i], num_channels_cur_layer[i], @@ -1373,7 +1391,7 @@ def _make_transition_layer( padding=0, bias=False, ), - build_norm_layer(self.norm_callable, num_channels_cur_layer[i])[1], + build_norm_layer(self.normalization_callable, num_channels_cur_layer[i])[1], nn.ReLU(), ), ) @@ -1395,7 +1413,7 @@ def _make_transition_layer( groups=in_channels, bias=False, ), - build_norm_layer(self.norm_callable, in_channels)[1], + build_norm_layer(self.normalization_callable, in_channels)[1], nn.Conv2d( in_channels, out_channels, @@ -1404,7 +1422,7 @@ def _make_transition_layer( padding=0, bias=False, ), - build_norm_layer(self.norm_callable, out_channels)[1], + build_norm_layer(self.normalization_callable, out_channels)[1], nn.ReLU(), ), ) @@ -1455,7 +1473,7 @@ def _make_stage( module_type, multiscale_output=reset_multiscale_output, with_fuse=with_fuse, - norm_callable=self.norm_callable, + normalization_callable=self.normalization_callable, with_cp=self.with_cp, dropout=dropout, weighting_module_version=weighting_module_version, diff --git a/src/otx/algo/segmentation/backbones/mscan.py b/src/otx/algo/segmentation/backbones/mscan.py index 21a005b07b8..0d07d961dab 100644 --- a/src/otx/algo/segmentation/backbones/mscan.py +++ b/src/otx/algo/segmentation/backbones/mscan.py @@ -108,8 +108,8 @@ class StemConv(BaseModule): out_channels (int): The dimension of output channels. activation_callable (Callable[..., nn.Module]): Activation layer module. Defaults to ``nn.GELU``. - norm_callable (Callable[..., nn.Module]): Normalization layer module. - Defaults to ``partial(SyncBatchNorm, requires_grad=True)``. + normalization_callable (Callable[..., nn.Module]): Normalization layer module. + Defaults to ``partial(build_norm_layer, SyncBatchNorm, requires_grad=True)``. """ def __init__( @@ -117,15 +117,15 @@ def __init__( in_channels: int, out_channels: int, activation_callable: Callable[..., nn.Module] = nn.GELU, - norm_callable: Callable[..., nn.Module] = partial(SyncBatchNorm, requires_grad=True), + normalization_callable: Callable[..., nn.Module] = partial(build_norm_layer, SyncBatchNorm, requires_grad=True), ) -> None: super().__init__() self.proj = nn.Sequential( nn.Conv2d(in_channels, out_channels // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), - build_norm_layer(norm_callable, out_channels // 2)[1], + build_norm_layer(normalization_callable, num_features=out_channels // 2)[1], activation_callable(), nn.Conv2d(out_channels // 2, out_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)), - build_norm_layer(norm_callable, out_channels)[1], + build_norm_layer(normalization_callable, num_features=out_channels)[1], ) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, int, int]: @@ -242,8 +242,8 @@ class MSCABlock(BaseModule): drop_path (float): The dropout rate for the path. activation_callable (Callable[..., nn.Module]): Activation layer module. Defaults to ``nn.GELU``. - norm_callable (Callable[..., nn.Module]): Normalization layer module. - Defaults to ``partial(SyncBatchNorm, requires_grad=True)``. + normalization_callable (Callable[..., nn.Module]): Normalization layer module. + Defaults to ``partial(build_norm_layer, SyncBatchNorm, requires_grad=True)``. """ def __init__( @@ -255,11 +255,11 @@ def __init__( drop: float = 0.0, drop_path: float = 0.0, activation_callable: Callable[..., nn.Module] = nn.GELU, - norm_callable: Callable[..., nn.Module] = partial(SyncBatchNorm, requires_grad=True), + normalization_callable: Callable[..., nn.Module] = partial(build_norm_layer, SyncBatchNorm, requires_grad=True), ) -> None: """Initialize a MSCABlock.""" super().__init__() - self.norm1 = build_norm_layer(norm_callable, channels)[1] # type: nn.Module + self.norm1 = build_norm_layer(normalization_callable, num_features=channels)[1] # type: nn.Module self.attn = MSCASpatialAttention( channels, attention_kernel_sizes, @@ -267,7 +267,7 @@ def __init__( activation_callable, ) # type: MSCAAttention self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() # type: nn.Module - self.norm2 = build_norm_layer(norm_callable, channels)[1] # type: nn.Module + self.norm2 = build_norm_layer(normalization_callable, num_features=channels)[1] # type: nn.Module mlp_hidden_channels = int(channels * mlp_ratio) # type: int self.mlp = Mlp( in_features=channels, @@ -296,8 +296,8 @@ class OverlapPatchEmbed(BaseModule): stride (int, optional): Stride of the convolutional layer. Defaults to 4. in_channels (int, optional): The number of input channels. Defaults to 3. embed_dim (int, optional): The dimensions of embedding. Defaults to 768. - norm_callable (Callable[..., nn.Module]): Normalization layer module. - Defaults to ``partial(SyncBatchNorm, requires_grad=True)``. + normalization_callable (Callable[..., nn.Module]): Normalization layer module. + Defaults to ``partial(build_norm_layer, SyncBatchNorm, requires_grad=True)``. """ def __init__( @@ -306,12 +306,12 @@ def __init__( stride: int = 4, in_channels: int = 3, embed_dim: int = 768, - norm_callable: Callable[..., nn.Module] = partial(SyncBatchNorm, requires_grad=True), + normalization_callable: Callable[..., nn.Module] = partial(build_norm_layer, SyncBatchNorm, requires_grad=True), ): """Initializes the OverlapPatchEmbed module.""" super().__init__() self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=stride, padding=patch_size // 2) - self.norm = build_norm_layer(norm_callable, embed_dim)[1] + self.norm = build_norm_layer(normalization_callable, num_features=embed_dim)[1] def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, int, int]: """Forward function.""" @@ -346,8 +346,8 @@ class MSCAN(BaseModule): in Attention Module (Figure 2(b) of original paper). Defaults to [2, [0, 3], [0, 5], [0, 10]]. activation_callable (Callable[..., nn.Module]): Activation layer module. Defaults to ``nn.GELU``. - norm_callable (Callable[..., nn.Module]): Normalization layer module. - Defaults to ``partial(SyncBatchNorm, requires_grad=True)``. + normalization_callable (Callable[..., nn.Module]): Normalization layer module. + Defaults to ``partial(build_norm_layer, SyncBatchNorm, requires_grad=True)``. init_cfg (Optional[Union[Dict[str, str], List[Dict[str, str]]]]): Initialization config dict. Defaults to None. """ @@ -364,7 +364,7 @@ def __init__( attention_kernel_sizes: list[int | list[int]] = [5, [1, 7], [1, 11], [1, 21]], # noqa: B006 attention_kernel_paddings: list[int | list[int]] = [2, [0, 3], [0, 5], [0, 10]], # noqa: B006 activation_callable: Callable[..., nn.Module] = nn.GELU, - norm_callable: Callable[..., nn.Module] = partial(SyncBatchNorm, requires_grad=True), + normalization_callable: Callable[..., nn.Module] = partial(build_norm_layer, SyncBatchNorm, requires_grad=True), init_cfg: dict[str, str] | list[dict[str, str]] | None = None, pretrained_weights: str | None = None, ) -> None: @@ -378,14 +378,14 @@ def __init__( for i in range(num_stages): if i == 0: - patch_embed = StemConv(in_channels, embed_dims[0], norm_callable=norm_callable) + patch_embed = StemConv(in_channels, embed_dims[0], normalization_callable=normalization_callable) else: patch_embed = OverlapPatchEmbed( patch_size=7 if i == 0 else 3, stride=4 if i == 0 else 2, in_channels=in_channels if i == 0 else embed_dims[i - 1], embed_dim=embed_dims[i], - norm_callable=norm_callable, + normalization_callable=normalization_callable, ) block = nn.ModuleList( [ @@ -397,7 +397,7 @@ def __init__( drop=drop_rate, drop_path=dpr[cur + j], activation_callable=activation_callable, - norm_callable=norm_callable, + normalization_callable=normalization_callable, ) for j in range(depths[i]) ], diff --git a/src/otx/algo/segmentation/dino_v2_seg.py b/src/otx/algo/segmentation/dino_v2_seg.py index 60fed311664..c26ff365858 100644 --- a/src/otx/algo/segmentation/dino_v2_seg.py +++ b/src/otx/algo/segmentation/dino_v2_seg.py @@ -10,6 +10,7 @@ from torch.nn import SyncBatchNorm +from otx.algo.modules.norm import build_norm_layer from otx.algo.segmentation.backbones import DinoVisionTransformer from otx.algo.segmentation.heads import FCNHead from otx.core.model.segmentation import TorchVisionCompatibleModel @@ -30,7 +31,7 @@ class DinoV2Seg(BaseSegmModel): "out_index": [8, 9, 10, 11], } default_decode_head_configuration: ClassVar[dict[str, Any]] = { - "norm_callable": partial(SyncBatchNorm, requires_grad=True), + "normalization_callable": partial(build_norm_layer, SyncBatchNorm, requires_grad=True), "in_channels": [384, 384, 384, 384], "in_index": [0, 1, 2, 3], "input_transform": "resize_concat", diff --git a/src/otx/algo/segmentation/heads/base_segm_head.py b/src/otx/algo/segmentation/heads/base_segm_head.py index 0b7360d4a5e..9e8fb4a9818 100644 --- a/src/otx/algo/segmentation/heads/base_segm_head.py +++ b/src/otx/algo/segmentation/heads/base_segm_head.py @@ -24,7 +24,7 @@ class BaseSegmHead(nn.Module, metaclass=ABCMeta): channels (int): Number of channels in the feature map. num_classes (int): Number of classes for segmentation. dropout_ratio (float, optional): The dropout ratio. Defaults to 0.1. - norm_callable (Callable[..., nn.Module] | None): Normalization layer module. + normalization_callable (Callable[..., nn.Module] | None): Normalization layer module. Defaults to None. activation_callable (Callable[..., nn.Module] | None): Activation layer module. Defaults to ``nn.ReLU``. @@ -41,7 +41,7 @@ def __init__( channels: int, num_classes: int, dropout_ratio: float = 0.1, - norm_callable: Callable[..., nn.Module] | None = None, + normalization_callable: Callable[..., nn.Module] | None = None, activation_callable: Callable[..., nn.Module] | None = nn.ReLU, in_index: int | list[int] = -1, input_transform: str | None = None, @@ -55,7 +55,7 @@ def __init__( self.num_classes = num_classes self.input_transform = input_transform self.dropout_ratio = dropout_ratio - self.norm_callable = norm_callable + self.normalization_callable = normalization_callable self.activation_callable = activation_callable if self.input_transform is not None and not isinstance(in_index, list): msg = f'"in_index" expects a list, but got {type(in_index)}' diff --git a/src/otx/algo/segmentation/heads/fcn_head.py b/src/otx/algo/segmentation/heads/fcn_head.py index 9bc0331738f..3fabbe58632 100644 --- a/src/otx/algo/segmentation/heads/fcn_head.py +++ b/src/otx/algo/segmentation/heads/fcn_head.py @@ -11,6 +11,7 @@ from torch import Tensor, nn from otx.algo.modules import Conv2dModule +from otx.algo.modules.norm import build_norm_layer from otx.algo.segmentation.modules import IterativeAggregator from .base_segm_head import BaseSegmHead @@ -22,7 +23,7 @@ class FCNHead(BaseSegmHead): This head is implemented of `FCNNet `_. Args: - norm_callable (Callable[..., nn.Module] | None): Normalization layer module. + normalization_callable (Callable[..., nn.Module] | None): Normalization layer module. Defaults to None. num_convs (int): Number of convs in the head. Default: 2. kernel_size (int): The kernel size for convs in the head. Default: 3. @@ -35,7 +36,7 @@ def __init__( self, in_channels: list[int] | int, in_index: list[int] | int, - norm_callable: Callable[..., nn.Module] | None = None, + normalization_callable: Callable[..., nn.Module] | None = None, input_transform: str | None = None, num_convs: int = 2, kernel_size: int = 3, @@ -66,7 +67,7 @@ def __init__( aggregator = IterativeAggregator( in_channels=in_channels, min_channels=aggregator_min_channels, - norm_callable=norm_callable, + normalization_callable=normalization_callable, merge_norm=aggregator_merge_norm, use_concat=aggregator_use_concat, ) @@ -82,7 +83,7 @@ def __init__( super().__init__( in_index=in_index, - norm_callable=norm_callable, + normalization_callable=normalization_callable, input_transform=input_transform, in_channels=in_channels, **kwargs, @@ -102,7 +103,7 @@ def __init__( kernel_size=kernel_size, padding=conv_padding, dilation=dilation, - norm_callable=self.norm_callable, + normalization=build_norm_layer(self.normalization_callable, num_features=self.channels), activation_callable=self.activation_callable, ), ] @@ -114,7 +115,7 @@ def __init__( kernel_size=kernel_size, padding=conv_padding, dilation=dilation, - norm_callable=self.norm_callable, + normalization=build_norm_layer(self.normalization_callable, num_features=self.channels), activation_callable=self.activation_callable, ) for _ in range(num_convs - 1) @@ -130,7 +131,7 @@ def __init__( self.channels, kernel_size=kernel_size, padding=kernel_size // 2, - norm_callable=self.norm_callable, + normalization=build_norm_layer(self.normalization_callable, num_features=self.channels), activation_callable=self.activation_callable, ) diff --git a/src/otx/algo/segmentation/heads/ham_head.py b/src/otx/algo/segmentation/heads/ham_head.py index 3bd563d16df..db4d072f176 100644 --- a/src/otx/algo/segmentation/heads/ham_head.py +++ b/src/otx/algo/segmentation/heads/ham_head.py @@ -12,6 +12,7 @@ from torch import nn from otx.algo.modules import Conv2dModule +from otx.algo.modules.norm import build_norm_layer from otx.algo.segmentation.modules import resize from .base_segm_head import BaseSegmHead @@ -26,7 +27,7 @@ class Hamburger(nn.Module): Args: ham_channels (int): Input and output channels of feature. ham_kwargs (dict): Config of matrix decomposition module. - norm_callable (Callable[..., nn.Module] | None): Normalization layer module. + normalization_callable (Callable[..., nn.Module] | None): Normalization layer module. Defaults to None. """ @@ -34,13 +35,13 @@ def __init__( self, ham_channels: int, ham_kwargs: dict[str, Any], - norm_callable: Callable[..., nn.Module] | None = None, + normalization_callable: Callable[..., nn.Module] | None = None, **kwargs: Any, # noqa: ANN401 ) -> None: """Initialize Hamburger Module.""" super().__init__() - self.ham_in = Conv2dModule(ham_channels, ham_channels, 1, norm_callable=None, activation_callable=None) + self.ham_in = Conv2dModule(ham_channels, ham_channels, 1, normalization=None, activation_callable=None) self.ham = NMF2D(ham_channels=ham_channels, **ham_kwargs) @@ -48,7 +49,7 @@ def __init__( ham_channels, ham_channels, 1, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=ham_channels), activation_callable=None, ) @@ -102,7 +103,7 @@ def __init__( sum(self.in_channels), self.ham_channels, 1, - norm_callable=self.norm_callable, + normalization=build_norm_layer(self.normalization_callable, num_features=self.ham_channels), activation_callable=self.activation_callable, ) @@ -112,7 +113,7 @@ def __init__( self.ham_channels, self.channels, 1, - norm_callable=self.norm_callable, + normalization=build_norm_layer(self.normalization_callable, num_features=self.channels), activation_callable=self.activation_callable, ) diff --git a/src/otx/algo/segmentation/litehrnet.py b/src/otx/algo/segmentation/litehrnet.py index 3d84cd174fa..11e4fe3e9bf 100644 --- a/src/otx/algo/segmentation/litehrnet.py +++ b/src/otx/algo/segmentation/litehrnet.py @@ -11,6 +11,7 @@ from torch import nn from torch.onnx import OperatorExportTypes +from otx.algo.modules.norm import build_norm_layer from otx.algo.segmentation.backbones import LiteHRNet from otx.algo.segmentation.heads import FCNHead from otx.algo.utils.support_otx_v1 import OTXv1Helper @@ -25,7 +26,7 @@ class LiteHRNetS(BaseSegmModel): """LiteHRNetS Model.""" default_backbone_configuration: ClassVar[dict[str, Any]] = { - "norm_callable": partial(nn.BatchNorm2d, requires_grad=True), + "normalization_callable": partial(build_norm_layer, nn.BatchNorm2d, requires_grad=True), "norm_eval": False, "extra": { "stem": { @@ -57,7 +58,7 @@ class LiteHRNetS(BaseSegmModel): "pretrained_weights": "https://storage.openvinotoolkit.org/repositories/openvino_training_extensions/models/custom_semantic_segmentation/litehrnetsv2_imagenet1k_rsc.pth", } default_decode_head_configuration: ClassVar[dict[str, Any]] = { - "norm_callable": partial(nn.BatchNorm2d, requires_grad=True), + "normalization_callable": partial(build_norm_layer, nn.BatchNorm2d, requires_grad=True), "in_channels": [60, 120, 240], "in_index": [0, 1, 2], "input_transform": "multiple_select", @@ -179,7 +180,7 @@ class LiteHRNet18(BaseSegmModel): "pretrained_weights": "https://storage.openvinotoolkit.org/repositories/openvino_training_extensions/models/custom_semantic_segmentation/litehrnet18_imagenet1k_rsc.pth", } default_decode_head_configuration: ClassVar[dict[str, Any]] = { - "norm_callable": partial(nn.BatchNorm2d, requires_grad=True), + "normalization_callable": partial(build_norm_layer, nn.BatchNorm2d, requires_grad=True), "in_channels": [40, 80, 160, 320], "in_index": [0, 1, 2, 3], "input_transform": "multiple_select", @@ -289,7 +290,7 @@ class LiteHRNetX(BaseSegmModel): """LiteHRNetX Model.""" default_backbone_configuration: ClassVar[dict[str, Any]] = { - "norm_callable": partial(nn.BatchNorm2d, requires_grad=True), + "normalization_callable": partial(build_norm_layer, nn.BatchNorm2d, requires_grad=True), "norm_eval": False, "extra": { "stem": { @@ -322,7 +323,7 @@ class LiteHRNetX(BaseSegmModel): "pretrained_weights": "https://storage.openvinotoolkit.org/repositories/openvino_training_extensions/models/custom_semantic_segmentation/litehrnetxv3_imagenet1k_rsc.pth", } default_decode_head_configuration: ClassVar[dict[str, Any]] = { - "norm_callable": partial(nn.BatchNorm2d, requires_grad=True), + "normalization_callable": partial(build_norm_layer, nn.BatchNorm2d, requires_grad=True), "in_channels": [18, 60, 80, 160, 320], "in_index": [0, 1, 2, 3, 4], "input_transform": "multiple_select", diff --git a/src/otx/algo/segmentation/modules/aggregators.py b/src/otx/algo/segmentation/modules/aggregators.py index 52b606a26e8..22ea47ecfd9 100644 --- a/src/otx/algo/segmentation/modules/aggregators.py +++ b/src/otx/algo/segmentation/modules/aggregators.py @@ -12,6 +12,7 @@ from torch.nn import functional as f from otx.algo.modules import Conv2dModule, DepthwiseSeparableConvModule +from otx.algo.modules.norm import build_norm_layer from .utils import normalize @@ -24,7 +25,7 @@ class IterativeAggregator(nn.Module): Args: in_channels (list[int]): List of input channels for each branch. min_channels (int | None): Minimum number of channels. Defaults to None. - norm_callable (Callable[..., nn.Module] | None): Normalization layer module. + normalization_callable (Callable[..., nn.Module] | None): Normalization layer module. Defaults to ``nn.BatchNorm2d``. merge_norm (str | None): Whether to merge normalization layers. Defaults to None. use_concat (bool): Whether to use concatenation. Defaults to False. @@ -34,7 +35,7 @@ def __init__( self, in_channels: list[int], min_channels: int | None = None, - norm_callable: Callable[..., nn.Module] | None = nn.BatchNorm2d, + normalization_callable: Callable[..., nn.Module] | None = nn.BatchNorm2d, merge_norm: str | None = None, use_concat: bool = False, ) -> None: @@ -63,7 +64,7 @@ def __init__( out_channels=out_channels, kernel_size=1, stride=1, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=out_channels), activation_callable=nn.ReLU, ), ) @@ -80,7 +81,7 @@ def __init__( kernel_size=3, stride=1, padding=1, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=out_channels), activation_callable=nn.ReLU, dw_activation_callable=None, pw_activation_callable=nn.ReLU, @@ -94,7 +95,7 @@ def __init__( out_channels=min_channels, kernel_size=1, stride=1, - norm_callable=norm_callable, + normalization=build_norm_layer(normalization_callable, num_features=min_channels), activation_callable=nn.ReLU, ), ) diff --git a/src/otx/algo/segmentation/modules/blocks.py b/src/otx/algo/segmentation/modules/blocks.py index 06439ecf928..cb16ca47409 100644 --- a/src/otx/algo/segmentation/modules/blocks.py +++ b/src/otx/algo/segmentation/modules/blocks.py @@ -13,6 +13,7 @@ from torch.nn import AdaptiveAvgPool2d, AdaptiveMaxPool2d from otx.algo.modules import Conv2dModule +from otx.algo.modules.norm import build_norm_layer class PSPModule(nn.Module): @@ -54,7 +55,7 @@ def __init__( key_channels: int, value_channels: int | None = None, psp_size: tuple | None = None, - norm_callable: Callable[..., nn.Module] = nn.BatchNorm2d, + normalization_callable: Callable[..., nn.Module] = nn.BatchNorm2d, ): super().__init__() @@ -63,14 +64,14 @@ def __init__( self.value_channels = value_channels if value_channels is not None else in_channels if psp_size is None: psp_size = (1, 3, 6, 8) - self.norm_callable = norm_callable + self.normalization_callable = normalization_callable self.query_key = Conv2dModule( in_channels=self.in_channels, out_channels=self.key_channels, kernel_size=1, stride=1, padding=0, - norm_callable=self.norm_callable, + normalization=build_norm_layer(self.normalization_callable, num_features=self.key_channels), activation_callable=nn.ReLU, ) self.key_psp = PSPModule(psp_size, method="max") @@ -81,7 +82,7 @@ def __init__( kernel_size=1, stride=1, padding=0, - norm_callable=self.norm_callable, + normalization=build_norm_layer(self.normalization_callable, num_features=self.value_channels), activation_callable=nn.ReLU, ) self.value_psp = PSPModule(psp_size, method="max") @@ -92,7 +93,7 @@ def __init__( kernel_size=1, stride=1, padding=0, - norm_callable=self.norm_callable, + normalization=build_norm_layer(self.normalization_callable, num_features=self.in_channels), activation_callable=None, ) @@ -156,12 +157,12 @@ class LocalAttentionModule(nn.Module): def __init__( self, num_channels: int, - norm_callable: Callable[..., nn.Module] = nn.BatchNorm2d, + normalization_callable: Callable[..., nn.Module] = nn.BatchNorm2d, ): super().__init__() self.num_channels = num_channels - self.norm_callable = norm_callable + self.normalization_callable = normalization_callable self.dwconv1 = Conv2dModule( in_channels=self.num_channels, @@ -170,7 +171,7 @@ def __init__( stride=2, padding=1, groups=self.num_channels, - norm_callable=self.norm_callable, + normalization=build_norm_layer(self.normalization_callable, num_features=self.num_channels), activation_callable=nn.ReLU, ) self.dwconv2 = Conv2dModule( @@ -180,7 +181,7 @@ def __init__( stride=2, padding=1, groups=self.num_channels, - norm_callable=self.norm_callable, + normalization=build_norm_layer(self.normalization_callable, num_features=self.num_channels), activation_callable=nn.ReLU, ) self.dwconv3 = Conv2dModule( @@ -190,7 +191,7 @@ def __init__( stride=2, padding=1, groups=self.num_channels, - norm_callable=self.norm_callable, + normalization=build_norm_layer(self.normalization_callable, num_features=self.num_channels), activation_callable=nn.ReLU, ) self.sigmoid_spatial = nn.Sigmoid() diff --git a/src/otx/algo/segmentation/segnext.py b/src/otx/algo/segmentation/segnext.py index fb92ef29b21..f4a0a92e7b5 100644 --- a/src/otx/algo/segmentation/segnext.py +++ b/src/otx/algo/segmentation/segnext.py @@ -10,6 +10,7 @@ from torch import nn +from otx.algo.modules.norm import build_norm_layer from otx.algo.segmentation.backbones import MSCAN from otx.algo.segmentation.heads import LightHamHead from otx.algo.utils.support_otx_v1 import OTXv1Helper @@ -30,14 +31,14 @@ class SegNextB(BaseSegmModel): "drop_rate": 0.0, "embed_dims": [64, 128, 320, 512], "mlp_ratios": [8, 8, 4, 4], - "norm_callable": partial(nn.BatchNorm2d, requires_grad=True), + "normalization_callable": partial(build_norm_layer, nn.BatchNorm2d, requires_grad=True), "pretrained_weights": "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segnext/mscan_b_20230227-3ab7d230.pth", } default_decode_head_configuration: ClassVar[dict[str, Any]] = { "ham_kwargs": {"md_r": 16, "md_s": 1, "eval_steps": 7, "train_steps": 6}, "in_channels": [128, 320, 512], "in_index": [1, 2, 3], - "norm_callable": partial(nn.GroupNorm, num_groups=32, requires_grad=True), + "normalization_callable": partial(build_norm_layer, nn.GroupNorm, num_groups=32, requires_grad=True), "align_corners": False, "channels": 512, "dropout_ratio": 0.1, @@ -57,11 +58,11 @@ class SegNextS(BaseSegmModel): "drop_rate": 0.0, "embed_dims": [64, 128, 320, 512], "mlp_ratios": [8, 8, 4, 4], - "norm_callable": partial(nn.BatchNorm2d, requires_grad=True), + "normalization_callable": partial(build_norm_layer, nn.BatchNorm2d, requires_grad=True), "pretrained_weights": "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segnext/mscan_s_20230227-f33ccdf2.pth", } default_decode_head_configuration: ClassVar[dict[str, Any]] = { - "norm_callable": partial(nn.GroupNorm, num_groups=32, requires_grad=True), + "normalization_callable": partial(build_norm_layer, nn.GroupNorm, num_groups=32, requires_grad=True), "ham_kwargs": {"md_r": 16, "md_s": 1, "eval_steps": 7, "rand_init": True, "train_steps": 6}, "in_channels": [128, 320, 512], "in_index": [1, 2, 3], @@ -84,12 +85,12 @@ class SegNextT(BaseSegmModel): "drop_rate": 0.0, "embed_dims": [32, 64, 160, 256], "mlp_ratios": [8, 8, 4, 4], - "norm_callable": partial(nn.BatchNorm2d, requires_grad=True), + "normalization_callable": partial(build_norm_layer, nn.BatchNorm2d, requires_grad=True), "pretrained_weights": "https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/segnext/mscan_t_20230227-119e8c9f.pth", } default_decode_head_configuration: ClassVar[dict[str, Any]] = { "ham_kwargs": {"md_r": 16, "md_s": 1, "eval_steps": 7, "rand_init": True, "train_steps": 6}, - "norm_callable": partial(nn.GroupNorm, num_groups=32, requires_grad=True), + "normalization_callable": partial(build_norm_layer, nn.GroupNorm, num_groups=32, requires_grad=True), "in_channels": [64, 160, 256], "in_index": [1, 2, 3], "align_corners": False, diff --git a/tests/unit/algo/detection/backbones/test_presnet.py b/tests/unit/algo/detection/backbones/test_presnet.py index 3cbdb6a5609..27cbb32a641 100644 --- a/tests/unit/algo/detection/backbones/test_presnet.py +++ b/tests/unit/algo/detection/backbones/test_presnet.py @@ -3,9 +3,12 @@ # """Test of Presnet.""" +from functools import partial + import torch from otx.algo.detection.backbones.presnet import PResNet from otx.algo.modules import FrozenBatchNorm2d +from otx.algo.modules.norm import build_norm_layer class TestPresnet: @@ -26,7 +29,10 @@ def test_presnet_freeze_parameters(self): assert not param.requires_grad def test_presnet_freeze_norm(self): - model = PResNet(depth=50, norm_callable=FrozenBatchNorm2d, norm_name="norm") + model = PResNet( + depth=50, + normalization_callable=partial(build_norm_layer, FrozenBatchNorm2d, layer_name="norm"), + ) for name, param in model.named_parameters(): if "norm" in name: assert isinstance(param, FrozenBatchNorm2d) diff --git a/tests/unit/algo/modules/test_conv_module.py b/tests/unit/algo/modules/test_conv_module.py index b5ff496b863..5f32efbea7e 100644 --- a/tests/unit/algo/modules/test_conv_module.py +++ b/tests/unit/algo/modules/test_conv_module.py @@ -8,9 +8,8 @@ import pytest import torch from otx.algo.modules.conv_module import Conv2dModule, DepthwiseSeparableConvModule -from torch import nn - from otx.algo.modules.norm import build_norm_layer +from torch import nn def test_conv_module_with_unsupported_activation(): diff --git a/tests/unit/algo/modules/test_transformer.py b/tests/unit/algo/modules/test_transformer.py index 6898eb0b9f6..8c13549804d 100644 --- a/tests/unit/algo/modules/test_transformer.py +++ b/tests/unit/algo/modules/test_transformer.py @@ -88,7 +88,7 @@ def test_patch_embed(): stride=stride, padding=0, dilation=1, - norm_callable=None, + normalization_callable=None, ) x1, shape = patch_merge_1(dummy_input) @@ -115,7 +115,7 @@ def test_patch_embed(): stride=stride, padding=0, dilation=2, - norm_callable=None, + normalization_callable=None, ) x2, shape = patch_merge_2(dummy_input) @@ -138,7 +138,7 @@ def test_patch_embed(): stride=stride, padding=0, dilation=2, - norm_callable=nn.LayerNorm, + normalization_callable=nn.LayerNorm, input_size=input_size, ) @@ -165,7 +165,7 @@ def test_patch_embed(): stride=stride, padding=0, dilation=2, - norm_callable=nn.LayerNorm, + normalization_callable=nn.LayerNorm, input_size=input_size, ) @@ -184,7 +184,7 @@ def test_patch_embed(): stride=stride, padding=0, dilation=2, - norm_callable=nn.LayerNorm, + normalization_callable=nn.LayerNorm, input_size=input_size, ) diff --git a/tests/unit/algo/segmentation/heads/test_ham_head.py b/tests/unit/algo/segmentation/heads/test_ham_head.py index bee93378a71..67835746139 100644 --- a/tests/unit/algo/segmentation/heads/test_ham_head.py +++ b/tests/unit/algo/segmentation/heads/test_ham_head.py @@ -5,6 +5,7 @@ import pytest import torch +from otx.algo.modules.norm import build_norm_layer from otx.algo.segmentation.heads.ham_head import LightHamHead from torch import nn @@ -16,7 +17,7 @@ def head_config(self) -> dict[str, Any]: "ham_kwargs": {"md_r": 16, "md_s": 1, "eval_steps": 7, "train_steps": 6}, "in_channels": [128, 320, 512], "in_index": [1, 2, 3], - "norm_callable": partial(nn.GroupNorm, num_groups=32, requires_grad=True), + "normalization_callable": partial(build_norm_layer, nn.GroupNorm, num_groups=32, requires_grad=True), "align_corners": False, "channels": 512, "dropout_ratio": 0.1, diff --git a/tests/unit/algo/segmentation/modules/test_blokcs.py b/tests/unit/algo/segmentation/modules/test_blokcs.py index dbd8190a1f0..e752da0fc03 100644 --- a/tests/unit/algo/segmentation/modules/test_blokcs.py +++ b/tests/unit/algo/segmentation/modules/test_blokcs.py @@ -16,7 +16,7 @@ def init_cfg(self) -> dict[str, Any]: "key_channels": 128, "value_channels": 320, "psp_size": [1, 3, 6, 8], - "norm_callable": nn.BatchNorm2d, + "normalization_callable": nn.BatchNorm2d, } def test_init(self, init_cfg): @@ -25,7 +25,7 @@ def test_init(self, init_cfg): assert module.in_channels == init_cfg["in_channels"] assert module.key_channels == init_cfg["key_channels"] assert module.value_channels == init_cfg["value_channels"] - assert module.norm_callable == init_cfg["norm_callable"] + assert module.normalization_callable == init_cfg["normalization_callable"] @pytest.fixture() def fake_input(self) -> torch.Tensor: @@ -43,14 +43,14 @@ class TestLocalAttentionModule: def init_cfg(self) -> dict[str, Any]: return { "num_channels": 320, - "norm_callable": nn.BatchNorm2d, + "normalization_callable": nn.BatchNorm2d, } def test_init(self, init_cfg): module = LocalAttentionModule(**init_cfg) assert module.num_channels == init_cfg["num_channels"] - assert module.norm_callable == init_cfg["norm_callable"] + assert module.normalization_callable == init_cfg["normalization_callable"] @pytest.fixture() def fake_input(self) -> torch.Tensor: