diff --git a/howto/register_external_algorithm.md b/howto/register_external_algorithm.md index a5e9a455..d1062dab 100644 --- a/howto/register_external_algorithm.md +++ b/howto/register_external_algorithm.md @@ -550,36 +550,36 @@ def ext_sota_main(fabric: Fabric, cfg: Dict[str, Any]): # Global variables last_train = 0 train_step = 0 - start_step = ( + start_iter = ( # + 1 because the checkpoint is at the end of the update step # (when resuming from a checkpoint, the update at the checkpoint # is ended and you have to start with the next one) - (state["update"] // fabric.world_size) + 1 + (state["iter_num"] // fabric.world_size) + 1 if cfg.checkpoint.resume_from else 1 ) - policy_step = state["update"] * cfg.env.num_envs * cfg.algo.rollout_steps if cfg.checkpoint.resume_from else 0 + policy_step = state["iter_num"] * cfg.env.num_envs * cfg.algo.rollout_steps if cfg.checkpoint.resume_from else 0 last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 - policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size) - num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 + policy_steps_per_iter = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size) + total_iters = cfg.algo.total_steps // policy_steps_per_iter if not cfg.dry_run else 1 if cfg.checkpoint.resume_from: cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size # Warning for log and checkpoint every - if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: + if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_iter != 0: warnings.warn( f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the metrics will be logged at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) - if cfg.checkpoint.every % policy_steps_per_update != 0: + if cfg.checkpoint.every % policy_steps_per_iter != 0: warnings.warn( f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the checkpoint will be saved at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) # Get the first environment observation and start the optimization @@ -590,9 +590,9 @@ def ext_sota_main(fabric: Fabric, cfg: Dict[str, Any]): next_obs[k] = next_obs[k].reshape(cfg.env.num_envs, -1, *next_obs[k].shape[-2:]) step_data[k] = next_obs[k][np.newaxis] - for update in range(start_step, num_updates + 1): + for iter_num in range(start_iter, total_iters + 1): for _ in range(0, cfg.algo.rollout_steps): - policy_step += policy_steps_per_update + policy_step += policy_steps_per_iter # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment @@ -653,7 +653,7 @@ def ext_sota_main(fabric: Fabric, cfg: Dict[str, Any]): train(fabric, agent, optimizer, local_data, aggregator, cfg) # Log metrics - if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run: + if policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters or cfg.dry_run: # Sync distributed metrics if aggregator and not aggregator.disabled: metrics_dict = aggregator.compute() @@ -686,13 +686,13 @@ def ext_sota_main(fabric: Fabric, cfg: Dict[str, Any]): if ( (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) or cfg.dry_run - or update == num_updates + or iter_num == total_iters ): last_checkpoint = policy_step state = { "agent": agent.state_dict(), "optimizer": optimizer.state_dict(), - "update_step": update, + "iter_num": iter_num * world_size, } ckpt_path = os.path.join(log_dir, f"checkpoint/ckpt_{policy_step}_{fabric.global_rank}.ckpt") fabric.call("on_checkpoint_coupled", fabric=fabric, ckpt_path=ckpt_path, state=state) diff --git a/howto/register_new_algorithm.md b/howto/register_new_algorithm.md index 714d797f..04d3c09d 100644 --- a/howto/register_new_algorithm.md +++ b/howto/register_new_algorithm.md @@ -548,36 +548,36 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]): # Global variables last_train = 0 train_step = 0 - start_step = ( + start_iter = ( # + 1 because the checkpoint is at the end of the update step # (when resuming from a checkpoint, the update at the checkpoint # is ended and you have to start with the next one) - (state["update"] // fabric.world_size) + 1 + (state["iter_num"] // fabric.world_size) + 1 if cfg.checkpoint.resume_from else 1 ) - policy_step = state["update"] * cfg.env.num_envs * cfg.algo.rollout_steps if cfg.checkpoint.resume_from else 0 + policy_step = state["iter_num"] * cfg.env.num_envs * cfg.algo.rollout_steps if cfg.checkpoint.resume_from else 0 last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 - policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size) - num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 + policy_steps_per_iter = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size) + total_iters = cfg.algo.total_steps // policy_steps_per_iter if not cfg.dry_run else 1 if cfg.checkpoint.resume_from: cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size # Warning for log and checkpoint every - if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: + if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_iter != 0: warnings.warn( f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the metrics will be logged at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) - if cfg.checkpoint.every % policy_steps_per_update != 0: + if cfg.checkpoint.every % policy_steps_per_iter != 0: warnings.warn( f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the checkpoint will be saved at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) # Get the first environment observation and start the optimization @@ -588,9 +588,9 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]): next_obs[k] = next_obs[k].reshape(cfg.env.num_envs, -1, *next_obs[k].shape[-2:]) step_data[k] = next_obs[k][np.newaxis] - for update in range(start_step, num_updates + 1): + for iter_num in range(start_iter, total_iters + 1): for _ in range(0, cfg.algo.rollout_steps): - policy_step += policy_steps_per_update + policy_step += policy_steps_per_iter # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment @@ -651,7 +651,7 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]): train(fabric, agent, optimizer, local_data, aggregator, cfg) # Log metrics - if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run: + if policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters or cfg.dry_run: # Sync distributed metrics if aggregator and not aggregator.disabled: metrics_dict = aggregator.compute() @@ -684,13 +684,13 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]): if ( (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) or cfg.dry_run - or update == num_updates + or iter_num == total_iters ): last_checkpoint = policy_step state = { "agent": agent.state_dict(), "optimizer": optimizer.state_dict(), - "update_step": update, + "iter_num": iter_num * world_size, } ckpt_path = os.path.join(log_dir, f"checkpoint/ckpt_{policy_step}_{fabric.global_rank}.ckpt") fabric.call("on_checkpoint_coupled", fabric=fabric, ckpt_path=ckpt_path, state=state) diff --git a/sheeprl/algos/a2c/a2c.py b/sheeprl/algos/a2c/a2c.py index e16b6932..2002fbe1 100644 --- a/sheeprl/algos/a2c/a2c.py +++ b/sheeprl/algos/a2c/a2c.py @@ -200,23 +200,23 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): train_step = 0 policy_step = 0 last_checkpoint = 0 - policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size) - num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 + policy_steps_per_iter = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size) + total_iters = cfg.algo.total_steps // policy_steps_per_iter if not cfg.dry_run else 1 # Warning for log and checkpoint every - if cfg.metric.log_every % policy_steps_per_update != 0: + if cfg.metric.log_every % policy_steps_per_iter != 0: warnings.warn( f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the metrics will be logged at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) - if cfg.checkpoint.every % policy_steps_per_update != 0: + if cfg.checkpoint.every % policy_steps_per_iter != 0: warnings.warn( f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the checkpoint will be saved at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) # Get the first environment observation and start the optimization @@ -225,10 +225,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): for k in obs_keys: step_data[k] = next_obs[k][np.newaxis] - for update in range(1, num_updates + 1): + for iter_num in range(1, total_iters + 1): with torch.inference_mode(): for _ in range(0, cfg.algo.rollout_steps): - policy_step += policy_steps_per_update + policy_step += policy_steps_per_iter # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment @@ -325,7 +325,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): train(fabric, agent, optimizer, local_data, aggregator, cfg) # Log metrics - if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run: + if policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters or cfg.dry_run: # Sync distributed metrics if aggregator and not aggregator.disabled: metrics_dict = aggregator.compute() @@ -358,13 +358,13 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if ( (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) or cfg.dry_run - or (update == num_updates and cfg.checkpoint.save_last) + or (iter_num == total_iters and cfg.checkpoint.save_last) ): last_checkpoint = policy_step state = { "agent": agent.state_dict(), "optimizer": optimizer.state_dict(), - "update_step": update, + "iter_num": iter_num * world_size, } ckpt_path = os.path.join(log_dir, f"checkpoint/ckpt_{policy_step}_{fabric.global_rank}.ckpt") fabric.call("on_checkpoint_coupled", fabric=fabric, ckpt_path=ckpt_path, state=state) diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index 6dcc6ac0..c9fdfbf5 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -496,24 +496,24 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Global variables train_step = 0 last_train = 0 - start_step = ( + start_iter = ( # + 1 because the checkpoint is at the end of the update step # (when resuming from a checkpoint, the update at the checkpoint # is ended and you have to start with the next one) - (state["update"] // world_size) + 1 + (state["iter_num"] // world_size) + 1 if cfg.checkpoint.resume_from else 1 ) - policy_step = state["update"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 + policy_step = state["iter_num"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 - policy_steps_per_update = int(cfg.env.num_envs * world_size) - num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 - learning_starts = (cfg.algo.learning_starts // policy_steps_per_update) if not cfg.dry_run else 0 - prefill_steps = learning_starts + start_step + policy_steps_per_iter = int(cfg.env.num_envs * world_size) + total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1 + learning_starts = (cfg.algo.learning_starts // policy_steps_per_iter) if not cfg.dry_run else 0 + prefill_steps = learning_starts + start_iter if cfg.checkpoint.resume_from: cfg.algo.per_rank_batch_size = state["batch_size"] // world_size - learning_starts += start_step + learning_starts += start_iter # Create Ratio class ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps) @@ -521,19 +521,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ratio.load_state_dict(state["ratio"]) # Warning for log and checkpoint every - if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: + if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_iter != 0: warnings.warn( f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the metrics will be logged at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) - if cfg.checkpoint.every % policy_steps_per_update != 0: + if cfg.checkpoint.every % policy_steps_per_iter != 0: warnings.warn( f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the checkpoint will be saved at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) # Get the first environment observation and start the optimization @@ -551,8 +551,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): player.init_states() cumulative_per_rank_gradient_steps = 0 - for update in range(start_step, num_updates + 1): - policy_step += policy_steps_per_update + for iter_num in range(start_iter, total_iters + 1): + policy_step += policy_steps_per_iter with torch.inference_mode(): # Measure environment interaction time: this considers both the model forward @@ -560,7 +560,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): # Sample an action given the observation received by the environment if ( - update <= learning_starts + iter_num <= learning_starts and cfg.checkpoint.resume_from is None and "minedojo" not in cfg.env.wrapper._target_.lower() ): @@ -643,8 +643,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): player.init_states(reset_envs=dones_idxes) # Train the agent - if update >= learning_starts: - ratio_steps = policy_step - prefill_steps + policy_steps_per_update + if iter_num >= learning_starts: + ratio_steps = policy_step - prefill_steps + policy_steps_per_iter per_rank_gradient_steps = ratio(ratio_steps / world_size) if per_rank_gradient_steps > 0: with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): @@ -676,7 +676,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): aggregator.update("Params/exploration_amount", actor._get_expl_amount(policy_step)) # Log metrics - if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): + if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters): # Sync distributed metrics if aggregator and not aggregator.disabled: metrics_dict = aggregator.compute() @@ -712,7 +712,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Checkpoint Model if (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) or ( - update == num_updates and cfg.checkpoint.save_last + iter_num == total_iters and cfg.checkpoint.save_last ): last_checkpoint = policy_step state = { @@ -723,7 +723,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "actor_optimizer": actor_optimizer.state_dict(), "critic_optimizer": critic_optimizer.state_dict(), "ratio": ratio.state_dict(), - "update": update * world_size, + "iter_num": iter_num * world_size, "batch_size": cfg.algo.per_rank_batch_size * world_size, "last_log": last_log, "last_checkpoint": last_checkpoint, diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index 8d195268..f9e251ab 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -519,24 +519,24 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Global variables train_step = 0 last_train = 0 - start_step = ( + start_iter = ( # + 1 because the checkpoint is at the end of the update step # (when resuming from a checkpoint, the update at the checkpoint # is ended and you have to start with the next one) - (state["update"] // world_size) + 1 + (state["iter_num"] // world_size) + 1 if cfg.checkpoint.resume_from else 1 ) - policy_step = state["update"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 + policy_step = state["iter_num"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 - policy_steps_per_update = int(cfg.env.num_envs * world_size) - num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 - learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 - prefill_steps = learning_starts + start_step + policy_steps_per_iter = int(cfg.env.num_envs * world_size) + total_iters = cfg.algo.total_steps // policy_steps_per_iter if not cfg.dry_run else 1 + learning_starts = cfg.algo.learning_starts // policy_steps_per_iter if not cfg.dry_run else 0 + prefill_steps = learning_starts + start_iter if cfg.checkpoint.resume_from: cfg.algo.per_rank_batch_size = state["batch_size"] // world_size - learning_starts += start_step + learning_starts += start_iter # Create Ratio class ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps) @@ -544,19 +544,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ratio.load_state_dict(state["ratio"]) # Warning for log and checkpoint every - if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: + if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_iter != 0: warnings.warn( f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the metrics will be logged at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) - if cfg.checkpoint.every % policy_steps_per_update != 0: + if cfg.checkpoint.every % policy_steps_per_iter != 0: warnings.warn( f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the checkpoint will be saved at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) # Get the first environment observation and start the optimization @@ -576,8 +576,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): player.init_states() cumulative_per_rank_gradient_steps = 0 - for update in range(start_step, num_updates + 1): - policy_step += policy_steps_per_update + for iter_num in range(start_iter, total_iters + 1): + policy_step += policy_steps_per_iter with torch.inference_mode(): # Measure environment interaction time: this considers both the model forward @@ -585,7 +585,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): # Sample an action given the observation received by the environment if ( - update <= learning_starts + iter_num <= learning_starts and cfg.checkpoint.resume_from is None and "minedojo" not in cfg.env.wrapper._target_.lower() ): @@ -671,8 +671,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): player.init_states(dones_idxes) # Train the agent - if update >= learning_starts: - ratio_steps = policy_step - prefill_steps + policy_steps_per_update + if iter_num >= learning_starts: + ratio_steps = policy_step - prefill_steps + policy_steps_per_iter per_rank_gradient_steps = ratio(ratio_steps / world_size) if per_rank_gradient_steps > 0: local_data = rb.sample_tensors( @@ -710,7 +710,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): train_step += world_size # Log metrics - if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): + if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters): # Sync distributed metrics if aggregator and not aggregator.disabled: metrics_dict = aggregator.compute() @@ -746,7 +746,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Checkpoint Model if (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) or ( - update == num_updates and cfg.checkpoint.save_last + iter_num == total_iters and cfg.checkpoint.save_last ): last_checkpoint = policy_step state = { @@ -758,7 +758,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "actor_optimizer": actor_optimizer.state_dict(), "critic_optimizer": critic_optimizer.state_dict(), "ratio": ratio.state_dict(), - "update": update * world_size, + "iter_num": iter_num * world_size, "batch_size": cfg.algo.per_rank_batch_size * world_size, "last_log": last_log, "last_checkpoint": last_checkpoint, diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index eb30adeb..51b48351 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -491,24 +491,24 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Global variables train_step = 0 last_train = 0 - start_step = ( + start_iter = ( # + 1 because the checkpoint is at the end of the update step # (when resuming from a checkpoint, the update at the checkpoint # is ended and you have to start with the next one) - (state["update"] // fabric.world_size) + 1 + (state["iter_num"] // fabric.world_size) + 1 if cfg.checkpoint.resume_from else 1 ) - policy_step = state["update"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 + policy_step = state["iter_num"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 - policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) - num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 - learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 - prefill_steps = learning_starts + start_step + policy_steps_per_iter = int(cfg.env.num_envs * fabric.world_size) + total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1 + learning_starts = cfg.algo.learning_starts // policy_steps_per_iter if not cfg.dry_run else 0 + prefill_steps = learning_starts + start_iter if cfg.checkpoint.resume_from: cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size - learning_starts += start_step + learning_starts += start_iter # Create Ratio class ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps) @@ -516,19 +516,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ratio.load_state_dict(state["ratio"]) # Warning for log and checkpoint every - if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: + if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_iter != 0: warnings.warn( f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the metrics will be logged at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) - if cfg.checkpoint.every % policy_steps_per_update != 0: + if cfg.checkpoint.every % policy_steps_per_iter != 0: warnings.warn( f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the checkpoint will be saved at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) # Get the first environment observation and start the optimization @@ -543,8 +543,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): player.init_states() cumulative_per_rank_gradient_steps = 0 - for update in range(start_step, num_updates + 1): - policy_step += policy_steps_per_update + for iter_num in range(start_iter, total_iters + 1): + policy_step += policy_steps_per_iter with torch.inference_mode(): # Measure environment interaction time: this considers both the model forward @@ -552,7 +552,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): # Sample an action given the observation received by the environment if ( - update <= learning_starts + iter_num <= learning_starts and cfg.checkpoint.resume_from is None and "minedojo" not in cfg.env.wrapper._target_.lower() ): @@ -653,8 +653,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): player.init_states(dones_idxes) # Train the agent - if update >= learning_starts: - ratio_steps = policy_step - prefill_steps + policy_steps_per_update + if iter_num >= learning_starts: + ratio_steps = policy_step - prefill_steps + policy_steps_per_iter per_rank_gradient_steps = ratio(ratio_steps / world_size) if per_rank_gradient_steps > 0: local_data = rb.sample_tensors( @@ -695,7 +695,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): train_step += world_size # Log metrics - if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): + if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters): # Sync distributed metrics if aggregator and not aggregator.disabled: metrics_dict = aggregator.compute() @@ -731,7 +731,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Checkpoint Model if (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) or ( - update == num_updates and cfg.checkpoint.save_last + iter_num == total_iters and cfg.checkpoint.save_last ): last_checkpoint = policy_step state = { @@ -744,7 +744,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "critic_optimizer": critic_optimizer.state_dict(), "moments": moments.state_dict(), "ratio": ratio.state_dict(), - "update": update * fabric.world_size, + "iter_num": iter_num * fabric.world_size, "batch_size": cfg.algo.per_rank_batch_size * fabric.world_size, "last_log": last_log, "last_checkpoint": last_checkpoint, diff --git a/sheeprl/algos/droq/droq.py b/sheeprl/algos/droq/droq.py index d254b638..636ad453 100644 --- a/sheeprl/algos/droq/droq.py +++ b/sheeprl/algos/droq/droq.py @@ -247,24 +247,24 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Global variables last_train = 0 train_step = 0 - start_step = ( + start_iter = ( # + 1 because the checkpoint is at the end of the update step # (when resuming from a checkpoint, the update at the checkpoint # is ended and you have to start with the next one) - (state["update"] // fabric.world_size) + 1 + (state["iter_num"] // fabric.world_size) + 1 if cfg.checkpoint.resume_from else 1 ) - policy_step = state["update"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 + policy_step = state["iter_num"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 - policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) - num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 - learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 - prefill_steps = learning_starts + start_step + policy_steps_per_iter = int(cfg.env.num_envs * fabric.world_size) + total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1 + learning_starts = cfg.algo.learning_starts // policy_steps_per_iter if not cfg.dry_run else 0 + prefill_steps = learning_starts + start_iter if cfg.checkpoint.resume_from: cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size - learning_starts += start_step + learning_starts += start_iter # Create Ratio class ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps) @@ -272,19 +272,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ratio.load_state_dict(state["ratio"]) # Warning for log and checkpoint every - if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: + if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_iter != 0: warnings.warn( f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the metrics will be logged at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) - if cfg.checkpoint.every % policy_steps_per_update != 0: + if cfg.checkpoint.every % policy_steps_per_iter != 0: warnings.warn( f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the checkpoint will be saved at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) step_data = {} @@ -293,13 +293,13 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): per_rank_gradient_steps = 0 cumulative_per_rank_gradient_steps = 0 - for update in range(start_step, num_updates + 1): - policy_step += cfg.env.num_envs * fabric.world_size + for iter_num in range(start_iter, total_iters + 1): + policy_step += policy_steps_per_iter # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): - if update <= learning_starts: + if iter_num <= learning_starts: actions = envs.action_space.sample() else: with torch.inference_mode(): @@ -345,8 +345,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): obs = next_obs # Train the agent - if update >= learning_starts: - ratio_steps = policy_step - prefill_steps + policy_steps_per_update + if iter_num >= learning_starts: + ratio_steps = policy_step - prefill_steps + policy_steps_per_iter per_rank_gradient_steps = ratio(ratio_steps / world_size) if per_rank_gradient_steps > 0: train( @@ -364,7 +364,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): cumulative_per_rank_gradient_steps += per_rank_gradient_steps # Log metrics - if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): + if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters): # Sync distributed metrics if aggregator and not aggregator.disabled: metrics_dict = aggregator.compute() @@ -400,7 +400,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Checkpoint model if (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) or ( - update == num_updates and cfg.checkpoint.save_last + iter_num == total_iters and cfg.checkpoint.save_last ): last_checkpoint = policy_step state = { @@ -409,7 +409,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "actor_optimizer": actor_optimizer.state_dict(), "alpha_optimizer": alpha_optimizer.state_dict(), "ratio": ratio.state_dict(), - "update": update * fabric.world_size, + "iter_num": iter_num * fabric.world_size, "batch_size": cfg.algo.per_rank_batch_size * fabric.world_size, "last_log": last_log, "last_checkpoint": last_checkpoint, diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py index c362883b..3f66a606 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py @@ -520,24 +520,24 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Global variables train_step = 0 last_train = 0 - start_step = ( + start_iter = ( # + 1 because the checkpoint is at the end of the update step # (when resuming from a checkpoint, the update at the checkpoint # is ended and you have to start with the next one) - (state["update"] // world_size) + 1 + (state["iter_num"] // world_size) + 1 if cfg.checkpoint.resume_from else 1 ) - policy_step = state["update"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 + policy_step = state["iter_num"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 - policy_steps_per_update = int(cfg.env.num_envs * world_size) - num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 - learning_starts = (cfg.algo.learning_starts // policy_steps_per_update) if not cfg.dry_run else 0 - prefill_steps = learning_starts + start_step + policy_steps_per_iter = int(cfg.env.num_envs * world_size) + total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1 + learning_starts = (cfg.algo.learning_starts // policy_steps_per_iter) if not cfg.dry_run else 0 + prefill_steps = learning_starts + start_iter if cfg.checkpoint.resume_from: cfg.algo.per_rank_batch_size = state["batch_size"] // world_size - learning_starts += start_step + learning_starts += start_iter # Create Ratio class ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps) @@ -545,19 +545,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ratio.load_state_dict(state["ratio"]) # Warning for log and checkpoint every - if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: + if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_iter != 0: warnings.warn( f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the metrics will be logged at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) - if cfg.checkpoint.every % policy_steps_per_update != 0: + if cfg.checkpoint.every % policy_steps_per_iter != 0: warnings.warn( f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the checkpoint will be saved at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) # Get the first environment observation and start the optimization @@ -575,8 +575,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): player.init_states() cumulative_per_rank_gradient_steps = 0 - for update in range(start_step, num_updates + 1): - policy_step += policy_steps_per_update + for iter_num in range(start_iter, total_iters + 1): + policy_step += policy_steps_per_iter with torch.inference_mode(): # Measure environment interaction time: this considers both the model forward @@ -584,7 +584,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): # Sample an action given the observation received by the environment if ( - update <= learning_starts + iter_num <= learning_starts and cfg.checkpoint.resume_from is None and "minedojo" not in cfg.env.wrapper._target_.lower() ): @@ -667,8 +667,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): player.init_states(reset_envs=dones_idxes) # Train the agent - if update >= learning_starts: - ratio_steps = policy_step - prefill_steps + policy_steps_per_update + if iter_num >= learning_starts: + ratio_steps = policy_step - prefill_steps + policy_steps_per_iter per_rank_gradient_steps = ratio(ratio_steps / world_size) if per_rank_gradient_steps > 0: with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): @@ -710,7 +710,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ) # Log metrics - if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): + if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters): # Sync distributed metrics if aggregator and not aggregator.disabled: metrics_dict = aggregator.compute() @@ -746,7 +746,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Checkpoint Model if (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) or ( - update == num_updates and cfg.checkpoint.save_last + iter_num == total_iters and cfg.checkpoint.save_last ): last_checkpoint = policy_step state = { @@ -759,7 +759,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "critic_task_optimizer": critic_task_optimizer.state_dict(), "ensemble_optimizer": ensemble_optimizer.state_dict(), "ratio": ratio.state_dict(), - "update": update * world_size, + "iter_num": iter_num * world_size, "batch_size": cfg.algo.per_rank_batch_size * world_size, "actor_exploration": actor_exploration.state_dict(), "critic_exploration": critic_exploration.state_dict(), diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py index d05aa22d..7767a9bb 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py @@ -191,24 +191,24 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Global variables train_step = 0 last_train = 0 - start_step = ( + start_iter = ( # + 1 because the checkpoint is at the end of the update step # (when resuming from a checkpoint, the update at the checkpoint # is ended and you have to start with the next one) - (state["update"] // world_size) + 1 + (state["iter_num"] // world_size) + 1 if resume_from_checkpoint else 1 ) - policy_step = state["update"] * cfg.env.num_envs if resume_from_checkpoint else 0 + policy_step = state["iter_num"] * cfg.env.num_envs if resume_from_checkpoint else 0 last_log = state["last_log"] if resume_from_checkpoint else 0 last_checkpoint = state["last_checkpoint"] if resume_from_checkpoint else 0 - policy_steps_per_update = int(cfg.env.num_envs * world_size) - num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 - learning_starts = (cfg.algo.learning_starts // policy_steps_per_update) if not cfg.dry_run else 0 - prefill_steps = learning_starts + start_step + policy_steps_per_iter = int(cfg.env.num_envs * world_size) + total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1 + learning_starts = (cfg.algo.learning_starts // policy_steps_per_iter) if not cfg.dry_run else 0 + prefill_steps = learning_starts + start_iter if resume_from_checkpoint: cfg.algo.per_rank_batch_size = state["batch_size"] // world_size - learning_starts += start_step + learning_starts += start_iter # Create Ratio class ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps) @@ -216,19 +216,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): ratio.load_state_dict(state["ratio"]) # Warning for log and checkpoint every - if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: + if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_iter != 0: warnings.warn( f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the metrics will be logged at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) - if cfg.checkpoint.every % policy_steps_per_update != 0: + if cfg.checkpoint.every % policy_steps_per_iter != 0: warnings.warn( f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the checkpoint will be saved at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) # Get the first environment observation and start the optimization @@ -246,8 +246,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): player.init_states() cumulative_per_rank_gradient_steps = 0 - for update in range(start_step, num_updates + 1): - policy_step += policy_steps_per_update + for iter_num in range(start_iter, total_iters + 1): + policy_step += policy_steps_per_iter with torch.inference_mode(): # Measure environment interaction time: this considers both the model forward @@ -322,8 +322,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): player.init_states(reset_envs=dones_idxes) # Train the agent - if update >= learning_starts: - ratio_steps = policy_step - prefill_steps + policy_steps_per_update + if iter_num >= learning_starts: + ratio_steps = policy_step - prefill_steps + policy_steps_per_iter per_rank_gradient_steps = ratio(ratio_steps / world_size) if per_rank_gradient_steps > 0: if player.actor_type != "task": @@ -363,7 +363,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): ) # Log metrics - if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): + if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters): # Sync distributed metrics if aggregator and not aggregator.disabled: metrics_dict = aggregator.compute() @@ -399,7 +399,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Checkpoint Model if (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) or ( - update == num_updates and cfg.checkpoint.save_last + iter_num == total_iters and cfg.checkpoint.save_last ): last_checkpoint = policy_step state = { @@ -410,7 +410,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): "actor_task_optimizer": actor_task_optimizer.state_dict(), "critic_task_optimizer": critic_task_optimizer.state_dict(), "ratio": ratio.state_dict(), - "update": update * world_size, + "iter_num": iter_num * world_size, "batch_size": cfg.algo.per_rank_batch_size * world_size, "actor_exploration": actor_exploration.state_dict(), "last_log": last_log, diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py index 7e42f5e1..72b2fb3b 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py @@ -655,24 +655,24 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Global variables train_step = 0 last_train = 0 - start_step = ( + start_iter = ( # + 1 because the checkpoint is at the end of the update step # (when resuming from a checkpoint, the update at the checkpoint # is ended and you have to start with the next one) - (state["update"] // world_size) + 1 + (state["iter_num"] // world_size) + 1 if cfg.checkpoint.resume_from else 1 ) - policy_step = state["update"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 + policy_step = state["iter_num"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 - policy_steps_per_update = int(cfg.env.num_envs * world_size) - num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 - learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 - prefill_steps = learning_starts + start_step + policy_steps_per_iter = int(cfg.env.num_envs * world_size) + total_iters = cfg.algo.total_steps // policy_steps_per_iter if not cfg.dry_run else 1 + learning_starts = cfg.algo.learning_starts // policy_steps_per_iter if not cfg.dry_run else 0 + prefill_steps = learning_starts + start_iter if cfg.checkpoint.resume_from: cfg.algo.per_rank_batch_size = state["batch_size"] // world_size - learning_starts += start_step + learning_starts += start_iter # Create Ratio class ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps) @@ -680,19 +680,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ratio.load_state_dict(state["ratio"]) # Warning for log and checkpoint every - if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: + if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_iter != 0: warnings.warn( f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the metrics will be logged at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) - if cfg.checkpoint.every % policy_steps_per_update != 0: + if cfg.checkpoint.every % policy_steps_per_iter != 0: warnings.warn( f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the checkpoint will be saved at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) # Get the first environment observation and start the optimization @@ -712,8 +712,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): player.init_states() cumulative_per_rank_gradient_steps = 0 - for update in range(start_step, num_updates + 1): - policy_step += policy_steps_per_update + for iter_num in range(start_iter, total_iters + 1): + policy_step += policy_steps_per_iter with torch.inference_mode(): # Measure environment interaction time: this considers both the model forward @@ -721,7 +721,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): # Sample an action given the observation received by the environment if ( - update <= learning_starts + iter_num <= learning_starts and cfg.checkpoint.resume_from is None and "minedojo" not in cfg.env.wrapper._target_.lower() ): @@ -807,8 +807,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): player.init_states(dones_idxes) # Train the agent - if update >= learning_starts: - ratio_steps = policy_step - prefill_steps + policy_steps_per_update + if iter_num >= learning_starts: + ratio_steps = policy_step - prefill_steps + policy_steps_per_iter per_rank_gradient_steps = ratio(ratio_steps / world_size) if per_rank_gradient_steps > 0: local_data = rb.sample_tensors( @@ -859,7 +859,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): train_step += world_size # Log metrics - if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): + if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters): # Sync distributed metrics if aggregator and not aggregator.disabled: metrics_dict = aggregator.compute() @@ -895,7 +895,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Checkpoint Model if (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) or ( - update == num_updates and cfg.checkpoint.save_last + iter_num == total_iters and cfg.checkpoint.save_last ): last_checkpoint = policy_step state = { @@ -909,7 +909,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "critic_task_optimizer": critic_task_optimizer.state_dict(), "ensemble_optimizer": ensemble_optimizer.state_dict(), "ratio": ratio.state_dict(), - "update": update * world_size, + "iter_num": iter_num * world_size, "batch_size": cfg.algo.per_rank_batch_size * world_size, "actor_exploration": actor_exploration.state_dict(), "critic_exploration": critic_exploration.state_dict(), diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py index 71109835..f5203987 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py @@ -209,24 +209,24 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Global variables train_step = 0 last_train = 0 - start_step = ( + start_iter = ( # + 1 because the checkpoint is at the end of the update step # (when resuming from a checkpoint, the update at the checkpoint # is ended and you have to start with the next one) - (state["update"] // world_size) + 1 + (state["iter_num"] // world_size) + 1 if resume_from_checkpoint else 1 ) - policy_step = state["update"] * cfg.env.num_envs if resume_from_checkpoint else 0 + policy_step = state["iter_num"] * cfg.env.num_envs if resume_from_checkpoint else 0 last_log = state["last_log"] if resume_from_checkpoint else 0 last_checkpoint = state["last_checkpoint"] if resume_from_checkpoint else 0 - policy_steps_per_update = int(cfg.env.num_envs * world_size) - num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 - learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 - prefill_steps = learning_starts + start_step + policy_steps_per_iter = int(cfg.env.num_envs * world_size) + total_iters = cfg.algo.total_steps // policy_steps_per_iter if not cfg.dry_run else 1 + learning_starts = cfg.algo.learning_starts // policy_steps_per_iter if not cfg.dry_run else 0 + prefill_steps = learning_starts + start_iter if resume_from_checkpoint: cfg.algo.per_rank_batch_size = state["batch_size"] // world_size - learning_starts += start_step + learning_starts += start_iter # Create Ratio class ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps) @@ -234,19 +234,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): ratio.load_state_dict(state["ratio"]) # Warning for log and checkpoint every - if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: + if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_iter != 0: warnings.warn( f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the metrics will be logged at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) - if cfg.checkpoint.every % policy_steps_per_update != 0: + if cfg.checkpoint.every % policy_steps_per_iter != 0: warnings.warn( f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the checkpoint will be saved at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) # Get the first environment observation and start the optimization @@ -266,8 +266,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): player.init_states() cumulative_per_rank_gradient_steps = 0 - for update in range(start_step, num_updates + 1): - policy_step += policy_steps_per_update + for iter_num in range(start_iter, total_iters + 1): + policy_step += policy_steps_per_iter with torch.inference_mode(): # Measure environment interaction time: this considers both the model forward @@ -345,8 +345,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): player.init_states(dones_idxes) # Train the agent - if update >= learning_starts: - ratio_steps = policy_step - prefill_steps + policy_steps_per_update + if iter_num >= learning_starts: + ratio_steps = policy_step - prefill_steps + policy_steps_per_iter per_rank_gradient_steps = ratio(ratio_steps / world_size) if per_rank_gradient_steps > 0: if player.actor_type != "task": @@ -390,7 +390,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): train_step += world_size # Log metrics - if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): + if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters): # Sync distributed metrics if aggregator and not aggregator.disabled: metrics_dict = aggregator.compute() @@ -426,7 +426,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Checkpoint Model if (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) or ( - update == num_updates and cfg.checkpoint.save_last + iter_num == total_iters and cfg.checkpoint.save_last ): last_checkpoint = policy_step state = { @@ -438,7 +438,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): "actor_task_optimizer": actor_task_optimizer.state_dict(), "critic_task_optimizer": critic_task_optimizer.state_dict(), "ratio": ratio.state_dict(), - "update": update * world_size, + "iter_num": iter_num * world_size, "batch_size": cfg.algo.per_rank_batch_size * world_size, "actor_exploration": actor_exploration.state_dict(), "last_log": last_log, diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py index 19d44487..9339bae9 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py @@ -732,24 +732,24 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Global variables train_step = 0 last_train = 0 - start_step = ( + start_iter = ( # + 1 because the checkpoint is at the end of the update step # (when resuming from a checkpoint, the update at the checkpoint # is ended and you have to start with the next one) - (state["update"] // fabric.world_size) + 1 + (state["iter_num"] // fabric.world_size) + 1 if cfg.checkpoint.resume_from else 1 ) - policy_step = state["update"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 + policy_step = state["iter_num"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 - policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) - num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 - learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 - prefill_steps = learning_starts + start_step + policy_steps_per_iter = int(cfg.env.num_envs * fabric.world_size) + total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1 + learning_starts = cfg.algo.learning_starts // policy_steps_per_iter if not cfg.dry_run else 0 + prefill_steps = learning_starts + start_iter if cfg.checkpoint.resume_from: cfg.algo.per_rank_batch_size = state["batch_size"] // world_size - learning_starts += start_step + learning_starts += start_iter # Create Ratio class ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps) @@ -757,19 +757,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ratio.load_state_dict(state["ratio"]) # Warning for log and checkpoint every - if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: + if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_iter != 0: warnings.warn( f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the metrics will be logged at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) - if cfg.checkpoint.every % policy_steps_per_update != 0: + if cfg.checkpoint.every % policy_steps_per_iter != 0: warnings.warn( f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the checkpoint will be saved at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) # Get the first environment observation and start the optimization @@ -784,8 +784,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): player.init_states() cumulative_per_rank_gradient_steps = 0 - for update in range(start_step, num_updates + 1): - policy_step += policy_steps_per_update + for iter_num in range(start_iter, total_iters + 1): + policy_step += policy_steps_per_iter with torch.inference_mode(): # Measure environment interaction time: this considers both the model forward @@ -793,7 +793,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): # Sample an action given the observation received by the environment if ( - update <= learning_starts + iter_num <= learning_starts and cfg.checkpoint.resume_from is None and "minedojo" not in cfg.algo.actor.cls.lower() ): @@ -894,8 +894,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): player.init_states(dones_idxes) # Train the agent - if update >= learning_starts: - ratio_steps = policy_step - prefill_steps + policy_steps_per_update + if iter_num >= learning_starts: + ratio_steps = policy_step - prefill_steps + policy_steps_per_iter per_rank_gradient_steps = ratio(ratio_steps / world_size) if per_rank_gradient_steps > 0: local_data = rb.sample_tensors( @@ -949,7 +949,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): train_step += world_size # Log metrics - if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): + if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters): # Sync distributed metrics if aggregator and not aggregator.disabled: metrics_dict = aggregator.compute() @@ -985,7 +985,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Checkpoint Model if (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) or ( - update == num_updates and cfg.checkpoint.save_last + iter_num == total_iters and cfg.checkpoint.save_last ): last_checkpoint = policy_step critics_exploration_state = {"critics_exploration": {}} @@ -1007,7 +1007,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "critic_task_optimizer": critic_task_optimizer.state_dict(), "ensemble_optimizer": ensemble_optimizer.state_dict(), "ratio": ratio.state_dict(), - "update": update * world_size, + "iter_num": iter_num * world_size, "batch_size": cfg.algo.per_rank_batch_size * world_size, "actor_exploration": actor_exploration.state_dict(), "actor_exploration_optimizer": actor_exploration_optimizer.state_dict(), diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py index 35b7f08c..7370db43 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py @@ -196,24 +196,24 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Global variables train_step = 0 last_train = 0 - start_step = ( + start_iter = ( # + 1 because the checkpoint is at the end of the update step # (when resuming from a checkpoint, the update at the checkpoint # is ended and you have to start with the next one) - (state["update"] // fabric.world_size) + 1 + (state["iter_num"] // fabric.world_size) + 1 if resume_from_checkpoint else 1 ) - policy_step = state["update"] * cfg.env.num_envs if resume_from_checkpoint else 0 + policy_step = state["iter_num"] * cfg.env.num_envs if resume_from_checkpoint else 0 last_log = state["last_log"] if resume_from_checkpoint else 0 last_checkpoint = state["last_checkpoint"] if resume_from_checkpoint else 0 - policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) - num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 - learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 - prefill_steps = learning_starts + start_step + policy_steps_per_iter = int(cfg.env.num_envs * fabric.world_size) + total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1 + learning_starts = cfg.algo.learning_starts // policy_steps_per_iter if not cfg.dry_run else 0 + prefill_steps = learning_starts + start_iter if resume_from_checkpoint: cfg.algo.per_rank_batch_size = state["batch_size"] // world_size - learning_starts += start_step + learning_starts += start_iter # Create Ratio class ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps) @@ -221,19 +221,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): ratio.load_state_dict(state["ratio"]) # Warning for log and checkpoint every - if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: + if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_iter != 0: warnings.warn( f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the metrics will be logged at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) - if cfg.checkpoint.every % policy_steps_per_update != 0: + if cfg.checkpoint.every % policy_steps_per_iter != 0: warnings.warn( f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the checkpoint will be saved at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) # Get the first environment observation and start the optimization @@ -248,8 +248,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): player.init_states() cumulative_per_rank_gradient_steps = 0 - for update in range(start_step, num_updates + 1): - policy_step += policy_steps_per_update + for iter_num in range(start_iter, total_iters + 1): + policy_step += policy_steps_per_iter with torch.inference_mode(): # Measure environment interaction time: this considers both the model forward @@ -342,8 +342,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): player.init_states(dones_idxes) # Train the agent - if update >= learning_starts: - ratio_steps = policy_step - prefill_steps + policy_steps_per_update + if iter_num >= learning_starts: + ratio_steps = policy_step - prefill_steps + policy_steps_per_iter per_rank_gradient_steps = ratio(ratio_steps / world_size) if per_rank_gradient_steps > 0: if player.actor_type != "task": @@ -390,7 +390,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): train_step += world_size # Log metrics - if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): + if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters): # Sync distributed metrics if aggregator and not aggregator.disabled: metrics_dict = aggregator.compute() @@ -426,7 +426,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # Checkpoint Model if (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) or ( - update == num_updates and cfg.checkpoint.save_last + iter_num == total_iters and cfg.checkpoint.save_last ): last_checkpoint = policy_step state = { @@ -438,7 +438,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): "actor_task_optimizer": actor_task_optimizer.state_dict(), "critic_task_optimizer": critic_task_optimizer.state_dict(), "ratio": ratio.state_dict(), - "update": update * world_size, + "iter_num": iter_num * world_size, "batch_size": cfg.algo.per_rank_batch_size * world_size, "actor_exploration": actor_exploration.state_dict(), "last_log": last_log, diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index 69ce9591..8dfe8812 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -214,43 +214,43 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Global variables last_train = 0 train_step = 0 - start_step = ( + start_iter = ( # + 1 because the checkpoint is at the end of the update step # (when resuming from a checkpoint, the update at the checkpoint # is ended and you have to start with the next one) - (state["update"] // fabric.world_size) + 1 + (state["iter_num"] // fabric.world_size) + 1 if cfg.checkpoint.resume_from else 1 ) - policy_step = state["update"] * cfg.env.num_envs * cfg.algo.rollout_steps if cfg.checkpoint.resume_from else 0 + policy_step = state["iter_num"] * cfg.env.num_envs * cfg.algo.rollout_steps if cfg.checkpoint.resume_from else 0 last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 - policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size) - num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 + policy_steps_per_iter = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size) + total_iters = cfg.algo.total_steps // policy_steps_per_iter if not cfg.dry_run else 1 if cfg.checkpoint.resume_from: cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size # Warning for log and checkpoint every - if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: + if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_iter != 0: warnings.warn( f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the metrics will be logged at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) - if cfg.checkpoint.every % policy_steps_per_update != 0: + if cfg.checkpoint.every % policy_steps_per_iter != 0: warnings.warn( f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the checkpoint will be saved at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) # Linear learning rate scheduler if cfg.algo.anneal_lr: from torch.optim.lr_scheduler import PolynomialLR - scheduler = PolynomialLR(optimizer=optimizer, total_iters=num_updates, power=1.0) + scheduler = PolynomialLR(optimizer=optimizer, total_iters=total_iters, power=1.0) if cfg.checkpoint.resume_from: scheduler.load_state_dict(state["scheduler"]) @@ -262,10 +262,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): next_obs[k] = next_obs[k].reshape(cfg.env.num_envs, -1, *next_obs[k].shape[-2:]) step_data[k] = next_obs[k][np.newaxis] - for update in range(start_step, num_updates + 1): + for iter_num in range(start_iter, total_iters + 1): with torch.inference_mode(): for _ in range(0, cfg.algo.rollout_steps): - policy_step += policy_steps_per_update + policy_step += policy_steps_per_iter # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment @@ -382,7 +382,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.log("Info/ent_coef", cfg.algo.ent_coef, policy_step) # Log metrics - if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): + if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters): # Sync distributed metrics if aggregator and not aggregator.disabled: metrics_dict = aggregator.compute() @@ -416,23 +416,23 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): scheduler.step() if cfg.algo.anneal_clip_coef: cfg.algo.clip_coef = polynomial_decay( - update, initial=initial_clip_coef, final=0.0, max_decay_steps=num_updates, power=1.0 + iter_num, initial=initial_clip_coef, final=0.0, max_decay_steps=total_iters, power=1.0 ) if cfg.algo.anneal_ent_coef: cfg.algo.ent_coef = polynomial_decay( - update, initial=initial_ent_coef, final=0.0, max_decay_steps=num_updates, power=1.0 + iter_num, initial=initial_ent_coef, final=0.0, max_decay_steps=total_iters, power=1.0 ) # Checkpoint model if (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) or ( - update == num_updates and cfg.checkpoint.save_last + iter_num == total_iters and cfg.checkpoint.save_last ): last_checkpoint = policy_step state = { "agent": agent.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict() if cfg.algo.anneal_lr else None, - "update": update * world_size, + "iter_num": iter_num * world_size, "batch_size": cfg.algo.per_rank_batch_size * fabric.world_size, "last_log": last_log, "last_checkpoint": last_checkpoint, diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index 423d20f3..d41f7911 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -141,49 +141,48 @@ def player( ) # Global variables - start_step = ( + start_iter = ( # + 1 because the checkpoint is at the end of the update step # (when resuming from a checkpoint, the update at the checkpoint # is ended and you have to start with the next one) - state["update"] + 1 + state["iter_num"] + 1 if cfg.checkpoint.resume_from else 1 ) - policy_step = state["update"] * cfg.env.num_envs * cfg.algo.rollout_steps if cfg.checkpoint.resume_from else 0 + policy_step = state["iter_num"] * cfg.env.num_envs * cfg.algo.rollout_steps if cfg.checkpoint.resume_from else 0 last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 - policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps) - num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 + policy_steps_per_iter = int(cfg.env.num_envs * cfg.algo.rollout_steps) + total_iters = cfg.algo.total_steps // policy_steps_per_iter if not cfg.dry_run else 1 # Warning for log and checkpoint every - if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: + if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_iter != 0: warnings.warn( f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the metrics will be logged at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) - if cfg.checkpoint.every % policy_steps_per_update != 0: + if cfg.checkpoint.every % policy_steps_per_iter != 0: warnings.warn( f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the checkpoint will be saved at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) - if policy_steps_per_update < world_collective.world_size - 1: + if policy_steps_per_iter < world_collective.world_size - 1: raise RuntimeError( "The number of trainers ({}) is greater than the available collected data ({}). ".format( - world_collective.world_size - 1, policy_steps_per_update + world_collective.world_size - 1, policy_steps_per_iter ) + "Consider to lower the number of trainers at least to the size of available collected data" ) chunks_sizes = [ - len(chunk) - for chunk in torch.tensor_split(torch.arange(policy_steps_per_update), world_collective.world_size - 1) + len(chunk) for chunk in torch.tensor_split(torch.arange(policy_steps_per_iter), world_collective.world_size - 1) ] - # Broadcast num_updates to all the world - update_t = torch.as_tensor([num_updates], device=device, dtype=torch.float32) + # Broadcast total_iters to all the world + update_t = torch.as_tensor([total_iters], device=device, dtype=torch.float32) world_collective.broadcast(update_t, src=0) # Get the first environment observation and start the optimization @@ -194,9 +193,9 @@ def player( next_obs[k] = next_obs[k].reshape(cfg.env.num_envs, -1, *next_obs[k].shape[-2:]) step_data[k] = next_obs[k][np.newaxis] - params = {"update": start_step, "last_log": last_log, "last_checkpoint": last_checkpoint} + params = {"iter_num": start_iter, "last_log": last_log, "last_checkpoint": last_checkpoint} world_collective.scatter_object_list([None], [params] * world_collective.world_size, src=0) - for _ in range(start_step, num_updates + 1): + for _ in range(start_iter, total_iters + 1): for _ in range(0, cfg.algo.rollout_steps): policy_step += cfg.env.num_envs @@ -425,15 +424,15 @@ def trainer( ) # Receive maximum number of updates from the player - num_updates = torch.zeros(1, device=device) - world_collective.broadcast(num_updates, src=0) - num_updates = num_updates.item() + total_iters = torch.zeros(1, device=device) + world_collective.broadcast(total_iters, src=0) + total_iters = total_iters.item() # Linear learning rate scheduler if cfg.algo.anneal_lr: from torch.optim.lr_scheduler import PolynomialLR - scheduler = PolynomialLR(optimizer=optimizer, total_iters=num_updates, power=1.0) + scheduler = PolynomialLR(optimizer=optimizer, total_iters=total_iters, power=1.0) if cfg.checkpoint.resume_from: scheduler.load_state_dict(state["scheduler"]) @@ -446,12 +445,12 @@ def trainer( last_train = 0 train_step = 0 - policy_steps_per_update = cfg.env.num_envs * cfg.algo.rollout_steps + policy_steps_per_iter = cfg.env.num_envs * cfg.algo.rollout_steps params = [None] world_collective.scatter_object_list(params, [None for _ in range(world_collective.world_size)], src=0) params = params[0] - update = params["update"] - policy_step = update * policy_steps_per_update + iter_num = params["iter_num"] + policy_step = iter_num * policy_steps_per_iter last_log = params["last_log"] last_checkpoint = params["last_checkpoint"] initial_ent_coef = copy.deepcopy(cfg.algo.ent_coef) @@ -468,7 +467,7 @@ def trainer( "agent": agent.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict() if cfg.algo.anneal_lr else None, - "update": update, + "iter_num": iter_num, "batch_size": cfg.algo.per_rank_batch_size * (world_collective.world_size - 1), "last_log": last_log, "last_checkpoint": last_checkpoint, @@ -589,12 +588,12 @@ def trainer( if cfg.algo.anneal_clip_coef: cfg.algo.clip_coef = polynomial_decay( - update, initial=initial_clip_coef, final=0.0, max_decay_steps=num_updates, power=1.0 + iter_num, initial=initial_clip_coef, final=0.0, max_decay_steps=total_iters, power=1.0 ) if cfg.algo.anneal_ent_coef: cfg.algo.ent_coef = polynomial_decay( - update, initial=initial_ent_coef, final=0.0, max_decay_steps=num_updates, power=1.0 + iter_num, initial=initial_ent_coef, final=0.0, max_decay_steps=total_iters, power=1.0 ) # Checkpoint model on rank-0: send it everything @@ -604,7 +603,7 @@ def trainer( "agent": agent.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict() if cfg.algo.anneal_lr else None, - "update": update, + "iter_num": iter_num, "batch_size": cfg.algo.per_rank_batch_size * (world_collective.world_size - 1), "last_log": last_log, "last_checkpoint": last_checkpoint, @@ -617,7 +616,7 @@ def trainer( ckpt_path=ckpt_path, state=state, ) - update += 1 + iter_num += 1 policy_step += cfg.env.num_envs * cfg.algo.rollout_steps diff --git a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py index 4298cb36..8b84d128 100644 --- a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py +++ b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py @@ -230,43 +230,43 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Global variables last_train = 0 train_step = 0 - start_step = ( + start_iter = ( # + 1 because the checkpoint is at the end of the update step # (when resuming from a checkpoint, the update at the checkpoint # is ended and you have to start with the next one) - (state["update"] // fabric.world_size) + 1 + (state["iter_num"] // fabric.world_size) + 1 if cfg.checkpoint.resume_from else 1 ) - policy_step = state["update"] * cfg.env.num_envs * cfg.algo.rollout_steps if cfg.checkpoint.resume_from else 0 + policy_step = state["iter_num"] * cfg.env.num_envs * cfg.algo.rollout_steps if cfg.checkpoint.resume_from else 0 last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 - policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size) - num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 + policy_steps_per_iter = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size) + total_iters = cfg.algo.total_steps // policy_steps_per_iter if not cfg.dry_run else 1 if cfg.checkpoint.resume_from: cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size # Warning for log and checkpoint every - if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: + if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_iter != 0: warnings.warn( f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the metrics will be logged at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) - if cfg.checkpoint.every % policy_steps_per_update != 0: + if cfg.checkpoint.every % policy_steps_per_iter != 0: warnings.warn( f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the checkpoint will be saved at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) # Linear learning rate scheduler if cfg.algo.anneal_lr: from torch.optim.lr_scheduler import PolynomialLR - scheduler = PolynomialLR(optimizer=optimizer, total_iters=num_updates, power=1.0) + scheduler = PolynomialLR(optimizer=optimizer, total_iters=total_iters, power=1.0) if cfg.checkpoint.resume_from: scheduler.load_state_dict(state["scheduler"]) @@ -284,10 +284,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): prev_actions = np.zeros((1, cfg.env.num_envs, sum(actions_dim))) torch_prev_actions = torch.zeros(1, cfg.env.num_envs, sum(actions_dim), device=device, dtype=torch.float32) - for update in range(start_step, num_updates + 1): + for iter_num in range(start_iter, total_iters + 1): with torch.inference_mode(): for _ in range(0, cfg.algo.rollout_steps): - policy_step += policy_steps_per_update + policy_step += policy_steps_per_iter # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment @@ -455,7 +455,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.log("Info/ent_coef", cfg.algo.ent_coef, policy_step) # Log metrics - if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): + if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters): # Sync distributed metrics if aggregator and not aggregator.disabled: metrics_dict = aggregator.compute() @@ -489,23 +489,23 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): scheduler.step() if cfg.algo.anneal_clip_coef: cfg.algo.clip_coef = polynomial_decay( - update, initial=initial_clip_coef, final=0.0, max_decay_steps=num_updates, power=1.0 + iter_num, initial=initial_clip_coef, final=0.0, max_decay_steps=total_iters, power=1.0 ) if cfg.algo.anneal_ent_coef: cfg.algo.ent_coef = polynomial_decay( - update, initial=initial_ent_coef, final=0.0, max_decay_steps=num_updates, power=1.0 + iter_num, initial=initial_ent_coef, final=0.0, max_decay_steps=total_iters, power=1.0 ) # Checkpoint model if ( cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every - ) or update == num_updates: + ) or iter_num == total_iters: last_checkpoint = policy_step ckpt_state = { "agent": agent.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict() if cfg.algo.anneal_lr else None, - "update": update * world_size, + "iter_num": iter_num * world_size, "batch_size": cfg.algo.per_rank_batch_size * fabric.world_size, "last_log": last_log, "last_checkpoint": last_checkpoint, diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py index 06b76356..00a5a3a8 100644 --- a/sheeprl/algos/sac/sac.py +++ b/sheeprl/algos/sac/sac.py @@ -39,7 +39,7 @@ def train( aggregator: MetricAggregator | None, update: int, cfg: Dict[str, Any], - policy_steps_per_update: int, + policy_steps_per_iter: int, group: Optional[CollectibleGroup] = None, ): # Update the soft-critic @@ -53,7 +53,7 @@ def train( qf_optimizer.step() # Update the target networks with EMA - if update % (cfg.algo.critic.target_network_frequency // policy_steps_per_update + 1) == 0: + if update % (cfg.algo.critic.target_network_frequency // policy_steps_per_iter + 1) == 0: agent.qfs_target_ema() # Update the actor @@ -198,24 +198,24 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Global variables last_train = 0 train_step = 0 - start_step = ( + start_iter = ( # + 1 because the checkpoint is at the end of the update step # (when resuming from a checkpoint, the update at the checkpoint # is ended and you have to start with the next one) - (state["update"] // fabric.world_size) + 1 + (state["iter_num"] // fabric.world_size) + 1 if cfg.checkpoint.resume_from else 1 ) - policy_step = state["update"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 + policy_step = state["iter_num"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 - policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) - num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 - learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 - prefill_steps = learning_starts + start_step + policy_steps_per_iter = int(cfg.env.num_envs * fabric.world_size) + total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1 + learning_starts = cfg.algo.learning_starts // policy_steps_per_iter if not cfg.dry_run else 0 + prefill_steps = learning_starts + start_iter if cfg.checkpoint.resume_from: cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size - learning_starts += start_step + learning_starts += start_iter # Create Ratio class ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps) @@ -223,19 +223,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ratio.load_state_dict(state["ratio"]) # Warning for log and checkpoint every - if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: + if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_iter != 0: warnings.warn( f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the metrics will be logged at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) - if cfg.checkpoint.every % policy_steps_per_update != 0: + if cfg.checkpoint.every % policy_steps_per_iter != 0: warnings.warn( f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the checkpoint will be saved at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) step_data = {} @@ -244,13 +244,13 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): per_rank_gradient_steps = 0 cumulative_per_rank_gradient_steps = 0 - for update in range(start_step, num_updates + 1): - policy_step += policy_steps_per_update + for iter_num in range(start_iter, total_iters + 1): + policy_step += policy_steps_per_iter # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): - if update <= learning_starts: + if iter_num <= learning_starts: actions = envs.action_space.sample() else: # Sample an action given the observation received by the environment @@ -295,9 +295,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): obs = next_obs # Train the agent - if update >= learning_starts: + if iter_num >= learning_starts: per_rank_gradient_steps = ( - ratio((policy_step - prefill_steps + policy_steps_per_update) / world_size) + ratio((policy_step - prefill_steps + policy_steps_per_iter) / world_size) if not cfg.run_benchmarks else 1 ) @@ -347,15 +347,15 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): alpha_optimizer, batch, aggregator, - update, + iter_num, cfg, - policy_steps_per_update, + policy_steps_per_iter, ) cumulative_per_rank_gradient_steps += 1 train_step += world_size # Log metrics - if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): + if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters): # Sync distributed metrics if aggregator and not aggregator.disabled: metrics_dict = aggregator.compute() @@ -391,7 +391,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Checkpoint model if (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) or ( - update == num_updates and cfg.checkpoint.save_last + iter_num == total_iters and cfg.checkpoint.save_last ): last_checkpoint = policy_step state = { @@ -400,7 +400,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "actor_optimizer": actor_optimizer.state_dict(), "alpha_optimizer": alpha_optimizer.state_dict(), "ratio": ratio.state_dict(), - "update": update * fabric.world_size, + "iter_num": iter_num * fabric.world_size, "batch_size": cfg.algo.per_rank_batch_size * fabric.world_size, "last_log": last_log, "last_checkpoint": last_checkpoint, diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py index 71740b07..6350bbec 100644 --- a/sheeprl/algos/sac/sac_decoupled.py +++ b/sheeprl/algos/sac/sac_decoupled.py @@ -131,22 +131,23 @@ def player( # Global variables first_info_sent = False - start_step = ( + start_iter = ( # + 1 because the checkpoint is at the end of the update step # (when resuming from a checkpoint, the update at the checkpoint # is ended and you have to start with the next one) - state["update"] + 1 + state["iter_num"] + 1 if cfg.checkpoint.resume_from else 1 ) - policy_step = state["update"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 + policy_step = state["iter_num"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 - policy_steps_per_update = int(cfg.env.num_envs) - num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 - learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 + policy_steps_per_iter = int(cfg.env.num_envs) + total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1 + learning_starts = cfg.algo.learning_starts // policy_steps_per_iter if not cfg.dry_run else 0 + prefill_steps = learning_starts + start_iter if cfg.checkpoint.resume_from and not cfg.buffer.checkpoint: - learning_starts += start_step + learning_starts += start_iter # Create Ratio class ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps) @@ -154,19 +155,19 @@ def player( ratio.load_state_dict(state["ratio"]) # Warning for log and checkpoint every - if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: + if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_iter != 0: warnings.warn( f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the metrics will be logged at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) - if cfg.checkpoint.every % policy_steps_per_update != 0: + if cfg.checkpoint.every % policy_steps_per_iter != 0: warnings.warn( f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the checkpoint will be saved at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) step_data = {} @@ -175,13 +176,13 @@ def player( per_rank_gradient_steps = 0 cumulative_per_rank_gradient_steps = 0 - for update in range(start_step, num_updates + 1): + for iter_num in range(start_iter, total_iters + 1): policy_step += cfg.env.num_envs # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): - if update <= learning_starts: + if iter_num <= learning_starts: actions = envs.action_space.sample() else: # Sample an action given the observation received by the environment @@ -225,14 +226,15 @@ def player( obs = next_obs # Send data to the training agents - if update >= learning_starts: - per_rank_gradient_steps = ratio(policy_step / (fabric.world_size - 1)) + if iter_num >= learning_starts: + ratio_steps = policy_step - prefill_steps + policy_steps_per_iter + per_rank_gradient_steps = ratio(ratio_steps / (fabric.world_size - 1)) cumulative_per_rank_gradient_steps += per_rank_gradient_steps if per_rank_gradient_steps > 0: # Send local info to the trainers if not first_info_sent: world_collective.broadcast_object_list( - [{"update": update, "last_log": last_log, "last_checkpoint": last_checkpoint}], src=0 + [{"iter_num": iter_num, "last_log": last_log, "last_checkpoint": last_checkpoint}], src=0 ) first_info_sent = True @@ -297,7 +299,7 @@ def player( # Checkpoint model if ( - update >= learning_starts # otherwise the processes end up deadlocked + iter_num >= learning_starts # otherwise the processes end up deadlocked and cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every ): @@ -417,20 +419,20 @@ def trainer( aggregator: MetricAggregator = hydra.utils.instantiate(cfg.metric.aggregator, _convert_="all").to(device) # Receive data from player regarding the: - # * update + # * iter_num # * last_log # * last_checkpoint data = [None] world_collective.broadcast_object_list(data, src=0) - update = data[0]["update"] + iter_num = data[0]["iter_num"] last_log = data[0]["last_log"] last_checkpoint = data[0]["last_checkpoint"] # Start training train_step = 0 last_train = 0 - policy_steps_per_update = cfg.env.num_envs - policy_step = update * policy_steps_per_update + policy_steps_per_iter = cfg.env.num_envs + policy_step = iter_num * policy_steps_per_iter while True: # Wait for data data = [None] @@ -444,7 +446,7 @@ def trainer( "qf_optimizer": qf_optimizer.state_dict(), "actor_optimizer": actor_optimizer.state_dict(), "alpha_optimizer": alpha_optimizer.state_dict(), - "update": update, + "iter_num": iter_num, "batch_size": cfg.algo.per_rank_batch_size * (world_collective.world_size - 1), "last_log": last_log, "last_checkpoint": last_checkpoint, @@ -479,9 +481,9 @@ def trainer( alpha_optimizer, {k: data[k][batch_idxes] for k in data.keys()}, aggregator, - update, + iter_num, cfg, - policy_steps_per_update, + policy_steps_per_iter, group=optimization_pg, ) train_step += group_world_size @@ -522,7 +524,7 @@ def trainer( "qf_optimizer": qf_optimizer.state_dict(), "actor_optimizer": actor_optimizer.state_dict(), "alpha_optimizer": alpha_optimizer.state_dict(), - "update": update, + "iter_num": iter_num, "batch_size": cfg.algo.per_rank_batch_size * (world_collective.world_size - 1), "last_log": last_log, "last_checkpoint": last_checkpoint, @@ -537,8 +539,8 @@ def trainer( ) # Update counters - update += 1 - policy_step += policy_steps_per_update + iter_num += 1 + policy_step += policy_steps_per_iter @register_algorithm(decoupled=True) diff --git a/sheeprl/algos/sac_ae/sac_ae.py b/sheeprl/algos/sac_ae/sac_ae.py index 92bededf..b310d87c 100644 --- a/sheeprl/algos/sac_ae/sac_ae.py +++ b/sheeprl/algos/sac_ae/sac_ae.py @@ -270,24 +270,24 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Global variables last_train = 0 train_step = 0 - start_step = ( + start_iter = ( # + 1 because the checkpoint is at the end of the update step # (when resuming from a checkpoint, the update at the checkpoint # is ended and you have to start with the next one) - (state["update"] // fabric.world_size) + 1 + (state["iter_num"] // fabric.world_size) + 1 if cfg.checkpoint.resume_from else 1 ) - policy_step = state["update"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 + policy_step = state["iter_num"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 - policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) - num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1 - learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 - prefill_steps = learning_starts + start_step + policy_steps_per_iter = int(cfg.env.num_envs * fabric.world_size) + total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1 + learning_starts = cfg.algo.learning_starts // policy_steps_per_iter if not cfg.dry_run else 0 + prefill_steps = learning_starts + start_iter if cfg.checkpoint.resume_from: cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size - learning_starts += start_step + learning_starts += start_iter # Create Ratio class ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps) @@ -295,19 +295,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ratio.load_state_dict(state["ratio"]) # Warning for log and checkpoint every - if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0: + if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_iter != 0: warnings.warn( f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the metrics will be logged at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) - if cfg.checkpoint.every % policy_steps_per_update != 0: + if cfg.checkpoint.every % policy_steps_per_iter != 0: warnings.warn( f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the " - f"policy_steps_per_update value ({policy_steps_per_update}), so " + f"policy_steps_per_iter value ({policy_steps_per_iter}), so " "the checkpoint will be saved at the nearest greater multiple of the " - "policy_steps_per_update value." + "policy_steps_per_iter value." ) # Get the first environment observation and start the optimization @@ -319,13 +319,13 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): per_rank_gradient_steps = 0 cumulative_per_rank_gradient_steps = 0 - for update in range(start_step, num_updates + 1): - policy_step += cfg.env.num_envs * fabric.world_size + for iter_num in range(start_iter, total_iters + 1): + policy_step += policy_steps_per_iter # Measure environment interaction time: this considers both the model forward # to get the action given the observation and the time taken into the environment with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False): - if update <= learning_starts: + if iter_num <= learning_starts: actions = envs.action_space.sample() else: with torch.inference_mode(): @@ -373,8 +373,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): obs = next_obs # Train the agent - if update >= learning_starts: - ratio_steps = policy_step - prefill_steps + policy_steps_per_update + if iter_num >= learning_starts: + ratio_steps = policy_step - prefill_steps + policy_steps_per_iter per_rank_gradient_steps = ratio(ratio_steps / world_size) if per_rank_gradient_steps > 0: # We sample one time to reduce the communications between processes @@ -426,7 +426,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): train_step += world_size # Log metrics - if cfg.metric.log_level and (policy_step - last_log >= cfg.metric.log_every or update == num_updates): + if cfg.metric.log_level and (policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters): # Sync distributed metrics if aggregator and not aggregator.disabled: metrics_dict = aggregator.compute() @@ -462,7 +462,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Checkpoint model if (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) or ( - update == num_updates and cfg.checkpoint.save_last + iter_num == total_iters and cfg.checkpoint.save_last ): last_checkpoint = policy_step state = { @@ -475,7 +475,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): "encoder_optimizer": encoder_optimizer.state_dict(), "decoder_optimizer": decoder_optimizer.state_dict(), "ratio": ratio.state_dict(), - "update": update * fabric.world_size, + "iter_num": iter_num * fabric.world_size, "batch_size": cfg.algo.per_rank_batch_size * fabric.world_size, "last_log": last_log, "last_checkpoint": last_checkpoint,