From f78e52aaff156aaf0e1c9b4c4eba243948191025 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Fri, 27 Oct 2023 11:56:11 +0200 Subject: [PATCH 01/12] feat: added evaluation for all the agents --- README.md | 10 ++ howto/eval_your_agent.md | 118 ++++++++++++++++++++++++ sheeprl/__init__.py | 36 +++++--- sheeprl/algos/dreamer_v1/dreamer_v1.py | 1 - sheeprl/algos/dreamer_v1/evaluate.py | 72 +++++++++++++++ sheeprl/algos/dreamer_v2/evaluate.py | 73 +++++++++++++++ sheeprl/algos/dreamer_v3/evaluate.py | 72 +++++++++++++++ sheeprl/algos/droq/evaluate.py | 75 +++++++++++++++ sheeprl/algos/p2e_dv1/evaluate.py | 73 +++++++++++++++ sheeprl/algos/p2e_dv2/evaluate.py | 74 +++++++++++++++ sheeprl/algos/ppo/evaluate.py | 65 +++++++++++++ sheeprl/algos/ppo_recurrent/evaluate.py | 68 ++++++++++++++ sheeprl/algos/sac/evaluate.py | 66 +++++++++++++ sheeprl/algos/sac_ae/evaluate.py | 118 ++++++++++++++++++++++++ sheeprl/cli.py | 59 +++++++++++- sheeprl/configs/eval_config.yaml | 17 ++++ sheeprl/configs/exp/dreamer_v1.yaml | 2 +- sheeprl/utils/registry.py | 24 ++++- sheeprl_eval.py | 4 + 19 files changed, 1006 insertions(+), 21 deletions(-) create mode 100644 howto/eval_your_agent.md create mode 100644 sheeprl/algos/dreamer_v1/evaluate.py create mode 100644 sheeprl/algos/dreamer_v2/evaluate.py create mode 100644 sheeprl/algos/dreamer_v3/evaluate.py create mode 100644 sheeprl/algos/droq/evaluate.py create mode 100644 sheeprl/algos/p2e_dv1/evaluate.py create mode 100644 sheeprl/algos/p2e_dv2/evaluate.py create mode 100644 sheeprl/algos/ppo/evaluate.py create mode 100644 sheeprl/algos/ppo_recurrent/evaluate.py create mode 100644 sheeprl/algos/sac/evaluate.py create mode 100644 sheeprl/algos/sac_ae/evaluate.py create mode 100644 sheeprl/configs/eval_config.yaml create mode 100644 sheeprl_eval.py diff --git a/README.md b/README.md index 670651d1..86bb8425 100644 --- a/README.md +++ b/README.md @@ -223,6 +223,16 @@ python sheeprl.py fabric.accelerator=cpu fabric.strategy=ddp fabric.devices=2 ex You can check the available parameters for Lightning Fabric [here](https://lightning.ai/docs/fabric/stable/api/fabric_args.html). +### Evaluate your Agents + +You can easily evaluate your trained agents from checkpoints: tranining configurations are retrieved automatically. + +```bash +python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt fabric.accelerator=gpu env.capture_video=True +``` + +For more information, check the corresponding [howto](./howto/eval_your_agent.md). + ## :book: Repository structure The repository is structured as follows: diff --git a/howto/eval_your_agent.md b/howto/eval_your_agent.md new file mode 100644 index 00000000..5222305d --- /dev/null +++ b/howto/eval_your_agent.md @@ -0,0 +1,118 @@ +# Evaluate your Agents + +In this document we give the user some advices to evaluate its agents. To evaluate an agent, it is simply necessary run the `./sheeprl_eval.py` script, by passing the path to the checkpoint of the agent you want to evaluate. + +```bash +python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt +``` + +The agent and the configs used during the traning are loaded automatically. The user can modify only few parameters for evaluation: +1. `fabric.accelerator`: you can use the accelerator you want for evaluating the agent, you just need specify it in the command. For instance, `python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt fabric.accelerator=gpu` for evaluating the agent on the gpu. If you want to choose the GPU, then you need to define the `CUDA_VISIBLE_DEVICES` environment variable in the `.env` file or set it before running the script. For example, you can execute the following command to evaluate your agent on the GPU with index 2: `CUDA_VISIBLE_DEVICES="2" python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt fabric.accelerator=gpu`. +2. `env.capture_video`: you can decide to caputre the video of the episode during the evaluation or not. For instance, `python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt env.capture_video=Ture` for capturing the video of the evaluation. + +> **Note** +> +> You cannot modify the number of processes to spawn. The evaluation is made with only 1 process. + +## Log Directory +By default the evaluation logs are stored in the same folder of the experiment. Suppose to have trained a `PPO` agent in the `CartPole-v1` environment. The log directory is organized as follows: +```tree +logs +└── runs + └── ppo + └── CartPole-v1 + └── 2023-10-27_11-46-05_default_42 + ├── .hydra + │ ├── config.yaml + │ ├── hydra.yaml + │ └── overrides.yaml + ├── cli.log + └── version_0 + ├── checkpoint + │ ├── ckpt_1024_0.ckpt + │ ├── ckpt_1536_0.ckpt + │ └── ckpt_512_0.ckpt + ├── events.out.tfevents.1698399966.72040.0 + ├── memmap_buffer + │ └── rank_0 + │ ├── actions.memmap + │ ├── actions.meta.pt + │ ├── advantages.memmap + │ ├── advantages.meta.pt + │ ├── dones.memmap + │ ├── dones.meta.pt + │ ├── logprobs.memmap + │ ├── logprobs.meta.pt + │ ├── meta.pt + │ ├── returns.memmap + │ ├── returns.meta.pt + │ ├── rewards.memmap + │ ├── rewards.meta.pt + │ ├── state.memmap + │ ├── state.meta.pt + │ ├── values.memmap + │ └── values.meta.pt + └── train_videos + ├── rl-video-episode-0.mp4 + ├── rl-video-episode-1.mp4 + └── rl-video-episode-8.mp4 +``` + +Where `./logs/runs/ppo/2023-10-27_11-46-05_default_42` contains your experiment. The evaluation script will create a subfolder, named `evaluation`, in the `./logs/runs/ppo/2023-10-27_11-46-05_default_42/version_0` folder, which will contain all the evaluations of the agents. + +For example, if we run two evaluations, then the log directory of the experiment will be as follows: +```diff +logs +└── runs + ├── .hydra + │ ├── config.yaml + │ ├── hydra.yaml + │ └── overrides.yaml + ├── cli.log + └── ppo + └── CartPole-v1 + └── 2023-10-27_11-46-05_default_42 + ├── .hydra + │ ├── config.yaml + │ ├── hydra.yaml + │ └── overrides.yaml + ├── cli.log + └── version_0 + ├── checkpoint + │ ├── ckpt_1024_0.ckpt + │ ├── ckpt_1536_0.ckpt + │ └── ckpt_512_0.ckpt ++ ├── evaluation ++ │ ├── version_0 ++ │ │ ├── events.out.tfevents.1698400212.73839.0 ++ │ │ └── test_videos ++ │ │ └── rl-video-episode-0.mp4 ++ │ └── version_1 ++ │ ├── events.out.tfevents.1698400283.74353.0 ++ │ └── test_videos ++ │ └── rl-video-episode-0.mp4 + ├── events.out.tfevents.1698399966.72040.0 + ├── memmap_buffer + │ └── rank_0 + │ ├── actions.memmap + │ ├── actions.meta.pt + │ ├── advantages.memmap + │ ├── advantages.meta.pt + │ ├── dones.memmap + │ ├── dones.meta.pt + │ ├── logprobs.memmap + │ ├── logprobs.meta.pt + │ ├── meta.pt + │ ├── returns.memmap + │ ├── returns.meta.pt + │ ├── rewards.memmap + │ ├── rewards.meta.pt + │ ├── state.memmap + │ ├── state.meta.pt + │ ├── values.memmap + │ └── values.meta.pt + └── train_videos + ├── rl-video-episode-0.mp4 + ├── rl-video-episode-1.mp4 + └── rl-video-episode-8.mp4 +``` \ No newline at end of file diff --git a/sheeprl/__init__.py b/sheeprl/__init__.py index 9d676f1e..b1fb7fe3 100644 --- a/sheeprl/__init__.py +++ b/sheeprl/__init__.py @@ -11,22 +11,32 @@ if not _IS_TORCH_GREATER_EQUAL_2_0: raise ModuleNotFoundError(_IS_TORCH_GREATER_EQUAL_2_0) -# Needed because MineRL 0.4.4 is not compatible with the latest version of numpy import numpy as np -from sheeprl.algos.dreamer_v1 import dreamer_v1 as dreamer_v1 -from sheeprl.algos.dreamer_v2 import dreamer_v2 as dreamer_v2 -from sheeprl.algos.dreamer_v3 import dreamer_v3 as dreamer_v3 -from sheeprl.algos.droq import droq as droq -from sheeprl.algos.p2e_dv1 import p2e_dv1 as p2e_dv1 -from sheeprl.algos.p2e_dv2 import p2e_dv2 as p2e_dv2 -from sheeprl.algos.ppo import ppo as ppo -from sheeprl.algos.ppo import ppo_decoupled as ppo_decoupled -from sheeprl.algos.ppo_recurrent import ppo_recurrent as ppo_recurrent -from sheeprl.algos.sac import sac as sac -from sheeprl.algos.sac import sac_decoupled as sac_decoupled -from sheeprl.algos.sac_ae import sac_ae as sac_ae +from sheeprl.algos.dreamer_v1 import dreamer_v1 # noqa: F401 +from sheeprl.algos.dreamer_v1 import evaluate as dreamer_v1_evaluate # noqa: F401 +from sheeprl.algos.dreamer_v2 import dreamer_v2 # noqa: F401 +from sheeprl.algos.dreamer_v2 import evaluate as dreamer_v2_evaluate # noqa: F401 +from sheeprl.algos.dreamer_v3 import dreamer_v3 # noqa: F401 +from sheeprl.algos.dreamer_v3 import evaluate as dreamer_v3_evaluate # noqa: F401 +from sheeprl.algos.droq import droq # noqa: F401 +from sheeprl.algos.droq import evaluate as droq_evaluate # noqa: F401 +from sheeprl.algos.p2e_dv1 import evaluate as p2e_dv1_evaluate # noqa: F401 +from sheeprl.algos.p2e_dv1 import p2e_dv1 # noqa: F401 +from sheeprl.algos.p2e_dv2 import evaluate as p2e_dv2_evaluate # noqa: F401 +from sheeprl.algos.p2e_dv2 import p2e_dv2 # noqa: F401 +from sheeprl.algos.ppo import evaluate as ppo_evaluate # noqa: F401 +from sheeprl.algos.ppo import ppo # noqa: F401 +from sheeprl.algos.ppo import ppo_decoupled # noqa: F401 +from sheeprl.algos.ppo_recurrent import evaluate as ppo_recurrent_evaluate # noqa: F401 +from sheeprl.algos.ppo_recurrent import ppo_recurrent # noqa: F401 +from sheeprl.algos.sac import evaluate as sac_evaluate # noqa: F401 +from sheeprl.algos.sac import sac # noqa: F401 +from sheeprl.algos.sac import sac_decoupled # noqa: F401 +from sheeprl.algos.sac_ae import evaluate as sac_ae_evaluate # noqa: F401 +from sheeprl.algos.sac_ae import sac_ae # noqa: F401 +# Needed because MineRL 0.4.4 is not compatible with the latest version of numpy np.float = np.float32 np.int = np.int64 np.bool = bool diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index 5bcc15ff..76fa2adc 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -454,7 +454,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): actions_dim = ( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) - # observation_shape = observation_space["rgb"].shape clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") diff --git a/sheeprl/algos/dreamer_v1/evaluate.py b/sheeprl/algos/dreamer_v1/evaluate.py new file mode 100644 index 00000000..00a47ee0 --- /dev/null +++ b/sheeprl/algos/dreamer_v1/evaluate.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from typing import Any, Dict + +import gymnasium as gym +from lightning import Fabric + +from sheeprl.algos.dreamer_v1.agent import PlayerDV1, build_models +from sheeprl.algos.dreamer_v2.utils import test +from sheeprl.utils.env import make_env +from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.registry import register_evaluation + + +@register_evaluation(name="dreamer_v1") +def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): + logger = create_tensorboard_logger(fabric, cfg) + if logger and fabric.is_global_zero: + fabric._loggers = [logger] + fabric.logger.log_hyperparams(cfg) + log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + + env = make_env( + cfg, + cfg.seed, + 0, + log_dir, + "test", + vector_env_idx=0, + )() + observation_space = env.observation_space + action_space = env.action_space + + if not isinstance(observation_space, gym.spaces.Dict): + raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") + if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []: + raise RuntimeError( + "You should specify at least one CNN keys or MLP keys from the cli: " + "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" + ) + fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + + is_continuous = isinstance(action_space, gym.spaces.Box) + is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) + actions_dim = ( + action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) + ) + # Create the actor and critic models + world_model, actor, _ = build_models( + fabric, + actions_dim, + is_continuous, + cfg, + observation_space, + state["world_model"], + state["actor"], + ) + player = PlayerDV1( + world_model.encoder.module, + world_model.rssm.recurrent_model.module, + world_model.rssm.representation_model.module, + actor.module, + actions_dim, + cfg.algo.player.expl_amount, + cfg.env.num_envs, + cfg.algo.world_model.stochastic_size, + cfg.algo.world_model.recurrent_model.recurrent_state_size, + fabric.device, + ) + + test(player, fabric, cfg, log_dir, sample_actions=False) diff --git a/sheeprl/algos/dreamer_v2/evaluate.py b/sheeprl/algos/dreamer_v2/evaluate.py new file mode 100644 index 00000000..9f71b246 --- /dev/null +++ b/sheeprl/algos/dreamer_v2/evaluate.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from typing import Any, Dict + +import gymnasium as gym +from lightning import Fabric + +from sheeprl.algos.dreamer_v2.agent import PlayerDV2, build_models +from sheeprl.algos.dreamer_v2.utils import test +from sheeprl.utils.env import make_env +from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.registry import register_evaluation + + +@register_evaluation(name="dreamer_v2") +def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): + logger = create_tensorboard_logger(fabric, cfg) + if logger and fabric.is_global_zero: + fabric._loggers = [logger] + fabric.logger.log_hyperparams(cfg) + log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + + env = make_env( + cfg, + cfg.seed, + 0, + log_dir, + "test", + vector_env_idx=0, + )() + observation_space = env.observation_space + action_space = env.action_space + + if not isinstance(observation_space, gym.spaces.Dict): + raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") + if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []: + raise RuntimeError( + "You should specify at least one CNN keys or MLP keys from the cli: " + "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" + ) + fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + + is_continuous = isinstance(action_space, gym.spaces.Box) + is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) + actions_dim = ( + action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) + ) + # Create the actor and critic models + world_model, actor, _, _ = build_models( + fabric, + actions_dim, + is_continuous, + cfg, + observation_space, + state["world_model"], + state["actor"], + ) + player = PlayerDV2( + world_model.encoder.module, + world_model.rssm.recurrent_model.module, + world_model.rssm.representation_model.module, + actor.module, + actions_dim, + cfg.algo.player.expl_amount, + cfg.env.num_envs, + cfg.algo.world_model.stochastic_size, + cfg.algo.world_model.recurrent_model.recurrent_state_size, + fabric.device, + discrete_size=cfg.algo.world_model.discrete_size, + ) + + test(player, fabric, cfg, log_dir, sample_actions=False) diff --git a/sheeprl/algos/dreamer_v3/evaluate.py b/sheeprl/algos/dreamer_v3/evaluate.py new file mode 100644 index 00000000..ff327268 --- /dev/null +++ b/sheeprl/algos/dreamer_v3/evaluate.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from typing import Any, Dict + +import gymnasium as gym +from lightning import Fabric + +from sheeprl.algos.dreamer_v3.agent import PlayerDV3, build_models +from sheeprl.algos.dreamer_v3.utils import test +from sheeprl.utils.env import make_env +from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.registry import register_evaluation + + +@register_evaluation(name="dreamer_v3") +def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): + logger = create_tensorboard_logger(fabric, cfg) + if logger and fabric.is_global_zero: + fabric._loggers = [logger] + fabric.logger.log_hyperparams(cfg) + log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + + env = make_env( + cfg, + cfg.seed, + 0, + log_dir, + "test", + vector_env_idx=0, + )() + observation_space = env.observation_space + action_space = env.action_space + + if not isinstance(observation_space, gym.spaces.Dict): + raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") + if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []: + raise RuntimeError( + "You should specify at least one CNN keys or MLP keys from the cli: " + "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" + ) + fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + + is_continuous = isinstance(action_space, gym.spaces.Box) + is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) + actions_dim = ( + action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) + ) + # Create the actor and critic models + world_model, actor, _, _ = build_models( + fabric, + actions_dim, + is_continuous, + cfg, + observation_space, + state["world_model"], + state["actor"], + ) + player = PlayerDV3( + world_model.encoder.module, + world_model.rssm, + actor.module, + actions_dim, + cfg.algo.player.expl_amount, + cfg.env.num_envs, + cfg.algo.world_model.stochastic_size, + cfg.algo.world_model.recurrent_model.recurrent_state_size, + fabric.device, + discrete_size=cfg.algo.world_model.discrete_size, + ) + + test(player, fabric, cfg, log_dir, sample_actions=True) diff --git a/sheeprl/algos/droq/evaluate.py b/sheeprl/algos/droq/evaluate.py new file mode 100644 index 00000000..f23ff60d --- /dev/null +++ b/sheeprl/algos/droq/evaluate.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from math import prod +from typing import Any, Dict + +import gymnasium as gym +from lightning import Fabric + +from sheeprl.algos.droq.agent import DROQAgent, DROQCritic +from sheeprl.algos.sac.agent import SACActor +from sheeprl.algos.sac.utils import test +from sheeprl.utils.env import make_env +from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.registry import register_evaluation + + +@register_evaluation(name="droq") +def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): + logger = create_tensorboard_logger(fabric, cfg) + if logger and fabric.is_global_zero: + fabric._loggers = [logger] + fabric.logger.log_hyperparams(cfg) + log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + + env = make_env( + cfg, + cfg.seed, + 0, + log_dir, + "test", + vector_env_idx=0, + )() + action_space = env.action_space + observation_space = env.observation_space + if not isinstance(action_space, gym.spaces.Box): + raise ValueError("Only continuous action space is supported for the DroQ agent") + if not isinstance(observation_space, gym.spaces.Dict): + raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") + if len(cfg.mlp_keys.encoder) == 0: + raise RuntimeError("You should specify at least one MLP key for the encoder: `mlp_keys.encoder=[state]`") + for k in cfg.mlp_keys.encoder: + if len(observation_space[k].shape) > 1: + raise ValueError( + "Only environments with vector-only observations are supported by the DroQ agent. " + f"Provided environment: {cfg.env.id}" + ) + if cfg.metric.log_level > 0: + fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + + act_dim = prod(action_space.shape) + obs_dim = sum([prod(observation_space[k].shape) for k in cfg.mlp_keys.encoder]) + actor = SACActor( + observation_dim=obs_dim, + action_dim=act_dim, + distribution_cfg=cfg.distribution, + hidden_size=cfg.algo.actor.hidden_size, + action_low=action_space.low, + action_high=action_space.high, + ) + critics = [ + DROQCritic( + observation_dim=obs_dim + act_dim, + hidden_size=cfg.algo.critic.hidden_size, + num_critics=1, + dropout=cfg.algo.critic.dropout, + ) + for _ in range(cfg.algo.critic.n) + ] + target_entropy = -act_dim + agent = DROQAgent( + actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device + ) + agent.load_state_dict(state["agent"]) + + test(agent.actor, fabric, cfg, log_dir) diff --git a/sheeprl/algos/p2e_dv1/evaluate.py b/sheeprl/algos/p2e_dv1/evaluate.py new file mode 100644 index 00000000..df4c53fd --- /dev/null +++ b/sheeprl/algos/p2e_dv1/evaluate.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from typing import Any, Dict + +import gymnasium as gym +from lightning import Fabric + +from sheeprl.algos.dreamer_v1.agent import PlayerDV1 +from sheeprl.algos.dreamer_v2.utils import test +from sheeprl.algos.p2e_dv1.agent import build_models +from sheeprl.utils.env import make_env +from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.registry import register_evaluation + + +@register_evaluation(name="p2e_dv1") +def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): + logger = create_tensorboard_logger(fabric, cfg) + if logger and fabric.is_global_zero: + fabric._loggers = [logger] + fabric.logger.log_hyperparams(cfg) + log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + + env = make_env( + cfg, + cfg.seed, + 0, + log_dir, + "test", + vector_env_idx=0, + )() + observation_space = env.observation_space + action_space = env.action_space + + if not isinstance(observation_space, gym.spaces.Dict): + raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") + if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []: + raise RuntimeError( + "You should specify at least one CNN keys or MLP keys from the cli: " + "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" + ) + fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + + is_continuous = isinstance(action_space, gym.spaces.Box) + is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) + actions_dim = ( + action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) + ) + # Create the actor and critic models + world_model, actor_task, _, _, _ = build_models( + fabric, + actions_dim, + is_continuous, + cfg, + observation_space, + state["world_model"], + state["actor_task"], + ) + player = PlayerDV1( + world_model.encoder.module, + world_model.rssm.recurrent_model.module, + world_model.rssm.representation_model.module, + actor_task.module, + actions_dim, + cfg.algo.player.expl_amount, + cfg.env.num_envs, + cfg.algo.world_model.stochastic_size, + cfg.algo.world_model.recurrent_model.recurrent_state_size, + fabric.device, + ) + + test(player, fabric, cfg, log_dir, sample_actions=False) diff --git a/sheeprl/algos/p2e_dv2/evaluate.py b/sheeprl/algos/p2e_dv2/evaluate.py new file mode 100644 index 00000000..4a2937e4 --- /dev/null +++ b/sheeprl/algos/p2e_dv2/evaluate.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from typing import Any, Dict + +import gymnasium as gym +from lightning import Fabric + +from sheeprl.algos.dreamer_v2.agent import PlayerDV2 +from sheeprl.algos.dreamer_v2.utils import test +from sheeprl.algos.p2e_dv2.agent import build_models +from sheeprl.utils.env import make_env +from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.registry import register_evaluation + + +@register_evaluation(name="p2e_dv2") +def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): + logger = create_tensorboard_logger(fabric, cfg) + if logger and fabric.is_global_zero: + fabric._loggers = [logger] + fabric.logger.log_hyperparams(cfg) + log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + + env = make_env( + cfg, + cfg.seed, + 0, + log_dir, + "test", + vector_env_idx=0, + )() + observation_space = env.observation_space + action_space = env.action_space + + if not isinstance(observation_space, gym.spaces.Dict): + raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") + if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []: + raise RuntimeError( + "You should specify at least one CNN keys or MLP keys from the cli: " + "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" + ) + fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + + is_continuous = isinstance(action_space, gym.spaces.Box) + is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) + actions_dim = ( + action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) + ) + # Create the actor and critic models + world_model, actor_task, _, _, _, _, _ = build_models( + fabric, + actions_dim, + is_continuous, + cfg, + observation_space, + state["world_model"], + state["actor_task"], + ) + player = PlayerDV2( + world_model.encoder.module, + world_model.rssm.recurrent_model.module, + world_model.rssm.representation_model.module, + actor_task.module, + actions_dim, + cfg.algo.player.expl_amount, + cfg.env.num_envs, + cfg.algo.world_model.stochastic_size, + cfg.algo.world_model.recurrent_model.recurrent_state_size, + fabric.device, + discrete_size=cfg.algo.world_model.discrete_size, + ) + + test(player, fabric, cfg, log_dir, sample_actions=False) diff --git a/sheeprl/algos/ppo/evaluate.py b/sheeprl/algos/ppo/evaluate.py new file mode 100644 index 00000000..39f5e8f8 --- /dev/null +++ b/sheeprl/algos/ppo/evaluate.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from typing import Any, Dict + +import gymnasium as gym +from lightning import Fabric + +from sheeprl.algos.ppo.agent import PPOAgent +from sheeprl.algos.ppo.utils import test +from sheeprl.utils.env import make_env +from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.registry import register_evaluation + + +@register_evaluation(name="ppo") +def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): + logger = create_tensorboard_logger(fabric, cfg) + if logger and fabric.is_global_zero: + fabric._loggers = [logger] + fabric.logger.log_hyperparams(cfg) + log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + + env = make_env( + cfg, + cfg.seed, + 0, + log_dir, + "test", + vector_env_idx=0, + )() + observation_space = env.observation_space + + if not isinstance(observation_space, gym.spaces.Dict): + raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") + if cfg.cnn_keys.encoder + cfg.mlp_keys.encoder == []: + raise RuntimeError( + "You should specify at least one CNN keys or MLP keys from the cli: " + "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" + ) + fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + + is_continuous = isinstance(env.action_space, gym.spaces.Box) + is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete) + actions_dim = ( + env.action_space.shape + if is_continuous + else (env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]) + ) + # Create the actor and critic models + agent = PPOAgent( + actions_dim=actions_dim, + obs_space=observation_space, + encoder_cfg=cfg.algo.encoder, + actor_cfg=cfg.algo.actor, + critic_cfg=cfg.algo.critic, + cnn_keys=cfg.cnn_keys.encoder, + mlp_keys=cfg.mlp_keys.encoder, + screen_size=cfg.env.screen_size, + distribution_cfg=cfg.distribution, + is_continuous=is_continuous, + ) + agent.load_state_dict(state["agent"]) + + test(agent, fabric, cfg, log_dir) diff --git a/sheeprl/algos/ppo_recurrent/evaluate.py b/sheeprl/algos/ppo_recurrent/evaluate.py new file mode 100644 index 00000000..707d9bb2 --- /dev/null +++ b/sheeprl/algos/ppo_recurrent/evaluate.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from typing import Any, Dict + +import gymnasium as gym +from lightning import Fabric + +from sheeprl.algos.ppo_recurrent.agent import RecurrentPPOAgent +from sheeprl.algos.ppo_recurrent.utils import test +from sheeprl.utils.env import make_env +from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.registry import register_evaluation + + +@register_evaluation(name="ppo_recurrent") +def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): + logger = create_tensorboard_logger(fabric, cfg) + if logger and fabric.is_global_zero: + fabric._loggers = [logger] + fabric.logger.log_hyperparams(cfg) + log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + + env = make_env( + cfg, + cfg.seed, + 0, + log_dir, + "test", + vector_env_idx=0, + )() + observation_space = env.observation_space + + if not isinstance(observation_space, gym.spaces.Dict): + raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") + if cfg.cnn_keys.encoder + cfg.mlp_keys.encoder == []: + raise RuntimeError( + "You should specify at least one CNN keys or MLP keys from the cli: " + "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" + ) + fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + + is_continuous = isinstance(env.action_space, gym.spaces.Box) + is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete) + actions_dim = ( + env.action_space.shape + if is_continuous + else (env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]) + ) + # Create the actor and critic models + agent = RecurrentPPOAgent( + actions_dim=actions_dim, + obs_space=observation_space, + encoder_cfg=cfg.algo.encoder, + rnn_cfg=cfg.algo.rnn, + actor_cfg=cfg.algo.actor, + critic_cfg=cfg.algo.critic, + cnn_keys=cfg.cnn_keys.encoder, + mlp_keys=cfg.mlp_keys.encoder, + is_continuous=is_continuous, + distribution_cfg=cfg.distribution, + num_envs=cfg.env.num_envs, + screen_size=cfg.env.screen_size, + device=fabric.device, + ) + agent.load_state_dict(state["agent"]) + + test(agent, fabric, cfg, log_dir) diff --git a/sheeprl/algos/sac/evaluate.py b/sheeprl/algos/sac/evaluate.py new file mode 100644 index 00000000..4f1b2cad --- /dev/null +++ b/sheeprl/algos/sac/evaluate.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from math import prod +from typing import Any, Dict + +import gymnasium as gym +from lightning import Fabric + +from sheeprl.algos.sac.agent import SACActor, SACAgent, SACCritic +from sheeprl.algos.sac.utils import test +from sheeprl.utils.env import make_env +from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.registry import register_evaluation + + +@register_evaluation(name="sac") +def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): + logger = create_tensorboard_logger(fabric, cfg) + if logger and fabric.is_global_zero: + fabric._loggers = [logger] + fabric.logger.log_hyperparams(cfg) + log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + + env = make_env( + cfg, + cfg.seed, + 0, + log_dir, + "test", + vector_env_idx=0, + )() + action_space = env.action_space + observation_space = env.observation_space + if not isinstance(action_space, gym.spaces.Box): + raise ValueError("Only continuous action space is supported for the SAC agent") + if not isinstance(observation_space, gym.spaces.Dict): + raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") + if len(cfg.mlp_keys.encoder) == 0: + raise RuntimeError("You should specify at least one MLP key for the encoder: `mlp_keys.encoder=[state]`") + for k in cfg.mlp_keys.encoder: + if len(observation_space[k].shape) > 1: + raise ValueError( + "Only environments with vector-only observations are supported by the SAC agent. " + f"Provided environment: {cfg.env.id}" + ) + fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + + act_dim = prod(action_space.shape) + obs_dim = sum([prod(observation_space[k].shape) for k in cfg.mlp_keys.encoder]) + actor = SACActor( + observation_dim=obs_dim, + action_dim=act_dim, + distribution_cfg=cfg.distribution, + hidden_size=cfg.algo.actor.hidden_size, + action_low=action_space.low, + action_high=action_space.high, + ) + critics = [ + SACCritic(observation_dim=obs_dim + act_dim, hidden_size=cfg.algo.critic.hidden_size, num_critics=1) + for _ in range(cfg.algo.critic.n) + ] + target_entropy = -act_dim + agent = SACAgent(actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device) + agent.load_state_dict(state["agent"]) + + test(agent.actor, fabric, cfg, log_dir) diff --git a/sheeprl/algos/sac_ae/evaluate.py b/sheeprl/algos/sac_ae/evaluate.py new file mode 100644 index 00000000..a3b05ed3 --- /dev/null +++ b/sheeprl/algos/sac_ae/evaluate.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +import copy +from math import prod +from typing import Any, Dict + +import gymnasium as gym +from lightning import Fabric + +from sheeprl.algos.sac_ae.agent import ( + CNNEncoder, + MLPEncoder, + SACAEAgent, + SACAEContinuousActor, + SACAECritic, + SACAEQFunction, +) +from sheeprl.algos.sac_ae.utils import test_sac_ae +from sheeprl.models.models import MultiEncoder +from sheeprl.utils.env import make_env +from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.registry import register_evaluation + + +@register_evaluation(name="sac_ae") +def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): + logger = create_tensorboard_logger(fabric, cfg) + if logger and fabric.is_global_zero: + fabric._loggers = [logger] + fabric.logger.log_hyperparams(cfg) + log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + + env = make_env( + cfg, + cfg.seed, + 0, + log_dir, + "test", + vector_env_idx=0, + )() + + observation_space = env.observation_space + action_space = env.action_space + if not isinstance(observation_space, gym.spaces.Dict): + raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") + if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []: + raise RuntimeError( + "You should specify at least one CNN keys or MLP keys from the cli: " + "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" + ) + fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + + act_dim = prod(action_space.shape) + target_entropy = -act_dim + + # Define the encoder and decoder and setup them with fabric. + # Then we will set the critic encoder and actor decoder as the unwrapped encoder module: + # we do not need it wrapped with the strategy inside actor and critic + cnn_channels = [prod(observation_space[k].shape[:-2]) for k in cfg.cnn_keys.encoder] + mlp_dims = [observation_space[k].shape[0] for k in cfg.mlp_keys.encoder] + cnn_encoder = ( + CNNEncoder( + in_channels=sum(cnn_channels), + features_dim=cfg.algo.encoder.features_dim, + keys=cfg.cnn_keys.encoder, + screen_size=cfg.env.screen_size, + cnn_channels_multiplier=cfg.algo.encoder.cnn_channels_multiplier, + ) + if cfg.cnn_keys.encoder is not None and len(cfg.cnn_keys.encoder) > 0 + else None + ) + mlp_encoder = ( + MLPEncoder( + sum(mlp_dims), + cfg.mlp_keys.encoder, + cfg.algo.encoder.dense_units, + cfg.algo.encoder.mlp_layers, + eval(cfg.algo.encoder.dense_act), + cfg.algo.encoder.layer_norm, + ) + if cfg.mlp_keys.encoder is not None and len(cfg.mlp_keys.encoder) > 0 + else None + ) + encoder = MultiEncoder(cnn_encoder, mlp_encoder) + + # Setup actor and critic. Those will initialize with orthogonal weights + # both the actor and critic + actor = SACAEContinuousActor( + encoder=copy.deepcopy(encoder), + action_dim=act_dim, + distribution_cfg=cfg.distribution, + hidden_size=cfg.algo.actor.hidden_size, + action_low=action_space.low, + action_high=action_space.high, + ) + qfs = [ + SACAEQFunction( + input_dim=encoder.output_dim, action_dim=act_dim, hidden_size=cfg.algo.critic.hidden_size, output_dim=1 + ) + for _ in range(cfg.algo.critic.n) + ] + critic = SACAECritic(encoder=encoder, qfs=qfs) + + # The agent will tied convolutional and linear weights between the encoder actor and critic + agent = SACAEAgent( + actor, + critic, + target_entropy, + alpha=cfg.algo.alpha.alpha, + tau=cfg.algo.tau, + encoder_tau=cfg.algo.encoder.tau, + device=fabric.device, + ) + agent.load_state_dict(state["agent"]) + encoder.load_state_dict(state["encoder"]) + + test_sac_ae(agent.actor, fabric, cfg, log_dir) diff --git a/sheeprl/cli.py b/sheeprl/cli.py index 64302a00..2782a1df 100644 --- a/sheeprl/cli.py +++ b/sheeprl/cli.py @@ -1,6 +1,7 @@ import datetime import importlib import os +import pathlib import time import warnings @@ -13,7 +14,7 @@ from sheeprl.utils.callback import CheckpointCallback from sheeprl.utils.metric import MetricAggregator -from sheeprl.utils.registry import tasks +from sheeprl.utils.registry import evaluation_registry, tasks from sheeprl.utils.timer import timer from sheeprl.utils.utils import dotdict, print_config @@ -100,6 +101,57 @@ def run_algorithm(cfg: DictConfig): fabric.launch(command, cfg) +def eval_algorithm(cfg: DictConfig): + """Run the algorithm specified in the configuration. + + Args: + cfg (DictConfig): the loaded configuration. + """ + if cfg.checkpoint_path is None: + raise ValueError("You must specify the evaluation checkpoint path") + cfg = dotdict(OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)) + capture_video = cfg.env.capture_video + + fabric = Fabric(**cfg.fabric, devices=1) + + state = fabric.load(cfg.checkpoint_path) + ckpt_path = pathlib.Path(cfg.checkpoint_path) + cfg = dotdict(OmegaConf.load(ckpt_path.parent.parent.parent / ".hydra" / "config.yaml")) + cfg.run_name = str( + os.path.join( + os.path.basename(ckpt_path.parent.parent.parent), os.path.basename(ckpt_path.parent.parent), "evaluation" + ) + ) + cfg.checkpoint_path = str(ckpt_path) + cfg.env.num_envs = 1 + cfg.env.capture_video = capture_video + + # Given the algorithm's name, retrieve the module where + # 'cfg.algo.name'.py is contained; from there retrieve the + # `register_algorithm`-decorated entrypoint; + # the entrypoint will be launched by Fabric with `fabric.launch(entrypoint)` + module = None + entrypoint = None + algo_name = cfg.algo.name.replace("_decoupled", "") + for _module, _algos in evaluation_registry.items(): + for _algo in _algos: + if algo_name == _algo["name"]: + module = _module + entrypoint = _algo["entrypoint"] + break + if module is None: + raise RuntimeError(f"Given the algorithm named `{algo_name}`, no module has been found to be imported.") + if entrypoint is None: + raise RuntimeError( + f"Given the module and algorithm named `{module}` and `{algo_name}` respectively, " + "no entrypoint has been found to be imported." + ) + task = importlib.import_module(f"{module}.evaluate") + command = task.__dict__[entrypoint] + + fabric.launch(command, cfg, state) + + def check_configs(cfg: DictConfig): """Check the validity of the configuration. @@ -113,3 +165,8 @@ def run(cfg: DictConfig): """SheepRL zero-code command line utility.""" check_configs(cfg) run_algorithm(cfg) + + +@hydra.main(version_base="1.13", config_path="configs", config_name="eval_config") +def evaluation(cfg: DictConfig): + eval_algorithm(cfg) diff --git a/sheeprl/configs/eval_config.yaml b/sheeprl/configs/eval_config.yaml new file mode 100644 index 00000000..9e7594fa --- /dev/null +++ b/sheeprl/configs/eval_config.yaml @@ -0,0 +1,17 @@ +# @package _global_ + +# specify here default training configuration +defaults: + - _self_ + +hydra: + run: + dir: logs/runs/ + +fabric: + accelerator: cpu + +env: + capture_video: True + +checkpoint_path: ??? \ No newline at end of file diff --git a/sheeprl/configs/exp/dreamer_v1.yaml b/sheeprl/configs/exp/dreamer_v1.yaml index 436200d3..c2e07d16 100644 --- a/sheeprl/configs/exp/dreamer_v1.yaml +++ b/sheeprl/configs/exp/dreamer_v1.yaml @@ -2,7 +2,7 @@ defaults: - override /algo: dreamer_v1 - - override /env: dmc + - override /env: atari - _self_ # Experiment diff --git a/sheeprl/utils/registry.py b/sheeprl/utils/registry.py index a914e2a2..2faf7707 100644 --- a/sheeprl/utils/registry.py +++ b/sheeprl/utils/registry.py @@ -7,23 +7,30 @@ # where `module` and `algorithm` are respectively taken from sheeprl/algos/{module}/{algorithm}.py, # while `entrypoint` is the decorated function tasks: Dict[str, List[Dict[str, Any]]] = {} +evaluation_registry: Dict[str, List[Dict[str, Any]]] = {} -def _register(fn: Callable[..., Any], decoupled: bool = False) -> Callable[..., Any]: +def _register( + fn: Callable[..., Any], decoupled: bool = False, type: str = "algorithm", name: str | None = None +) -> Callable[..., Any]: # lookup containing module if fn.__module__ == "__main__": return fn entrypoint = fn.__name__ module_split = fn.__module__.split(".") - algorithm = module_split[-1] + algorithm = module_split[-1] if name is None else name module = ".".join(module_split[:-1]) - algos = tasks.get(module, None) + if type == "algorithm": + registry = tasks + else: + registry = evaluation_registry + algos = registry.get(module, None) if algos is None: - tasks[module] = [{"name": algorithm, "entrypoint": entrypoint, "decoupled": decoupled}] + registry[module] = [{"name": algorithm, "entrypoint": entrypoint, "decoupled": decoupled}] else: if algorithm in algos: raise ValueError(f"The algorithm `{algorithm}` has already been registered!") - tasks[module].append({"name": algorithm, "entrypoint": entrypoint, "decoupled": decoupled}) + registry[module].append({"name": algorithm, "entrypoint": entrypoint, "decoupled": decoupled}) # add the decorated function to __all__ in algorithm mod = sys.modules[fn.__module__] @@ -39,3 +46,10 @@ def inner_decorator(fn): return _register(fn, decoupled=decoupled) return inner_decorator + + +def register_evaluation(name: str | None = None): + def inner_decorator(fn): + return _register(fn, type="evaluation", name=name) + + return inner_decorator diff --git a/sheeprl_eval.py b/sheeprl_eval.py new file mode 100644 index 00000000..12ce6693 --- /dev/null +++ b/sheeprl_eval.py @@ -0,0 +1,4 @@ +from sheeprl.cli import evaluation + +if __name__ == "__main__": + evaluation() From 2a039cc5142d5223b7245ea95a0e744e517fa920 Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Fri, 27 Oct 2023 12:03:21 +0200 Subject: [PATCH 02/12] feat: added check configs evaluation function --- sheeprl/cli.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sheeprl/cli.py b/sheeprl/cli.py index 2782a1df..11abb382 100644 --- a/sheeprl/cli.py +++ b/sheeprl/cli.py @@ -107,8 +107,6 @@ def eval_algorithm(cfg: DictConfig): Args: cfg (DictConfig): the loaded configuration. """ - if cfg.checkpoint_path is None: - raise ValueError("You must specify the evaluation checkpoint path") cfg = dotdict(OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)) capture_video = cfg.env.capture_video @@ -160,6 +158,11 @@ def check_configs(cfg: DictConfig): """ +def check_configs_evaluation(cfg: DictConfig): + if cfg.checkpoint_path is None: + raise ValueError("You must specify the evaluation checkpoint path") + + @hydra.main(version_base="1.13", config_path="configs", config_name="config") def run(cfg: DictConfig): """SheepRL zero-code command line utility.""" @@ -169,4 +172,5 @@ def run(cfg: DictConfig): @hydra.main(version_base="1.13", config_path="configs", config_name="eval_config") def evaluation(cfg: DictConfig): + check_configs_evaluation(cfg) eval_algorithm(cfg) From 3d34da3ff1e7bc7feb181a9a9351b54b98f9e4bf Mon Sep 17 00:00:00 2001 From: belerico_t Date: Mon, 30 Oct 2023 14:51:33 +0100 Subject: [PATCH 03/12] Checks on evaluate function --- sheeprl/__init__.py | 23 ++++---- sheeprl/algos/dreamer_v1/evaluate.py | 2 +- sheeprl/algos/dreamer_v2/evaluate.py | 2 +- sheeprl/algos/dreamer_v3/evaluate.py | 2 +- sheeprl/algos/droq/evaluate.py | 2 +- sheeprl/algos/p2e_dv1/evaluate.py | 2 +- sheeprl/algos/p2e_dv2/evaluate.py | 2 +- sheeprl/algos/ppo/evaluate.py | 2 +- sheeprl/utils/registry.py | 78 ++++++++++++++++++++++------ 9 files changed, 81 insertions(+), 34 deletions(-) diff --git a/sheeprl/__init__.py b/sheeprl/__init__.py index b1fb7fe3..bb3b8ae0 100644 --- a/sheeprl/__init__.py +++ b/sheeprl/__init__.py @@ -13,29 +13,32 @@ import numpy as np +# fmt: off from sheeprl.algos.dreamer_v1 import dreamer_v1 # noqa: F401 -from sheeprl.algos.dreamer_v1 import evaluate as dreamer_v1_evaluate # noqa: F401 from sheeprl.algos.dreamer_v2 import dreamer_v2 # noqa: F401 -from sheeprl.algos.dreamer_v2 import evaluate as dreamer_v2_evaluate # noqa: F401 from sheeprl.algos.dreamer_v3 import dreamer_v3 # noqa: F401 -from sheeprl.algos.dreamer_v3 import evaluate as dreamer_v3_evaluate # noqa: F401 from sheeprl.algos.droq import droq # noqa: F401 -from sheeprl.algos.droq import evaluate as droq_evaluate # noqa: F401 -from sheeprl.algos.p2e_dv1 import evaluate as p2e_dv1_evaluate # noqa: F401 from sheeprl.algos.p2e_dv1 import p2e_dv1 # noqa: F401 -from sheeprl.algos.p2e_dv2 import evaluate as p2e_dv2_evaluate # noqa: F401 from sheeprl.algos.p2e_dv2 import p2e_dv2 # noqa: F401 -from sheeprl.algos.ppo import evaluate as ppo_evaluate # noqa: F401 from sheeprl.algos.ppo import ppo # noqa: F401 from sheeprl.algos.ppo import ppo_decoupled # noqa: F401 -from sheeprl.algos.ppo_recurrent import evaluate as ppo_recurrent_evaluate # noqa: F401 from sheeprl.algos.ppo_recurrent import ppo_recurrent # noqa: F401 -from sheeprl.algos.sac import evaluate as sac_evaluate # noqa: F401 from sheeprl.algos.sac import sac # noqa: F401 from sheeprl.algos.sac import sac_decoupled # noqa: F401 -from sheeprl.algos.sac_ae import evaluate as sac_ae_evaluate # noqa: F401 from sheeprl.algos.sac_ae import sac_ae # noqa: F401 +from sheeprl.algos.dreamer_v1 import evaluate as dreamer_v1_evaluate # noqa: F401, isort:skip +from sheeprl.algos.dreamer_v2 import evaluate as dreamer_v2_evaluate # noqa: F401, isort:skip +from sheeprl.algos.dreamer_v3 import evaluate as dreamer_v3_evaluate # noqa: F401, isort:skip +from sheeprl.algos.droq import evaluate as droq_evaluate # noqa: F401, isort:skip +from sheeprl.algos.p2e_dv1 import evaluate as p2e_dv1_evaluate # noqa: F401, isort:skip +from sheeprl.algos.p2e_dv2 import evaluate as p2e_dv2_evaluate # noqa: F401, isort:skip +from sheeprl.algos.ppo import evaluate as ppo_evaluate # noqa: F401, isort:skip +from sheeprl.algos.ppo_recurrent import evaluate as ppo_recurrent_evaluate # noqa: F401, isort:skip +from sheeprl.algos.sac import evaluate as sac_evaluate # noqa: F401, isort:skip +from sheeprl.algos.sac_ae import evaluate as sac_ae_evaluate # noqa: F401, isort:skip +# fmt: on + # Needed because MineRL 0.4.4 is not compatible with the latest version of numpy np.float = np.float32 np.int = np.int64 diff --git a/sheeprl/algos/dreamer_v1/evaluate.py b/sheeprl/algos/dreamer_v1/evaluate.py index 00a47ee0..18c75004 100644 --- a/sheeprl/algos/dreamer_v1/evaluate.py +++ b/sheeprl/algos/dreamer_v1/evaluate.py @@ -12,7 +12,7 @@ from sheeprl.utils.registry import register_evaluation -@register_evaluation(name="dreamer_v1") +@register_evaluation(algorithms="dreamer_v1") def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): logger = create_tensorboard_logger(fabric, cfg) if logger and fabric.is_global_zero: diff --git a/sheeprl/algos/dreamer_v2/evaluate.py b/sheeprl/algos/dreamer_v2/evaluate.py index 9f71b246..3640b1c9 100644 --- a/sheeprl/algos/dreamer_v2/evaluate.py +++ b/sheeprl/algos/dreamer_v2/evaluate.py @@ -12,7 +12,7 @@ from sheeprl.utils.registry import register_evaluation -@register_evaluation(name="dreamer_v2") +@register_evaluation(algorithms="dreamer_v2") def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): logger = create_tensorboard_logger(fabric, cfg) if logger and fabric.is_global_zero: diff --git a/sheeprl/algos/dreamer_v3/evaluate.py b/sheeprl/algos/dreamer_v3/evaluate.py index ff327268..eeef7915 100644 --- a/sheeprl/algos/dreamer_v3/evaluate.py +++ b/sheeprl/algos/dreamer_v3/evaluate.py @@ -12,7 +12,7 @@ from sheeprl.utils.registry import register_evaluation -@register_evaluation(name="dreamer_v3") +@register_evaluation(algorithms="dreamer_v3") def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): logger = create_tensorboard_logger(fabric, cfg) if logger and fabric.is_global_zero: diff --git a/sheeprl/algos/droq/evaluate.py b/sheeprl/algos/droq/evaluate.py index f23ff60d..c241ab9f 100644 --- a/sheeprl/algos/droq/evaluate.py +++ b/sheeprl/algos/droq/evaluate.py @@ -14,7 +14,7 @@ from sheeprl.utils.registry import register_evaluation -@register_evaluation(name="droq") +@register_evaluation(algorithms="droq") def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): logger = create_tensorboard_logger(fabric, cfg) if logger and fabric.is_global_zero: diff --git a/sheeprl/algos/p2e_dv1/evaluate.py b/sheeprl/algos/p2e_dv1/evaluate.py index df4c53fd..0dd7abb2 100644 --- a/sheeprl/algos/p2e_dv1/evaluate.py +++ b/sheeprl/algos/p2e_dv1/evaluate.py @@ -13,7 +13,7 @@ from sheeprl.utils.registry import register_evaluation -@register_evaluation(name="p2e_dv1") +@register_evaluation(algorithms="p2e_dv1") def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): logger = create_tensorboard_logger(fabric, cfg) if logger and fabric.is_global_zero: diff --git a/sheeprl/algos/p2e_dv2/evaluate.py b/sheeprl/algos/p2e_dv2/evaluate.py index 4a2937e4..8a6b1f99 100644 --- a/sheeprl/algos/p2e_dv2/evaluate.py +++ b/sheeprl/algos/p2e_dv2/evaluate.py @@ -13,7 +13,7 @@ from sheeprl.utils.registry import register_evaluation -@register_evaluation(name="p2e_dv2") +@register_evaluation(algorithms="p2e_dv2") def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): logger = create_tensorboard_logger(fabric, cfg) if logger and fabric.is_global_zero: diff --git a/sheeprl/algos/ppo/evaluate.py b/sheeprl/algos/ppo/evaluate.py index 39f5e8f8..69dbf314 100644 --- a/sheeprl/algos/ppo/evaluate.py +++ b/sheeprl/algos/ppo/evaluate.py @@ -12,7 +12,7 @@ from sheeprl.utils.registry import register_evaluation -@register_evaluation(name="ppo") +@register_evaluation(algorithms=["ppo", "ppo_decoupled"]) def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): logger = create_tensorboard_logger(fabric, cfg) if logger and fabric.is_global_zero: diff --git a/sheeprl/utils/registry.py b/sheeprl/utils/registry.py index 2faf7707..ab28fd79 100644 --- a/sheeprl/utils/registry.py +++ b/sheeprl/utils/registry.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import sys from typing import Any, Callable, Dict, List @@ -6,31 +8,73 @@ # tasks[module] = [..., {"name": algorithm, "entrypoint": entrypoint, "decoupled": decoupled}] # where `module` and `algorithm` are respectively taken from sheeprl/algos/{module}/{algorithm}.py, # while `entrypoint` is the decorated function -tasks: Dict[str, List[Dict[str, Any]]] = {} +algorithm_registry: Dict[str, List[Dict[str, Any]]] = {} evaluation_registry: Dict[str, List[Dict[str, Any]]] = {} -def _register( - fn: Callable[..., Any], decoupled: bool = False, type: str = "algorithm", name: str | None = None -) -> Callable[..., Any]: +def _register_algorithm(fn: Callable[..., Any], decoupled: bool = False) -> Callable[..., Any]: # lookup containing module if fn.__module__ == "__main__": return fn entrypoint = fn.__name__ module_split = fn.__module__.split(".") - algorithm = module_split[-1] if name is None else name + algorithm = module_split[-1] module = ".".join(module_split[:-1]) - if type == "algorithm": - registry = tasks + registered_algos = algorithm_registry.get(module, None) + if registered_algos is None: + algorithm_registry[module] = [{"name": algorithm, "entrypoint": entrypoint, "decoupled": decoupled}] else: - registry = evaluation_registry - algos = registry.get(module, None) - if algos is None: - registry[module] = [{"name": algorithm, "entrypoint": entrypoint, "decoupled": decoupled}] + algorithm_registry[module].append({"name": algorithm, "entrypoint": entrypoint, "decoupled": decoupled}) + + # add the decorated function to __all__ in algorithm + mod = sys.modules[fn.__module__] + if hasattr(mod, "__all__"): + mod.__all__.append(entrypoint) + else: + mod.__all__ = [entrypoint] + return fn + + +def _register_evaluation(fn: Callable[..., Any], algorithms: str | List[str]) -> Callable[..., Any]: + # lookup containing module + if fn.__module__ == "__main__": + return fn + entrypoint = fn.__name__ + module_split = fn.__module__.split(".") + module = ".".join(module_split[:-1]) + if isinstance(algorithms, str): + algorithms = [algorithms] + # Check that the algorithms which we want to register an evaluation function for + # have been registered as algorithms + registered_algos = algorithm_registry.get(module, None) + if registered_algos is None: + raise ValueError( + f"The evaluation function `{module+'.'+entrypoint}` for the algorithms named `{', '.join(algorithms)}` " + "is going to be registered, but no algorithm has been registered!" + ) + registered_algo_names = {algo["name"] for algo in registered_algos} + if len(set(algorithms) - registered_algo_names) > 0: + raise ValueError( + f"You are trying to register the evaluation function `{module+'.'+entrypoint}` " + f"for algorithms which have not been registered for the module `{module}`!\n" + f"Registered algorithms: {', '.join(registered_algo_names)}\n" + f"Specified algorithms: {', '.join(algorithms)}" + ) + registered_evals = evaluation_registry.get(module, None) + if registered_evals is None: + evaluation_registry[module] = [] + for algorithm in algorithms: + evaluation_registry[module].append({"name": algorithm, "entrypoint": entrypoint}) else: - if algorithm in algos: - raise ValueError(f"The algorithm `{algorithm}` has already been registered!") - registry[module].append({"name": algorithm, "entrypoint": entrypoint, "decoupled": decoupled}) + for registered_eval in registered_evals: + if registered_eval["name"] in algorithms: + raise ValueError( + f"Cannot register the evaluate function `{module+'.'+entrypoint}` " + f"for the algorithm `{registered_eval['name']}`: " + f"the evaluation function `{module+'.'+registered_eval['entrypoint']}` has already " + f"been registered for the algorithm named `{registered_eval['name']}` in the module `{module}`!" + ) + evaluation_registry[module].extend([{"name": algorithm, "entrypoint": entrypoint} for algorithm in algorithms]) # add the decorated function to __all__ in algorithm mod = sys.modules[fn.__module__] @@ -43,13 +87,13 @@ def _register( def register_algorithm(decoupled: bool = False): def inner_decorator(fn): - return _register(fn, decoupled=decoupled) + return _register_algorithm(fn, decoupled=decoupled) return inner_decorator -def register_evaluation(name: str | None = None): +def register_evaluation(algorithms: str | List[str]): def inner_decorator(fn): - return _register(fn, type="evaluation", name=name) + return _register_evaluation(fn, algorithms=algorithms) return inner_decorator From a7b57466b5de0949bcc70867fb14255faf2de782 Mon Sep 17 00:00:00 2001 From: belerico_t Date: Mon, 30 Oct 2023 15:03:50 +0100 Subject: [PATCH 04/12] Create fabric object with single device and node --- sheeprl/cli.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/sheeprl/cli.py b/sheeprl/cli.py index 6d5e3c9e..b90205bb 100644 --- a/sheeprl/cli.py +++ b/sheeprl/cli.py @@ -104,13 +104,8 @@ def eval_algorithm(cfg: DictConfig): """ cfg = dotdict(OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)) capture_video = cfg.env.capture_video - - # TODO: change the number of devices when FSDP will be supported - fabric = Fabric(**cfg.fabric, devices=1) - - # Load the checkpoint - state = fabric.load(cfg.checkpoint_path) ckpt_path = pathlib.Path(cfg.checkpoint_path) + accelerator = cfg.fabric.get("accelerator", "auto") # Load the configuration cfg = dotdict(OmegaConf.load(ckpt_path.parent.parent.parent / ".hydra" / "config.yaml")) @@ -123,6 +118,17 @@ def eval_algorithm(cfg: DictConfig): cfg.env.num_envs = 1 cfg.env.capture_video = capture_video + # TODO: change the number of devices when FSDP will be supported + cfg.fabric.pop("devices", None) + cfg.fabric.pop("strategy", None) + cfg.fabric.pop("num_nodes", None) + cfg.fabric.pop("callbacks", None) + cfg.fabric.pop("accelerator", None) + fabric = Fabric(**cfg.fabric, accelerator=accelerator, devices=1, num_nodes=1) + + # Load the checkpoint + state = fabric.load(cfg.checkpoint_path) + # Given the algorithm's name, retrieve the module where # 'cfg.algo.name'.py is contained; from there retrieve the # `register_algorithm`-decorated entrypoint; From a33c0127ddd4c75fdcaf14a112ef5bfea245f7a1 Mon Sep 17 00:00:00 2001 From: belerico_t Date: Mon, 30 Oct 2023 15:11:02 +0100 Subject: [PATCH 05/12] Fix algo name in evaluate --- sheeprl/cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sheeprl/cli.py b/sheeprl/cli.py index b90205bb..c67e5abe 100644 --- a/sheeprl/cli.py +++ b/sheeprl/cli.py @@ -135,7 +135,7 @@ def eval_algorithm(cfg: DictConfig): # the entrypoint will be launched by Fabric with `fabric.launch(entrypoint)` module = None entrypoint = None - algo_name = cfg.algo.name.replace("_decoupled", "") + algo_name = cfg.algo.name for _module, _algos in evaluation_registry.items(): for _algo in _algos: if algo_name == _algo["name"]: From a421863acf8092d389094a406020396107618545 Mon Sep 17 00:00:00 2001 From: belerico_t Date: Mon, 30 Oct 2023 15:11:07 +0100 Subject: [PATCH 06/12] Update docs --- howto/eval_your_agent.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/howto/eval_your_agent.md b/howto/eval_your_agent.md index 5222305d..51ac0bf2 100644 --- a/howto/eval_your_agent.md +++ b/howto/eval_your_agent.md @@ -7,9 +7,14 @@ python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt ``` The agent and the configs used during the traning are loaded automatically. The user can modify only few parameters for evaluation: + 1. `fabric.accelerator`: you can use the accelerator you want for evaluating the agent, you just need specify it in the command. For instance, `python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt fabric.accelerator=gpu` for evaluating the agent on the gpu. If you want to choose the GPU, then you need to define the `CUDA_VISIBLE_DEVICES` environment variable in the `.env` file or set it before running the script. For example, you can execute the following command to evaluate your agent on the GPU with index 2: `CUDA_VISIBLE_DEVICES="2" python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt fabric.accelerator=gpu`. 2. `env.capture_video`: you can decide to caputre the video of the episode during the evaluation or not. For instance, `python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt env.capture_video=Ture` for capturing the video of the evaluation. +All the other parameters are loaded from the config file used during the training. Moreover, the following parameters are automatically set during the evaluation: + +* `cfg.env.num_envs`: the number of environments used during the evaluation is set to 1 + > **Note** > > You cannot modify the number of processes to spawn. The evaluation is made with only 1 process. From 88779afc3ff7f6472121e49f1d594dcd441af8fc Mon Sep 17 00:00:00 2001 From: belerico_t Date: Mon, 30 Oct 2023 15:59:08 +0100 Subject: [PATCH 07/12] Instantiate fabric from hydra --- sheeprl/cli.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sheeprl/cli.py b/sheeprl/cli.py index c67e5abe..542235a7 100644 --- a/sheeprl/cli.py +++ b/sheeprl/cli.py @@ -124,7 +124,9 @@ def eval_algorithm(cfg: DictConfig): cfg.fabric.pop("num_nodes", None) cfg.fabric.pop("callbacks", None) cfg.fabric.pop("accelerator", None) - fabric = Fabric(**cfg.fabric, accelerator=accelerator, devices=1, num_nodes=1) + fabric: Fabric = hydra.utils.instantiate( + cfg.fabric, accelerator=accelerator, devices=1, num_nodes=1, _convert_="all" + ) # Load the checkpoint state = fabric.load(cfg.checkpoint_path) From 0836e3db4375d4c437b3a7e06ad60bdf97354a41 Mon Sep 17 00:00:00 2001 From: belerico_t Date: Tue, 31 Oct 2023 10:34:46 +0100 Subject: [PATCH 08/12] Fix CLI for evaluation --- sheeprl/cli.py | 26 +++++++++++++++++++------- sheeprl/configs/eval_config.yaml | 5 ++--- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/sheeprl/cli.py b/sheeprl/cli.py index 542235a7..644de5c5 100644 --- a/sheeprl/cli.py +++ b/sheeprl/cli.py @@ -4,6 +4,7 @@ import pathlib import time import warnings +from pathlib import Path from typing import Any, Dict import hydra @@ -107,13 +108,6 @@ def eval_algorithm(cfg: DictConfig): ckpt_path = pathlib.Path(cfg.checkpoint_path) accelerator = cfg.fabric.get("accelerator", "auto") - # Load the configuration - cfg = dotdict(OmegaConf.load(ckpt_path.parent.parent.parent / ".hydra" / "config.yaml")) - cfg.run_name = str( - os.path.join( - os.path.basename(ckpt_path.parent.parent.parent), os.path.basename(ckpt_path.parent.parent), "evaluation" - ) - ) cfg.checkpoint_path = str(ckpt_path) cfg.env.num_envs = 1 cfg.env.capture_video = capture_video @@ -231,5 +225,23 @@ def run(cfg: DictConfig): @hydra.main(version_base="1.13", config_path="configs", config_name="eval_config") def evaluation(cfg: DictConfig): + # Load the checkpoint configuration + checkpoint_path = Path(cfg.checkpoint_path) + ckpt_cfg = OmegaConf.load(checkpoint_path.parent.parent.parent / ".hydra" / "config.yaml") + + # Merge the two configs + from omegaconf import open_dict + + with open_dict(cfg): + cfg.merge_with(ckpt_cfg) + cfg.run_name = str( + os.path.join( + os.path.basename(checkpoint_path.parent.parent.parent), + os.path.basename(checkpoint_path.parent.parent), + "evaluation", + ) + ) + + # Check the validity of the configuration and run the evaluation check_configs_evaluation(cfg) eval_algorithm(cfg) diff --git a/sheeprl/configs/eval_config.yaml b/sheeprl/configs/eval_config.yaml index 9e7594fa..388c91b5 100644 --- a/sheeprl/configs/eval_config.yaml +++ b/sheeprl/configs/eval_config.yaml @@ -5,8 +5,7 @@ defaults: - _self_ hydra: - run: - dir: logs/runs/ + output_subdir: null fabric: accelerator: cpu @@ -14,4 +13,4 @@ fabric: env: capture_video: True -checkpoint_path: ??? \ No newline at end of file +checkpoint_path: ??? From 673dadf40ca6d50df13686dd2cb1929589c1e790 Mon Sep 17 00:00:00 2001 From: belerico_t Date: Tue, 31 Oct 2023 10:39:46 +0100 Subject: [PATCH 09/12] Fix merge direction between configs --- sheeprl/cli.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sheeprl/cli.py b/sheeprl/cli.py index 644de5c5..7e6ff5c9 100644 --- a/sheeprl/cli.py +++ b/sheeprl/cli.py @@ -233,8 +233,8 @@ def evaluation(cfg: DictConfig): from omegaconf import open_dict with open_dict(cfg): - cfg.merge_with(ckpt_cfg) - cfg.run_name = str( + ckpt_cfg.merge_with(cfg) + ckpt_cfg.run_name = str( os.path.join( os.path.basename(checkpoint_path.parent.parent.parent), os.path.basename(checkpoint_path.parent.parent), @@ -243,5 +243,5 @@ def evaluation(cfg: DictConfig): ) # Check the validity of the configuration and run the evaluation - check_configs_evaluation(cfg) - eval_algorithm(cfg) + check_configs_evaluation(ckpt_cfg) + eval_algorithm(ckpt_cfg) From e52ca08458594e143b955d316d0ba06086d10f13 Mon Sep 17 00:00:00 2001 From: belerico_t Date: Tue, 31 Oct 2023 10:55:19 +0100 Subject: [PATCH 10/12] Do not let user modify eval_cfg unless where he/she can --- howto/eval_your_agent.md | 4 ++-- sheeprl/cli.py | 16 +++++++++------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/howto/eval_your_agent.md b/howto/eval_your_agent.md index 51ac0bf2..1261e6d3 100644 --- a/howto/eval_your_agent.md +++ b/howto/eval_your_agent.md @@ -8,8 +8,8 @@ python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt The agent and the configs used during the traning are loaded automatically. The user can modify only few parameters for evaluation: -1. `fabric.accelerator`: you can use the accelerator you want for evaluating the agent, you just need specify it in the command. For instance, `python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt fabric.accelerator=gpu` for evaluating the agent on the gpu. If you want to choose the GPU, then you need to define the `CUDA_VISIBLE_DEVICES` environment variable in the `.env` file or set it before running the script. For example, you can execute the following command to evaluate your agent on the GPU with index 2: `CUDA_VISIBLE_DEVICES="2" python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt fabric.accelerator=gpu`. -2. `env.capture_video`: you can decide to caputre the video of the episode during the evaluation or not. For instance, `python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt env.capture_video=Ture` for capturing the video of the evaluation. +1. `fabric` related ones: you can use the accelerator you want for evaluating the agent, you just need specify it in the command. For instance, `python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt fabric.accelerator=gpu` for evaluating the agent on the gpu. If you want to choose the GPU, then you need to define the `CUDA_VISIBLE_DEVICES` environment variable in the `.env` file or set it before running the script. For example, you can execute the following command to evaluate your agent on the GPU with index 2: `CUDA_VISIBLE_DEVICES="2" python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt fabric.accelerator=gpu`. By default the number of devices and nodes is set to 1, while the precision and the plugins are set to the ones set in the checkpoint config. +2. `env.capture_video`: you can decide to capture the video of the episode during the evaluation or not. For instance, `python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt env.capture_video=Ture` for capturing the video of the evaluation. All the other parameters are loaded from the config file used during the training. Moreover, the following parameters are automatically set during the evaluation: diff --git a/sheeprl/cli.py b/sheeprl/cli.py index 7e6ff5c9..bc4300ce 100644 --- a/sheeprl/cli.py +++ b/sheeprl/cli.py @@ -1,7 +1,6 @@ import datetime import importlib import os -import pathlib import time import warnings from pathlib import Path @@ -104,14 +103,8 @@ def eval_algorithm(cfg: DictConfig): cfg (DictConfig): the loaded configuration. """ cfg = dotdict(OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)) - capture_video = cfg.env.capture_video - ckpt_path = pathlib.Path(cfg.checkpoint_path) accelerator = cfg.fabric.get("accelerator", "auto") - cfg.checkpoint_path = str(ckpt_path) - cfg.env.num_envs = 1 - cfg.env.capture_video = capture_video - # TODO: change the number of devices when FSDP will be supported cfg.fabric.pop("devices", None) cfg.fabric.pop("strategy", None) @@ -233,7 +226,16 @@ def evaluation(cfg: DictConfig): from omegaconf import open_dict with open_dict(cfg): + # Remove env related parameters + capture_video = getattr(cfg.env, "capture_video", True) + cfg.env = {"capture_video": capture_video, "num_envs": 1} + cfg.exp = {} + cfg.algo = {} + + # Merge configs ckpt_cfg.merge_with(cfg) + + # Update values after merge ckpt_cfg.run_name = str( os.path.join( os.path.basename(checkpoint_path.parent.parent.parent), From 98fdd37dbaf3776b0280022a30ec839362aca436 Mon Sep 17 00:00:00 2001 From: belerico_t Date: Tue, 31 Oct 2023 11:02:11 +0100 Subject: [PATCH 11/12] Prevent creation of output folder when evaluating --- sheeprl/configs/eval_config.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/sheeprl/configs/eval_config.yaml b/sheeprl/configs/eval_config.yaml index 388c91b5..2eeba61a 100644 --- a/sheeprl/configs/eval_config.yaml +++ b/sheeprl/configs/eval_config.yaml @@ -3,9 +3,13 @@ # specify here default training configuration defaults: - _self_ + - override hydra/hydra_logging: disabled + - override hydra/job_logging: disabled hydra: output_subdir: null + run: + dir: . fabric: accelerator: cpu From 05587d4fdab51750913ca48bce524fcaf7e980e5 Mon Sep 17 00:00:00 2001 From: belerico_t Date: Tue, 31 Oct 2023 12:30:48 +0100 Subject: [PATCH 12/12] Set fabric defaults on evaluation --- howto/eval_your_agent.md | 5 +++-- sheeprl/cli.py | 14 +++++++------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/howto/eval_your_agent.md b/howto/eval_your_agent.md index 1261e6d3..6228face 100644 --- a/howto/eval_your_agent.md +++ b/howto/eval_your_agent.md @@ -11,9 +11,10 @@ The agent and the configs used during the traning are loaded automatically. The 1. `fabric` related ones: you can use the accelerator you want for evaluating the agent, you just need specify it in the command. For instance, `python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt fabric.accelerator=gpu` for evaluating the agent on the gpu. If you want to choose the GPU, then you need to define the `CUDA_VISIBLE_DEVICES` environment variable in the `.env` file or set it before running the script. For example, you can execute the following command to evaluate your agent on the GPU with index 2: `CUDA_VISIBLE_DEVICES="2" python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt fabric.accelerator=gpu`. By default the number of devices and nodes is set to 1, while the precision and the plugins are set to the ones set in the checkpoint config. 2. `env.capture_video`: you can decide to capture the video of the episode during the evaluation or not. For instance, `python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt env.capture_video=Ture` for capturing the video of the evaluation. -All the other parameters are loaded from the config file used during the training. Moreover, the following parameters are automatically set during the evaluation: +All the other parameters are loaded from the checkpoint config file used during the training. Moreover, the following parameters are automatically set during the evaluation: -* `cfg.env.num_envs`: the number of environments used during the evaluation is set to 1 +* `cfg.env.num_envs`, i.e. the number of environments used during the evaluation, is set to 1 +* `cfg.fabric.devices`and `cfg.fabric.num_nodes` are set to 1 > **Note** > diff --git a/sheeprl/cli.py b/sheeprl/cli.py index bc4300ce..a8f067ef 100644 --- a/sheeprl/cli.py +++ b/sheeprl/cli.py @@ -103,14 +103,9 @@ def eval_algorithm(cfg: DictConfig): cfg (DictConfig): the loaded configuration. """ cfg = dotdict(OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)) - accelerator = cfg.fabric.get("accelerator", "auto") # TODO: change the number of devices when FSDP will be supported - cfg.fabric.pop("devices", None) - cfg.fabric.pop("strategy", None) - cfg.fabric.pop("num_nodes", None) - cfg.fabric.pop("callbacks", None) - cfg.fabric.pop("accelerator", None) + accelerator = cfg.fabric.get("accelerator", "auto") fabric: Fabric = hydra.utils.instantiate( cfg.fabric, accelerator=accelerator, devices=1, num_nodes=1, _convert_="all" ) @@ -226,11 +221,16 @@ def evaluation(cfg: DictConfig): from omegaconf import open_dict with open_dict(cfg): - # Remove env related parameters capture_video = getattr(cfg.env, "capture_video", True) cfg.env = {"capture_video": capture_video, "num_envs": 1} cfg.exp = {} cfg.algo = {} + cfg.fabric = { + "devices": 1, + "num_nodes": 1, + "strategy": "auto", + "accelerator": getattr(cfg.fabric, "accelerator", "auto"), + } # Merge configs ckpt_cfg.merge_with(cfg)