From 2e953e6f46e1fc99f110f06e2c3066d668551110 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Mon, 11 Mar 2024 15:46:49 +0800 Subject: [PATCH 1/4] pt: add explicit decay_rate and min_lr for lr --- deepmd/pt/utils/learning_rate.py | 47 +++++++++++++++++++++++--------- deepmd/tf/utils/learning_rate.py | 1 + deepmd/utils/argcheck.py | 23 ++++++++++++++++ source/tests/pt/test_lr.py | 33 ++++++++++++++++++++++ 4 files changed, 91 insertions(+), 13 deletions(-) diff --git a/deepmd/pt/utils/learning_rate.py b/deepmd/pt/utils/learning_rate.py index eca3c6ad87..4bdcc11fe7 100644 --- a/deepmd/pt/utils/learning_rate.py +++ b/deepmd/pt/utils/learning_rate.py @@ -3,14 +3,36 @@ class LearningRateExp: - def __init__(self, start_lr, stop_lr, decay_steps, stop_steps, **kwargs): - """Construct an exponential-decayed learning rate. + def __init__( + self, + start_lr, + stop_lr, + decay_steps, + stop_steps, + decay_rate=None, + min_lr=None, + **kwargs, + ): + """ + Construct an exponential-decayed learning rate. - Args: - - start_lr: Initial learning rate. - - stop_lr: Learning rate at the last step. - - decay_steps: Decay learning rate every N steps. - - stop_steps: When is the last step. + Parameters + ---------- + start_lr + The learning rate at the start of the training. + stop_lr + The desired learning rate at the end of the training. + decay_steps + The learning rate is decaying every this number of training steps. + stop_steps + The total training steps for learning rate scheduler. + decay_rate + The decay rate for the learning rate. + If provided, the decay rate will be set instead of + calculating it through interpolation between start_lr and stop_lr. + min_lr + The minimum learning rate to be used when decay_rate is applied. + If the learning rate decays below min_lr, min_lr will be used instead. """ self.start_lr = start_lr default_ds = 100 if stop_steps // 10 > 100 else stop_steps // 100 + 1 @@ -20,16 +42,15 @@ def __init__(self, start_lr, stop_lr, decay_steps, stop_steps, **kwargs): self.decay_rate = np.exp( np.log(stop_lr / self.start_lr) / (stop_steps / self.decay_steps) ) - if "decay_rate" in kwargs: - self.decay_rate = kwargs["decay_rate"] - if "min_lr" in kwargs: - self.min_lr = kwargs["min_lr"] + if decay_rate is not None: + self.decay_rate = decay_rate + self.min_lr = min_lr else: - self.min_lr = 3e-10 + self.min_lr = None def value(self, step): """Get the learning rate at the given step.""" step_lr = self.start_lr * np.power(self.decay_rate, step // self.decay_steps) - if step_lr < self.min_lr: + if self.min_lr is not None and step_lr < self.min_lr: step_lr = self.min_lr return step_lr diff --git a/deepmd/tf/utils/learning_rate.py b/deepmd/tf/utils/learning_rate.py index 519bf20bd0..d2c6faffa4 100644 --- a/deepmd/tf/utils/learning_rate.py +++ b/deepmd/tf/utils/learning_rate.py @@ -41,6 +41,7 @@ def __init__( stop_lr: float = 5e-8, decay_steps: int = 5000, decay_rate: float = 0.95, + **kwargs, ) -> None: """Constructor.""" self.cd = {} diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index e822e18d50..0f2382fb74 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1522,11 +1522,34 @@ def learning_rate_exp(): doc_decay_steps = ( "The learning rate is decaying every this number of training steps." ) + doc_decay_rate = ( + "The decay rate for the learning rate. " + "If this is provided, it will be used directly as the decay rate for learning rate " + "instead of calculating it through interpolation between start_lr and stop_lr." + ) + doc_min_lr = ( + "The minimum learning rate to be used when decay_rate is applied. " + "If the learning rate decays below min_lr, min_lr will be used instead." + ) args = [ Argument("start_lr", float, optional=True, default=1e-3, doc=doc_start_lr), Argument("stop_lr", float, optional=True, default=1e-8, doc=doc_stop_lr), Argument("decay_steps", int, optional=True, default=5000, doc=doc_decay_steps), + Argument( + "decay_rate", + float, + optional=True, + default=None, + doc=doc_only_pt_supported + doc_decay_rate, + ), + Argument( + "min_lr", + float, + optional=True, + default=None, + doc=doc_only_pt_supported + doc_min_lr, + ), ] return args diff --git a/source/tests/pt/test_lr.py b/source/tests/pt/test_lr.py index ca1ec7e490..cd246ac0d6 100644 --- a/source/tests/pt/test_lr.py +++ b/source/tests/pt/test_lr.py @@ -27,6 +27,7 @@ def test_consistency(self): self.decay_step = decay_step self.stop_step = stop_step self.judge_it() + self.decay_rate_pt() def judge_it(self): base_lr = learning_rate.LearningRateExp( @@ -54,6 +55,38 @@ def judge_it(self): self.assertTrue(np.allclose(base_vals, my_vals)) tf.reset_default_graph() + def decay_rate_pt(self): + my_lr = LearningRateExp( + self.start_lr, self.stop_lr, self.decay_step, self.stop_step + ) + + default_ds = 100 if self.stop_step // 10 > 100 else self.stop_step // 100 + 1 + if self.decay_step >= self.stop_step: + self.decay_step = default_ds + decay_rate = np.exp( + np.log(self.stop_lr / self.start_lr) / (self.stop_step / self.decay_step) + ) + min_lr = self.stop_lr + my_lr_decay = LearningRateExp( + self.start_lr, + 1.0, + self.decay_step, + self.stop_step, + decay_rate=decay_rate, + min_lr=min_lr, + ) + my_vals = [ + my_lr.value(step_id) + for step_id in range(self.stop_step) + if step_id % self.decay_step != 0 + ] + my_vals_decay = [ + my_lr_decay.value(step_id) + for step_id in range(self.stop_step) + if step_id % self.decay_step != 0 + ] + self.assertTrue(np.allclose(my_vals_decay, my_vals)) + if __name__ == "__main__": unittest.main() From 9d76a240202a241e4d7ca4bbdf616a40fa444cff Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 13 Mar 2024 11:47:46 +0800 Subject: [PATCH 2/4] fix comments --- deepmd/pt/utils/learning_rate.py | 9 ++++----- deepmd/tf/utils/learning_rate.py | 1 - deepmd/utils/argcheck.py | 18 ++++++------------ source/tests/pt/test_lr.py | 4 +--- 4 files changed, 11 insertions(+), 21 deletions(-) diff --git a/deepmd/pt/utils/learning_rate.py b/deepmd/pt/utils/learning_rate.py index 4bdcc11fe7..1e432f466a 100644 --- a/deepmd/pt/utils/learning_rate.py +++ b/deepmd/pt/utils/learning_rate.py @@ -10,7 +10,6 @@ def __init__( decay_steps, stop_steps, decay_rate=None, - min_lr=None, **kwargs, ): """ @@ -22,6 +21,9 @@ def __init__( The learning rate at the start of the training. stop_lr The desired learning rate at the end of the training. + When decay_rate is explicitly set, this value will serve as + the minimum learning rate during training. In other words, + if the learning rate decays below stop_lr, stop_lr will be applied instead. decay_steps The learning rate is decaying every this number of training steps. stop_steps @@ -30,9 +32,6 @@ def __init__( The decay rate for the learning rate. If provided, the decay rate will be set instead of calculating it through interpolation between start_lr and stop_lr. - min_lr - The minimum learning rate to be used when decay_rate is applied. - If the learning rate decays below min_lr, min_lr will be used instead. """ self.start_lr = start_lr default_ds = 100 if stop_steps // 10 > 100 else stop_steps // 100 + 1 @@ -44,7 +43,7 @@ def __init__( ) if decay_rate is not None: self.decay_rate = decay_rate - self.min_lr = min_lr + self.min_lr = stop_lr else: self.min_lr = None diff --git a/deepmd/tf/utils/learning_rate.py b/deepmd/tf/utils/learning_rate.py index d2c6faffa4..519bf20bd0 100644 --- a/deepmd/tf/utils/learning_rate.py +++ b/deepmd/tf/utils/learning_rate.py @@ -41,7 +41,6 @@ def __init__( stop_lr: float = 5e-8, decay_steps: int = 5000, decay_rate: float = 0.95, - **kwargs, ) -> None: """Constructor.""" self.cd = {} diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 0f2382fb74..35642f6e3f 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -1518,7 +1518,12 @@ def linear_ener_model_args() -> Argument: # --- Learning rate configurations: --- # def learning_rate_exp(): doc_start_lr = "The learning rate at the start of the training." - doc_stop_lr = "The desired learning rate at the end of the training." + doc_stop_lr = ( + "The desired learning rate at the end of the training. " + f"When decay_rate {doc_only_pt_supported}is explicitly set, " + "this value will serve as the minimum learning rate during training. " + "In other words, if the learning rate decays below stop_lr, stop_lr will be applied instead." + ) doc_decay_steps = ( "The learning rate is decaying every this number of training steps." ) @@ -1527,10 +1532,6 @@ def learning_rate_exp(): "If this is provided, it will be used directly as the decay rate for learning rate " "instead of calculating it through interpolation between start_lr and stop_lr." ) - doc_min_lr = ( - "The minimum learning rate to be used when decay_rate is applied. " - "If the learning rate decays below min_lr, min_lr will be used instead." - ) args = [ Argument("start_lr", float, optional=True, default=1e-3, doc=doc_start_lr), @@ -1543,13 +1544,6 @@ def learning_rate_exp(): default=None, doc=doc_only_pt_supported + doc_decay_rate, ), - Argument( - "min_lr", - float, - optional=True, - default=None, - doc=doc_only_pt_supported + doc_min_lr, - ), ] return args diff --git a/source/tests/pt/test_lr.py b/source/tests/pt/test_lr.py index cd246ac0d6..8a446451f9 100644 --- a/source/tests/pt/test_lr.py +++ b/source/tests/pt/test_lr.py @@ -66,14 +66,12 @@ def decay_rate_pt(self): decay_rate = np.exp( np.log(self.stop_lr / self.start_lr) / (self.stop_step / self.decay_step) ) - min_lr = self.stop_lr my_lr_decay = LearningRateExp( self.start_lr, - 1.0, + 1e-10, self.decay_step, self.stop_step, decay_rate=decay_rate, - min_lr=min_lr, ) my_vals = [ my_lr.value(step_id) From 0c3c8d7ab346172f9b15bc419284f0451da922bd Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Wed, 13 Mar 2024 11:56:24 +0800 Subject: [PATCH 3/4] Update learning_rate.py --- deepmd/pt/utils/learning_rate.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/deepmd/pt/utils/learning_rate.py b/deepmd/pt/utils/learning_rate.py index 1e432f466a..94c657abd4 100644 --- a/deepmd/pt/utils/learning_rate.py +++ b/deepmd/pt/utils/learning_rate.py @@ -43,13 +43,11 @@ def __init__( ) if decay_rate is not None: self.decay_rate = decay_rate - self.min_lr = stop_lr - else: - self.min_lr = None + self.min_lr = stop_lr def value(self, step): """Get the learning rate at the given step.""" step_lr = self.start_lr * np.power(self.decay_rate, step // self.decay_steps) - if self.min_lr is not None and step_lr < self.min_lr: + if step_lr < self.min_lr: step_lr = self.min_lr return step_lr From dd4a5934fd1a0fcec57345e3e6e57457abaf85fe Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Tue, 19 Mar 2024 17:57:18 +0800 Subject: [PATCH 4/4] Add truncation test when lr reaches stop_lr. --- source/tests/pt/test_lr.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/source/tests/pt/test_lr.py b/source/tests/pt/test_lr.py index 8a446451f9..9fbde599bb 100644 --- a/source/tests/pt/test_lr.py +++ b/source/tests/pt/test_lr.py @@ -73,6 +73,14 @@ def decay_rate_pt(self): self.stop_step, decay_rate=decay_rate, ) + min_lr = 1e-5 + my_lr_decay_trunc = LearningRateExp( + self.start_lr, + min_lr, + self.decay_step, + self.stop_step, + decay_rate=decay_rate, + ) my_vals = [ my_lr.value(step_id) for step_id in range(self.stop_step) @@ -83,7 +91,15 @@ def decay_rate_pt(self): for step_id in range(self.stop_step) if step_id % self.decay_step != 0 ] + my_vals_decay_trunc = [ + my_lr_decay_trunc.value(step_id) + for step_id in range(self.stop_step) + if step_id % self.decay_step != 0 + ] self.assertTrue(np.allclose(my_vals_decay, my_vals)) + self.assertTrue( + np.allclose(my_vals_decay_trunc, np.clip(my_vals, a_min=min_lr, a_max=None)) + ) if __name__ == "__main__":