Skip to content

Commit

Permalink
[gym_jiminy/common] Support string representation of enums in pipelin…
Browse files Browse the repository at this point in the history
…e config.
  • Loading branch information
duburcqa committed Nov 27, 2024
1 parent d0a291b commit ee105c2
Showing 1 changed file with 46 additions and 27 deletions.
73 changes: 46 additions & 27 deletions python/gym_jiminy/common/gym_jiminy/common/utils/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@
import gymnasium as gym

import jiminy_py.core as jiminy
import pinocchio as pin
from jiminy_py.dynamics import State, Trajectory

from ..bases import (InterfaceJiminyEnv,
from ..quantities import EnergyGenerationMode
from ..bases import (QuantityEvalMode,
InterfaceJiminyEnv,
InterfaceBlock,
InterfaceQuantity,
BaseControllerBlock,
Expand All @@ -41,6 +44,13 @@
from ..envs import BaseJiminyEnv


ENUM_TYPES = (EnergyGenerationMode,
QuantityEvalMode,
pin.KinematicLevel)
ENUM_NAME_TO_MODULE_MAP = {enum_type.__name__: enum_type.__module__.split(".")
for enum_type in ENUM_TYPES}


class CompositionConfig(TypedDict, total=False):
"""Store information required for instantiating a given composition, which
comprises reward components or a termination condition at the time being.
Expand Down Expand Up @@ -236,6 +246,34 @@ def build_pipeline(env_config: EnvConfig,
an exception if required but not provided.
Optional: `None` by default.
"""
# Define helper to replace enums string by its corresponding object value
def sanitize_enum_string(kwargs: Dict[str, Any]) -> None:
"""Replace in-place enum string representation with their object
counterpart.
:param kwargs: Nested dictionary of options.
"""
for key, value in kwargs.items():
if isinstance(value, dict):
sanitize_enum_string(value)
continue

if not isinstance(value, str):
continue

if value == "none":
kwargs[key] = None
continue

value_path = value.split(".")
enum_type = value_path[-2] if len(value_path) > 1 else None
if enum_type in ENUM_NAME_TO_MODULE_MAP.keys():
for path_ in ENUM_NAME_TO_MODULE_MAP[enum_type][::-1]:
if path_ not in value_path:
value_path.insert(0, path_)
kwargs[key] = locate(".".join(value_path))
continue

# Define helper to sanitize composition configuration
def sanitize_composition_config(composition_config: CompositionConfig,
is_reward: bool) -> None:
Expand All @@ -259,6 +297,9 @@ def sanitize_composition_config(composition_config: CompositionConfig,
# Get its constructor keyword-arguments
kwargs = composition_config.get("kwargs", {})

# Special treatment for "none" and enum string
sanitize_enum_string(kwargs)

# Special handling for `MixtureReward`
if is_reward and issubclass(cls, MixtureReward):
for component_config in kwargs["components"]:
Expand Down Expand Up @@ -300,10 +341,8 @@ def build_composition(
# Get its constructor keyword-arguments
kwargs = composition_config.get("kwargs", {}).copy()

# Special treatment for "none"
for key, value in kwargs.items():
if isinstance(value, str) and value == "none":
kwargs[key] = None
# Special treatment for "none" and enum string
sanitize_enum_string(kwargs)

# Special handling for `MixtureReward`
if is_reward and issubclass(cls, MixtureReward):
Expand Down Expand Up @@ -521,29 +560,9 @@ def build_controller_observer_layer(
block_kwargs = block_config.get("kwargs", {})
wrapper_kwargs = wrapper_config.get("kwargs", {})

# Special treatment for some values
# Special treatment for "none" and enum string
for kwargs in (block_kwargs, wrapper_kwargs):
for key, value in kwargs.items():
if not isinstance(value, str):
continue

if value == "none":
kwargs[key] = None
continue

value_path = value.split(".")
enum_type = value_path[-2] if len(value_path) > 1 else None
if enum_type in ("QuantityEvalMode", "KinematicLevel"):
module_path: Sequence[str]
if enum_type == "QuantityEvalMode":
module_path = ("gym_jiminy", "common", "bases")
else:
module_path = ("pinocchio",)
for path_ in module_path[::-1]:
if path_ not in value_path:
value_path.insert(0, path_)
kwargs[key] = locate(".".join(value_path))
continue
sanitize_enum_string(kwargs)

# Special treatment for "quantity" arg of `QuantityObserver` blocks
if block_cls_ is not None and issubclass(block_cls_, QuantityObserver):
Expand Down

0 comments on commit ee105c2

Please sign in to comment.