Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/evaluate agents #139

Merged
merged 15 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,16 @@ python sheeprl.py fabric.accelerator=cpu fabric.strategy=ddp fabric.devices=2 ex

You can check the available parameters for Lightning Fabric [here](https://lightning.ai/docs/fabric/stable/api/fabric_args.html).

### Evaluate your Agents

You can easily evaluate your trained agents from checkpoints: tranining configurations are retrieved automatically.

```bash
python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt fabric.accelerator=gpu env.capture_video=True
```

For more information, check the corresponding [howto](./howto/eval_your_agent.md).

## :book: Repository structure

The repository is structured as follows:
Expand Down
124 changes: 124 additions & 0 deletions howto/eval_your_agent.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Evaluate your Agents

In this document we give the user some advices to evaluate its agents. To evaluate an agent, it is simply necessary run the `./sheeprl_eval.py` script, by passing the path to the checkpoint of the agent you want to evaluate.

```bash
python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt
```

The agent and the configs used during the traning are loaded automatically. The user can modify only few parameters for evaluation:

1. `fabric` related ones: you can use the accelerator you want for evaluating the agent, you just need specify it in the command. For instance, `python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt fabric.accelerator=gpu` for evaluating the agent on the gpu. If you want to choose the GPU, then you need to define the `CUDA_VISIBLE_DEVICES` environment variable in the `.env` file or set it before running the script. For example, you can execute the following command to evaluate your agent on the GPU with index 2: `CUDA_VISIBLE_DEVICES="2" python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt fabric.accelerator=gpu`. By default the number of devices and nodes is set to 1, while the precision and the plugins are set to the ones set in the checkpoint config.
2. `env.capture_video`: you can decide to capture the video of the episode during the evaluation or not. For instance, `python sheeprl_eval.py checkpoint_path=/path/to/checkpoint.ckpt env.capture_video=Ture` for capturing the video of the evaluation.

All the other parameters are loaded from the checkpoint config file used during the training. Moreover, the following parameters are automatically set during the evaluation:

* `cfg.env.num_envs`, i.e. the number of environments used during the evaluation, is set to 1
* `cfg.fabric.devices`and `cfg.fabric.num_nodes` are set to 1

> **Note**
>
> You cannot modify the number of processes to spawn. The evaluation is made with only 1 process.

## Log Directory
By default the evaluation logs are stored in the same folder of the experiment. Suppose to have trained a `PPO` agent in the `CartPole-v1` environment. The log directory is organized as follows:
```tree
logs
└── runs
└── ppo
└── CartPole-v1
└── 2023-10-27_11-46-05_default_42
├── .hydra
│ ├── config.yaml
│ ├── hydra.yaml
│ └── overrides.yaml
├── cli.log
└── version_0
├── checkpoint
│ ├── ckpt_1024_0.ckpt
│ ├── ckpt_1536_0.ckpt
│ └── ckpt_512_0.ckpt
├── events.out.tfevents.1698399966.72040.0
├── memmap_buffer
│ └── rank_0
│ ├── actions.memmap
│ ├── actions.meta.pt
│ ├── advantages.memmap
│ ├── advantages.meta.pt
│ ├── dones.memmap
│ ├── dones.meta.pt
│ ├── logprobs.memmap
│ ├── logprobs.meta.pt
│ ├── meta.pt
│ ├── returns.memmap
│ ├── returns.meta.pt
│ ├── rewards.memmap
│ ├── rewards.meta.pt
│ ├── state.memmap
│ ├── state.meta.pt
│ ├── values.memmap
│ └── values.meta.pt
└── train_videos
├── rl-video-episode-0.mp4
├── rl-video-episode-1.mp4
└── rl-video-episode-8.mp4
```

Where `./logs/runs/ppo/2023-10-27_11-46-05_default_42` contains your experiment. The evaluation script will create a subfolder, named `evaluation`, in the `./logs/runs/ppo/2023-10-27_11-46-05_default_42/version_0` folder, which will contain all the evaluations of the agents.

For example, if we run two evaluations, then the log directory of the experiment will be as follows:
```diff
logs
└── runs
├── .hydra
│ ├── config.yaml
│ ├── hydra.yaml
│ └── overrides.yaml
├── cli.log
└── ppo
└── CartPole-v1
└── 2023-10-27_11-46-05_default_42
├── .hydra
│ ├── config.yaml
│ ├── hydra.yaml
│ └── overrides.yaml
├── cli.log
└── version_0
├── checkpoint
│ ├── ckpt_1024_0.ckpt
│ ├── ckpt_1536_0.ckpt
│ └── ckpt_512_0.ckpt
+ ├── evaluation
+ │ ├── version_0
+ │ │ ├── events.out.tfevents.1698400212.73839.0
+ │ │ └── test_videos
+ │ │ └── rl-video-episode-0.mp4
+ │ └── version_1
+ │ ├── events.out.tfevents.1698400283.74353.0
+ │ └── test_videos
+ │ └── rl-video-episode-0.mp4
├── events.out.tfevents.1698399966.72040.0
├── memmap_buffer
│ └── rank_0
│ ├── actions.memmap
│ ├── actions.meta.pt
│ ├── advantages.memmap
│ ├── advantages.meta.pt
│ ├── dones.memmap
│ ├── dones.meta.pt
│ ├── logprobs.memmap
│ ├── logprobs.meta.pt
│ ├── meta.pt
│ ├── returns.memmap
│ ├── returns.meta.pt
│ ├── rewards.memmap
│ ├── rewards.meta.pt
│ ├── state.memmap
│ ├── state.meta.pt
│ ├── values.memmap
│ └── values.meta.pt
└── train_videos
├── rl-video-episode-0.mp4
├── rl-video-episode-1.mp4
└── rl-video-episode-8.mp4
```
39 changes: 26 additions & 13 deletions sheeprl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,35 @@
if not _IS_TORCH_GREATER_EQUAL_2_0:
raise ModuleNotFoundError(_IS_TORCH_GREATER_EQUAL_2_0)

# Needed because MineRL 0.4.4 is not compatible with the latest version of numpy
import numpy as np

from sheeprl.algos.dreamer_v1 import dreamer_v1 as dreamer_v1
from sheeprl.algos.dreamer_v2 import dreamer_v2 as dreamer_v2
from sheeprl.algos.dreamer_v3 import dreamer_v3 as dreamer_v3
from sheeprl.algos.droq import droq as droq
from sheeprl.algos.p2e_dv1 import p2e_dv1 as p2e_dv1
from sheeprl.algos.p2e_dv2 import p2e_dv2 as p2e_dv2
from sheeprl.algos.ppo import ppo as ppo
from sheeprl.algos.ppo import ppo_decoupled as ppo_decoupled
from sheeprl.algos.ppo_recurrent import ppo_recurrent as ppo_recurrent
from sheeprl.algos.sac import sac as sac
from sheeprl.algos.sac import sac_decoupled as sac_decoupled
from sheeprl.algos.sac_ae import sac_ae as sac_ae
# fmt: off
from sheeprl.algos.dreamer_v1 import dreamer_v1 # noqa: F401
from sheeprl.algos.dreamer_v2 import dreamer_v2 # noqa: F401
from sheeprl.algos.dreamer_v3 import dreamer_v3 # noqa: F401
from sheeprl.algos.droq import droq # noqa: F401
from sheeprl.algos.p2e_dv1 import p2e_dv1 # noqa: F401
from sheeprl.algos.p2e_dv2 import p2e_dv2 # noqa: F401
from sheeprl.algos.ppo import ppo # noqa: F401
from sheeprl.algos.ppo import ppo_decoupled # noqa: F401
from sheeprl.algos.ppo_recurrent import ppo_recurrent # noqa: F401
from sheeprl.algos.sac import sac # noqa: F401
from sheeprl.algos.sac import sac_decoupled # noqa: F401
from sheeprl.algos.sac_ae import sac_ae # noqa: F401

from sheeprl.algos.dreamer_v1 import evaluate as dreamer_v1_evaluate # noqa: F401, isort:skip
from sheeprl.algos.dreamer_v2 import evaluate as dreamer_v2_evaluate # noqa: F401, isort:skip
from sheeprl.algos.dreamer_v3 import evaluate as dreamer_v3_evaluate # noqa: F401, isort:skip
from sheeprl.algos.droq import evaluate as droq_evaluate # noqa: F401, isort:skip
from sheeprl.algos.p2e_dv1 import evaluate as p2e_dv1_evaluate # noqa: F401, isort:skip
from sheeprl.algos.p2e_dv2 import evaluate as p2e_dv2_evaluate # noqa: F401, isort:skip
from sheeprl.algos.ppo import evaluate as ppo_evaluate # noqa: F401, isort:skip
from sheeprl.algos.ppo_recurrent import evaluate as ppo_recurrent_evaluate # noqa: F401, isort:skip
from sheeprl.algos.sac import evaluate as sac_evaluate # noqa: F401, isort:skip
from sheeprl.algos.sac_ae import evaluate as sac_ae_evaluate # noqa: F401, isort:skip
# fmt: on

# Needed because MineRL 0.4.4 is not compatible with the latest version of numpy
np.float = np.float32
np.int = np.int64
np.bool = bool
Expand Down
1 change: 0 additions & 1 deletion sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
actions_dim = (
action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
)
# observation_shape = observation_space["rgb"].shape
clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r
if not isinstance(observation_space, gym.spaces.Dict):
raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}")
Expand Down
72 changes: 72 additions & 0 deletions sheeprl/algos/dreamer_v1/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from __future__ import annotations

from typing import Any, Dict

import gymnasium as gym
from lightning import Fabric

from sheeprl.algos.dreamer_v1.agent import PlayerDV1, build_models
from sheeprl.algos.dreamer_v2.utils import test
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir
from sheeprl.utils.registry import register_evaluation


@register_evaluation(algorithms="dreamer_v1")
def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]):
logger = create_tensorboard_logger(fabric, cfg)
if logger and fabric.is_global_zero:
fabric._loggers = [logger]
fabric.logger.log_hyperparams(cfg)
log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name)

env = make_env(
cfg,
cfg.seed,
0,
log_dir,
"test",
vector_env_idx=0,
)()
observation_space = env.observation_space
action_space = env.action_space

if not isinstance(observation_space, gym.spaces.Dict):
raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}")
if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []:
raise RuntimeError(
"You should specify at least one CNN keys or MLP keys from the cli: "
"`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`"
)
fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder)
fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder)

is_continuous = isinstance(action_space, gym.spaces.Box)
is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete)
actions_dim = (
action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
)
# Create the actor and critic models
world_model, actor, _ = build_models(
fabric,
actions_dim,
is_continuous,
cfg,
observation_space,
state["world_model"],
state["actor"],
)
player = PlayerDV1(
world_model.encoder.module,
world_model.rssm.recurrent_model.module,
world_model.rssm.representation_model.module,
actor.module,
actions_dim,
cfg.algo.player.expl_amount,
cfg.env.num_envs,
cfg.algo.world_model.stochastic_size,
cfg.algo.world_model.recurrent_model.recurrent_state_size,
fabric.device,
)

test(player, fabric, cfg, log_dir, sample_actions=False)
73 changes: 73 additions & 0 deletions sheeprl/algos/dreamer_v2/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from __future__ import annotations

from typing import Any, Dict

import gymnasium as gym
from lightning import Fabric

from sheeprl.algos.dreamer_v2.agent import PlayerDV2, build_models
from sheeprl.algos.dreamer_v2.utils import test
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import create_tensorboard_logger, get_log_dir
from sheeprl.utils.registry import register_evaluation


@register_evaluation(algorithms="dreamer_v2")
def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]):
logger = create_tensorboard_logger(fabric, cfg)
if logger and fabric.is_global_zero:
fabric._loggers = [logger]
fabric.logger.log_hyperparams(cfg)
log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name)

env = make_env(
cfg,
cfg.seed,
0,
log_dir,
"test",
vector_env_idx=0,
)()
observation_space = env.observation_space
action_space = env.action_space

if not isinstance(observation_space, gym.spaces.Dict):
raise RuntimeError(f"Unexpected observation type, should be of type Dict, got: {observation_space}")
if cfg.cnn_keys.encoder == [] and cfg.mlp_keys.encoder == []:
raise RuntimeError(
"You should specify at least one CNN keys or MLP keys from the cli: "
"`cnn_keys.encoder=[rgb]` or `mlp_keys.encoder=[state]`"
)
fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder)
fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder)

is_continuous = isinstance(action_space, gym.spaces.Box)
is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete)
actions_dim = (
action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
)
# Create the actor and critic models
world_model, actor, _, _ = build_models(
fabric,
actions_dim,
is_continuous,
cfg,
observation_space,
state["world_model"],
state["actor"],
)
player = PlayerDV2(
world_model.encoder.module,
world_model.rssm.recurrent_model.module,
world_model.rssm.representation_model.module,
actor.module,
actions_dim,
cfg.algo.player.expl_amount,
cfg.env.num_envs,
cfg.algo.world_model.stochastic_size,
cfg.algo.world_model.recurrent_model.recurrent_state_size,
fabric.device,
discrete_size=cfg.algo.world_model.discrete_size,
)

test(player, fabric, cfg, log_dir, sample_actions=False)
Loading