-
Notifications
You must be signed in to change notification settings - Fork 27.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[DETR
] Remove timm hardcoded logic in modeling files
#29038
[DETR
] Remove timm hardcoded logic in modeling files
#29038
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
bc8d155
to
0a38d16
Compare
4d524f2
to
33c89b9
Compare
@@ -141,23 +161,6 @@ def backward(context, grad_output): | |||
return grad_value, None, None, grad_sampling_loc, grad_attn_weight, None | |||
|
|||
|
|||
if is_scipy_available(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moves these above the MultiScaleDeformableAttentionFunction definition - better matching library patterns
raise ValueError("You can't specify both `backbone_kwargs` and `backbone_config`.") | ||
|
||
if not use_timm_backbone: | ||
if use_timm_backbone and backbone_kwargs is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These replicate the defaults that are used to load a timm backbone in the modeling file. This PR makes it possible to configure the timm backbone loaded, using the standard backbone API, the defaults here are for backwards compatibility
@@ -354,17 +354,20 @@ def __init__(self, config): | |||
|
|||
self.config = config | |||
|
|||
# For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API | |||
if config.use_timm_backbone: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately, we can't remove the timm logic here and use load_backbone instead. When using load_backbone a timm model is loaded as a TimmBackbone class. This means, the loaded weight names are different from using the create_model call here. For backwards compatibility - being able to load existing checkpoints - we need to leave as-is.
Instead - to be compatible with the backbone API and remove the hard-coding, we allow specifying of the backbone behaviour through backbone_kwargs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, makes sense!
DETR
] Remove timm hardcoded logic in modeling files
dbc1355
to
51eb6d3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok I think that makes sense! No problem for me, thanks for working on this @amyeroberts
@@ -354,17 +354,20 @@ def __init__(self, config): | |||
|
|||
self.config = config | |||
|
|||
# For backwards compatibility we have to use the timm library directly instead of the AutoBackbone API | |||
if config.use_timm_backbone: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, makes sense!
kwargs = {} | ||
kwargs = getattr(config, "backbone_kwargs", {}) | ||
kwargs = {} if kwargs is None else kwargs.copy() | ||
out_indices = kwargs.pop("out_indices", (1, 2, 3, 4)) | ||
num_channels = kwargs.pop("in_chans", config.num_channels) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would maybe add a few comments here to explain what's happening for posterity
kwargs = {} | ||
kwargs = getattr(config, "backbone_kwargs", {}) | ||
kwargs = {} if kwargs is None else kwargs.copy() | ||
out_indices = kwargs.pop("out_indices", (2, 3, 4) if config.num_feature_levels > 1 else (4,)) | ||
num_channels = kwargs.pop("in_chans", config.num_channels) | ||
if config.dilation: | ||
kwargs["output_stride"] = 16 | ||
kwargs["output_stride"] = kwargs.get("output_stride", 16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM +1 on the comments for posterity! 🤗 sorry for being slow here
Co-authored-by: Arthur <[email protected]>
51eb6d3
to
80b32cc
Compare
* Enable instantiating model with pretrained backbone weights * Clarify pretrained import * Use load_backbone instead * Add backbone_kwargs to config * Fix up * Add tests * Tidy up * Enable instantiating model with pretrained backbone weights * Update tests so backbone checkpoint isn't passed in * Clarify pretrained import * Update configs - docs and validation check * Update src/transformers/utils/backbone_utils.py Co-authored-by: Arthur <[email protected]> * Clarify exception message * Update config init in tests * Add test for when use_timm_backbone=True * Use load_backbone instead * Add use_timm_backbone to the model configs * Add backbone_kwargs to config * Pass kwargs to constructors * Draft * Fix tests * Add back timm - weight naming * More tidying up * Whoops * Tidy up * Handle when kwargs are none * Update tests * Revert test changes * Deformable detr test - don't use default * Don't mutate; correct model attributes * Add some clarifying comments * nit - grammar is hard --------- Co-authored-by: Arthur <[email protected]>
What does this PR do?
Certain models use timm's
create_model
to load their backbone.In future, all models should use
load_backbone
to create the backbone, removing the need for the conditional timm logic. Removing this from existing models isn't possible, because it changes the weight names for the backbone as the backbone is now loaded as a TimmBackbone class i.e. existing checkpoints wouldn't compatible.This PR makes it possible to configure the timm backbone loaded completely through the model config, removing the hard-coded values in the modeling files. So, for users, it's the same as-if
load_backbone
was being used.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.