Skip to content

Commit

Permalink
🎨 Cleans up code
Browse files Browse the repository at this point in the history
  • Loading branch information
rickstaa committed Apr 30, 2021
1 parent 09627e3 commit 3b39e4c
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 62 deletions.
108 changes: 62 additions & 46 deletions simzoo/common/disturber.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def _initate_time_vars(self):
self._has_time_vars = True
if not hasattr(self, "dt"):
if hasattr(self, "tau"):
self.dt = self.tau
self.dt = self.tau # Some environments use tau instead of dt
else:
self.dt = 1.0

Expand Down Expand Up @@ -337,6 +337,7 @@ def _get_disturbance( # noqa: C901
return np.zeros_like(input_signal)

# Retrieve the requested disturbance
# NOTE: New disturbances should be aded here.
current_timestep = self.t / self.dt
if "impulse" in disturbance_variant:
impulse_magnitude = disturbance_cfg["magnitude_range"][
Expand Down Expand Up @@ -375,19 +376,25 @@ def _get_disturbance( # noqa: C901
signal_kwargs["phase"] = disturbance_cfg["phase_range"][
self._disturbance_range_idx
]

# Make sure the signal kwarg has the right shape
if not isinstance(signal_kwargs.values(), np.ndarray):
signal_kwargs = {
k: np.repeat(v, input_signal.shape)
for k, v in signal_kwargs.items()
}

return periodic_disturbance(current_timestep, **signal_kwargs)
elif disturbance_variant == "noise":
mean = disturbance_cfg["noise_range"]["mean"][self._disturbance_range_idx]
std = disturbance_cfg["noise_range"]["std"][self._disturbance_range_idx]

# Make sure that the mean and std have the right shape
if not isinstance(mean, np.ndarray):
mean = np.repeat(mean, input_signal.shape)
if not isinstance(std, np.ndarray):
std = np.repeat(std, input_signal.shape)

return noise_disturbance(mean, std)
else:
raise NotImplementedError(
Expand All @@ -407,7 +414,7 @@ def _set_disturber_type(self, disturbance_type=None):
ValueError: Thrown when the disturbance type does not exist on the
disturber.
"""
if disturbance_type is not None: # If None
if disturbance_type is not None:
disturbance_type_input = disturbance_type
disturbance_type = [
item
Expand Down Expand Up @@ -451,7 +458,7 @@ def _set_disturber_type(self, disturbance_type=None):
)
)
disturbance_type = self._disturber_cfg["default_type"]
else:
else: # Thrown warning if type could not be retrieved
valid_type_keys = {
k for k in self._disturber_cfg.keys() if k not in ["default_type"]
}
Expand All @@ -465,7 +472,6 @@ def _set_disturber_type(self, disturbance_type=None):
raise ValueError(
d_type_info_msg, "disturbance_type",
)

self._disturbance_type = disturbance_type

def _set_disturber_variant(self, disturbance_variant):
Expand Down Expand Up @@ -550,6 +556,7 @@ def _set_disturber_variant(self, disturbance_variant):
"default_variant"
]
else:
# Thrown warning if disturbance variant could not be retrieved
valid_variant_keys = {
k
for k in self._disturber_cfg[self._disturbance_type].keys()
Expand Down Expand Up @@ -779,7 +786,7 @@ def _validate_disturbance_variant_cfg( # noqa: C901
"disturbance is added automatically)."
)

# Check if the range keys have the right shape given the disturbance_type
# Check if each observation/action has a disturbance range if it is a 2D array
disturbance_range = disturbance_cfg[disturbance_range_keys[0]]
disturbance_range_dict = (
{"var_key": disturbance_range}
Expand Down Expand Up @@ -828,46 +835,48 @@ def _parse_disturbance_cfg(self):
if "_range" in key
][0]
)

# Thrown warning if the disturbance variant is invalid
if not isinstance(
self._disturbance_cfg[sub_variant_key][
self._disturbance_range_keys[-1]
],
(dict, list, np.ndarray),
):
raise TypeError(
f"The '{sub_variant_key}' variable found in the "
"'disturber_cfg' has the wrong type. Please make sure it "
"contains a 'list' or a 'dictionary'."
)

# Add zero disturbance and retrieve disturbance range length
disturbance_sub_variant_cfg = inject_value(
self._disturbance_cfg[sub_variant_key][
self._disturbance_range_keys[-1]
],
value=0.0,
) # Add undisturbed state if not yet present
if isinstance(
self._disturbance_cfg[sub_variant_key][
self._disturbance_range_keys[-1]
],
dict,
):
disturbance_sub_variant_cfg = inject_value(
self._disturbance_cfg[sub_variant_key][
self._disturbance_range_keys[-1]
],
value=0.0,
) # Add undisturbed state if not yet present
self._disturbance_range_length = len(
list(disturbance_sub_variant_cfg.values())[0]
)
self.disturbance_cfg[sub_variant_key][
self._disturbance_range_keys[-1]
] = disturbance_sub_variant_cfg
elif isinstance(
self._disturbance_cfg[sub_variant_key][
self._disturbance_range_keys[-1]
],
(list, np.ndarray),
):
disturbance_sub_variant_cfg = inject_value(
self._disturbance_cfg[sub_variant_key][
self._disturbance_range_keys[-1]
],
value=0.0,
) # Add undisturbed state if not yet present
self._disturbance_range_length = len(disturbance_sub_variant_cfg)
self.disturbance_cfg[sub_variant_key][
self._disturbance_range_keys[-1]
] = disturbance_sub_variant_cfg
else:
raise TypeError(
f"The '{sub_variant_key}' variable found in the "
"'disturber_cfg' has the wrong type. Please make sure it "
"contains a 'list' or a 'dictionary'."
)

# Store disturbance subvariant config
self.disturbance_cfg[sub_variant_key][
self._disturbance_range_keys[-1]
] = disturbance_sub_variant_cfg
elif self._disturbance_type == "env":
variable = self._disturbance_cfg["variable"]
self._disturbance_range_keys.append(
Expand All @@ -883,29 +892,34 @@ def _parse_disturbance_cfg(self):
self._disturbance_range_keys.append(
[key for key in self._disturbance_cfg.keys() if "_range" in key][0]
)

# Thrown warning if the disturbance variant is invalid
if not isinstance(
self._disturbance_cfg[self._disturbance_range_keys[-1]],
(dict, list, np.ndarray),
):
raise TypeError(
f"The '{self._disturbance_range_keys[-1]}' variable found in "
"the 'disturber_cfg' has the wrong type. Please make sure it "
"contains a 'list' or a 'dictionary'."
)

# Add zero disturbance and retrieve disturbance range length
disturbance_cfg = inject_value(
self._disturbance_cfg[self._disturbance_range_keys[-1]], value=0.0
) # Add undisturbed state if not yet present
if isinstance(
self._disturbance_cfg[self._disturbance_range_keys[-1]], dict
):
disturbance_cfg = inject_value(
self._disturbance_cfg[self._disturbance_range_keys[-1]], value=0.0
) # Add undisturbed state if not yet present
self._disturbance_range_length = len(list(disturbance_cfg.values())[0])
self.disturbance_cfg[self._disturbance_range_keys[-1]] = disturbance_cfg
elif isinstance(
self._disturbance_cfg[self._disturbance_range_keys[-1]],
(list, np.ndarray),
):
disturbance_cfg = inject_value(
self._disturbance_cfg[self._disturbance_range_keys[-1]], value=0.0
) # Add undisturbed state if not yet present
self._disturbance_range_length = len(disturbance_cfg)
self.disturbance_cfg[self._disturbance_range_keys[-1]] = disturbance_cfg
else:
raise TypeError(
f"The '{self._disturbance_range_keys[-1]}' variable found in "
"the 'disturber_cfg' has the wrong type. Please make sure it "
"contains a 'list' or a 'dictionary'."
)

# Store disturbance config
self.disturbance_cfg[self._disturbance_range_keys[-1]] = disturbance_cfg

def _set_disturbance_cfg(self):
"""Sets the disturbance configuration based on the set 'disturbance_type` and/or
Expand All @@ -923,9 +937,9 @@ def _set_disturbance_cfg(self):
self._disturbance_variant
]

# Validate disturbance config, add initial (zero) disturbance and retrieve
# disturbance range length
self._validate_disturbance_cfg()

# Adds initial disturbance to the configuration and get disturbance range length
self._parse_disturbance_cfg()

def _get_plot_labels(self): # noqa: C901
Expand Down Expand Up @@ -1038,6 +1052,8 @@ def _set_disturbance_info(self):
self.disturbance_info["type"] = self._disturbance_type
self.disturbance_info["variant"] = self._disturbance_variant
self.disturbance_info["variables"] = {}

# Store disturbance range values and current value
if self._disturbance_type == "combined":
sub_vars = [
re.search("(input(?=_)|output(?=_))", key)[0]
Expand All @@ -1056,7 +1072,7 @@ def _set_disturbance_info(self):
):
disturbance_range = self.disturbance_cfg[sub_variant][range_key]
self.disturbance_info["variables"][var]["values"] = disturbance_range
if isinstance(self._disturbance_cfg[sub_variant][range_key], dict,):
if isinstance(self._disturbance_cfg[sub_variant][range_key], dict):
self.disturbance_info["variables"][var]["value"] = {
k: v[self._disturbance_range_idx]
for k, v in disturbance_range.items()
Expand Down
32 changes: 19 additions & 13 deletions simzoo/envs/classic_control/cart_pole_cost/cart_pole_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,20 +128,21 @@ def __init__(
self.length = self._length_init = 1.0
self.mass_cart = self._mass_cart_init = 1.0
self.mass_pole = self._mass_pole_init = 0.1
self.gravity = self._gravity_init = 10.0 # DEBUG: OpenAi uses 9.8
self.force_mag = 20 # NOTE: OpenAi uses 10.0
self.gravity = self._gravity_init = 9.8
# self.force_mag = 10 # NOTE: OpenAI values
self.force_mag = 20
self._kinematics_integrator = kinematics_integrator
self._init_state = np.array(
[0.1, 0.2, 0.3, 0.1]
) # Initial state when random is disabled
# self._init_state_range = {
# "low": [-0.2, -0.05, -0.05, -0.05], # NOTE: OpenAi uses -0.05
# "high": [0.2, 0.05, 0.05, 0.05], # NOTE: OpenAi uses 0.05
# } # Initial state range when random is enabled
# "low": [-0.2, -0.05, -0.05, -0.05],
# "high": [0.2, 0.05, 0.05, 0.05],
# } # NOTE: OpenAI values
self._init_state_range = {
"low": [-5, -0.2, -0.2, -0.2], # NOTE: OpenAi uses -0.05
"high": [5, 0.2, 0.2, 0.2], # NOTE: OpenAi uses 0.05
} # DEBUG: Openai uses above
"low": [-5, -0.2, -0.2, -0.2],
"high": [5, 0.2, 0.2, 0.2],
} # Initial state range when random is enabled

# Print environment information
print(
Expand Down Expand Up @@ -170,14 +171,19 @@ def __init__(
# Thresholds
# self.theta_threshold_radians = (
# 12 * 2 * math.pi / 360
# ) # Angle at which to fail the episode
self.theta_threshold_radians = 20 * 2 * math.pi / 360 # DEBUG: Openai uses 12
self.x_threshold = 10 # NOTE: OpenAi Uses 2.4
# ) # NOTE: OpenAi value
self.theta_threshold_radians = (
20 * 2 * math.pi / 360
) # Angle at which to fail the episode
# self.x_threshold = 2.4 # NOTE: OpenAi value
self.x_threshold = 10
self.y_threshold = (
5 # NOTE: Defines real world window height (not used as threshold)
)
self.max_v = 50 # NOTE: OpenAi uses np.finfo(np.float32).max
self.max_w = 50 # NOTE: OpenAi uses np.finfo(np.float32).max
# self.max_v = np.finfo(np.float32).max # NOTE: OpenAi value
# self.max_w = np.finfo(np.float32).max # NOTE: OpenAi value
self.max_v = 50
self.max_w = 50

# Set angle limit set to 2 * theta_threshold_radians so failing observation
# is still within bounds
Expand Down
6 changes: 3 additions & 3 deletions simzoo/envs/classic_control/ex3_ekf/ex3_ekf.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,11 +202,11 @@ def step(self, action):
x_1, x_2 = state
y_1 = np.sin(x_1) + self.np_random.normal(self.mean2, np.sqrt(self.cov2))
hat_y_1 = np.sin(hat_x_1 + self.dt * hat_x_2)

# Mimic the signal drop rate
# flag=1: received
# flag=0: dropout
(flag,) = self.np_random.binomial(1, 1 - self.missing_rate, 1)
# drop_rate = 1
# to construct cost
(flag) = self.np_random.binomial(1, 1 - self.missing_rate, 1)
if flag == 1:
hat_x_1 = hat_x_1 + self.dt * hat_x_2 + self.dt * u1 * (y_1 - hat_y_1)
hat_x_2 = (
Expand Down

0 comments on commit 3b39e4c

Please sign in to comment.