Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(pytorch): correct step-based learning rate decay #405

Merged
merged 1 commit into from
Feb 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 59 additions & 9 deletions stable_learning_control/algos/pytorch/common/get_lr_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,22 +22,24 @@
return gamma


def calc_linear_decay_rate(lr_init, lr_final, steps):
r"""Returns the linear decay factor (G) needed to achieve a given final learning
rate at a certain step. This decay factor can for example be used with a
:class:`torch.optim.lr_scheduler.LambdaLR` scheduler. Keep in mind that this
function assumes the following formula for the learning rate decay.
def get_linear_decay_rate(lr_init, lr_final, steps):
r"""Returns a linear decay factor (G) that enables a learning rate to transition
from an initial value (`lr_init`) at step 0 to a final value (`lr_final`) at a
specified step (N). This decay factor is compatible with the
:class:`torch.optim.lr_scheduler.LambdaLR` scheduler. The decay factor is calculated
using the following formula:

.. math::
lr_{terminal} = lr_{init} * (1.0 - G \cdot step)

Args:
lr_init (float): The initial learning rate.
lr_final (float): The final learning rate you want to achieve.
steps (int): The step/epoch at which you want to achieve this learning rate.
steps (int): The number of steps/epochs over which the learning rate should
decay. This is equal to epochs - 1.

Returns:
decimal.Decimal: Linear learning rate decay factor (G)
decimal.Decimal: Linear learning rate decay factor (G).
""" # noqa: W605
return -(
((Decimal(lr_final) / Decimal(lr_init)) - Decimal(1.0)) / Decimal(max(steps, 1))
Expand All @@ -53,7 +55,7 @@
(options are: ``linear`` and ``exponential`` and ``constant``).
lr_start (float): Initial learning rate.
lr_final (float): Final learning rate.
steps (int, optional): Number of steps/epochs used in the training. This
steps (int, optional): Number of steps/epochs used in the training. This
includes the starting step.

Returns:
Expand Down Expand Up @@ -83,7 +85,7 @@
return np.longdouble(
Decimal(1.0)
- (
calc_linear_decay_rate(lr_start, lr_final, (steps - 1.0))
get_linear_decay_rate(lr_start, lr_final, (steps - 1.0))
* Decimal(step)
)
)
Expand All @@ -96,3 +98,51 @@
return torch.optim.lr_scheduler.LambdaLR(
optimizer, lr_lambda=lambda step: np.longdouble(1.0)
) # Return a constant function.


def estimate_step_learning_rate(
lr_scheduler, lr_start, lr_final, update_after, total_steps, step
):
"""Estimates the learning rate at a given step.

This function estimates the learning rate for a specific training step. It differs
from the `get_last_lr` method of the learning rate scheduler, which returns the
learning rate at the last scheduler step, not necessarily the current training step.

Args:
lr_scheduler (torch.optim.lr_scheduler): The learning rate scheduler.
lr_start (float): The initial learning rate.
update_after (int): The step number after which the learning rate should start
decreasing.
lr_final (float): The final learning rate.
total_steps (int): The total number of steps/epochs in the training process.

Check warning on line 118 in stable_learning_control/algos/pytorch/common/get_lr_scheduler.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] stable_learning_control/algos/pytorch/common/get_lr_scheduler.py#L118 <291>

trailing whitespace
Raw output
./stable_learning_control/algos/pytorch/common/get_lr_scheduler.py:118:85: W291 trailing whitespace
Excludes the initial step.
step (int): The current step number. Excludes the initial step.

Returns:
float: The learning rate at the given step.
"""
if step < update_after:
return lr_start
else:
adjusted_step = step - update_after
adjusted_total_steps = total_steps - update_after
if isinstance(lr_scheduler, torch.optim.lr_scheduler.LambdaLR):
decay_rate = get_linear_decay_rate(lr_start, lr_final, adjusted_total_steps)
lr = float(
Decimal(lr_start) * (Decimal(1.0) - decay_rate * Decimal(adjusted_step))
)
elif isinstance(lr_scheduler, torch.optim.lr_scheduler.ExponentialLR):
decay_rate = get_exponential_decay_rate(
lr_start, lr_final, adjusted_total_steps
)
lr = float(
Decimal(lr_start) * (Decimal(decay_rate) ** Decimal(adjusted_step))
)
else:
supported_schedulers = ["LambdaLR", "ExponentialLR"]
raise ValueError(
f"The learning rate scheduler is not supported for this function. "
f"Supported schedulers are: {', '.join(supported_schedulers)}"
)
return max(lr, lr_final)
147 changes: 107 additions & 40 deletions stable_learning_control/algos/pytorch/lac/lac.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
)
from stable_learning_control.algos.pytorch.common.get_lr_scheduler import (
get_lr_scheduler,
estimate_step_learning_rate,
)
from stable_learning_control.algos.pytorch.common.helpers import (
count_vars,
Expand Down Expand Up @@ -1111,7 +1112,7 @@ def lac(
actor_critic = LyapunovActorCritic if actor_critic is None else actor_critic

# Ensure the environment is correctly seeded.
# NOTE: Done here since we donote:n't want to seed on every env.reset() call.
# NOTE: Done here since we don't want to seed on every env.reset() call.
if seed is not None:
env.np_random, _ = seeding.np_random(seed)
env.action_space.seed(seed)
Expand Down Expand Up @@ -1197,29 +1198,51 @@ def lac(
logger.log("Network structure:\n", type="info")
logger.log(policy.ac, end="\n\n")

# Create learning rate schedulers.
opt_schedulers = []
lr_decay_ref_var = total_steps if lr_decay_ref.lower() == "steps" else epochs
pi_opt_scheduler = get_lr_scheduler(
policy._pi_optimizer, lr_decay_type, lr_a, lr_a_final, lr_decay_ref_var
)
opt_schedulers.append(pi_opt_scheduler)
alpha_opt_scheduler = get_lr_scheduler(
policy._log_alpha_optimizer, lr_decay_type, lr_a, lr_a_final, lr_decay_ref_var
)
opt_schedulers.append(alpha_opt_scheduler)
c_opt_scheduler = get_lr_scheduler(
policy._c_optimizer, lr_decay_type, lr_c, lr_c_final, lr_decay_ref_var
)
opt_schedulers.append(c_opt_scheduler)
labda_opt_scheduler = get_lr_scheduler(
policy._log_labda_optimizer,
lr_decay_type,
lr_a,
lr_a_final,
lr_decay_ref_var,
)
opt_schedulers.append(labda_opt_scheduler)
# Parse learning rate decay type.
valid_lr_decay_options = ["step", "epoch"]
lr_decay_ref = lr_decay_ref.lower()
if lr_decay_ref not in valid_lr_decay_options:
options = [f"'{option}'" for option in valid_lr_decay_options]
logger.log(
f"The learning rate decay reference variable was set to '{lr_decay_ref}', "
"which is not a valid option. Valid options are "
f"{', '.join(options)}. The learning rate decay reference "
"variable has been set to 'epoch'.",
type="warning",
)
lr_decay_ref = "epoch"

# Calculate the number of learning rate scheduler steps.
if lr_decay_ref == "step":
# NOTE: Decay applied at policy update to improve performance.
lr_decay_steps = (total_steps - update_after) / update_every
else:
lr_decay_steps = epochs

# Setup learning rate schedulers.
# NOTE: +1 since we start at the initial learning rate.
opt_schedulers = {
"pi": get_lr_scheduler(
policy._pi_optimizer, lr_decay_type, lr_a, lr_a_final, lr_decay_steps + 1
),
"c": get_lr_scheduler(
policy._c_optimizer, lr_decay_type, lr_c, lr_c_final, lr_decay_steps + 1
),
"alpha": get_lr_scheduler(
policy._log_alpha_optimizer,
lr_decay_type,
lr_a,
lr_a_final,
lr_decay_steps + 1,
),
"lambda": get_lr_scheduler(
policy._log_labda_optimizer,
lr_decay_type,
lr_a,
lr_a_final,
lr_decay_steps + 1,
),
}

logger.setup_pytorch_saver(policy)

Expand Down Expand Up @@ -1253,6 +1276,7 @@ def lac(
"Entropy",
]
if use_tensorboard:
# NOTE: TensorBoard counts from 0.
logger.log_to_tb(
"Lr_a",
policy._pi_optimizer.param_groups[0]["lr"],
Expand Down Expand Up @@ -1321,16 +1345,18 @@ def lac(
logger.store(**update_diagnostics) # Log diagnostics.

# Step based learning rate decay.
if lr_decay_ref.lower() == "step":
for scheduler in opt_schedulers:
if lr_decay_ref == "step":
for scheduler in opt_schedulers.values():
scheduler.step()
policy.bound_lr(
lr_a_final, lr_c_final, lr_a_final, lr_a_final
) # Make sure lr is bounded above the final lr.

# SGD batch tb logging.
if use_tensorboard and not tb_low_log_freq:
logger.log_to_tb(keys=diag_tb_log_list, global_step=t)
logger.log_to_tb(
keys=diag_tb_log_list, global_step=t
) # NOTE: TensorBoard counts from 0.

# End of epoch handling (Save model, test performance and log data)
if (t + 1) % steps_per_epoch == 0:
Expand All @@ -1349,17 +1375,50 @@ def lac(
extend=True,
)

# Epoch based learning rate decay.
if lr_decay_ref.lower() != "step":
for scheduler in opt_schedulers:
scheduler.step()
policy.bound_lr(
lr_a_final, lr_c_final, lr_a_final, lr_a_final
) # Make sure lr is bounded above the final lr.
# Retrieve current learning rates.
if lr_decay_ref == "step":
# NOTE: Estimate since 'step' decay is applied at policy update.
lr_actor = estimate_step_learning_rate(
opt_schedulers["pi"],
lr_a,
lr_a_final,
update_after,
total_steps,
t + 1,
)
lr_critic = estimate_step_learning_rate(
opt_schedulers["c"],
lr_c,
lr_c_final,
update_after,
total_steps,
t + 1,
)
lr_alpha = estimate_step_learning_rate(
opt_schedulers["alpha"],
lr_a,
lr_a_final,
update_after,
total_steps,
t + 1,
)
lr_labda = estimate_step_learning_rate(
opt_schedulers["lambda"],
lr_a,
lr_a_final,
update_after,
total_steps,
t + 1,
)
else:
lr_actor = policy._pi_optimizer.param_groups[0]["lr"]
lr_critic = policy._c_optimizer.param_groups[0]["lr"]
lr_alpha = policy._log_alpha_optimizer.param_groups[0]["lr"]
lr_labda = policy._log_labda_optimizer.param_groups[0]["lr"]

# Log info about epoch.
logger.log_tabular("Epoch", epoch)
logger.log_tabular("TotalEnvInteracts", t)
logger.log_tabular("TotalEnvInteracts", t + 1)
logger.log_tabular(
"EpRet",
with_min_and_max=True,
Expand All @@ -1379,25 +1438,25 @@ def lac(
)
logger.log_tabular(
"Lr_a",
policy._pi_optimizer.param_groups[0]["lr"],
lr_actor,
tb_write=use_tensorboard,
tb_prefix="LearningRates",
)
logger.log_tabular(
"Lr_c",
policy._c_optimizer.param_groups[0]["lr"],
lr_critic,
tb_write=use_tensorboard,
tb_prefix="LearningRates",
)
logger.log_tabular(
"Lr_alpha",
policy._log_alpha_optimizer.param_groups[0]["lr"],
lr_alpha,
tb_write=use_tensorboard,
tb_prefix="LearningRates",
)
logger.log_tabular(
"Lr_labda",
policy._log_labda_optimizer.param_groups[0]["lr"],
lr_labda,
tb_write=use_tensorboard,
tb_prefix="LearningRates",
)
Expand Down Expand Up @@ -1440,7 +1499,15 @@ def lac(
tb_write=(use_tensorboard and tb_low_log_freq),
)
logger.log_tabular("Time", time.time() - start_time)
logger.dump_tabular(global_step=t)
logger.dump_tabular(global_step=t) # NOTE: TensorBoard counts from 0.

# Epoch based learning rate decay.
if lr_decay_ref != "step":
for scheduler in opt_schedulers.values():
scheduler.step()
policy.bound_lr(
lr_a_final, lr_c_final, lr_a_final, lr_a_final
) # Make sure lr is bounded above the final lr.

# Export model to 'TorchScript'
if export:
Expand Down
Loading