diff --git a/src/transformers/models/vit_msn/modeling_vit_msn.py b/src/transformers/models/vit_msn/modeling_vit_msn.py index 140864a06a46f6..01068196423b30 100644 --- a/src/transformers/models/vit_msn/modeling_vit_msn.py +++ b/src/transformers/models/vit_msn/modeling_vit_msn.py @@ -36,6 +36,7 @@ ) from .configuration_vit_msn import ViTMSNConfig + logger = logging.get_logger(__name__) @@ -56,22 +57,14 @@ def __init__(self, config: ViTMSNConfig, use_mask_token: bool = False) -> None: super().__init__() self.cls_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) - self.mask_token = ( - nn.Parameter(torch.zeros(1, 1, config.hidden_size)) - if use_mask_token - else None - ) + self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None self.patch_embeddings = ViTMSNPatchEmbeddings(config) num_patches = self.patch_embeddings.num_patches - self.position_embeddings = nn.Parameter( - torch.zeros(1, num_patches + 1, config.hidden_size) - ) + self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches + 1, config.hidden_size)) self.dropout = nn.Dropout(config.hidden_dropout_prob) self.config = config - def interpolate_pos_encoding( - self, embeddings: torch.Tensor, height: int, width: int - ) -> torch.Tensor: + def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: """ This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution images. @@ -95,9 +88,7 @@ def interpolate_pos_encoding( patch_window_height + 0.1, patch_window_width + 0.1, ) - patch_pos_embed = patch_pos_embed.reshape( - 1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim - ) + patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) patch_pos_embed = nn.functional.interpolate( patch_pos_embed, @@ -118,9 +109,7 @@ def forward( interpolate_pos_encoding: bool = False, ) -> torch.Tensor: batch_size, num_channels, height, width = pixel_values.shape - embeddings = self.patch_embeddings( - pixel_values, interpolate_pos_encoding=interpolate_pos_encoding - ) + embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) if bool_masked_pos is not None: seq_length = embeddings.shape[1] @@ -135,9 +124,7 @@ def forward( # add positional encoding to each token if interpolate_pos_encoding: - embeddings = embeddings + self.interpolate_pos_encoding( - embeddings, height, width - ) + embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) else: embeddings = embeddings + self.position_embeddings @@ -159,31 +146,17 @@ def __init__(self, config): image_size, patch_size = config.image_size, config.patch_size num_channels, hidden_size = config.num_channels, config.hidden_size - image_size = ( - image_size - if isinstance(image_size, collections.abc.Iterable) - else (image_size, image_size) - ) - patch_size = ( - patch_size - if isinstance(patch_size, collections.abc.Iterable) - else (patch_size, patch_size) - ) - num_patches = (image_size[1] // patch_size[1]) * ( - image_size[0] // patch_size[0] - ) + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) self.image_size = image_size self.patch_size = patch_size self.num_channels = num_channels self.num_patches = num_patches - self.projection = nn.Conv2d( - num_channels, hidden_size, kernel_size=patch_size, stride=patch_size - ) + self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) - def forward( - self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False - ) -> torch.Tensor: + def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: batch_size, num_channels, height, width = pixel_values.shape if num_channels != self.num_channels: raise ValueError( @@ -204,9 +177,7 @@ def forward( class ViTMSNSelfAttention(nn.Module): def __init__(self, config: ViTMSNConfig) -> None: super().__init__() - if config.hidden_size % config.num_attention_heads != 0 and not hasattr( - config, "embedding_size" - ): + if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): raise ValueError( f"The hidden size {config.hidden_size,} is not a multiple of the number of attention " f"heads {config.num_attention_heads}." @@ -216,15 +187,9 @@ def __init__(self, config: ViTMSNConfig) -> None: self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size - self.query = nn.Linear( - config.hidden_size, self.all_head_size, bias=config.qkv_bias - ) - self.key = nn.Linear( - config.hidden_size, self.all_head_size, bias=config.qkv_bias - ) - self.value = nn.Linear( - config.hidden_size, self.all_head_size, bias=config.qkv_bias - ) + self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) + self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) self.dropout = nn.Dropout(config.attention_probs_dropout_prob) @@ -270,9 +235,7 @@ def forward( new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(new_context_layer_shape) - outputs = ( - (context_layer, attention_probs) if output_attentions else (context_layer,) - ) + outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) return outputs @@ -324,9 +287,7 @@ def __init__(self, config: ViTMSNConfig) -> None: self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward( - self, hidden_states: torch.Tensor, input_tensor: torch.Tensor - ) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) @@ -358,12 +319,8 @@ def prune_heads(self, heads: Set[int]) -> None: self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) # Update hyper params and store pruned heads - self.attention.num_attention_heads = self.attention.num_attention_heads - len( - heads - ) - self.attention.all_head_size = ( - self.attention.attention_head_size * self.attention.num_attention_heads - ) + self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) + self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads self.pruned_heads = self.pruned_heads.union(heads) def forward( @@ -376,9 +333,7 @@ def forward( attention_output = self.output(self_outputs[0], hidden_states) - outputs = (attention_output,) + self_outputs[ - 1: - ] # add attentions if we output them + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them return outputs @@ -413,9 +368,7 @@ def __init__(self, config: ViTMSNConfig) -> None: self.dense = nn.Linear(config.intermediate_size, config.hidden_size) self.dropout = nn.Dropout(config.hidden_dropout_prob) - def forward( - self, hidden_states: torch.Tensor, input_tensor: torch.Tensor - ) -> torch.Tensor: + def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) @@ -438,12 +391,8 @@ def __init__(self, config: ViTMSNConfig) -> None: self.attention = VITMSN_ATTENTION_CLASSES[config._attn_implementation](config) self.intermediate = ViTMSNIntermediate(config) self.output = ViTMSNOutput(config) - self.layernorm_before = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_eps - ) - self.layernorm_after = nn.LayerNorm( - config.hidden_size, eps=config.layer_norm_eps - ) + self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) def forward( self, @@ -452,16 +401,12 @@ def forward( output_attentions: bool = False, ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: self_attention_outputs = self.attention( - self.layernorm_before( - hidden_states - ), # in ViTMSN, layernorm is applied before self-attention + self.layernorm_before(hidden_states), # in ViTMSN, layernorm is applied before self-attention head_mask, output_attentions=output_attentions, ) attention_output = self_attention_outputs[0] - outputs = self_attention_outputs[ - 1: - ] # add self attentions if we output attention weights + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights # first residual connection hidden_states = attention_output + hidden_states @@ -483,9 +428,7 @@ class ViTMSNEncoder(nn.Module): def __init__(self, config: ViTMSNConfig) -> None: super().__init__() self.config = config - self.layer = nn.ModuleList( - [ViTMSNLayer(config) for _ in range(config.num_hidden_layers)] - ) + self.layer = nn.ModuleList([ViTMSNLayer(config) for _ in range(config.num_hidden_layers)]) self.gradient_checkpointing = False def forward( @@ -513,9 +456,7 @@ def forward( output_attentions, ) else: - layer_outputs = layer_module( - hidden_states, layer_head_mask, output_attentions - ) + layer_outputs = layer_module(hidden_states, layer_head_mask, output_attentions) hidden_states = layer_outputs[0] @@ -526,11 +467,7 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if not return_dict: - return tuple( - v - for v in [hidden_states, all_hidden_states, all_self_attentions] - if v is not None - ) + return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) return BaseModelOutput( last_hidden_state=hidden_states, hidden_states=all_hidden_states, @@ -631,9 +568,7 @@ class PreTrainedModel self.encoder.layer[layer].attention.prune_heads(heads) @add_start_docstrings_to_model_forward(VIT_MSN_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC - ) + @replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC) def forward( self, pixel_values: Optional[torch.Tensor] = None, @@ -668,19 +603,11 @@ def forward( ... outputs = model(**inputs) >>> last_hidden_states = outputs.last_hidden_state ```""" - output_attentions = ( - output_attentions - if output_attentions is not None - else self.config.output_attentions - ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( - output_hidden_states - if output_hidden_states is not None - else self.config.output_hidden_states - ) - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict if pixel_values is None: raise ValueError("You have to specify pixel_values") @@ -735,19 +662,13 @@ def __init__(self, config: ViTMSNConfig) -> None: self.vit = ViTMSNModel(config) # Classifier head - self.classifier = ( - nn.Linear(config.hidden_size, config.num_labels) - if config.num_labels > 0 - else nn.Identity() - ) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) if config.num_labels > 0 else nn.Identity() # Initialize weights and apply final processing self.post_init() @add_start_docstrings_to_model_forward(VIT_MSN_INPUTS_DOCSTRING) - @replace_return_docstrings( - output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC - ) + @replace_return_docstrings(output_type=ImageClassifierOutput, config_class=_CONFIG_FOR_DOC) def forward( self, pixel_values: Optional[torch.Tensor] = None, @@ -785,9 +706,7 @@ def forward( >>> print(model.config.id2label[predicted_label]) tusker ```""" - return_dict = ( - return_dict if return_dict is not None else self.config.use_return_dict - ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.vit( pixel_values, @@ -807,9 +726,7 @@ def forward( if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" - elif self.num_labels > 1 and ( - labels.dtype == torch.long or labels.dtype == torch.int - ): + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification"