Skip to content

Commit

Permalink
[gym_jiminy/common] Add 'AdaptLayoutObservation' that generalizes 'Fi…
Browse files Browse the repository at this point in the history
…lterObservation'. (#835)

* [gym_jiminy/common] Add 'AdaptLayoutObservation' that generalizes 'FilterObservation'.
* [gym_jiminy/common] Make sure that the initial pd state is within bounds.
* [misc] Move from 'toml' to 'tomlkit' to fix heterogeneous array support.
* [misc] Update documentation.
* [misc] Fix typing.
  • Loading branch information
duburcqa authored Nov 25, 2024
1 parent 8311b57 commit 17f4302
Show file tree
Hide file tree
Showing 15 changed files with 631 additions and 228 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def compute_reward(self,
# Similarly, `gym.Env` must be last to make sure all the other initialization
# methods are called first.
class InterfaceJiminyEnv(
InterfaceObserver[Obs, EngineObsType],
InterfaceObserver[Obs, EngineObsType], # type: ignore[type-var]
InterfaceController[Act, np.ndarray],
gym.Env[Obs, Act],
Generic[Obs, Act]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
EncoderSensor, array_copyto)

from ..bases import BaseObs, InterfaceJiminyEnv, BaseControllerBlock
from ..utils import fill
from ..utils import zeros, fill


# Name of the n-th position derivative
Expand Down Expand Up @@ -392,6 +392,9 @@ def __init__(self,
# Initialize the controller
super().__init__(name, env, update_ratio)

# Make sure that the state is within bounds
self._command_state[:2] = zeros(self.state_space)

# References to command acceleration for fast access
self._command_acceleration = self._command_state[2]

Expand Down
8 changes: 4 additions & 4 deletions python/gym_jiminy/common/gym_jiminy/common/envs/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def _get_measurements_space(self) -> spaces.Dict:
joint_type = jiminy.get_joint_type(joint)
if joint_type == jiminy.JointModelType.ROTARY_UNBOUNDED:
sensor_position_lower = - np.pi
sensor_position_upper = np.pi
sensor_position_upper = + np.pi
else:
try:
motor = self.robot.motors[sensor.motor_index]
Expand Down Expand Up @@ -1243,7 +1243,7 @@ def evaluate(self,
Optional: `None` by default. If not specified, then a
strongly random seed will be generated by gym.
:param horizon: Horizon of the simulation, namely maximum number of
steps before termination. `None` to disable.
env steps before termination. `None` to disable.
Optional: Disabled by default.
:param enable_stats: Whether to print high-level statistics after the
simulation.
Expand Down Expand Up @@ -1274,13 +1274,13 @@ def evaluate(self,
self._initialize_seed(seed)

# Initialize the simulation
obs, info = self.derived.reset()
env = self.derived
obs, info = env.reset()
action, reward, terminated, truncated = None, None, False, False

# Run the simulation
info_episode = [info]
try:
env = self.derived
while horizon is None or self.num_steps < horizon:
action = policy_fn(
obs, action, reward, terminated, truncated, info)
Expand Down
3 changes: 2 additions & 1 deletion python/gym_jiminy/common/gym_jiminy/common/utils/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,6 +1031,8 @@ def swing_from_vector(
components of quaternions (x, y, z, w) and columns are the N
independent orientations.
"""
# pylint: disable=possibly-used-before-assignment

# Extract individual tilt components
v_x, v_y, v_z = v_a

Expand All @@ -1047,7 +1049,6 @@ def swing_from_vector(
for i, q_i in enumerate(q.T):
swing_from_vector((v_x[i], v_y[i], v_z[i]), q_i)
else:
# pylint: disable=possibly-used-before-assignment
eps_thr = np.sqrt(TWIST_SWING_SINGULAR_THR)
eps_x = -TWIST_SWING_SINGULAR_THR < v_x < TWIST_SWING_SINGULAR_THR
eps_y = -TWIST_SWING_SINGULAR_THR < v_y < TWIST_SWING_SINGULAR_THR
Expand Down
23 changes: 18 additions & 5 deletions python/gym_jiminy/common/gym_jiminy/common/utils/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
TypedDict, Literal, overload, cast)

import h5py
import toml
import tomlkit
import numpy as np
import gymnasium as gym

Expand Down Expand Up @@ -224,11 +224,14 @@ def build_pipeline(env_config: EnvConfig,
:param env_config:
Configuration of the environment, as a dict of type `EnvConfig`.
:param layers_config:
Configuration of the blocks, as a list. The list is ordered from the
lowest level layer to the highest, each element corresponding to the
configuration of a individual layer, as a dict of type `LayerConfig`.
:param root_path: Optional path used as root for loading reference
trajectories from relative path if any. It will raise
an exception if required but not provided.
Optional: `None` by default.
"""
# Define helper to sanitize composition configuration
def sanitize_composition_config(composition_config: CompositionConfig,
Expand Down Expand Up @@ -554,13 +557,23 @@ def load_pipeline(fullpath: Union[str, pathlib.Path]
:param: Fullpath of the configuration file.
"""
# Extract root path from configuration file
fullpath = pathlib.Path(fullpath)
root_path, file_ext = fullpath.parent, fullpath.suffix

# Load configuration file
with open(fullpath, 'r') as f:
if file_ext == '.json':
return build_pipeline(**json.load(f), root_path=root_path)
if file_ext == '.toml':
return build_pipeline(**toml.load(f), root_path=root_path)
# Parse JSON configuration file
all_config = json.load(f)
elif file_ext == '.toml':
# Parse TOML configuration file
all_config = tomlkit.load(f).unwrap()
else:
raise ValueError(f"File extension '{file_ext}' not supported.")

# Build pipeline
return build_pipeline(**all_config, root_path=root_path)
raise ValueError("Only json and toml formats are supported.")


Expand Down
19 changes: 15 additions & 4 deletions python/gym_jiminy/common/gym_jiminy/common/utils/spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
from collections.abc import Mapping, MutableMapping, Sequence, MutableSequence
from typing import (
Any, Dict, Optional, Union, Sequence as SequenceT, Tuple, Literal,
Mapping as MappingT, Iterable, SupportsFloat, TypeVar, Type, Callable,
no_type_check, cast)
Mapping as MappingT, SupportsFloat, TypeVar, Type, Callable, no_type_check,
overload)

import numba as nb
import numpy as np
Expand All @@ -30,7 +30,7 @@


StructNested = Union[MappingT[str, 'StructNested[ValueT]'],
Iterable['StructNested[ValueT]'],
SequenceT['StructNested[ValueT]'],
ValueT]
FieldNested = StructNested[str]
DataNested = StructNested[np.ndarray]
Expand Down Expand Up @@ -210,13 +210,24 @@ def copyto(dst: DataNested, src: DataNested) -> None:
array_copyto(data, value)


@overload
def copy(data: DataNestedT) -> DataNestedT:
...


@overload
def copy(data: gym.Space[DataNestedT]) -> gym.Space[DataNestedT]:
...


def copy(data: Union[DataNestedT, gym.Space[DataNestedT]]
) -> Union[DataNestedT, gym.Space[DataNestedT]]:
"""Shallow copy recursively 'data' from `gym.Space`, so that only leaves
are still references.
:param data: Hierarchical data structure to copy without allocation.
"""
return cast(DataNestedT, tree.unflatten_as(data, tree.flatten(data)))
return tree.unflatten_as(data, tree.flatten(data))


@no_type_check
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# pylint: disable=missing-module-docstring

from .observation_filter import FilterObservation
from .observation_layout import AdaptLayoutObservation, FilterObservation
from .observation_stack import StackObservation
from .normalize import NormalizeAction, NormalizeObservation
from .flatten import FlattenAction, FlattenObservation


__all__ = [
'AdaptLayoutObservation',
'FilterObservation',
'StackObservation',
'NormalizeObservation',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ def __init__(self,
"""
# Find most appropriate dtype if not specified
if dtype is None:
obs_flat = tuple(
value.dtype for value in tree.flatten(env.observation))
if env.observation:
dtype = reduce(np.promote_types, obs_flat)
dtype_all = [
value.dtype for value in tree.flatten(env.observation)]
dtype = reduce(np.promote_types, dtype_all)
else:
dtype = np.float64

Expand Down

This file was deleted.

Loading

0 comments on commit 17f4302

Please sign in to comment.