From 5ca9630591c230d506474769d63453add1021403 Mon Sep 17 00:00:00 2001 From: rickstaa Date: Wed, 26 Jul 2023 17:28:39 +0200 Subject: [PATCH] feat(classicalcontrol): add additional info to step/reset return This commit adds the `reference`, `state_of_interest` and `reference_error` keys to the info dictionary that is returned by the step and reset methods. --- .../classic_control/cartpole_cost/README.md | 6 ++++- .../cartpole_cost/cartpole_cost.py | 17 +++++++++++--- .../cartpole_tracking_cost.py | 23 +++++++++---------- 3 files changed, 30 insertions(+), 16 deletions(-) diff --git a/stable_gym/envs/classic_control/cartpole_cost/README.md b/stable_gym/envs/classic_control/cartpole_cost/README.md index 36cc2b20..032d186a 100644 --- a/stable_gym/envs/classic_control/cartpole_cost/README.md +++ b/stable_gym/envs/classic_control/cartpole_cost/README.md @@ -59,7 +59,11 @@ In addition to the observations, the cost and a termination and truncation boole [observation, cost, termination, truncation, info_dict] ``` -The info dictionary currently is empty. +The info dictionary contains the following keys: + +* **reference**: The set cart position and angle reference (i.e. the zero position and angle). +* **state\_of\_interest**: The state that should track the reference (SOI). +* **reference\_error**: The error between SOI and the reference. ## How to use diff --git a/stable_gym/envs/classic_control/cartpole_cost/cartpole_cost.py b/stable_gym/envs/classic_control/cartpole_cost/cartpole_cost.py index a0f2dbf7..47c00738 100644 --- a/stable_gym/envs/classic_control/cartpole_cost/cartpole_cost.py +++ b/stable_gym/envs/classic_control/cartpole_cost/cartpole_cost.py @@ -396,6 +396,11 @@ def step(self, action): # Create observation and info dict. obs = np.array(self.state) + info_dict = dict( + reference=np.array([0.0, 0.0]), + state_of_interest=np.array([x, theta]), + reference_error=np.array([-x, -theta]), + ) # NOTE: The original returns an empty info dict. return ( @@ -403,7 +408,7 @@ def step(self, action): cost, terminated, False, - {}, + info_dict, ) def reset(self, seed=None, options=None, random=True): @@ -467,15 +472,21 @@ def reset(self, seed=None, options=None, random=True): self.steps_beyond_terminated = None self.t = 0.0 - # Create info dict and observation. + # Retrieve observation and info_dict. obs = np.array(self.state) + x, _, theta, _ = self.state + info_dict = dict( + reference=np.array([0.0, 0.0]), + state_of_interest=np.array([x, theta]), + reference_error=np.array([-x, -theta]), + ) # Render environment reset if requested. if self.render_mode == "human": self.render() # NOTE: The original returns an empty info dict. - return obs, {} + return obs, info_dict def render(self): """Render one frame of the environment.""" diff --git a/stable_gym/envs/classic_control/cartpole_tracking_cost/cartpole_tracking_cost.py b/stable_gym/envs/classic_control/cartpole_tracking_cost/cartpole_tracking_cost.py index 74e71c5c..a373f8af 100644 --- a/stable_gym/envs/classic_control/cartpole_tracking_cost/cartpole_tracking_cost.py +++ b/stable_gym/envs/classic_control/cartpole_tracking_cost/cartpole_tracking_cost.py @@ -348,11 +348,10 @@ def cost(self, x, theta): - cost (float): The current cost. - r_1 (float): The current position reference. - - r_2 (float): The cart_pole angle reference. """ # TODO: Fine-tune cost function. The current one is a initial test. - ref = [self.reference(self.t), 0.0] - ref_cost = np.square(x - ref[0]) + ref = self.reference(self.t) + ref_cost = np.square(x - ref) stab_cost = np.square(theta / self.theta_threshold_radians) cost = stab_cost + ref_cost @@ -465,15 +464,15 @@ def step(self, action): obs = np.append( np.array(self.state), np.array( - [ref[0]] + [ref] if self._exclude_reference_error_from_observation - else [ref[0], x - ref[0]] + else [ref, x - ref] ), ) info_dict = dict( - reference=ref[0], + reference=ref, state_of_interest=x, - reference_error=x - ref[0], + reference_error=x - ref, ) # NOTE: The original returns an empty info dict. @@ -546,20 +545,20 @@ def reset(self, seed=None, options=None, random=True): self.steps_beyond_terminated = None self.t = 0.0 - # Create info dict and observation. + # Retrieve observation and info_dict. x, _, theta, _ = self.state _, ref = self.cost(x, theta) info_dict = dict( - reference=ref[0], + reference=ref, state_of_interest=x, - reference_error=x - ref[0], + reference_error=x - ref, ) obs = np.append( np.array(self.state), np.array( - [ref[0]] + [ref] if self._exclude_reference_error_from_observation - else [ref[0], x - ref[0]] + else [ref, x - ref] ), )