Skip to content

Commit

Permalink
adding use case of interpolate_pos = False and respective testing to …
Browse files Browse the repository at this point in the history
…all models
  • Loading branch information
Manuel Sanchez Hernandez committed Aug 4, 2024
1 parent 0cd8a8f commit db5a20a
Show file tree
Hide file tree
Showing 12 changed files with 132 additions and 608 deletions.
13 changes: 10 additions & 3 deletions src/transformers/models/altclip/modeling_altclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ def forward(
attention_mask: Optional[torch.Tensor] = None,
causal_attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel"""

bsz, tgt_len, embed_dim = hidden_states.size()
Expand Down Expand Up @@ -840,6 +840,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states


# Copied from transformers.models.clip.modeling_clip.CLIPEncoderLayer with CLIP->AltCLIP
class AltCLIPEncoderLayer(nn.Module):
def __init__(self, config: AltCLIPConfig):
super().__init__()
Expand Down Expand Up @@ -890,6 +891,7 @@ def forward(
return outputs


# Copied from transformers.models.clip.modeling_clip.CLIPEncoder with CLIP->AltCLIP
class AltCLIPEncoder(nn.Module):
"""
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
Expand Down Expand Up @@ -1047,6 +1049,10 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width:

def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
batch_size, _, height, width = pixel_values.shape
if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model" f" ({self.image_size}*{self.image_size})."
)
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
Expand Down Expand Up @@ -1117,6 +1123,7 @@ def _init_weights(self, module):
module.weight.data[module.padding_idx].zero_()


# Copied from transformers.models.clip.modeling_clip.CLIPVisionTransformer with CLIPVisionTransformer->AltCLIPVisionTransformer,CLIPVisionConfig->AltCLIPVisionConfig,CLIPVisionEmbeddings->AltCLIPVisionEmbeddings,CLIPEncoder->AltCLIPEncoder,CLIP_VISION_INPUTS_DOCSTRING->ALTCLIP_VISION_INPUTS_DOCSTRING
class AltCLIPVisionTransformer(nn.Module):
def __init__(self, config: AltCLIPVisionConfig):
super().__init__()
Expand Down Expand Up @@ -1508,12 +1515,12 @@ def __init__(self, config: AltCLIPConfig):
super().__init__(config)

if not isinstance(config.vision_config, AltCLIPVisionConfig):
raise TypeError(
raise ValueError(
"config.vision_config is expected to be of type AltCLIPVisionConfig but is of type"
f" {type(config.vision_config)}."
)
if not isinstance(config.text_config, AltCLIPTextConfig):
raise TypeError(
raise ValueError(
"config.text_config is expected to be of type AltCLIPTextConfig but is of type"
f" {type(config.text_config)}."
)
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/models/bridgetower/modeling_bridgetower.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,10 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width:

def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
batch_size, _, height, width = pixel_values.shape
if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model" f" ({self.image_size}*{self.image_size})."
)
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
Expand Down
8 changes: 6 additions & 2 deletions src/transformers/models/chinese_clip/modeling_chinese_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,10 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width:

def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
batch_size, _, height, width = pixel_values.shape
if not interpolate_pos_encoding and (height != self.image_size or width != self.image_size):
raise ValueError(
f"Input image size ({height}*{width}) doesn't match model" f" ({self.image_size}*{self.image_size})."
)
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
Expand Down Expand Up @@ -1385,13 +1389,13 @@ def __init__(self, config: ChineseCLIPConfig):
super().__init__(config)

if not isinstance(config.text_config, ChineseCLIPTextConfig):
raise TypeError(
raise ValueError(
"config.text_config is expected to be of type ChineseCLIPTextConfig but is of type"
f" {type(config.text_config)}."
)

if not isinstance(config.vision_config, ChineseCLIPVisionConfig):
raise TypeError(
raise ValueError(
"config.vision_config is expected to be of type ChineseCLIPVisionConfig but is of type"
f" {type(config.vision_config)}."
)
Expand Down
Loading

0 comments on commit db5a20a

Please sign in to comment.