Skip to content

Commit

Permalink
[DINOv2] Update pooler output (#25392)
Browse files Browse the repository at this point in the history
Update pooler output
  • Loading branch information
NielsRogge authored Aug 10, 2023
1 parent d0c1aeb commit b175fc3
Showing 1 changed file with 5 additions and 22 deletions.
27 changes: 5 additions & 22 deletions src/transformers/models/dinov2/modeling_dinov2.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,15 +583,14 @@ def _set_gradient_checkpointing(self, module: Dinov2Encoder, value: bool = False
DINOV2_START_DOCSTRING,
)
class Dinov2Model(Dinov2PreTrainedModel):
def __init__(self, config: Dinov2Config, add_pooling_layer: bool = True):
def __init__(self, config: Dinov2Config):
super().__init__(config)
self.config = config

self.embeddings = Dinov2Embeddings(config)
self.encoder = Dinov2Encoder(config)

self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.pooler = Dinov2Pooler(config) if add_pooling_layer else None

# Initialize weights and apply final processing
self.post_init()
Expand Down Expand Up @@ -651,10 +650,10 @@ def forward(
)
sequence_output = encoder_outputs[0]
sequence_output = self.layernorm(sequence_output)
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
pooled_output = sequence_output[:, 0, :]

if not return_dict:
head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,)
head_outputs = (sequence_output, pooled_output)
return head_outputs + encoder_outputs[1:]

return BaseModelOutputWithPooling(
Expand All @@ -665,22 +664,6 @@ def forward(
)


# Copied from transformers.models.vit.modeling_vit.ViTPooler with ViT->Dinov2
class Dinov2Pooler(nn.Module):
def __init__(self, config: Dinov2Config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.activation = nn.Tanh()

def forward(self, hidden_states):
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output


@add_start_docstrings(
"""
Dinov2 Model transformer with an image classification head on top (a linear layer on top of the final hidden state
Expand All @@ -693,7 +676,7 @@ def __init__(self, config: Dinov2Config) -> None:
super().__init__(config)

self.num_labels = config.num_labels
self.dinov2 = Dinov2Model(config, add_pooling_layer=False)
self.dinov2 = Dinov2Model(config)

# Classifier head
self.classifier = (
Expand Down Expand Up @@ -770,7 +753,7 @@ def forward(
loss = loss_fct(logits, labels)

if not return_dict:
output = (logits,) + outputs[1:]
output = (logits,) + outputs[2:]
return ((loss,) + output) if loss is not None else output

return ImageClassifierOutput(
Expand Down

0 comments on commit b175fc3

Please sign in to comment.