Skip to content

Commit

Permalink
🔀 Merges main branch
Browse files Browse the repository at this point in the history
  • Loading branch information
rickstaa committed Apr 12, 2021
1 parent acd28f5 commit 0f9e5cf
Showing 1 changed file with 132 additions and 56 deletions.
188 changes: 132 additions & 56 deletions simzoo/common/disturber.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def periodic_disturbance(time):
DISTURBER_CFG = {
# Disturbance type when no type has been given
"default_type": "input_disturbance",
# Disturbance applied to environment variables
# Disturbance applied to *ENVIRONMENT* variables
# NOTE: The values below are meant as an example the environment disturbance config
# needs to be implemented inside the environment.
"env_disturbance": {
Expand All @@ -47,50 +47,50 @@ def periodic_disturbance(time):
"variable": "c1",
# The range of values you want to use for each disturbance iteration
"variable_range": np.linspace(1.6, 3.0, num=5, dtype=np.float32),
# Label used in robustness plots.
# Label used in robustness plots
"label": "r: %s",
},
# Disturbance applied to the *INPUT* of the environment step function
"input_disturbance": {
# The variant used when no variant is given by the user.
# The variant used when no variant is given by the user
"default_variant": "impulse",
# Impulse disturbance applied in the opposite direction of the action at a given
# timestep.
# timestep
"impulse": {
"description": "Impulse disturbance",
# The step at which you want to apply the impulse.
# The step at which you want to apply the impulse
"impulse_instant": 100,
# The magnitudes you want to apply.
# The magnitudes you want to apply
"magnitude_range": np.linspace(0.0, 3.0, num=5, dtype=np.float),
# Label used in robustness plots.
# Label used in robustness plots
"label": "M: %s",
},
# Similar to the impulse above but now the impulse force is continuously applied
# against the action after the impulse instant has been reached.
"constant_impulse": {
"description": "Constant impulse disturbance",
# The step at which you want to apply the impulse.
# The step at which you want to apply the impulse
"impulse_instant": 100,
# The magnitudes you want to apply.
# The magnitudes you want to apply
"magnitude_range": np.linspace(80, 155, num=3, dtype=np.int),
# Label that can be used in plots.
# Label that can be used in plots
"label": "M: %s",
},
# A periodic signal noise that is applied to the action at every time step.
# A periodic signal noise that is applied to the action at every time step
"periodic": {
"description": "Periodic noise disturbance",
# The magnitudes of the periodic signal.
# The magnitudes of the periodic signal
"amplitude_range": np.linspace(10, 80, num=3, dtype=np.int),
# The function that describes the signal
# NOTE: A amplitude between 0-1 is recommended.
"periodic_function": periodic_disturbance,
# Label used in robustness plots.
# Label used in robustness plots
"label": "A: %s",
},
# A random noise that is applied to the action at every timestep.
"noise": {
"description": "Random noise disturbance",
# The means and standards deviations of the random noise disturbances.
# The means and standards deviations of the random noise disturbance
"noise_range": {
"mean": np.linspace(80, 155, num=3, dtype=np.int),
"std": np.linspace(1.0, 5.0, num=3, dtype=np.int),
Expand All @@ -101,33 +101,57 @@ def periodic_disturbance(time):
},
# Disturbance applied to the *OUTPUT* of the environment step function
"output_disturbance": {
# The variant used when no variant is given by the user.
# The variant used when no variant is given by the user
"default_variant": "noise",
# A periodic signal noise that is applied to the action at every time step.
# A periodic signal noise that is applied to the action at every time step
"periodic": {
"description": "Periodic noise disturbance",
# The magnitudes of the periodic signal.
# The magnitudes of the periodic signal
"amplitude_range": np.linspace(10, 80, num=3, dtype=np.int),
# The function that describes the signal
# NOTE: A amplitude between 0-1 is recommended.
"periodic_function": periodic_disturbance,
# Label used in robustness plots.
# Label used in robustness plots
"label": "A: %s",
},
# A random noise that is applied to the action at every timestep.
# A random noise that is applied to the action at every timestep
"noise": {
"description": "Random noise disturbance",
# The means and standards deviations of the random noise disturbances.
# The means and standards deviations of the random noise disturbance
"noise_range": {
"mean": np.linspace(80, 155, num=3, dtype=np.int),
"std": np.linspace(1.0, 5.0, num=3, dtype=np.int),
},
# Label used in robustness plots.
# Label used in robustness plots
"label": "x̅:%s, σ:%s",
},
},
# Combined disturbances
"combined": {},
# Disturbance applied to both the *INPUT* and *OUTPUT* of the environment step
# function
"combined": {
# A random noise that is applied to the action and output at every timestep
"noise": {
"description": "Random input and output noise disturbance",
"input": {
# The means and standards deviations of the random input noise
# disturbance
"noise_range": {
"mean": np.linspace(80, 155, num=3, dtype=np.int),
"std": np.linspace(1.0, 5.0, num=3, dtype=np.int),
},
},
"output": {
# The means and standards deviations of the random output noise
# disturbance
"noise_range": {
"mean": np.linspace(80, 155, num=3, dtype=np.int),
"std": np.linspace(1.0, 5.0, num=3, dtype=np.int),
},
},
# Label used in robustness plots.
"label": "x̅:(%s, %s), σ:(%s, %s)",
},
},
}


Expand Down Expand Up @@ -330,13 +354,12 @@ def _get_noise_disturbance(self, input):
len(input),
)

def _validate_disturbance_variant_cfg(self):
"""Validates the disturbance configuration dictionary to see if it contains the
right information to apply the requested disturbance *variant*.
"""
# Check if a disturbance range key is present
def _validate_disturbance_variant_keys(self, disturbance_cfg):
# TODO: DOCSTRING

# Check if range key is present
disturbance_range_keys = [
key for key in self._disturbance_cfg.keys() if "_range" in key
key for key in disturbance_cfg.keys() if "_range" in key
]
if len(disturbance_range_keys) > 1:
raise ValueError(
Expand All @@ -355,14 +378,15 @@ def _validate_disturbance_variant_cfg(self):
"'disturber_cfg'."
)

# Check if the required keys are present for the requested disturbance variant
# Check if the required keys are present for the requested disturbance
# variant
if (
self._disturbance_variant == "impulse"
or self._disturbance_variant == "constant_impulse"
):
assert all(
[
req_key in self._disturbance_cfg.keys()
req_key in disturbance_cfg.keys()
for req_key in ["magnitude_range", "impulse_instant"]
]
), (
Expand All @@ -372,41 +396,70 @@ def _validate_disturbance_variant_cfg(self):
elif self._disturbance_variant == "periodic":
assert all(
[
req_key in self._disturbance_cfg.keys()
req_key in disturbance_cfg.keys()
for req_key in ["amplitude_range", "periodic_function"]
]
), (
"The 'impulse' disturbance config is invalid. Please make sure it "
"contains a 'amplitude_range' and 'periodic_function' key."
)
assert callable(self._disturbance_cfg["periodic_function"]), (
assert callable(disturbance_cfg["periodic_function"]), (
"The 'impulse' disturbance config is invalid. Please make sure the "
"'periodic_function' key contains a callable function."
)
elif self._disturbance_variant == "noise":
assert "noise_range" in self._disturbance_cfg.keys(), (
assert "noise_range" in disturbance_cfg.keys(), (
"The 'noise' disturbance config is invalid. Please make sure it "
"contains a 'noise_range' key."
)
assert len(self._disturbance_cfg["noise_range"]["mean"]) == len(
self._disturbance_cfg["noise_range"]["std"]
assert len(disturbance_cfg["noise_range"]["mean"]) == len(
disturbance_cfg["noise_range"]["std"]
), (
"The 'noise' disturbance config is invalid. Please make sure the "
" length of the 'mean' and 'std' keys are equal."
)

def _validate_disturbance_variant_cfg(self):
"""Validates the disturbance configuration dictionary to see if it contains the
right information to apply the requested disturbance *variant*.
"""
if self._disturbance_type != "combined":
try:
self._validate_disturbance_variant_keys(self._disturbance_cfg)
except (AssertionError, ValueError) as e:
raise Exception(
f"The '{self._disturbance_variant}' disturbance config is invalid. "
"Please check the configuration and try again."
) from e
else:
req_keys = ["input", "output"]
assert all(
[req_key in self._disturbance_cfg.keys() for req_key in req_keys]
), (
"The 'combined_disturbance' config is invalid. Please make sure it "
"contains a 'input' and 'output' key."
)
for req_key in req_keys:
try:
self._validate_disturbance_variant_keys(
self._disturbance_cfg[req_key]
)
except (AssertionError, ValueError) as e:
raise Exception(
"The 'combined_disturbance' disturbance config is invalid. "
"Please check the configuration and try again."
) from e

def _validate_disturbance_cfg(self):
"""Validates the disturbance configuration dictionary to see if it contains the
right information to apply the requested disturbance *type* and *variant*.
"""
if self._disturbance_type != "env_disturbance":
self._validate_disturbance_variant_cfg()
else:
req_keys = ["variable", "variable_range"]
assert all(
[
req_key in self._disturbance_cfg.keys()
for req_key in ["description", "variable", "variable_range"]
]
[req_key in self._disturbance_cfg.keys() for req_key in req_keys]
), (
"The 'env_disturbance' config is invalid. Please make sure it contains "
"a 'variable' and 'variable_range' key."
Expand Down Expand Up @@ -630,12 +683,17 @@ def init_disturber( # noqa E901

# Validate disturbance type and/or variant input arguments
disturbance_type_input = disturbance_type
disturbance_type = (
disturbance_type.lower() + "_disturbance"
if "_disturbance" not in disturbance_type.lower()
else disturbance_type.lower()
)
if disturbance_type not in self._disturber_cfg.keys():
disturbance_type = [
item
for item in [
disturbance_type,
disturbance_type.lower(),
disturbance_type + "_disturbance",
disturbance_type.lower() + "_disturbance",
]
if item in self._disturber_cfg.keys()
][0]
if not disturbance_type:
try:
environment_name = self.unwrapped.spec.id
except AttributeError:
Expand All @@ -649,8 +707,10 @@ def init_disturber( # noqa E901
f"for the '{environment_name}' environment. Please specify a "
f"valid disturbance type {valid_keys}."
),
"disturbance_Type",
"disturbance_type",
)

# Validate disturbance variant input argument
if disturbance_type != "env_disturbance":
if disturbance_variant is None:
if "default_variant" in self._disturber_cfg[disturbance_type].keys():
Expand Down Expand Up @@ -683,10 +743,15 @@ def init_disturber( # noqa E901
"disturbance_variant",
)
else:
if (
disturbance_variant
not in self._disturber_cfg[disturbance_type].keys()
):
disturbance_variant = [
item
for item in [
disturbance_variant,
disturbance_variant.lower(),
]
if item in self._disturber_cfg[disturbance_type].keys()
][0]
if not disturbance_variant:
raise ValueError(
(
f"Disturber variant '{disturbance_variant}' is not "
Expand All @@ -698,7 +763,6 @@ def init_disturber( # noqa E901
),
"disturbance_variant",
)
disturbance_variant = disturbance_variant.lower()

# Set the disturber parameters
if self._disturbance_type is not None:
Expand Down Expand Up @@ -776,23 +840,35 @@ def disturbed_step(self, action, *args, **kwargs):
"not yet been initialized using the 'init_disturber' method. Please "
"initialize the disturber and try again."
)

# Retrieve the disturbed step
if self._disturbance_type not in ["input_disturbance", "output_disturbance"]:
if self._disturbance_type not in [
"input_disturbance",
"output_disturbance",
"combined",
]:
raise RuntimeError(
"You are trying to retrieve a disturbed step while the disturbance "
f"type is set to be '{self._disturbance_type}'. Please initialize the "
"disturber with the 'input_disturbance' or 'output_disturbance' type "
"if you want to use this feature."
)

# Create time axis if not available
if not self._has_time_vars:
self.t += self.dt # Create time axis if not given by the environment

# Retrieve the disturbed step
if self._disturbance_type.split("_")[0] == "output":
s, r, done, info = self.step(action, *args, **kwargs)
s_dist = s + self._get_disturbance(s)
return s_dist, r, done, info
else:
elif self._disturbance_type.split("_")[0] == "input":
return self.step(action + self._get_disturbance(action), *args, **kwargs)
else:
s, r, done, info = self.step(
action + self._get_disturbance(action), *args, **kwargs
)
s_dist = s + self._get_disturbance(s)
return s_dist, r, done, info

def _apply_env_disturbance(self):
"""Function used to apply the next environment disturbance that is specified in
Expand Down

0 comments on commit 0f9e5cf

Please sign in to comment.