Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a stochastic action wrapper along with its test #355

Merged
merged 13 commits into from
May 31, 2023
26 changes: 25 additions & 1 deletion minigrid/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import gymnasium as gym
import numpy as np
from gymnasium import logger, spaces
from gymnasium.core import ObservationWrapper, ObsType, Wrapper
from gymnasium.core import ActionWrapper, ObservationWrapper, ObsType, Wrapper

from minigrid.core.constants import COLOR_TO_IDX, OBJECT_TO_IDX, STATE_TO_IDX
from minigrid.core.world_object import Goal
Expand Down Expand Up @@ -764,3 +764,27 @@ def observation(self, obs):
obs["image"] = grid

return obs


class StochasticActionWrapper(ActionWrapper):
"""
Add stochasticity to the actions

If a random action is provided, it is returned with probability `1 - prob`.
Else, a random action is sampled from the action space.
"""

def __init__(self, env=None, prob=0.9, random_action=None):
super().__init__(env)
self.prob = prob
self.random_action = random_action

def action(self, action):
""" """
if np.random.uniform() < self.prob:
return action
else:
if self.random_action is None:
return self.np_random.integers(0, high=6)
else:
return self.random_action
18 changes: 18 additions & 0 deletions tests/test_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
ReseedWrapper,
RGBImgObsWrapper,
RGBImgPartialObsWrapper,
StochasticActionWrapper,
SymbolicObsWrapper,
ViewSizeWrapper,
)
Expand Down Expand Up @@ -329,6 +330,23 @@ def test_symbolic_obs_wrapper(env_id):
env.close()


@pytest.mark.parametrize("env_id", ["MiniGrid-Empty-16x16-v0"])
def test_stochastic_action_wrapper(env_id):
env = gym.make(env_id)
env = StochasticActionWrapper(env, prob=0.2)
_, _ = env.reset()
for _ in range(20):
_, _, _, _, _ = env.step(0)
env.close()

env = gym.make(env_id)
env = StochasticActionWrapper(env, prob=0.2, random_action=1)
_, _ = env.reset()
for _ in range(20):
_, _, _, _, _ = env.step(0)
env.close()


def test_dict_observation_space_doesnt_clash_with_one_hot():
env = gym.make("MiniGrid-Empty-5x5-v0")
env = OneHotPartialObsWrapper(env)
Expand Down