Skip to content

Commit

Permalink
feat(oscillator): simplify oscillator environments (#231)
Browse files Browse the repository at this point in the history
This commit removes the 'reference_type' argument since the same can
already be achieved by setting the 'reference_frequency' to `0`.
  • Loading branch information
rickstaa authored Jul 21, 2023
1 parent 5905366 commit da9b589
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 76 deletions.
6 changes: 1 addition & 5 deletions stable_gym/envs/biological/oscillator/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,10 @@ By default, the environment returns the following observation:
* $p_1$ - The lacI (repressor) protein concentration (Inhibits transcription of tetR gene).
* $p_2$ - The tetR (repressor) protein concentration (Inhibits transcription of CI gene).
* $p_3$ - The CI (repressor) protein concentration (Inhibits transcription of lacI gene).

Extra variables will be returned if the `exclude_reference_from_observation` and `exclude_reference_error_from_observation`
flag are set to `False`:

* $r$ - The reference we want to follow.
* $r_{error}$ - The error between the state of interest (i.e. $p_1$) and the reference.

The reference can not be excluded when `reference_type` is set to `periodic`.
The last two variables can be excluded from the observation space by setting the `exclude_reference_from_observation` and `exclude_reference_error_from_observation` environment arguments to `True`. Please note that the environment needs the reference or the reference error to be included in the observation space to function correctly. If both are excluded, the environment will raise an error.

## Action space

Expand Down
54 changes: 21 additions & 33 deletions stable_gym/envs/biological/oscillator/oscillator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ class Oscillator(gym.Env, OscillatorDisturber):
+-----+-----------------------------------------------+-------------------+-------------------+
| 6 | The reference we want to follow | 0 | 100 |
+-----+-----------------------------------------------+-------------------+-------------------+
| 7 || **Optional** - The error between the current | -100 | 100 |
| (7) || **Optional** - The error between the current | -100 | 100 |
| || value of protein 1 and the reference | | |
+-----+-----------------------------------------------+-------------------+-------------------+
Expand Down Expand Up @@ -121,7 +121,6 @@ class Oscillator(gym.Env, OscillatorDisturber):
def __init__(
self,
render_mode=None,
reference_type="periodic",
reference_target_position=8.0,
reference_amplitude=7.0,
reference_frequency=(1 / 200), # NOTE: Han et al. 2020 uses a period of 200.
Expand All @@ -136,24 +135,18 @@ def __init__(
Args:
render_mode (str, optional): The render mode you want to use. Defaults to
``None``. Not used in this environment.
reference_type (str, optional): The type of reference you want to use
(``constant`` or ``periodic``), by default ``periodic``.
reference_target_position: The reference target position, by default
``8.0`` (i.e. the mean of the reference signal).
reference_amplitude: The reference amplitude, by default ``7.0``. Only used
if ``reference_type`` is ``periodic``.
reference_frequency: The reference frequency, by default ``200``. Only used
if ``reference_type`` is ``periodic``.
reference_phase_shift: The reference phase shift, by default ``0.0``. Only
used if ``reference_type`` is ``periodic``.
reference_amplitude: The reference amplitude, by default ``7.0``.
reference_frequency: The reference frequency, by default ``200``.
reference_phase_shift: The reference phase shift, by default ``0.0``.
reference_constraint_position: The reference constraint position, by
default ``20.0``. Not used in the environment but used for the info
dict.
clip_action (str, optional): Whether the actions should be clipped if
they are greater than the set action limit. Defaults to ``True``.
exclude_reference_from_observation (bool, optional): Whether the reference
should be excluded from the observation. Defaults to ``False``. Can only
be set to ``True`` if ``reference_type`` is ``constant``.
should be excluded from the observation. Defaults to ``False``.
exclude_reference_error_from_observation (bool, optional): Whether the error
should be excluded from the observation. Defaults to ``True``.
"""
Expand All @@ -166,20 +159,17 @@ def __init__(
)

# Validate input arguments.
if reference_type.lower() not in ["constant", "periodic"]:
raise ValueError(
"The reference type must be either 'constant' or 'periodic'."
)
assert (
reference_type.lower() == "constant"
or reference_type.lower() == "periodic"
and not exclude_reference_from_observation
assert not (
exclude_reference_from_observation
and exclude_reference_error_from_observation
), (
"The reference can only be excluded from the observation if the reference "
"type is constant."
"The agent needs to observe either the reference or the reference error "
"for it to be able to learn."
)
assert (
reference_frequency >= 0
), "The reference frequency must be greater than or equal to zero."

self.reference_type = reference_type
self.t = 0.0
self.dt = 1.0
self._init_state = np.array(
Expand Down Expand Up @@ -252,10 +242,6 @@ def __init__(
self.__class__.instances += 1
self.instance_id = self.__class__.instances
logger.debug(f"Oscillator instance '{self.instance_id}' created.")
logger.debug(
f"Oscillator instance '{self.instance_id}' uses a '{reference_type}' "
"reference."
)

def step(self, action):
"""Take step into the environment.
Expand Down Expand Up @@ -504,12 +490,9 @@ def reference(self, t):
- :math:`\phi` is the phase of the signal.
- :math:`C` is the offset of the signal.
"""
if self.reference_type == "periodic":
return self.reference_target_pos + self.reference_amplitude * np.sin(
((2 * np.pi) * self.reference_frequency * t) - self.phase_shift
)
else:
return self.reference_target_pos
return self.reference_target_pos + self.reference_amplitude * np.sin(
((2 * np.pi) * self.reference_frequency * t) - self.phase_shift
)

def render(self, mode="human"):
"""Render one frame of the environment.
Expand All @@ -536,6 +519,11 @@ def tau(self):
"""
return self.dt

@property
def physics_time(self):
"""Returns the physics time. Alias for :attr:`.t`."""
return self.t


if __name__ == "__main__":
print("Setting up 'Oscillator' environment.")
Expand Down
6 changes: 1 addition & 5 deletions stable_gym/envs/biological/oscillator_complicated/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,10 @@ By default, the environment returns the following observation:
* $p_2$ - The tetR (repressor) protein concentration (Inhibits transcription of CI gene).
* $p_3$ - The CI (repressor) protein concentration (Inhibits transcription of extra protein gene).
* $p_4$ - Extra protein concentration (Inhibits transcription of lacI gene).

Extra variables will be returned if the `exclude_reference_from_observation` and `exclude_reference_error_from_observation`
flag are set to `False`:

* $r$ - The reference we want to follow.
* $r_{error}$ - The error between the state of interest (i.e. $p_1$) and the reference.

The reference can not be excluded when `reference_type` is set to `periodic`.
The last two variables can be excluded from the observation space by setting the `exclude_reference_from_observation` and `exclude_reference_error_from_observation` environment arguments to `True`. Please note that the environment needs the reference or the reference error to be included in the observation space to function correctly. If both are excluded, the environment will raise an error.

## Action space

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class is based on the :class:`~stable_gym.envs.biological.oscillator.oscillator.
+-----+-------------------------------------------------+-------------------+-------------------+
| 8 | The reference we want to follow | 0 | 100 |
+-----+-------------------------------------------------+-------------------+-------------------+
| 9 || **Optional** - The error between the current | -100 | 100 |
| (9) || **Optional** - The error between the current | -100 | 100 |
| || value of protein 1 and the reference | | |
+-----+-------------------------------------------------+-------------------+-------------------+
Expand Down Expand Up @@ -137,7 +137,6 @@ class is based on the :class:`~stable_gym.envs.biological.oscillator.oscillator.
def __init__(
self,
render_mode=None,
reference_type="periodic",
reference_target_position=8.0,
reference_amplitude=7.0,
reference_frequency=(1 / 200), # NOTE: Han et al. 2020 uses a period of 200.
Expand All @@ -152,24 +151,18 @@ def __init__(
Args:
render_mode (str, optional): The render mode you want to use. Defaults to
``None``. Not used in this environment.
reference_type (str, optional): The type of reference you want to use
(``constant`` or ``periodic``), by default ``periodic``.
reference_target_position: The reference target position, by default
``8.0`` (i.e. the mean of the reference signal).
reference_amplitude: The reference amplitude, by default ``7.0``. Only used
if ``reference_type`` is ``periodic``.
reference_frequency: The reference frequency, by default ``200``. Only used
if ``reference_type`` is ``periodic``.
reference_phase_shift: The reference phase shift, by default ``0.0``. Only
used if ``reference_type`` is ``periodic``.
reference_amplitude: The reference amplitude, by default ``7.0``.
reference_frequency: The reference frequency, by default ``200``.
reference_phase_shift: The reference phase shift, by default ``0.0``.
reference_constraint_position: The reference constraint position, by
default ``20.0``. Not used in the environment but used for the info
dict.
clip_action (str, optional): Whether the actions should be clipped if
they are greater than the set action limit. Defaults to ``True``.
exclude_reference_from_observation (bool, optional): Whether the reference
should be excluded from the observation. Defaults to ``False``. Can only
be set to ``True`` if ``reference_type`` is ``constant``.
should be excluded from the observation. Defaults to ``False``.
exclude_reference_error_from_observation (bool, optional): Whether the error
should be excluded from the observation. Defaults to ``True``.
"""
Expand All @@ -182,20 +175,17 @@ def __init__(
)

# Validate input arguments.
if reference_type.lower() not in ["constant", "periodic"]:
raise ValueError(
"The reference type must be either 'constant' or 'periodic'."
)
assert (
reference_type.lower() == "constant"
or reference_type.lower() == "periodic"
and not exclude_reference_from_observation
assert not (
exclude_reference_from_observation
and exclude_reference_error_from_observation
), (
"The reference can only be excluded from the observation if the reference "
"type is constant."
"The agent needs to observe either the reference or the reference error "
"for it to be able to learn."
)
assert (
reference_frequency >= 0
), "The reference frequency must be greater than or equal to zero."

self.reference_type = reference_type
self.t = 0.0
self.dt = 1.0
self._init_state = np.array(
Expand Down Expand Up @@ -276,10 +266,6 @@ def __init__(
self.__class__.instances += 1
self.instance_id = self.__class__.instances
logger.debug(f"Oscillator instance '{self.instance_id}' created.")
logger.debug(
f"Oscillator instance '{self.instance_id}' uses a '{reference_type}' "
"reference."
)

def step(self, action):
"""Take step into the environment.
Expand Down Expand Up @@ -557,12 +543,9 @@ def reference(self, t):
- :math:`\phi` is the phase of the signal.
- :math:`C` is the offset of the signal.
"""
if self.reference_type == "periodic":
return self.reference_target_pos + self.reference_amplitude * np.sin(
((2 * np.pi) * self.reference_frequency * t) - self.phase_shift
)
else:
return self.reference_target_pos
return self.reference_target_pos + self.reference_amplitude * np.sin(
((2 * np.pi) * self.reference_frequency * t) - self.phase_shift
)

def render(self, mode="human"):
"""Render one frame of the environment.
Expand All @@ -589,6 +572,11 @@ def tau(self):
"""
return self.dt

@property
def physics_time(self):
"""Returns the physics time. Alias for :attr:`.t`."""
return self.t


if __name__ == "__main__":
print("Setting up 'OscillatorComplicated' environment.")
Expand Down

0 comments on commit da9b589

Please sign in to comment.