-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: added super mario bros environment (#204)
* feat: added super mario bros environment * fix: dependencies * docs: update * docs: update readme * fix: remove configs * fix: super mario wrapper
- Loading branch information
1 parent
8045b2e
commit 7de395f
Showing
9 changed files
with
250 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
## Install Atari environments | ||
First, we should install the Super Mario Bros environment with: | ||
|
||
```bash | ||
pip install .[supermario] | ||
``` | ||
|
||
For more information: https://github.com/Kautenja/gym-super-mario-bros/tree/master | ||
|
||
## Environment Config | ||
The default configurations of the Super Mario Bros environment are in the `./sheeprl/configs/env/super_mario_bros.yaml` file. | ||
|
||
```yaml | ||
defaults: | ||
- default | ||
- _self_ | ||
|
||
# Override from `default` config | ||
id: SuperMarioBros-v0 | ||
frame_stack: 1 | ||
sync_env: False | ||
action_repeat: 1 | ||
|
||
# Wrapper to be instantiated | ||
wrapper: | ||
_target_: sheeprl.envs.super_mario_bros.SuperMarioBrosWrapper | ||
id: ${env.id} | ||
action_space: simple # or complex or right_only | ||
render_mode: rgb_array | ||
``` | ||
The parameters under the `wrapper` key are explained below: | ||
- `id`: The id of the environment, check [here](https://github.com/Kautenja/gym-super-mario-bros/tree/master) which environments can be instantiated. | ||
- `action_space`: The actions that can be performed by the agent (always discrete actions). The possible options are: `simple`, `right_only`, or `complex`. Check [here](https://github.com/Kautenja/gym-super-mario-bros/blob/bcb8f10c3e3676118a7364a68f5c0eb287116d7a/gym_super_mario_bros/actions.py) the differences between them. | ||
- `render_mode`: one between `rgb_array` or `human`. | ||
|
||
|
||
## Train your agent | ||
|
||
It is important to remember that not all the algorithms can work with images, so it is necessary to check the first table in the [README](../README.md) and select a proper algorithm. | ||
The list of selectable algorithms is given below: | ||
* `dreamer_v1` | ||
* `dreamer_v2` | ||
* `dreamer_v3` | ||
* `p2e_dv1` | ||
* `p2e_dv2` | ||
* `p2e_dv3` | ||
* `ppo` | ||
* `ppo_decoupled` | ||
* `sac_ae` | ||
|
||
Once you have chosen the algorithm you want to train, you can start the train, for instance, of the ppo agent by running: | ||
|
||
```bash | ||
python sheeprl.py exp=ppo env=super_mario_bros env.id=SuperMarioBros-v0 algo.cnn_keys.encoder=[rgb] fabric.accelerator=cpu fabric.strategy=ddp fabric.devices=2 algo.mlp_keys.encoder=[] | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
defaults: | ||
- default | ||
- _self_ | ||
|
||
# Override from `default` config | ||
id: SuperMarioBros-v0 | ||
frame_stack: 1 | ||
sync_env: False | ||
action_repeat: 1 | ||
|
||
# Wrapper to be instantiated | ||
wrapper: | ||
_target_: sheeprl.envs.super_mario_bros.SuperMarioBrosWrapper | ||
id: ${env.id} | ||
action_space: simple # or complex or right_only | ||
render_mode: rgb_array |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
# @package _global_ | ||
|
||
defaults: | ||
- dreamer_v3 | ||
- override /env: super_mario_bros | ||
- _self_ | ||
|
||
# Experiment | ||
seed: 5 | ||
|
||
# Checkpoint | ||
checkpoint: | ||
every: 10000 | ||
|
||
# Buffer | ||
buffer: | ||
size: 100000 | ||
checkpoint: True | ||
memmap: True | ||
|
||
# Algorithm | ||
algo: | ||
total_steps: 1000000 | ||
cnn_keys: | ||
encoder: | ||
- rgb | ||
mlp_keys: | ||
encoder: [] | ||
learning_starts: 16384 | ||
train_every: 4 | ||
dense_units: 512 | ||
mlp_layers: 2 | ||
world_model: | ||
encoder: | ||
cnn_channels_multiplier: 32 | ||
recurrent_model: | ||
recurrent_state_size: 512 | ||
transition_model: | ||
hidden_size: 512 | ||
representation_model: | ||
hidden_size: 512 | ||
|
||
# Metric | ||
metric: | ||
log_every: 5000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
# @package _global_ | ||
|
||
defaults: | ||
- override /algo: ppo | ||
- override /env: super_mario_bros | ||
- override /model_manager: ppo | ||
- _self_ | ||
|
||
# Environment | ||
env: | ||
num_envs: 8 | ||
|
||
# Algorithm | ||
algo: | ||
# total_steps: 524288 | ||
# total_steps: 262144 | ||
total_steps: 1048576 | ||
# total_steps: 2097152 | ||
max_grad_norm: 0.5 | ||
per_rank_batch_size: 256 | ||
rollout_steps: 2048 | ||
dense_units: 64 | ||
encoder: | ||
cnn_features_dim: 512 | ||
cnn_keys: | ||
encoder: [rgb] | ||
|
||
# Buffer | ||
buffer: | ||
share_data: False | ||
size: ${algo.rollout_steps} | ||
|
||
# Checkpoint | ||
checkpoint: | ||
every: 50000 | ||
|
||
metric: | ||
aggregator: | ||
metrics: | ||
Loss/value_loss: | ||
_target_: torchmetrics.MeanMetric | ||
sync_on_compute: ${metric.sync_on_compute} | ||
Loss/policy_loss: | ||
_target_: torchmetrics.MeanMetric | ||
sync_on_compute: ${metric.sync_on_compute} | ||
Loss/entropy_loss: | ||
_target_: torchmetrics.MeanMetric | ||
sync_on_compute: ${metric.sync_on_compute} | ||
|
||
fabric: | ||
accelerator: cuda |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from __future__ import annotations | ||
|
||
from sheeprl.utils.imports import _IS_SUPER_MARIO_BROS_AVAILABLE | ||
|
||
if not _IS_SUPER_MARIO_BROS_AVAILABLE: | ||
raise ModuleNotFoundError(_IS_SUPER_MARIO_BROS_AVAILABLE) | ||
|
||
|
||
from typing import Any, Dict, SupportsFloat, Tuple | ||
|
||
import gym_super_mario_bros as gsmb | ||
import gymnasium as gym | ||
import numpy as np | ||
from gym_super_mario_bros.actions import COMPLEX_MOVEMENT, RIGHT_ONLY, SIMPLE_MOVEMENT | ||
from gymnasium.core import RenderFrame | ||
from nes_py.wrappers import JoypadSpace | ||
|
||
ACTIONS_SPACE_MAP = {"simple": SIMPLE_MOVEMENT, "right_only": RIGHT_ONLY, "complex": COMPLEX_MOVEMENT} | ||
|
||
|
||
class JoypadSpaceCustomReset(JoypadSpace): | ||
def reset(self, seed: int | None = None, options: Dict[str, Any] | None = None): | ||
return self.env.reset(seed=seed, options=options) | ||
|
||
|
||
class SuperMarioBrosWrapper(gym.Wrapper): | ||
def __init__(self, id: str, action_space: str = "simple", render_mode: str = "rgb_array"): | ||
env = gsmb.make(id) | ||
env = JoypadSpaceCustomReset(env, ACTIONS_SPACE_MAP[action_space]) | ||
super().__init__(env) | ||
|
||
self._render_mode = render_mode | ||
self.observation_space = gym.spaces.Dict( | ||
{ | ||
"rgb": gym.spaces.Box( | ||
env.observation_space.low, | ||
env.observation_space.high, | ||
env.observation_space.shape, | ||
env.observation_space.dtype, | ||
) | ||
} | ||
) | ||
self.action_space = gym.spaces.Discrete(env.action_space.n) | ||
|
||
@property | ||
def render_mode(self) -> str: | ||
return self._render_mode | ||
|
||
@render_mode.setter | ||
def render_mode(self, render_mode: str): | ||
self._render_mode = render_mode | ||
|
||
def step(self, action: np.ndarray | int) -> Tuple[Any, SupportsFloat, bool, bool, Dict[str, Any]]: | ||
if isinstance(action, np.ndarray): | ||
action = action.squeeze().item() | ||
obs, reward, done, info = self.env.step(action) | ||
converted_obs = {"rgb": obs.copy()} | ||
return converted_obs, reward, done, False, info | ||
|
||
def render(self) -> RenderFrame | list[RenderFrame] | None: | ||
rendered_frame: np.ndarray | None = self.env.render(mode=self.render_mode) | ||
if self.render_mode == "rgb_array" and rendered_frame is not None: | ||
return rendered_frame.copy() | ||
return | ||
|
||
def reset(self, *, seed: int | None = None, options: Dict[str, Any] | None = None) -> Tuple[Any, Dict[str, Any]]: | ||
obs = self.env.reset(seed=seed, options=options) | ||
converted_obs = {"rgb": obs.copy()} | ||
return converted_obs, {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters