Skip to content

Commit

Permalink
✨ Adds non-homogenious disturbance
Browse files Browse the repository at this point in the history
This commit allows the user to now also control the disturbance of each
action and observation individually. See
https://rickstaa.github.io/bayesian-learning-control/control/eval_robustness.html
for more information.
  • Loading branch information
rickstaa committed Apr 29, 2021
1 parent 544d5ea commit 81f348d
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 87 deletions.
34 changes: 14 additions & 20 deletions simzoo/common/disturbances.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ def impulse_disturbance(
Args:
input_signal (numpy.ndarray): The signal to which the disturbance should be
applied. Used for determining the direction of the disturbance.
impulse_magnitude (float): The magnitude of the impulse disturbance.
impulse_magnitude (union[float, :obj:`np.ndarray`]): The magnitude of the
impulse disturbance.
impulse_instant (float): The time step at which you want to apply the impulse
disturbance.
impulse_type (str): The type of impulse disturbance you want to use. Options
Expand All @@ -38,39 +39,32 @@ def impulse_disturbance(
return dist_val


def periodic_disturbance(
input_signal, current_timestep, amplitude=1, frequency=10, phase_shift=0
):
def periodic_disturbance(current_timestep, amplitude=1, frequency=10, phase_shift=0):
"""Returns a periodic disturbance signal that has the same shape as the input signal.
Args:
input_signal (numpy.ndarray): The signal to which the disturbance should be
applied. Used for determining the direction of the disturbance.
current_timestep(int): The current time step.
amplitude (float, optional): The periodic signal amplitude. Defaults to ``1``.
frequency (float, optional): The periodic signal frequency. Defaults to ``10``.
phase_shift (float, optional): The periodic signal phase shift. Defaults to
``0``.
amplitude (union[float, np.ndarray), optional): The periodic signal amplitude.
Defaults to ``1``.
frequency (union[float, np.ndarray), optional): The periodic signal frequency.
Defaults to ``10``.
phase_shift (union[float, np.ndarray), optional): The periodic signal phase
shift. Defaults to ``0``.
Returns:
numpy.ndarray: The disturbance array.
"""
return (
amplitude
* np.sin(2 * np.pi * frequency * current_timestep + phase_shift)
* np.ones_like(input_signal)
)
return amplitude * np.sin(2 * np.pi * frequency * current_timestep + phase_shift)


def noise_disturbance(input_signal, mean, std):
def noise_disturbance(mean, std):
"""Returns a random noise specified mean and a standard deviation.
Args:
input_signal (numpy.ndarray): The signal to which the disturbance should be
applied. Used for determining the direction of the disturbance.
mean
mean (union[float, :obj:`np.ndarray`]): The mean value of the noise.
std (union[float, :obj:`np.ndarray`]): The standard deviation of the noise.
Returns:
numpy.ndarray: The disturbance array.
"""
return np.random.normal(mean, std, len(input_signal),)
return np.random.normal(mean, std)
109 changes: 74 additions & 35 deletions simzoo/common/disturber.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,17 @@

import gym
import numpy as np
from iteration_utilities import deepflatten

from .disturbances import impulse_disturbance, noise_disturbance, periodic_disturbance
from .helpers import (
abbreviate,
colorize,
flatten_list,
friendly_list,
get_flattened_keys,
get_flattened_values,
strip_underscores,
inject_value,
strip_underscores,
)

# Default Disturber configuration variable
Expand Down Expand Up @@ -269,7 +269,7 @@ def __init__(self, disturber_cfg=None):
self.disturbance_info = {}
self._disturbance_done_warned = False
self._disturber_cfg = (
{**DISTURBER_CFG, **disturber_cfg,}
{**DISTURBER_CFG, **disturber_cfg}
if disturber_cfg is not None
else DISTURBER_CFG
) # Allow users to overwrite the default config
Expand Down Expand Up @@ -299,7 +299,9 @@ def _initate_time_vars(self):
else:
self.dt = 1.0

def _get_disturbance(self, input_signal, disturbance_variant, disturbance_cfg):
def _get_disturbance( # noqa: C901
self, input_signal, disturbance_variant, disturbance_cfg
):
"""Retrieves the right disturbance using the disturbance type and variant that
were set using the :meth:`Disturber.init_disturber` method.
Expand Down Expand Up @@ -373,11 +375,20 @@ def _get_disturbance(self, input_signal, disturbance_variant, disturbance_cfg):
signal_kwargs["phase"] = disturbance_cfg["phase_range"][
self._disturbance_range_idx
]
return periodic_disturbance(input_signal, current_timestep, **signal_kwargs)
if not isinstance(signal_kwargs.values(), np.ndarray):
signal_kwargs = {
k: np.repeat(v, input_signal.shape)
for k, v in signal_kwargs.items()
}
return periodic_disturbance(current_timestep, **signal_kwargs)
elif disturbance_variant == "noise":
mean = disturbance_cfg["noise_range"]["mean"][self._disturbance_range_idx]
std = disturbance_cfg["noise_range"]["std"][self._disturbance_range_idx]
return noise_disturbance(input_signal, mean, std)
if not isinstance(mean, np.ndarray):
mean = np.repeat(mean, input_signal.shape)
if not isinstance(std, np.ndarray):
std = np.repeat(std, input_signal.shape)
return noise_disturbance(mean, std)
else:
raise NotImplementedError(
f"Disturbance variant '{self._disturbance_variant}' not yet "
Expand All @@ -400,7 +411,7 @@ def _set_disturber_type(self, disturbance_type=None):
disturbance_type_input = disturbance_type
disturbance_type = [
item
for item in list(set([disturbance_type, disturbance_type.lower(),]))
for item in list(set([disturbance_type, disturbance_type.lower()]))
if item in self._disturber_cfg.keys()
] # Catch some common human writing errors
disturbance_type = disturbance_type[0] if disturbance_type else None
Expand Down Expand Up @@ -489,7 +500,7 @@ def _set_disturber_variant(self, disturbance_variant):
disturbance_variant = [
item
for item in list(
set([disturbance_variant, disturbance_variant.lower(),])
set([disturbance_variant, disturbance_variant.lower()])
)
if item in self._disturber_cfg[self._disturbance_type].keys()
] # Catch some common human writing errors
Expand Down Expand Up @@ -620,9 +631,10 @@ def _validate_disturbance_cfg(self):
disturbance_range_keys = [
key for key in self._disturbance_cfg[req_key] if "_range" in key
]
disturbance_type = re.search("(input(?=_)|output(?=_))", req_key)[0]
try:
self._validate_disturbance_variant_cfg(
self._disturbance_cfg[req_key]
self._disturbance_cfg[req_key], disturbance_type
)
vals_key_lengths.extend(
[
Expand All @@ -645,7 +657,11 @@ def _validate_disturbance_cfg(self):
len(
set(
[
(len(item) if isinstance(item, list) else 1)
(
len(item)
if isinstance(item, (list, np.ndarray))
else 1
)
for item in vals_key_lengths
]
)
Expand All @@ -661,14 +677,18 @@ def _validate_disturbance_cfg(self):
)
else:
try:
self._validate_disturbance_variant_cfg(self._disturbance_cfg)
self._validate_disturbance_variant_cfg(
self._disturbance_cfg, self._disturbance_type
)
except (AssertionError, ValueError) as e:
raise Exception(
f"The '{self._disturbance_variant}' disturbance config is "
"invalid. Please check the configuration and try again."
) from e

def _validate_disturbance_variant_cfg(self, disturbance_cfg):
def _validate_disturbance_variant_cfg( # noqa: C901
self, disturbance_cfg, disturbance_type
):
"""Validates the disturbance variant configuration object to see if it is valid
for the disturbances that are currently implemented.
Expand Down Expand Up @@ -706,8 +726,7 @@ def _validate_disturbance_variant_cfg(self, disturbance_cfg):
"'disturber_cfg'."
)

# Check if the required keys are present for the requested disturbance
# variant
# Check if required keys are found and the range key has the right length
invalid_keys_string = (
f"The '{self._disturbance_variant}' disturbance config is invalid. "
"Please make sure it contains the "
Expand Down Expand Up @@ -760,6 +779,38 @@ def _validate_disturbance_variant_cfg(self, disturbance_cfg):
"disturbance is added automatically)."
)

# Check if the range keys have the right shape given the disturbance_type
disturbance_range = disturbance_cfg[disturbance_range_keys[0]]
disturbance_range_dict = (
{"var_key": disturbance_range}
if not isinstance(disturbance_range, dict)
else disturbance_range
)
for key, val in disturbance_range_dict.items():
if isinstance(val, np.ndarray) and val.ndim > 1:
req_length = (
self.action_space.shape[0]
if disturbance_type.lower() == "input"
else self.observation_space.shape[0]
)
if val.shape[1] != req_length:
space_name = (
"action space"
if disturbance_type.lower() == "input"
else "observation space"
)
key_and_range_string = (
f"'{key}' key in the '{disturbance_range_keys[0]}'"
if isinstance(disturbance_range, dict)
else f"'{disturbance_range_keys[0]}'"
)
raise ValueError(
f"The '{self._disturbance_variant}' disturbance config is "
"invalid. Please make sure that the length of the "
f"{key_and_range_string} (i.e. {val.shape[0]}) is equal to "
f"the {space_name} size (i.e. {req_length})."
)

def _parse_disturbance_cfg(self):
"""Parse the disturbance config to add determine the disturbance range and add
the initial disturbance (0.0) if it is not yet present.
Expand Down Expand Up @@ -908,21 +959,9 @@ def _get_plot_labels(self): # noqa: C901
"label" in self._disturbance_cfg.keys()
and len(set(self._disturbance_range_keys)) == 1
):
label_values = [
tuple(item) for item in list(np.hstack(np.dstack(label_values)))
]
label_values = zip(*deepflatten(zip(*label_values), depth=1))
else:
label_values = [
tuple(flatten_list(item))
for item in zip(
*[
zip(*item)
if len(item) > 1
else tuple(flatten_list(item))
for item in label_values
]
)
]
label_values = zip(*deepflatten(label_values, depth=1))

# Generate label literal if not supplied
if "label" not in self._disturbance_cfg.keys():
Expand All @@ -934,12 +973,12 @@ def _get_plot_labels(self): # noqa: C901
)
):
disturbance_range = self.disturbance_cfg[sub_var][range_key]
label_keys.append(
label_keys.extend(
get_flattened_keys(disturbance_range)
if isinstance(disturbance_range, dict)
else self._disturbance_range_keys[idx]
else [self._disturbance_range_keys[idx]]
)
label_abbreviations = abbreviate(list(flatten_list(label_keys)))
label_abbreviations = abbreviate(label_keys)
self._disturbance_cfg["label"] = (
":%s, ".join(label_abbreviations) + ":%s"
)
Expand Down Expand Up @@ -977,9 +1016,9 @@ def _get_plot_labels(self): # noqa: C901
raise Exception(
"Something went wrong while creating the 'plot_labels'. It "
"looks like the plot label that is specified in the disturber "
f"config requires {req_vars} values while only "
f"{available_vars} values could be from the disturbance "
"config. Please check your disturbance label and try again."
f"config requires {req_vars} values while {available_vars} values "
"could be retrieved from the disturbance config. Please check your "
"disturbance label and try again."
)
else:
disturbance_range = self.disturbance_cfg[self._disturbance_range_keys[0]]
Expand Down Expand Up @@ -1138,7 +1177,7 @@ def init_disturber( # noqa E901
)
)

def disturbed_step(self, action, *args, **kwargs):
def disturbed_step(self, action, *args, **kwargs): # noqa: C901
"""Takes a action inside the gym environment while applying the requested
disturbance.
Expand Down
37 changes: 17 additions & 20 deletions simzoo/common/helpers.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Functions that are used in multiple simzoo environments.
"""

import collections
import re

import numpy as np
from gym.utils import colorize as gym_colorize


Expand All @@ -30,22 +30,6 @@ def colorize(string, color, bold=False, highlight=False):
return string


def flatten_list(input_list):
"""Generator for flatting a nested list of lists or tuples.
Args:
input_list (list): The list you want to flatten.
Yields:
list: A flattened list
"""
for el in input_list:
if isinstance(el, collections.Iterable) and not isinstance(el, (str, bytes)):
yield from flatten_list(el)
else:
yield el


def get_flattened_values(input_obj):
"""Retrieves all the values that are present in a nested dictionary and appends them
to a list. Its like a recursive version of the :meth:`dict.values()` method.
Expand Down Expand Up @@ -224,7 +208,7 @@ def strip_underscores(text, position="all"):
return text


def inject_value(input_item, value, round_accuracy=2, order=False):
def inject_value(input_item, value, round_accuracy=2, order=False, axis=0):
"""Injects a value into a list or dictionary if it is not yet present.
Args:
Expand All @@ -234,6 +218,8 @@ def inject_value(input_item, value, round_accuracy=2, order=False):
is present. Defaults to 2.
order (bool, optional): Whether the list should be ordered when returned.
Defaults to ``false``.
axis (int, optional): The axis along which you want to inject the value. Only
used when the input is a numpy array. Defaults to ``0``.
Returns:
union[list,dict]: The list or dictionary that contains the value.
Expand All @@ -245,10 +231,21 @@ def inject_value(input_item, value, round_accuracy=2, order=False):
)
if isinstance(input_item, dict):
return {
k: order_op(
[value] + [item for item in v if round(item, round_accuracy) != value]
k: inject_value(
v, value=value, round_accuracy=round_accuracy, order=order, axis=axis
)
for k, v in input_item.items()
}
elif isinstance(input_item, np.ndarray) and input_item.ndim > 1:
transpose_matrix = np.eye(input_item.ndim, dtype=np.int16)
return np.transpose(
np.array(
[
order_op([value] + [it for it in item if it != value])
for item in np.transpose(input_item, transpose_matrix[axis])
]
),
transpose_matrix[axis],
)
else:
return order_op([value] + [item for item in input_item if item != value])
3 changes: 0 additions & 3 deletions simzoo/envs/biological/oscillator/oscillator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
(see https://www-nature-com.tudelft.idm.oclc.org/articles/35002125).
"""

import importlib
import sys

import gym
import matplotlib.pyplot as plt
import numpy as np
Expand Down
Loading

0 comments on commit 81f348d

Please sign in to comment.