Skip to content

Commit

Permalink
fix(envs): fix some upstream deprication issues
Browse files Browse the repository at this point in the history
This commit fixes some upstream deprication issues I missed in my
previous commits.
  • Loading branch information
rickstaa committed May 30, 2023
1 parent f4eba93 commit d2b2be5
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 15 deletions.
4 changes: 2 additions & 2 deletions examples/use_simzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
# TODO: [ ] Check ex3_ekf.
# TODO: [ ] Check cartpole.

ENV_NAME = "Oscillator-v1"
# ENV_NAME = "Ex3EKF-v1"
# ENV_NAME = "Oscillator-v1"
ENV_NAME = "Ex3EKF-v1"

if __name__ == "__main__":
env = gym.make(ENV_NAME, render_mode="human")
Expand Down
21 changes: 18 additions & 3 deletions simzoo/envs/classic_control/cart_pole_cost/cart_pole_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def step(self, action):
or cost < self.reward_range.low
)

# Calculate stopping cost
# Define stopping criteria
if terminated:
cost = 100.0

Expand All @@ -405,7 +405,7 @@ def step(self, action):
)
self.steps_beyond_done += 1

# Return state, cost, terminated and info_dict
# Return state, cost, done and info_dict
violation_of_constraint = bool(abs(x) > self.const_pos)
violation_of_x_threshold = bool(x < -self.x_threshold or x > self.x_threshold)
return (
Expand Down Expand Up @@ -435,6 +435,7 @@ def reset(self, random=True, seed=None):
Returns:
numpy.ndarray: Array containing the current observations.
info_dict (:obj:`dict`): Dictionary with additional information.
"""
if seed is not None:
self.seed(seed)
Expand All @@ -449,7 +450,21 @@ def reset(self, random=True, seed=None):
)
self.steps_beyond_done = None
self.t = 0.0
return np.array(self.state)

# Return state and info_dict
x, _, theta, _ = self.state
cost, ref = self.cost(x, theta)
violation_of_constraint = bool(abs(x) > self.const_pos)
violation_of_x_threshold = bool(x < -self.x_threshold or x > self.x_threshold)
return np.array(self.state), dict(
cons_pos=self.const_pos,
cons_theta=self.theta_threshold_radians,
target=self.target_pos,
violation_of_x_threshold=violation_of_x_threshold,
violation_of_constraint=violation_of_constraint,
reference=ref,
state_of_interest=theta,
)

def render(self, render_mode="human"):
"""Render one frame of the environment.
Expand Down
41 changes: 31 additions & 10 deletions simzoo/envs/classic_control/ex3_ekf/ex3_ekf.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,9 @@ def step(self, action):
state + self.np_random.multivariate_normal(self.mean1, self.cov1).flatten()
) # Add process noise
x_1, x_2 = state
y_1 = np.sin(x_1) + self.np_random.normal(self.mean2, np.sqrt(self.cov2))

# Retrieve reference
y_1 = self.reference(x_1, x_2)
hat_y_1 = np.sin(hat_x_1 + self.dt * hat_x_2)

# Mimic the signal drop rate
Expand Down Expand Up @@ -272,28 +274,47 @@ def step(self, action):
),
)

def reset(self):
def reset(self, seed=None):
"""Reset gym environment.
Args:
action (bool, optional): Whether we want to randomly initialize the
environment. By default True.
seed (int, optional): A random seed for the environment. By default
`None``.
Returns:
numpy.ndarray: Array containing the current observations.
info_dict (:obj:`dict`): Dictionary with additional information.
"""
if seed is not None:
self.seed(seed)

x_1 = self.np_random.uniform(-np.pi / 2, np.pi / 2)
x_2 = self.np_random.uniform(-np.pi / 2, np.pi / 2)
hat_x_1 = x_1 + self.np_random.uniform(-np.pi / 4, np.pi / 4)
hat_x_2 = x_2 + self.np_random.uniform(-np.pi / 4, np.pi / 4)
self.state = np.array([hat_x_1, hat_x_2, x_1, x_2])
self.output = np.sin(x_1) + self.np_random.normal(
self.mean2, np.sqrt(self.cov2)
)
# y_1 = self.output
# y_2 = np.sin(x_2) + self.np_random.normal(self.mean2, np.sqrt(self.cov2))

# Retrieve reference
y_1 = self.reference(x_1)

self.output = y_1
self.t = 0.0
return np.array([hat_x_1, hat_x_2, x_1, x_2])
return np.array([hat_x_1, hat_x_2, x_1, x_2]), dict(
reference=y_1,
state_of_interest=np.array([hat_x_1 - x_1, hat_x_2 - x_2]),
)

def reference(self, x):
"""Returns the current value of the periodic reference signal that is tracked by
the Synthetic oscillatory network.
Args:
x (float): The reference value.
Returns:
float: The current reference value.
"""
return np.sin(x) + self.np_random.normal(self.mean2, np.sqrt(self.cov2))

def render(self, mode="human"):
"""Render one frame of the environment.
Expand Down

0 comments on commit d2b2be5

Please sign in to comment.