diff --git a/examples/use_simzoo.py b/examples/use_simzoo.py index 9b036410..cee28821 100644 --- a/examples/use_simzoo.py +++ b/examples/use_simzoo.py @@ -9,8 +9,8 @@ # TODO: [ ] Check ex3_ekf. # TODO: [ ] Check cartpole. -# ENV_NAME = "Oscillator-v1" -ENV_NAME = "Ex3EKF-v1" +ENV_NAME = "Oscillator-v1" +# ENV_NAME = "Ex3EKF-v1" if __name__ == "__main__": env = gym.make(ENV_NAME, render_mode="human") diff --git a/simzoo/envs/biological/oscillator/oscillator.py b/simzoo/envs/biological/oscillator/oscillator.py index 329fb02d..6105ded0 100644 --- a/simzoo/envs/biological/oscillator/oscillator.py +++ b/simzoo/envs/biological/oscillator/oscillator.py @@ -199,7 +199,10 @@ def step(self, action): - obs (:obj:`numpy.ndarray`): The current state - cost (:obj:`numpy.float64`): The current cost. - - done (:obj:`bool`): Whether the episode was done. + - terminated (:obj:`bool`): Whether the episode was done. + - truncated (:obj:`bool`): Whether the episode was truncated. This value + is set by wrappers when for example a time limit is reached or the + agent goes out of bounds. - info_dict (:obj:`dict`): Dictionary with additional information. """ # Clip action if needed @@ -299,16 +302,14 @@ def step(self, action): # cost = (abs(p1 - r1)) ** 0.2 # Define stopping criteria - if cost > self.reward_range.high or cost < self.reward_range.low: - done = True - else: - done = False + terminated = bool(cost > self.reward_range.high or cost < self.reward_range.low) # Return state, cost, done and reference return ( np.array([m1, m2, m3, p1, p2, p3, r1, p1 - r1]), cost, - done, + terminated, + False, dict(reference=r1, state_of_interest=p1 - r1), ) diff --git a/simzoo/envs/classic_control/cart_pole_cost/cart_pole_cost.py b/simzoo/envs/classic_control/cart_pole_cost/cart_pole_cost.py index ab64468c..5bdbddd4 100644 --- a/simzoo/envs/classic_control/cart_pole_cost/cart_pole_cost.py +++ b/simzoo/envs/classic_control/cart_pole_cost/cart_pole_cost.py @@ -310,7 +310,10 @@ def step(self, action): - obs (:obj:`numpy.ndarray`): The current state - cost (:obj:`numpy.float64`): The current cost. - - done (:obj:`bool`): Whether the episode was done. + - terminated (:obj:`bool`): Whether the episode was done. + - truncated (:obj:`bool`): Whether the episode was truncated. This value + is set by wrappers when for example a time limit is reached or the + agent goes out of bounds. - info_dict (:obj:`dict`): Dictionary with additional information. """ # Clip action if needed @@ -378,13 +381,15 @@ def step(self, action): cost, ref = self.cost(x, theta) # Define stopping criteria - if ( + terminated = bool( abs(x) > self.x_threshold or abs(theta) > self.theta_threshold_radians or cost > self.reward_range.high or cost < self.reward_range.low - ): - done = True + ) + + # Calculate stopping cost + if terminated: cost = 100.0 # Throw warning if already done @@ -399,16 +404,15 @@ def step(self, action): "True' -- any further steps are undefined behavior." ) self.steps_beyond_done += 1 - else: - done = False - # Return state, cost, done and info_dict + # Return state, cost, terminated and info_dict violation_of_constraint = bool(abs(x) > self.const_pos) violation_of_x_threshold = bool(x < -self.x_threshold or x > self.x_threshold) return ( self.state, cost, - done, + terminated, + False, dict( cons_pos=self.const_pos, cons_theta=self.theta_threshold_radians, diff --git a/simzoo/envs/classic_control/ex3_ekf/ex3_ekf.py b/simzoo/envs/classic_control/ex3_ekf/ex3_ekf.py index aa7b5891..874008bb 100644 --- a/simzoo/envs/classic_control/ex3_ekf/ex3_ekf.py +++ b/simzoo/envs/classic_control/ex3_ekf/ex3_ekf.py @@ -185,7 +185,10 @@ def step(self, action): - obs (:obj:`numpy.ndarray`): The current state - cost (:obj:`numpy.float64`): The current cost. - - done (:obj:`bool`): Whether the episode was done. + - terminated (:obj:`bool`): Whether the episode was done. + - truncated (:obj:`bool`): Whether the episode was truncated. This value + is set by wrappers when for example a time limit is reached or the + agent goes out of bounds. - info_dict (:obj:`dict`): Dictionary with additional information. """ # Clip action if needed @@ -250,10 +253,7 @@ def step(self, action): # cost = np.abs(hat_x_1 - x_1)**1 + np.abs(hat_x_2 - x_2)**1 # Define stopping criteria - if cost > self.reward_range.high or cost < self.reward_range.low: - done = True - else: - done = False + terminated = bool(cost > self.reward_range.high or cost < self.reward_range.low) # Update state self.state = np.array([hat_x_1, hat_x_2, x_1, x_2]) @@ -264,7 +264,8 @@ def step(self, action): return ( np.array([hat_x_1, hat_x_2, x_1, x_2]), cost, - done, + terminated, + False, dict( reference=y_1, state_of_interest=np.array([hat_x_1 - x_1, hat_x_2 - x_2]),