diff --git a/README.md b/README.md index 670651d1..86bb8425 100644 --- a/README.md +++ b/README.md @@ -223,6 +223,16 @@ python sheeprl.py fabric.accelerator=cpu fabric.strategy=ddp fabric.devices=2 ex You can check the available parameters for Lightning Fabric [here](https://lightning.ai/docs/fabric/stable/api/fabric_args.html). +### Evaluate your Agents + +You can easily evaluate your trained agents from checkpoints: tranining configurations are retrieved automatically. + +```bash +python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt fabric.accelerator=gpu env.capture_video=True +``` + +For more information, check the corresponding [howto](./howto/eval_your_agent.md). + ## :book: Repository structure The repository is structured as follows: diff --git a/howto/eval_your_agent.md b/howto/eval_your_agent.md new file mode 100644 index 00000000..6228face --- /dev/null +++ b/howto/eval_your_agent.md @@ -0,0 +1,124 @@ +# Evaluate your Agents + +In this document we give the user some advices to evaluate its agents. To evaluate an agent, it is simply necessary run the `./sheeprl_eval.py` script, by passing the path to the checkpoint of the agent you want to evaluate. + +```bash +python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt +``` + +The agent and the configs used during the traning are loaded automatically. The user can modify only few parameters for evaluation: + +1. `fabric` related ones: you can use the accelerator you want for evaluating the agent, you just need specify it in the command. For instance, `python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt fabric.accelerator=gpu` for evaluating the agent on the gpu. If you want to choose the GPU, then you need to define the `CUDA_VISIBLE_DEVICES` environment variable in the `.env` file or set it before running the script. For example, you can execute the following command to evaluate your agent on the GPU with index 2: `CUDA_VISIBLE_DEVICES="2" python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt fabric.accelerator=gpu`. By default the number of devices and nodes is set to 1, while the precision and the plugins are set to the ones set in the checkpoint config. +2. `env.capture_video`: you can decide to capture the video of the episode during the evaluation or not. For instance, `python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt env.capture_video=Ture` for capturing the video of the evaluation. + +All the other parameters are loaded from the checkpoint config file used during the training. Moreover, the following parameters are automatically set during the evaluation: + +* `cfg.env.num_envs`, i.e. the number of environments used during the evaluation, is set to 1 +* `cfg.fabric.devices`and `cfg.fabric.num_nodes` are set to 1 + +> **Note** +> +> You cannot modify the number of processes to spawn. The evaluation is made with only 1 process. + +## Log Directory +By default the evaluation logs are stored in the same folder of the experiment. Suppose to have trained a `PPO` agent in the `CartPole-v1` environment. The log directory is organized as follows: +```tree +logs +└── runs + └── ppo + └── CartPole-v1 + └── 2023-10-27_11-46-05_default_42 + ├── .hydra + │ ├── config.yaml + │ ├── hydra.yaml + │ └── overrides.yaml + ├── cli.log + └── version_0 + ├── checkpoint + │ ├── ckpt_1024_0.ckpt + │ ├── ckpt_1536_0.ckpt + │ └── ckpt_512_0.ckpt + ├── events.out.tfevents.1698399966.72040.0 + ├── memmap_buffer + │ └── rank_0 + │ ├── actions.memmap + │ ├── actions.meta.pt + │ ├── advantages.memmap + │ ├── advantages.meta.pt + │ ├── dones.memmap + │ ├── dones.meta.pt + │ ├── logprobs.memmap + │ ├── logprobs.meta.pt + │ ├── meta.pt + │ ├── returns.memmap + │ ├── returns.meta.pt + │ ├── rewards.memmap + │ ├── rewards.meta.pt + │ ├── state.memmap + │ ├── state.meta.pt + │ ├── values.memmap + │ └── values.meta.pt + └── train_videos + ├── rl-video-episode-0.mp4 + ├── rl-video-episode-1.mp4 + └── rl-video-episode-8.mp4 +``` + +Where `./logs/runs/ppo/2023-10-27_11-46-05_default_42` contains your experiment. The evaluation script will create a subfolder, named `evaluation`, in the `./logs/runs/ppo/2023-10-27_11-46-05_default_42/version_0` folder, which will contain all the evaluations of the agents. + +For example, if we run two evaluations, then the log directory of the experiment will be as follows: +```diff +logs +└── runs + ├── .hydra + │ ├── config.yaml + │ ├── hydra.yaml + │ └── overrides.yaml + ├── cli.log + └── ppo + └── CartPole-v1 + └── 2023-10-27_11-46-05_default_42 + ├── .hydra + │ ├── config.yaml + │ ├── hydra.yaml + │ └── overrides.yaml + ├── cli.log + └── version_0 + ├── checkpoint + │ ├── ckpt_1024_0.ckpt + │ ├── ckpt_1536_0.ckpt + │ └── ckpt_512_0.ckpt ++ ├── evaluation ++ │ ├── version_0 ++ │ │ ├── events.out.tfevents.1698400212.73839.0 ++ │ │ └── test_videos ++ │ │ └── rl-video-episode-0.mp4 ++ │ └── version_1 ++ │ ├── events.out.tfevents.1698400283.74353.0 ++ │ └── test_videos ++ │ └── rl-video-episode-0.mp4 + ├── events.out.tfevents.1698399966.72040.0 + ├── memmap_buffer + │ └── rank_0 + │ ├── actions.memmap + │ ├── actions.meta.pt + │ ├── advantages.memmap + │ ├── advantages.meta.pt + │ ├── dones.memmap + │ ├── dones.meta.pt + │ ├── logprobs.memmap + │ ├── logprobs.meta.pt + │ ├── meta.pt + │ ├── returns.memmap + │ ├── returns.meta.pt + │ ├── rewards.memmap + │ ├── rewards.meta.pt + │ ├── state.memmap + │ ├── state.meta.pt + │ ├── values.memmap + │ └── values.meta.pt + └── train_videos + ├── rl-video-episode-0.mp4 + ├── rl-video-episode-1.mp4 + └── rl-video-episode-8.mp4 +``` \ No newline at end of file diff --git a/sheeprl/__init__.py b/sheeprl/__init__.py index 9d676f1e..bb3b8ae0 100644 --- a/sheeprl/__init__.py +++ b/sheeprl/__init__.py @@ -11,22 +11,35 @@ if not _IS_TORCH_GREATER_EQUAL_2_0: raise ModuleNotFoundError(_IS_TORCH_GREATER_EQUAL_2_0) -# Needed because MineRL 0.4.4 is not compatible with the latest version of numpy import numpy as np -from sheeprl.algos.dreamer_v1 import dreamer_v1 as dreamer_v1 -from sheeprl.algos.dreamer_v2 import dreamer_v2 as dreamer_v2 -from sheeprl.algos.dreamer_v3 import dreamer_v3 as dreamer_v3 -from sheeprl.algos.droq import droq as droq -from sheeprl.algos.p2e_dv1 import p2e_dv1 as p2e_dv1 -from sheeprl.algos.p2e_dv2 import p2e_dv2 as p2e_dv2 -from sheeprl.algos.ppo import ppo as ppo -from sheeprl.algos.ppo import ppo_decoupled as ppo_decoupled -from sheeprl.algos.ppo_recurrent import ppo_recurrent as ppo_recurrent -from sheeprl.algos.sac import sac as sac -from sheeprl.algos.sac import sac_decoupled as sac_decoupled -from sheeprl.algos.sac_ae import sac_ae as sac_ae +# fmt: off +from sheeprl.algos.dreamer_v1 import dreamer_v1 # noqa: F401 +from sheeprl.algos.dreamer_v2 import dreamer_v2 # noqa: F401 +from sheeprl.algos.dreamer_v3 import dreamer_v3 # noqa: F401 +from sheeprl.algos.droq import droq # noqa: F401 +from sheeprl.algos.p2e_dv1 import p2e_dv1 # noqa: F401 +from sheeprl.algos.p2e_dv2 import p2e_dv2 # noqa: F401 +from sheeprl.algos.ppo import ppo # noqa: F401 +from sheeprl.algos.ppo import ppo_decoupled # noqa: F401 +from sheeprl.algos.ppo_recurrent import ppo_recurrent # noqa: F401 +from sheeprl.algos.sac import sac # noqa: F401 +from sheeprl.algos.sac import sac_decoupled # noqa: F401 +from sheeprl.algos.sac_ae import sac_ae # noqa: F401 + +from sheeprl.algos.dreamer_v1 import evaluate as dreamer_v1_evaluate # noqa: F401, isort:skip +from sheeprl.algos.dreamer_v2 import evaluate as dreamer_v2_evaluate # noqa: F401, isort:skip +from sheeprl.algos.dreamer_v3 import evaluate as dreamer_v3_evaluate # noqa: F401, isort:skip +from sheeprl.algos.droq import evaluate as droq_evaluate # noqa: F401, isort:skip +from sheeprl.algos.p2e_dv1 import evaluate as p2e_dv1_evaluate # noqa: F401, isort:skip +from sheeprl.algos.p2e_dv2 import evaluate as p2e_dv2_evaluate # noqa: F401, isort:skip +from sheeprl.algos.ppo import evaluate as ppo_evaluate # noqa: F401, isort:skip +from sheeprl.algos.ppo_recurrent import evaluate as ppo_recurrent_evaluate # noqa: F401, isort:skip +from sheeprl.algos.sac import evaluate as sac_evaluate # noqa: F401, isort:skip +from sheeprl.algos.sac_ae import evaluate as sac_ae_evaluate # noqa: F401, isort:skip +# fmt: on +# Needed because MineRL 0.4.4 is not compatible with the latest version of numpy np.float = np.float32 np.int = np.int64 np.bool = bool diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index c5617898..2cb9abf1 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -445,7 +445,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): actions_dim = ( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) - # observation_shape = observation_space["rgb"].shape clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r if not isinstance(observation_space, gym.spaces.Dict): raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") diff --git a/sheeprl/algos/dreamer_v1/evaluate.py b/sheeprl/algos/dreamer_v1/evaluate.py new file mode 100644 index 00000000..18c75004 --- /dev/null +++ b/sheeprl/algos/dreamer_v1/evaluate.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from typing import Any, Dict + +import gymnasium as gym +from lightning import Fabric + +from sheeprl.algos.dreamer_v1.agent import PlayerDV1, build_models +from sheeprl.algos.dreamer_v2.utils import test +from sheeprl.utils.env import make_env +from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.registry import register_evaluation + + +@register_evaluation(algorithms="dreamer_v1") +def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): + logger = create_tensorboard_logger(fabric, cfg) + if logger and fabric.is_global_zero: + fabric._loggers = [logger] + fabric.logger.log_hyperparams(cfg) + log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + + env = make_env( + cfg, + cfg.seed, + 0, + log_dir, + "test", + vector_env_idx=0, + )() + observation_space = env.observation_space + action_space = env.action_space + + if not isinstance(observation_space, gym.spaces.Dict): + raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") + if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []: + raise RuntimeError( + "You should specify at least one CNN keys or MLP keys from the cli: " + "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" + ) + fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + + is_continuous = isinstance(action_space, gym.spaces.Box) + is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) + actions_dim = ( + action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) + ) + # Create the actor and critic models + world_model, actor, _ = build_models( + fabric, + actions_dim, + is_continuous, + cfg, + observation_space, + state["world_model"], + state["actor"], + ) + player = PlayerDV1( + world_model.encoder.module, + world_model.rssm.recurrent_model.module, + world_model.rssm.representation_model.module, + actor.module, + actions_dim, + cfg.algo.player.expl_amount, + cfg.env.num_envs, + cfg.algo.world_model.stochastic_size, + cfg.algo.world_model.recurrent_model.recurrent_state_size, + fabric.device, + ) + + test(player, fabric, cfg, log_dir, sample_actions=False) diff --git a/sheeprl/algos/dreamer_v2/evaluate.py b/sheeprl/algos/dreamer_v2/evaluate.py new file mode 100644 index 00000000..3640b1c9 --- /dev/null +++ b/sheeprl/algos/dreamer_v2/evaluate.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from typing import Any, Dict + +import gymnasium as gym +from lightning import Fabric + +from sheeprl.algos.dreamer_v2.agent import PlayerDV2, build_models +from sheeprl.algos.dreamer_v2.utils import test +from sheeprl.utils.env import make_env +from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.registry import register_evaluation + + +@register_evaluation(algorithms="dreamer_v2") +def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): + logger = create_tensorboard_logger(fabric, cfg) + if logger and fabric.is_global_zero: + fabric._loggers = [logger] + fabric.logger.log_hyperparams(cfg) + log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + + env = make_env( + cfg, + cfg.seed, + 0, + log_dir, + "test", + vector_env_idx=0, + )() + observation_space = env.observation_space + action_space = env.action_space + + if not isinstance(observation_space, gym.spaces.Dict): + raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") + if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []: + raise RuntimeError( + "You should specify at least one CNN keys or MLP keys from the cli: " + "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" + ) + fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + + is_continuous = isinstance(action_space, gym.spaces.Box) + is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) + actions_dim = ( + action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) + ) + # Create the actor and critic models + world_model, actor, _, _ = build_models( + fabric, + actions_dim, + is_continuous, + cfg, + observation_space, + state["world_model"], + state["actor"], + ) + player = PlayerDV2( + world_model.encoder.module, + world_model.rssm.recurrent_model.module, + world_model.rssm.representation_model.module, + actor.module, + actions_dim, + cfg.algo.player.expl_amount, + cfg.env.num_envs, + cfg.algo.world_model.stochastic_size, + cfg.algo.world_model.recurrent_model.recurrent_state_size, + fabric.device, + discrete_size=cfg.algo.world_model.discrete_size, + ) + + test(player, fabric, cfg, log_dir, sample_actions=False) diff --git a/sheeprl/algos/dreamer_v3/evaluate.py b/sheeprl/algos/dreamer_v3/evaluate.py new file mode 100644 index 00000000..eeef7915 --- /dev/null +++ b/sheeprl/algos/dreamer_v3/evaluate.py @@ -0,0 +1,72 @@ +from __future__ import annotations + +from typing import Any, Dict + +import gymnasium as gym +from lightning import Fabric + +from sheeprl.algos.dreamer_v3.agent import PlayerDV3, build_models +from sheeprl.algos.dreamer_v3.utils import test +from sheeprl.utils.env import make_env +from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.registry import register_evaluation + + +@register_evaluation(algorithms="dreamer_v3") +def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): + logger = create_tensorboard_logger(fabric, cfg) + if logger and fabric.is_global_zero: + fabric._loggers = [logger] + fabric.logger.log_hyperparams(cfg) + log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + + env = make_env( + cfg, + cfg.seed, + 0, + log_dir, + "test", + vector_env_idx=0, + )() + observation_space = env.observation_space + action_space = env.action_space + + if not isinstance(observation_space, gym.spaces.Dict): + raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") + if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []: + raise RuntimeError( + "You should specify at least one CNN keys or MLP keys from the cli: " + "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" + ) + fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + + is_continuous = isinstance(action_space, gym.spaces.Box) + is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) + actions_dim = ( + action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) + ) + # Create the actor and critic models + world_model, actor, _, _ = build_models( + fabric, + actions_dim, + is_continuous, + cfg, + observation_space, + state["world_model"], + state["actor"], + ) + player = PlayerDV3( + world_model.encoder.module, + world_model.rssm, + actor.module, + actions_dim, + cfg.algo.player.expl_amount, + cfg.env.num_envs, + cfg.algo.world_model.stochastic_size, + cfg.algo.world_model.recurrent_model.recurrent_state_size, + fabric.device, + discrete_size=cfg.algo.world_model.discrete_size, + ) + + test(player, fabric, cfg, log_dir, sample_actions=True) diff --git a/sheeprl/algos/droq/evaluate.py b/sheeprl/algos/droq/evaluate.py new file mode 100644 index 00000000..c241ab9f --- /dev/null +++ b/sheeprl/algos/droq/evaluate.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +from math import prod +from typing import Any, Dict + +import gymnasium as gym +from lightning import Fabric + +from sheeprl.algos.droq.agent import DROQAgent, DROQCritic +from sheeprl.algos.sac.agent import SACActor +from sheeprl.algos.sac.utils import test +from sheeprl.utils.env import make_env +from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.registry import register_evaluation + + +@register_evaluation(algorithms="droq") +def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): + logger = create_tensorboard_logger(fabric, cfg) + if logger and fabric.is_global_zero: + fabric._loggers = [logger] + fabric.logger.log_hyperparams(cfg) + log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + + env = make_env( + cfg, + cfg.seed, + 0, + log_dir, + "test", + vector_env_idx=0, + )() + action_space = env.action_space + observation_space = env.observation_space + if not isinstance(action_space, gym.spaces.Box): + raise ValueError("Only continuous action space is supported for the DroQ agent") + if not isinstance(observation_space, gym.spaces.Dict): + raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") + if len(cfg.mlp_keys.encoder) == 0: + raise RuntimeError("You should specify at least one MLP key for the encoder: `mlp_keys.encoder=[state]`") + for k in cfg.mlp_keys.encoder: + if len(observation_space[k].shape) > 1: + raise ValueError( + "Only environments with vector-only observations are supported by the DroQ agent. " + f"Provided environment: {cfg.env.id}" + ) + if cfg.metric.log_level > 0: + fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + + act_dim = prod(action_space.shape) + obs_dim = sum([prod(observation_space[k].shape) for k in cfg.mlp_keys.encoder]) + actor = SACActor( + observation_dim=obs_dim, + action_dim=act_dim, + distribution_cfg=cfg.distribution, + hidden_size=cfg.algo.actor.hidden_size, + action_low=action_space.low, + action_high=action_space.high, + ) + critics = [ + DROQCritic( + observation_dim=obs_dim + act_dim, + hidden_size=cfg.algo.critic.hidden_size, + num_critics=1, + dropout=cfg.algo.critic.dropout, + ) + for _ in range(cfg.algo.critic.n) + ] + target_entropy = -act_dim + agent = DROQAgent( + actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device + ) + agent.load_state_dict(state["agent"]) + + test(agent.actor, fabric, cfg, log_dir) diff --git a/sheeprl/algos/p2e_dv1/evaluate.py b/sheeprl/algos/p2e_dv1/evaluate.py new file mode 100644 index 00000000..0dd7abb2 --- /dev/null +++ b/sheeprl/algos/p2e_dv1/evaluate.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +from typing import Any, Dict + +import gymnasium as gym +from lightning import Fabric + +from sheeprl.algos.dreamer_v1.agent import PlayerDV1 +from sheeprl.algos.dreamer_v2.utils import test +from sheeprl.algos.p2e_dv1.agent import build_models +from sheeprl.utils.env import make_env +from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.registry import register_evaluation + + +@register_evaluation(algorithms="p2e_dv1") +def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): + logger = create_tensorboard_logger(fabric, cfg) + if logger and fabric.is_global_zero: + fabric._loggers = [logger] + fabric.logger.log_hyperparams(cfg) + log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + + env = make_env( + cfg, + cfg.seed, + 0, + log_dir, + "test", + vector_env_idx=0, + )() + observation_space = env.observation_space + action_space = env.action_space + + if not isinstance(observation_space, gym.spaces.Dict): + raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") + if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []: + raise RuntimeError( + "You should specify at least one CNN keys or MLP keys from the cli: " + "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" + ) + fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + + is_continuous = isinstance(action_space, gym.spaces.Box) + is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) + actions_dim = ( + action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) + ) + # Create the actor and critic models + world_model, actor_task, _, _, _ = build_models( + fabric, + actions_dim, + is_continuous, + cfg, + observation_space, + state["world_model"], + state["actor_task"], + ) + player = PlayerDV1( + world_model.encoder.module, + world_model.rssm.recurrent_model.module, + world_model.rssm.representation_model.module, + actor_task.module, + actions_dim, + cfg.algo.player.expl_amount, + cfg.env.num_envs, + cfg.algo.world_model.stochastic_size, + cfg.algo.world_model.recurrent_model.recurrent_state_size, + fabric.device, + ) + + test(player, fabric, cfg, log_dir, sample_actions=False) diff --git a/sheeprl/algos/p2e_dv2/evaluate.py b/sheeprl/algos/p2e_dv2/evaluate.py new file mode 100644 index 00000000..8a6b1f99 --- /dev/null +++ b/sheeprl/algos/p2e_dv2/evaluate.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from typing import Any, Dict + +import gymnasium as gym +from lightning import Fabric + +from sheeprl.algos.dreamer_v2.agent import PlayerDV2 +from sheeprl.algos.dreamer_v2.utils import test +from sheeprl.algos.p2e_dv2.agent import build_models +from sheeprl.utils.env import make_env +from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.registry import register_evaluation + + +@register_evaluation(algorithms="p2e_dv2") +def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): + logger = create_tensorboard_logger(fabric, cfg) + if logger and fabric.is_global_zero: + fabric._loggers = [logger] + fabric.logger.log_hyperparams(cfg) + log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + + env = make_env( + cfg, + cfg.seed, + 0, + log_dir, + "test", + vector_env_idx=0, + )() + observation_space = env.observation_space + action_space = env.action_space + + if not isinstance(observation_space, gym.spaces.Dict): + raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") + if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []: + raise RuntimeError( + "You should specify at least one CNN keys or MLP keys from the cli: " + "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" + ) + fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + + is_continuous = isinstance(action_space, gym.spaces.Box) + is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) + actions_dim = ( + action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) + ) + # Create the actor and critic models + world_model, actor_task, _, _, _, _, _ = build_models( + fabric, + actions_dim, + is_continuous, + cfg, + observation_space, + state["world_model"], + state["actor_task"], + ) + player = PlayerDV2( + world_model.encoder.module, + world_model.rssm.recurrent_model.module, + world_model.rssm.representation_model.module, + actor_task.module, + actions_dim, + cfg.algo.player.expl_amount, + cfg.env.num_envs, + cfg.algo.world_model.stochastic_size, + cfg.algo.world_model.recurrent_model.recurrent_state_size, + fabric.device, + discrete_size=cfg.algo.world_model.discrete_size, + ) + + test(player, fabric, cfg, log_dir, sample_actions=False) diff --git a/sheeprl/algos/ppo/evaluate.py b/sheeprl/algos/ppo/evaluate.py new file mode 100644 index 00000000..69dbf314 --- /dev/null +++ b/sheeprl/algos/ppo/evaluate.py @@ -0,0 +1,65 @@ +from __future__ import annotations + +from typing import Any, Dict + +import gymnasium as gym +from lightning import Fabric + +from sheeprl.algos.ppo.agent import PPOAgent +from sheeprl.algos.ppo.utils import test +from sheeprl.utils.env import make_env +from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.registry import register_evaluation + + +@register_evaluation(algorithms=["ppo", "ppo_decoupled"]) +def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): + logger = create_tensorboard_logger(fabric, cfg) + if logger and fabric.is_global_zero: + fabric._loggers = [logger] + fabric.logger.log_hyperparams(cfg) + log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + + env = make_env( + cfg, + cfg.seed, + 0, + log_dir, + "test", + vector_env_idx=0, + )() + observation_space = env.observation_space + + if not isinstance(observation_space, gym.spaces.Dict): + raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") + if cfg.cnn_keys.encoder + cfg.mlp_keys.encoder == []: + raise RuntimeError( + "You should specify at least one CNN keys or MLP keys from the cli: " + "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" + ) + fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + + is_continuous = isinstance(env.action_space, gym.spaces.Box) + is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete) + actions_dim = ( + env.action_space.shape + if is_continuous + else (env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]) + ) + # Create the actor and critic models + agent = PPOAgent( + actions_dim=actions_dim, + obs_space=observation_space, + encoder_cfg=cfg.algo.encoder, + actor_cfg=cfg.algo.actor, + critic_cfg=cfg.algo.critic, + cnn_keys=cfg.cnn_keys.encoder, + mlp_keys=cfg.mlp_keys.encoder, + screen_size=cfg.env.screen_size, + distribution_cfg=cfg.distribution, + is_continuous=is_continuous, + ) + agent.load_state_dict(state["agent"]) + + test(agent, fabric, cfg, log_dir) diff --git a/sheeprl/algos/ppo_recurrent/evaluate.py b/sheeprl/algos/ppo_recurrent/evaluate.py new file mode 100644 index 00000000..b70c1e24 --- /dev/null +++ b/sheeprl/algos/ppo_recurrent/evaluate.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from typing import Any, Dict + +import gymnasium as gym +from lightning import Fabric + +from sheeprl.algos.ppo_recurrent.agent import RecurrentPPOAgent +from sheeprl.algos.ppo_recurrent.utils import test +from sheeprl.utils.env import make_env +from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.registry import register_evaluation + + +@register_evaluation(algorithms="ppo_recurrent") +def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): + logger = create_tensorboard_logger(fabric, cfg) + if logger and fabric.is_global_zero: + fabric._loggers = [logger] + fabric.logger.log_hyperparams(cfg) + log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + + env = make_env( + cfg, + cfg.seed, + 0, + log_dir, + "test", + vector_env_idx=0, + )() + observation_space = env.observation_space + + if not isinstance(observation_space, gym.spaces.Dict): + raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") + if cfg.cnn_keys.encoder + cfg.mlp_keys.encoder == []: + raise RuntimeError( + "You should specify at least one CNN keys or MLP keys from the cli: " + "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" + ) + fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + + is_continuous = isinstance(env.action_space, gym.spaces.Box) + is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete) + actions_dim = ( + env.action_space.shape + if is_continuous + else (env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]) + ) + # Create the actor and critic models + agent = RecurrentPPOAgent( + actions_dim=actions_dim, + obs_space=observation_space, + encoder_cfg=cfg.algo.encoder, + rnn_cfg=cfg.algo.rnn, + actor_cfg=cfg.algo.actor, + critic_cfg=cfg.algo.critic, + cnn_keys=cfg.cnn_keys.encoder, + mlp_keys=cfg.mlp_keys.encoder, + is_continuous=is_continuous, + distribution_cfg=cfg.distribution, + num_envs=cfg.env.num_envs, + screen_size=cfg.env.screen_size, + device=fabric.device, + ) + agent.load_state_dict(state["agent"]) + + test(agent, fabric, cfg, log_dir) diff --git a/sheeprl/algos/sac/evaluate.py b/sheeprl/algos/sac/evaluate.py new file mode 100644 index 00000000..e952e60f --- /dev/null +++ b/sheeprl/algos/sac/evaluate.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from math import prod +from typing import Any, Dict + +import gymnasium as gym +from lightning import Fabric + +from sheeprl.algos.sac.agent import SACActor, SACAgent, SACCritic +from sheeprl.algos.sac.utils import test +from sheeprl.utils.env import make_env +from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.registry import register_evaluation + + +@register_evaluation(algorithms=["sac", "sac_decoupled"]) +def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): + logger = create_tensorboard_logger(fabric, cfg) + if logger and fabric.is_global_zero: + fabric._loggers = [logger] + fabric.logger.log_hyperparams(cfg) + log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + + env = make_env( + cfg, + cfg.seed, + 0, + log_dir, + "test", + vector_env_idx=0, + )() + action_space = env.action_space + observation_space = env.observation_space + if not isinstance(action_space, gym.spaces.Box): + raise ValueError("Only continuous action space is supported for the SAC agent") + if not isinstance(observation_space, gym.spaces.Dict): + raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") + if len(cfg.mlp_keys.encoder) == 0: + raise RuntimeError("You should specify at least one MLP key for the encoder: `mlp_keys.encoder=[state]`") + for k in cfg.mlp_keys.encoder: + if len(observation_space[k].shape) > 1: + raise ValueError( + "Only environments with vector-only observations are supported by the SAC agent. " + f"Provided environment: {cfg.env.id}" + ) + fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + + act_dim = prod(action_space.shape) + obs_dim = sum([prod(observation_space[k].shape) for k in cfg.mlp_keys.encoder]) + actor = SACActor( + observation_dim=obs_dim, + action_dim=act_dim, + distribution_cfg=cfg.distribution, + hidden_size=cfg.algo.actor.hidden_size, + action_low=action_space.low, + action_high=action_space.high, + ) + critics = [ + SACCritic(observation_dim=obs_dim + act_dim, hidden_size=cfg.algo.critic.hidden_size, num_critics=1) + for _ in range(cfg.algo.critic.n) + ] + target_entropy = -act_dim + agent = SACAgent(actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device) + agent.load_state_dict(state["agent"]) + + test(agent.actor, fabric, cfg, log_dir) diff --git a/sheeprl/algos/sac_ae/evaluate.py b/sheeprl/algos/sac_ae/evaluate.py new file mode 100644 index 00000000..30a538cc --- /dev/null +++ b/sheeprl/algos/sac_ae/evaluate.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +import copy +from math import prod +from typing import Any, Dict + +import gymnasium as gym +from lightning import Fabric + +from sheeprl.algos.sac_ae.agent import ( + CNNEncoder, + MLPEncoder, + SACAEAgent, + SACAEContinuousActor, + SACAECritic, + SACAEQFunction, +) +from sheeprl.algos.sac_ae.utils import test_sac_ae +from sheeprl.models.models import MultiEncoder +from sheeprl.utils.env import make_env +from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir +from sheeprl.utils.registry import register_evaluation + + +@register_evaluation(algorithms="sac_ae") +def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): + logger = create_tensorboard_logger(fabric, cfg) + if logger and fabric.is_global_zero: + fabric._loggers = [logger] + fabric.logger.log_hyperparams(cfg) + log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name) + + env = make_env( + cfg, + cfg.seed, + 0, + log_dir, + "test", + vector_env_idx=0, + )() + + observation_space = env.observation_space + action_space = env.action_space + if not isinstance(observation_space, gym.spaces.Dict): + raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}") + if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []: + raise RuntimeError( + "You should specify at least one CNN keys or MLP keys from the cli: " + "`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`" + ) + fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) + fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) + + act_dim = prod(action_space.shape) + target_entropy = -act_dim + + # Define the encoder and decoder and setup them with fabric. + # Then we will set the critic encoder and actor decoder as the unwrapped encoder module: + # we do not need it wrapped with the strategy inside actor and critic + cnn_channels = [prod(observation_space[k].shape[:-2]) for k in cfg.cnn_keys.encoder] + mlp_dims = [observation_space[k].shape[0] for k in cfg.mlp_keys.encoder] + cnn_encoder = ( + CNNEncoder( + in_channels=sum(cnn_channels), + features_dim=cfg.algo.encoder.features_dim, + keys=cfg.cnn_keys.encoder, + screen_size=cfg.env.screen_size, + cnn_channels_multiplier=cfg.algo.encoder.cnn_channels_multiplier, + ) + if cfg.cnn_keys.encoder is not None and len(cfg.cnn_keys.encoder) > 0 + else None + ) + mlp_encoder = ( + MLPEncoder( + sum(mlp_dims), + cfg.mlp_keys.encoder, + cfg.algo.encoder.dense_units, + cfg.algo.encoder.mlp_layers, + eval(cfg.algo.encoder.dense_act), + cfg.algo.encoder.layer_norm, + ) + if cfg.mlp_keys.encoder is not None and len(cfg.mlp_keys.encoder) > 0 + else None + ) + encoder = MultiEncoder(cnn_encoder, mlp_encoder) + + # Setup actor and critic. Those will initialize with orthogonal weights + # both the actor and critic + actor = SACAEContinuousActor( + encoder=copy.deepcopy(encoder), + action_dim=act_dim, + distribution_cfg=cfg.distribution, + hidden_size=cfg.algo.actor.hidden_size, + action_low=action_space.low, + action_high=action_space.high, + ) + qfs = [ + SACAEQFunction( + input_dim=encoder.output_dim, action_dim=act_dim, hidden_size=cfg.algo.critic.hidden_size, output_dim=1 + ) + for _ in range(cfg.algo.critic.n) + ] + critic = SACAECritic(encoder=encoder, qfs=qfs) + + # The agent will tied convolutional and linear weights between the encoder actor and critic + agent = SACAEAgent( + actor, + critic, + target_entropy, + alpha=cfg.algo.alpha.alpha, + tau=cfg.algo.tau, + encoder_tau=cfg.algo.encoder.tau, + device=fabric.device, + ) + agent.load_state_dict(state["agent"]) + encoder.load_state_dict(state["encoder"]) + + test_sac_ae(agent.actor, fabric, cfg, log_dir) diff --git a/sheeprl/available_agents.py b/sheeprl/available_agents.py index d279e147..8c815b52 100644 --- a/sheeprl/available_agents.py +++ b/sheeprl/available_agents.py @@ -2,7 +2,7 @@ from rich.console import Console from rich.table import Table - from sheeprl.utils.registry import tasks + from sheeprl.utils.registry import algorithm_registry table = Table(title="SheepRL Agents") table.add_column("Module") @@ -10,7 +10,7 @@ table.add_column("Entrypoint") table.add_column("Decoupled") - for module, implementations in tasks.items(): + for module, implementations in algorithm_registry.items(): for algo in implementations: table.add_row(module, algo["name"], algo["entrypoint"], str(algo["decoupled"])) diff --git a/sheeprl/cli.py b/sheeprl/cli.py index 251ff4ab..f274d3f7 100644 --- a/sheeprl/cli.py +++ b/sheeprl/cli.py @@ -4,6 +4,7 @@ import pathlib import time import warnings +from pathlib import Path from typing import Any, Dict import hydra @@ -13,7 +14,7 @@ from omegaconf import DictConfig, OmegaConf from sheeprl.utils.metric import MetricAggregator -from sheeprl.utils.registry import tasks +from sheeprl.utils.registry import algorithm_registry, evaluation_registry from sheeprl.utils.timer import timer from sheeprl.utils.utils import dotdict, print_config @@ -52,7 +53,7 @@ def run_algorithm(cfg: Dict[str, Any]): decoupled = False entrypoint = None algo_name = cfg.algo.name - for _module, _algos in tasks.items(): + for _module, _algos in algorithm_registry.items(): for _algo in _algos: if algo_name == _algo["name"]: module = _module @@ -115,6 +116,48 @@ def run_algorithm(cfg: Dict[str, Any]): fabric.launch(command, cfg) +def eval_algorithm(cfg: DictConfig): + """Run the algorithm specified in the configuration. + + Args: + cfg (DictConfig): the loaded configuration. + """ + cfg = dotdict(OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)) + + # TODO: change the number of devices when FSDP will be supported + accelerator = cfg.fabric.get("accelerator", "auto") + fabric: Fabric = hydra.utils.instantiate( + cfg.fabric, accelerator=accelerator, devices=1, num_nodes=1, _convert_="all" + ) + + # Load the checkpoint + state = fabric.load(cfg.checkpoint_path) + + # Given the algorithm's name, retrieve the module where + # 'cfg.algo.name'.py is contained; from there retrieve the + # `register_algorithm`-decorated entrypoint; + # the entrypoint will be launched by Fabric with `fabric.launch(entrypoint)` + module = None + entrypoint = None + algo_name = cfg.algo.name + for _module, _algos in evaluation_registry.items(): + for _algo in _algos: + if algo_name == _algo["name"]: + module = _module + entrypoint = _algo["entrypoint"] + break + if module is None: + raise RuntimeError(f"Given the algorithm named `{algo_name}`, no module has been found to be imported.") + if entrypoint is None: + raise RuntimeError( + f"Given the module and algorithm named `{module}` and `{algo_name}` respectively, " + "no entrypoint has been found to be imported." + ) + task = importlib.import_module(f"{module}.evaluate") + command = task.__dict__[entrypoint] + fabric.launch(command, cfg, state) + + def check_configs(cfg: Dict[str, Any]): """Check the validity of the configuration. @@ -123,7 +166,7 @@ def check_configs(cfg: Dict[str, Any]): """ decoupled = False algo_name = cfg.algo.name - for _, _algos in tasks.items(): + for _, _algos in algorithm_registry.items(): for _algo in _algos: if algo_name == _algo["name"]: decoupled = _algo["decoupled"] @@ -174,13 +217,55 @@ def check_configs(cfg: Dict[str, Any]): ) +def check_configs_evaluation(cfg: DictConfig): + if cfg.checkpoint_path is None: + raise ValueError("You must specify the evaluation checkpoint path") + + @hydra.main(version_base="1.13", config_path="configs", config_name="config") def run(cfg: DictConfig): """SheepRL zero-code command line utility.""" - if cfg.metric.log_level > 0: - print_config(cfg) + print_config(cfg) cfg = dotdict(OmegaConf.to_container(cfg, resolve=True, throw_on_missing=True)) if cfg.checkpoint.resume_from: cfg = resume_from_checkpoint(cfg) check_configs(cfg) run_algorithm(cfg) + + +@hydra.main(version_base="1.13", config_path="configs", config_name="eval_config") +def evaluation(cfg: DictConfig): + # Load the checkpoint configuration + checkpoint_path = Path(cfg.checkpoint_path) + ckpt_cfg = OmegaConf.load(checkpoint_path.parent.parent.parent / ".hydra" / "config.yaml") + + # Merge the two configs + from omegaconf import open_dict + + with open_dict(cfg): + capture_video = getattr(cfg.env, "capture_video", True) + cfg.env = {"capture_video": capture_video, "num_envs": 1} + cfg.exp = {} + cfg.algo = {} + cfg.fabric = { + "devices": 1, + "num_nodes": 1, + "strategy": "auto", + "accelerator": getattr(cfg.fabric, "accelerator", "auto"), + } + + # Merge configs + ckpt_cfg.merge_with(cfg) + + # Update values after merge + ckpt_cfg.run_name = str( + os.path.join( + os.path.basename(checkpoint_path.parent.parent.parent), + os.path.basename(checkpoint_path.parent.parent), + "evaluation", + ) + ) + + # Check the validity of the configuration and run the evaluation + check_configs_evaluation(ckpt_cfg) + eval_algorithm(ckpt_cfg) diff --git a/sheeprl/configs/eval_config.yaml b/sheeprl/configs/eval_config.yaml new file mode 100644 index 00000000..2eeba61a --- /dev/null +++ b/sheeprl/configs/eval_config.yaml @@ -0,0 +1,20 @@ +# @package _global_ + +# specify here default training configuration +defaults: + - _self_ + - override hydra/hydra_logging: disabled + - override hydra/job_logging: disabled + +hydra: + output_subdir: null + run: + dir: . + +fabric: + accelerator: cpu + +env: + capture_video: True + +checkpoint_path: ??? diff --git a/sheeprl/configs/exp/dreamer_v1.yaml b/sheeprl/configs/exp/dreamer_v1.yaml index 436200d3..c2e07d16 100644 --- a/sheeprl/configs/exp/dreamer_v1.yaml +++ b/sheeprl/configs/exp/dreamer_v1.yaml @@ -2,7 +2,7 @@ defaults: - override /algo: dreamer_v1 - - override /env: dmc + - override /env: atari - _self_ # Experiment diff --git a/sheeprl/utils/registry.py b/sheeprl/utils/registry.py index a914e2a2..ab28fd79 100644 --- a/sheeprl/utils/registry.py +++ b/sheeprl/utils/registry.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import sys from typing import Any, Callable, Dict, List @@ -6,10 +8,11 @@ # tasks[module] = [..., {"name": algorithm, "entrypoint": entrypoint, "decoupled": decoupled}] # where `module` and `algorithm` are respectively taken from sheeprl/algos/{module}/{algorithm}.py, # while `entrypoint` is the decorated function -tasks: Dict[str, List[Dict[str, Any]]] = {} +algorithm_registry: Dict[str, List[Dict[str, Any]]] = {} +evaluation_registry: Dict[str, List[Dict[str, Any]]] = {} -def _register(fn: Callable[..., Any], decoupled: bool = False) -> Callable[..., Any]: +def _register_algorithm(fn: Callable[..., Any], decoupled: bool = False) -> Callable[..., Any]: # lookup containing module if fn.__module__ == "__main__": return fn @@ -17,13 +20,61 @@ def _register(fn: Callable[..., Any], decoupled: bool = False) -> Callable[..., module_split = fn.__module__.split(".") algorithm = module_split[-1] module = ".".join(module_split[:-1]) - algos = tasks.get(module, None) - if algos is None: - tasks[module] = [{"name": algorithm, "entrypoint": entrypoint, "decoupled": decoupled}] + registered_algos = algorithm_registry.get(module, None) + if registered_algos is None: + algorithm_registry[module] = [{"name": algorithm, "entrypoint": entrypoint, "decoupled": decoupled}] else: - if algorithm in algos: - raise ValueError(f"The algorithm `{algorithm}` has already been registered!") - tasks[module].append({"name": algorithm, "entrypoint": entrypoint, "decoupled": decoupled}) + algorithm_registry[module].append({"name": algorithm, "entrypoint": entrypoint, "decoupled": decoupled}) + + # add the decorated function to __all__ in algorithm + mod = sys.modules[fn.__module__] + if hasattr(mod, "__all__"): + mod.__all__.append(entrypoint) + else: + mod.__all__ = [entrypoint] + return fn + + +def _register_evaluation(fn: Callable[..., Any], algorithms: str | List[str]) -> Callable[..., Any]: + # lookup containing module + if fn.__module__ == "__main__": + return fn + entrypoint = fn.__name__ + module_split = fn.__module__.split(".") + module = ".".join(module_split[:-1]) + if isinstance(algorithms, str): + algorithms = [algorithms] + # Check that the algorithms which we want to register an evaluation function for + # have been registered as algorithms + registered_algos = algorithm_registry.get(module, None) + if registered_algos is None: + raise ValueError( + f"The evaluation function `{module+'.'+entrypoint}` for the algorithms named `{', '.join(algorithms)}` " + "is going to be registered, but no algorithm has been registered!" + ) + registered_algo_names = {algo["name"] for algo in registered_algos} + if len(set(algorithms) - registered_algo_names) > 0: + raise ValueError( + f"You are trying to register the evaluation function `{module+'.'+entrypoint}` " + f"for algorithms which have not been registered for the module `{module}`!\n" + f"Registered algorithms: {', '.join(registered_algo_names)}\n" + f"Specified algorithms: {', '.join(algorithms)}" + ) + registered_evals = evaluation_registry.get(module, None) + if registered_evals is None: + evaluation_registry[module] = [] + for algorithm in algorithms: + evaluation_registry[module].append({"name": algorithm, "entrypoint": entrypoint}) + else: + for registered_eval in registered_evals: + if registered_eval["name"] in algorithms: + raise ValueError( + f"Cannot register the evaluate function `{module+'.'+entrypoint}` " + f"for the algorithm `{registered_eval['name']}`: " + f"the evaluation function `{module+'.'+registered_eval['entrypoint']}` has already " + f"been registered for the algorithm named `{registered_eval['name']}` in the module `{module}`!" + ) + evaluation_registry[module].extend([{"name": algorithm, "entrypoint": entrypoint} for algorithm in algorithms]) # add the decorated function to __all__ in algorithm mod = sys.modules[fn.__module__] @@ -36,6 +87,13 @@ def _register(fn: Callable[..., Any], decoupled: bool = False) -> Callable[..., def register_algorithm(decoupled: bool = False): def inner_decorator(fn): - return _register(fn, decoupled=decoupled) + return _register_algorithm(fn, decoupled=decoupled) + + return inner_decorator + + +def register_evaluation(algorithms: str | List[str]): + def inner_decorator(fn): + return _register_evaluation(fn, algorithms=algorithms) return inner_decorator diff --git a/sheeprl_eval.py b/sheeprl_eval.py new file mode 100644 index 00000000..12ce6693 --- /dev/null +++ b/sheeprl_eval.py @@ -0,0 +1,4 @@ +from sheeprl.cli import evaluation + +if __name__ == "__main__": + evaluation()