-
Notifications
You must be signed in to change notification settings - Fork 27.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fixes clip interpolate #30783
Closed
Closed
fixes clip interpolate #30783
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1010,15 +1010,52 @@ def __init__(self, config: AltCLIPVisionConfig): | |
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) | ||
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) | ||
|
||
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: | ||
batch_size = pixel_values.shape[0] | ||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: | ||
""" | ||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher | ||
resolution images. | ||
|
||
Source: | ||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 | ||
""" | ||
position_embeddings = self.position_embedding.weight.unsqueeze(0) | ||
num_patches = embeddings.shape[1] - 1 | ||
num_positions = position_embeddings.shape[1] - 1 | ||
if num_patches == num_positions and height == width: | ||
return position_embeddings | ||
class_pos_embed = position_embeddings[:, 0] | ||
patch_pos_embed = position_embeddings[:, 1:] | ||
dim = embeddings.shape[-1] | ||
height = height // self.config.patch_size | ||
width = width // self.config.patch_size | ||
# we add a small number to avoid floating point error in the interpolation | ||
# see discussion at https://github.com/facebookresearch/dino/issues/8 | ||
height, width = height + 0.1, width + 0.1 | ||
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) | ||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) | ||
patch_pos_embed = nn.functional.interpolate( | ||
patch_pos_embed, | ||
scale_factor=(height / math.sqrt(num_positions), width / math.sqrt(num_positions)), | ||
mode="bicubic", | ||
align_corners=False, | ||
) | ||
if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]: | ||
raise ValueError("Width or height does not match with the interpolated position embeddings") | ||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) | ||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) | ||
|
||
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: | ||
batch_size, _, height, width = pixel_values.shape | ||
target_dtype = self.patch_embedding.weight.dtype | ||
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] | ||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2) | ||
|
||
class_embeds = self.class_embedding.expand(batch_size, 1, -1) | ||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1) | ||
embeddings = embeddings + self.position_embedding(self.position_ids) | ||
if interpolate_pos_encoding: | ||
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) | ||
else: | ||
embeddings = embeddings + self.position_embedding(self.position_ids) | ||
return embeddings | ||
|
||
|
||
|
@@ -1099,6 +1136,7 @@ def forward( | |
output_attentions: Optional[bool] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
return_dict: Optional[bool] = None, | ||
interpolate_pos_encoding: Optional[bool] = False, | ||
) -> Union[Tuple, BaseModelOutputWithPooling]: | ||
r""" | ||
Returns: | ||
|
@@ -1113,7 +1151,7 @@ def forward( | |
if pixel_values is None: | ||
raise ValueError("You have to specify pixel_values") | ||
|
||
hidden_states = self.embeddings(pixel_values) | ||
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) | ||
hidden_states = self.pre_layrnorm(hidden_states) | ||
|
||
encoder_outputs = self.encoder( | ||
|
@@ -1158,6 +1196,7 @@ def forward( | |
pixel_values: Optional[torch.FloatTensor] = None, | ||
output_attentions: Optional[bool] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
interpolate_pos_encoding: bool = False, | ||
return_dict: Optional[bool] = None, | ||
) -> Union[Tuple, BaseModelOutputWithPooling]: | ||
r""" | ||
|
@@ -1188,6 +1227,7 @@ def forward( | |
pixel_values=pixel_values, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
interpolate_pos_encoding=interpolate_pos_encoding, | ||
return_dict=return_dict, | ||
) | ||
|
||
|
@@ -1548,6 +1588,7 @@ def get_image_features( | |
pixel_values: Optional[torch.FloatTensor] = None, | ||
output_attentions: Optional[bool] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
interpolate_pos_encoding: bool = False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
return_dict: Optional[bool] = None, | ||
) -> torch.FloatTensor: | ||
r""" | ||
|
@@ -1580,6 +1621,7 @@ def get_image_features( | |
pixel_values=pixel_values, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
interpolate_pos_encoding=interpolate_pos_encoding, | ||
return_dict=return_dict, | ||
) | ||
|
||
|
@@ -1600,6 +1642,7 @@ def forward( | |
return_loss: Optional[bool] = None, | ||
output_attentions: Optional[bool] = None, | ||
output_hidden_states: Optional[bool] = None, | ||
interpolate_pos_encoding: bool = False, | ||
nileshkokane01 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return_dict: Optional[bool] = None, | ||
) -> Union[Tuple, AltCLIPOutput]: | ||
r""" | ||
|
@@ -1644,6 +1687,7 @@ def forward( | |
pixel_values=pixel_values, | ||
output_attentions=output_attentions, | ||
output_hidden_states=output_hidden_states, | ||
interpolate_pos_encoding=interpolate_pos_encoding, | ||
return_dict=return_dict, | ||
) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -276,15 +276,52 @@ def __init__(self, config: BridgeTowerVisionConfig): | |
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) | ||
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False) | ||
|
||
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: | ||
batch_size = pixel_values.shape[0] | ||
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: | ||
""" | ||
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher | ||
resolution images. | ||
|
||
Source: | ||
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 | ||
""" | ||
position_embeddings = self.position_embedding.weight.unsqueeze(0) | ||
num_patches = embeddings.shape[1] - 1 | ||
num_positions = position_embeddings.shape[1] - 1 | ||
if num_patches == num_positions and height == width: | ||
return position_embeddings | ||
class_pos_embed = position_embeddings[:, 0] | ||
patch_pos_embed = position_embeddings[:, 1:] | ||
dim = embeddings.shape[-1] | ||
height = height // self.config.patch_size | ||
width = width // self.config.patch_size | ||
# we add a small number to avoid floating point error in the interpolation | ||
# see discussion at https://github.com/facebookresearch/dino/issues/8 | ||
height, width = height + 0.1, width + 0.1 | ||
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) | ||
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) | ||
patch_pos_embed = nn.functional.interpolate( | ||
patch_pos_embed, | ||
scale_factor=(height / math.sqrt(num_positions), width / math.sqrt(num_positions)), | ||
mode="bicubic", | ||
align_corners=False, | ||
) | ||
if int(height) != patch_pos_embed.shape[-2] or int(width) != patch_pos_embed.shape[-1]: | ||
raise ValueError("Width or height does not match with the interpolated position embeddings") | ||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) | ||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) | ||
|
||
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding=False) -> torch.Tensor: | ||
batch_size, _, height, width = pixel_values.shape | ||
target_dtype = self.patch_embedding.weight.dtype | ||
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid] | ||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2) | ||
|
||
class_embeds = self.class_embedding.expand(batch_size, 1, -1) | ||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1) | ||
embeddings = embeddings + self.position_embedding(self.position_ids) | ||
if interpolate_pos_encoding: | ||
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) | ||
else: | ||
embeddings = embeddings + self.position_embedding(self.position_ids) | ||
return embeddings | ||
|
||
|
||
|
@@ -302,8 +339,13 @@ def __init__(self, config): | |
[nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) for _ in range(config.num_hidden_layers)] | ||
) | ||
|
||
def forward(self, pixel_values: torch.Tensor, attention_mask): | ||
hidden_states = self.embeddings(pixel_values) | ||
def forward( | ||
self, | ||
pixel_values: torch.Tensor, | ||
attention_mask, | ||
interpolate_pos_encoding: bool = False, | ||
): | ||
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding) | ||
hidden_states = self.ln_pre(hidden_states) | ||
# NLD -> LND | ||
hidden_states = hidden_states.permute(1, 0, 2) | ||
|
@@ -324,8 +366,12 @@ def forward(self, pixel_values: torch.Tensor, attention_mask): | |
hidden_states = torch.stack(hidden_states_stack, dim=0) | ||
return hidden_states | ||
|
||
def forward_pre(self, pixel_values: torch.Tensor): | ||
hidden_states = self.embeddings(pixel_values) | ||
def forward_pre( | ||
self, | ||
pixel_values: torch.Tensor, | ||
interpolate_pos_encoding: bool = False, | ||
): | ||
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) | ||
hidden_states = self.ln_pre(hidden_states) | ||
# NLD -> LND | ||
hidden_states = hidden_states.permute(1, 0, 2) | ||
|
@@ -1015,8 +1061,8 @@ def __init__(self, config): | |
def dtype(self): | ||
return self.visual.embeddings.patch_embedding.weight.dtype | ||
|
||
def forward(self, image, image_mask=None): | ||
return self.visual(image.type(self.dtype), image_mask) | ||
def forward(self, image, image_mask=None, interpolate_pos_encoding=False): | ||
return self.visual(image.type(self.dtype), image_mask, interpolate_pos_encoding) | ||
|
||
|
||
class BridgeTowerTextModel(BridgeTowerPreTrainedModel): | ||
|
@@ -1280,6 +1326,7 @@ def forward( | |
output_hidden_states: Optional[bool] = None, | ||
return_dict: Optional[bool] = None, | ||
labels: Optional[torch.LongTensor] = None, | ||
interpolate_pos_encoding: bool = False, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. BRIDGETOWER_INPUTS_DOCSTRING should be updated |
||
) -> Union[Tuple[torch.Tensor], BridgeTowerModelOutput]: | ||
r""" | ||
output_hidden_states (`bool`, *optional*): | ||
|
@@ -1352,7 +1399,9 @@ def forward( | |
all_hidden_states_text += (text_embeds,) | ||
|
||
if image_embeds is None: | ||
image_embeds = self.vision_model.visual.forward_pre(pixel_values.type(self.vision_model.dtype)) | ||
image_embeds = self.vision_model.visual.forward_pre( | ||
pixel_values.type(self.vision_model.dtype), interpolate_pos_encoding=interpolate_pos_encoding | ||
) | ||
else: | ||
# Permute as BridgeTowerResidualAttention has batch_first=True | ||
image_embeds = image_embeds.permute(1, 0, 2) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The value should be True or False, but not None