Skip to content

Commit

Permalink
Feature/delete old ckpts (#171)
Browse files Browse the repository at this point in the history
* Add first PPO numpy buffer implementation

* Add distribution cfg to agent

* No need for tensordict

* Add SAC numpy

* Improve sample_next_obs

* Add DV1 with numpy buffer

* Too much reshapes

* Add Sequential and EnvIndipendent np buffers

* Fewer number of reshapes

* Faster indexing + from_numpy parameter

* Dreamer-V2 numpy

* Fix buffer add

* Better indexing

* Fix indexes to sample

* Fix metrics when they are nan

* Fix reshape when bootstrapping + fix normalization

* Guard timer metrics

* np.intp for indexing

* Change dtype after creating the tensor

* Fix buf[key] after __getstate__ is called upon checkpoint

* Securely close fd on __getstate__()

* Add MemmapArray

* Add __len__ function

* Fix len

* Better array setter and __del__ now controls ownership

* Do not transfer ownership upon array setter

* Add properties

* Feature/episode buffer np (#121)

* feat: added episode buffer numpy

* fix: memmap episode buffer numpy

* fix: checkpoint when memmap=True EpisodeBufferNumpy

* fix: memmap episode buffer np

* tests: added tests for episode buffer np

* feat: update episode buffer, added MemmapArray

* Fix not use self._obs_keys

* Sample only if n > 0

* Fix shapes

* feat: added possibility to specify sequence length in sample() + added possibility to add data only to some env

* tests: update episode buffer numpy tests

* tests: added replay buffer np tests

* tests: added sequential replay buffer np tests

* fix: env independent repla buffer name

* fix: replay buffer + add tests

* Safely release buffer on Windows

* Safely delets memmaps

* Del buffer

* Safer array setter

* Add Memmap.from_array

* Fix ReplayBuffer __set_item__

* fix: sac_np sample

* tests: update tests

* tests: update

* fix: sequential replay buffer sample clone

* Add tests + Fix MemmapArray on Windows

* Add tests to run only on Linux

* Fix tests

* Fix skip test on Windows

* Dreamer-V2 with EpisodeBuffer np

* Add user warning if file exists when creating a new MemmapArray

* feat: added dreamer v3 np

* Add docstrings + Fix array setter if shapes differ

* Fix tests

* Add docstring

* Docstrings

* fix: sample of env independent buffer

* Fix locked tensordict

* Add configs

* feat: update np algorithms with new specifications

* fix: mypy

* PokemonRed env from https://github.com/PWhiddy/PokemonRedExperiments/blob/master/baselines/red_gym_env.py

* Update dreamer_v3 with main

* Update dreamer_v2 with main

* Update dreamer_v1 with main

* Update ppo with main

* Update sac with main

* Amend numpy to torch dtype and back dicts

* feat: added np callback

* fix: np callback

* feat: add support functions in np checkpoint callback

* feat: added droq np

* feat: added ppo recurrent np

* feat: added sac-ae np

* Update dreamer algos with main

* feat: added p2e dv1 np

* feat: added p2e dv2 np

* feat: add p2e dv3 np

* feat: added ppo decoupled np

* feat: add sac decoupled

* np.tanh instead of torch.tanh

* feat: from tensordict to buffers np

* from td to np

* exclude mlflow from tests

* No more tensordict

* Updated howto

* Fix tests

* .cpu().numpy() just one time

* Removed old cfgs

* Convert all when hydra instantiating

* convert all on instantiate

* [skip-ci] Removed pokemon files

* fix: git merge related errors

* Fix get absolute path

* Amend dreamer-v3 pokemon config

* feat: added keep_last parameter

* docs: update

* Removed dict from config

---------

Co-authored-by: belerico <[email protected]>
Co-authored-by: belerico_t <[email protected]>
  • Loading branch information
3 people authored Dec 19, 2023
1 parent b437994 commit 46e35cc
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 1 deletion.
2 changes: 2 additions & 0 deletions howto/logs_and_checkpoints.md
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,15 @@ By default the checkpointing is enabled with the following settings:
every: 100
resume_from: null
save_last: True
keep_last: 5
```

meaning that:

* `every` is the number of policy steps (number of steps played in the environment, e.g. if one has 2 processes with 4 environments per process then the policy steps are 2*4=8) between two consecutive checkpointing operations. For more info about the policy steps, check the [Work with Steps Tutorial](./work_with_steps.md).
* `resume_from` is the path of the checkpoint to resume from. If `null`, then the checkpointing is not resumed.
* `save_last` is a boolean flag that enables/disables the saving of the last checkpoint.
* `keep_last` is the number of checkpoints you want to keep during the experiment. If `null`, all the checkpoints are kept.

> **Note**
>
Expand Down
1 change: 1 addition & 0 deletions sheeprl/configs/checkpoint/default.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
every: 100
resume_from: null
save_last: True
keep_last: 5
1 change: 1 addition & 0 deletions sheeprl/configs/fabric/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ accelerator: "cpu"
precision: "32-true"
callbacks:
- _target_: sheeprl.utils.callback.CheckpointCallback
keep_last: "${checkpoint.keep_last}"
2 changes: 1 addition & 1 deletion sheeprl/configs/model_manager/dreamer_v2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,4 @@ models:
target_critic:
model_name: "${exp_name}_target_critic"
description: "DreamerV2 Target Critic used in ${env.id} Environment"
tags: {}
tags: {}
15 changes: 15 additions & 0 deletions sheeprl/utils/callback.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import os
import pathlib
from typing import Any, Dict, Optional, Sequence, Union

from lightning.fabric import Fabric
Expand All @@ -22,6 +24,9 @@ class CheckpointCallback:
When the buffer is added to the state of the checkpoint, it is assumed that the episode is truncated.
"""

def __init__(self, keep_last: int | None = None) -> None:
self.keep_last = keep_last

def on_checkpoint_coupled(
self,
fabric: Fabric,
Expand All @@ -47,6 +52,8 @@ def on_checkpoint_coupled(
fabric.save(ckpt_path, state)
if replay_buffer is not None:
self._experiment_consistent_rb(replay_buffer, rb_state)
if self.keep_last:
self._delete_old_checkpoints(pathlib.Path(ckpt_path).parent)

def on_checkpoint_player(
self,
Expand All @@ -64,6 +71,8 @@ def on_checkpoint_player(
fabric.save(ckpt_path, state)
if replay_buffer is not None:
self._experiment_consistent_rb(replay_buffer, rb_state)
if self.keep_last:
self._delete_old_checkpoints(pathlib.Path(ckpt_path).parent)

def on_checkpoint_trainer(
self, fabric: Fabric, player_trainer_collective: TorchCollective, state: Dict[str, Any], ckpt_path: str
Expand Down Expand Up @@ -128,3 +137,9 @@ def _experiment_consistent_rb(
elif isinstance(rb, EpisodeBuffer):
# reinsert the open episodes to continue the training
rb._open_episodes = state

def _delete_old_checkpoints(self, ckpt_folder: str | pathlib.Path):
ckpts = list(sorted(ckpt_folder.glob("*.ckpt"), key=os.path.getmtime))
if len(ckpts) > self.keep_last:
to_delete = ckpts[: -self.keep_last]
[f.unlink() for f in to_delete]

0 comments on commit 46e35cc

Please sign in to comment.