From a642541e40d8c86724556fd9902bf774a84af202 Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Mon, 19 Feb 2024 17:01:08 +0100 Subject: [PATCH] feat(pytorch): add alpha/lambda learning rate customization This commit enhances user control over the training process by allowing direct customization of the alpha/lambda learning rates and their decay rates. Users can now fine-tune these parameters to better suit their specific training requirements. --- .../algos/pytorch/lac/lac.py | 241 ++++++++++++++---- .../algos/pytorch/sac/sac.py | 179 ++++++++++--- 2 files changed, 342 insertions(+), 78 deletions(-) diff --git a/stable_learning_control/algos/pytorch/lac/lac.py b/stable_learning_control/algos/pytorch/lac/lac.py index 6977f67c..8d784d60 100644 --- a/stable_learning_control/algos/pytorch/lac/lac.py +++ b/stable_learning_control/algos/pytorch/lac/lac.py @@ -79,6 +79,10 @@ "AverageLossPi", "AverageEntropy", ] +VALID_DECAY_TYPES = ["linear", "exponential", "constant"] +VALID_DECAY_REFERENCES = ["step", "epoch"] +DEFAULT_DECAY_TYPE = "linear" +DEFAULT_DECAY_REFERENCE = "epoch" class LAC(nn.Module): @@ -110,6 +114,8 @@ def __init__( adaptive_temperature=True, lr_a=1e-4, lr_c=3e-4, + lr_alpha=1e-4, + lr_labda=3e-4, device="cpu", ): """Initialise the LAC algorithm. @@ -197,6 +203,10 @@ def __init__( ``1e-4``. lr_c (float, optional): Learning rate used for the (lyapunov) critic. Defaults to ``1e-4``. + lr_alpha (float, optional): Learning rate used for the entropy temperature. + Defaults to ``1e-4``. + lr_labda (float, optional): Learning rate used for the Lyapunov Lagrance + multiplier. Defaults to ``3e-4``. device (str, optional): The device the networks are placed on (options: ``cpu``, ``gpu``, ``gpu:0``, ``gpu:1``, etc.). Defaults to ``cpu``. @@ -258,8 +268,8 @@ def __init__( self._alpha3 = alpha3 self._lr_a = lr_a if self._adaptive_temperature: - self._lr_alpha = lr_a - self._lr_lag = lr_a + self._lr_alpha = lr_alpha + self._lr_lag = lr_labda self._lr_c = lr_c if not isinstance(target_entropy, (float, int)): self._target_entropy = heuristic_target_entropy(env.action_space) @@ -870,10 +880,18 @@ def lac( adaptive_temperature=True, lr_a=1e-4, lr_c=3e-4, + lr_alpha=1e-4, + lr_labda=3e-4, lr_a_final=1e-10, lr_c_final=1e-10, - lr_decay_type="linear", - lr_decay_ref="epoch", + lr_alpha_final=1e-10, + lr_labda_final=1e-10, + lr_decay_type=DEFAULT_DECAY_TYPE, + lr_a_decay_type=DEFAULT_DECAY_TYPE, + lr_c_decay_type=DEFAULT_DECAY_TYPE, + lr_alpha_decay_type=DEFAULT_DECAY_TYPE, + lr_labda_decay_type=DEFAULT_DECAY_TYPE, + lr_decay_ref=DEFAULT_DECAY_REFERENCE, batch_size=256, replay_size=int(1e6), horizon_length=0, @@ -988,13 +1006,33 @@ def lac( ``1e-4``. lr_c (float, optional): Learning rate used for the (lyapunov) critic. Defaults to ``1e-4``. + lr_alpha (float, optional): Learning rate used for the entropy temperature. + Defaults to ``1e-4``. + lr_labda (float, optional): Learning rate used for the Lyapunov Lagrance + multiplier. Defaults to ``3e-4``. lr_a_final(float, optional): The final actor learning rate that is achieved at the end of the training. Defaults to ``1e-10``. lr_c_final(float, optional): The final critic learning rate that is achieved at the end of the training. Defaults to ``1e-10``. - lr_decay_type (str, optional): The learning rate decay type that is used ( - options are: ``linear`` and ``exponential`` and ``constant``). Defaults to - ``linear``. + lr_alpha_final(float, optional): The final alpha learning rate that is + achieved at the end of the training. Defaults to ``1e-10``. + lr_labda_final(float, optional): The final labda learning rate that is + achieved at the end of the training. Defaults to ``1e-10``. + lr_decay_type (str, optional): The learning rate decay type that is used (options + are: ``linear`` and ``exponential`` and ``constant``). Defaults to + ``linear``.Can be overridden by the specific learning rate decay types. + lr_a_decay_type (str, optional): The learning rate decay type that is used for + the actor learning rate (options are: ``linear`` and ``exponential`` and + ``constant``). If not specified, the general learning rate decay type is used. + lr_c_decay_type (str, optional): The learning rate decay type that is used for + the critic learning rate (options are: ``linear`` and ``exponential`` and + ``constant``). If not specified, the general learning rate decay type is used. + lr_alpha_decay_type (str, optional): The learning rate decay type that is used + for the alpha learning rate (options are: ``linear`` and ``exponential`` + and ``constant``). If not specified, the general learning rate decay type is used. + lr_labda_decay_type (str, optional): The learning rate decay type that is used + for the labda learning rate (options are: ``linear`` and ``exponential`` + and ``constant``). If not specified, the general learning rate decay type is used. lr_decay_ref (str, optional): The reference variable that is used for decaying the learning rate (options: ``epoch`` and ``step``). Defaults to ``epoch``. batch_size (int, optional): Minibatch size for SGD. Defaults to ``256``. @@ -1134,20 +1172,22 @@ def lac( # torch.backends.cudnn.benchmark = False # Disable for reproducibility. policy = LAC( - env, - actor_critic, - ac_kwargs, - opt_type, - alpha, - alpha3, - labda, - gamma, - polyak, - target_entropy, - adaptive_temperature, - lr_a, - lr_c, - device, + env=env, + actor_critic=actor_critic, + ac_kwargs=ac_kwargs, + opt_type=opt_type, + alpha=alpha, + alpha3=alpha3, + labda=labda, + gamma=gamma, + polyak=polyak, + target_entropy=target_entropy, + adaptive_temperature=adaptive_temperature, + lr_a=lr_a, + lr_c=lr_c, + lr_alpha=lr_alpha, + lr_labda=lr_labda, + device=device, ) # Restore policy if supplied. @@ -1199,19 +1239,51 @@ def lac( logger.log("Network structure:\n", type="info") logger.log(policy.ac, end="\n\n") - # Parse learning rate decay type. - valid_lr_decay_options = ["step", "epoch"] + # Parse learning rate decay reference. lr_decay_ref = lr_decay_ref.lower() - if lr_decay_ref not in valid_lr_decay_options: - options = [f"'{option}'" for option in valid_lr_decay_options] + if lr_decay_ref not in VALID_DECAY_REFERENCES: + options = [f"'{option}'" for option in VALID_DECAY_REFERENCES] logger.log( f"The learning rate decay reference variable was set to '{lr_decay_ref}', " "which is not a valid option. Valid options are " f"{', '.join(options)}. The learning rate decay reference " - "variable has been set to 'epoch'.", + f"variable has been set to '{DEFAULT_DECAY_REFERENCE}'.", type="warning", ) - lr_decay_ref = "epoch" + lr_decay_ref = DEFAULT_DECAY_REFERENCE + + # Parse learning rate decay types. + lr_decay_type = lr_decay_type.lower() + if lr_decay_type not in VALID_DECAY_TYPES: + options = [f"'{option}'" for option in VALID_DECAY_TYPES] + logger.log( + f"The learning rate decay type was set to '{lr_decay_type}', which is not " + "a valid option. Valid options are " + f"{', '.join(options)}. The learning rate decay type has been set to " + f"'{DEFAULT_DECAY_TYPE}'.", + type="warning", + ) + lr_decay_type = DEFAULT_DECAY_TYPE + decay_types = { + "actor": lr_a_decay_type.lower() if lr_a_decay_type else None, + "critic": lr_c_decay_type.lower() if lr_c_decay_type else None, + "alpha": lr_alpha_decay_type.lower() if lr_alpha_decay_type else None, + "labda": lr_labda_decay_type.lower() if lr_labda_decay_type else None, + } + for name, decay_type in decay_types.items(): + if decay_type is None: + decay_types[name] = lr_decay_type + else: + if decay_type not in VALID_DECAY_TYPES: + logger.log( + f"Invalid {name} learning rate decay type: '{decay_type}'. Using " + f"global learning rate decay type: '{lr_decay_type}' instead.", + type="warning", + ) + decay_types[name] = lr_decay_type + lr_a_decay_type, lr_c_decay_type, lr_alpha_decay_type, lr_labda_decay_type = ( + decay_types.values() + ) # Calculate the number of learning rate scheduler steps. if lr_decay_ref == "step": @@ -1223,25 +1295,34 @@ def lac( lr_decay_steps = epochs # Setup learning rate schedulers. + lr_a_init, lr_c_init, lr_alpha_init, lr_labda_init = lr_a, lr_c, lr_alpha, lr_labda opt_schedulers = { "pi": get_lr_scheduler( - policy._pi_optimizer, lr_decay_type, lr_a, lr_a_final, lr_decay_steps + policy._pi_optimizer, + lr_a_decay_type, + lr_a_init, + lr_a_final, + lr_decay_steps, ), "c": get_lr_scheduler( - policy._c_optimizer, lr_decay_type, lr_c, lr_c_final, lr_decay_steps + policy._c_optimizer, + lr_c_decay_type, + lr_c_init, + lr_c_final, + lr_decay_steps, ), "alpha": get_lr_scheduler( policy._log_alpha_optimizer, - lr_decay_type, - lr_a, - lr_a_final, + lr_alpha_decay_type, + lr_alpha_init, + lr_alpha_final, lr_decay_steps, ), - "lambda": get_lr_scheduler( + "labda": get_lr_scheduler( policy._log_labda_optimizer, - lr_decay_type, - lr_a, - lr_a_final, + lr_labda_decay_type, + lr_labda_init, + lr_labda_final, lr_decay_steps, ), } @@ -1351,7 +1432,7 @@ def lac( for scheduler in opt_schedulers.values(): scheduler.step() policy.bound_lr( - lr_a_final, lr_c_final, lr_a_final, lr_a_final + lr_a_final, lr_c_final, lr_alpha_final, lr_labda_final ) # Make sure lr is bounded above the final lr. # SGD batch tb logging. @@ -1382,7 +1463,7 @@ def lac( # NOTE: Estimate since 'step' decay is applied at policy update. lr_actor = estimate_step_learning_rate( opt_schedulers["pi"], - lr_a, + lr_a_init, lr_a_final, update_after, total_steps, @@ -1390,7 +1471,7 @@ def lac( ) lr_critic = estimate_step_learning_rate( opt_schedulers["c"], - lr_c, + lr_c_init, lr_c_final, update_after, total_steps, @@ -1398,16 +1479,16 @@ def lac( ) lr_alpha = estimate_step_learning_rate( opt_schedulers["alpha"], - lr_a, - lr_a_final, + lr_alpha_init, + lr_alpha_final, update_after, total_steps, t + 1, ) lr_labda = estimate_step_learning_rate( - opt_schedulers["lambda"], - lr_a, - lr_a_final, + opt_schedulers["labda"], + lr_labda_init, + lr_labda_final, update_after, total_steps, t + 1, @@ -1508,7 +1589,7 @@ def lac( for scheduler in opt_schedulers.values(): scheduler.step() policy.bound_lr( - lr_a_final, lr_c_final, lr_a_final, lr_a_final + lr_a_final, lr_c_final, lr_alpha_final, lr_labda_final ) # Make sure lr is bounded above the final lr. # Export model to 'TorchScript' @@ -1684,6 +1765,18 @@ def lac( parser.add_argument( "--lr_c", type=float, default=3e-4, help="critic learning rate (default: 1e-4)" ) + parser.add_argument( + "--lr_alpha", + type=float, + default=1e-4, + help="entropy temperature learning rate (default: 1e-4)", + ) + parser.add_argument( + "--lr_labda", + type=float, + default=3e-4, + help="lyapunov Lagrance multiplier learning rate (default: 3e-4)", + ) parser.add_argument( "--lr_a_final", type=float, @@ -1696,12 +1789,62 @@ def lac( default=1e-10, help="the finalcritic learning rate (default: 1e-10)", ) + parser.add_argument( + "--lr_alpha_final", + type=float, + default=1e-10, + help="the final entropy temperature learning rate (default: 1e-10)", + ) + parser.add_argument( + "--lr_labda_final", + type=float, + default=1e-10, + help="the final lyapunov Lagrance multiplier learning rate (default: 1e-10)", + ) parser.add_argument( "--lr_decay_type", type=str, default="linear", help="the learning rate decay type (default: linear)", ) + parser.add_argument( + "--lr_a_decay_type", + type=str, + default=None, + help=( + "the learning rate decay type that is used for the actor learning rate. " + "If not specified, the general learning rate decay type is used." + ), + ) + parser.add_argument( + "--lr_c_decay_type", + type=str, + default=None, + help=( + "the learning rate decay type that is used for the critic learning rate. " + "If not specified, the general learning rate decay type is used." + ), + ) + parser.add_argument( + "--lr_alpha_decay_type", + type=str, + default=None, + help=( + "the learning rate decay type that is used for the entropy temperature " + "learning rate. If not specified, the general learning rate decay type is " + "used." + ), + ) + parser.add_argument( + "--lr_labda_decay_type", + type=str, + default=None, + help=( + "the learning rate decay type that is used for the lyapunov Lagrance " + "multiplier learning rate. If not specified, the general learning rate " + "decay type is used." + ), + ) parser.add_argument( "--lr_decay_ref", type=str, @@ -1914,9 +2057,17 @@ def lac( adaptive_temperature=args.adaptive_temperature, lr_a=args.lr_a, lr_c=args.lr_c, + lr_alpha=args.lr_alpha, + lr_labda=args.lr_labda, lr_a_final=args.lr_a_final, lr_c_final=args.lr_c_final, + lr_alpha_final=args.lr_a_final, + lr_labda_final=args.lr_a_final, lr_decay_type=args.lr_decay_type, + lr_a_decay_type=args.lr_a_decay_type, + lr_c_decay_type=args.lr_c_decay_type, + lr_alpha_decay_type=args.lr_alpha_decay_type, + lr_labda_decay_type=args.lr_labda_decay_type, lr_decay_ref=args.lr_decay_ref, batch_size=args.batch_size, replay_size=args.replay_size, diff --git a/stable_learning_control/algos/pytorch/sac/sac.py b/stable_learning_control/algos/pytorch/sac/sac.py index 7e347907..97b7aa29 100644 --- a/stable_learning_control/algos/pytorch/sac/sac.py +++ b/stable_learning_control/algos/pytorch/sac/sac.py @@ -72,6 +72,10 @@ "AverageLossPi", "AverageEntropy", ] +VALID_DECAY_TYPES = ["linear", "exponential", "constant"] +VALID_DECAY_REFERENCES = ["step", "epoch"] +DEFAULT_DECAY_TYPE = "linear" +DEFAULT_DECAY_REFERENCE = "epoch" class SAC(nn.Module): @@ -100,6 +104,7 @@ def __init__( adaptive_temperature=True, lr_a=1e-4, lr_c=3e-4, + lr_alpha=1e-4, device="cpu", ): """Initialise the SAC algorithm. @@ -182,6 +187,8 @@ def __init__( ``1e-4``. lr_c (float, optional): Learning rate used for the (Soft) critic. Defaults to ``1e-4``. + lr_alpha (float, optional): Learning rate used for the entropy temperature. + Defaults to ``1e-4``. device (str, optional): The device the networks are placed on (options: ``cpu``, ``gpu``, ``gpu:0``, ``gpu:1``, etc.). Defaults to ``cpu``. """ # noqa: E501, D301 @@ -237,7 +244,7 @@ def __init__( self._gamma = gamma self._lr_a = lr_a if self._adaptive_temperature: - self._lr_alpha = lr_a + self._lr_alpha = lr_alpha self._lr_c = lr_c if not isinstance(target_entropy, (float, int)): self._target_entropy = heuristic_target_entropy(env.action_space) @@ -746,10 +753,15 @@ def sac( adaptive_temperature=True, lr_a=1e-4, lr_c=3e-4, + lr_alpha=1e-4, lr_a_final=1e-10, lr_c_final=1e-10, - lr_decay_type="linear", - lr_decay_ref="epoch", + lr_alpha_final=1e-10, + lr_decay_type=DEFAULT_DECAY_TYPE, + lr_a_decay_type=DEFAULT_DECAY_TYPE, + lr_c_decay_type=DEFAULT_DECAY_TYPE, + lr_alpha_decay_type=DEFAULT_DECAY_TYPE, + lr_decay_ref=DEFAULT_DECAY_REFERENCE, batch_size=256, replay_size=int(1e6), seed=None, @@ -859,13 +871,26 @@ def sac( ``1e-4``. lr_c (float, optional): Learning rate used for the (soft) critic. Defaults to ``1e-4``. + lr_alpha (float, optional): Learning rate used for the entropy temperature. + Defaults to ``1e-4``. lr_a_final(float, optional): The final actor learning rate that is achieved at the end of the training. Defaults to ``1e-10``. lr_c_final(float, optional): The final critic learning rate that is achieved at the end of the training. Defaults to ``1e-10``. - lr_decay_type (str, optional): The learning rate decay type that is used ( - options are: ``linear`` and ``exponential`` and ``constant``). Defaults to - ``linear``. + lr_alpha_final(float, optional): The final alpha learning rate that is + achieved at the end of the training. Defaults to ``1e-10``. + lr_decay_type (str, optional): The learning rate decay type that is used (options + are: ``linear`` and ``exponential`` and ``constant``). Defaults to + ``linear``. Can be overridden by the specific learning rate decay types. + lr_a_decay_type (str, optional): The learning rate decay type that is used for + the actor learning rate (options are: ``linear`` and ``exponential`` and + ``constant``). If not specified, the general learning rate decay type is used. + lr_c_decay_type (str, optional): The learning rate decay type that is used for + the critic learning rate (options are: ``linear`` and ``exponential`` and + ``constant``). If not specified, the general learning rate decay type is used. + lr_alpha_decay_type (str, optional): The learning rate decay type that is used + for the alpha learning rate (options are: ``linear`` and ``exponential`` + and ``constant``). If not specified, the general learning rate decay type is used. lr_decay_ref (str, optional): The reference variable that is used for decaying the learning rate (options: ``epoch`` and ``step``). Defaults to ``epoch``. batch_size (int, optional): Minibatch size for SGD. Defaults to ``256``. @@ -1002,18 +1027,19 @@ def sac( # torch.backends.cudnn.benchmark = False # Disable for reproducibility. policy = SAC( - env, - actor_critic, - ac_kwargs, - opt_type, - alpha, - gamma, - polyak, - target_entropy, - adaptive_temperature, - lr_a, - lr_c, - device, + env=env, + actor_critic=actor_critic, + ac_kwargs=ac_kwargs, + opt_type=opt_type, + alpha=alpha, + gamma=gamma, + polyak=polyak, + target_entropy=target_entropy, + adaptive_temperature=adaptive_temperature, + lr_a=lr_a, + lr_c=lr_c, + lr_alpha=lr_alpha, + device=device, ) # Restore policy if supplied. @@ -1050,19 +1076,48 @@ def sac( logger.log("Network structure:\n", type="info") logger.log(policy.ac, end="\n\n") - # Parse learning rate decay type. - valid_lr_decay_options = ["step", "epoch"] + # Parse learning rate decay reference. lr_decay_ref = lr_decay_ref.lower() - if lr_decay_ref not in valid_lr_decay_options: - options = [f"'{option}'" for option in valid_lr_decay_options] + if lr_decay_ref not in VALID_DECAY_REFERENCES: + options = [f"'{option}'" for option in VALID_DECAY_REFERENCES] logger.log( f"The learning rate decay reference variable was set to '{lr_decay_ref}', " "which is not a valid option. Valid options are " f"{', '.join(options)}. The learning rate decay reference " - "variable has been set to 'epoch'.", + f"variable has been set to '{DEFAULT_DECAY_REFERENCE}'.", type="warning", ) - lr_decay_ref = "epoch" + lr_decay_ref = DEFAULT_DECAY_REFERENCE + + # Parse learning rate decay types. + lr_decay_type = lr_decay_type.lower() + if lr_decay_type not in VALID_DECAY_TYPES: + options = [f"'{option}'" for option in VALID_DECAY_TYPES] + logger.log( + f"The learning rate decay type was set to '{lr_decay_type}', which is not " + "a valid option. Valid options are " + f"{', '.join(options)}. The learning rate decay type has been set to " + f"'{DEFAULT_DECAY_TYPE}'.", + type="warning", + ) + lr_decay_type = DEFAULT_DECAY_TYPE + decay_types = { + "actor": lr_a_decay_type.lower() if lr_a_decay_type else None, + "critic": lr_c_decay_type.lower() if lr_c_decay_type else None, + "alpha": lr_alpha_decay_type.lower() if lr_alpha_decay_type else None, + } + for name, decay_type in decay_types.items(): + if decay_type is None: + decay_types[name] = lr_decay_type + else: + if decay_type not in VALID_DECAY_TYPES: + logger.log( + f"Invalid {name} learning rate decay type: '{decay_type}'. Using " + f"global learning rate decay type: '{lr_decay_type}' instead.", + type="warning", + ) + decay_types[name] = lr_decay_type + lr_a_decay_type, lr_c_decay_type, lr_alpha_decay_type = decay_types.values() # Calculate the number of learning rate scheduler steps. if lr_decay_ref == "step": @@ -1074,15 +1129,28 @@ def sac( lr_decay_steps = epochs # Setup learning rate schedulers. + lr_a_init, lr_c_init, lr_alpha_init = lr_a, lr_c, lr_alpha opt_schedulers = { "pi": get_lr_scheduler( - policy._pi_optimizer, lr_decay_type, lr_a, lr_a_final, lr_decay_steps + policy._pi_optimizer, + lr_a_decay_type, + lr_a_init, + lr_a_final, + lr_decay_steps, ), "c": get_lr_scheduler( - policy._c_optimizer, lr_decay_type, lr_c, lr_c_final, lr_decay_steps + policy._c_optimizer, + lr_c_decay_type, + lr_c_init, + lr_c_final, + lr_decay_steps, ), "alpha": get_lr_scheduler( - policy._log_alpha_optimizer, lr_decay_type, lr_a, lr_a_final, lr_decay_steps + policy._log_alpha_optimizer, + lr_alpha_decay_type, + lr_alpha_init, + lr_alpha_final, + lr_decay_steps, ), } @@ -1173,7 +1241,7 @@ def sac( for scheduler in opt_schedulers.values(): scheduler.step() policy.bound_lr( - lr_a_final, lr_c_final, lr_a_final + lr_a_final, lr_c_final, lr_alpha_final ) # Make sure lr is bounded above the final lr. # SGD batch tb logging. @@ -1204,7 +1272,7 @@ def sac( # NOTE: Estimate since 'step' decay is applied at policy update. lr_actor = estimate_step_learning_rate( opt_schedulers["pi"], - lr_a, + lr_a_init, lr_a_final, update_after, total_steps, @@ -1212,7 +1280,7 @@ def sac( ) lr_critic = estimate_step_learning_rate( opt_schedulers["c"], - lr_c, + lr_c_init, lr_c_final, update_after, total_steps, @@ -1220,8 +1288,8 @@ def sac( ) lr_alpha = estimate_step_learning_rate( opt_schedulers["alpha"], - lr_a, - lr_a_final, + lr_alpha_init, + lr_alpha_final, update_after, total_steps, t + 1, @@ -1306,7 +1374,7 @@ def sac( for scheduler in opt_schedulers.values(): scheduler.step() policy.bound_lr( - lr_a_final, lr_c_final, lr_a_final + lr_a_final, lr_c_final, lr_alpha_final ) # Make sure lr is bounded above the final lr. # Export model to 'TorchScript' @@ -1475,6 +1543,12 @@ def sac( parser.add_argument( "--lr_c", type=float, default=3e-4, help="critic learning rate (default: 1e-4)" ) + parser.add_argument( + "--lr_alpha", + type=float, + default=1e-4, + help="entropy temperature learning rate (default: 1e-4)", + ) parser.add_argument( "--lr_a_final", type=float, @@ -1487,12 +1561,46 @@ def sac( default=1e-10, help="the finalcritic learning rate (default: 1e-10)", ) + parser.add_argument( + "--lr_alpha_final", + type=float, + default=1e-10, + help="the final entropy temperature learning rate (default: 1e-10)", + ) parser.add_argument( "--lr_decay_type", type=str, default="linear", help="the learning rate decay type (default: linear)", ) + parser.add_argument( + "--lr_a_decay_type", + type=str, + default=None, + help=( + "the learning rate decay type that is used for the actor learning rate. " + "If not specified, the general learning rate decay type is used." + ), + ) + parser.add_argument( + "--lr_c_decay_type", + type=str, + default=None, + help=( + "the learning rate decay type that is used for the critic learning rate. " + "If not specified, the general learning rate decay type is used." + ), + ) + parser.add_argument( + "--lr_alpha_decay_type", + type=str, + default=None, + help=( + "the learning rate decay type that is used for the entropy temperature " + "learning rate. If not specified, the general learning rate decay type is " + "used." + ), + ) parser.add_argument( "--lr_decay_ref", type=str, @@ -1695,9 +1803,14 @@ def sac( adaptive_temperature=args.adaptive_temperature, lr_a=args.lr_a, lr_c=args.lr_c, + lr_alpha=args.lr_alpha, lr_a_final=args.lr_a_final, lr_c_final=args.lr_c_final, + lr_alpha_final=args.lr_a_final, lr_decay_type=args.lr_decay_type, + lr_a_decay_type=args.lr_a_decay_type, + lr_c_decay_type=args.lr_c_decay_type, + lr_alpha_decay_type=args.lr_alpha_decay_type, lr_decay_ref=args.lr_decay_ref, batch_size=args.batch_size, replay_size=args.replay_size,