Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pt: add explicit decay_rate for lr #3445

Merged
merged 7 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 34 additions & 13 deletions deepmd/pt/utils/learning_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,36 @@


class LearningRateExp:
def __init__(self, start_lr, stop_lr, decay_steps, stop_steps, **kwargs):
"""Construct an exponential-decayed learning rate.
def __init__(

Check warning on line 6 in deepmd/pt/utils/learning_rate.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/learning_rate.py#L6

Added line #L6 was not covered by tests
self,
start_lr,
stop_lr,
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
decay_steps,
stop_steps,
decay_rate=None,
min_lr=None,
**kwargs,
):
"""
Construct an exponential-decayed learning rate.

Args:
- start_lr: Initial learning rate.
- stop_lr: Learning rate at the last step.
- decay_steps: Decay learning rate every N steps.
- stop_steps: When is the last step.
Parameters
----------
start_lr
The learning rate at the start of the training.
stop_lr
The desired learning rate at the end of the training.
decay_steps
The learning rate is decaying every this number of training steps.
stop_steps
The total training steps for learning rate scheduler.
decay_rate
The decay rate for the learning rate.
If provided, the decay rate will be set instead of
calculating it through interpolation between start_lr and stop_lr.
min_lr
The minimum learning rate to be used when decay_rate is applied.
If the learning rate decays below min_lr, min_lr will be used instead.
"""
self.start_lr = start_lr
default_ds = 100 if stop_steps // 10 > 100 else stop_steps // 100 + 1
Expand All @@ -20,16 +42,15 @@
self.decay_rate = np.exp(
np.log(stop_lr / self.start_lr) / (stop_steps / self.decay_steps)
)
if "decay_rate" in kwargs:
self.decay_rate = kwargs["decay_rate"]
if "min_lr" in kwargs:
self.min_lr = kwargs["min_lr"]
if decay_rate is not None:
self.decay_rate = decay_rate
self.min_lr = min_lr

Check warning on line 47 in deepmd/pt/utils/learning_rate.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/learning_rate.py#L45-L47

Added lines #L45 - L47 were not covered by tests
iProzd marked this conversation as resolved.
Show resolved Hide resolved
else:
self.min_lr = 3e-10
self.min_lr = None

Check warning on line 49 in deepmd/pt/utils/learning_rate.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/learning_rate.py#L49

Added line #L49 was not covered by tests
iProzd marked this conversation as resolved.
Show resolved Hide resolved

def value(self, step):
"""Get the learning rate at the given step."""
step_lr = self.start_lr * np.power(self.decay_rate, step // self.decay_steps)
if step_lr < self.min_lr:
if self.min_lr is not None and step_lr < self.min_lr:

Check warning on line 54 in deepmd/pt/utils/learning_rate.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/utils/learning_rate.py#L54

Added line #L54 was not covered by tests
step_lr = self.min_lr
return step_lr
1 change: 1 addition & 0 deletions deepmd/tf/utils/learning_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
stop_lr: float = 5e-8,
decay_steps: int = 5000,
decay_rate: float = 0.95,
**kwargs,
iProzd marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""Constructor."""
self.cd = {}
Expand Down
23 changes: 23 additions & 0 deletions deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1522,11 +1522,34 @@
doc_decay_steps = (
"The learning rate is decaying every this number of training steps."
)
doc_decay_rate = (

Check warning on line 1525 in deepmd/utils/argcheck.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/argcheck.py#L1525

Added line #L1525 was not covered by tests
"The decay rate for the learning rate. "
"If this is provided, it will be used directly as the decay rate for learning rate "
"instead of calculating it through interpolation between start_lr and stop_lr."
iProzd marked this conversation as resolved.
Show resolved Hide resolved
)
doc_min_lr = (

Check warning on line 1530 in deepmd/utils/argcheck.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/argcheck.py#L1530

Added line #L1530 was not covered by tests
"The minimum learning rate to be used when decay_rate is applied. "
"If the learning rate decays below min_lr, min_lr will be used instead."
)

args = [
Argument("start_lr", float, optional=True, default=1e-3, doc=doc_start_lr),
Argument("stop_lr", float, optional=True, default=1e-8, doc=doc_stop_lr),
Argument("decay_steps", int, optional=True, default=5000, doc=doc_decay_steps),
Argument(
"decay_rate",
float,
optional=True,
default=None,
doc=doc_only_pt_supported + doc_decay_rate,
),
Argument(
"min_lr",
float,
optional=True,
default=None,
doc=doc_only_pt_supported + doc_min_lr,
),
]
return args

Expand Down
33 changes: 33 additions & 0 deletions source/tests/pt/test_lr.py
iProzd marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_consistency(self):
self.decay_step = decay_step
self.stop_step = stop_step
self.judge_it()
self.decay_rate_pt()

def judge_it(self):
base_lr = learning_rate.LearningRateExp(
Expand Down Expand Up @@ -54,6 +55,38 @@ def judge_it(self):
self.assertTrue(np.allclose(base_vals, my_vals))
tf.reset_default_graph()

def decay_rate_pt(self):
my_lr = LearningRateExp(
self.start_lr, self.stop_lr, self.decay_step, self.stop_step
)

default_ds = 100 if self.stop_step // 10 > 100 else self.stop_step // 100 + 1
if self.decay_step >= self.stop_step:
self.decay_step = default_ds
decay_rate = np.exp(
np.log(self.stop_lr / self.start_lr) / (self.stop_step / self.decay_step)
)
min_lr = self.stop_lr
my_lr_decay = LearningRateExp(
self.start_lr,
1.0,
self.decay_step,
self.stop_step,
decay_rate=decay_rate,
min_lr=min_lr,
)
my_vals = [
my_lr.value(step_id)
for step_id in range(self.stop_step)
if step_id % self.decay_step != 0
]
my_vals_decay = [
my_lr_decay.value(step_id)
for step_id in range(self.stop_step)
if step_id % self.decay_step != 0
]
self.assertTrue(np.allclose(my_vals_decay, my_vals))


if __name__ == "__main__":
unittest.main()