Skip to content

Commit

Permalink
enh: boost start support
Browse files Browse the repository at this point in the history
  • Loading branch information
bpiwowar committed Nov 12, 2024
1 parent 4626817 commit 32dee91
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@ format-jinja = "{% if distance == 0 %}{{ base }}{% else %}{{ base }}+{{ distance

[tool.poetry.dependencies]
python = "^3.8"
PySuperTuxKart2 = ">=0.3.8"
PySuperTuxKart2 = ">=0.4.0"
gymnasium = ">0.29.0"
32 changes: 32 additions & 0 deletions src/pystk2_gymnasium/envs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from enum import Enum
import logging
import functools
from typing import Any, ClassVar, Dict, List, Optional, Tuple, TypedDict
Expand Down Expand Up @@ -37,10 +38,38 @@ def kart_action_space():
)


class Phase(Enum):
"""A phase in PySTK (subset of STK phases)"""

# 'Ready' is displayed
READY_PHASE = 0

# 'Set' is displayed
SET_PHASE = 1

# 'Go' is displayed, but this is already race phase
GO_PHASE = 2

# Other phases
RACE_PHASE = 3

@staticmethod
def from_stk(source: pystk2.WorldState.Phase):
if (source is None) or (source == pystk2.WorldState.Phase.READY_PHASE):
return Phase.READY_PHASE
if source == pystk2.WorldState.Phase.SET_PHASE:
return Phase.SET_PHASE
if source == pystk2.WorldState.Phase.GO_PHASE:
return Phase.GO_PHASE
return Phase.RACE_PHASE


@functools.lru_cache
def kart_observation_space(use_ai: bool):
space = spaces.Dict(
{
"aux_ticks": spaces.Box(0.0, float("inf"), dtype=np.float32, shape=(1,)),
"phase": spaces.Discrete(max_enum_value(Phase)),
"powerup": spaces.Discrete(max_enum_value(pystk2.Powerup)),
# Last attachment... is no attachment
"attachment": spaces.Discrete(max_enum_value(pystk2.Attachment)),
Expand Down Expand Up @@ -328,6 +357,9 @@ def sort_closest(positions, *lists):

return {
**obs,
# World properties
"phase": Phase.from_stk(self.world.phase).value,
"aux_ticks": np.array([self.world.aux_ticks], dtype=np.float32),
# Kart properties
"powerup": kart.powerup.num,
"attachment": kart.attachment.type.value,
Expand Down
2 changes: 1 addition & 1 deletion src/pystk2_gymnasium/pystk_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def warmup_race(self, config) -> pystk2.Track:
while True:
self.race.step()
self.world.update()
if self.world.phase == pystk2.WorldState.Phase.GO_PHASE:
if self.world.phase == pystk2.WorldState.Phase.READY_PHASE:
break

track.update()
Expand Down
6 changes: 5 additions & 1 deletion src/pystk2_gymnasium/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from enum import Enum
from typing import Type
import numpy as np
import gymnasium.spaces as spaces
Expand Down Expand Up @@ -48,7 +49,10 @@ def rotate(v: np.array, q: np.array):

def max_enum_value(EnumType: Type):
"""Returns the maximum enum value in a given enum type"""
return max([v.value for v in EnumType.Type.__members__.values()]) + 1
if not issubclass(EnumType, Enum):
EnumType = EnumType.Type

return max([v.value for v in EnumType.__members__.values()]) + 1


class Discretizer:
Expand Down

0 comments on commit 32dee91

Please sign in to comment.