From 32dee9139763f005fbcda7116bb584d1121a442d Mon Sep 17 00:00:00 2001 From: Benjamin Piwowarski Date: Tue, 12 Nov 2024 14:10:01 +0100 Subject: [PATCH] enh: boost start support --- pyproject.toml | 2 +- src/pystk2_gymnasium/envs.py | 32 +++++++++++++++++++++++++++ src/pystk2_gymnasium/pystk_process.py | 2 +- src/pystk2_gymnasium/utils.py | 6 ++++- 4 files changed, 39 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index f8480b3..aa76847 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,5 +23,5 @@ format-jinja = "{% if distance == 0 %}{{ base }}{% else %}{{ base }}+{{ distance [tool.poetry.dependencies] python = "^3.8" -PySuperTuxKart2 = ">=0.3.8" +PySuperTuxKart2 = ">=0.4.0" gymnasium = ">0.29.0" diff --git a/src/pystk2_gymnasium/envs.py b/src/pystk2_gymnasium/envs.py index ef1e736..5c168ce 100644 --- a/src/pystk2_gymnasium/envs.py +++ b/src/pystk2_gymnasium/envs.py @@ -1,3 +1,4 @@ +from enum import Enum import logging import functools from typing import Any, ClassVar, Dict, List, Optional, Tuple, TypedDict @@ -37,10 +38,38 @@ def kart_action_space(): ) +class Phase(Enum): + """A phase in PySTK (subset of STK phases)""" + + # 'Ready' is displayed + READY_PHASE = 0 + + # 'Set' is displayed + SET_PHASE = 1 + + # 'Go' is displayed, but this is already race phase + GO_PHASE = 2 + + # Other phases + RACE_PHASE = 3 + + @staticmethod + def from_stk(source: pystk2.WorldState.Phase): + if (source is None) or (source == pystk2.WorldState.Phase.READY_PHASE): + return Phase.READY_PHASE + if source == pystk2.WorldState.Phase.SET_PHASE: + return Phase.SET_PHASE + if source == pystk2.WorldState.Phase.GO_PHASE: + return Phase.GO_PHASE + return Phase.RACE_PHASE + + @functools.lru_cache def kart_observation_space(use_ai: bool): space = spaces.Dict( { + "aux_ticks": spaces.Box(0.0, float("inf"), dtype=np.float32, shape=(1,)), + "phase": spaces.Discrete(max_enum_value(Phase)), "powerup": spaces.Discrete(max_enum_value(pystk2.Powerup)), # Last attachment... is no attachment "attachment": spaces.Discrete(max_enum_value(pystk2.Attachment)), @@ -328,6 +357,9 @@ def sort_closest(positions, *lists): return { **obs, + # World properties + "phase": Phase.from_stk(self.world.phase).value, + "aux_ticks": np.array([self.world.aux_ticks], dtype=np.float32), # Kart properties "powerup": kart.powerup.num, "attachment": kart.attachment.type.value, diff --git a/src/pystk2_gymnasium/pystk_process.py b/src/pystk2_gymnasium/pystk_process.py index 7935df4..fd75b6d 100644 --- a/src/pystk2_gymnasium/pystk_process.py +++ b/src/pystk2_gymnasium/pystk_process.py @@ -70,7 +70,7 @@ def warmup_race(self, config) -> pystk2.Track: while True: self.race.step() self.world.update() - if self.world.phase == pystk2.WorldState.Phase.GO_PHASE: + if self.world.phase == pystk2.WorldState.Phase.READY_PHASE: break track.update() diff --git a/src/pystk2_gymnasium/utils.py b/src/pystk2_gymnasium/utils.py index 50bd016..dee9a03 100644 --- a/src/pystk2_gymnasium/utils.py +++ b/src/pystk2_gymnasium/utils.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Type import numpy as np import gymnasium.spaces as spaces @@ -48,7 +49,10 @@ def rotate(v: np.array, q: np.array): def max_enum_value(EnumType: Type): """Returns the maximum enum value in a given enum type""" - return max([v.value for v in EnumType.Type.__members__.values()]) + 1 + if not issubclass(EnumType, Enum): + EnumType = EnumType.Type + + return max([v.value for v in EnumType.__members__.values()]) + 1 class Discretizer: