Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/sg 849 add replace in channels #1557

Merged
merged 37 commits into from
Oct 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
1085517
wip
Louis-Dupont Oct 19, 2023
906b351
first draft
Louis-Dupont Oct 21, 2023
356fb18
wip
Louis-Dupont Oct 21, 2023
0c9c4d4
remove self.in_channels
Louis-Dupont Oct 21, 2023
2f2c43b
add get_input_channels to all SgModel
Louis-Dupont Oct 21, 2023
0cb05a9
ADD TESTS
Louis-Dupont Oct 21, 2023
3be5c64
SupportsReplaceInChannels -> SupportsReplaceInputChannels
Louis-Dupont Oct 21, 2023
8d654eb
remove unwanted comment
Louis-Dupont Oct 21, 2023
79e861e
Merge branch 'master' into feature/SG-849-add_replace_in_channels
Louis-Dupont Oct 21, 2023
975b0ed
replace_in_channels -> replace_input_channels
Louis-Dupont Oct 21, 2023
6d43aae
remove unwanted comment
Louis-Dupont Oct 21, 2023
87ce5a2
add docstring
Louis-Dupont Oct 22, 2023
622c61e
rename replace_input_channels_with_random_weights -> replace_conv2d_i…
Louis-Dupont Oct 23, 2023
6964249
Merge branch 'master' into feature/SG-849-add_replace_in_channels
Louis-Dupont Oct 23, 2023
db9cec1
fix
Louis-Dupont Oct 23, 2023
cd0b67e
Merge branch 'master' into hotfix/SG-000-fix_csp_darknet53_forward
Louis-Dupont Oct 23, 2023
4afccd0
update test to also run foward
Louis-Dupont Oct 23, 2023
b5144b9
Merge branch 'master' into feature/SG-849-add_replace_in_channels
Louis-Dupont Oct 23, 2023
e74b8be
Merge branch 'hotfix/SG-000-fix_csp_darknet53_forward' into feature/S…
Louis-Dupont Oct 23, 2023
d54f268
add pretrained test
Louis-Dupont Oct 23, 2023
37ab77b
add minor docstring
Louis-Dupont Oct 23, 2023
c386a30
Merge branch 'master' into feature/SG-849-add_replace_in_channels
Louis-Dupont Oct 23, 2023
3d114af
use existing channels when replacing
Louis-Dupont Oct 23, 2023
ab760e8
set self.in_channels in replace_input_channels
Louis-Dupont Oct 24, 2023
287b28f
add num_input_channels in models.get
Louis-Dupont Oct 24, 2023
12170e7
Merge branch 'master' into feature/SG-849-add_replace_in_channels
Louis-Dupont Oct 24, 2023
c56d553
automatically set self.input_channels when calling replace_input_chan…
Louis-Dupont Oct 24, 2023
b075062
Merge branch 'master' into feature/SG-849-add_replace_in_channels
Louis-Dupont Oct 25, 2023
cccecb6
Merge branch 'master' into feature/SG-849-add_replace_in_channels
Louis-Dupont Oct 25, 2023
9318806
Merge branch 'master' into feature/SG-849-add_replace_in_channels
Louis-Dupont Oct 26, 2023
bf64cad
remove #TODO
Louis-Dupont Oct 26, 2023
9d86c3e
add to train_from_recipe
Louis-Dupont Oct 26, 2023
0a38118
add to
Louis-Dupont Oct 26, 2023
a5e3c86
add to kd
Louis-Dupont Oct 26, 2023
01ed612
Merge branch 'master' into feature/SG-849-add_replace_in_channels
Louis-Dupont Oct 29, 2023
503129f
make inherit from
Louis-Dupont Oct 29, 2023
df11873
Merge branch 'master' into feature/SG-849-add_replace_in_channels
Louis-Dupont Oct 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions src/super_gradients/common/deprecate.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,36 @@ def wrapper(*args, **training_params):
return wrapper

return decorator


def deprecate_param(
deprecated_param_name: str,
new_param_name: str = "",
deprecated_since: str = "",
removed_from: str = "",
reason: str = "",
):
"""
Utility function to warn about a deprecated parameter (or dictionary key).

:param deprecated_param_name: Name of the deprecated parameter.
:param new_param_name: Name of the new parameter/key that should replace the deprecated one.
:param deprecated_since: Version number when the parameter was deprecated.
:param removed_from: Version number when the parameter will be removed.
:param reason: Additional information or reason for the deprecation.
"""
is_still_supported = deprecated_since < removed_from
status_msg = "is deprecated" if is_still_supported else "was deprecated and has been removed"
message = f"Parameter `{deprecated_param_name}` {status_msg} " f"since version `{deprecated_since}` and will be removed in version `{removed_from}`.\n"

if reason:
message += f"Reason: {reason}.\n"

if new_param_name:
message += f"Please update your code to use the `{new_param_name}` instead of `{deprecated_param_name}`."

if is_still_supported:
warnings.simplefilter("once", DeprecationWarning) # Required, otherwise the warning may never be displayed.
warnings.warn(message, DeprecationWarning)
else:
raise ValueError(message)
3 changes: 2 additions & 1 deletion src/super_gradients/module_interfaces/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .module_interfaces import HasPredict, HasPreprocessingParams, SupportsReplaceNumClasses
from .module_interfaces import HasPredict, HasPreprocessingParams, SupportsReplaceNumClasses, SupportsReplaceInputChannels
from .exceptions import ModelHasNoPreprocessingParamsException
from .exportable_detector import ExportableObjectDetectionModel, AbstractObjectDetectionDecodingModule
from .exportable_pose_estimation import ExportablePoseEstimationModel, PoseEstimationModelExportResult, AbstractPoseEstimationDecodingModule
Expand All @@ -8,6 +8,7 @@
"HasPredict",
"HasPreprocessingParams",
"SupportsReplaceNumClasses",
"SupportsReplaceInputChannels",
"ExportableObjectDetectionModel",
"AbstractObjectDetectionDecodingModule",
"ModelHasNoPreprocessingParamsException",
Expand Down
32 changes: 32 additions & 0 deletions src/super_gradients/module_interfaces/module_interfaces.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from abc import ABC
from typing import Callable, Optional, TYPE_CHECKING

from torch import nn
Expand Down Expand Up @@ -69,3 +70,34 @@ def replace_num_classes(self, num_classes: int, compute_new_weights_fn: Callable
:return: None
"""
raise NotImplementedError(f"replace_num_classes is not implemented in the derived class {self.__class__.__name__}")


class SupportsReplaceInputChannels(ABC):
shaydeci marked this conversation as resolved.
Show resolved Hide resolved
"""
Protocol interface for modules that support replacing the number of input channels.
Derived classes should implement the `replace_input_channels` method.

This interface class serves the purpose of explicitly indicating whether a class supports optimized input channel replacement:

>>> class InputLayer(nn_Module, SupportsReplaceInputChannels):
>>> def replace_input_channels(self, in_channels: int, compute_new_weights_fn: Callable[[nn.Module, int], nn.Module] = None):
>>> ...

"""

def replace_input_channels(self, in_channels: int, compute_new_weights_fn: Optional[Callable[[nn.Module, int], nn.Module]]):
"""
Replace the number of input channels in the module.

:param in_channels: New number of input channels.
:param compute_new_weights_fn: (Optional) function that computes the new weights for the new input channels.
It takes the existing nn_Module and returns a new one.
"""
raise NotImplementedError(f"`replace_input_channels` is not implemented in the derived class `{self.__class__.__name__}`")

def get_input_channels(self) -> int:
"""Get the number of input channels for the model.

:return: Number of input channels.
"""
raise NotImplementedError(f"`get_input_channels` is not implemented in the derived class `{self.__class__.__name__}`")
23 changes: 20 additions & 3 deletions src/super_gradients/modules/conv_bn_act_block.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from typing import Union, Tuple, Type
from typing import Union, Tuple, Type, Callable, Optional

from torch import nn

from super_gradients.modules.utils import autopad
from super_gradients.module_interfaces import SupportsReplaceInputChannels


class ConvBNAct(nn.Module):
class ConvBNAct(nn.Module, SupportsReplaceInputChannels):
"""
Class for Convolution2d-Batchnorm2d-Activation layer.
Default behaviour is Conv-BN-Act. To exclude Batchnorm module use
Expand Down Expand Up @@ -67,8 +68,16 @@ def __init__(
def forward(self, x):
return self.seq(x)

def replace_input_channels(self, in_channels: int, compute_new_weights_fn: Optional[Callable[[nn.Module, int], nn.Module]] = None):
from super_gradients.modules.weight_replacement_utils import replace_conv2d_input_channels

class Conv(nn.Module):
self.seq[0] = replace_conv2d_input_channels(conv=self.seq[0], in_channels=in_channels, fn=compute_new_weights_fn)

def get_input_channels(self) -> int:
return self.seq[0].in_channels


class Conv(nn.Module, SupportsReplaceInputChannels):
# STANDARD CONVOLUTION
# TODO: This class is illegaly similar to ConvBNAct, and the only reason it exists is due to fact that some models were using it
# previosly and one have to find a bulletproof way drop this class but still be able to load models that were using it. Perhaps
Expand All @@ -85,3 +94,11 @@ def forward(self, x):

def fuseforward(self, x):
return self.act(self.conv(x))

def replace_input_channels(self, in_channels: int, compute_new_weights_fn: Optional[Callable[[nn.Module, int], nn.Module]] = None):
from super_gradients.modules.weight_replacement_utils import replace_conv2d_input_channels

self.conv = replace_conv2d_input_channels(conv=self.conv, in_channels=in_channels, fn=compute_new_weights_fn)

def get_input_channels(self) -> int:
return self.conv.in_channels
31 changes: 26 additions & 5 deletions src/super_gradients/modules/detection_modules.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from abc import ABC, abstractmethod
from typing import Union, List
from typing import Union, List, Optional, Callable

import torch
from torch import nn
from omegaconf import DictConfig
from omegaconf.listconfig import ListConfig

from super_gradients.common.registry.registry import register_detection_module
from super_gradients.modules.base_modules import BaseDetectionModule
from super_gradients.modules.multi_output_modules import MultiOutputModule
from super_gradients.training.models import MobileNet, MobileNetV2
from super_gradients.training.models.classification_models.mobilenetv2 import InvertedResidual
from super_gradients.training.utils.utils import HpmStruct
from torch import nn
from super_gradients.module_interfaces import SupportsReplaceInputChannels


__all__ = [
"PANNeck",
Expand All @@ -28,7 +31,7 @@


@register_detection_module()
class NStageBackbone(BaseDetectionModule):
class NStageBackbone(BaseDetectionModule, SupportsReplaceInputChannels):
"""
A backbone with a stem -> N stages -> context module
Returns outputs of the layers listed in out_layers
Expand Down Expand Up @@ -83,6 +86,18 @@ def forward(self, x):

return outputs

def replace_input_channels(self, in_channels: int, compute_new_weights_fn: Optional[Callable[[nn.Module, int], nn.Module]] = None):
if isinstance(self.stem, SupportsReplaceInputChannels):
self.stem.replace_input_channels(in_channels=in_channels, compute_new_weights_fn=compute_new_weights_fn)
else:
raise NotImplementedError(f"`{self.stem.__class__.__name__}` does not support `replace_input_channels`")

def get_input_channels(self) -> int:
if isinstance(self.stem, SupportsReplaceInputChannels):
return self.stem.get_input_channels()
else:
raise NotImplementedError(f"`{self.stem.__class__.__name__}` does not support `get_input_channels`")


@register_detection_module()
class PANNeck(BaseDetectionModule):
Expand Down Expand Up @@ -176,14 +191,14 @@ def combine_preds(self, preds):
return outputs if self.training else (torch.cat(outputs, 1), outputs_logits)


class MultiOutputBackbone(BaseDetectionModule):
class MultiOutputBackbone(BaseDetectionModule, SupportsReplaceInputChannels):
"""
Defines a backbone using MultiOutputModule with the interface of BaseDetectionModule
"""

def __init__(self, in_channels: int, backbone: nn.Module, out_layers: List):
super().__init__(in_channels)
self.multi_output_backbone = MultiOutputModule(backbone, out_layers)
self.multi_output_backbone = MultiOutputModule(module=backbone, output_paths=out_layers)
self._out_channels = [x.shape[1] for x in self.forward(torch.empty((1, in_channels, 64, 64)))]

@property
Expand All @@ -193,6 +208,12 @@ def out_channels(self) -> Union[List[int], int]:
def forward(self, x):
return self.multi_output_backbone(x)

def replace_input_channels(self, in_channels: int, compute_new_weights_fn: Optional[Callable[[nn.Module, int], nn.Module]] = None):
self.multi_output_backbone.replace_input_channels(in_channels=in_channels, compute_new_weights_fn=compute_new_weights_fn)

def get_input_channels(self) -> int:
return self.multi_output_backbone.get_input_channels()


@register_detection_module()
class MobileNetV1Backbone(MultiOutputBackbone):
Expand Down
19 changes: 18 additions & 1 deletion src/super_gradients/modules/multi_output_modules.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
from collections import OrderedDict
from typing import Optional, Callable
from torch import nn
from omegaconf.listconfig import ListConfig

from super_gradients.module_interfaces import SupportsReplaceInputChannels

class MultiOutputModule(nn.Module):

class MultiOutputModule(nn.Module, SupportsReplaceInputChannels):
"""
This module wraps around a container nn.Module (such as Module, Sequential and ModuleList) and allows to extract
multiple output from its inner modules on each forward call() (as a list of output tensors)
Expand Down Expand Up @@ -99,3 +102,17 @@ def _prune(self, module: nn.Module, output_paths: list):
def _slice_odict(self, odict: OrderedDict, start: int, end: int):
"""Slice an OrderedDict in the same logic list,tuple... are sliced"""
return OrderedDict([(k, v) for (k, v) in odict.items() if k in list(odict.keys())[start:end]])

def replace_input_channels(self, in_channels: int, compute_new_weights_fn: Optional[Callable[[nn.Module, int], nn.Module]] = None):
module = self._modules["0"]
if isinstance(module, SupportsReplaceInputChannels):
module.replace_input_channels(in_channels=in_channels, compute_new_weights_fn=compute_new_weights_fn)
else:
raise NotImplementedError(f"`{module.__class__.__name__}` does not support `replace_input_channels`")

def get_input_channels(self) -> int:
module = self._modules["0"]
if isinstance(module, SupportsReplaceInputChannels):
return module.get_input_channels()
else:
raise NotImplementedError(f"`{module.__class__.__name__}` does not support `get_input_channels`")
65 changes: 65 additions & 0 deletions src/super_gradients/modules/weight_replacement_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from typing import Optional, Callable

import torch
from torch import nn

__all__ = ["replace_conv2d_input_channels", "replace_conv2d_input_channels_with_random_weights"]


def replace_conv2d_input_channels(conv: nn.Conv2d, in_channels: int, fn: Optional[Callable[[nn.Conv2d, int], nn.Conv2d]] = None) -> nn.Module:
"""Instantiate a new Conv2d module with same attributes as input Conv2d, except for the input channels.

:param conv: Conv2d to replace the input channels in.
:param in_channels: New number of input channels.
:param fn: (Optional) Function to instantiate the new Conv2d.
By default, it will initialize the new weights with the same mean and std as the original weights.
:return: Conv2d with new number of input channels.
"""
if fn:
return fn(conv, in_channels)
else:
return replace_conv2d_input_channels_with_random_weights(conv=conv, in_channels=in_channels)


def replace_conv2d_input_channels_with_random_weights(conv: nn.Conv2d, in_channels: int) -> nn.Conv2d:
"""
Replace the input channels in the input Conv2d with random weights.
Returned module will have the same device and dtype as the original module.
Random weights are initialized with the same mean and std as the original weights.

:param conv: Conv2d to replace the input channels in.
:param in_channels: New number of input channels.
:return: Conv2d with new number of input channels.
"""

if in_channels % conv.groups != 0:
raise ValueError(
f"Incompatible number of input channels ({in_channels}) with the number of groups ({conv.groups})."
f"The number of input channels must be divisible by the number of groups."
)

new_conv = nn.Conv2d(
in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
dilation=conv.dilation,
groups=conv.groups,
bias=conv.bias is not None,
device=conv.weight.device,
dtype=conv.weight.dtype,
)

if in_channels <= conv.in_channels:
new_conv.weight.data = conv.weight.data[:, :in_channels, ...]
else:
new_conv.weight.data[:, : conv.in_channels, ...] = conv.weight.data

# Pad the remaining channels with random weights
torch.nn.init.normal_(new_conv.weight.data[:, conv.in_channels :, ...], mean=conv.weight.mean().item(), std=conv.weight.std().item())

if conv.bias is not None:
torch.nn.init.normal_(new_conv.bias, mean=conv.bias.mean().item(), std=conv.bias.std().item())

return new_conv
2 changes: 2 additions & 0 deletions src/super_gradients/training/kd_trainer/kd_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> None:
checkpoint_path=cfg.student_checkpoint_params.checkpoint_path,
load_backbone=cfg.student_checkpoint_params.load_backbone,
checkpoint_num_classes=get_param(cfg.student_checkpoint_params, "checkpoint_num_classes"),
num_input_channels=get_param(cfg.student_arch_params, "num_input_channels"),
)

teacher = models.get(
Expand All @@ -87,6 +88,7 @@ def train_from_config(cls, cfg: Union[DictConfig, dict]) -> None:
checkpoint_path=cfg.teacher_checkpoint_params.checkpoint_path,
load_backbone=cfg.teacher_checkpoint_params.load_backbone,
checkpoint_num_classes=get_param(cfg.teacher_checkpoint_params, "checkpoint_num_classes"),
num_input_channels=get_param(cfg.teacher_arch_params, "num_input_channels"),
)

recipe_logged_cfg = {"recipe_config": OmegaConf.to_container(cfg, resolve=True)}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
# --------------------------------------------------------'
import math
from functools import partial
from typing import Optional, Tuple
from typing import Optional, Tuple, Callable

import torch
import torch.nn as nn
Expand Down Expand Up @@ -322,7 +322,9 @@ def __init__(
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.grad_checkpointing = False

self.patch_embed = PatchEmbed(img_size=image_size, patch_size=patch_size, in_channels=in_chans, hidden_dim=embed_dim)
self.image_size = image_size
self.patch_size = patch_size
self.patch_embed = PatchEmbed(img_size=self.image_size, patch_size=self.patch_size, in_channels=in_chans, hidden_dim=self.embed_dim)
num_patches = self.patch_embed.num_patches

self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
Expand Down Expand Up @@ -453,6 +455,12 @@ def replace_head(self, new_num_classes=None, new_head=None):
else:
self.head = nn.Linear(self.head.in_features, new_num_classes)

def replace_input_channels(self, in_channels: int, compute_new_weights_fn: Optional[Callable[[nn.Module, int], nn.Module]] = None):
self.patch_embed = PatchEmbed(img_size=self.image_size, patch_size=self.patch_size, in_channels=in_channels, hidden_dim=self.embed_dim)

def get_input_channels(self) -> int:
return self.patch_embed.get_input_channels()


@register_model(Models.BEIT_BASE_PATCH16_224)
class BeitBasePatch16_224(Beit):
Expand Down
Loading