diff --git a/simzoo/common/disturber.py b/simzoo/common/disturber.py index 4c1d1f95..a6831d6a 100644 --- a/simzoo/common/disturber.py +++ b/simzoo/common/disturber.py @@ -38,7 +38,7 @@ def periodic_disturbance(time): DISTURBER_CFG = { # Disturbance type when no type has been given "default_type": "input_disturbance", - # Disturbance applied to environment variables + # Disturbance applied to *ENVIRONMENT* variables # NOTE: The values below are meant as an example the environment disturbance config # needs to be implemented inside the environment. "env_disturbance": { @@ -47,50 +47,50 @@ def periodic_disturbance(time): "variable": "c1", # The range of values you want to use for each disturbance iteration "variable_range": np.linspace(1.6, 3.0, num=5, dtype=np.float32), - # Label used in robustness plots. + # Label used in robustness plots "label": "r: %s", }, # Disturbance applied to the *INPUT* of the environment step function "input_disturbance": { - # The variant used when no variant is given by the user. + # The variant used when no variant is given by the user "default_variant": "impulse", # Impulse disturbance applied in the opposite direction of the action at a given - # timestep. + # timestep "impulse": { "description": "Impulse disturbance", - # The step at which you want to apply the impulse. + # The step at which you want to apply the impulse "impulse_instant": 100, - # The magnitudes you want to apply. + # The magnitudes you want to apply "magnitude_range": np.linspace(0.0, 3.0, num=5, dtype=np.float), - # Label used in robustness plots. + # Label used in robustness plots "label": "M: %s", }, # Similar to the impulse above but now the impulse force is continuously applied # against the action after the impulse instant has been reached. "constant_impulse": { "description": "Constant impulse disturbance", - # The step at which you want to apply the impulse. + # The step at which you want to apply the impulse "impulse_instant": 100, - # The magnitudes you want to apply. + # The magnitudes you want to apply "magnitude_range": np.linspace(80, 155, num=3, dtype=np.int), - # Label that can be used in plots. + # Label that can be used in plots "label": "M: %s", }, - # A periodic signal noise that is applied to the action at every time step. + # A periodic signal noise that is applied to the action at every time step "periodic": { "description": "Periodic noise disturbance", - # The magnitudes of the periodic signal. + # The magnitudes of the periodic signal "amplitude_range": np.linspace(10, 80, num=3, dtype=np.int), # The function that describes the signal # NOTE: A amplitude between 0-1 is recommended. "periodic_function": periodic_disturbance, - # Label used in robustness plots. + # Label used in robustness plots "label": "A: %s", }, # A random noise that is applied to the action at every timestep. "noise": { "description": "Random noise disturbance", - # The means and standards deviations of the random noise disturbances. + # The means and standards deviations of the random noise disturbance "noise_range": { "mean": np.linspace(80, 155, num=3, dtype=np.int), "std": np.linspace(1.0, 5.0, num=3, dtype=np.int), @@ -101,33 +101,57 @@ def periodic_disturbance(time): }, # Disturbance applied to the *OUTPUT* of the environment step function "output_disturbance": { - # The variant used when no variant is given by the user. + # The variant used when no variant is given by the user "default_variant": "noise", - # A periodic signal noise that is applied to the action at every time step. + # A periodic signal noise that is applied to the action at every time step "periodic": { "description": "Periodic noise disturbance", - # The magnitudes of the periodic signal. + # The magnitudes of the periodic signal "amplitude_range": np.linspace(10, 80, num=3, dtype=np.int), # The function that describes the signal # NOTE: A amplitude between 0-1 is recommended. "periodic_function": periodic_disturbance, - # Label used in robustness plots. + # Label used in robustness plots "label": "A: %s", }, - # A random noise that is applied to the action at every timestep. + # A random noise that is applied to the action at every timestep "noise": { "description": "Random noise disturbance", - # The means and standards deviations of the random noise disturbances. + # The means and standards deviations of the random noise disturbance "noise_range": { "mean": np.linspace(80, 155, num=3, dtype=np.int), "std": np.linspace(1.0, 5.0, num=3, dtype=np.int), }, - # Label used in robustness plots. + # Label used in robustness plots "label": "x̅:%s, σ:%s", }, }, - # Combined disturbances - "combined": {}, + # Disturbance applied to both the *INPUT* and *OUTPUT* of the environment step + # function + "combined": { + # A random noise that is applied to the action and output at every timestep + "noise": { + "description": "Random input and output noise disturbance", + "input": { + # The means and standards deviations of the random input noise + # disturbance + "noise_range": { + "mean": np.linspace(80, 155, num=3, dtype=np.int), + "std": np.linspace(1.0, 5.0, num=3, dtype=np.int), + }, + }, + "output": { + # The means and standards deviations of the random output noise + # disturbance + "noise_range": { + "mean": np.linspace(80, 155, num=3, dtype=np.int), + "std": np.linspace(1.0, 5.0, num=3, dtype=np.int), + }, + }, + # Label used in robustness plots. + "label": "x̅:(%s, %s), σ:(%s, %s)", + }, + }, } @@ -330,13 +354,12 @@ def _get_noise_disturbance(self, input): len(input), ) - def _validate_disturbance_variant_cfg(self): - """Validates the disturbance configuration dictionary to see if it contains the - right information to apply the requested disturbance *variant*. - """ - # Check if a disturbance range key is present + def _validate_disturbance_variant_keys(self, disturbance_cfg): + # TODO: DOCSTRING + + # Check if range key is present disturbance_range_keys = [ - key for key in self._disturbance_cfg.keys() if "_range" in key + key for key in disturbance_cfg.keys() if "_range" in key ] if len(disturbance_range_keys) > 1: raise ValueError( @@ -355,14 +378,15 @@ def _validate_disturbance_variant_cfg(self): "'disturber_cfg'." ) - # Check if the required keys are present for the requested disturbance variant + # Check if the required keys are present for the requested disturbance + # variant if ( self._disturbance_variant == "impulse" or self._disturbance_variant == "constant_impulse" ): assert all( [ - req_key in self._disturbance_cfg.keys() + req_key in disturbance_cfg.keys() for req_key in ["magnitude_range", "impulse_instant"] ] ), ( @@ -372,29 +396,60 @@ def _validate_disturbance_variant_cfg(self): elif self._disturbance_variant == "periodic": assert all( [ - req_key in self._disturbance_cfg.keys() + req_key in disturbance_cfg.keys() for req_key in ["amplitude_range", "periodic_function"] ] ), ( "The 'impulse' disturbance config is invalid. Please make sure it " "contains a 'amplitude_range' and 'periodic_function' key." ) - assert callable(self._disturbance_cfg["periodic_function"]), ( + assert callable(disturbance_cfg["periodic_function"]), ( "The 'impulse' disturbance config is invalid. Please make sure the " "'periodic_function' key contains a callable function." ) elif self._disturbance_variant == "noise": - assert "noise_range" in self._disturbance_cfg.keys(), ( + assert "noise_range" in disturbance_cfg.keys(), ( "The 'noise' disturbance config is invalid. Please make sure it " "contains a 'noise_range' key." ) - assert len(self._disturbance_cfg["noise_range"]["mean"]) == len( - self._disturbance_cfg["noise_range"]["std"] + assert len(disturbance_cfg["noise_range"]["mean"]) == len( + disturbance_cfg["noise_range"]["std"] ), ( "The 'noise' disturbance config is invalid. Please make sure the " " length of the 'mean' and 'std' keys are equal." ) + def _validate_disturbance_variant_cfg(self): + """Validates the disturbance configuration dictionary to see if it contains the + right information to apply the requested disturbance *variant*. + """ + if self._disturbance_type != "combined": + try: + self._validate_disturbance_variant_keys(self._disturbance_cfg) + except (AssertionError, ValueError) as e: + raise Exception( + f"The '{self._disturbance_variant}' disturbance config is invalid. " + "Please check the configuration and try again." + ) from e + else: + req_keys = ["input", "output"] + assert all( + [req_key in self._disturbance_cfg.keys() for req_key in req_keys] + ), ( + "The 'combined_disturbance' config is invalid. Please make sure it " + "contains a 'input' and 'output' key." + ) + for req_key in req_keys: + try: + self._validate_disturbance_variant_keys( + self._disturbance_cfg[req_key] + ) + except (AssertionError, ValueError) as e: + raise Exception( + "The 'combined_disturbance' disturbance config is invalid. " + "Please check the configuration and try again." + ) from e + def _validate_disturbance_cfg(self): """Validates the disturbance configuration dictionary to see if it contains the right information to apply the requested disturbance *type* and *variant*. @@ -402,11 +457,9 @@ def _validate_disturbance_cfg(self): if self._disturbance_type != "env_disturbance": self._validate_disturbance_variant_cfg() else: + req_keys = ["variable", "variable_range"] assert all( - [ - req_key in self._disturbance_cfg.keys() - for req_key in ["description", "variable", "variable_range"] - ] + [req_key in self._disturbance_cfg.keys() for req_key in req_keys] ), ( "The 'env_disturbance' config is invalid. Please make sure it contains " "a 'variable' and 'variable_range' key." @@ -630,12 +683,17 @@ def init_disturber( # noqa E901 # Validate disturbance type and/or variant input arguments disturbance_type_input = disturbance_type - disturbance_type = ( - disturbance_type.lower() + "_disturbance" - if "_disturbance" not in disturbance_type.lower() - else disturbance_type.lower() - ) - if disturbance_type not in self._disturber_cfg.keys(): + disturbance_type = [ + item + for item in [ + disturbance_type, + disturbance_type.lower(), + disturbance_type + "_disturbance", + disturbance_type.lower() + "_disturbance", + ] + if item in self._disturber_cfg.keys() + ][0] + if not disturbance_type: try: environment_name = self.unwrapped.spec.id except AttributeError: @@ -649,8 +707,10 @@ def init_disturber( # noqa E901 f"for the '{environment_name}' environment. Please specify a " f"valid disturbance type {valid_keys}." ), - "disturbance_Type", + "disturbance_type", ) + + # Validate disturbance variant input argument if disturbance_type != "env_disturbance": if disturbance_variant is None: if "default_variant" in self._disturber_cfg[disturbance_type].keys(): @@ -683,10 +743,15 @@ def init_disturber( # noqa E901 "disturbance_variant", ) else: - if ( - disturbance_variant - not in self._disturber_cfg[disturbance_type].keys() - ): + disturbance_variant = [ + item + for item in [ + disturbance_variant, + disturbance_variant.lower(), + ] + if item in self._disturber_cfg[disturbance_type].keys() + ][0] + if not disturbance_variant: raise ValueError( ( f"Disturber variant '{disturbance_variant}' is not " @@ -698,7 +763,6 @@ def init_disturber( # noqa E901 ), "disturbance_variant", ) - disturbance_variant = disturbance_variant.lower() # Set the disturber parameters if self._disturbance_type is not None: @@ -776,23 +840,35 @@ def disturbed_step(self, action, *args, **kwargs): "not yet been initialized using the 'init_disturber' method. Please " "initialize the disturber and try again." ) - - # Retrieve the disturbed step - if self._disturbance_type not in ["input_disturbance", "output_disturbance"]: + if self._disturbance_type not in [ + "input_disturbance", + "output_disturbance", + "combined", + ]: raise RuntimeError( "You are trying to retrieve a disturbed step while the disturbance " f"type is set to be '{self._disturbance_type}'. Please initialize the " "disturber with the 'input_disturbance' or 'output_disturbance' type " "if you want to use this feature." ) + + # Create time axis if not available if not self._has_time_vars: self.t += self.dt # Create time axis if not given by the environment + + # Retrieve the disturbed step if self._disturbance_type.split("_")[0] == "output": s, r, done, info = self.step(action, *args, **kwargs) s_dist = s + self._get_disturbance(s) return s_dist, r, done, info - else: + elif self._disturbance_type.split("_")[0] == "input": return self.step(action + self._get_disturbance(action), *args, **kwargs) + else: + s, r, done, info = self.step( + action + self._get_disturbance(action), *args, **kwargs + ) + s_dist = s + self._get_disturbance(s) + return s_dist, r, done, info def _apply_env_disturbance(self): """Function used to apply the next environment disturbance that is specified in