Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pt: add explicit decay_rate for lr #3445

Merged
merged 7 commits into from
Mar 20, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 31 additions & 13 deletions deepmd/pt/utils/learning_rate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,35 @@


class LearningRateExp:
def __init__(self, start_lr, stop_lr, decay_steps, stop_steps, **kwargs):
"""Construct an exponential-decayed learning rate.
def __init__(
self,
start_lr,
stop_lr,
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
decay_steps,
stop_steps,
decay_rate=None,
**kwargs,
):
"""
Construct an exponential-decayed learning rate.

Args:
- start_lr: Initial learning rate.
- stop_lr: Learning rate at the last step.
- decay_steps: Decay learning rate every N steps.
- stop_steps: When is the last step.
Parameters
----------
start_lr
The learning rate at the start of the training.
stop_lr
The desired learning rate at the end of the training.
When decay_rate is explicitly set, this value will serve as
the minimum learning rate during training. In other words,
if the learning rate decays below stop_lr, stop_lr will be applied instead.
decay_steps
The learning rate is decaying every this number of training steps.
stop_steps
The total training steps for learning rate scheduler.
decay_rate
The decay rate for the learning rate.
If provided, the decay rate will be set instead of
calculating it through interpolation between start_lr and stop_lr.
"""
self.start_lr = start_lr
default_ds = 100 if stop_steps // 10 > 100 else stop_steps // 100 + 1
Expand All @@ -20,12 +41,9 @@ def __init__(self, start_lr, stop_lr, decay_steps, stop_steps, **kwargs):
self.decay_rate = np.exp(
np.log(stop_lr / self.start_lr) / (stop_steps / self.decay_steps)
)
if "decay_rate" in kwargs:
self.decay_rate = kwargs["decay_rate"]
if "min_lr" in kwargs:
self.min_lr = kwargs["min_lr"]
else:
self.min_lr = 3e-10
if decay_rate is not None:
self.decay_rate = decay_rate
self.min_lr = stop_lr

def value(self, step):
"""Get the learning rate at the given step."""
Expand Down
19 changes: 18 additions & 1 deletion deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -1517,15 +1517,32 @@ def linear_ener_model_args() -> Argument:
# --- Learning rate configurations: --- #
def learning_rate_exp():
doc_start_lr = "The learning rate at the start of the training."
doc_stop_lr = "The desired learning rate at the end of the training."
doc_stop_lr = (
"The desired learning rate at the end of the training. "
f"When decay_rate {doc_only_pt_supported}is explicitly set, "
"this value will serve as the minimum learning rate during training. "
"In other words, if the learning rate decays below stop_lr, stop_lr will be applied instead."
)
doc_decay_steps = (
"The learning rate is decaying every this number of training steps."
)
doc_decay_rate = (
"The decay rate for the learning rate. "
"If this is provided, it will be used directly as the decay rate for learning rate "
"instead of calculating it through interpolation between start_lr and stop_lr."
iProzd marked this conversation as resolved.
Show resolved Hide resolved
)

args = [
Argument("start_lr", float, optional=True, default=1e-3, doc=doc_start_lr),
Argument("stop_lr", float, optional=True, default=1e-8, doc=doc_stop_lr),
Argument("decay_steps", int, optional=True, default=5000, doc=doc_decay_steps),
Argument(
"decay_rate",
float,
optional=True,
default=None,
doc=doc_only_pt_supported + doc_decay_rate,
),
]
return args

Expand Down
31 changes: 31 additions & 0 deletions source/tests/pt/test_lr.py
iProzd marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_consistency(self):
self.decay_step = decay_step
self.stop_step = stop_step
self.judge_it()
self.decay_rate_pt()

def judge_it(self):
base_lr = learning_rate.LearningRateExp(
Expand Down Expand Up @@ -54,6 +55,36 @@ def judge_it(self):
self.assertTrue(np.allclose(base_vals, my_vals))
tf.reset_default_graph()

def decay_rate_pt(self):
my_lr = LearningRateExp(
self.start_lr, self.stop_lr, self.decay_step, self.stop_step
)

default_ds = 100 if self.stop_step // 10 > 100 else self.stop_step // 100 + 1
if self.decay_step >= self.stop_step:
self.decay_step = default_ds
decay_rate = np.exp(
np.log(self.stop_lr / self.start_lr) / (self.stop_step / self.decay_step)
)
my_lr_decay = LearningRateExp(
self.start_lr,
1e-10,
self.decay_step,
self.stop_step,
decay_rate=decay_rate,
)
my_vals = [
my_lr.value(step_id)
for step_id in range(self.stop_step)
if step_id % self.decay_step != 0
]
my_vals_decay = [
my_lr_decay.value(step_id)
for step_id in range(self.stop_step)
if step_id % self.decay_step != 0
]
self.assertTrue(np.allclose(my_vals_decay, my_vals))


if __name__ == "__main__":
unittest.main()
Loading