Skip to content

Commit

Permalink
feat: added super mario bros environment (#204)
Browse files Browse the repository at this point in the history
* feat: added super mario bros environment

* fix: dependencies

* docs: update

* docs: update readme

* fix: remove configs

* fix: super mario wrapper
  • Loading branch information
michele-milesi authored Feb 11, 2024
1 parent 8045b2e commit 7de395f
Show file tree
Hide file tree
Showing 9 changed files with 250 additions and 8 deletions.
18 changes: 11 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ The environments supported by sheeprl are:
| MineDojo | `pip install -e .[minedojo]` | [how_to/minedojo](./howto/learn_in_minedojo.md) | :heavy_check_mark: |
| DIAMBRA | `pip install -e .[diambra]` | [how_to/diambra](./howto/learn_in_diambra.md) | :heavy_check_mark: |
| Crafter | `pip install -e .[crafter]` | https://github.com/danijar/crafter | :heavy_check_mark: |
| Super Mario Bros | `pip install -e .[supermario]` | https://github.com/Kautenja/gym-super-mario-bros/tree/master | :heavy_check_mark: |


## Why
Expand Down Expand Up @@ -262,18 +263,21 @@ source .venv/bin/activate
# if you do not wish to install extras such as mujuco, atari do
pip install "sheeprl @ git+https://github.com/Eclectic-Sheep/sheeprl.git"
# or, to install with atari and mujuco environment support, do
pip install "sheeprl[atari,mujoco,dev] @ git+https://github.com/Eclectic-Sheep/sheeprl.git"
pip install "sheeprl[atari,mujoco,dev] @ git+https://github.com/Eclectic-Sheep/sheeprl.git"
# or, to install with box2d environment support, do
pip install swig
pip install "sheeprl[box2d] @ git+https://github.com/Eclectic-Sheep/sheeprl.git"
pip install "sheeprl[box2d] @ git+https://github.com/Eclectic-Sheep/sheeprl.git"
# or, to install with minedojo environment support, do
pip install "sheeprl[minedojo,dev] @ git+https://github.com/Eclectic-Sheep/sheeprl.git"
pip install "sheeprl[minedojo,dev] @ git+https://github.com/Eclectic-Sheep/sheeprl.git"
# or, to install with minerl environment support, do
pip install "sheeprl[minerl,dev] @ git+https://github.com/Eclectic-Sheep/sheeprl.git"
pip install "sheeprl[minerl,dev] @ git+https://github.com/Eclectic-Sheep/sheeprl.git"
# or, to install with diambra environment support, do
pip install "sheeprl[diambra,dev] @ git+https://github.com/Eclectic-Sheep/sheeprl.git"
pip install "sheeprl[diambra,dev] @ git+https://github.com/Eclectic-Sheep/sheeprl.git"
# or, to install with super mario bros environment support, do
pip install "sheeprl[supermario,dev] @ git+https://github.com/Eclectic-Sheep/sheeprl.git"
# or, to install all extras, do
pip install "sheeprl[atari,mujoco,miedojo,dev,test] @ git+https://github.com/Eclectic-Sheep/sheeprl.git"
pip install swig
pip install "sheeprl[box2d,atari,mujoco,minerl,supermario,dev,test] @ git+https://github.com/Eclectic-Sheep/sheeprl.git"
```

</details>
Expand All @@ -300,7 +304,7 @@ pip install "sheeprl[atari,box2d,mujoco,dev,test] @ git+https://github.com/Eclec
</details>

<details>
<summary>MineRL, MineDojo and DIAMBRA</summary>
<summary>MineRL and MineDojo</summary>

> [!NOTE]
>
Expand Down
2 changes: 1 addition & 1 deletion howto/learn_in_atari.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@ The list of selectable algorithms is given below:
Once you have chosen the algorithm you want to train, you can start the train, for instance, of the ppo agent by running:

```bash
python sheeprl.py exp=ppo env=atari env.id=PongNoFrameskip-v4 algo.cnn_keys.encoder=[rgb] fabric.accelerator=cpu fabric.strategy=ddp fabric.devices=2
python sheeprl.py exp=ppo env=atari env.id=PongNoFrameskip-v4 algo.cnn_keys.encoder=[rgb] fabric.accelerator=cpu fabric.strategy=ddp fabric.devices=2 algo.mlp_keys.encoder=[]
```
55 changes: 55 additions & 0 deletions howto/learn_in_supermario.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
## Install Atari environments
First, we should install the Super Mario Bros environment with:

```bash
pip install .[supermario]
```

For more information: https://github.com/Kautenja/gym-super-mario-bros/tree/master

## Environment Config
The default configurations of the Super Mario Bros environment are in the `./sheeprl/configs/env/super_mario_bros.yaml` file.

```yaml
defaults:
- default
- _self_

# Override from `default` config
id: SuperMarioBros-v0
frame_stack: 1
sync_env: False
action_repeat: 1

# Wrapper to be instantiated
wrapper:
_target_: sheeprl.envs.super_mario_bros.SuperMarioBrosWrapper
id: ${env.id}
action_space: simple # or complex or right_only
render_mode: rgb_array
```
The parameters under the `wrapper` key are explained below:
- `id`: The id of the environment, check [here](https://github.com/Kautenja/gym-super-mario-bros/tree/master) which environments can be instantiated.
- `action_space`: The actions that can be performed by the agent (always discrete actions). The possible options are: `simple`, `right_only`, or `complex`. Check [here](https://github.com/Kautenja/gym-super-mario-bros/blob/bcb8f10c3e3676118a7364a68f5c0eb287116d7a/gym_super_mario_bros/actions.py) the differences between them.
- `render_mode`: one between `rgb_array` or `human`.


## Train your agent

It is important to remember that not all the algorithms can work with images, so it is necessary to check the first table in the [README](../README.md) and select a proper algorithm.
The list of selectable algorithms is given below:
* `dreamer_v1`
* `dreamer_v2`
* `dreamer_v3`
* `p2e_dv1`
* `p2e_dv2`
* `p2e_dv3`
* `ppo`
* `ppo_decoupled`
* `sac_ae`

Once you have chosen the algorithm you want to train, you can start the train, for instance, of the ppo agent by running:

```bash
python sheeprl.py exp=ppo env=super_mario_bros env.id=SuperMarioBros-v0 algo.cnn_keys.encoder=[rgb] fabric.accelerator=cpu fabric.strategy=ddp fabric.devices=2 algo.mlp_keys.encoder=[]
```
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ minerl = ["setuptools==66.0.0", "minerl==0.4.4"]
diambra = ["diambra==0.0.16", "diambra-arena==2.2.2"]
crafter = ["crafter==1.8.1"]
mlflow = ["mlflow==2.8.0"]
supermario = ["gym-super-mario-bros==7.4.0", "gym<0.26"]

[tool.ruff]
line-length = 120
Expand Down
16 changes: 16 additions & 0 deletions sheeprl/configs/env/super_mario_bros.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
defaults:
- default
- _self_

# Override from `default` config
id: SuperMarioBros-v0
frame_stack: 1
sync_env: False
action_repeat: 1

# Wrapper to be instantiated
wrapper:
_target_: sheeprl.envs.super_mario_bros.SuperMarioBrosWrapper
id: ${env.id}
action_space: simple # or complex or right_only
render_mode: rgb_array
45 changes: 45 additions & 0 deletions sheeprl/configs/exp/dreamer_v3_super_mario_bros.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# @package _global_

defaults:
- dreamer_v3
- override /env: super_mario_bros
- _self_

# Experiment
seed: 5

# Checkpoint
checkpoint:
every: 10000

# Buffer
buffer:
size: 100000
checkpoint: True
memmap: True

# Algorithm
algo:
total_steps: 1000000
cnn_keys:
encoder:
- rgb
mlp_keys:
encoder: []
learning_starts: 16384
train_every: 4
dense_units: 512
mlp_layers: 2
world_model:
encoder:
cnn_channels_multiplier: 32
recurrent_model:
recurrent_state_size: 512
transition_model:
hidden_size: 512
representation_model:
hidden_size: 512

# Metric
metric:
log_every: 5000
51 changes: 51 additions & 0 deletions sheeprl/configs/exp/ppo_super_mario_bros.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# @package _global_

defaults:
- override /algo: ppo
- override /env: super_mario_bros
- override /model_manager: ppo
- _self_

# Environment
env:
num_envs: 8

# Algorithm
algo:
# total_steps: 524288
# total_steps: 262144
total_steps: 1048576
# total_steps: 2097152
max_grad_norm: 0.5
per_rank_batch_size: 256
rollout_steps: 2048
dense_units: 64
encoder:
cnn_features_dim: 512
cnn_keys:
encoder: [rgb]

# Buffer
buffer:
share_data: False
size: ${algo.rollout_steps}

# Checkpoint
checkpoint:
every: 50000

metric:
aggregator:
metrics:
Loss/value_loss:
_target_: torchmetrics.MeanMetric
sync_on_compute: ${metric.sync_on_compute}
Loss/policy_loss:
_target_: torchmetrics.MeanMetric
sync_on_compute: ${metric.sync_on_compute}
Loss/entropy_loss:
_target_: torchmetrics.MeanMetric
sync_on_compute: ${metric.sync_on_compute}

fabric:
accelerator: cuda
69 changes: 69 additions & 0 deletions sheeprl/envs/super_mario_bros.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from __future__ import annotations

from sheeprl.utils.imports import _IS_SUPER_MARIO_BROS_AVAILABLE

if not _IS_SUPER_MARIO_BROS_AVAILABLE:
raise ModuleNotFoundError(_IS_SUPER_MARIO_BROS_AVAILABLE)


from typing import Any, Dict, SupportsFloat, Tuple

import gym_super_mario_bros as gsmb
import gymnasium as gym
import numpy as np
from gym_super_mario_bros.actions import COMPLEX_MOVEMENT, RIGHT_ONLY, SIMPLE_MOVEMENT
from gymnasium.core import RenderFrame
from nes_py.wrappers import JoypadSpace

ACTIONS_SPACE_MAP = {"simple": SIMPLE_MOVEMENT, "right_only": RIGHT_ONLY, "complex": COMPLEX_MOVEMENT}


class JoypadSpaceCustomReset(JoypadSpace):
def reset(self, seed: int | None = None, options: Dict[str, Any] | None = None):
return self.env.reset(seed=seed, options=options)


class SuperMarioBrosWrapper(gym.Wrapper):
def __init__(self, id: str, action_space: str = "simple", render_mode: str = "rgb_array"):
env = gsmb.make(id)
env = JoypadSpaceCustomReset(env, ACTIONS_SPACE_MAP[action_space])
super().__init__(env)

self._render_mode = render_mode
self.observation_space = gym.spaces.Dict(
{
"rgb": gym.spaces.Box(
env.observation_space.low,
env.observation_space.high,
env.observation_space.shape,
env.observation_space.dtype,
)
}
)
self.action_space = gym.spaces.Discrete(env.action_space.n)

@property
def render_mode(self) -> str:
return self._render_mode

@render_mode.setter
def render_mode(self, render_mode: str):
self._render_mode = render_mode

def step(self, action: np.ndarray | int) -> Tuple[Any, SupportsFloat, bool, bool, Dict[str, Any]]:
if isinstance(action, np.ndarray):
action = action.squeeze().item()
obs, reward, done, info = self.env.step(action)
converted_obs = {"rgb": obs.copy()}
return converted_obs, reward, done, False, info

def render(self) -> RenderFrame | list[RenderFrame] | None:
rendered_frame: np.ndarray | None = self.env.render(mode=self.render_mode)
if self.render_mode == "rgb_array" and rendered_frame is not None:
return rendered_frame.copy()
return

def reset(self, *, seed: int | None = None, options: Dict[str, Any] | None = None) -> Tuple[Any, Dict[str, Any]]:
obs = self.env.reset(seed=seed, options=options)
converted_obs = {"rgb": obs.copy()}
return converted_obs, {}
1 change: 1 addition & 0 deletions sheeprl/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@
_IS_MINEDOJO_AVAILABLE = RequirementCache("minedojo")
_IS_MINERL_0_4_4_AVAILABLE = RequirementCache("minerl==0.4.4")
_IS_MLFLOW_AVAILABLE = RequirementCache("mlflow>=2.8", "mlflow")
_IS_SUPER_MARIO_BROS_AVAILABLE = RequirementCache("gym-super-mario-bros==7.4.0")
_IS_TORCH_GREATER_EQUAL_2_0 = RequirementCache("torch>=2.0")
_IS_WINDOWS = platform.system() == "Windows"

0 comments on commit 7de395f

Please sign in to comment.