Skip to content

Commit

Permalink
feat: added ratio state to checkpoint in sac decoupled
Browse files Browse the repository at this point in the history
  • Loading branch information
michele-milesi committed Mar 29, 2024
1 parent 7b143ed commit beff471
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 0 deletions.
2 changes: 2 additions & 0 deletions sheeprl/algos/sac/sac_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def player(
player_trainer_collective=player_trainer_collective,
ckpt_path=ckpt_path,
replay_buffer=rb if cfg.buffer.checkpoint else None,
ratio_state_dict=ratio.state_dict(),
)

world_collective.scatter_object_list([None], [None] + [-1] * (world_collective.world_size - 1), src=0)
Expand All @@ -322,6 +323,7 @@ def player(
player_trainer_collective=player_trainer_collective,
ckpt_path=ckpt_path,
replay_buffer=rb if cfg.buffer.checkpoint else None,
ratio_state_dict=ratio.state_dict(),
)

envs.close()
Expand Down
3 changes: 3 additions & 0 deletions sheeprl/utils/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,16 @@ def on_checkpoint_player(
player_trainer_collective: TorchCollective,
ckpt_path: str,
replay_buffer: Optional["ReplayBuffer"] = None,
ratio_state_dict: Dict[str, Any] | None = None,
):
state = [None]
player_trainer_collective.broadcast_object_list(state, src=1)
state = state[0]
if replay_buffer is not None:
rb_state = self._ckpt_rb(replay_buffer)
state["rb"] = replay_buffer
if ratio_state_dict is not None:
state["ratio"] = ratio_state_dict
fabric.save(ckpt_path, state)
if replay_buffer is not None:
self._experiment_consistent_rb(replay_buffer, rb_state)
Expand Down

0 comments on commit beff471

Please sign in to comment.