From da9b589fb1950a8cee19b934b5bf6f1571016dc3 Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Fri, 21 Jul 2023 14:40:36 +0200 Subject: [PATCH] feat(oscillator): simplify oscillator environments (#231) This commit removes the 'reference_type' argument since the same can already be achieved by setting the 'reference_frequency' to `0`. --- .../envs/biological/oscillator/README.md | 6 +-- .../envs/biological/oscillator/oscillator.py | 54 ++++++++----------- .../oscillator_complicated/README.md | 6 +-- .../oscillator_complicated.py | 54 ++++++++----------- 4 files changed, 44 insertions(+), 76 deletions(-) diff --git a/stable_gym/envs/biological/oscillator/README.md b/stable_gym/envs/biological/oscillator/README.md index a5489a3b..a41d4020 100644 --- a/stable_gym/envs/biological/oscillator/README.md +++ b/stable_gym/envs/biological/oscillator/README.md @@ -12,14 +12,10 @@ By default, the environment returns the following observation: * $p_1$ - The lacI (repressor) protein concentration (Inhibits transcription of tetR gene). * $p_2$ - The tetR (repressor) protein concentration (Inhibits transcription of CI gene). * $p_3$ - The CI (repressor) protein concentration (Inhibits transcription of lacI gene). - -Extra variables will be returned if the `exclude_reference_from_observation` and `exclude_reference_error_from_observation` -flag are set to `False`: - * $r$ - The reference we want to follow. * $r_{error}$ - The error between the state of interest (i.e. $p_1$) and the reference. -The reference can not be excluded when `reference_type` is set to `periodic`. +The last two variables can be excluded from the observation space by setting the `exclude_reference_from_observation` and `exclude_reference_error_from_observation` environment arguments to `True`. Please note that the environment needs the reference or the reference error to be included in the observation space to function correctly. If both are excluded, the environment will raise an error. ## Action space diff --git a/stable_gym/envs/biological/oscillator/oscillator.py b/stable_gym/envs/biological/oscillator/oscillator.py index 248f468f..136bffc9 100644 --- a/stable_gym/envs/biological/oscillator/oscillator.py +++ b/stable_gym/envs/biological/oscillator/oscillator.py @@ -62,7 +62,7 @@ class Oscillator(gym.Env, OscillatorDisturber): +-----+-----------------------------------------------+-------------------+-------------------+ | 6 | The reference we want to follow | 0 | 100 | +-----+-----------------------------------------------+-------------------+-------------------+ - | 7 || **Optional** - The error between the current | -100 | 100 | + | (7) || **Optional** - The error between the current | -100 | 100 | | || value of protein 1 and the reference | | | +-----+-----------------------------------------------+-------------------+-------------------+ @@ -121,7 +121,6 @@ class Oscillator(gym.Env, OscillatorDisturber): def __init__( self, render_mode=None, - reference_type="periodic", reference_target_position=8.0, reference_amplitude=7.0, reference_frequency=(1 / 200), # NOTE: Han et al. 2020 uses a period of 200. @@ -136,24 +135,18 @@ def __init__( Args: render_mode (str, optional): The render mode you want to use. Defaults to ``None``. Not used in this environment. - reference_type (str, optional): The type of reference you want to use - (``constant`` or ``periodic``), by default ``periodic``. reference_target_position: The reference target position, by default ``8.0`` (i.e. the mean of the reference signal). - reference_amplitude: The reference amplitude, by default ``7.0``. Only used - if ``reference_type`` is ``periodic``. - reference_frequency: The reference frequency, by default ``200``. Only used - if ``reference_type`` is ``periodic``. - reference_phase_shift: The reference phase shift, by default ``0.0``. Only - used if ``reference_type`` is ``periodic``. + reference_amplitude: The reference amplitude, by default ``7.0``. + reference_frequency: The reference frequency, by default ``200``. + reference_phase_shift: The reference phase shift, by default ``0.0``. reference_constraint_position: The reference constraint position, by default ``20.0``. Not used in the environment but used for the info dict. clip_action (str, optional): Whether the actions should be clipped if they are greater than the set action limit. Defaults to ``True``. exclude_reference_from_observation (bool, optional): Whether the reference - should be excluded from the observation. Defaults to ``False``. Can only - be set to ``True`` if ``reference_type`` is ``constant``. + should be excluded from the observation. Defaults to ``False``. exclude_reference_error_from_observation (bool, optional): Whether the error should be excluded from the observation. Defaults to ``True``. """ @@ -166,20 +159,17 @@ def __init__( ) # Validate input arguments. - if reference_type.lower() not in ["constant", "periodic"]: - raise ValueError( - "The reference type must be either 'constant' or 'periodic'." - ) - assert ( - reference_type.lower() == "constant" - or reference_type.lower() == "periodic" - and not exclude_reference_from_observation + assert not ( + exclude_reference_from_observation + and exclude_reference_error_from_observation ), ( - "The reference can only be excluded from the observation if the reference " - "type is constant." + "The agent needs to observe either the reference or the reference error " + "for it to be able to learn." ) + assert ( + reference_frequency >= 0 + ), "The reference frequency must be greater than or equal to zero." - self.reference_type = reference_type self.t = 0.0 self.dt = 1.0 self._init_state = np.array( @@ -252,10 +242,6 @@ def __init__( self.__class__.instances += 1 self.instance_id = self.__class__.instances logger.debug(f"Oscillator instance '{self.instance_id}' created.") - logger.debug( - f"Oscillator instance '{self.instance_id}' uses a '{reference_type}' " - "reference." - ) def step(self, action): """Take step into the environment. @@ -504,12 +490,9 @@ def reference(self, t): - :math:`\phi` is the phase of the signal. - :math:`C` is the offset of the signal. """ - if self.reference_type == "periodic": - return self.reference_target_pos + self.reference_amplitude * np.sin( - ((2 * np.pi) * self.reference_frequency * t) - self.phase_shift - ) - else: - return self.reference_target_pos + return self.reference_target_pos + self.reference_amplitude * np.sin( + ((2 * np.pi) * self.reference_frequency * t) - self.phase_shift + ) def render(self, mode="human"): """Render one frame of the environment. @@ -536,6 +519,11 @@ def tau(self): """ return self.dt + @property + def physics_time(self): + """Returns the physics time. Alias for :attr:`.t`.""" + return self.t + if __name__ == "__main__": print("Setting up 'Oscillator' environment.") diff --git a/stable_gym/envs/biological/oscillator_complicated/README.md b/stable_gym/envs/biological/oscillator_complicated/README.md index c6aa3ffd..624ddaf7 100644 --- a/stable_gym/envs/biological/oscillator_complicated/README.md +++ b/stable_gym/envs/biological/oscillator_complicated/README.md @@ -14,14 +14,10 @@ By default, the environment returns the following observation: * $p_2$ - The tetR (repressor) protein concentration (Inhibits transcription of CI gene). * $p_3$ - The CI (repressor) protein concentration (Inhibits transcription of extra protein gene). * $p_4$ - Extra protein concentration (Inhibits transcription of lacI gene). - -Extra variables will be returned if the `exclude_reference_from_observation` and `exclude_reference_error_from_observation` -flag are set to `False`: - * $r$ - The reference we want to follow. * $r_{error}$ - The error between the state of interest (i.e. $p_1$) and the reference. -The reference can not be excluded when `reference_type` is set to `periodic`. +The last two variables can be excluded from the observation space by setting the `exclude_reference_from_observation` and `exclude_reference_error_from_observation` environment arguments to `True`. Please note that the environment needs the reference or the reference error to be included in the observation space to function correctly. If both are excluded, the environment will raise an error. ## Action space diff --git a/stable_gym/envs/biological/oscillator_complicated/oscillator_complicated.py b/stable_gym/envs/biological/oscillator_complicated/oscillator_complicated.py index a6e8ff50..1d890b87 100644 --- a/stable_gym/envs/biological/oscillator_complicated/oscillator_complicated.py +++ b/stable_gym/envs/biological/oscillator_complicated/oscillator_complicated.py @@ -75,7 +75,7 @@ class is based on the :class:`~stable_gym.envs.biological.oscillator.oscillator. +-----+-------------------------------------------------+-------------------+-------------------+ | 8 | The reference we want to follow | 0 | 100 | +-----+-------------------------------------------------+-------------------+-------------------+ - | 9 || **Optional** - The error between the current | -100 | 100 | + | (9) || **Optional** - The error between the current | -100 | 100 | | || value of protein 1 and the reference | | | +-----+-------------------------------------------------+-------------------+-------------------+ @@ -137,7 +137,6 @@ class is based on the :class:`~stable_gym.envs.biological.oscillator.oscillator. def __init__( self, render_mode=None, - reference_type="periodic", reference_target_position=8.0, reference_amplitude=7.0, reference_frequency=(1 / 200), # NOTE: Han et al. 2020 uses a period of 200. @@ -152,24 +151,18 @@ def __init__( Args: render_mode (str, optional): The render mode you want to use. Defaults to ``None``. Not used in this environment. - reference_type (str, optional): The type of reference you want to use - (``constant`` or ``periodic``), by default ``periodic``. reference_target_position: The reference target position, by default ``8.0`` (i.e. the mean of the reference signal). - reference_amplitude: The reference amplitude, by default ``7.0``. Only used - if ``reference_type`` is ``periodic``. - reference_frequency: The reference frequency, by default ``200``. Only used - if ``reference_type`` is ``periodic``. - reference_phase_shift: The reference phase shift, by default ``0.0``. Only - used if ``reference_type`` is ``periodic``. + reference_amplitude: The reference amplitude, by default ``7.0``. + reference_frequency: The reference frequency, by default ``200``. + reference_phase_shift: The reference phase shift, by default ``0.0``. reference_constraint_position: The reference constraint position, by default ``20.0``. Not used in the environment but used for the info dict. clip_action (str, optional): Whether the actions should be clipped if they are greater than the set action limit. Defaults to ``True``. exclude_reference_from_observation (bool, optional): Whether the reference - should be excluded from the observation. Defaults to ``False``. Can only - be set to ``True`` if ``reference_type`` is ``constant``. + should be excluded from the observation. Defaults to ``False``. exclude_reference_error_from_observation (bool, optional): Whether the error should be excluded from the observation. Defaults to ``True``. """ @@ -182,20 +175,17 @@ def __init__( ) # Validate input arguments. - if reference_type.lower() not in ["constant", "periodic"]: - raise ValueError( - "The reference type must be either 'constant' or 'periodic'." - ) - assert ( - reference_type.lower() == "constant" - or reference_type.lower() == "periodic" - and not exclude_reference_from_observation + assert not ( + exclude_reference_from_observation + and exclude_reference_error_from_observation ), ( - "The reference can only be excluded from the observation if the reference " - "type is constant." + "The agent needs to observe either the reference or the reference error " + "for it to be able to learn." ) + assert ( + reference_frequency >= 0 + ), "The reference frequency must be greater than or equal to zero." - self.reference_type = reference_type self.t = 0.0 self.dt = 1.0 self._init_state = np.array( @@ -276,10 +266,6 @@ def __init__( self.__class__.instances += 1 self.instance_id = self.__class__.instances logger.debug(f"Oscillator instance '{self.instance_id}' created.") - logger.debug( - f"Oscillator instance '{self.instance_id}' uses a '{reference_type}' " - "reference." - ) def step(self, action): """Take step into the environment. @@ -557,12 +543,9 @@ def reference(self, t): - :math:`\phi` is the phase of the signal. - :math:`C` is the offset of the signal. """ - if self.reference_type == "periodic": - return self.reference_target_pos + self.reference_amplitude * np.sin( - ((2 * np.pi) * self.reference_frequency * t) - self.phase_shift - ) - else: - return self.reference_target_pos + return self.reference_target_pos + self.reference_amplitude * np.sin( + ((2 * np.pi) * self.reference_frequency * t) - self.phase_shift + ) def render(self, mode="human"): """Render one frame of the environment. @@ -589,6 +572,11 @@ def tau(self): """ return self.dt + @property + def physics_time(self): + """Returns the physics time. Alias for :attr:`.t`.""" + return self.t + if __name__ == "__main__": print("Setting up 'OscillatorComplicated' environment.")