Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MAINTENANCE] Refactor and clean up. #4008

Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 40 additions & 22 deletions ludwig/encoders/image/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,31 +383,49 @@ def __init__(
self._input_shape = (in_channels, img_height, img_width)

if use_pretrained and not saved_weights_in_checkpoint:
transformer = ViTModel.from_pretrained(pretrained_model)
if output_attentions:
transformer = ViTModel.from_pretrained(
pretrained_model_name_or_path=pretrained_model,
attn_implementation="eager",
)
else:
transformer = ViTModel.from_pretrained(pretrained_model_name_or_path=pretrained_model)
else:
config = ViTConfig(
image_size=img_height,
num_channels=in_channels,
patch_size=patch_size,
hidden_size=hidden_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
hidden_act=hidden_act,
hidden_dropout_prob=hidden_dropout_prob,
attention_probs_dropout_prob=attention_probs_dropout_prob,
initializer_range=initializer_range,
layer_norm_eps=layer_norm_eps,
gradient_checkpointing=gradient_checkpointing,
)
if output_attentions:
config = ViTConfig(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just create a dictionary mapping, then optionally add "attn_implementation" if output_attention and pass the dictionary as **kwargs into the Config? Will reduce boilerplate :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@arnavgarg1 Done -- sorry about missing the obvious! Thanks!

image_size=img_height,
num_channels=in_channels,
patch_size=patch_size,
hidden_size=hidden_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
hidden_act=hidden_act,
hidden_dropout_prob=hidden_dropout_prob,
attention_probs_dropout_prob=attention_probs_dropout_prob,
initializer_range=initializer_range,
layer_norm_eps=layer_norm_eps,
gradient_checkpointing=gradient_checkpointing,
attn_implementation="eager",
)
else:
config = ViTConfig(
image_size=img_height,
num_channels=in_channels,
patch_size=patch_size,
hidden_size=hidden_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
intermediate_size=intermediate_size,
hidden_act=hidden_act,
hidden_dropout_prob=hidden_dropout_prob,
attention_probs_dropout_prob=attention_probs_dropout_prob,
initializer_range=initializer_range,
layer_norm_eps=layer_norm_eps,
gradient_checkpointing=gradient_checkpointing,
)
transformer = ViTModel(config)

if output_attentions:
config_dict: dict = transformer.config.to_dict()
updated_config: ViTConfig = ViTConfig(**config_dict)
updated_config._attn_implementation = "eager"
transformer = ViTModel(updated_config)

self.transformer = FreezeModule(transformer, frozen=not trainable)

self._output_shape = (transformer.config.hidden_size,)
Expand Down
Loading