From 6682748c888d8272f9a0ff2cbc31ccece6778551 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 14 Jan 2022 14:18:47 +0000 Subject: [PATCH 01/17] Adding CNBlock and skeleton architecture --- torchvision/models/convnext.py | 105 +++++++++++++++++++++++++++++++++ 1 file changed, 105 insertions(+) create mode 100644 torchvision/models/convnext.py diff --git a/torchvision/models/convnext.py b/torchvision/models/convnext.py new file mode 100644 index 00000000000..029fd6acae0 --- /dev/null +++ b/torchvision/models/convnext.py @@ -0,0 +1,105 @@ +import torch + +from functools import partial +from torch import nn, Tensor +from typing import Any, Callable, List, Optional, Sequence + +from ..ops.misc import ConvNormActivation +from ..ops.stochastic_depth import StochasticDepth +from ..utils import _log_api_usage_once + + +class CNBlock(nn.Module): + def __init__(self, dim, + stochastic_depth_prob: float, + norm_layer: Callable[..., nn.Module], + layer_scale: Optional[float] = 1e-6): + super().__init__() + self.block = nn.Sequential( + ConvNormActivation( + dim, + dim, + kernel_size=7, + groups=dim, + norm_layer=norm_layer, + activation_layer=None, + bias=True, # TODO: check + ), + ConvNormActivation( + dim, + 4 * dim, + kernel_size=1, + norm_layer=None, + activation_layer=nn.GELU, + ), + ConvNormActivation( + 4 * dim, + dim, + kernel_size=1, + norm_layer=None, + activation_layer=None, + ) + ) + self.layer_scale = nn.Parameter(torch.ones(dim) * layer_scale) + self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") + + def forward(self, input: Tensor) -> Tensor: + result = self.layer_scale * self.block(input) + result = self.stochastic_depth(result) + result += input + return result + + +class CNBlockConfig: + # Stores information listed at Section 3 of the ConvNeXt paper + def __init__( + self, + input_channels: int, + out_channels: int, + num_layers: int, + ) -> None: + self.input_channels = input_channels + self.out_channels = out_channels + self.num_layers = num_layers + + def __repr__(self) -> str: + s = self.__class__.__name__ + "(" + s += "input_channels={input_channels}" + s += ", out_channels={out_channels}" + s += ", num_layers={num_layers}" + s += ")" + return s.format(**self.__dict__) + + + +class ConvNeXt(nn.Module): + def __init__( + self, + block_setting: List[CNBlockConfig], + stochastic_depth_prob: float = 0.0, + layer_scale: float = 1e-6, + num_classes: int = 1000, + block: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + **kwargs: Any, + ) -> None: + super().__init__() + _log_api_usage_once(self) + + if not block_setting: + raise ValueError("The block_setting should not be empty") + elif not ( + isinstance(block_setting, Sequence) + and all([isinstance(s, CNBlockConfig) for s in block_setting]) + ): + raise TypeError("The block_setting should be List[CNBlockConfig]") + + if block is None: + block = CNBlock + + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + layers: List[nn.Module] = [ + + ] From 6c49ef8551426c16d215257a097266eb56dc1235 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 14 Jan 2022 20:51:43 +0000 Subject: [PATCH 02/17] Completed implementation. --- torchvision/models/__init__.py | 1 + torchvision/models/convnext.py | 124 ++++++++++++++++++++++++++------- torchvision/ops/misc.py | 5 +- 3 files changed, 103 insertions(+), 27 deletions(-) diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 22e2e45d4ce..16495e8552e 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -1,4 +1,5 @@ from .alexnet import * +from .convnext import * from .resnet import * from .vgg import * from .squeezenet import * diff --git a/torchvision/models/convnext.py b/torchvision/models/convnext.py index 029fd6acae0..ab69c0046c2 100644 --- a/torchvision/models/convnext.py +++ b/torchvision/models/convnext.py @@ -1,19 +1,36 @@ -import torch - from functools import partial -from torch import nn, Tensor from typing import Any, Callable, List, Optional, Sequence +import torch +from torch import nn, Tensor + from ..ops.misc import ConvNormActivation from ..ops.stochastic_depth import StochasticDepth from ..utils import _log_api_usage_once +__all__ = [ + "ConvNeXt", + "convnext_tiny", +] + + +class LayerNorm(nn.LayerNorm): + def __init__(self, *args: Any, **kwargs: Any) -> None: + self.channels_last = kwargs.pop("channels_last", False) + super().__init__(*args, **kwargs) + + def forward(self, x): + if not self.channels_last: + x = x.permute(0, 2, 3, 1) + x = super().forward(x) + if not self.channels_last: + x = x.permute(0, 3, 1, 2) + return x + + class CNBlock(nn.Module): - def __init__(self, dim, - stochastic_depth_prob: float, - norm_layer: Callable[..., nn.Module], - layer_scale: Optional[float] = 1e-6): + def __init__(self, dim, layer_scale: float, stochastic_depth_prob: float, norm_layer: Callable[..., nn.Module]): super().__init__() self.block = nn.Sequential( ConvNormActivation( @@ -25,22 +42,16 @@ def __init__(self, dim, activation_layer=None, bias=True, # TODO: check ), - ConvNormActivation( - dim, - 4 * dim, - kernel_size=1, - norm_layer=None, - activation_layer=nn.GELU, - ), + ConvNormActivation(dim, 4 * dim, kernel_size=1, norm_layer=None, activation_layer=nn.GELU, inplace=None), ConvNormActivation( 4 * dim, dim, kernel_size=1, norm_layer=None, activation_layer=None, - ) + ), ) - self.layer_scale = nn.Parameter(torch.ones(dim) * layer_scale) + self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale) self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") def forward(self, input: Tensor) -> Tensor: @@ -55,7 +66,7 @@ class CNBlockConfig: def __init__( self, input_channels: int, - out_channels: int, + out_channels: Optional[int], num_layers: int, ) -> None: self.input_channels = input_channels @@ -71,7 +82,6 @@ def __repr__(self) -> str: return s.format(**self.__dict__) - class ConvNeXt(nn.Module): def __init__( self, @@ -88,18 +98,82 @@ def __init__( if not block_setting: raise ValueError("The block_setting should not be empty") - elif not ( - isinstance(block_setting, Sequence) - and all([isinstance(s, CNBlockConfig) for s in block_setting]) - ): + elif not (isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])): raise TypeError("The block_setting should be List[CNBlockConfig]") if block is None: block = CNBlock if norm_layer is None: - norm_layer = partial(nn.LayerNorm, eps=1e-6) + norm_layer = partial(LayerNorm, eps=1e-6) + + layers: List[nn.Module] = [] + + firstconv_output_channels = block_setting[0].input_channels + layers.append( + ConvNormActivation( + 3, + firstconv_output_channels, + kernel_size=4, + stride=4, + padding=0, + norm_layer=norm_layer, + activation_layer=None, + bias=True, + ) + ) - layers: List[nn.Module] = [ + total_stage_blocks = sum(cnf.num_layers for cnf in block_setting) + stage_block_id = 0 + for cnf in block_setting: + stage: List[nn.Module] = [] + for _ in range(cnf.num_layers): + # adjust stochastic depth probability based on the depth of the stage block + sd_prob = stochastic_depth_prob * float(stage_block_id) / total_stage_blocks + stage.append(block(cnf.input_channels, layer_scale, sd_prob, norm_layer)) + stage_block_id += 1 + layers.append(nn.Sequential(*stage)) + if cnf.out_channels is not None: + layers.append( + nn.Sequential( + norm_layer(cnf.input_channels), + nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2), + ) + ) + + self.features = nn.Sequential(*layers) + self.avgpool = nn.AdaptiveAvgPool2d(1) + + lastblock = block_setting[-1] + lastconv_output_channels = ( + lastblock.out_channels if lastblock.out_channels is not None else lastblock.input_channels + ) + self.classifier = nn.Sequential( + norm_layer(lastconv_output_channels), nn.Flatten(1), nn.Linear(lastconv_output_channels, num_classes) + ) - ] + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.Linear)): + nn.init.trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def _forward_impl(self, x: Tensor) -> Tensor: + x = self.features(x) + x = self.avgpool(x) + x = self.classifier(x) + return x + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + +def convnext_tiny(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: + block_setting = [ + CNBlockConfig(96, 192, 3), + CNBlockConfig(192, 384, 3), + CNBlockConfig(384, 768, 9), + CNBlockConfig(768, None, 3), + ] + model = ConvNeXt(block_setting, **kwargs) + return model diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index 392517cb772..6725da8e8f9 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -131,7 +131,7 @@ def __init__( norm_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.BatchNorm2d, activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, dilation: int = 1, - inplace: bool = True, + inplace: Optional[bool] = True, bias: Optional[bool] = None, ) -> None: if padding is None: @@ -153,7 +153,8 @@ def __init__( if norm_layer is not None: layers.append(norm_layer(out_channels)) if activation_layer is not None: - layers.append(activation_layer(inplace=inplace)) + params = {} if inplace is None else {"inplace": inplace} + layers.append(activation_layer(*params)) super().__init__(*layers) _log_api_usage_once(self) self.out_channels = out_channels From a3034c47d27939fb4b10fb812b5a0821be72d1f9 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 14 Jan 2022 21:02:52 +0000 Subject: [PATCH 03/17] Adding model in prototypes. --- torchvision/models/convnext.py | 17 +++++++++++++++ torchvision/prototype/models/__init__.py | 1 + torchvision/prototype/models/convnext.py | 27 ++++++++++++++++++++++++ 3 files changed, 45 insertions(+) create mode 100644 torchvision/prototype/models/convnext.py diff --git a/torchvision/models/convnext.py b/torchvision/models/convnext.py index ab69c0046c2..30b287984d1 100644 --- a/torchvision/models/convnext.py +++ b/torchvision/models/convnext.py @@ -4,6 +4,7 @@ import torch from torch import nn, Tensor +from .._internally_replaced_utils import load_state_dict_from_url from ..ops.misc import ConvNormActivation from ..ops.stochastic_depth import StochasticDepth from ..utils import _log_api_usage_once @@ -15,6 +16,9 @@ ] +model_urls = {} + + class LayerNorm(nn.LayerNorm): def __init__(self, *args: Any, **kwargs: Any) -> None: self.channels_last = kwargs.pop("channels_last", False) @@ -169,6 +173,13 @@ def forward(self, x: Tensor) -> Tensor: def convnext_tiny(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: + r"""ConvNeXt model architecture from the + `"A ConvNet for the 2020s" `_ paper. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + """ block_setting = [ CNBlockConfig(96, 192, 3), CNBlockConfig(192, 384, 3), @@ -176,4 +187,10 @@ def convnext_tiny(pretrained: bool = False, progress: bool = True, **kwargs: Any CNBlockConfig(768, None, 3), ] model = ConvNeXt(block_setting, **kwargs) + if pretrained: + arch = "convnext_tiny" + if arch not in model_urls: + raise ValueError(f"No checkpoint is available for model type {arch}") + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict) return model diff --git a/torchvision/prototype/models/__init__.py b/torchvision/prototype/models/__init__.py index bfa44ffa720..83e49908348 100644 --- a/torchvision/prototype/models/__init__.py +++ b/torchvision/prototype/models/__init__.py @@ -1,4 +1,5 @@ from .alexnet import * +from .convnext import * from .densenet import * from .efficientnet import * from .googlenet import * diff --git a/torchvision/prototype/models/convnext.py b/torchvision/prototype/models/convnext.py new file mode 100644 index 00000000000..c84fb6735c5 --- /dev/null +++ b/torchvision/prototype/models/convnext.py @@ -0,0 +1,27 @@ +from typing import Any, Optional + +from ...models.convnext import ConvNeXt +from ._api import WeightsEnum +from ._utils import handle_legacy_interface, _ovewrite_named_param + + +__all__ = ["ConvNeXt", "ConvNeXt_Weights", "convnext_tiny"] + + +class ConvNeXt_Weights(WeightsEnum): + pass + + +@handle_legacy_interface(weights=("pretrained", None)) +def convnext_tiny(*, weights: Optional[ConvNeXt_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt: + weights = ConvNeXt_Weights.verify(weights) + + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + model = ConvNeXt(**kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress)) + + return model From e57b64f5b2a43304d37b364353f1dfb83668f73f Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 14 Jan 2022 21:07:40 +0000 Subject: [PATCH 04/17] Add test and minor refactor for JIT. --- .../ModelTester.test_convnext_tiny_expect.pkl | Bin 0 -> 939 bytes torchvision/models/convnext.py | 3 ++- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 test/expect/ModelTester.test_convnext_tiny_expect.pkl diff --git a/test/expect/ModelTester.test_convnext_tiny_expect.pkl b/test/expect/ModelTester.test_convnext_tiny_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..c6fb873f12f17656ab2f10e83328b29a0a7807aa GIT binary patch literal 939 zcmWIWW@cev;NW1u00Im`42ea_8JT6N`YDMeiFyUuIc`pT3{fbcfhoBpAE-(%zO*DW zr zf)S|3ppZF&8AvA=loqmh8ZvUemW=jY_4CYNO9=M{7L z7p0^YrKY%KCYNv(a%ct>a+VZw1r>7Z1$eV_Fj*X^nFTZrgadH;l#f9R#i#lPZcb`w z{zUOK5~pX2oc;3-p8ag{>ib@=Jz#hCzq*~lE3v%?to7}bQWN)oN!PPCbz}Knv=b(ovPi8du{uUDQw<%DdE1|wf#K%ZM|LXS0D4UtFYd? z_fpNgeHO~!_JTsIFz}=B31CQpFz(ReXRwBcR#|FMF)$X~oXm*~E~JoyFparDHeZ~V z9?Ar?6@&x489@|0O(Msk07wD_pr=rD-N=68L(%yP$V1kxZ-A~B*;V`~dL@7^gz1Hb zL4Y?Kn+{Zw9J4N5IVdrM0F2%a;WA7DdlKYbHc;MR@PsM=Wr6^2RyL3rGZ2E*L(~EQ DyXN}0 literal 0 HcmV?d00001 diff --git a/torchvision/models/convnext.py b/torchvision/models/convnext.py index 30b287984d1..f75454286dd 100644 --- a/torchvision/models/convnext.py +++ b/torchvision/models/convnext.py @@ -3,6 +3,7 @@ import torch from torch import nn, Tensor +from torch.nn import functional as F from .._internally_replaced_utils import load_state_dict_from_url from ..ops.misc import ConvNormActivation @@ -27,7 +28,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: def forward(self, x): if not self.channels_last: x = x.permute(0, 2, 3, 1) - x = super().forward(x) + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) if not self.channels_last: x = x.permute(0, 3, 1, 2) return x From 8cddcac0ac04dec3378711352d3761c2e61339f0 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 14 Jan 2022 21:26:35 +0000 Subject: [PATCH 05/17] Fix mypy. --- torchvision/models/convnext.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/models/convnext.py b/torchvision/models/convnext.py index f75454286dd..2f7acdce45d 100644 --- a/torchvision/models/convnext.py +++ b/torchvision/models/convnext.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Callable, List, Optional, Sequence +from typing import Any, Callable, Dict, List, Optional, Sequence import torch from torch import nn, Tensor @@ -17,7 +17,7 @@ ] -model_urls = {} +model_urls: Dict[str, Optional[str]] = {} class LayerNorm(nn.LayerNorm): From 0bef112c5ff34523bc3900e3bcae516c658f4578 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 14 Jan 2022 21:49:47 +0000 Subject: [PATCH 06/17] Fixing naming conventions. --- torchvision/prototype/models/convnext.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/torchvision/prototype/models/convnext.py b/torchvision/prototype/models/convnext.py index c84fb6735c5..4bf1842cbb9 100644 --- a/torchvision/prototype/models/convnext.py +++ b/torchvision/prototype/models/convnext.py @@ -1,25 +1,31 @@ from typing import Any, Optional -from ...models.convnext import ConvNeXt +from ...models.convnext import ConvNeXt, CNBlockConfig from ._api import WeightsEnum from ._utils import handle_legacy_interface, _ovewrite_named_param -__all__ = ["ConvNeXt", "ConvNeXt_Weights", "convnext_tiny"] +__all__ = ["ConvNeXt", "ConvNeXt_Tiny_Weights", "convnext_tiny"] -class ConvNeXt_Weights(WeightsEnum): +class ConvNeXt_Tiny_Weights(WeightsEnum): pass @handle_legacy_interface(weights=("pretrained", None)) -def convnext_tiny(*, weights: Optional[ConvNeXt_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt: - weights = ConvNeXt_Weights.verify(weights) +def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt: + weights = ConvNeXt_Tiny_Weights.verify(weights) if weights is not None: _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) - model = ConvNeXt(**kwargs) + block_setting = [ + CNBlockConfig(96, 192, 3), + CNBlockConfig(192, 384, 3), + CNBlockConfig(384, 768, 9), + CNBlockConfig(768, None, 3), + ] + model = ConvNeXt(block_setting, **kwargs) if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) From cf698320bee9babfcdb50b26c7a33e62a7873b8c Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sat, 15 Jan 2022 01:32:28 +0000 Subject: [PATCH 07/17] Fixing tests. --- torchvision/ops/misc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index 6725da8e8f9..6fe16b0e757 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -154,7 +154,7 @@ def __init__( layers.append(norm_layer(out_channels)) if activation_layer is not None: params = {} if inplace is None else {"inplace": inplace} - layers.append(activation_layer(*params)) + layers.append(activation_layer(**params)) super().__init__(*layers) _log_api_usage_once(self) self.out_channels = out_channels From eb4c8252d3690f039d04d5750776c427758d8764 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sat, 15 Jan 2022 01:53:29 +0000 Subject: [PATCH 08/17] Fix stochastic depth percentages. --- torchvision/models/convnext.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/convnext.py b/torchvision/models/convnext.py index 2f7acdce45d..b0665b6e3c5 100644 --- a/torchvision/models/convnext.py +++ b/torchvision/models/convnext.py @@ -134,7 +134,7 @@ def __init__( stage: List[nn.Module] = [] for _ in range(cnf.num_layers): # adjust stochastic depth probability based on the depth of the stage block - sd_prob = stochastic_depth_prob * float(stage_block_id) / total_stage_blocks + sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0) stage.append(block(cnf.input_channels, layer_scale, sd_prob, norm_layer)) stage_block_id += 1 layers.append(nn.Sequential(*stage)) From 52960cf00203fceac144ebe5f6f69502716dbe68 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sat, 15 Jan 2022 09:43:50 +0000 Subject: [PATCH 09/17] Adding stochastic depth to tiny variant. --- torchvision/models/convnext.py | 3 ++- torchvision/prototype/models/convnext.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/torchvision/models/convnext.py b/torchvision/models/convnext.py index b0665b6e3c5..7d9ab767642 100644 --- a/torchvision/models/convnext.py +++ b/torchvision/models/convnext.py @@ -187,7 +187,8 @@ def convnext_tiny(pretrained: bool = False, progress: bool = True, **kwargs: Any CNBlockConfig(384, 768, 9), CNBlockConfig(768, None, 3), ] - model = ConvNeXt(block_setting, **kwargs) + stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1) + model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs) if pretrained: arch = "convnext_tiny" if arch not in model_urls: diff --git a/torchvision/prototype/models/convnext.py b/torchvision/prototype/models/convnext.py index 4bf1842cbb9..1d5f2b668c2 100644 --- a/torchvision/prototype/models/convnext.py +++ b/torchvision/prototype/models/convnext.py @@ -25,7 +25,8 @@ def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: CNBlockConfig(384, 768, 9), CNBlockConfig(768, None, 3), ] - model = ConvNeXt(block_setting, **kwargs) + stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1) + model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs) if weights is not None: model.load_state_dict(weights.get_state_dict(progress=progress)) From 8ddc17c2a360f2d153e2cc0062c835424523cffc Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Sun, 16 Jan 2022 11:41:36 +0000 Subject: [PATCH 10/17] Minor refactoring and adding comments. --- torchvision/models/convnext.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torchvision/models/convnext.py b/torchvision/models/convnext.py index 7d9ab767642..51739f238ab 100644 --- a/torchvision/models/convnext.py +++ b/torchvision/models/convnext.py @@ -20,7 +20,7 @@ model_urls: Dict[str, Optional[str]] = {} -class LayerNorm(nn.LayerNorm): +class LayerNorm2d(nn.LayerNorm): def __init__(self, *args: Any, **kwargs: Any) -> None: self.channels_last = kwargs.pop("channels_last", False) super().__init__(*args, **kwargs) @@ -110,10 +110,11 @@ def __init__( block = CNBlock if norm_layer is None: - norm_layer = partial(LayerNorm, eps=1e-6) + norm_layer = partial(LayerNorm2d, eps=1e-6) layers: List[nn.Module] = [] + # Stem firstconv_output_channels = block_setting[0].input_channels layers.append( ConvNormActivation( @@ -131,6 +132,7 @@ def __init__( total_stage_blocks = sum(cnf.num_layers for cnf in block_setting) stage_block_id = 0 for cnf in block_setting: + # Bottlenecks stage: List[nn.Module] = [] for _ in range(cnf.num_layers): # adjust stochastic depth probability based on the depth of the stage block @@ -139,6 +141,7 @@ def __init__( stage_block_id += 1 layers.append(nn.Sequential(*stage)) if cnf.out_channels is not None: + # Downsampling layers.append( nn.Sequential( norm_layer(cnf.input_channels), From ce05e242704a8213614abb221c444ebe7d884d87 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 20 Jan 2022 11:30:14 +0000 Subject: [PATCH 11/17] Adding weights. --- docs/source/models.rst | 14 +++++++++++++ hubconf.py | 1 + references/classification/README.md | 14 +++++++++++++ torchvision/models/convnext.py | 5 ++++- torchvision/prototype/models/convnext.py | 26 ++++++++++++++++++++++-- 5 files changed, 57 insertions(+), 3 deletions(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index 62c104cf927..014c50c1070 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -41,6 +41,7 @@ architectures for image classification: - `EfficientNet`_ - `RegNet`_ - `VisionTransformer`_ +- `ConvNeXt`_ You can construct a model with random weights by calling its constructor: @@ -88,6 +89,7 @@ You can construct a model with random weights by calling its constructor: vit_b_32 = models.vit_b_32() vit_l_16 = models.vit_l_16() vit_l_32 = models.vit_l_32() + convnext_tiny = models.convnext_tiny() We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`. These can be constructed by passing ``pretrained=True``: @@ -135,6 +137,7 @@ These can be constructed by passing ``pretrained=True``: vit_b_32 = models.vit_b_32(pretrained=True) vit_l_16 = models.vit_l_16(pretrained=True) vit_l_32 = models.vit_l_32(pretrained=True) + convnext_tiny = models.convnext_tiny(pretrained=True) Instancing a pre-trained model will download its weights to a cache directory. This directory can be set using the `TORCH_HOME` environment variable. See @@ -247,6 +250,7 @@ vit_b_16 81.072 95.318 vit_b_32 75.912 92.466 vit_l_16 79.662 94.638 vit_l_32 76.972 93.070 +convnext_tiny 82.520 96.146 ================================ ============= ============= @@ -265,6 +269,7 @@ vit_l_32 76.972 93.070 .. _EfficientNet: https://arxiv.org/abs/1905.11946 .. _RegNet: https://arxiv.org/abs/2003.13678 .. _VisionTransformer: https://arxiv.org/abs/2010.11929 +.. _ConvNeXt: https://arxiv.org/abs/2201.03545 .. currentmodule:: torchvision.models @@ -461,6 +466,15 @@ VisionTransformer vit_l_16 vit_l_32 +ConvNeXt +-------- + +.. autosummary:: + :toctree: generated/ + :template: function.rst + + convnext_tiny + Quantized Models ---------------- diff --git a/hubconf.py b/hubconf.py index 2b2eeb1c166..6ed1bb9129e 100644 --- a/hubconf.py +++ b/hubconf.py @@ -2,6 +2,7 @@ dependencies = ["torch"] from torchvision.models.alexnet import alexnet +from torchvision.models.convnext import convnext_tiny from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161 from torchvision.models.efficientnet import ( efficientnet_b0, diff --git a/references/classification/README.md b/references/classification/README.md index 48b20a30242..bab845cfe6e 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -197,6 +197,20 @@ Note that the above command corresponds to training on a single node with 8 GPUs For generatring the pre-trained weights, we trained with 8 nodes, each with 8 GPUs (for a total of 64 GPUs), and `--batch_size 64`. + +### ConvNeXt +``` +torchrun --nproc_per_node=8 train.py\ +--model convnext_tiny --batch-size 128 --opt adamw --lr 1e-3 --lr-scheduler cosineannealinglr \ +--lr-warmup-epochs 5 --lr-warmup-method linear --auto-augment ta_wide --epochs 600 --random-erase 0.1 \ +--label-smoothing 0.1 --mixup-alpha 0.2 --cutmix-alpha 1.0 --weight-decay 0.05 --norm-weight-decay 0.0 \ +--train-crop-size 176 --model-ema --val-resize-size 236 --ra-sampler --ra-reps 4 +``` + +Note that the above command corresponds to training on a single node with 8 GPUs. +For generatring the pre-trained weights, we trained with 2 nodes, each with 8 GPUs (for a total of 32 GPUs), +and `--batch_size 64`. + ## Mixed precision training Automatic Mixed Precision (AMP) training on GPU for Pytorch can be enabled with the [torch.cuda.amp](https://pytorch.org/docs/stable/amp.html?highlight=amp#module-torch.cuda.amp). diff --git a/torchvision/models/convnext.py b/torchvision/models/convnext.py index 51739f238ab..475e32727c8 100644 --- a/torchvision/models/convnext.py +++ b/torchvision/models/convnext.py @@ -17,7 +17,9 @@ ] -model_urls: Dict[str, Optional[str]] = {} +model_urls: Dict[str, Optional[str]] = { + "convnext_tiny": "https://download.pytorch.org/models/convnext_tiny-47b116bd.pth", +} class LayerNorm2d(nn.LayerNorm): @@ -26,6 +28,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) def forward(self, x): + # TODO: Benchmark this against the approach described at https://github.com/pytorch/vision/pull/5197#discussion_r786251298 if not self.channels_last: x = x.permute(0, 2, 3, 1) x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) diff --git a/torchvision/prototype/models/convnext.py b/torchvision/prototype/models/convnext.py index 1d5f2b668c2..f4144e5d273 100644 --- a/torchvision/prototype/models/convnext.py +++ b/torchvision/prototype/models/convnext.py @@ -1,7 +1,12 @@ +from functools import partial from typing import Any, Optional +from torchvision.prototype.transforms import ImageNetEval +from torchvision.transforms.functional import InterpolationMode + from ...models.convnext import ConvNeXt, CNBlockConfig -from ._api import WeightsEnum +from ._api import WeightsEnum, Weights +from ._meta import _IMAGENET_CATEGORIES from ._utils import handle_legacy_interface, _ovewrite_named_param @@ -9,7 +14,24 @@ class ConvNeXt_Tiny_Weights(WeightsEnum): - pass + ImageNet1K_V1 = Weights( + url="https://download.pytorch.org/models/convnext_tiny-47b116bd.pth", + transforms=partial(ImageNetEval, crop_size=236), + meta={ + "task": "image_classification", + "architecture": "ConvNeXt", + "publication_year": 2022, + "num_params": 28589128, + "size": (224, 224), + "min_size": (32, 32), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#convnext", + "acc@1": 82.520, + "acc@5": 96.146, + }, + ) + default = ImageNet1K_V1 @handle_legacy_interface(weights=("pretrained", None)) From c4ffc8481e02510f90af7b2ded1f68f7a86e6360 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 20 Jan 2022 11:35:36 +0000 Subject: [PATCH 12/17] Update default weights. --- torchvision/prototype/models/convnext.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/models/convnext.py b/torchvision/prototype/models/convnext.py index f4144e5d273..cd2f505bb32 100644 --- a/torchvision/prototype/models/convnext.py +++ b/torchvision/prototype/models/convnext.py @@ -34,7 +34,7 @@ class ConvNeXt_Tiny_Weights(WeightsEnum): default = ImageNet1K_V1 -@handle_legacy_interface(weights=("pretrained", None)) +@handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.ImageNet1K_V1)) def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt: weights = ConvNeXt_Tiny_Weights.verify(weights) From 7af0e2001bfcd5ab7401642457db81f8faf859a8 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 20 Jan 2022 11:44:54 +0000 Subject: [PATCH 13/17] Fix transforms issue --- torchvision/prototype/models/convnext.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/prototype/models/convnext.py b/torchvision/prototype/models/convnext.py index cd2f505bb32..b16e81343b0 100644 --- a/torchvision/prototype/models/convnext.py +++ b/torchvision/prototype/models/convnext.py @@ -16,7 +16,7 @@ class ConvNeXt_Tiny_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/convnext_tiny-47b116bd.pth", - transforms=partial(ImageNetEval, crop_size=236), + transforms=partial(ImageNetEval, crop_size=224, resize_size=236), meta={ "task": "image_classification", "architecture": "ConvNeXt", From 1ee5b0fbb88a74aba0cbbaa0565b16b93e239a32 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 20 Jan 2022 11:55:59 +0000 Subject: [PATCH 14/17] Move convnext to prototype. --- hubconf.py | 2 +- torchvision/models/__init__.py | 1 - torchvision/models/convnext.py | 204 ----------------------- torchvision/prototype/models/convnext.py | 173 ++++++++++++++++++- 4 files changed, 172 insertions(+), 208 deletions(-) delete mode 100644 torchvision/models/convnext.py diff --git a/hubconf.py b/hubconf.py index 48731449fb0..45a8f73dda6 100644 --- a/hubconf.py +++ b/hubconf.py @@ -2,7 +2,7 @@ dependencies = ["torch"] from torchvision.models.alexnet import alexnet -from torchvision.models.convnext import convnext_tiny +# from torchvision.models.convnext import convnext_tiny from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161 from torchvision.models.efficientnet import ( efficientnet_b0, diff --git a/torchvision/models/__init__.py b/torchvision/models/__init__.py index 16495e8552e..22e2e45d4ce 100644 --- a/torchvision/models/__init__.py +++ b/torchvision/models/__init__.py @@ -1,5 +1,4 @@ from .alexnet import * -from .convnext import * from .resnet import * from .vgg import * from .squeezenet import * diff --git a/torchvision/models/convnext.py b/torchvision/models/convnext.py deleted file mode 100644 index 475e32727c8..00000000000 --- a/torchvision/models/convnext.py +++ /dev/null @@ -1,204 +0,0 @@ -from functools import partial -from typing import Any, Callable, Dict, List, Optional, Sequence - -import torch -from torch import nn, Tensor -from torch.nn import functional as F - -from .._internally_replaced_utils import load_state_dict_from_url -from ..ops.misc import ConvNormActivation -from ..ops.stochastic_depth import StochasticDepth -from ..utils import _log_api_usage_once - - -__all__ = [ - "ConvNeXt", - "convnext_tiny", -] - - -model_urls: Dict[str, Optional[str]] = { - "convnext_tiny": "https://download.pytorch.org/models/convnext_tiny-47b116bd.pth", -} - - -class LayerNorm2d(nn.LayerNorm): - def __init__(self, *args: Any, **kwargs: Any) -> None: - self.channels_last = kwargs.pop("channels_last", False) - super().__init__(*args, **kwargs) - - def forward(self, x): - # TODO: Benchmark this against the approach described at https://github.com/pytorch/vision/pull/5197#discussion_r786251298 - if not self.channels_last: - x = x.permute(0, 2, 3, 1) - x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) - if not self.channels_last: - x = x.permute(0, 3, 1, 2) - return x - - -class CNBlock(nn.Module): - def __init__(self, dim, layer_scale: float, stochastic_depth_prob: float, norm_layer: Callable[..., nn.Module]): - super().__init__() - self.block = nn.Sequential( - ConvNormActivation( - dim, - dim, - kernel_size=7, - groups=dim, - norm_layer=norm_layer, - activation_layer=None, - bias=True, # TODO: check - ), - ConvNormActivation(dim, 4 * dim, kernel_size=1, norm_layer=None, activation_layer=nn.GELU, inplace=None), - ConvNormActivation( - 4 * dim, - dim, - kernel_size=1, - norm_layer=None, - activation_layer=None, - ), - ) - self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale) - self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") - - def forward(self, input: Tensor) -> Tensor: - result = self.layer_scale * self.block(input) - result = self.stochastic_depth(result) - result += input - return result - - -class CNBlockConfig: - # Stores information listed at Section 3 of the ConvNeXt paper - def __init__( - self, - input_channels: int, - out_channels: Optional[int], - num_layers: int, - ) -> None: - self.input_channels = input_channels - self.out_channels = out_channels - self.num_layers = num_layers - - def __repr__(self) -> str: - s = self.__class__.__name__ + "(" - s += "input_channels={input_channels}" - s += ", out_channels={out_channels}" - s += ", num_layers={num_layers}" - s += ")" - return s.format(**self.__dict__) - - -class ConvNeXt(nn.Module): - def __init__( - self, - block_setting: List[CNBlockConfig], - stochastic_depth_prob: float = 0.0, - layer_scale: float = 1e-6, - num_classes: int = 1000, - block: Optional[Callable[..., nn.Module]] = None, - norm_layer: Optional[Callable[..., nn.Module]] = None, - **kwargs: Any, - ) -> None: - super().__init__() - _log_api_usage_once(self) - - if not block_setting: - raise ValueError("The block_setting should not be empty") - elif not (isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])): - raise TypeError("The block_setting should be List[CNBlockConfig]") - - if block is None: - block = CNBlock - - if norm_layer is None: - norm_layer = partial(LayerNorm2d, eps=1e-6) - - layers: List[nn.Module] = [] - - # Stem - firstconv_output_channels = block_setting[0].input_channels - layers.append( - ConvNormActivation( - 3, - firstconv_output_channels, - kernel_size=4, - stride=4, - padding=0, - norm_layer=norm_layer, - activation_layer=None, - bias=True, - ) - ) - - total_stage_blocks = sum(cnf.num_layers for cnf in block_setting) - stage_block_id = 0 - for cnf in block_setting: - # Bottlenecks - stage: List[nn.Module] = [] - for _ in range(cnf.num_layers): - # adjust stochastic depth probability based on the depth of the stage block - sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0) - stage.append(block(cnf.input_channels, layer_scale, sd_prob, norm_layer)) - stage_block_id += 1 - layers.append(nn.Sequential(*stage)) - if cnf.out_channels is not None: - # Downsampling - layers.append( - nn.Sequential( - norm_layer(cnf.input_channels), - nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2), - ) - ) - - self.features = nn.Sequential(*layers) - self.avgpool = nn.AdaptiveAvgPool2d(1) - - lastblock = block_setting[-1] - lastconv_output_channels = ( - lastblock.out_channels if lastblock.out_channels is not None else lastblock.input_channels - ) - self.classifier = nn.Sequential( - norm_layer(lastconv_output_channels), nn.Flatten(1), nn.Linear(lastconv_output_channels, num_classes) - ) - - for m in self.modules(): - if isinstance(m, (nn.Conv2d, nn.Linear)): - nn.init.trunc_normal_(m.weight, std=0.02) - if m.bias is not None: - nn.init.zeros_(m.bias) - - def _forward_impl(self, x: Tensor) -> Tensor: - x = self.features(x) - x = self.avgpool(x) - x = self.classifier(x) - return x - - def forward(self, x: Tensor) -> Tensor: - return self._forward_impl(x) - - -def convnext_tiny(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ConvNeXt: - r"""ConvNeXt model architecture from the - `"A ConvNet for the 2020s" `_ paper. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - block_setting = [ - CNBlockConfig(96, 192, 3), - CNBlockConfig(192, 384, 3), - CNBlockConfig(384, 768, 9), - CNBlockConfig(768, None, 3), - ] - stochastic_depth_prob = kwargs.pop("stochastic_depth_prob", 0.1) - model = ConvNeXt(block_setting, stochastic_depth_prob=stochastic_depth_prob, **kwargs) - if pretrained: - arch = "convnext_tiny" - if arch not in model_urls: - raise ValueError(f"No checkpoint is available for model type {arch}") - state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) - model.load_state_dict(state_dict) - return model diff --git a/torchvision/prototype/models/convnext.py b/torchvision/prototype/models/convnext.py index b16e81343b0..d91f25175c9 100644 --- a/torchvision/prototype/models/convnext.py +++ b/torchvision/prototype/models/convnext.py @@ -1,10 +1,15 @@ from functools import partial -from typing import Any, Optional +from typing import Any, Callable, List, Optional, Sequence +import torch +from torch import nn, Tensor +from torch.nn import functional as F from torchvision.prototype.transforms import ImageNetEval from torchvision.transforms.functional import InterpolationMode -from ...models.convnext import ConvNeXt, CNBlockConfig +from ...ops.misc import ConvNormActivation +from ...ops.stochastic_depth import StochasticDepth +from ...utils import _log_api_usage_once from ._api import WeightsEnum, Weights from ._meta import _IMAGENET_CATEGORIES from ._utils import handle_legacy_interface, _ovewrite_named_param @@ -13,6 +18,163 @@ __all__ = ["ConvNeXt", "ConvNeXt_Tiny_Weights", "convnext_tiny"] +class LayerNorm2d(nn.LayerNorm): + def __init__(self, *args: Any, **kwargs: Any) -> None: + self.channels_last = kwargs.pop("channels_last", False) + super().__init__(*args, **kwargs) + + def forward(self, x): + # TODO: Benchmark this against the approach described at https://github.com/pytorch/vision/pull/5197#discussion_r786251298 + if not self.channels_last: + x = x.permute(0, 2, 3, 1) + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + if not self.channels_last: + x = x.permute(0, 3, 1, 2) + return x + + +class CNBlock(nn.Module): + def __init__(self, dim, layer_scale: float, stochastic_depth_prob: float, norm_layer: Callable[..., nn.Module]): + super().__init__() + self.block = nn.Sequential( + ConvNormActivation( + dim, + dim, + kernel_size=7, + groups=dim, + norm_layer=norm_layer, + activation_layer=None, + bias=True, + ), + ConvNormActivation(dim, 4 * dim, kernel_size=1, norm_layer=None, activation_layer=nn.GELU, inplace=None), + ConvNormActivation( + 4 * dim, + dim, + kernel_size=1, + norm_layer=None, + activation_layer=None, + ), + ) + self.layer_scale = nn.Parameter(torch.ones(dim, 1, 1) * layer_scale) + self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") + + def forward(self, input: Tensor) -> Tensor: + result = self.layer_scale * self.block(input) + result = self.stochastic_depth(result) + result += input + return result + + +class CNBlockConfig: + # Stores information listed at Section 3 of the ConvNeXt paper + def __init__( + self, + input_channels: int, + out_channels: Optional[int], + num_layers: int, + ) -> None: + self.input_channels = input_channels + self.out_channels = out_channels + self.num_layers = num_layers + + def __repr__(self) -> str: + s = self.__class__.__name__ + "(" + s += "input_channels={input_channels}" + s += ", out_channels={out_channels}" + s += ", num_layers={num_layers}" + s += ")" + return s.format(**self.__dict__) + + +class ConvNeXt(nn.Module): + def __init__( + self, + block_setting: List[CNBlockConfig], + stochastic_depth_prob: float = 0.0, + layer_scale: float = 1e-6, + num_classes: int = 1000, + block: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + **kwargs: Any, + ) -> None: + super().__init__() + _log_api_usage_once(self) + + if not block_setting: + raise ValueError("The block_setting should not be empty") + elif not (isinstance(block_setting, Sequence) and all([isinstance(s, CNBlockConfig) for s in block_setting])): + raise TypeError("The block_setting should be List[CNBlockConfig]") + + if block is None: + block = CNBlock + + if norm_layer is None: + norm_layer = partial(LayerNorm2d, eps=1e-6) + + layers: List[nn.Module] = [] + + # Stem + firstconv_output_channels = block_setting[0].input_channels + layers.append( + ConvNormActivation( + 3, + firstconv_output_channels, + kernel_size=4, + stride=4, + padding=0, + norm_layer=norm_layer, + activation_layer=None, + bias=True, + ) + ) + + total_stage_blocks = sum(cnf.num_layers for cnf in block_setting) + stage_block_id = 0 + for cnf in block_setting: + # Bottlenecks + stage: List[nn.Module] = [] + for _ in range(cnf.num_layers): + # adjust stochastic depth probability based on the depth of the stage block + sd_prob = stochastic_depth_prob * stage_block_id / (total_stage_blocks - 1.0) + stage.append(block(cnf.input_channels, layer_scale, sd_prob, norm_layer)) + stage_block_id += 1 + layers.append(nn.Sequential(*stage)) + if cnf.out_channels is not None: + # Downsampling + layers.append( + nn.Sequential( + norm_layer(cnf.input_channels), + nn.Conv2d(cnf.input_channels, cnf.out_channels, kernel_size=2, stride=2), + ) + ) + + self.features = nn.Sequential(*layers) + self.avgpool = nn.AdaptiveAvgPool2d(1) + + lastblock = block_setting[-1] + lastconv_output_channels = ( + lastblock.out_channels if lastblock.out_channels is not None else lastblock.input_channels + ) + self.classifier = nn.Sequential( + norm_layer(lastconv_output_channels), nn.Flatten(1), nn.Linear(lastconv_output_channels, num_classes) + ) + + for m in self.modules(): + if isinstance(m, (nn.Conv2d, nn.Linear)): + nn.init.trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def _forward_impl(self, x: Tensor) -> Tensor: + x = self.features(x) + x = self.avgpool(x) + x = self.classifier(x) + return x + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + class ConvNeXt_Tiny_Weights(WeightsEnum): ImageNet1K_V1 = Weights( url="https://download.pytorch.org/models/convnext_tiny-47b116bd.pth", @@ -36,6 +198,13 @@ class ConvNeXt_Tiny_Weights(WeightsEnum): @handle_legacy_interface(weights=("pretrained", ConvNeXt_Tiny_Weights.ImageNet1K_V1)) def convnext_tiny(*, weights: Optional[ConvNeXt_Tiny_Weights] = None, progress: bool = True, **kwargs: Any) -> ConvNeXt: + r"""ConvNeXt model architecture from the + `"A ConvNet for the 2020s" `_ paper. + + Args: + weights (ConvNeXt_Tiny_Weights, optional): The pre-trained weights of the model + progress (bool): If True, displays a progress bar of the download to stderr + """ weights = ConvNeXt_Tiny_Weights.verify(weights) if weights is not None: From be2972e403bb86e834d464595562dc1db40d4df1 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 20 Jan 2022 12:06:29 +0000 Subject: [PATCH 15/17] linter fix --- hubconf.py | 1 - 1 file changed, 1 deletion(-) diff --git a/hubconf.py b/hubconf.py index 45a8f73dda6..1b3b191efa4 100644 --- a/hubconf.py +++ b/hubconf.py @@ -2,7 +2,6 @@ dependencies = ["torch"] from torchvision.models.alexnet import alexnet -# from torchvision.models.convnext import convnext_tiny from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161 from torchvision.models.efficientnet import ( efficientnet_b0, From f47a59098dbaaad2881db3960d6d01b9d0200f06 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 20 Jan 2022 12:17:41 +0000 Subject: [PATCH 16/17] fix docs --- docs/source/models.rst | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index e7e06ad938c..4c65eac8135 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -89,8 +89,7 @@ You can construct a model with random weights by calling its constructor: vit_b_32 = models.vit_b_32() vit_l_16 = models.vit_l_16() vit_l_32 = models.vit_l_32() - vit_h_14 = models.vit_h_14() - convnext_tiny = models.convnext_tiny() + vit_h_14 = models.vit_h_14() We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`. These can be constructed by passing ``pretrained=True``: @@ -138,7 +137,6 @@ These can be constructed by passing ``pretrained=True``: vit_b_32 = models.vit_b_32(pretrained=True) vit_l_16 = models.vit_l_16(pretrained=True) vit_l_32 = models.vit_l_32(pretrained=True) - convnext_tiny = models.convnext_tiny(pretrained=True) Instancing a pre-trained model will download its weights to a cache directory. This directory can be set using the `TORCH_HOME` environment variable. See @@ -251,7 +249,7 @@ vit_b_16 81.072 95.318 vit_b_32 75.912 92.466 vit_l_16 79.662 94.638 vit_l_32 76.972 93.070 -convnext_tiny 82.520 96.146 +convnext_tiny (prototype) 82.520 96.146 ================================ ============= ============= @@ -468,15 +466,6 @@ VisionTransformer vit_l_32 vit_h_14 -ConvNeXt --------- - -.. autosummary:: - :toctree: generated/ - :template: function.rst - - convnext_tiny - Quantized Models ---------------- From 9e6fda124543c50ceffc6f1f0a1743be35700c42 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 20 Jan 2022 15:02:31 +0000 Subject: [PATCH 17/17] Addressing code review comments. --- references/classification/README.md | 2 +- torchvision/prototype/models/convnext.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/references/classification/README.md b/references/classification/README.md index bab845cfe6e..0fb27eac7cc 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -208,7 +208,7 @@ torchrun --nproc_per_node=8 train.py\ ``` Note that the above command corresponds to training on a single node with 8 GPUs. -For generatring the pre-trained weights, we trained with 2 nodes, each with 8 GPUs (for a total of 32 GPUs), +For generatring the pre-trained weights, we trained with 2 nodes, each with 8 GPUs (for a total of 16 GPUs), and `--batch_size 64`. ## Mixed precision training diff --git a/torchvision/prototype/models/convnext.py b/torchvision/prototype/models/convnext.py index d91f25175c9..788dcbc2cd1 100644 --- a/torchvision/prototype/models/convnext.py +++ b/torchvision/prototype/models/convnext.py @@ -23,7 +23,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self.channels_last = kwargs.pop("channels_last", False) super().__init__(*args, **kwargs) - def forward(self, x): + def forward(self, x: Tensor) -> Tensor: # TODO: Benchmark this against the approach described at https://github.com/pytorch/vision/pull/5197#discussion_r786251298 if not self.channels_last: x = x.permute(0, 2, 3, 1) @@ -34,7 +34,9 @@ def forward(self, x): class CNBlock(nn.Module): - def __init__(self, dim, layer_scale: float, stochastic_depth_prob: float, norm_layer: Callable[..., nn.Module]): + def __init__( + self, dim, layer_scale: float, stochastic_depth_prob: float, norm_layer: Callable[..., nn.Module] + ) -> None: super().__init__() self.block = nn.Sequential( ConvNormActivation(