Skip to content

Commit

Permalink
feat(pytorch): add alpha/lambda learning rate customization
Browse files Browse the repository at this point in the history
This commit enhances user control over the training process by allowing
direct customization of the alpha/lambda learning rates and their decay
rates. Users can now fine-tune these parameters to better suit their
specific training requirements.
  • Loading branch information
rickstaa committed Feb 19, 2024
1 parent 2b3693e commit a642541
Show file tree
Hide file tree
Showing 2 changed files with 342 additions and 78 deletions.
241 changes: 196 additions & 45 deletions stable_learning_control/algos/pytorch/lac/lac.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,10 @@
"AverageLossPi",
"AverageEntropy",
]
VALID_DECAY_TYPES = ["linear", "exponential", "constant"]
VALID_DECAY_REFERENCES = ["step", "epoch"]
DEFAULT_DECAY_TYPE = "linear"
DEFAULT_DECAY_REFERENCE = "epoch"


class LAC(nn.Module):
Expand Down Expand Up @@ -110,6 +114,8 @@ def __init__(
adaptive_temperature=True,
lr_a=1e-4,
lr_c=3e-4,
lr_alpha=1e-4,
lr_labda=3e-4,
device="cpu",
):
"""Initialise the LAC algorithm.
Expand Down Expand Up @@ -197,6 +203,10 @@ def __init__(
``1e-4``.
lr_c (float, optional): Learning rate used for the (lyapunov) critic.
Defaults to ``1e-4``.
lr_alpha (float, optional): Learning rate used for the entropy temperature.
Defaults to ``1e-4``.
lr_labda (float, optional): Learning rate used for the Lyapunov Lagrance
multiplier. Defaults to ``3e-4``.
device (str, optional): The device the networks are placed on (options:
``cpu``, ``gpu``, ``gpu:0``, ``gpu:1``, etc.). Defaults to ``cpu``.
Expand Down Expand Up @@ -258,8 +268,8 @@ def __init__(
self._alpha3 = alpha3
self._lr_a = lr_a
if self._adaptive_temperature:
self._lr_alpha = lr_a
self._lr_lag = lr_a
self._lr_alpha = lr_alpha
self._lr_lag = lr_labda
self._lr_c = lr_c
if not isinstance(target_entropy, (float, int)):
self._target_entropy = heuristic_target_entropy(env.action_space)
Expand Down Expand Up @@ -870,10 +880,18 @@ def lac(
adaptive_temperature=True,
lr_a=1e-4,
lr_c=3e-4,
lr_alpha=1e-4,
lr_labda=3e-4,
lr_a_final=1e-10,
lr_c_final=1e-10,
lr_decay_type="linear",
lr_decay_ref="epoch",
lr_alpha_final=1e-10,
lr_labda_final=1e-10,
lr_decay_type=DEFAULT_DECAY_TYPE,
lr_a_decay_type=DEFAULT_DECAY_TYPE,
lr_c_decay_type=DEFAULT_DECAY_TYPE,
lr_alpha_decay_type=DEFAULT_DECAY_TYPE,
lr_labda_decay_type=DEFAULT_DECAY_TYPE,
lr_decay_ref=DEFAULT_DECAY_REFERENCE,
batch_size=256,
replay_size=int(1e6),
horizon_length=0,
Expand Down Expand Up @@ -988,13 +1006,33 @@ def lac(
``1e-4``.
lr_c (float, optional): Learning rate used for the (lyapunov) critic.
Defaults to ``1e-4``.
lr_alpha (float, optional): Learning rate used for the entropy temperature.
Defaults to ``1e-4``.
lr_labda (float, optional): Learning rate used for the Lyapunov Lagrance
multiplier. Defaults to ``3e-4``.
lr_a_final(float, optional): The final actor learning rate that is achieved
at the end of the training. Defaults to ``1e-10``.
lr_c_final(float, optional): The final critic learning rate that is achieved
at the end of the training. Defaults to ``1e-10``.
lr_decay_type (str, optional): The learning rate decay type that is used (
options are: ``linear`` and ``exponential`` and ``constant``). Defaults to
``linear``.
lr_alpha_final(float, optional): The final alpha learning rate that is
achieved at the end of the training. Defaults to ``1e-10``.
lr_labda_final(float, optional): The final labda learning rate that is
achieved at the end of the training. Defaults to ``1e-10``.
lr_decay_type (str, optional): The learning rate decay type that is used (options
are: ``linear`` and ``exponential`` and ``constant``). Defaults to
``linear``.Can be overridden by the specific learning rate decay types.
lr_a_decay_type (str, optional): The learning rate decay type that is used for
the actor learning rate (options are: ``linear`` and ``exponential`` and
``constant``). If not specified, the general learning rate decay type is used.
lr_c_decay_type (str, optional): The learning rate decay type that is used for
the critic learning rate (options are: ``linear`` and ``exponential`` and
``constant``). If not specified, the general learning rate decay type is used.
lr_alpha_decay_type (str, optional): The learning rate decay type that is used
for the alpha learning rate (options are: ``linear`` and ``exponential``
and ``constant``). If not specified, the general learning rate decay type is used.
lr_labda_decay_type (str, optional): The learning rate decay type that is used
for the labda learning rate (options are: ``linear`` and ``exponential``
and ``constant``). If not specified, the general learning rate decay type is used.
lr_decay_ref (str, optional): The reference variable that is used for decaying
the learning rate (options: ``epoch`` and ``step``). Defaults to ``epoch``.
batch_size (int, optional): Minibatch size for SGD. Defaults to ``256``.
Expand Down Expand Up @@ -1134,20 +1172,22 @@ def lac(
# torch.backends.cudnn.benchmark = False # Disable for reproducibility.

policy = LAC(
env,
actor_critic,
ac_kwargs,
opt_type,
alpha,
alpha3,
labda,
gamma,
polyak,
target_entropy,
adaptive_temperature,
lr_a,
lr_c,
device,
env=env,
actor_critic=actor_critic,
ac_kwargs=ac_kwargs,
opt_type=opt_type,
alpha=alpha,
alpha3=alpha3,
labda=labda,
gamma=gamma,
polyak=polyak,
target_entropy=target_entropy,
adaptive_temperature=adaptive_temperature,
lr_a=lr_a,
lr_c=lr_c,
lr_alpha=lr_alpha,
lr_labda=lr_labda,
device=device,
)

# Restore policy if supplied.
Expand Down Expand Up @@ -1199,19 +1239,51 @@ def lac(
logger.log("Network structure:\n", type="info")
logger.log(policy.ac, end="\n\n")

# Parse learning rate decay type.
valid_lr_decay_options = ["step", "epoch"]
# Parse learning rate decay reference.
lr_decay_ref = lr_decay_ref.lower()
if lr_decay_ref not in valid_lr_decay_options:
options = [f"'{option}'" for option in valid_lr_decay_options]
if lr_decay_ref not in VALID_DECAY_REFERENCES:
options = [f"'{option}'" for option in VALID_DECAY_REFERENCES]
logger.log(
f"The learning rate decay reference variable was set to '{lr_decay_ref}', "
"which is not a valid option. Valid options are "
f"{', '.join(options)}. The learning rate decay reference "
"variable has been set to 'epoch'.",
f"variable has been set to '{DEFAULT_DECAY_REFERENCE}'.",
type="warning",
)
lr_decay_ref = "epoch"
lr_decay_ref = DEFAULT_DECAY_REFERENCE

# Parse learning rate decay types.
lr_decay_type = lr_decay_type.lower()
if lr_decay_type not in VALID_DECAY_TYPES:
options = [f"'{option}'" for option in VALID_DECAY_TYPES]
logger.log(
f"The learning rate decay type was set to '{lr_decay_type}', which is not "
"a valid option. Valid options are "
f"{', '.join(options)}. The learning rate decay type has been set to "
f"'{DEFAULT_DECAY_TYPE}'.",
type="warning",
)
lr_decay_type = DEFAULT_DECAY_TYPE
decay_types = {
"actor": lr_a_decay_type.lower() if lr_a_decay_type else None,
"critic": lr_c_decay_type.lower() if lr_c_decay_type else None,
"alpha": lr_alpha_decay_type.lower() if lr_alpha_decay_type else None,
"labda": lr_labda_decay_type.lower() if lr_labda_decay_type else None,
}
for name, decay_type in decay_types.items():
if decay_type is None:
decay_types[name] = lr_decay_type
else:
if decay_type not in VALID_DECAY_TYPES:
logger.log(
f"Invalid {name} learning rate decay type: '{decay_type}'. Using "
f"global learning rate decay type: '{lr_decay_type}' instead.",
type="warning",
)
decay_types[name] = lr_decay_type
lr_a_decay_type, lr_c_decay_type, lr_alpha_decay_type, lr_labda_decay_type = (
decay_types.values()
)

# Calculate the number of learning rate scheduler steps.
if lr_decay_ref == "step":
Expand All @@ -1223,25 +1295,34 @@ def lac(
lr_decay_steps = epochs

# Setup learning rate schedulers.
lr_a_init, lr_c_init, lr_alpha_init, lr_labda_init = lr_a, lr_c, lr_alpha, lr_labda
opt_schedulers = {
"pi": get_lr_scheduler(
policy._pi_optimizer, lr_decay_type, lr_a, lr_a_final, lr_decay_steps
policy._pi_optimizer,
lr_a_decay_type,
lr_a_init,
lr_a_final,
lr_decay_steps,
),
"c": get_lr_scheduler(
policy._c_optimizer, lr_decay_type, lr_c, lr_c_final, lr_decay_steps
policy._c_optimizer,
lr_c_decay_type,
lr_c_init,
lr_c_final,
lr_decay_steps,
),
"alpha": get_lr_scheduler(
policy._log_alpha_optimizer,
lr_decay_type,
lr_a,
lr_a_final,
lr_alpha_decay_type,
lr_alpha_init,
lr_alpha_final,
lr_decay_steps,
),
"lambda": get_lr_scheduler(
"labda": get_lr_scheduler(
policy._log_labda_optimizer,
lr_decay_type,
lr_a,
lr_a_final,
lr_labda_decay_type,
lr_labda_init,
lr_labda_final,
lr_decay_steps,
),
}
Expand Down Expand Up @@ -1351,7 +1432,7 @@ def lac(
for scheduler in opt_schedulers.values():
scheduler.step()
policy.bound_lr(
lr_a_final, lr_c_final, lr_a_final, lr_a_final
lr_a_final, lr_c_final, lr_alpha_final, lr_labda_final
) # Make sure lr is bounded above the final lr.

# SGD batch tb logging.
Expand Down Expand Up @@ -1382,32 +1463,32 @@ def lac(
# NOTE: Estimate since 'step' decay is applied at policy update.
lr_actor = estimate_step_learning_rate(
opt_schedulers["pi"],
lr_a,
lr_a_init,
lr_a_final,
update_after,
total_steps,
t + 1,
)
lr_critic = estimate_step_learning_rate(
opt_schedulers["c"],
lr_c,
lr_c_init,
lr_c_final,
update_after,
total_steps,
t + 1,
)
lr_alpha = estimate_step_learning_rate(
opt_schedulers["alpha"],
lr_a,
lr_a_final,
lr_alpha_init,
lr_alpha_final,
update_after,
total_steps,
t + 1,
)
lr_labda = estimate_step_learning_rate(
opt_schedulers["lambda"],
lr_a,
lr_a_final,
opt_schedulers["labda"],
lr_labda_init,
lr_labda_final,
update_after,
total_steps,
t + 1,
Expand Down Expand Up @@ -1508,7 +1589,7 @@ def lac(
for scheduler in opt_schedulers.values():
scheduler.step()
policy.bound_lr(
lr_a_final, lr_c_final, lr_a_final, lr_a_final
lr_a_final, lr_c_final, lr_alpha_final, lr_labda_final
) # Make sure lr is bounded above the final lr.

# Export model to 'TorchScript'
Expand Down Expand Up @@ -1684,6 +1765,18 @@ def lac(
parser.add_argument(
"--lr_c", type=float, default=3e-4, help="critic learning rate (default: 1e-4)"
)
parser.add_argument(
"--lr_alpha",
type=float,
default=1e-4,
help="entropy temperature learning rate (default: 1e-4)",
)
parser.add_argument(
"--lr_labda",
type=float,
default=3e-4,
help="lyapunov Lagrance multiplier learning rate (default: 3e-4)",
)
parser.add_argument(
"--lr_a_final",
type=float,
Expand All @@ -1696,12 +1789,62 @@ def lac(
default=1e-10,
help="the finalcritic learning rate (default: 1e-10)",
)
parser.add_argument(
"--lr_alpha_final",
type=float,
default=1e-10,
help="the final entropy temperature learning rate (default: 1e-10)",
)
parser.add_argument(
"--lr_labda_final",
type=float,
default=1e-10,
help="the final lyapunov Lagrance multiplier learning rate (default: 1e-10)",
)
parser.add_argument(
"--lr_decay_type",
type=str,
default="linear",
help="the learning rate decay type (default: linear)",
)
parser.add_argument(
"--lr_a_decay_type",
type=str,
default=None,
help=(
"the learning rate decay type that is used for the actor learning rate. "
"If not specified, the general learning rate decay type is used."
),
)
parser.add_argument(
"--lr_c_decay_type",
type=str,
default=None,
help=(
"the learning rate decay type that is used for the critic learning rate. "
"If not specified, the general learning rate decay type is used."
),
)
parser.add_argument(
"--lr_alpha_decay_type",
type=str,
default=None,
help=(
"the learning rate decay type that is used for the entropy temperature "
"learning rate. If not specified, the general learning rate decay type is "
"used."
),
)
parser.add_argument(
"--lr_labda_decay_type",
type=str,
default=None,
help=(
"the learning rate decay type that is used for the lyapunov Lagrance "
"multiplier learning rate. If not specified, the general learning rate "
"decay type is used."
),
)
parser.add_argument(
"--lr_decay_ref",
type=str,
Expand Down Expand Up @@ -1914,9 +2057,17 @@ def lac(
adaptive_temperature=args.adaptive_temperature,
lr_a=args.lr_a,
lr_c=args.lr_c,
lr_alpha=args.lr_alpha,
lr_labda=args.lr_labda,
lr_a_final=args.lr_a_final,
lr_c_final=args.lr_c_final,
lr_alpha_final=args.lr_a_final,
lr_labda_final=args.lr_a_final,
lr_decay_type=args.lr_decay_type,
lr_a_decay_type=args.lr_a_decay_type,
lr_c_decay_type=args.lr_c_decay_type,
lr_alpha_decay_type=args.lr_alpha_decay_type,
lr_labda_decay_type=args.lr_labda_decay_type,
lr_decay_ref=args.lr_decay_ref,
batch_size=args.batch_size,
replay_size=args.replay_size,
Expand Down
Loading

0 comments on commit a642541

Please sign in to comment.