From 642a19330d70e53cba845c9e10916b76332beec8 Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Mon, 12 Feb 2024 17:17:31 +0100 Subject: [PATCH] fix(tf2): correct step-based learning rate decay (#407) This commit addresses an issue with the step-based learning rate decay mechanism when `lr_decay_ref` is set to 'step'. Previously, the learning rate was decaying too rapidly due to a bug in the decay logic. This fix ensures that the learning rate decays at the correct pace as per the step-based decay configuration. --- .../algos/tf2/common/get_lr_scheduler.py | 4 +- stable_learning_control/algos/tf2/lac/lac.py | 94 ++++++++++++++----- stable_learning_control/algos/tf2/sac/sac.py | 90 +++++++++++++----- 3 files changed, 141 insertions(+), 47 deletions(-) diff --git a/stable_learning_control/algos/tf2/common/get_lr_scheduler.py b/stable_learning_control/algos/tf2/common/get_lr_scheduler.py index 94a840d6..d38ad980 100644 --- a/stable_learning_control/algos/tf2/common/get_lr_scheduler.py +++ b/stable_learning_control/algos/tf2/common/get_lr_scheduler.py @@ -11,8 +11,8 @@ def get_lr_scheduler(decaying_lr_type, lr_start, lr_final, steps): """Creates a learning rate scheduler. Args: - decaying_lr_type (str): The learning rate decay type that is used ( - options are: ``linear`` and ``exponential`` and ``constant``). + decaying_lr_type (str): The learning rate decay type that is used (options are: + ``linear`` and ``exponential`` and ``constant``). lr_start (float): Initial learning rate. lr_final (float): Final learning rate. steps (int, optional): Number of steps/epochs used in the training. This diff --git a/stable_learning_control/algos/tf2/lac/lac.py b/stable_learning_control/algos/tf2/lac/lac.py index 21a2902e..3be40cd8 100644 --- a/stable_learning_control/algos/tf2/lac/lac.py +++ b/stable_learning_control/algos/tf2/lac/lac.py @@ -952,6 +952,7 @@ def lac( - replay_buffer (union[:class:`~stable_learning_control.algos.common.buffers.ReplayBuffer`, :class:`~stable_learning_control.algos.common.buffers.FiniteHorizonReplayBuffer`]): The replay buffer used during training. """ # noqa: E501, D301 + update_after = max(1, update_after) # You can not update before the first step. validate_args(**locals()) # Retrieve hyperparameters while filtering out the logger_kwargs. @@ -1083,11 +1084,41 @@ def lac( device, ) + # Parse learning rate decay type. + valid_lr_decay_options = ["step", "epoch"] + lr_decay_ref = lr_decay_ref.lower() + if lr_decay_ref not in valid_lr_decay_options: + options = [f"'{option}'" for option in valid_lr_decay_options] + logger.log( + f"The learning rate decay reference variable was set to '{lr_decay_ref}', " + "which is not a valid option. Valid options are " + f"{', '.join(options)}. The learning rate decay reference " + "variable has been set to 'epoch'.", + type="warning", + ) + lr_decay_ref = "epoch" + + # Calculate the number of learning rate scheduler steps. + if lr_decay_ref == "step": + # NOTE: Decay applied at policy update to improve performance. + lr_decay_steps = (total_steps - update_after) / update_every + else: + lr_decay_steps = epochs + # Create learning rate schedulers. # NOTE: Alpha and labda currently use the same scheduler as the actor. - lr_decay_ref_var = total_steps if lr_decay_ref.lower() == "steps" else epochs - lr_a_scheduler = get_lr_scheduler(lr_decay_type, lr_a, lr_a_final, lr_decay_ref_var) - lr_c_scheduler = get_lr_scheduler(lr_decay_type, lr_c, lr_c_final, lr_decay_ref_var) + lr_a_scheduler = get_lr_scheduler(lr_decay_type, lr_a, lr_a_final, lr_decay_steps) + lr_c_scheduler = get_lr_scheduler(lr_decay_type, lr_c, lr_c_final, lr_decay_steps) + + # Create step based learning rate schedulers. + # NOTE: Used to estimate the learning rate at each step. + if lr_decay_ref == "step": + lr_a_step_scheduler = get_lr_scheduler( + lr_decay_type, lr_a, lr_a_final, lr_decay_steps + 1 + ) + lr_c_step_scheduler = get_lr_scheduler( + lr_decay_type, lr_c, lr_c_final, lr_decay_steps + 1 + ) # Restore policy if supplied. if start_policy is not None: @@ -1165,6 +1196,7 @@ def lac( "Entropy", ] if use_tensorboard: + # NOTE: TensorBoard counts from 0. logger.log_to_tb( "Lr_a", policy._pi_optimizer.lr.numpy(), @@ -1227,6 +1259,7 @@ def lac( # NOTE: Improved compared to Han et al. 2020. Previously, updates were based on # memory size, which only changed at terminal states. if (t + 1) >= update_after and ((t + 1) - update_after) % update_every == 0: + n_update = ((t + 1) - update_after) // update_every for _ in range(steps_per_update): batch = replay_buffer.sample_batch(batch_size) update_diagnostics = policy.update(data=batch) @@ -1235,10 +1268,10 @@ def lac( # Step based learning rate decay. if lr_decay_ref.lower() == "step": lr_a_now = max( - lr_a_scheduler(t + 1), lr_a_final + lr_a_scheduler(n_update + 1), lr_a_final ) # Make sure lr is bounded above final lr. lr_c_now = max( - lr_c_scheduler(t + 1), lr_c_final + lr_c_scheduler(n_update + 1), lr_c_final ) # Make sure lr is bounded above final lr. policy.set_learning_rates( lr_a=lr_a_now, lr_c=lr_c_now, lr_alpha=lr_a_now, lr_labda=lr_a_now @@ -1246,7 +1279,9 @@ def lac( # SGD batch tb logging. if use_tensorboard and not tb_low_log_freq: - logger.log_to_tb(keys=diag_tb_log_list, global_step=t) + logger.log_to_tb( + keys=diag_tb_log_list, global_step=t + ) # NOTE: TensorBoard counts from 0. # End of epoch handling (Save model, test performance and log data) if (t + 1) % steps_per_epoch == 0: @@ -1265,21 +1300,22 @@ def lac( extend=True, ) - # Epoch based learning rate decay. - if lr_decay_ref.lower() != "step": - lr_a_now = max( - lr_a_scheduler(epoch), lr_a_final - ) # Make sure lr is bounded above final. - lr_c_now = max( - lr_c_scheduler(epoch), lr_c_final - ) # Make sure lr is bounded above final. - policy.set_learning_rates( - lr_a=lr_a_now, lr_c=lr_c_now, lr_alpha=lr_a_now, lr_labda=lr_a_now - ) + # Retrieve current learning rates. + if lr_decay_ref == "step": + progress = max((t + 1) - update_after, 0) / update_every + lr_actor = lr_a_step_scheduler(progress) + lr_critic = lr_c_step_scheduler(progress) + lr_alpha = lr_a_step_scheduler(progress) + lr_labda = lr_a_step_scheduler(progress) + else: + lr_actor = policy._pi_optimizer.lr.numpy() + lr_critic = policy._c_optimizer.lr.numpy() + lr_alpha = policy._log_alpha_optimizer.lr.numpy() + lr_labda = policy._log_labda_optimizer.lr.numpy() # Log info about epoch. logger.log_tabular("Epoch", epoch) - logger.log_tabular("TotalEnvInteracts", t) + logger.log_tabular("TotalEnvInteracts", t + 1) logger.log_tabular( "EpRet", with_min_and_max=True, @@ -1299,25 +1335,25 @@ def lac( ) logger.log_tabular( "Lr_a", - policy._pi_optimizer.lr.numpy(), + lr_actor, tb_write=use_tensorboard, tb_prefix="LearningRates", ) logger.log_tabular( "Lr_c", - policy._c_optimizer.lr.numpy(), + lr_critic, tb_write=use_tensorboard, tb_prefix="LearningRates", ) logger.log_tabular( "Lr_alpha", - policy._log_alpha_optimizer.lr.numpy(), + lr_alpha, tb_write=use_tensorboard, tb_prefix="LearningRates", ) logger.log_tabular( "Lr_labda", - policy._log_labda_optimizer.lr.numpy(), + lr_labda, tb_write=use_tensorboard, tb_prefix="LearningRates", ) @@ -1360,7 +1396,19 @@ def lac( tb_write=(use_tensorboard and tb_low_log_freq), ) logger.log_tabular("Time", time.time() - start_time) - logger.dump_tabular(global_step=t) + logger.dump_tabular(global_step=t) # NOTE: TensorBoard counts from 0. + + # Epoch based learning rate decay. + if lr_decay_ref.lower() != "step": + lr_a_now = max( + lr_a_scheduler(epoch), lr_a_final + ) # Make sure lr is bounded above final. + lr_c_now = max( + lr_c_scheduler(epoch), lr_c_final + ) # Make sure lr is bounded above final. + policy.set_learning_rates( + lr_a=lr_a_now, lr_c=lr_c_now, lr_alpha=lr_a_now, lr_labda=lr_a_now + ) # Export model to 'SavedModel' if export: diff --git a/stable_learning_control/algos/tf2/sac/sac.py b/stable_learning_control/algos/tf2/sac/sac.py index 0f82ecf6..00e9cdc8 100644 --- a/stable_learning_control/algos/tf2/sac/sac.py +++ b/stable_learning_control/algos/tf2/sac/sac.py @@ -821,6 +821,7 @@ def sac( - replay_buffer (union[:class:`~stable_learning_control.algos.common.buffers.ReplayBuffer`, :class:`~stable_learning_control.algos.common.buffers.FiniteHorizonReplayBuffer`]): The replay buffer used during training. """ # noqa: E501, D301 + update_after = max(1, update_after) # You can not update before the first step. validate_args(**locals()) # Retrieve hyperparameters while filtering out the logger_kwargs. @@ -950,11 +951,41 @@ def sac( device, ) + # Parse learning rate decay type. + valid_lr_decay_options = ["step", "epoch"] + lr_decay_ref = lr_decay_ref.lower() + if lr_decay_ref not in valid_lr_decay_options: + options = [f"'{option}'" for option in valid_lr_decay_options] + logger.log( + f"The learning rate decay reference variable was set to '{lr_decay_ref}', " + "which is not a valid option. Valid options are " + f"{', '.join(options)}. The learning rate decay reference " + "variable has been set to 'epoch'.", + type="warning", + ) + lr_decay_ref = "epoch" + + # Calculate the number of learning rate scheduler steps. + if lr_decay_ref == "step": + # NOTE: Decay applied at policy update to improve performance. + lr_decay_steps = (total_steps - update_after) / update_every + else: + lr_decay_steps = epochs + # Create learning rate schedulers. # NOTE: Alpha currently uses the same scheduler as the actor. - lr_decay_ref_var = total_steps if lr_decay_ref.lower() == "steps" else epochs - lr_a_scheduler = get_lr_scheduler(lr_decay_type, lr_a, lr_a_final, lr_decay_ref_var) - lr_c_scheduler = get_lr_scheduler(lr_decay_type, lr_c, lr_c_final, lr_decay_ref_var) + lr_a_scheduler = get_lr_scheduler(lr_decay_type, lr_a, lr_a_final, lr_decay_steps) + lr_c_scheduler = get_lr_scheduler(lr_decay_type, lr_c, lr_c_final, lr_decay_steps) + + # Create step based learning rate schedulers. + # NOTE: Used to estimate the learning rate at each step. + if lr_decay_ref == "step": + lr_a_step_scheduler = get_lr_scheduler( + lr_decay_type, lr_a, lr_a_final, lr_decay_steps + 1 + ) + lr_c_step_scheduler = get_lr_scheduler( + lr_decay_type, lr_c, lr_c_final, lr_decay_steps + 1 + ) # Restore policy if supplied. if start_policy is not None: @@ -1010,6 +1041,7 @@ def sac( # Setup diagnostics tb_write dict and store initial learning rates. diag_tb_log_list = ["LossQ", "LossPi", "Alpha", "LossAlpha", "Entropy"] if use_tensorboard: + # NOTE: TensorBoard counts from 0. logger.log_to_tb( "Lr_a", policy._pi_optimizer.lr.numpy(), @@ -1062,6 +1094,7 @@ def sac( # NOTE: Improved compared to Han et al. 2020. Previously, updates were based on # memory size, which only changed at terminal states. if (t + 1) >= update_after and ((t + 1) - update_after) % update_every == 0: + n_update = ((t + 1) - update_after) // update_every for _ in range(steps_per_update): batch = replay_buffer.sample_batch(batch_size) update_diagnostics = policy.update(data=batch) @@ -1070,10 +1103,10 @@ def sac( # Step based learning rate decay. if lr_decay_ref.lower() == "step": lr_a_now = max( - lr_a_scheduler(t + 1), lr_a_final + lr_a_scheduler(n_update + 1), lr_a_final ) # Make sure lr is bounded above final lr. lr_c_now = max( - lr_c_scheduler(t + 1), lr_c_final + lr_c_scheduler(n_update + 1), lr_c_final ) # Make sure lr is bounded above final lr. policy.set_learning_rates( lr_a=lr_a_now, lr_c=lr_c_now, lr_alpha=lr_a_now @@ -1081,7 +1114,9 @@ def sac( # SGD batch tb logging. if use_tensorboard and not tb_low_log_freq: - logger.log_to_tb(keys=diag_tb_log_list, global_step=t) + logger.log_to_tb( + keys=diag_tb_log_list, global_step=t + ) # NOTE: TensorBoard counts from 0. # End of epoch handling (Save model, test performance and log data) if (t + 1) % steps_per_epoch == 0: @@ -1100,21 +1135,20 @@ def sac( extend=True, ) - # Epoch based learning rate decay. - if lr_decay_ref.lower() != "step": - lr_a_now = max( - lr_a_scheduler(epoch), lr_a_final - ) # Make sure lr is bounded above final. - lr_c_now = max( - lr_c_scheduler(epoch), lr_c_final - ) # Make sure lr is bounded above final. - policy.set_learning_rates( - lr_a=lr_a_now, lr_c=lr_c_now, lr_alpha=lr_a_now - ) + # Retrieve current learning rates. + if lr_decay_ref == "step": + progress = max((t + 1) - update_after, 0) / update_every + lr_actor = lr_a_step_scheduler(progress) + lr_critic = lr_c_step_scheduler(progress) + lr_alpha = lr_a_step_scheduler(progress) + else: + lr_actor = policy._pi_optimizer.lr.numpy() + lr_critic = policy._c_optimizer.lr.numpy() + lr_alpha = policy._log_alpha_optimizer.lr.numpy() # Log info about epoch. logger.log_tabular("Epoch", epoch) - logger.log_tabular("TotalEnvInteracts", t) + logger.log_tabular("TotalEnvInteracts", t + 1) logger.log_tabular( "EpRet", with_min_and_max=True, @@ -1134,19 +1168,19 @@ def sac( ) logger.log_tabular( "Lr_a", - policy._pi_optimizer.lr.numpy(), + lr_actor, tb_write=use_tensorboard, tb_prefix="LearningRates", ) logger.log_tabular( "Lr_c", - policy._c_optimizer.lr.numpy(), + lr_critic, tb_write=use_tensorboard, tb_prefix="LearningRates", ) logger.log_tabular( "Lr_alpha", - policy._log_alpha_optimizer.lr.numpy(), + lr_alpha, tb_write=use_tensorboard, tb_prefix="LearningRates", ) @@ -1180,7 +1214,19 @@ def sac( tb_write=(use_tensorboard and tb_low_log_freq), ) logger.log_tabular("Time", time.time() - start_time) - logger.dump_tabular(global_step=t) + logger.dump_tabular(global_step=t) # NOTE: TensorBoard counts from 0. + + # Epoch based learning rate decay. + if lr_decay_ref.lower() != "step": + lr_a_now = max( + lr_a_scheduler(epoch), lr_a_final + ) # Make sure lr is bounded above final. + lr_c_now = max( + lr_c_scheduler(epoch), lr_c_final + ) # Make sure lr is bounded above final. + policy.set_learning_rates( + lr_a=lr_a_now, lr_c=lr_c_now, lr_alpha=lr_a_now + ) # Export model to 'SavedModel' if export: