Skip to content

Commit

Permalink
Re-worked multi-discrete values
Browse files Browse the repository at this point in the history
  • Loading branch information
bpiwowar committed Aug 28, 2024
1 parent 916f4b7 commit 0e4a6ef
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 28 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# Version 0.5.0

- Changed default for steer discretization steps (7)
- (breaking) Use multi-discrete for item types

# Version 0.4.5

- `center_path_distance` is now relative (to indicate left/right of the path)
Expand Down
7 changes: 4 additions & 3 deletions src/pystk2_gymnasium/stk_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ def __init__(
space["items_position"] = spaces.Box(
-float("inf"), float("inf"), shape=(self.state_items, 3), dtype=np.float32
)
space["items_type"] = spaces.Box(
0, max_enum_value(pystk2.Item), dtype=np.int64, shape=(self.state_items,)
n_item_types = max_enum_value(pystk2.Item)
space["items_type"] = spaces.MultiDiscrete(
[n_item_types for _ in range(self.state_items)]
)
space["karts_position"] = spaces.Box(
-float("inf"), float("inf"), shape=(self.state_karts, 3)
Expand Down Expand Up @@ -133,7 +134,7 @@ class STKDiscreteAction(STKAction):

class DiscreteActionsWrapper(ActionObservationWrapper):
# Wraps the actions
def __init__(self, env: gym.Env, *, acceleration_steps=5, steer_steps=10, **kwargs):
def __init__(self, env: gym.Env, *, acceleration_steps=5, steer_steps=7, **kwargs):
super().__init__(env, **kwargs)

self._action_space = copy.deepcopy(env.action_space)
Expand Down
71 changes: 46 additions & 25 deletions src/pystk2_gymnasium/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
This module contains generic wrappers
"""
from typing import Any, Callable, Dict, SupportsFloat, Tuple
from typing import Any, Callable, Dict, List, SupportsFloat, Tuple

import gymnasium as gym
from gymnasium import spaces
Expand Down Expand Up @@ -30,7 +30,7 @@ def __init__(self, space: gym.Space):
# Flatten the observation space
self.continuous_keys = []
self.shapes = []
self.discrete_keys = []
self.discrete_keys: List[str] = []
self.indices = [0]

continuous_size = 0
Expand All @@ -47,6 +47,10 @@ def __init__(self, space: gym.Space):
if isinstance(value, spaces.Discrete):
self.discrete_keys.append(key)
counts.append(value.n)
elif isinstance(value, spaces.MultiDiscrete):
self.discrete_keys.append(key)
for n in value.nvec:
counts.append(n)
elif isinstance(value, spaces.Box):
self.continuous_keys.append(key)
self.shapes.append(value.shape)
Expand Down Expand Up @@ -80,6 +84,17 @@ def __init__(self, space: gym.Space):
}
)

def discrete(self, observation):
"""Concatenates discrete and multi-discrete keys"""
r = []
for key in self.discrete_keys:
value = observation[key]
if isinstance(value, int):
r.append(value)
else:
r.extend(value)
return r


class FlattenerWrapper(ActionObservationWrapper):
"""Flattens actions and observations."""
Expand All @@ -100,9 +115,7 @@ def __init__(self, env: gym.Env):

def observation(self, observation):
new_obs = {
"discrete": np.array(
[observation[key] for key in self.observation_flattener.discrete_keys]
),
"discrete": np.array(self.observation_flattener.discrete(observation)),
"continuous": np.concatenate(
[
observation[key].flatten()
Expand All @@ -114,9 +127,7 @@ def observation(self, observation):
if self.has_action:
# Transforms from nested action to a flattened
obs_action = observation["action"]
discrete = np.array(
[obs_action[key] for key in self.action_flattener.discrete_keys]
)
discrete = np.array(self.action_flattener.discrete(obs_action))
if self.action_flattener.only_discrete:
new_obs["action"] = discrete
else:
Expand Down Expand Up @@ -214,6 +225,7 @@ def __init__(
self,
env: gym.Env,
*,
keep_original=False,
wrapper_factories: Dict[str, Callable[[gym.Env], Wrapper]],
):
"""Initialize an adapter that use distinct wrappers
Expand All @@ -222,6 +234,7 @@ def __init__(
corresponds to a different agent.
:param env: The base environment
:param keep_original: Keep original space
:param wrapper_factories: Return a wrapper for every key in the
observation/action spaces dictionary. Supported wrappers are
`ActionObservationWrapper`, `ObservationWrapper`, and `ActionWrapper`.
Expand Down Expand Up @@ -257,23 +270,28 @@ def __init__(
wrappers.append(wrapper)
wrapper = wrapper.env

# Change the observation space
self._action_space = spaces.Dict(
{
key: self.wrappers[key][0].action_space
if len(self.wrappers[key]) > 0
else self.mono_envs[key].action_space
for key in self.keys
}
)
self._observation_space = spaces.Dict(
{
key: self.wrappers[key][0].observation_space
if len(self.wrappers[key]) > 0
else self.mono_envs[key].observation_space
for key in self.keys
}
)
# Change the action/observation space
observation_space = {
key: self.wrappers[key][0].observation_space
if len(self.wrappers[key]) > 0
else self.mono_envs[key].observation_space
for key in self.keys
}
action_space = {
key: self.wrappers[key][0].action_space
if len(self.wrappers[key]) > 0
else self.mono_envs[key].action_space
for key in self.keys
}

self.keep_original = keep_original
if keep_original:
for key, mono_env in self.mono_envs.items():
observation_space[f"original/{key}"] = mono_env.observation_space

# Set the action/observation space
self._action_space = spaces.Dict(action_space)
self._observation_space = spaces.Dict(observation_space)

def action(self, actions: WrapperActType) -> ActType:
new_action = {}
Expand All @@ -290,6 +308,9 @@ def observation(self, observations: ObsType) -> WrapperObsType:
new_observation = {}
for key in self.keys:
observation = observations[key]
if self.keep_original:
new_observation[f"original/{key}"] = observation

for wrapper in reversed(self.wrappers[key]):
if isinstance(
wrapper, (gym.ObservationWrapper, ActionObservationWrapper)
Expand Down

0 comments on commit 0e4a6ef

Please sign in to comment.