From b86656b65cf5f18ba9eb760d1f7565ed95e7e96e Mon Sep 17 00:00:00 2001 From: liuruiyang98 <865296294@qq.com> Date: Fri, 6 May 2022 15:14:45 +0800 Subject: [PATCH] Update sequencer2D --- models_jittor/sequencer.py | 2 +- models_pytorch/sequencer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/models_jittor/sequencer.py b/models_jittor/sequencer.py index fe47e0e..03682ac 100644 --- a/models_jittor/sequencer.py +++ b/models_jittor/sequencer.py @@ -72,7 +72,7 @@ def execute(self, x): return x class Sequencer2D(nn.Module): - def __init__(self, model_name: str = 'T', pretrained: str = None, num_classes: int = 1000, in_channels = 3, *args, **kwargs) -> None: + def __init__(self, model_name: str = 'M', pretrained: str = None, num_classes: int = 1000, in_channels = 3, *args, **kwargs) -> None: super().__init__() assert model_name in sequencer_settings.keys(), f"Sequencer model name should be in {list(sequencer_settings.keys())}" depth, embed_dims, hidden_dims, expansion_factor = sequencer_settings[model_name] diff --git a/models_pytorch/sequencer.py b/models_pytorch/sequencer.py index 0eac833..deb6486 100644 --- a/models_pytorch/sequencer.py +++ b/models_pytorch/sequencer.py @@ -72,7 +72,7 @@ def forward(self, x): return x class Sequencer2D(nn.Module): - def __init__(self, model_name: str = 'T', pretrained: str = None, num_classes: int = 1000, in_channels = 3, *args, **kwargs) -> None: + def __init__(self, model_name: str = 'M', pretrained: str = None, num_classes: int = 1000, in_channels = 3, *args, **kwargs) -> None: super().__init__() assert model_name in sequencer_settings.keys(), f"Sequencer model name should be in {list(sequencer_settings.keys())}" depth, embed_dims, hidden_dims, expansion_factor = sequencer_settings[model_name]