From c32b4ee6fe80afcd9992ff2937c91e745de3285f Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Sun, 16 Jul 2023 20:38:54 +0200 Subject: [PATCH] feat: add velocity randomize and exclude reference arguments (#215) * feat: add velocity randomize and exclude reference arguments This commit adds the ability to randomize the velocity reference in the Mujoco environment. It also adds the `exclude_reference_from_observation` argument which can be used to remove the reference from the observation. * refactor: fix several environment bugs --- .../envs/biological/oscillator/README.md | 7 +- .../envs/biological/oscillator/oscillator.py | 44 +++++++--- .../oscillator_complicated/README.md | 7 +- .../oscillator_complicated.py | 48 +++++++---- .../classic_control/cartpole_cost/README.md | 7 +- .../cartpole_cost/cartpole_cost.py | 25 ++++-- stable_gym/envs/mujoco/ant_cost/README.md | 3 +- stable_gym/envs/mujoco/ant_cost/__init__.py | 10 +-- stable_gym/envs/mujoco/ant_cost/ant_cost.py | 86 ++++++++++++------- .../envs/mujoco/half_cheetah_cost/README.md | 6 +- .../envs/mujoco/half_cheetah_cost/__init__.py | 10 +-- .../half_cheetah_cost/half_cheetah_cost.py | 86 ++++++++++++------- stable_gym/envs/mujoco/hopper_cost/README.md | 4 +- .../envs/mujoco/hopper_cost/__init__.py | 4 + .../envs/mujoco/hopper_cost/hopper_cost.py | 84 ++++++++++++++++++ .../envs/mujoco/humanoid_cost/README.md | 1 + .../envs/mujoco/humanoid_cost/__init__.py | 4 + .../mujoco/humanoid_cost/humanoid_cost.py | 84 ++++++++++++++++++ stable_gym/envs/mujoco/swimmer_cost/README.md | 1 + .../envs/mujoco/swimmer_cost/__init__.py | 4 + .../envs/mujoco/swimmer_cost/swimmer_cost.py | 84 ++++++++++++++++++ .../envs/mujoco/walker2d_cost/README.md | 2 + .../envs/mujoco/walker2d_cost/__init__.py | 5 ++ .../mujoco/walker2d_cost/walker2d_cost.py | 84 ++++++++++++++++++ 24 files changed, 574 insertions(+), 126 deletions(-) diff --git a/stable_gym/envs/biological/oscillator/README.md b/stable_gym/envs/biological/oscillator/README.md index 98d9d9bb..a5489a3b 100644 --- a/stable_gym/envs/biological/oscillator/README.md +++ b/stable_gym/envs/biological/oscillator/README.md @@ -12,12 +12,15 @@ By default, the environment returns the following observation: * $p_1$ - The lacI (repressor) protein concentration (Inhibits transcription of tetR gene). * $p_2$ - The tetR (repressor) protein concentration (Inhibits transcription of CI gene). * $p_3$ - The CI (repressor) protein concentration (Inhibits transcription of lacI gene). -* $r$ - The reference we want to follow. -An extra variable will be returned if the `exclude_reference_error_from_observation` flag is set to `False`: +Extra variables will be returned if the `exclude_reference_from_observation` and `exclude_reference_error_from_observation` +flag are set to `False`: +* $r$ - The reference we want to follow. * $r_{error}$ - The error between the state of interest (i.e. $p_1$) and the reference. +The reference can not be excluded when `reference_type` is set to `periodic`. + ## Action space * $u_1$ - Relative intensity of the light signal that induces the Lacl mRNA gene expression. diff --git a/stable_gym/envs/biological/oscillator/oscillator.py b/stable_gym/envs/biological/oscillator/oscillator.py index 1d341bd2..717d8902 100644 --- a/stable_gym/envs/biological/oscillator/oscillator.py +++ b/stable_gym/envs/biological/oscillator/oscillator.py @@ -127,7 +127,8 @@ def __init__( reference_frequency=200, reference_constraint_position=20.0, clip_action=True, - exclude_reference_error_from_observation=True, # NOTE: False in Han et al. 2018. # noqa: E501 + exclude_reference_from_observation=False, + exclude_reference_error_from_observation=False, ): """Initialise a new Oscillator environment instance. @@ -147,12 +148,16 @@ def __init__( dict. clip_action (str, optional): Whether the actions should be clipped if they are greater than the set action limit. Defaults to ``True``. + exclude_reference_from_observation (bool, optional): Whether the reference + should be excluded from the observation. Defaults to ``False``. Can only + be set to ``True`` if ``reference_type`` is ``constant``. exclude_reference_error_from_observation (bool, optional): Whether the error - should be excluded from the observation. Defaults to ``True``. + should be excluded from the observation. Defaults to ``False``. """ super().__init__() # Setup disturber. self._action_clip_warning = False self._clip_action = clip_action + self._exclude_reference_from_observation = exclude_reference_from_observation self._exclude_reference_error_from_observation = ( exclude_reference_error_from_observation ) @@ -162,6 +167,14 @@ def __init__( raise ValueError( "The reference type must be either 'constant' or 'periodic'." ) + assert ( + reference_type.lower() == "periodic" + or reference_type.lower() == "constant" + and not exclude_reference_from_observation + ), ( + "The reference can only be excluded from the observation if the reference " + "type is constant." + ) self.reference_type = reference_type self.t = 0.0 @@ -203,11 +216,14 @@ def __init__( self.delta5 = 0.0 # p2 noise. self.delta6 = 0.0 # p3 noise. - obs_low = np.array([0, 0, 0, 0, 0, 0, 0], dtype=np.float32) - obs_high = np.array([100, 100, 100, 100, 100, 100, 100], dtype=np.float32) + obs_low = np.array([0, 0, 0, 0, 0, 0], dtype=np.float32) + obs_high = np.array([100, 100, 100, 100, 100, 100], dtype=np.float32) + if not self._exclude_reference_from_observation: + obs_low = np.append(obs_low, 0.0).astype(np.float32) + obs_high = np.append(obs_high, 100.0).astype(np.float32) if not self._exclude_reference_error_from_observation: - obs_low = np.append(obs_low, np.float32(-100.0)) - obs_high = np.append(obs_high, np.float32(100.0)) + obs_low = np.append(obs_low, -100.0).astype(np.float32) + obs_high = np.append(obs_high, 100.0).astype(np.float32) self.action_space = spaces.Box( low=np.array([-5.0, -5.0, -5.0], dtype=np.float32), high=np.array([5.0, 5.0, 5.0], dtype=np.float32), @@ -352,9 +368,11 @@ def step(self, action): terminated = bool(cost > self.cost_range.high or cost < self.cost_range.low) # Create observation. - obs = np.array([m1, m2, m3, p1, p2, p3, r1], dtype=np.float32) + obs = np.array([m1, m2, m3, p1, p2, p3], dtype=np.float32) + if not self._exclude_reference_from_observation: + obs = np.append(obs, r1).astype(np.float32) if not self._exclude_reference_error_from_observation: - obs = np.append(obs, np.float32(p1 - r1)) + obs = np.append(obs, p1 - r1).astype(np.float32) # Return state, cost, terminated, truncated and info_dict return ( @@ -419,7 +437,7 @@ def reset( np.append( low, np.zeros( - 1 if self._exclude_reference_error_from_observation else 2, + self.observation_space.shape[0] - low.shape[0], dtype=np.float32, ), ) @@ -429,7 +447,7 @@ def reset( np.append( high, np.zeros( - 1 if self._exclude_reference_error_from_observation else 2, + self.observation_space.shape[0] - low.shape[0], dtype=np.float32, ), ) @@ -448,9 +466,11 @@ def reset( self.t = 0.0 m1, m2, m3, p1, p2, p3 = self.state r1 = self.reference(self.t) - obs = np.array([m1, m2, m3, p1, p2, p3, r1], dtype=np.float32) + obs = np.array([m1, m2, m3, p1, p2, p3], dtype=np.float32) + if not self._exclude_reference_from_observation: + obs = np.append(obs, r1).astype(np.float32) if not self._exclude_reference_error_from_observation: - obs = np.append(obs, np.float32(p1 - r1)) + obs = np.append(obs, p1 - r1).astype(np.float32) # Return initial observation and info_dict. return obs, dict( diff --git a/stable_gym/envs/biological/oscillator_complicated/README.md b/stable_gym/envs/biological/oscillator_complicated/README.md index f47acd4b..c6aa3ffd 100644 --- a/stable_gym/envs/biological/oscillator_complicated/README.md +++ b/stable_gym/envs/biological/oscillator_complicated/README.md @@ -14,12 +14,15 @@ By default, the environment returns the following observation: * $p_2$ - The tetR (repressor) protein concentration (Inhibits transcription of CI gene). * $p_3$ - The CI (repressor) protein concentration (Inhibits transcription of extra protein gene). * $p_4$ - Extra protein concentration (Inhibits transcription of lacI gene). -* $r$ - The reference we want to follow. -An extra variable will be returned if the `exclude_reference_error_from_observation` flag is set to `False`: +Extra variables will be returned if the `exclude_reference_from_observation` and `exclude_reference_error_from_observation` +flag are set to `False`: +* $r$ - The reference we want to follow. * $r_{error}$ - The error between the state of interest (i.e. $p_1$) and the reference. +The reference can not be excluded when `reference_type` is set to `periodic`. + ## Action space * $u_1$ - Relative intensity of the light signal that induces the Lacl mRNA gene expression. diff --git a/stable_gym/envs/biological/oscillator_complicated/oscillator_complicated.py b/stable_gym/envs/biological/oscillator_complicated/oscillator_complicated.py index 99593483..9d523d1c 100644 --- a/stable_gym/envs/biological/oscillator_complicated/oscillator_complicated.py +++ b/stable_gym/envs/biological/oscillator_complicated/oscillator_complicated.py @@ -143,7 +143,8 @@ def __init__( reference_frequency=200, reference_constraint_position=20.0, clip_action=True, - exclude_reference_error_from_observation=True, # NOTE: False in Han et al. 2018. # noqa: E501 + exclude_reference_from_observation=False, + exclude_reference_error_from_observation=False, ): """Initialise a new OscillatorComplicated environment instance. @@ -163,12 +164,16 @@ def __init__( dict. clip_action (str, optional): Whether the actions should be clipped if they are greater than the set action limit. Defaults to ``True``. + exclude_reference_from_observation (bool, optional): Whether the reference + should be excluded from the observation. Defaults to ``False``. Can only + be set to ``True`` if ``reference_type`` is ``constant``. exclude_reference_error_from_observation (bool, optional): Whether the error - should be excluded from the observation. Defaults to ``True``. + should be excluded from the observation. Defaults to ``False``. """ super().__init__() # Setup disturber. self._action_clip_warning = False self._clip_action = clip_action + self._exclude_reference_from_observation = exclude_reference_from_observation self._exclude_reference_error_from_observation = ( exclude_reference_error_from_observation ) @@ -178,6 +183,14 @@ def __init__( raise ValueError( "The reference type must be either 'constant' or 'periodic'." ) + assert ( + reference_type.lower() == "periodic" + or reference_type.lower() == "constant" + and not exclude_reference_from_observation + ), ( + "The reference can only be excluded from the observation if the reference " + "type is constant." + ) self.reference_type = reference_type self.t = 0.0 @@ -227,13 +240,14 @@ def __init__( self.delta7 = 0.0 # p3 noise. self.delta8 = 0.0 # p4 noise. - obs_low = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.float32) - obs_high = np.array( - [100, 100, 100, 100, 100, 100, 100, 100, 100], dtype=np.float32 - ) + obs_low = np.array([0, 0, 0, 0, 0, 0, 0, 0], dtype=np.float32) + obs_high = np.array([100, 100, 100, 100, 100, 100, 100, 100], dtype=np.float32) + if not self._exclude_reference_from_observation: + obs_low = np.append(obs_low, 0.0).astype(np.float32) + obs_high = np.append(obs_high, 100.0).astype(np.float32) if not self._exclude_reference_error_from_observation: - obs_low = np.append(obs_low, np.float32(-100)) - obs_high = np.append(obs_high, np.float32(100)) + obs_low = np.append(obs_low, -100).astype(np.float32) + obs_high = np.append(obs_high, 100).astype(np.float32) self.action_space = spaces.Box( low=np.array([-5.0, -5.0, -5.0, -5.0], dtype=np.float32), high=np.array([5.0, 5.0, 5.0, 5.0], dtype=np.float32), @@ -251,7 +265,7 @@ def __init__( self.steps_beyond_done = None # Reference target and constraint positions. - self.reference_target_pos = reference_target_position # Reference target. + self.reference_target_pos = reference_target_position self.reference_amplitude = reference_amplitude self.reference_frequency = reference_frequency self.reference_constraint_pos = ( @@ -407,9 +421,11 @@ def step(self, action): terminated = bool(cost > self.cost_range.high or cost < self.cost_range.low) # Create observation. - obs = np.array([m1, m2, m3, m4, p1, p2, p3, p4, r1], dtype=np.float32) + obs = np.array([m1, m2, m3, m4, p1, p2, p3, p4], dtype=np.float32) + if not self._exclude_reference_from_observation: + obs = np.append(obs, r1).astype(np.float32) if not self._exclude_reference_error_from_observation: - obs = np.append(obs, np.float32(p1 - r1)) + obs = np.append(obs, p1 - r1).astype(np.float32) # Return state, cost, terminated, truncated and info_dict return ( @@ -474,7 +490,7 @@ def reset( np.append( low, np.zeros( - 1 if self._exclude_reference_error_from_observation else 2, + self.observation_space.shape[0] - low.shape[0], dtype=np.float32, ), ) @@ -484,7 +500,7 @@ def reset( np.append( high, np.zeros( - 1 if self._exclude_reference_error_from_observation else 2, + self.observation_space.shape[0] - low.shape[0], dtype=np.float32, ), ) @@ -503,9 +519,11 @@ def reset( self.t = 0.0 m1, m2, m3, m4, p1, p2, p3, p4 = self.state r1 = self.reference(self.t) - obs = np.array([m1, m2, m3, m4, p1, p2, p3, p4, r1], dtype=np.float32) + obs = np.array([m1, m2, m3, m4, p1, p2, p3, p4], dtype=np.float32) + if not self._exclude_reference_from_observation: + obs = np.append(obs, r1).astype(np.float32) if not self._exclude_reference_error_from_observation: - obs = np.append(obs, np.float32(p1 - r1)) + obs = np.append(obs, p1 - r1).astype(np.float32) # Return initial observation and info_dict. return obs, dict( diff --git a/stable_gym/envs/classic_control/cartpole_cost/README.md b/stable_gym/envs/classic_control/cartpole_cost/README.md index dc073c33..a499a01c 100644 --- a/stable_gym/envs/classic_control/cartpole_cost/README.md +++ b/stable_gym/envs/classic_control/cartpole_cost/README.md @@ -38,12 +38,15 @@ By default, for the reference tracking task the environment returns the followin * $x_{dot}$ - Cart Velocity. * $w$ - Pole angle. * $w_{dot}$ - Pole angle velocity. -* $x_{ref}$ - The cart position reference. -An extra variable will be returned if the `exclude_reference_error_from_observation` flag is set to `False`: +Extra variables will be returned if the `exclude_reference_from_observation` and `exclude_reference_error_from_observation` +flag are set to `False`: +* $x_{ref}$ - The cart position reference. * $x_{ref\_error}$ - The reference tracking error. +The reference can not be excluded when `reference_type` is set to `periodic`. + ## Action space * **u1:** The x-force applied on the cart. diff --git a/stable_gym/envs/classic_control/cartpole_cost/cartpole_cost.py b/stable_gym/envs/classic_control/cartpole_cost/cartpole_cost.py index 03b3808e..ee0d4fbb 100644 --- a/stable_gym/envs/classic_control/cartpole_cost/cartpole_cost.py +++ b/stable_gym/envs/classic_control/cartpole_cost/cartpole_cost.py @@ -163,6 +163,7 @@ def __init__( reference_constraint_position=4.0, max_cost=100.0, clip_action=True, + exclude_reference_from_observation=False, exclude_reference_error_from_observation=True, ): """Initialise a new CartPoleCost environment instance. @@ -188,6 +189,10 @@ def __init__( terminated. Defaults to ``100.0``. clip_action (str, optional): Whether the actions should be clipped if they are greater than the set action limit. Defaults to ``True``. + exclude_reference_from_observation (bool, optional): Whether the reference + should be excluded from the observation. Defaults to ``False``. Can only + be set to ``True`` if ``reference_type`` is ``constant``. Only + used when ``task_type`` is ``reference_tracking``. exclude_reference_error_from_observation (bool, optional): Whether the error should be excluded from the observation. Defaults to ``True``. Only used when ``task_type`` is ``reference_tracking``. @@ -204,6 +209,14 @@ def __init__( raise ValueError( "Invalid reference type. Options are 'constant' and 'periodic'." ) + assert ( + reference_type.lower() == "periodic" + or reference_type.lower() == "constant" + and not exclude_reference_from_observation + ), ( + "The reference can only be excluded from the observation if the reference " + "type is constant." + ) # NOTE: Compared to the original I store the initial values for the reset # function and replace the `self.total_mass` and `self.polemass_length` with @@ -232,6 +245,7 @@ def __init__( # Create observation space bounds. # Angle limit set to 2 * theta_threshold_radians so failing observation # is still within bounds. + self._exclude_reference_from_observation = exclude_reference_from_observation self._exclude_reference_error_from_observation = ( exclude_reference_error_from_observation ) @@ -246,13 +260,10 @@ def __init__( ) # NOTE: When reference tracking add two extra observation states. if task_type.lower() == "reference_tracking": - high = np.append( - high, - np.repeat( - self.x_threshold * 2, - 1 if self._exclude_reference_error_from_observation else 2, - ), - ).astype(np.float32) + if not self._exclude_reference_from_observation: + high = np.append(high, self.x_threshold * 2).astype(np.float32) + if not self._exclude_reference_error_from_observation: + high = np.append(high, self.x_threshold * 2).astype(np.float32) self.action_space = spaces.Box( low=-self.force_mag, high=self.force_mag, shape=(1,), dtype=np.float32 diff --git a/stable_gym/envs/mujoco/ant_cost/README.md b/stable_gym/envs/mujoco/ant_cost/README.md index 14f89928..9d341918 100644 --- a/stable_gym/envs/mujoco/ant_cost/README.md +++ b/stable_gym/envs/mujoco/ant_cost/README.md @@ -9,8 +9,7 @@ An actuated 8-jointed ant. This environment corresponds to the [Ant-v4](https:// * The objective was changed to a velocity-tracking task. To do this, the reward is replaced with a cost. This cost is the squared difference between the Ant's forward velocity and a reference value (error). -* The reference velocity was added to the observation space. -* Two **optional** variables were added to the observation space These are the ant's forward velocity and the error (difference between the ant's forward velocity and the reference velocity). These variables can be enabled using the `exclude_velocity_from_observation` and `exclude_reference_error_from_observation` environment arguments. +* Three **optional** variables were added to the observation space; The reference velocity, the reference error (i.e. the difference between the ant's forward velocity and the reference) and the ant's forward velocity. These variables can be enabled using the `exclude_reference_from_observation`, `exclude_reference_error_from_observation` and `exclude_velocity_from_observation` environment arguments. The rest of the environment is the same as the original Ant environment. Below, the modified cost is described. For more information about the environment (e.g. observation space, action space, episode termination, etc.), please refer to the [gymnasium library](https://gymnasium.farama.org/environments/mujoco/ant/). diff --git a/stable_gym/envs/mujoco/ant_cost/__init__.py b/stable_gym/envs/mujoco/ant_cost/__init__.py index 915bbd7f..c66818c7 100644 --- a/stable_gym/envs/mujoco/ant_cost/__init__.py +++ b/stable_gym/envs/mujoco/ant_cost/__init__.py @@ -4,12 +4,10 @@ - The objective was changed to a velocity-tracking task. To do this, the reward is replaced with a cost. This cost is the squared difference between the Ant's forward velocity and a reference value (error). -- The reference velocity was added to the observation space. -- Two **optional** variables were added to the observation space. These are - the ant's forward velocity and the error (difference between the ant's - forward velocity and the reference velocity). These variables can be enabled using - the ``exclude_velocity_from_observation`` and ``exclude_reference_error_from_observation`` - environment arguments. +- Three **optional** variables were added to the observation space; The reference velocity, the reference error + (i.e. the difference between the ant's forward velocity and the reference) and the ant's forward velocity. + These variables can be enabled using the ``exclude_reference_from_observation``, ``exclude_reference_error_from_observation`` + and ``exclude_velocity_from_observation`` environment arguments. .. _`Han et al. 2020`: https://arxiv.org/abs/2004.14288 """ # noqa: E501 diff --git a/stable_gym/envs/mujoco/ant_cost/ant_cost.py b/stable_gym/envs/mujoco/ant_cost/ant_cost.py index 1d107750..2f777f2a 100644 --- a/stable_gym/envs/mujoco/ant_cost/ant_cost.py +++ b/stable_gym/envs/mujoco/ant_cost/ant_cost.py @@ -30,12 +30,10 @@ class AntCost(AntEnv, utils.EzPickle): is replaced with a cost. This cost is the squared difference between the Ant's forward velocity and a reference value (error). Additionally, also a control cost and health penalty can be included in the cost. - - The reference velocity was added to the observation space. - - Two **optional** variables were added to the observation space. These are - the ant's forward velocity and the error (difference between the ant's - forward velocity and the reference velocity). These variables can be enabled using - the ``exclude_velocity_from_observation`` and ``exclude_reference_error_from_observation`` - environment arguments. + - Three **optional** variables were added to the observation space; The reference velocity, the reference error + (i.e. the difference between the ant's forward velocity and the reference) and the ant's forward velocity. + These variables can be enabled using the ``exclude_reference_from_observation``, + ``exclude_reference_error_from_observation`` and ``exclude_velocity_from_observation`` environment arguments. The rest of the environment is the same as the original Ant environment. Below, the modified cost is described. For more information about the environment @@ -68,8 +66,8 @@ class AntCost(AntEnv, utils.EzPickle): def __init__( self, reference_forward_velocity=1.0, - randomize_reference_forward_velocity=True, - randomize_reference_forward_velocity_range=(0.5, 1.5), + randomise_reference_forward_velocity=False, + randomise_reference_forward_velocity_range=(0.5, 1.5), forward_velocity_weight=1.0, include_ctrl_cost=False, include_health_penalty=True, @@ -82,8 +80,9 @@ def __init__( contact_force_range=(-1.0, 1.0), reset_noise_scale=0.1, exclude_current_positions_from_observation=True, - exclude_x_velocity_from_observation=False, + exclude_reference_from_observation=False, # NOTE: True in Han et al. 2018. # noqa: E501 exclude_reference_error_from_observation=True, + exclude_x_velocity_from_observation=False, # NOTE: True in Han et al. 2018. # noqa: E501 **kwargs, ): """Initialise a new AntCost environment instance. @@ -91,9 +90,9 @@ def __init__( Args: reference_forward_velocity (float, optional): The forward velocity that the agent should try to track. Defaults to ``1.0``. - randomize_reference_forward_velocity (bool, optional): Whether to randomize + randomise_reference_forward_velocity (bool, optional): Whether to randomize the reference forward velocity. Defaults to ``False``. - randomize_reference_forward_velocity_range (tuple, optional): The range of + randomise_reference_forward_velocity_range (tuple, optional): The range of the random reference forward velocity. Defaults to ``(0.5, 1.5)``. forward_velocity_weight (float, optional): The weight used to scale the forward velocity error. Defaults to ``1.0``. @@ -121,27 +120,41 @@ def __init__( the x- and y-coordinates of the front tip from observations. Excluding the position can serve as an inductive bias to induce position-agnostic behaviour in policies. Defaults to ``True``. - exclude_x_velocity_from_observation (bool, optional): Whether to omit the - x- component of the velocity from observations. Defaults to ``False``. + exclude_reference_from_observation (bool, optional): Whether the reference + should be excluded from the observation. Defaults to ``False``. Can only + be set to ``True`` if ``randomise_reference_forward_velocity`` is set to + ``False``. exclude_reference_error_from_observation (bool, optional): Whether the error should be excluded from the observation. Defaults to ``True``. + exclude_x_velocity_from_observation (bool, optional): Whether to omit the + x- component of the velocity from observations. Defaults to ``False``. **kwargs: Extra keyword arguments to pass to the :class:`~gymnasium.envs.mujoco.ant_v4.AntEnv` class. """ self.reference_forward_velocity = reference_forward_velocity - self.randomize_reference_forward_velocity = randomize_reference_forward_velocity - self.randomize_reference_forward_velocity_range = ( - randomize_reference_forward_velocity_range + self.randomise_reference_forward_velocity = randomise_reference_forward_velocity + self.randomise_reference_forward_velocity_range = ( + randomise_reference_forward_velocity_range ) self._forward_velocity_weight = forward_velocity_weight self._include_ctrl_cost = include_ctrl_cost self._use_contact_forces = use_contact_forces self._include_health_penalty = include_health_penalty self._health_penalty_size = health_penalty_size - self._exclude_x_velocity_from_observation = exclude_x_velocity_from_observation + self._exclude_reference_from_observation = exclude_reference_from_observation self._exclude_reference_error_from_observation = ( exclude_reference_error_from_observation ) + self._exclude_x_velocity_from_observation = exclude_x_velocity_from_observation + + # Validate input arguments. + assert ( + not randomise_reference_forward_velocity + or not exclude_reference_from_observation + ), ( + "The reference can only be excluded from the observation if the forward " + "velocity is not randomised." + ) self.state = None @@ -161,14 +174,15 @@ def __init__( # Extend observation space if necessary. low = self.observation_space.low high = self.observation_space.high - low = np.append(low, -np.inf) # NOTE: Added to include reference. - high = np.append(high, np.inf) # NOTE: Added to include reference. - if not self._exclude_x_velocity_from_observation: + if not self._exclude_reference_from_observation: low = np.append(low, -np.inf) high = np.append(high, np.inf) if not self._exclude_reference_error_from_observation: low = np.append(low, -np.inf) high = np.append(high, np.inf) + if not self._exclude_x_velocity_from_observation: + low = np.append(low, -np.inf) + high = np.append(high, np.inf) self.observation_space = gym.spaces.Box(low, high, dtype=np.float32) # Reinitialize the EzPickle class. @@ -177,6 +191,8 @@ def __init__( utils.EzPickle.__init__( self, reference_forward_velocity, + randomise_reference_forward_velocity, + randomise_reference_forward_velocity_range, forward_velocity_weight, include_ctrl_cost, include_health_penalty, @@ -189,6 +205,9 @@ def __init__( contact_force_range, reset_noise_scale, exclude_current_positions_from_observation, + exclude_reference_from_observation, + exclude_reference_error_from_observation, + exclude_x_velocity_from_observation, **kwargs, ) @@ -259,13 +278,14 @@ def step(self, action): cost, cost_info = self.cost(info["x_velocity"], -info["reward_ctrl"]) # Add reference, x velocity and reference error to observation. - obs = np.append(obs, np.float32(self.reference_forward_velocity)) - if not self._exclude_x_velocity_from_observation: - obs = np.append(obs, np.float32(info["x_velocity"])) + if not self._exclude_reference_from_observation: + obs = np.append(obs, self.reference_forward_velocity).astype(np.float32) if not self._exclude_reference_error_from_observation: obs = np.append( - obs, np.float32(info["x_velocity"] - self.reference_forward_velocity) - ) + obs, info["x_velocity"] - self.reference_forward_velocity + ).astype(np.float32) + if not self._exclude_x_velocity_from_observation: + obs = np.append(obs, info["x_velocity"]).astype(np.float32) # Update info. info.update(cost_info) @@ -291,20 +311,20 @@ def reset(self, seed=None, options=None): obs, info = super().reset(seed=seed, options=options) # Randomize the reference forward velocity if requested. - if self.randomize_reference_forward_velocity: + if self.randomise_reference_forward_velocity: self.reference_forward_velocity = self.np_random.uniform( - *self.randomize_reference_forward_velocity_range + *self.randomise_reference_forward_velocity_range ) # Add reference, x velocity and reference error to observation. - obs = np.append(obs, np.float32(self.reference_forward_velocity)) - if not self._exclude_x_velocity_from_observation: - obs = np.append(obs, np.float32(0.0)) + if not self._exclude_reference_from_observation: + obs = np.append(obs, self.reference_forward_velocity).astype(np.float32) if not self._exclude_reference_error_from_observation: - obs = np.append( - obs, - np.float32(0.0 - self.reference_forward_velocity), + obs = np.append(obs, 0.0 - self.reference_forward_velocity).astype( + np.float32 ) + if not self._exclude_x_velocity_from_observation: + obs = np.append(obs, 0.0).astype(np.float32) self.state = obs diff --git a/stable_gym/envs/mujoco/half_cheetah_cost/README.md b/stable_gym/envs/mujoco/half_cheetah_cost/README.md index 9a945e25..3dc1d41f 100644 --- a/stable_gym/envs/mujoco/half_cheetah_cost/README.md +++ b/stable_gym/envs/mujoco/half_cheetah_cost/README.md @@ -7,10 +7,8 @@ An actuated 8-jointed half cheetah. This environment corresponds to the [HalfCheetah-v4](https://gymnasium.farama.org/environments/mujoco/half_cheetah) environment included in the [gymnasium package](https://gymnasium.farama.org/). It is different in the fact that: -* The objective was changed to a velocity-tracking task. To do this, the reward is replaced with a cost. This cost is the squared - difference between the HalfCheetah's forward velocity and a reference value (error). -- The reference velocity was added to the observation space. -- Two **optional** variables were added to the observation space. These are the cheetah's forward velocity and the error (difference between the cheetah's forward velocity and the reference velocity). These variables can be enabled using the `exclude_velocity_from_observation` and `exclude_reference_error_from_observation` environment arguments. +* The objective was changed to a velocity-tracking task. To do this, the reward is replaced with a cost. This cost is the squared difference between the HalfCheetah's forward velocity and a reference value (error). +* Three **optional** variables were added to the observation space; The reference velocity, the reference error (i.e. the difference between the cheetah's forward velocity and the reference) and the cheetah's forward velocity. These variables can be enabled using the `exclude_reference_from_observation`, `exclude_reference_error_from_observation` and `exclude_velocity_from_observation` environment arguments. The rest of the environment is the same as the original HalfCheetah environment. Below, the modified cost is described. For more information about the environment (e.g. observation space, action space, episode termination, etc.), please refer to the [gymnasium library](https://gymnasium.farama.org/environments/mujoco/half_cheetah/). diff --git a/stable_gym/envs/mujoco/half_cheetah_cost/__init__.py b/stable_gym/envs/mujoco/half_cheetah_cost/__init__.py index 3f2596b1..978fc055 100644 --- a/stable_gym/envs/mujoco/half_cheetah_cost/__init__.py +++ b/stable_gym/envs/mujoco/half_cheetah_cost/__init__.py @@ -4,12 +4,10 @@ - The objective was changed to a velocity-tracking task. To do this, the reward is replaced with a cost. This cost is the squared difference between the HalfCheetah's forward velocity and a reference value (error). -- The reference velocity was added to the observation space. -- Two **optional** variables were added to the observation space. These are - the cheetah's forward velocity and the error (difference between the cheetah's - forward velocity and the reference velocity). These variables can be enabled using - the ``exclude_velocity_from_observation`` and ``exclude_reference_error_from_observation`` - environment arguments. +- Three **optional** variables were added to the observation space; The reference velocity, the reference error + (i.e. the difference between the cheetah's forward velocity and the reference) and the cheetah's forward velocity. + These variables can be enabled using the ``exclude_reference_from_observation``, + ``exclude_reference_error_from_observation`` and ``exclude_velocity_from_observation`` environment arguments. .. _`Han et al. 2020`: https://arxiv.org/abs/2004.14288 """ # noqa: E501 diff --git a/stable_gym/envs/mujoco/half_cheetah_cost/half_cheetah_cost.py b/stable_gym/envs/mujoco/half_cheetah_cost/half_cheetah_cost.py index 5299a336..fe5844f9 100644 --- a/stable_gym/envs/mujoco/half_cheetah_cost/half_cheetah_cost.py +++ b/stable_gym/envs/mujoco/half_cheetah_cost/half_cheetah_cost.py @@ -30,12 +30,10 @@ class HalfCheetahCost(HalfCheetahEnv, utils.EzPickle): is replaced with a cost. This cost is the squared difference between the HalfCheetah's forward velocity and a reference value (error). Additionally, also a control cost can be included in the cost. - - The reference velocity was added to the observation space. - - Two **optional** variables were added to the observation space. These are - the cheetah's forward velocity and the error (difference between the cheetah's - forward velocity and the reference velocity). These variables can be enabled using - the ``exclude_velocity_from_observation`` and ``exclude_reference_error_from_observation`` - environment arguments. + - Three **optional** variables were added to the observation space; The reference velocity, the reference error + (i.e. the difference between the cheetah's forward velocity and the reference) and the cheetah's forward velocity. + These variables can be enabled using the ``exclude_reference_from_observation``, + ``exclude_reference_error_from_observation`` and ``exclude_velocity_from_observation`` environment arguments. The rest of the environment is the same as the original HalfCheetah environment. Below, the modified cost is described. For more information about the environment @@ -68,15 +66,16 @@ class HalfCheetahCost(HalfCheetahEnv, utils.EzPickle): def __init__( self, reference_forward_velocity=1.0, - randomize_reference_forward_velocity=True, - randomize_reference_forward_velocity_range=(0.5, 1.5), + randomise_reference_forward_velocity=False, + randomise_reference_forward_velocity_range=(0.5, 1.5), forward_velocity_weight=1.0, include_ctrl_cost=False, ctrl_cost_weight=1e-4, # NOTE: Lower than original because we use different cost. # noqa: E501 reset_noise_scale=0.1, exclude_current_positions_from_observation=True, - exclude_x_velocity_from_observation=False, + exclude_reference_from_observation=False, # NOTE: True in Han et al. 2018. # noqa: E501 exclude_reference_error_from_observation=True, + exclude_x_velocity_from_observation=False, # NOTE: True in Han et al. 2018. # noqa: E501 **kwargs, ): """Initialise a new HalfCheetahCost environment instance. @@ -84,9 +83,9 @@ def __init__( Args: reference_forward_velocity (float, optional): The forward velocity that the agent should try to track. Defaults to ``1.0``. - randomize_reference_forward_velocity (bool, optional): Whether to randomize + randomise_reference_forward_velocity (bool, optional): Whether to randomize the reference forward velocity. Defaults to ``False``. - randomize_reference_forward_velocity_range (tuple, optional): The range of + randomise_reference_forward_velocity_range (tuple, optional): The range of the random reference forward velocity. Defaults to ``(0.5, 1.5)``. forward_velocity_weight (float, optional): The weight used to scale the forward velocity error. Defaults to ``1.0``. @@ -101,24 +100,38 @@ def __init__( the x- and y-coordinates of the front tip from observations. Excluding the position can serve as an inductive bias to induce position-agnostic behaviour in policies. Defaults to ``True``. - exclude_x_velocity_from_observation (bool, optional): Whether to omit the - x- component of the velocity from observations. Defaults to ``False``. + exclude_reference_from_observation (bool, optional): Whether the reference + should be excluded from the observation. Defaults to ``False``. Can only + be set to ``True`` if ``randomise_reference_forward_velocity`` is set to + ``False``. exclude_reference_error_from_observation (bool, optional): Whether the error should be excluded from the observation. Defaults to ``True``. + exclude_x_velocity_from_observation (bool, optional): Whether to omit the + x- component of the velocity from observations. Defaults to ``False``. **kwargs: Extra keyword arguments to pass to the :class:`~gymnasium.envs.mujoco.half_cheetah_v4.HalfCheetahEnv` class. """ self.reference_forward_velocity = reference_forward_velocity - self.randomize_reference_forward_velocity = randomize_reference_forward_velocity - self.randomize_reference_forward_velocity_range = ( - randomize_reference_forward_velocity_range + self.randomise_reference_forward_velocity = randomise_reference_forward_velocity + self.randomise_reference_forward_velocity_range = ( + randomise_reference_forward_velocity_range ) self._forward_velocity_weight = forward_velocity_weight self._include_ctrl_cost = include_ctrl_cost - self._exclude_x_velocity_from_observation = exclude_x_velocity_from_observation + self._exclude_reference_from_observation = exclude_reference_from_observation self._exclude_reference_error_from_observation = ( exclude_reference_error_from_observation ) + self._exclude_x_velocity_from_observation = exclude_x_velocity_from_observation + + # Validate input arguments. + assert ( + not randomise_reference_forward_velocity + or not exclude_reference_from_observation + ), ( + "The reference can only be excluded from the observation if the forward " + "velocity is not randomised." + ) self.state = None @@ -133,14 +146,15 @@ def __init__( # Extend observation space if necessary. low = self.observation_space.low high = self.observation_space.high - low = np.append(low, -np.inf) # NOTE: Added to include reference. - high = np.append(high, np.inf) # NOTE: Added to include reference. - if not self._exclude_x_velocity_from_observation: + if not self._exclude_reference_from_observation: low = np.append(low, -np.inf) high = np.append(high, np.inf) if not self._exclude_reference_error_from_observation: low = np.append(low, -np.inf) high = np.append(high, np.inf) + if not self._exclude_x_velocity_from_observation: + low = np.append(low, -np.inf) + high = np.append(high, np.inf) self.observation_space = gym.spaces.Box(low, high, dtype=np.float32) # Reinitialize the EzPickle class. @@ -149,11 +163,16 @@ def __init__( utils.EzPickle.__init__( self, reference_forward_velocity, + randomise_reference_forward_velocity, + randomise_reference_forward_velocity_range, forward_velocity_weight, include_ctrl_cost, ctrl_cost_weight, reset_noise_scale, exclude_current_positions_from_observation, + exclude_reference_from_observation, + exclude_reference_error_from_observation, + exclude_x_velocity_from_observation, **kwargs, ) @@ -207,13 +226,14 @@ def step(self, action): cost, cost_info = self.cost(info["x_velocity"], -info["reward_ctrl"]) # Add reference, x velocity and reference error to observation. - obs = np.append(obs, np.float32(self.reference_forward_velocity)) - if not self._exclude_x_velocity_from_observation: - obs = np.append(obs, np.float32(info["x_velocity"])) + if not self._exclude_reference_from_observation: + obs = np.append(obs, self.reference_forward_velocity).astype(np.float32) if not self._exclude_reference_error_from_observation: obs = np.append( - obs, np.float32(info["x_velocity"] - self.reference_forward_velocity) - ) + obs, info["x_velocity"] - self.reference_forward_velocity + ).astype(np.float32) + if not self._exclude_x_velocity_from_observation: + obs = np.append(obs, info["x_velocity"]).astype(np.float32) # Update info. del info["reward_run"], info["reward_ctrl"] @@ -240,20 +260,20 @@ def reset(self, seed=None, options=None): obs, info = super().reset(seed=seed, options=options) # Randomize the reference forward velocity if requested. - if self.randomize_reference_forward_velocity: + if self.randomise_reference_forward_velocity: self.reference_forward_velocity = self.np_random.uniform( - *self.randomize_reference_forward_velocity_range + *self.randomise_reference_forward_velocity_range ) # Add reference, x velocity and reference error to observation. - obs = np.append(obs, np.float32(self.reference_forward_velocity)) - if not self._exclude_x_velocity_from_observation: - obs = np.append(obs, np.float32(0.0)) + if not self._exclude_reference_from_observation: + obs = np.append(obs, self.reference_forward_velocity).astype(np.float32) if not self._exclude_reference_error_from_observation: - obs = np.append( - obs, - np.float32(0.0 - self.reference_forward_velocity), + obs = np.append(obs, 0.0 - self.reference_forward_velocity).astype( + np.float32 ) + if not self._exclude_x_velocity_from_observation: + obs = np.append(obs, 0.0).astype(np.float32) self.state = obs diff --git a/stable_gym/envs/mujoco/hopper_cost/README.md b/stable_gym/envs/mujoco/hopper_cost/README.md index 537e9811..93a6fa81 100644 --- a/stable_gym/envs/mujoco/hopper_cost/README.md +++ b/stable_gym/envs/mujoco/hopper_cost/README.md @@ -7,8 +7,8 @@ An actuated 3-jointed hopper. This environment corresponds to the [Hopper-v4](https://gymnasium.farama.org/environments/mujoco/hopper) environment included in the [gymnasium package](https://gymnasium.farama.org/). It is different in the fact that: -* The objective was changed to a velocity-tracking task. To do this, the reward is replaced with a cost. This cost is the squared - difference between the Hopper's forward velocity and a reference value (error). +* The objective was changed to a velocity-tracking task. To do this, the reward is replaced with a cost. This cost is the squared difference between the Hopper's forward velocity and a reference value (error). +* Three **optional** variables were added to the observation space; The reference velocity, the reference error (i.e. the difference between the hopper's forward velocity and the reference) and the hopper's forward velocity. These variables can be enabled using the `exclude_reference_from_observation`, `exclude_reference_error_from_observation` and `exclude_velocity_from_observation` environment arguments. The rest of the environment is the same as the original Hopper environment. Below, the modified cost is described. For more information about the environment (e.g. observation space, action space, episode termination, etc.), please refer to the [gymnasium library](https://gymnasium.farama.org/environments/mujoco/hopper/). diff --git a/stable_gym/envs/mujoco/hopper_cost/__init__.py b/stable_gym/envs/mujoco/hopper_cost/__init__.py index 369b0e10..2b537d11 100644 --- a/stable_gym/envs/mujoco/hopper_cost/__init__.py +++ b/stable_gym/envs/mujoco/hopper_cost/__init__.py @@ -4,6 +4,10 @@ - The objective was changed to a velocity-tracking task. To do this, the reward is replaced with a cost. This cost is the squared difference between the Hopper's forward velocity and a reference value (error). +- Three **optional** variables were added to the observation space; The reference velocity, the reference error + (i.e. the difference between the hopper's forward velocity and the reference) and the hopper's forward velocity. + These variables can be enabled using the ``exclude_reference_from_observation``, + ``exclude_reference_error_from_observation`` and ``exclude_velocity_from_observation`` environment arguments. .. _`Han et al. 2020`: https://arxiv.org/abs/2004.14288 """ # noqa: E501 diff --git a/stable_gym/envs/mujoco/hopper_cost/hopper_cost.py b/stable_gym/envs/mujoco/hopper_cost/hopper_cost.py index df51619e..285c6942 100644 --- a/stable_gym/envs/mujoco/hopper_cost/hopper_cost.py +++ b/stable_gym/envs/mujoco/hopper_cost/hopper_cost.py @@ -29,6 +29,10 @@ class HopperCost(HopperEnv, utils.EzPickle): is replaced with a cost. This cost is the squared difference between the Hopper's forward velocity and a reference value (error). Additionally, also a control cost and health penalty can be included in the cost. + - Three **optional** variables were added to the observation space; The reference velocity, the reference error + (i.e. the difference between the hopper's forward velocity and the reference) and the hopper's forward velocity. + These variables can be enabled using the ``exclude_reference_from_observation``, + ``exclude_reference_error_from_observation`` and ``exclude_velocity_from_observation`` environment arguments. The rest of the environment is the same as the original Hopper environment. Below, the modified cost is described. For more information about the environment @@ -61,6 +65,8 @@ class HopperCost(HopperEnv, utils.EzPickle): def __init__( self, reference_forward_velocity=1.0, + randomise_reference_forward_velocity=False, + randomise_reference_forward_velocity_range=(0.5, 1.5), forward_velocity_weight=1.0, include_ctrl_cost=False, include_health_penalty=True, @@ -72,6 +78,9 @@ def __init__( healthy_angle_range=(-0.2, 0.2), reset_noise_scale=5e-3, exclude_current_positions_from_observation=True, + exclude_reference_from_observation=False, + exclude_reference_error_from_observation=True, + exclude_x_velocity_from_observation=False, **kwargs, ): """Initialise a new HopperCost environment instance. @@ -79,6 +88,10 @@ def __init__( Args: reference_forward_velocity (float, optional): The forward velocity that the agent should try to track. Defaults to ``1.0``. + randomise_reference_forward_velocity (bool, optional): Whether to randomize + the reference forward velocity. Defaults to ``False``. + randomise_reference_forward_velocity_range (tuple, optional): The range of + the random reference forward velocity. Defaults to ``(0.5, 1.5)``. forward_velocity_weight (float, optional): The weight used to scale the forward velocity error. Defaults to ``1.0``. include_ctrl_cost (bool, optional): Whether you also want to penalize the @@ -103,14 +116,40 @@ def __init__( the x- and y-coordinates of the front tip from observations. Excluding the position can serve as an inductive bias to induce position-agnostic behaviour in policies. Defaults to ``True``. + exclude_reference_from_observation (bool, optional): Whether the reference + should be excluded from the observation. Defaults to ``False``. Can only + be set to ``True`` if ``randomise_reference_forward_velocity`` is set to + ``False``. + exclude_reference_error_from_observation (bool, optional): Whether the error + should be excluded from the observation. Defaults to ``True``. + exclude_x_velocity_from_observation (bool, optional): Whether to omit the + x- component of the velocity from observations. Defaults to ``False``. **kwargs: Extra keyword arguments to pass to the :class:`~gymnasium.envs.mujoco.hopper_v4.HopperEnv` class. """ self.reference_forward_velocity = reference_forward_velocity + self.randomise_reference_forward_velocity = randomise_reference_forward_velocity + self.randomise_reference_forward_velocity_range = ( + randomise_reference_forward_velocity_range + ) self._forward_velocity_weight = forward_velocity_weight self._include_ctrl_cost = include_ctrl_cost self._include_health_penalty = include_health_penalty self._health_penalty_size = health_penalty_size + self._exclude_reference_from_observation = exclude_reference_from_observation + self._exclude_reference_error_from_observation = ( + exclude_reference_error_from_observation + ) + self._exclude_x_velocity_from_observation = exclude_x_velocity_from_observation + + # Validate input arguments. + assert ( + not randomise_reference_forward_velocity + or not exclude_reference_from_observation + ), ( + "The reference can only be excluded from the observation if the forward " + "velocity is not randomised." + ) self.state = None @@ -126,12 +165,28 @@ def __init__( **kwargs, ) + # Extend observation space if necessary. + low = self.observation_space.low + high = self.observation_space.high + if not self._exclude_reference_from_observation: + low = np.append(low, -np.inf) + high = np.append(high, np.inf) + if not self._exclude_reference_error_from_observation: + low = np.append(low, -np.inf) + high = np.append(high, np.inf) + if not self._exclude_x_velocity_from_observation: + low = np.append(low, -np.inf) + high = np.append(high, np.inf) + self.observation_space = gym.spaces.Box(low, high, dtype=np.float32) + # Reinitialize the EzPickle class. # NOTE: Done to ensure the args of the HopperCost class are also pickled. # NOTE: Ensure that all args are passed to the EzPickle class! utils.EzPickle.__init__( self, reference_forward_velocity, + randomise_reference_forward_velocity, + randomise_reference_forward_velocity_range, forward_velocity_weight, include_ctrl_cost, include_health_penalty, @@ -143,6 +198,9 @@ def __init__( healthy_angle_range, reset_noise_scale, exclude_current_positions_from_observation, + exclude_reference_from_observation, + exclude_reference_error_from_observation, + exclude_x_velocity_from_observation, **kwargs, ) @@ -208,6 +266,16 @@ def step(self, action): ctrl_cost = super().control_cost(action) cost, cost_info = self.cost(info["x_velocity"], ctrl_cost) + # Add reference, x velocity and reference error to observation. + if not self._exclude_reference_from_observation: + obs = np.append(obs, self.reference_forward_velocity).astype(np.float32) + if not self._exclude_reference_error_from_observation: + obs = np.append( + obs, info["x_velocity"] - self.reference_forward_velocity + ).astype(np.float32) + if not self._exclude_x_velocity_from_observation: + obs = np.append(obs, info["x_velocity"]).astype(np.float32) + # Update info. info.update(cost_info) @@ -231,6 +299,22 @@ def reset(self, seed=None, options=None): """ obs, info = super().reset(seed=seed, options=options) + # Randomize the reference forward velocity if requested. + if self.randomise_reference_forward_velocity: + self.reference_forward_velocity = self.np_random.uniform( + *self.randomise_reference_forward_velocity_range + ) + + # Add reference, x velocity and reference error to observation. + if not self._exclude_reference_from_observation: + obs = np.append(obs, self.reference_forward_velocity).astype(np.float32) + if not self._exclude_reference_error_from_observation: + obs = np.append(obs, 0.0 - self.reference_forward_velocity).astype( + np.float32 + ) + if not self._exclude_x_velocity_from_observation: + obs = np.append(obs, 0.0).astype(np.float32) + self.state = obs return obs, info diff --git a/stable_gym/envs/mujoco/humanoid_cost/README.md b/stable_gym/envs/mujoco/humanoid_cost/README.md index aaea0c15..0e593171 100644 --- a/stable_gym/envs/mujoco/humanoid_cost/README.md +++ b/stable_gym/envs/mujoco/humanoid_cost/README.md @@ -9,6 +9,7 @@ An actuated 17-jointed humanoid. This environment corresponds to the [Humanoid-v * The objective was changed to a velocity-tracking task. To do this, the reward is replaced with a cost. This cost is the squared difference between the Humanoid's forward velocity and a reference value (error). +* Three **optional** variables were added to the observation space; The reference velocity, the reference error (i.e. the difference between the humanoid's forward velocity and the reference) and the humanoid's forward velocity. These variables can be enabled using the `exclude_reference_from_observation`, `exclude_reference_error_from_observation` and `exclude_velocity_from_observation` environment arguments. The rest of the environment is the same as the original Humanoid environment. Below, the modified cost is described. For more information about the environment (e.g. observation space, action space, episode termination, etc.), please refer to the [gymnasium library](https://gymnasium.farama.org/environments/mujoco/humanoid/). diff --git a/stable_gym/envs/mujoco/humanoid_cost/__init__.py b/stable_gym/envs/mujoco/humanoid_cost/__init__.py index c19f4c71..1c645dc1 100644 --- a/stable_gym/envs/mujoco/humanoid_cost/__init__.py +++ b/stable_gym/envs/mujoco/humanoid_cost/__init__.py @@ -4,6 +4,10 @@ - The objective was changed to a velocity-tracking task. To do this, the reward is replaced with a cost. This cost is the squared difference between the Humanoid's forward velocity and a reference value (error). +- Three **optional** variables were added to the observation space; The reference velocity, the reference error + (i.e. the difference between the humanoid's forward velocity and the reference) and the humanoid's forward velocity. + These variables can be enabled using the ``exclude_reference_from_observation``, + ``exclude_reference_error_from_observation`` and ``exclude_velocity_from_observation`` environment arguments. .. _`Han et al. 2020`: https://arxiv.org/abs/2004.14288 """ # noqa: E501 diff --git a/stable_gym/envs/mujoco/humanoid_cost/humanoid_cost.py b/stable_gym/envs/mujoco/humanoid_cost/humanoid_cost.py index a010cae3..85e1968a 100644 --- a/stable_gym/envs/mujoco/humanoid_cost/humanoid_cost.py +++ b/stable_gym/envs/mujoco/humanoid_cost/humanoid_cost.py @@ -30,6 +30,10 @@ class HumanoidCost(HumanoidEnv, utils.EzPickle): is replaced with a cost. This cost is the squared difference between the Humanoid's forward velocity and a reference value (error). Additionally, also a control cost and health penalty can be included in the cost. + - Three **optional** variables were added to the observation space; The reference velocity, the reference error + (i.e. the difference between the humanoid's forward velocity and the reference) and the humanoid's forward velocity. + These variables can be enabled using the ``exclude_reference_from_observation``, + ``exclude_reference_error_from_observation`` and ``exclude_velocity_from_observation`` environment arguments. The rest of the environment is the same as the original Humanoid environment. Below, the modified cost is described. For more information about the environment @@ -62,6 +66,8 @@ class HumanoidCost(HumanoidEnv, utils.EzPickle): def __init__( self, reference_forward_velocity=1.0, + randomise_reference_forward_velocity=False, + randomise_reference_forward_velocity_range=(0.5, 1.5), forward_velocity_weight=1.0, include_ctrl_cost=False, include_health_penalty=True, @@ -71,6 +77,9 @@ def __init__( healthy_z_range=(1.0, 2.0), reset_noise_scale=1e-2, exclude_current_positions_from_observation=True, + exclude_reference_from_observation=False, + exclude_reference_error_from_observation=True, + exclude_x_velocity_from_observation=False, **kwargs, ): """Initialise a new HumanoidCost environment instance. @@ -78,6 +87,10 @@ def __init__( Args: reference_forward_velocity (float, optional): The forward velocity that the agent should try to track. Defaults to ``1.0``. + randomise_reference_forward_velocity (bool, optional): Whether to randomize + the reference forward velocity. Defaults to ``False``. + randomise_reference_forward_velocity_range (tuple, optional): The range of + the random reference forward velocity. Defaults to ``(0.5, 1.5)``. forward_velocity_weight (float, optional): The weight used to scale the forward velocity error. Defaults to ``1.0``. include_ctrl_cost (bool, optional): Whether you also want to penalize the @@ -98,14 +111,40 @@ def __init__( the x- and y-coordinates of the front tip from observations. Excluding the position can serve as an inductive bias to induce position-agnostic behaviour in policies. Defaults to ``True``. + exclude_reference_from_observation (bool, optional): Whether the reference + should be excluded from the observation. Defaults to ``False``. Can only + be set to ``True`` if ``randomise_reference_forward_velocity`` is set to + ``False``. + exclude_reference_error_from_observation (bool, optional): Whether the error + should be excluded from the observation. Defaults to ``True``. + exclude_x_velocity_from_observation (bool, optional): Whether to omit the + x- component of the velocity from observations. Defaults to ``False``. **kwargs: Extra keyword arguments to pass to the :class:`~gymnasium.envs.mujoco.humanoid_v4.HumanoidEnv` class. """ self.reference_forward_velocity = reference_forward_velocity + self.randomise_reference_forward_velocity = randomise_reference_forward_velocity + self.randomise_reference_forward_velocity_range = ( + randomise_reference_forward_velocity_range + ) self._forward_velocity_weight = forward_velocity_weight self._include_ctrl_cost = include_ctrl_cost self._include_health_penalty = include_health_penalty self._health_penalty_size = health_penalty_size + self._exclude_reference_from_observation = exclude_reference_from_observation + self._exclude_reference_error_from_observation = ( + exclude_reference_error_from_observation + ) + self._exclude_x_velocity_from_observation = exclude_x_velocity_from_observation + + # Validate input arguments. + assert ( + not randomise_reference_forward_velocity + or not exclude_reference_from_observation + ), ( + "The reference can only be excluded from the observation if the forward " + "velocity is not randomised." + ) self.state = None @@ -119,12 +158,28 @@ def __init__( **kwargs, ) + # Extend observation space if necessary. + low = self.observation_space.low + high = self.observation_space.high + if not self._exclude_reference_from_observation: + low = np.append(low, -np.inf) + high = np.append(high, np.inf) + if not self._exclude_reference_error_from_observation: + low = np.append(low, -np.inf) + high = np.append(high, np.inf) + if not self._exclude_x_velocity_from_observation: + low = np.append(low, -np.inf) + high = np.append(high, np.inf) + self.observation_space = gym.spaces.Box(low, high, dtype=np.float32) + # Reinitialize the EzPickle class. # NOTE: Done to ensure the args of the HumanoidCost class are also pickled. # NOTE: Ensure that all args are passed to the EzPickle class! utils.EzPickle.__init__( self, reference_forward_velocity, + randomise_reference_forward_velocity, + randomise_reference_forward_velocity_range, forward_velocity_weight, include_ctrl_cost, include_health_penalty, @@ -134,6 +189,9 @@ def __init__( healthy_z_range, reset_noise_scale, exclude_current_positions_from_observation, + exclude_reference_from_observation, + exclude_reference_error_from_observation, + exclude_x_velocity_from_observation, **kwargs, ) @@ -198,6 +256,16 @@ def step(self, action): cost, cost_info = self.cost(info["x_velocity"], -info["reward_quadctrl"]) + # Add reference, x velocity and reference error to observation. + if not self._exclude_reference_from_observation: + obs = np.append(obs, self.reference_forward_velocity).astype(np.float32) + if not self._exclude_reference_error_from_observation: + obs = np.append( + obs, info["x_velocity"] - self.reference_forward_velocity + ).astype(np.float32) + if not self._exclude_x_velocity_from_observation: + obs = np.append(obs, info["x_velocity"]).astype(np.float32) + # Update info. del ( info["reward_linvel"], @@ -227,6 +295,22 @@ def reset(self, seed=None, options=None): """ obs, info = super().reset(seed=seed, options=options) + # Randomize the reference forward velocity if requested. + if self.randomise_reference_forward_velocity: + self.reference_forward_velocity = self.np_random.uniform( + *self.randomise_reference_forward_velocity_range + ) + + # Add reference, x velocity and reference error to observation. + if not self._exclude_reference_from_observation: + obs = np.append(obs, self.reference_forward_velocity).astype(np.float32) + if not self._exclude_reference_error_from_observation: + obs = np.append(obs, 0.0 - self.reference_forward_velocity).astype( + np.float32 + ) + if not self._exclude_x_velocity_from_observation: + obs = np.append(obs, 0.0).astype(np.float32) + self.state = obs return obs, info diff --git a/stable_gym/envs/mujoco/swimmer_cost/README.md b/stable_gym/envs/mujoco/swimmer_cost/README.md index e7bc07c3..5f5e5389 100644 --- a/stable_gym/envs/mujoco/swimmer_cost/README.md +++ b/stable_gym/envs/mujoco/swimmer_cost/README.md @@ -8,6 +8,7 @@ An actuated 2-jointed swimmer. This environment corresponds to the [Swimmer-v4](https://gymnasium.farama.org/environments/mujoco/swimmer) environment included in the [gymnasium package](https://gymnasium.farama.org/). It is different in the fact that: * The objective was changed to a velocity-tracking task. To do this, the reward is replaced with a cost. This cost is the squared difference between the Swimmer's forward velocity and a reference value (error). +- Three **optional** variables were added to the observation space; The reference velocity, the reference error (i.e. the difference between the swimmer's forward velocity and the reference) and the swimmer's forward velocity. These variables can be enabled using the `exclude_reference_from_observation`, `exclude_reference_error_from_observation` and `exclude_velocity_from_observation` environment arguments. The rest of the environment is the same as the original Swimmer environment. Below, the modified cost is described. For more information about the environment (e.g. observation space, action space, episode termination, etc.), please refer to the [gymnasium library](https://gymnasium.farama.org/environments/mujoco/swimmer/). diff --git a/stable_gym/envs/mujoco/swimmer_cost/__init__.py b/stable_gym/envs/mujoco/swimmer_cost/__init__.py index f322eac5..13bab385 100644 --- a/stable_gym/envs/mujoco/swimmer_cost/__init__.py +++ b/stable_gym/envs/mujoco/swimmer_cost/__init__.py @@ -4,6 +4,10 @@ - The objective was changed to a velocity-tracking task. To do this, the reward is replaced with a cost. This cost is the squared difference between the swimmer's forward velocity and a reference value (error). +- Three **optional** variables were added to the observation space; The reference velocity, the reference error + (i.e. the difference between the swimmer's forward velocity and the reference) and the swimmer's forward velocity. + These variables can be enabled using the ``exclude_reference_from_observation``, + ``exclude_reference_error_from_observation`` and ``exclude_velocity_from_observation`` environment arguments. .. _`Han et al. 2020`: https://arxiv.org/abs/2004.14288 """ # noqa: E501 diff --git a/stable_gym/envs/mujoco/swimmer_cost/swimmer_cost.py b/stable_gym/envs/mujoco/swimmer_cost/swimmer_cost.py index f6fbb4aa..f6af98c8 100644 --- a/stable_gym/envs/mujoco/swimmer_cost/swimmer_cost.py +++ b/stable_gym/envs/mujoco/swimmer_cost/swimmer_cost.py @@ -29,6 +29,10 @@ class SwimmerCost(SwimmerEnv, utils.EzPickle): is replaced with a cost. This cost is the squared difference between the swimmer's forward velocity and a reference value (error). Additionally, also a control cost can be included in the cost. + - Three **optional** variables were added to the observation space; The reference velocity, the reference error + (i.e. the difference between the swimmer's forward velocity and the reference) and the swimmer's forward velocity. + These variables can be enabled using the ``exclude_reference_from_observation``, + ``exclude_reference_error_from_observation`` and ``exclude_velocity_from_observation`` environment arguments. The rest of the environment is the same as the original Swimmer environment. Below, the modified cost is described. For more information about the environment @@ -61,11 +65,16 @@ class SwimmerCost(SwimmerEnv, utils.EzPickle): def __init__( self, reference_forward_velocity=1.0, + randomise_reference_forward_velocity=False, + randomise_reference_forward_velocity_range=(0.5, 1.5), forward_velocity_weight=1.0, include_ctrl_cost=False, ctrl_cost_weight=1e-4, reset_noise_scale=0.1, exclude_current_positions_from_observation=True, + exclude_reference_from_observation=False, # NOTE: True in Han et al. 2018. # noqa: E501 + exclude_reference_error_from_observation=True, + exclude_x_velocity_from_observation=False, # NOTE: True in Han et al. 2018. # noqa: E501 **kwargs, ): """Initialise a new SwimmerCost environment instance. @@ -73,6 +82,10 @@ def __init__( Args: reference_forward_velocity (float, optional): The forward velocity that the agent should try to track. Defaults to ``1.0``. + randomise_reference_forward_velocity (bool, optional): Whether to randomize + the reference forward velocity. Defaults to ``False``. + randomise_reference_forward_velocity_range (tuple, optional): The range of + the random reference forward velocity. Defaults to ``(0.5, 1.5)``. forward_velocity_weight (float, optional): The weight used to scale the forward velocity error. Defaults to ``1.0``. include_ctrl_cost (bool, optional): Whether you also want to penalize the @@ -85,12 +98,38 @@ def __init__( the x- and y-coordinates of the front tip from observations. Excluding the position can serve as an inductive bias to induce position-agnostic behaviour in policies. Defaults to ``True``. + exclude_reference_from_observation (bool, optional): Whether the reference + should be excluded from the observation. Defaults to ``False``. Can only + be set to ``True`` if ``randomise_reference_forward_velocity`` is set to + ``False``. + exclude_reference_error_from_observation (bool, optional): Whether the error + should be excluded from the observation. Defaults to ``True``. + exclude_x_velocity_from_observation (bool, optional): Whether to omit the + x- component of the velocity from observations. Defaults to ``False``. **kwargs: Extra keyword arguments to pass to the :class:`~gymnasium.envs.mujoco.swimmer_v4.SwimmerEnv` class. """ self.reference_forward_velocity = reference_forward_velocity + self.randomise_reference_forward_velocity = randomise_reference_forward_velocity + self.randomise_reference_forward_velocity_range = ( + randomise_reference_forward_velocity_range + ) self._forward_velocity_weight = forward_velocity_weight self._include_ctrl_cost = include_ctrl_cost + self._exclude_reference_from_observation = exclude_reference_from_observation + self._exclude_reference_error_from_observation = ( + exclude_reference_error_from_observation + ) + self._exclude_x_velocity_from_observation = exclude_x_velocity_from_observation + + # Validate input arguments. + assert ( + not randomise_reference_forward_velocity + or not exclude_reference_from_observation + ), ( + "The reference can only be excluded from the observation if the forward " + "velocity is not randomised." + ) self.state = None @@ -102,17 +141,36 @@ def __init__( **kwargs, ) + # Extend observation space if necessary. + low = self.observation_space.low + high = self.observation_space.high + if not self._exclude_reference_from_observation: + low = np.append(low, -np.inf) + high = np.append(high, np.inf) + if not self._exclude_reference_error_from_observation: + low = np.append(low, -np.inf) + high = np.append(high, np.inf) + if not self._exclude_x_velocity_from_observation: + low = np.append(low, -np.inf) + high = np.append(high, np.inf) + self.observation_space = gym.spaces.Box(low, high, dtype=np.float32) + # Reinitialize the EzPickle class. # NOTE: Done to ensure the args of the SwimmerCost class are also pickled. # NOTE: Ensure that all args are passed to the EzPickle class! utils.EzPickle.__init__( self, reference_forward_velocity, + randomise_reference_forward_velocity, + randomise_reference_forward_velocity_range, forward_velocity_weight, include_ctrl_cost, ctrl_cost_weight, reset_noise_scale, exclude_current_positions_from_observation, + exclude_reference_from_observation, + exclude_reference_error_from_observation, + exclude_x_velocity_from_observation, **kwargs, ) @@ -165,6 +223,16 @@ def step(self, action): cost, cost_info = self.cost(info["x_velocity"], -info["reward_ctrl"]) + # Add reference, x velocity and reference error to observation. + if not self._exclude_reference_from_observation: + obs = np.append(obs, self.reference_forward_velocity).astype(np.float32) + if not self._exclude_reference_error_from_observation: + obs = np.append( + obs, info["x_velocity"] - self.reference_forward_velocity + ).astype(np.float32) + if not self._exclude_x_velocity_from_observation: + obs = np.append(obs, info["x_velocity"]).astype(np.float32) + # Update info. del info["reward_fwd"], info["reward_ctrl"], info["forward_reward"] info.update(cost_info) @@ -189,6 +257,22 @@ def reset(self, seed=None, options=None): """ obs, info = super().reset(seed=seed, options=options) + # Randomize the reference forward velocity if requested. + if self.randomise_reference_forward_velocity: + self.reference_forward_velocity = self.np_random.uniform( + *self.randomise_reference_forward_velocity_range + ) + + # Add reference, x velocity and reference error to observation. + if not self._exclude_reference_from_observation: + obs = np.append(obs, self.reference_forward_velocity).astype(np.float32) + if not self._exclude_reference_error_from_observation: + obs = np.append(obs, 0.0 - self.reference_forward_velocity).astype( + np.float32 + ) + if not self._exclude_x_velocity_from_observation: + obs = np.append(obs, 0.0).astype(np.float32) + self.state = obs return obs, info diff --git a/stable_gym/envs/mujoco/walker2d_cost/README.md b/stable_gym/envs/mujoco/walker2d_cost/README.md index 78bafb95..3a6e03c6 100644 --- a/stable_gym/envs/mujoco/walker2d_cost/README.md +++ b/stable_gym/envs/mujoco/walker2d_cost/README.md @@ -9,6 +9,8 @@ An actuated 8-jointed 2D walker. This environment corresponds to the [Walker2d-v * The objective was changed to a velocity-tracking task. To do this, the reward is replaced with a cost. This cost is the squared difference between the Walker2d's forward velocity and a reference value (error). +* The reference velocity was added to the observation space. +* Three **optional** variables were added to the observation space; The reference velocity, the reference error (i.e. the difference between the walker2d's forward velocity and the reference) and the walker2d's forward velocity. These variables can be enabled using the `exclude_reference_from_observation`, `exclude_reference_error_from_observation` and `exclude_velocity_from_observation` environment arguments. The rest of the environment is the same as the original Walker2d environment. Below, the modified cost is described. For more information about the environment (e.g. observation space, action space, episode termination, etc.), please refer to the [gymnasium library](https://gymnasium.farama.org/environments/mujoco/walker2d/). diff --git a/stable_gym/envs/mujoco/walker2d_cost/__init__.py b/stable_gym/envs/mujoco/walker2d_cost/__init__.py index d6001c68..d254fc81 100644 --- a/stable_gym/envs/mujoco/walker2d_cost/__init__.py +++ b/stable_gym/envs/mujoco/walker2d_cost/__init__.py @@ -4,6 +4,11 @@ - The objective was changed to a velocity-tracking task. To do this, the reward is replaced with a cost. This cost is the squared difference between the Walker2d's forward velocity and a reference value (error). +- The reference velocity was added to the observation space. +- Three **optional** variables were added to the observation space; The reference velocity, the reference error + (i.e. the difference between the walker2d's forward velocity and the reference) and the walker2d's forward velocity. + These variables can be enabled using the ``exclude_reference_from_observation``, + ``exclude_reference_error_from_observation`` and ``exclude_velocity_from_observation`` environment arguments. .. _`Han et al. 2020`: https://arxiv.org/abs/2004.14288 """ # noqa: E501 diff --git a/stable_gym/envs/mujoco/walker2d_cost/walker2d_cost.py b/stable_gym/envs/mujoco/walker2d_cost/walker2d_cost.py index 602bd6c3..ff6dcbb9 100644 --- a/stable_gym/envs/mujoco/walker2d_cost/walker2d_cost.py +++ b/stable_gym/envs/mujoco/walker2d_cost/walker2d_cost.py @@ -29,6 +29,10 @@ class Walker2dCost(Walker2dEnv, utils.EzPickle): is replaced with a cost. This cost is the squared difference between the Walker2d's forward velocity and a reference value (error). Additionally, also a control cost and health penalty can be included in the cost. + - Three **optional** variables were added to the observation space; The reference velocity, the reference error + (i.e. the difference between the walker2d's forward velocity and the reference) and the walker2d's forward velocity. + These variables can be enabled using the ``exclude_reference_from_observation``, + ``exclude_reference_error_from_observation`` and ``exclude_velocity_from_observation`` environment arguments. The rest of the environment is the same as the original Walker2d environment. Below, the modified cost is described. For more information about the environment @@ -61,6 +65,8 @@ class Walker2dCost(Walker2dEnv, utils.EzPickle): def __init__( self, reference_forward_velocity=1.0, + randomise_reference_forward_velocity=False, + randomise_reference_forward_velocity_range=(0.5, 1.5), forward_velocity_weight=1.0, include_ctrl_cost=False, include_health_penalty=True, @@ -71,6 +77,9 @@ def __init__( healthy_angle_range=(-1.0, 1.0), reset_noise_scale=5e-3, exclude_current_positions_from_observation=True, + exclude_reference_from_observation=False, + exclude_reference_error_from_observation=True, + exclude_x_velocity_from_observation=False, **kwargs, ): """Initialise a new Walker2dCost environment instance. @@ -78,6 +87,10 @@ def __init__( Args: reference_forward_velocity (float, optional): The forward velocity that the agent should try to track. Defaults to ``1.0``. + randomise_reference_forward_velocity (bool, optional): Whether to randomize + the reference forward velocity. Defaults to ``False``. + randomise_reference_forward_velocity_range (tuple, optional): The range of + the random reference forward velocity. Defaults to ``(0.5, 1.5)``. forward_velocity_weight (float, optional): The weight used to scale the forward velocity error. Defaults to ``1.0``. include_ctrl_cost (bool, optional): Whether you also want to penalize the @@ -100,14 +113,40 @@ def __init__( the x- and y-coordinates of the front tip from observations. Excluding the position can serve as an inductive bias to induce position-agnostic behaviour in policies. Defaults to ``True``. + exclude_reference_from_observation (bool, optional): Whether the reference + should be excluded from the observation. Defaults to ``False``. Can only + be set to ``True`` if ``randomise_reference_forward_velocity`` is set to + ``False``. + exclude_reference_error_from_observation (bool, optional): Whether the error + should be excluded from the observation. Defaults to ``True``. + exclude_x_velocity_from_observation (bool, optional): Whether to omit the + x- component of the velocity from observations. Defaults to ``False``. **kwargs: Extra keyword arguments to pass to the :class:`~gymnasium.envs.mujoco.walker2d_v4.Walker2dEnv` class. """ self.reference_forward_velocity = reference_forward_velocity + self.randomise_reference_forward_velocity = randomise_reference_forward_velocity + self.randomise_reference_forward_velocity_range = ( + randomise_reference_forward_velocity_range + ) self._forward_velocity_weight = forward_velocity_weight self._include_ctrl_cost = include_ctrl_cost self._include_health_penalty = include_health_penalty self._health_penalty_size = health_penalty_size + self._exclude_reference_from_observation = exclude_reference_from_observation + self._exclude_reference_error_from_observation = ( + exclude_reference_error_from_observation + ) + self._exclude_x_velocity_from_observation = exclude_x_velocity_from_observation + + # Validate input arguments. + assert ( + not randomise_reference_forward_velocity + or not exclude_reference_from_observation + ), ( + "The reference can only be excluded from the observation if the forward " + "velocity is not randomised." + ) self.state = None @@ -122,12 +161,28 @@ def __init__( **kwargs, ) + # Extend observation space if necessary. + low = self.observation_space.low + high = self.observation_space.high + if not self._exclude_reference_from_observation: + low = np.append(low, -np.inf) + high = np.append(high, np.inf) + if not self._exclude_reference_error_from_observation: + low = np.append(low, -np.inf) + high = np.append(high, np.inf) + if not self._exclude_x_velocity_from_observation: + low = np.append(low, -np.inf) + high = np.append(high, np.inf) + self.observation_space = gym.spaces.Box(low, high, dtype=np.float32) + # Reinitialize the EzPickle class. # NOTE: Done to ensure the args of the Walker2dCost class are also pickled. # NOTE: Ensure that all args are passed to the EzPickle class! utils.EzPickle.__init__( self, reference_forward_velocity, + randomise_reference_forward_velocity, + randomise_reference_forward_velocity_range, forward_velocity_weight, include_ctrl_cost, include_health_penalty, @@ -138,6 +193,9 @@ def __init__( healthy_angle_range, reset_noise_scale, exclude_current_positions_from_observation, + exclude_reference_from_observation, + exclude_reference_error_from_observation, + exclude_x_velocity_from_observation, **kwargs, ) @@ -203,6 +261,16 @@ def step(self, action): ctrl_cost = super().control_cost(action) cost, cost_info = self.cost(info["x_velocity"], ctrl_cost) + # Add reference, x velocity and reference error to observation. + if not self._exclude_reference_from_observation: + obs = np.append(obs, self.reference_forward_velocity).astype(np.float32) + if not self._exclude_reference_error_from_observation: + obs = np.append( + obs, info["x_velocity"] - self.reference_forward_velocity + ).astype(np.float32) + if not self._exclude_x_velocity_from_observation: + obs = np.append(obs, info["x_velocity"]).astype(np.float32) + # Update info. info.update(cost_info) @@ -226,6 +294,22 @@ def reset(self, seed=None, options=None): """ obs, info = super().reset(seed=seed, options=options) + # Randomize the reference forward velocity if requested. + if self.randomise_reference_forward_velocity: + self.reference_forward_velocity = self.np_random.uniform( + *self.randomise_reference_forward_velocity_range + ) + + # Add reference, x velocity and reference error to observation. + if not self._exclude_reference_from_observation: + obs = np.append(obs, self.reference_forward_velocity).astype(np.float32) + if not self._exclude_reference_error_from_observation: + obs = np.append(obs, 0.0 - self.reference_forward_velocity).astype( + np.float32 + ) + if not self._exclude_x_velocity_from_observation: + obs = np.append(obs, 0.0).astype(np.float32) + self.state = obs return obs, info