Skip to content

Commit

Permalink
fix: grayscale capture video + diambra async vector env + dependencies (
Browse files Browse the repository at this point in the history
  • Loading branch information
michele-milesi authored Oct 9, 2023
1 parent ee1ac39 commit bd98b41
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 28 deletions.
4 changes: 3 additions & 1 deletion howto/learn_in_diambra.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,11 @@ diambra run -s=4 python sheeprl.py exp=custom_exp env.num_envs=4
>
> When you set the `action_repeat` cli argument greater than one (i.e., the `repeat_action` DIAMBRA wrapper), the `step_ratio` diambra setting is automatically modified to $1$ because it is a DIAMBRA requirement.
>
> You can increase the performance of the DIAMBRA engine with the `env.wrapper.increase_performance` parameter. When set to `True` the engine is faster, but the recorded video will have the dimension specified by the `env.screen_size` parameter.
>
> **Important**
>
> You **must** set the **`sync_env`** cli argument to **`True`**.
> If you want to use the `AsyncVectorEnv` ([https://gymnasium.farama.org/api/vector/#async-vector-env](https://gymnasium.farama.org/api/vector/#async-vector-env)), you **must** set the **`env.wrapper.diambra_settings.splash_screen`** cli argument to **`False`**. Moreover, you must set the number of container to `env.num_envs + 1` (i.e., you must set the `-s` cli argument as specified before).
## Headless machines

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ atari = [
]
minedojo = ["minedojo==0.1", "importlib_resources==5.12.0"]
minerl = ["setuptools==66.0.0", "minerl==0.4.4"]
diambra = ["diambra==0.0.16", "diambra-arena==2.2.1"]
diambra = ["diambra==0.0.16", "diambra-arena==2.2.2"]
crafter = ["crafter==1.8.1"]

[tool.ruff]
Expand Down
6 changes: 4 additions & 2 deletions sheeprl/configs/env/diambra.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ defaults:

# Override from `default` config
id: doapp
frame_stack: 4
sync_env: True
frame_stack: 1
sync_env: False
action_repeat: 1

# Wrapper to be instantiated
Expand All @@ -18,13 +18,15 @@ wrapper:
repeat_action: ${env.action_repeat}
rank: null
log_level: 0
increase_performance: True
diambra_settings:
role: diambra.arena.Roles.P1
step_ratio: 6
difficulty: 4
continue_game: 0.0
show_final: False
outfits: 1
splash_screen: False
diambra_wrappers:
stack_actions: 1
no_op_max: 0
Expand Down
6 changes: 5 additions & 1 deletion sheeprl/envs/diambra.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
diambra_wrappers: Dict[str, Any] = {},
render_mode: str = "rgb_array",
log_level: int = 0,
increase_performance: bool = True,
) -> None:
super().__init__()

Expand Down Expand Up @@ -72,9 +73,12 @@ def __init__(
**{
"flatten": True,
"repeat_action": repeat_action,
"frame_shape": screen_size + (int(grayscale),),
},
)
if increase_performance:
settings.frame_shape = screen_size + (int(grayscale),)
else:
wrappers.frame_shape = screen_size + (int(grayscale),)
self._env = diambra.arena.make(id, settings, wrappers, rank=rank, render_mode=render_mode, log_level=log_level)

# Observation and action space
Expand Down
55 changes: 33 additions & 22 deletions sheeprl/envs/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import copy
import time
from collections import deque
from typing import Any, Callable, Dict, Optional, Sequence, SupportsFloat, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Sequence, SupportsFloat, Tuple, Union

import gymnasium as gym
import numpy as np
from gymnasium.core import Env
from gymnasium.core import Env, RenderFrame


class MaskVelocityWrapper(gym.ObservationWrapper):
Expand Down Expand Up @@ -48,23 +48,22 @@ def __init__(self, env: gym.Env, amount: int = 1):
super().__init__(env)
if amount <= 0:
raise ValueError("`amount` should be a positive integer")
self._env = env
self._amount = amount

@property
def action_repeat(self) -> int:
return self._amount

def __getattr__(self, name):
return getattr(self._env, name)
return getattr(self.env, name)

def step(self, action):
done = False
truncated = False
current_step = 0
total_reward = 0.0
while current_step < self._amount and not (done or truncated):
obs, reward, done, truncated, info = self._env.step(action)
obs, reward, done, truncated, info = self.env.step(action)
total_reward += reward
current_step += 1
return obs, total_reward, done, truncated, info
Expand Down Expand Up @@ -131,19 +130,18 @@ def __init__(self, env: Env, num_stack: int, cnn_keys: Sequence[str], dilation:
raise RuntimeError(
f"Expected an observation space of type gym.spaces.Dict, got: {type(env.observation_space)}"
)
self._env = env
self._num_stack = num_stack
self._cnn_keys = []
self._dilation = dilation
self.observation_space = copy.deepcopy(self._env.observation_space)
for k, v in self._env.observation_space.spaces.items():
self.observation_space = copy.deepcopy(self.env.observation_space)
for k, v in self.env.observation_space.spaces.items():
if cnn_keys and len(v.shape) == 3:
self._cnn_keys.append(k)
self.observation_space[k] = gym.spaces.Box(
np.repeat(self._env.observation_space[k].low[None, ...], num_stack, axis=0),
np.repeat(self._env.observation_space[k].high[None, ...], num_stack, axis=0),
(self._num_stack, *self._env.observation_space[k].shape),
self._env.observation_space[k].dtype,
np.repeat(self.env.observation_space[k].low[None, ...], num_stack, axis=0),
np.repeat(self.env.observation_space[k].high[None, ...], num_stack, axis=0),
(self._num_stack, *self.env.observation_space[k].shape),
self.env.observation_space[k].dtype,
)

if self._cnn_keys is None or len(self._cnn_keys) == 0:
Expand All @@ -156,7 +154,7 @@ def _get_obs(self, key):
return np.stack(list(frames_subset), axis=0)

def step(self, action: Any) -> Tuple[Any, SupportsFloat, bool, bool, Dict[str, Any]]:
obs, reward, done, truncated, infos = self._env.step(action)
obs, reward, done, truncated, infos = self.env.step(action)
for k in self._cnn_keys:
self._frames[k].append(obs[k])
if (
Expand All @@ -174,7 +172,7 @@ def step(self, action: Any) -> Tuple[Any, SupportsFloat, bool, bool, Dict[str, A
def reset(
self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None, **kwargs
) -> Tuple[Any, Dict[str, Any]]:
obs, infos = self._env.reset(seed=seed, **kwargs)
obs, infos = self.env.reset(seed=seed, **kwargs)
[self._frames[k].clear() for k in self._cnn_keys]
for k in self._cnn_keys:
[self._frames[k].append(obs[k]) for _ in range(self._num_stack * self._dilation)]
Expand All @@ -200,25 +198,24 @@ class RewardAsObservationWrapper(gym.Wrapper):

def __init__(self, env: Env) -> None:
super().__init__(env)
self._env = env
reward_range = (
self._env.reward_range or (-np.inf, np.inf) if hasattr(self._env, "reward_range") else (-np.inf, np.inf)
self.env.reward_range or (-np.inf, np.inf) if hasattr(self.env, "reward_range") else (-np.inf, np.inf)
)
# The reward is assumed to be a scalar
if isinstance(self._env.observation_space, gym.spaces.Dict):
if isinstance(self.env.observation_space, gym.spaces.Dict):
self.observation_space = gym.spaces.Dict(
{
"reward": gym.spaces.Box(*reward_range, (1,), np.float32),
**{k: v for k, v in self._env.observation_space.items()},
**{k: v for k, v in self.env.observation_space.items()},
}
)
else:
self.observation_space = gym.spaces.Dict(
{"obs": self._env.observation_space, "reward": gym.spaces.Box(*reward_range, (1,), np.float32)}
{"obs": self.env.observation_space, "reward": gym.spaces.Box(*reward_range, (1,), np.float32)}
)

def __getattr__(self, name):
return getattr(self._env, name)
return getattr(self.env, name)

def _convert_obs(self, obs: Any, reward: Union[float, np.ndarray]) -> Dict[str, Any]:
reward_obs = (np.array(reward) if not isinstance(reward, np.ndarray) else reward).reshape(-1)
Expand All @@ -232,11 +229,25 @@ def _convert_obs(self, obs: Any, reward: Union[float, np.ndarray]) -> Dict[str,
return obs

def step(self, action: Any) -> Tuple[Any, SupportsFloat, bool, bool, Dict[str, Any]]:
obs, reward, done, truncated, infos = self._env.step(action)
obs, reward, done, truncated, infos = self.env.step(action)
return self._convert_obs(obs, copy.deepcopy(reward)), reward, done, truncated, infos

def reset(
self, *, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None
) -> Tuple[Any, Dict[str, Any]]:
obs, infos = self._env.reset(seed=seed, options=options)
obs, infos = self.env.reset(seed=seed, options=options)
return self._convert_obs(obs, 0), infos


class GrayscaleRenderWrapper(gym.Wrapper):
def __init__(self, env: Env):
super().__init__(env)

def render(self) -> Optional[Union[RenderFrame, List[RenderFrame]]]:
frame = super().render()
if isinstance(frame, np.ndarray):
if len(frame.shape) == 2:
frame = frame[..., np.newaxis]
if len(frame.shape) == 3 and frame.shape[-1] == 1:
frame = frame.repeat(3, axis=-1)
return frame
19 changes: 18 additions & 1 deletion sheeprl/utils/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@
import hydra
import numpy as np

from sheeprl.envs.wrappers import ActionRepeat, FrameStack, MaskVelocityWrapper, RewardAsObservationWrapper
from sheeprl.envs.wrappers import (
ActionRepeat,
FrameStack,
GrayscaleRenderWrapper,
MaskVelocityWrapper,
RewardAsObservationWrapper,
)
from sheeprl.utils.imports import _IS_DIAMBRA_ARENA_AVAILABLE, _IS_DIAMBRA_AVAILABLE, _IS_DMC_AVAILABLE

if _IS_DIAMBRA_ARENA_AVAILABLE and _IS_DIAMBRA_AVAILABLE:
Expand Down Expand Up @@ -49,6 +55,15 @@ def thunk() -> gym.Env:
except Exception:
env_spec = ""

if "diambra" in cfg.env.wrapper._target_ and not cfg.env.sync_env:
if cfg.env.wrapper.diambra_settings.pop("splash_screen", True):
warnings.warn(
"You must set the `splash_screen` setting to `False` when using the `AsyncVectorEnv` "
"in `DIAMBRA` environments. The specified `splash_screen` setting is ignored and set "
"to `False`."
)
cfg.env.wrapper.diambra_settings.splash_screen = False

instantiate_kwargs = {}
if "seed" in cfg.env.wrapper:
instantiate_kwargs["seed"] = seed
Expand Down Expand Up @@ -177,6 +192,8 @@ def transform_obs(obs: Dict[str, Any]):
env = gym.wrappers.TimeLimit(env, max_episode_steps=cfg.env.max_episode_steps)
env = gym.wrappers.RecordEpisodeStatistics(env)
if cfg.env.capture_video and rank == 0 and vector_env_idx == 0 and run_name is not None:
if cfg.env.grayscale:
env = GrayscaleRenderWrapper(env)
env = gym.experimental.wrappers.RecordVideoV0(
env, os.path.join(run_name, prefix + "_videos" if prefix else "videos"), disable_logger=True
)
Expand Down

0 comments on commit bd98b41

Please sign in to comment.