Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/from update to iter #284

Merged
merged 6 commits into from
May 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions howto/register_external_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -550,36 +550,36 @@ def ext_sota_main(fabric: Fabric, cfg: Dict[str, Any]):
# Global variables
last_train = 0
train_step = 0
start_step = (
start_iter = (
# + 1 because the checkpoint is at the end of the update step
# (when resuming from a checkpoint, the update at the checkpoint
# is ended and you have to start with the next one)
(state["update"] // fabric.world_size) + 1
(state["iter_num"] // fabric.world_size) + 1
if cfg.checkpoint.resume_from
else 1
)
policy_step = state["update"] * cfg.env.num_envs * cfg.algo.rollout_steps if cfg.checkpoint.resume_from else 0
policy_step = state["iter_num"] * cfg.env.num_envs * cfg.algo.rollout_steps if cfg.checkpoint.resume_from else 0
last_log = state["last_log"] if cfg.checkpoint.resume_from else 0
last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0
policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size)
num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1
policy_steps_per_iter = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size)
total_iters = cfg.algo.total_steps // policy_steps_per_iter if not cfg.dry_run else 1
if cfg.checkpoint.resume_from:
cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size

# Warning for log and checkpoint every
if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0:
if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_iter != 0:
warnings.warn(
f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the "
f"policy_steps_per_update value ({policy_steps_per_update}), so "
f"policy_steps_per_iter value ({policy_steps_per_iter}), so "
"the metrics will be logged at the nearest greater multiple of the "
"policy_steps_per_update value."
"policy_steps_per_iter value."
)
if cfg.checkpoint.every % policy_steps_per_update != 0:
if cfg.checkpoint.every % policy_steps_per_iter != 0:
warnings.warn(
f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the "
f"policy_steps_per_update value ({policy_steps_per_update}), so "
f"policy_steps_per_iter value ({policy_steps_per_iter}), so "
"the checkpoint will be saved at the nearest greater multiple of the "
"policy_steps_per_update value."
"policy_steps_per_iter value."
)

# Get the first environment observation and start the optimization
Expand All @@ -590,9 +590,9 @@ def ext_sota_main(fabric: Fabric, cfg: Dict[str, Any]):
next_obs[k] = next_obs[k].reshape(cfg.env.num_envs, -1, *next_obs[k].shape[-2:])
step_data[k] = next_obs[k][np.newaxis]

for update in range(start_step, num_updates + 1):
for iter_num in range(start_iter, total_iters + 1):
for _ in range(0, cfg.algo.rollout_steps):
policy_step += policy_steps_per_update
policy_step += policy_steps_per_iter

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
Expand Down Expand Up @@ -653,7 +653,7 @@ def ext_sota_main(fabric: Fabric, cfg: Dict[str, Any]):
train(fabric, agent, optimizer, local_data, aggregator, cfg)

# Log metrics
if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run:
if policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters or cfg.dry_run:
# Sync distributed metrics
if aggregator and not aggregator.disabled:
metrics_dict = aggregator.compute()
Expand Down Expand Up @@ -686,13 +686,13 @@ def ext_sota_main(fabric: Fabric, cfg: Dict[str, Any]):
if (
(cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every)
or cfg.dry_run
or update == num_updates
or iter_num == total_iters
):
last_checkpoint = policy_step
state = {
"agent": agent.state_dict(),
"optimizer": optimizer.state_dict(),
"update_step": update,
"iter_num": iter_num * world_size,
}
ckpt_path = os.path.join(log_dir, f"checkpoint/ckpt_{policy_step}_{fabric.global_rank}.ckpt")
fabric.call("on_checkpoint_coupled", fabric=fabric, ckpt_path=ckpt_path, state=state)
Expand Down
32 changes: 16 additions & 16 deletions howto/register_new_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -548,36 +548,36 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]):
# Global variables
last_train = 0
train_step = 0
start_step = (
start_iter = (
# + 1 because the checkpoint is at the end of the update step
# (when resuming from a checkpoint, the update at the checkpoint
# is ended and you have to start with the next one)
(state["update"] // fabric.world_size) + 1
(state["iter_num"] // fabric.world_size) + 1
if cfg.checkpoint.resume_from
else 1
)
policy_step = state["update"] * cfg.env.num_envs * cfg.algo.rollout_steps if cfg.checkpoint.resume_from else 0
policy_step = state["iter_num"] * cfg.env.num_envs * cfg.algo.rollout_steps if cfg.checkpoint.resume_from else 0
last_log = state["last_log"] if cfg.checkpoint.resume_from else 0
last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0
policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size)
num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1
policy_steps_per_iter = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size)
total_iters = cfg.algo.total_steps // policy_steps_per_iter if not cfg.dry_run else 1
if cfg.checkpoint.resume_from:
cfg.algo.per_rank_batch_size = state["batch_size"] // fabric.world_size

# Warning for log and checkpoint every
if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0:
if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_iter != 0:
warnings.warn(
f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the "
f"policy_steps_per_update value ({policy_steps_per_update}), so "
f"policy_steps_per_iter value ({policy_steps_per_iter}), so "
"the metrics will be logged at the nearest greater multiple of the "
"policy_steps_per_update value."
"policy_steps_per_iter value."
)
if cfg.checkpoint.every % policy_steps_per_update != 0:
if cfg.checkpoint.every % policy_steps_per_iter != 0:
warnings.warn(
f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the "
f"policy_steps_per_update value ({policy_steps_per_update}), so "
f"policy_steps_per_iter value ({policy_steps_per_iter}), so "
"the checkpoint will be saved at the nearest greater multiple of the "
"policy_steps_per_update value."
"policy_steps_per_iter value."
)

# Get the first environment observation and start the optimization
Expand All @@ -588,9 +588,9 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]):
next_obs[k] = next_obs[k].reshape(cfg.env.num_envs, -1, *next_obs[k].shape[-2:])
step_data[k] = next_obs[k][np.newaxis]

for update in range(start_step, num_updates + 1):
for iter_num in range(start_iter, total_iters + 1):
for _ in range(0, cfg.algo.rollout_steps):
policy_step += policy_steps_per_update
policy_step += policy_steps_per_iter

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
Expand Down Expand Up @@ -651,7 +651,7 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]):
train(fabric, agent, optimizer, local_data, aggregator, cfg)

# Log metrics
if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run:
if policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters or cfg.dry_run:
# Sync distributed metrics
if aggregator and not aggregator.disabled:
metrics_dict = aggregator.compute()
Expand Down Expand Up @@ -684,13 +684,13 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]):
if (
(cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every)
or cfg.dry_run
or update == num_updates
or iter_num == total_iters
):
last_checkpoint = policy_step
state = {
"agent": agent.state_dict(),
"optimizer": optimizer.state_dict(),
"update_step": update,
"iter_num": iter_num * world_size,
}
ckpt_path = os.path.join(log_dir, f"checkpoint/ckpt_{policy_step}_{fabric.global_rank}.ckpt")
fabric.call("on_checkpoint_coupled", fabric=fabric, ckpt_path=ckpt_path, state=state)
Expand Down
26 changes: 13 additions & 13 deletions sheeprl/algos/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,23 +200,23 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
train_step = 0
policy_step = 0
last_checkpoint = 0
policy_steps_per_update = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size)
num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1
policy_steps_per_iter = int(cfg.env.num_envs * cfg.algo.rollout_steps * world_size)
total_iters = cfg.algo.total_steps // policy_steps_per_iter if not cfg.dry_run else 1

# Warning for log and checkpoint every
if cfg.metric.log_every % policy_steps_per_update != 0:
if cfg.metric.log_every % policy_steps_per_iter != 0:
warnings.warn(
f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the "
f"policy_steps_per_update value ({policy_steps_per_update}), so "
f"policy_steps_per_iter value ({policy_steps_per_iter}), so "
"the metrics will be logged at the nearest greater multiple of the "
"policy_steps_per_update value."
"policy_steps_per_iter value."
)
if cfg.checkpoint.every % policy_steps_per_update != 0:
if cfg.checkpoint.every % policy_steps_per_iter != 0:
warnings.warn(
f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the "
f"policy_steps_per_update value ({policy_steps_per_update}), so "
f"policy_steps_per_iter value ({policy_steps_per_iter}), so "
"the checkpoint will be saved at the nearest greater multiple of the "
"policy_steps_per_update value."
"policy_steps_per_iter value."
)

# Get the first environment observation and start the optimization
Expand All @@ -225,10 +225,10 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
for k in obs_keys:
step_data[k] = next_obs[k][np.newaxis]

for update in range(1, num_updates + 1):
for iter_num in range(1, total_iters + 1):
with torch.inference_mode():
for _ in range(0, cfg.algo.rollout_steps):
policy_step += policy_steps_per_update
policy_step += policy_steps_per_iter

# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
Expand Down Expand Up @@ -325,7 +325,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
train(fabric, agent, optimizer, local_data, aggregator, cfg)

# Log metrics
if policy_step - last_log >= cfg.metric.log_every or update == num_updates or cfg.dry_run:
if policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters or cfg.dry_run:
# Sync distributed metrics
if aggregator and not aggregator.disabled:
metrics_dict = aggregator.compute()
Expand Down Expand Up @@ -358,13 +358,13 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
if (
(cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every)
or cfg.dry_run
or (update == num_updates and cfg.checkpoint.save_last)
or (iter_num == total_iters and cfg.checkpoint.save_last)
):
last_checkpoint = policy_step
state = {
"agent": agent.state_dict(),
"optimizer": optimizer.state_dict(),
"update_step": update,
"iter_num": iter_num * world_size,
}
ckpt_path = os.path.join(log_dir, f"checkpoint/ckpt_{policy_step}_{fabric.global_rank}.ckpt")
fabric.call("on_checkpoint_coupled", fabric=fabric, ckpt_path=ckpt_path, state=state)
Expand Down
44 changes: 22 additions & 22 deletions sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,44 +496,44 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Global variables
train_step = 0
last_train = 0
start_step = (
start_iter = (
# + 1 because the checkpoint is at the end of the update step
# (when resuming from a checkpoint, the update at the checkpoint
# is ended and you have to start with the next one)
(state["update"] // world_size) + 1
(state["iter_num"] // world_size) + 1
if cfg.checkpoint.resume_from
else 1
)
policy_step = state["update"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0
policy_step = state["iter_num"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0
last_log = state["last_log"] if cfg.checkpoint.resume_from else 0
last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0
policy_steps_per_update = int(cfg.env.num_envs * world_size)
num_updates = cfg.algo.total_steps // policy_steps_per_update if not cfg.dry_run else 1
learning_starts = (cfg.algo.learning_starts // policy_steps_per_update) if not cfg.dry_run else 0
prefill_steps = learning_starts + start_step
policy_steps_per_iter = int(cfg.env.num_envs * world_size)
total_iters = int(cfg.algo.total_steps // policy_steps_per_iter) if not cfg.dry_run else 1
learning_starts = (cfg.algo.learning_starts // policy_steps_per_iter) if not cfg.dry_run else 0
prefill_steps = learning_starts + start_iter
if cfg.checkpoint.resume_from:
cfg.algo.per_rank_batch_size = state["batch_size"] // world_size
learning_starts += start_step
learning_starts += start_iter

# Create Ratio class
ratio = Ratio(cfg.algo.replay_ratio, pretrain_steps=cfg.algo.per_rank_pretrain_steps)
if cfg.checkpoint.resume_from:
ratio.load_state_dict(state["ratio"])

# Warning for log and checkpoint every
if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_update != 0:
if cfg.metric.log_level > 0 and cfg.metric.log_every % policy_steps_per_iter != 0:
warnings.warn(
f"The metric.log_every parameter ({cfg.metric.log_every}) is not a multiple of the "
f"policy_steps_per_update value ({policy_steps_per_update}), so "
f"policy_steps_per_iter value ({policy_steps_per_iter}), so "
"the metrics will be logged at the nearest greater multiple of the "
"policy_steps_per_update value."
"policy_steps_per_iter value."
)
if cfg.checkpoint.every % policy_steps_per_update != 0:
if cfg.checkpoint.every % policy_steps_per_iter != 0:
warnings.warn(
f"The checkpoint.every parameter ({cfg.checkpoint.every}) is not a multiple of the "
f"policy_steps_per_update value ({policy_steps_per_update}), so "
f"policy_steps_per_iter value ({policy_steps_per_iter}), so "
"the checkpoint will be saved at the nearest greater multiple of the "
"policy_steps_per_update value."
"policy_steps_per_iter value."
)

# Get the first environment observation and start the optimization
Expand All @@ -551,16 +551,16 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
player.init_states()

cumulative_per_rank_gradient_steps = 0
for update in range(start_step, num_updates + 1):
policy_step += policy_steps_per_update
for iter_num in range(start_iter, total_iters + 1):
policy_step += policy_steps_per_iter

with torch.inference_mode():
# Measure environment interaction time: this considers both the model forward
# to get the action given the observation and the time taken into the environment
with timer("Time/env_interaction_time", SumMetric, sync_on_compute=False):
# Sample an action given the observation received by the environment
if (
update <= learning_starts
iter_num <= learning_starts
and cfg.checkpoint.resume_from is None
and "minedojo" not in cfg.env.wrapper._target_.lower()
):
Expand Down Expand Up @@ -643,8 +643,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
player.init_states(reset_envs=dones_idxes)

# Train the agent
if update >= learning_starts:
ratio_steps = policy_step - prefill_steps + policy_steps_per_update
if iter_num >= learning_starts:
ratio_steps = policy_step - prefill_steps + policy_steps_per_iter
per_rank_gradient_steps = ratio(ratio_steps / world_size)
if per_rank_gradient_steps > 0:
with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute):
Expand Down Expand Up @@ -676,7 +676,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
aggregator.update("Params/exploration_amount", actor._get_expl_amount(policy_step))

# Log metrics
if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or update == num_updates):
if cfg.metric.log_level > 0 and (policy_step - last_log >= cfg.metric.log_every or iter_num == total_iters):
# Sync distributed metrics
if aggregator and not aggregator.disabled:
metrics_dict = aggregator.compute()
Expand Down Expand Up @@ -712,7 +712,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Checkpoint Model
if (cfg.checkpoint.every > 0 and policy_step - last_checkpoint >= cfg.checkpoint.every) or (
update == num_updates and cfg.checkpoint.save_last
iter_num == total_iters and cfg.checkpoint.save_last
):
last_checkpoint = policy_step
state = {
Expand All @@ -723,7 +723,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
"actor_optimizer": actor_optimizer.state_dict(),
"critic_optimizer": critic_optimizer.state_dict(),
"ratio": ratio.state_dict(),
"update": update * world_size,
"iter_num": iter_num * world_size,
"batch_size": cfg.algo.per_rank_batch_size * world_size,
"last_log": last_log,
"last_checkpoint": last_checkpoint,
Expand Down
Loading
Loading