Skip to content

Commit

Permalink
Update sequencer2D
Browse files Browse the repository at this point in the history
  • Loading branch information
liuruiyang98 committed May 6, 2022
1 parent 8feb07f commit b86656b
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion models_jittor/sequencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def execute(self, x):
return x

class Sequencer2D(nn.Module):
def __init__(self, model_name: str = 'T', pretrained: str = None, num_classes: int = 1000, in_channels = 3, *args, **kwargs) -> None:
def __init__(self, model_name: str = 'M', pretrained: str = None, num_classes: int = 1000, in_channels = 3, *args, **kwargs) -> None:
super().__init__()
assert model_name in sequencer_settings.keys(), f"Sequencer model name should be in {list(sequencer_settings.keys())}"
depth, embed_dims, hidden_dims, expansion_factor = sequencer_settings[model_name]
Expand Down
2 changes: 1 addition & 1 deletion models_pytorch/sequencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def forward(self, x):
return x

class Sequencer2D(nn.Module):
def __init__(self, model_name: str = 'T', pretrained: str = None, num_classes: int = 1000, in_channels = 3, *args, **kwargs) -> None:
def __init__(self, model_name: str = 'M', pretrained: str = None, num_classes: int = 1000, in_channels = 3, *args, **kwargs) -> None:
super().__init__()
assert model_name in sequencer_settings.keys(), f"Sequencer model name should be in {list(sequencer_settings.keys())}"
depth, embed_dims, hidden_dims, expansion_factor = sequencer_settings[model_name]
Expand Down

0 comments on commit b86656b

Please sign in to comment.