Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Editing issue with pickle def with lambda function #23869

Merged
merged 4 commits into from
May 30, 2023
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion src/transformers/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,21 @@
logger = logging.get_logger(__name__)


def get_constant_lambda(_=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def get_constant_lambda(_=None):
def _get_constant_lambda(current_step):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

"""
Return 1, independent from args.

Args:
_ ( *optional*, defaults to None):
Placeholder to argument, used to consistency with args in LambdaLR

Return:
1 : int - constant lambda for constant scheduler
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for a documentation for a private function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete it.


return 1


def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
"""
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
Expand All @@ -46,7 +61,7 @@ def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""

return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
return LambdaLR(optimizer, get_constant_lambda, last_epoch=last_epoch)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return LambdaLR(optimizer, get_constant_lambda, last_epoch=last_epoch)
return LambdaLR(optimizer, _get_constant_lambda, last_epoch=last_epoch)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done refactoring based on function name



def get_reduce_on_plateau_schedule(optimizer: Optimizer):
Expand Down