Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] some param will be set to optimizer twice when using MLM transformer heads #965

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
57 changes: 36 additions & 21 deletions mmf/models/transformers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,25 +103,13 @@ def build(self):

def get_optimizer_parameters(self, config):
lr = config.optimizer.params.lr

trunk_param_set = set()
param_list = []
parameters = []
head_configs = self.config.get("heads", [])

for name, module in self.named_children():
# Heads can have different learning rates. This is handled here
if name == "heads":
# Parameters in the head which have a separate learning
# rate, are added as a separate param group
for head_config, head in zip(head_configs, self.heads):
parameters, param_list = self.set_lr_for_parameters(
config=head_config,
module_name="{} head".format(head_config.get("type", "MLP")),
base_lr=lr,
module=head,
parameters=parameters,
param_list=param_list,
)
elif name == "encoders":

if name == "encoders":
for key in module:
for modality in self.config.modalities:
if key == modality.key:
Expand All @@ -134,29 +122,56 @@ def get_optimizer_parameters(self, config):
parameters=parameters,
param_list=param_list,
)
else:
if name != "heads":
# For other modules in trunk, add to same param group
param_list += list(module.named_parameters())

trunk_param_set.update(list(module.parameters()))
head_configs = self.config.get("heads", [])
# Heads can have different learning rates. This is handled here
if len(head_configs) > 0:
# Parameters in the head which have a separate learning
# rate, are added as a separate param group
for head_config, head in zip(head_configs, self.heads):
parameters, param_list = self.set_lr_for_parameters(
config=head_config,
module_name="{} head".format(head_config.get("type", "MLP")),
base_lr=lr,
module=head,
parameters=parameters,
param_list=param_list,
excluded_params=trunk_param_set,
)
parameters += get_bert_configured_parameters(param_list)

return parameters

def set_lr_for_parameters(
self, config, module_name, base_lr, module, parameters, param_list
self,
config,
module_name,
base_lr,
module,
parameters,
param_list,
excluded_params=None,
):
lr_multiplier = config.get("lr_multiplier", 1.0)
module_param = list(module.named_parameters())
if excluded_params is not None:
module_param = [
tup for tup in module_param if tup[1] not in excluded_params
]
if lr_multiplier != 1.0:
logger.info(
f"Setting learning rate of {module_name} to be {base_lr} * {lr_multiplier}."
) # noqa
parameters += get_bert_configured_parameters(
module, base_lr * lr_multiplier
module_param, base_lr * lr_multiplier
)
else:
# Parameters for the modules with same learning rate as
# trunk, add to same param group
param_list += list(module.named_parameters())
param_list += module_param
return parameters, param_list

def build_encoders(self):
Expand Down