Skip to content

Commit

Permalink
fixes clip interpolate
Browse files Browse the repository at this point in the history
  • Loading branch information
nileshkokane01 committed May 13, 2024
1 parent 0f8fefd commit 6ed7b47
Show file tree
Hide file tree
Showing 14 changed files with 540 additions and 34 deletions.
52 changes: 48 additions & 4 deletions src/transformers/models/altclip/modeling_altclip.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,15 +1010,52 @@ def __init__(self, config: AltCLIPVisionConfig):
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)

def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
position_embeddings = self.position_embedding.weight.unsqueeze(0)
num_patches = embeddings.shape[1] - 1
num_positions = position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return position_embeddings
class_pos_embed = position_embeddings[:, 0]
patch_pos_embed = position_embeddings[:, 1:]
dim = embeddings.shape[-1]
height = height // self.config.patch_size
width = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
height, width = height + 0.1, width + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(height / math.sqrt(num_positions), width / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
raise ValueError("Width or height does not match with the interpolated position embeddings")
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
batch_size, _, height, width = pixel_values.shape
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings


Expand Down Expand Up @@ -1099,6 +1136,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
interpolate_pos_encoding: Optional[bool] = False,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
Expand All @@ -1113,7 +1151,7 @@ def forward(
if pixel_values is None:
raise ValueError("You have to specify pixel_values")

hidden_states = self.embeddings(pixel_values)
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
hidden_states = self.pre_layrnorm(hidden_states)

encoder_outputs = self.encoder(
Expand Down Expand Up @@ -1158,6 +1196,7 @@ def forward(
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Expand Down Expand Up @@ -1188,6 +1227,7 @@ def forward(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)

Expand Down Expand Up @@ -1548,6 +1588,7 @@ def get_image_features(
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> torch.FloatTensor:
r"""
Expand Down Expand Up @@ -1580,6 +1621,7 @@ def get_image_features(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)

Expand All @@ -1600,6 +1642,7 @@ def forward(
return_loss: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple, AltCLIPOutput]:
r"""
Expand Down Expand Up @@ -1644,6 +1687,7 @@ def forward(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)

Expand Down
69 changes: 59 additions & 10 deletions src/transformers/models/bridgetower/modeling_bridgetower.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,15 +276,52 @@ def __init__(self, config: BridgeTowerVisionConfig):
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)

def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
position_embeddings = self.position_embedding.weight.unsqueeze(0)
num_patches = embeddings.shape[1] - 1
num_positions = position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return position_embeddings
class_pos_embed = position_embeddings[:, 0]
patch_pos_embed = position_embeddings[:, 1:]
dim = embeddings.shape[-1]
height = height // self.config.patch_size
width = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
height, width = height + 0.1, width + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(height / math.sqrt(num_positions), width / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
raise ValueError("Width or height does not match with the interpolated position embeddings")
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
batch_size, _, height, width = pixel_values.shape
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings


Expand All @@ -302,8 +339,13 @@ def __init__(self, config):
[nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) for _ in range(config.num_hidden_layers)]
)

def forward(self, pixel_values: torch.Tensor, attention_mask):
hidden_states = self.embeddings(pixel_values)
def forward(
self,
pixel_values: torch.Tensor,
attention_mask,
interpolate_pos_encoding: bool = False,
):
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding)
hidden_states = self.ln_pre(hidden_states)
# NLD -> LND
hidden_states = hidden_states.permute(1, 0, 2)
Expand All @@ -324,8 +366,12 @@ def forward(self, pixel_values: torch.Tensor, attention_mask):
hidden_states = torch.stack(hidden_states_stack, dim=0)
return hidden_states

def forward_pre(self, pixel_values: torch.Tensor):
hidden_states = self.embeddings(pixel_values)
def forward_pre(
self,
pixel_values: torch.Tensor,
interpolate_pos_encoding: bool = False,
):
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
hidden_states = self.ln_pre(hidden_states)
# NLD -> LND
hidden_states = hidden_states.permute(1, 0, 2)
Expand Down Expand Up @@ -1015,8 +1061,8 @@ def __init__(self, config):
def dtype(self):
return self.visual.embeddings.patch_embedding.weight.dtype

def forward(self, image, image_mask=None):
return self.visual(image.type(self.dtype), image_mask)
def forward(self, image, image_mask=None, interpolate_pos_encoding=False):
return self.visual(image.type(self.dtype), image_mask, interpolate_pos_encoding)


class BridgeTowerTextModel(BridgeTowerPreTrainedModel):
Expand Down Expand Up @@ -1280,6 +1326,7 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
labels: Optional[torch.LongTensor] = None,
interpolate_pos_encoding: bool = False,
) -> Union[Tuple[torch.Tensor], BridgeTowerModelOutput]:
r"""
output_hidden_states (`bool`, *optional*):
Expand Down Expand Up @@ -1352,7 +1399,9 @@ def forward(
all_hidden_states_text += (text_embeds,)

if image_embeds is None:
image_embeds = self.vision_model.visual.forward_pre(pixel_values.type(self.vision_model.dtype))
image_embeds = self.vision_model.visual.forward_pre(
pixel_values.type(self.vision_model.dtype), interpolate_pos_encoding=interpolate_pos_encoding
)
else:
# Permute as BridgeTowerResidualAttention has batch_first=True
image_embeds = image_embeds.permute(1, 0, 2)
Expand Down
56 changes: 52 additions & 4 deletions src/transformers/models/chinese_clip/modeling_chinese_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,15 +189,52 @@ def __init__(self, config: ChineseCLIPVisionConfig):
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)

def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
position_embeddings = self.position_embedding.weight.unsqueeze(0)
num_patches = embeddings.shape[1] - 1
num_positions = position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return position_embeddings
class_pos_embed = position_embeddings[:, 0]
patch_pos_embed = position_embeddings[:, 1:]
dim = embeddings.shape[-1]
height = height // self.config.patch_size
width = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
height, width = height + 0.1, width + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(height / math.sqrt(num_positions), width / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]:
raise ValueError("Width or height does not match with the interpolated position embeddings")
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)

def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor:
batch_size, _, height, width = pixel_values.shape
target_dtype = self.patch_embedding.weight.dtype
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)

class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
embeddings = embeddings + self.position_embedding(self.position_ids)
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
embeddings = embeddings + self.position_embedding(self.position_ids)
return embeddings


Expand Down Expand Up @@ -799,6 +836,8 @@ def _init_weights(self, module):
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
Expand All @@ -814,6 +853,8 @@ def _init_weights(self, module):
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
interpolate_pos_encoding (`bool`, *optional*):
Whether to interpolate the pre-trained position encodings.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
Expand Down Expand Up @@ -1053,6 +1094,7 @@ def forward(
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Expand All @@ -1067,7 +1109,7 @@ def forward(
if pixel_values is None:
raise ValueError("You have to specify pixel_values")

hidden_states = self.embeddings(pixel_values)
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding)
hidden_states = self.pre_layrnorm(hidden_states)

encoder_outputs = self.encoder(
Expand Down Expand Up @@ -1300,6 +1342,7 @@ def forward(
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Expand Down Expand Up @@ -1330,6 +1373,7 @@ def forward(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)

Expand Down Expand Up @@ -1426,6 +1470,7 @@ def get_image_features(
pixel_values: Optional[torch.FloatTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> torch.FloatTensor:
r"""
Expand Down Expand Up @@ -1462,6 +1507,7 @@ def get_image_features(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)

Expand All @@ -1482,6 +1528,7 @@ def forward(
return_loss: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
interpolate_pos_encoding: bool = False,
return_dict: Optional[bool] = None,
) -> Union[Tuple, ChineseCLIPOutput]:
r"""
Expand Down Expand Up @@ -1517,6 +1564,7 @@ def forward(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
interpolate_pos_encoding=interpolate_pos_encoding,
return_dict=return_dict,
)

Expand Down
Loading

0 comments on commit 6ed7b47

Please sign in to comment.