Skip to content

Commit

Permalink
feat: depricate step done return value (#98)
Browse files Browse the repository at this point in the history
This commit depricates the step `done` return value. This was done for consistency with gym v0.25.0
(see openai/gym@907b1b2).

BREAKING CHANGE: step now returns 5 values instead of 4.
  • Loading branch information
rickstaa authored May 30, 2023
1 parent 54e12b5 commit 4f435d6
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 22 deletions.
4 changes: 2 additions & 2 deletions examples/use_simzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
# TODO: [ ] Check ex3_ekf.
# TODO: [ ] Check cartpole.

# ENV_NAME = "Oscillator-v1"
ENV_NAME = "Ex3EKF-v1"
ENV_NAME = "Oscillator-v1"
# ENV_NAME = "Ex3EKF-v1"

if __name__ == "__main__":
env = gym.make(ENV_NAME, render_mode="human")
Expand Down
13 changes: 7 additions & 6 deletions simzoo/envs/biological/oscillator/oscillator.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,10 @@ def step(self, action):
- obs (:obj:`numpy.ndarray`): The current state
- cost (:obj:`numpy.float64`): The current cost.
- done (:obj:`bool`): Whether the episode was done.
- terminated (:obj:`bool`): Whether the episode was done.
- truncated (:obj:`bool`): Whether the episode was truncated. This value
is set by wrappers when for example a time limit is reached or the
agent goes out of bounds.
- info_dict (:obj:`dict`): Dictionary with additional information.
"""
# Clip action if needed
Expand Down Expand Up @@ -299,16 +302,14 @@ def step(self, action):
# cost = (abs(p1 - r1)) ** 0.2

# Define stopping criteria
if cost > self.reward_range.high or cost < self.reward_range.low:
done = True
else:
done = False
terminated = bool(cost > self.reward_range.high or cost < self.reward_range.low)

# Return state, cost, done and reference
return (
np.array([m1, m2, m3, p1, p2, p3, r1, p1 - r1]),
cost,
done,
terminated,
False,
dict(reference=r1, state_of_interest=p1 - r1),
)

Expand Down
20 changes: 12 additions & 8 deletions simzoo/envs/classic_control/cart_pole_cost/cart_pole_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,10 @@ def step(self, action):
- obs (:obj:`numpy.ndarray`): The current state
- cost (:obj:`numpy.float64`): The current cost.
- done (:obj:`bool`): Whether the episode was done.
- terminated (:obj:`bool`): Whether the episode was done.
- truncated (:obj:`bool`): Whether the episode was truncated. This value
is set by wrappers when for example a time limit is reached or the
agent goes out of bounds.
- info_dict (:obj:`dict`): Dictionary with additional information.
"""
# Clip action if needed
Expand Down Expand Up @@ -378,13 +381,15 @@ def step(self, action):
cost, ref = self.cost(x, theta)

# Define stopping criteria
if (
terminated = bool(
abs(x) > self.x_threshold
or abs(theta) > self.theta_threshold_radians
or cost > self.reward_range.high
or cost < self.reward_range.low
):
done = True
)

# Calculate stopping cost
if terminated:
cost = 100.0

# Throw warning if already done
Expand All @@ -399,16 +404,15 @@ def step(self, action):
"True' -- any further steps are undefined behavior."
)
self.steps_beyond_done += 1
else:
done = False

# Return state, cost, done and info_dict
# Return state, cost, terminated and info_dict
violation_of_constraint = bool(abs(x) > self.const_pos)
violation_of_x_threshold = bool(x < -self.x_threshold or x > self.x_threshold)
return (
self.state,
cost,
done,
terminated,
False,
dict(
cons_pos=self.const_pos,
cons_theta=self.theta_threshold_radians,
Expand Down
13 changes: 7 additions & 6 deletions simzoo/envs/classic_control/ex3_ekf/ex3_ekf.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,10 @@ def step(self, action):
- obs (:obj:`numpy.ndarray`): The current state
- cost (:obj:`numpy.float64`): The current cost.
- done (:obj:`bool`): Whether the episode was done.
- terminated (:obj:`bool`): Whether the episode was done.
- truncated (:obj:`bool`): Whether the episode was truncated. This value
is set by wrappers when for example a time limit is reached or the
agent goes out of bounds.
- info_dict (:obj:`dict`): Dictionary with additional information.
"""
# Clip action if needed
Expand Down Expand Up @@ -250,10 +253,7 @@ def step(self, action):
# cost = np.abs(hat_x_1 - x_1)**1 + np.abs(hat_x_2 - x_2)**1

# Define stopping criteria
if cost > self.reward_range.high or cost < self.reward_range.low:
done = True
else:
done = False
terminated = bool(cost > self.reward_range.high or cost < self.reward_range.low)

# Update state
self.state = np.array([hat_x_1, hat_x_2, x_1, x_2])
Expand All @@ -264,7 +264,8 @@ def step(self, action):
return (
np.array([hat_x_1, hat_x_2, x_1, x_2]),
cost,
done,
terminated,
False,
dict(
reference=y_1,
state_of_interest=np.array([hat_x_1 - x_1, hat_x_2 - x_2]),
Expand Down

0 comments on commit 4f435d6

Please sign in to comment.