Skip to content

Commit

Permalink
feat: add velocity randomize and exclude reference arguments (#215)
Browse files Browse the repository at this point in the history
* feat: add velocity randomize and exclude reference arguments

This commit adds the ability to randomize the velocity reference in the
Mujoco environment. It also adds the `exclude_reference_from_observation`
argument which can be used to remove the reference from the observation.

* refactor: fix several environment bugs
  • Loading branch information
rickstaa authored Jul 16, 2023
1 parent ca7685c commit c32b4ee
Show file tree
Hide file tree
Showing 24 changed files with 574 additions and 126 deletions.
7 changes: 5 additions & 2 deletions stable_gym/envs/biological/oscillator/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@ By default, the environment returns the following observation:
* $p_1$ - The lacI (repressor) protein concentration (Inhibits transcription of tetR gene).
* $p_2$ - The tetR (repressor) protein concentration (Inhibits transcription of CI gene).
* $p_3$ - The CI (repressor) protein concentration (Inhibits transcription of lacI gene).
* $r$ - The reference we want to follow.

An extra variable will be returned if the `exclude_reference_error_from_observation` flag is set to `False`:
Extra variables will be returned if the `exclude_reference_from_observation` and `exclude_reference_error_from_observation`
flag are set to `False`:

* $r$ - The reference we want to follow.
* $r_{error}$ - The error between the state of interest (i.e. $p_1$) and the reference.

The reference can not be excluded when `reference_type` is set to `periodic`.

## Action space

* $u_1$ - Relative intensity of the light signal that induces the Lacl mRNA gene expression.
Expand Down
44 changes: 32 additions & 12 deletions stable_gym/envs/biological/oscillator/oscillator.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def __init__(
reference_frequency=200,
reference_constraint_position=20.0,
clip_action=True,
exclude_reference_error_from_observation=True, # NOTE: False in Han et al. 2018. # noqa: E501
exclude_reference_from_observation=False,
exclude_reference_error_from_observation=False,
):
"""Initialise a new Oscillator environment instance.
Expand All @@ -147,12 +148,16 @@ def __init__(
dict.
clip_action (str, optional): Whether the actions should be clipped if
they are greater than the set action limit. Defaults to ``True``.
exclude_reference_from_observation (bool, optional): Whether the reference
should be excluded from the observation. Defaults to ``False``. Can only
be set to ``True`` if ``reference_type`` is ``constant``.
exclude_reference_error_from_observation (bool, optional): Whether the error
should be excluded from the observation. Defaults to ``True``.
should be excluded from the observation. Defaults to ``False``.
"""
super().__init__() # Setup disturber.
self._action_clip_warning = False
self._clip_action = clip_action
self._exclude_reference_from_observation = exclude_reference_from_observation
self._exclude_reference_error_from_observation = (
exclude_reference_error_from_observation
)
Expand All @@ -162,6 +167,14 @@ def __init__(
raise ValueError(
"The reference type must be either 'constant' or 'periodic'."
)
assert (
reference_type.lower() == "periodic"
or reference_type.lower() == "constant"
and not exclude_reference_from_observation
), (
"The reference can only be excluded from the observation if the reference "
"type is constant."
)

self.reference_type = reference_type
self.t = 0.0
Expand Down Expand Up @@ -203,11 +216,14 @@ def __init__(
self.delta5 = 0.0 # p2 noise.
self.delta6 = 0.0 # p3 noise.

obs_low = np.array([0, 0, 0, 0, 0, 0, 0], dtype=np.float32)
obs_high = np.array([100, 100, 100, 100, 100, 100, 100], dtype=np.float32)
obs_low = np.array([0, 0, 0, 0, 0, 0], dtype=np.float32)
obs_high = np.array([100, 100, 100, 100, 100, 100], dtype=np.float32)
if not self._exclude_reference_from_observation:
obs_low = np.append(obs_low, 0.0).astype(np.float32)
obs_high = np.append(obs_high, 100.0).astype(np.float32)
if not self._exclude_reference_error_from_observation:
obs_low = np.append(obs_low, np.float32(-100.0))
obs_high = np.append(obs_high, np.float32(100.0))
obs_low = np.append(obs_low, -100.0).astype(np.float32)
obs_high = np.append(obs_high, 100.0).astype(np.float32)
self.action_space = spaces.Box(
low=np.array([-5.0, -5.0, -5.0], dtype=np.float32),
high=np.array([5.0, 5.0, 5.0], dtype=np.float32),
Expand Down Expand Up @@ -352,9 +368,11 @@ def step(self, action):
terminated = bool(cost > self.cost_range.high or cost < self.cost_range.low)

# Create observation.
obs = np.array([m1, m2, m3, p1, p2, p3, r1], dtype=np.float32)
obs = np.array([m1, m2, m3, p1, p2, p3], dtype=np.float32)
if not self._exclude_reference_from_observation:
obs = np.append(obs, r1).astype(np.float32)
if not self._exclude_reference_error_from_observation:
obs = np.append(obs, np.float32(p1 - r1))
obs = np.append(obs, p1 - r1).astype(np.float32)

# Return state, cost, terminated, truncated and info_dict
return (
Expand Down Expand Up @@ -419,7 +437,7 @@ def reset(
np.append(
low,
np.zeros(
1 if self._exclude_reference_error_from_observation else 2,
self.observation_space.shape[0] - low.shape[0],
dtype=np.float32,
),
)
Expand All @@ -429,7 +447,7 @@ def reset(
np.append(
high,
np.zeros(
1 if self._exclude_reference_error_from_observation else 2,
self.observation_space.shape[0] - low.shape[0],
dtype=np.float32,
),
)
Expand All @@ -448,9 +466,11 @@ def reset(
self.t = 0.0
m1, m2, m3, p1, p2, p3 = self.state
r1 = self.reference(self.t)
obs = np.array([m1, m2, m3, p1, p2, p3, r1], dtype=np.float32)
obs = np.array([m1, m2, m3, p1, p2, p3], dtype=np.float32)
if not self._exclude_reference_from_observation:
obs = np.append(obs, r1).astype(np.float32)
if not self._exclude_reference_error_from_observation:
obs = np.append(obs, np.float32(p1 - r1))
obs = np.append(obs, p1 - r1).astype(np.float32)

# Return initial observation and info_dict.
return obs, dict(
Expand Down
7 changes: 5 additions & 2 deletions stable_gym/envs/biological/oscillator_complicated/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@ By default, the environment returns the following observation:
* $p_2$ - The tetR (repressor) protein concentration (Inhibits transcription of CI gene).
* $p_3$ - The CI (repressor) protein concentration (Inhibits transcription of extra protein gene).
* $p_4$ - Extra protein concentration (Inhibits transcription of lacI gene).
* $r$ - The reference we want to follow.

An extra variable will be returned if the `exclude_reference_error_from_observation` flag is set to `False`:
Extra variables will be returned if the `exclude_reference_from_observation` and `exclude_reference_error_from_observation`
flag are set to `False`:

* $r$ - The reference we want to follow.
* $r_{error}$ - The error between the state of interest (i.e. $p_1$) and the reference.

The reference can not be excluded when `reference_type` is set to `periodic`.

## Action space

* $u_1$ - Relative intensity of the light signal that induces the Lacl mRNA gene expression.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@ def __init__(
reference_frequency=200,
reference_constraint_position=20.0,
clip_action=True,
exclude_reference_error_from_observation=True, # NOTE: False in Han et al. 2018. # noqa: E501
exclude_reference_from_observation=False,
exclude_reference_error_from_observation=False,
):
"""Initialise a new OscillatorComplicated environment instance.
Expand All @@ -163,12 +164,16 @@ def __init__(
dict.
clip_action (str, optional): Whether the actions should be clipped if
they are greater than the set action limit. Defaults to ``True``.
exclude_reference_from_observation (bool, optional): Whether the reference
should be excluded from the observation. Defaults to ``False``. Can only
be set to ``True`` if ``reference_type`` is ``constant``.
exclude_reference_error_from_observation (bool, optional): Whether the error
should be excluded from the observation. Defaults to ``True``.
should be excluded from the observation. Defaults to ``False``.
"""
super().__init__() # Setup disturber.
self._action_clip_warning = False
self._clip_action = clip_action
self._exclude_reference_from_observation = exclude_reference_from_observation
self._exclude_reference_error_from_observation = (
exclude_reference_error_from_observation
)
Expand All @@ -178,6 +183,14 @@ def __init__(
raise ValueError(
"The reference type must be either 'constant' or 'periodic'."
)
assert (
reference_type.lower() == "periodic"
or reference_type.lower() == "constant"
and not exclude_reference_from_observation
), (
"The reference can only be excluded from the observation if the reference "
"type is constant."
)

self.reference_type = reference_type
self.t = 0.0
Expand Down Expand Up @@ -227,13 +240,14 @@ def __init__(
self.delta7 = 0.0 # p3 noise.
self.delta8 = 0.0 # p4 noise.

obs_low = np.array([0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.float32)
obs_high = np.array(
[100, 100, 100, 100, 100, 100, 100, 100, 100], dtype=np.float32
)
obs_low = np.array([0, 0, 0, 0, 0, 0, 0, 0], dtype=np.float32)
obs_high = np.array([100, 100, 100, 100, 100, 100, 100, 100], dtype=np.float32)
if not self._exclude_reference_from_observation:
obs_low = np.append(obs_low, 0.0).astype(np.float32)
obs_high = np.append(obs_high, 100.0).astype(np.float32)
if not self._exclude_reference_error_from_observation:
obs_low = np.append(obs_low, np.float32(-100))
obs_high = np.append(obs_high, np.float32(100))
obs_low = np.append(obs_low, -100).astype(np.float32)
obs_high = np.append(obs_high, 100).astype(np.float32)
self.action_space = spaces.Box(
low=np.array([-5.0, -5.0, -5.0, -5.0], dtype=np.float32),
high=np.array([5.0, 5.0, 5.0, 5.0], dtype=np.float32),
Expand All @@ -251,7 +265,7 @@ def __init__(
self.steps_beyond_done = None

# Reference target and constraint positions.
self.reference_target_pos = reference_target_position # Reference target.
self.reference_target_pos = reference_target_position
self.reference_amplitude = reference_amplitude
self.reference_frequency = reference_frequency
self.reference_constraint_pos = (
Expand Down Expand Up @@ -407,9 +421,11 @@ def step(self, action):
terminated = bool(cost > self.cost_range.high or cost < self.cost_range.low)

# Create observation.
obs = np.array([m1, m2, m3, m4, p1, p2, p3, p4, r1], dtype=np.float32)
obs = np.array([m1, m2, m3, m4, p1, p2, p3, p4], dtype=np.float32)
if not self._exclude_reference_from_observation:
obs = np.append(obs, r1).astype(np.float32)
if not self._exclude_reference_error_from_observation:
obs = np.append(obs, np.float32(p1 - r1))
obs = np.append(obs, p1 - r1).astype(np.float32)

# Return state, cost, terminated, truncated and info_dict
return (
Expand Down Expand Up @@ -474,7 +490,7 @@ def reset(
np.append(
low,
np.zeros(
1 if self._exclude_reference_error_from_observation else 2,
self.observation_space.shape[0] - low.shape[0],
dtype=np.float32,
),
)
Expand All @@ -484,7 +500,7 @@ def reset(
np.append(
high,
np.zeros(
1 if self._exclude_reference_error_from_observation else 2,
self.observation_space.shape[0] - low.shape[0],
dtype=np.float32,
),
)
Expand All @@ -503,9 +519,11 @@ def reset(
self.t = 0.0
m1, m2, m3, m4, p1, p2, p3, p4 = self.state
r1 = self.reference(self.t)
obs = np.array([m1, m2, m3, m4, p1, p2, p3, p4, r1], dtype=np.float32)
obs = np.array([m1, m2, m3, m4, p1, p2, p3, p4], dtype=np.float32)
if not self._exclude_reference_from_observation:
obs = np.append(obs, r1).astype(np.float32)
if not self._exclude_reference_error_from_observation:
obs = np.append(obs, np.float32(p1 - r1))
obs = np.append(obs, p1 - r1).astype(np.float32)

# Return initial observation and info_dict.
return obs, dict(
Expand Down
7 changes: 5 additions & 2 deletions stable_gym/envs/classic_control/cartpole_cost/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,15 @@ By default, for the reference tracking task the environment returns the followin
* $x_{dot}$ - Cart Velocity.
* $w$ - Pole angle.
* $w_{dot}$ - Pole angle velocity.
* $x_{ref}$ - The cart position reference.

An extra variable will be returned if the `exclude_reference_error_from_observation` flag is set to `False`:
Extra variables will be returned if the `exclude_reference_from_observation` and `exclude_reference_error_from_observation`
flag are set to `False`:

* $x_{ref}$ - The cart position reference.
* $x_{ref\_error}$ - The reference tracking error.

The reference can not be excluded when `reference_type` is set to `periodic`.

## Action space

* **u1:** The x-force applied on the cart.
Expand Down
25 changes: 18 additions & 7 deletions stable_gym/envs/classic_control/cartpole_cost/cartpole_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def __init__(
reference_constraint_position=4.0,
max_cost=100.0,
clip_action=True,
exclude_reference_from_observation=False,
exclude_reference_error_from_observation=True,
):
"""Initialise a new CartPoleCost environment instance.
Expand All @@ -188,6 +189,10 @@ def __init__(
terminated. Defaults to ``100.0``.
clip_action (str, optional): Whether the actions should be clipped if
they are greater than the set action limit. Defaults to ``True``.
exclude_reference_from_observation (bool, optional): Whether the reference
should be excluded from the observation. Defaults to ``False``. Can only
be set to ``True`` if ``reference_type`` is ``constant``. Only
used when ``task_type`` is ``reference_tracking``.
exclude_reference_error_from_observation (bool, optional): Whether the error
should be excluded from the observation. Defaults to ``True``. Only
used when ``task_type`` is ``reference_tracking``.
Expand All @@ -204,6 +209,14 @@ def __init__(
raise ValueError(
"Invalid reference type. Options are 'constant' and 'periodic'."
)
assert (
reference_type.lower() == "periodic"
or reference_type.lower() == "constant"
and not exclude_reference_from_observation
), (
"The reference can only be excluded from the observation if the reference "
"type is constant."
)

# NOTE: Compared to the original I store the initial values for the reset
# function and replace the `self.total_mass` and `self.polemass_length` with
Expand Down Expand Up @@ -232,6 +245,7 @@ def __init__(
# Create observation space bounds.
# Angle limit set to 2 * theta_threshold_radians so failing observation
# is still within bounds.
self._exclude_reference_from_observation = exclude_reference_from_observation
self._exclude_reference_error_from_observation = (
exclude_reference_error_from_observation
)
Expand All @@ -246,13 +260,10 @@ def __init__(
)
# NOTE: When reference tracking add two extra observation states.
if task_type.lower() == "reference_tracking":
high = np.append(
high,
np.repeat(
self.x_threshold * 2,
1 if self._exclude_reference_error_from_observation else 2,
),
).astype(np.float32)
if not self._exclude_reference_from_observation:
high = np.append(high, self.x_threshold * 2).astype(np.float32)
if not self._exclude_reference_error_from_observation:
high = np.append(high, self.x_threshold * 2).astype(np.float32)

self.action_space = spaces.Box(
low=-self.force_mag, high=self.force_mag, shape=(1,), dtype=np.float32
Expand Down
3 changes: 1 addition & 2 deletions stable_gym/envs/mujoco/ant_cost/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ An actuated 8-jointed ant. This environment corresponds to the [Ant-v4](https://

* The objective was changed to a velocity-tracking task. To do this, the reward is replaced with a cost. This cost is the squared
difference between the Ant's forward velocity and a reference value (error).
* The reference velocity was added to the observation space.
* Two **optional** variables were added to the observation space These are the ant's forward velocity and the error (difference between the ant's forward velocity and the reference velocity). These variables can be enabled using the `exclude_velocity_from_observation` and `exclude_reference_error_from_observation` environment arguments.
* Three **optional** variables were added to the observation space; The reference velocity, the reference error (i.e. the difference between the ant's forward velocity and the reference) and the ant's forward velocity. These variables can be enabled using the `exclude_reference_from_observation`, `exclude_reference_error_from_observation` and `exclude_velocity_from_observation` environment arguments.

The rest of the environment is the same as the original Ant environment. Below, the modified cost is described. For more information about the environment (e.g. observation space, action space, episode termination, etc.), please refer to the [gymnasium library](https://gymnasium.farama.org/environments/mujoco/ant/).

Expand Down
10 changes: 4 additions & 6 deletions stable_gym/envs/mujoco/ant_cost/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@
- The objective was changed to a velocity-tracking task. To do this, the reward is replaced with a cost.
This cost is the squared difference between the Ant's forward velocity and a reference value (error).
- The reference velocity was added to the observation space.
- Two **optional** variables were added to the observation space. These are
the ant's forward velocity and the error (difference between the ant's
forward velocity and the reference velocity). These variables can be enabled using
the ``exclude_velocity_from_observation`` and ``exclude_reference_error_from_observation``
environment arguments.
- Three **optional** variables were added to the observation space; The reference velocity, the reference error
(i.e. the difference between the ant's forward velocity and the reference) and the ant's forward velocity.
These variables can be enabled using the ``exclude_reference_from_observation``, ``exclude_reference_error_from_observation``
and ``exclude_velocity_from_observation`` environment arguments.
.. _`Han et al. 2020`: https://arxiv.org/abs/2004.14288
""" # noqa: E501
Expand Down
Loading

0 comments on commit c32b4ee

Please sign in to comment.