From 974333037454fe9667691ec489ae716602e20fdb Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Tue, 25 Jul 2023 20:05:55 +0200 Subject: [PATCH] feat(quadxtrackingcost): add extra keys to environment step info dictionary (#263) This commit adds the `reference`, `state_of_interest` and `reference_error` keys to the step info dictionary. --- .../quadrotor/quadx_tracking_cost/README.md | 15 ++++++++++++ .../quadrotor/quadx_tracking_cost/__init__.py | 2 ++ .../quadx_tracking_cost.py | 24 +++++++++++++++++-- .../test_quadx_tracking_cost.ambr | 18 ++++++++++++++ 4 files changed, 57 insertions(+), 2 deletions(-) diff --git a/stable_gym/envs/robotics/quadrotor/quadx_tracking_cost/README.md b/stable_gym/envs/robotics/quadrotor/quadx_tracking_cost/README.md index 26cb821a..6a2800ef 100644 --- a/stable_gym/envs/robotics/quadrotor/quadx_tracking_cost/README.md +++ b/stable_gym/envs/robotics/quadrotor/quadx_tracking_cost/README.md @@ -11,6 +11,7 @@ An actuated multirotor unmanned aerial vehicle (UAV) in the Quad-X configuration * A health penalty has been added. This penalty is applied when the quadrotor moves outside the flight dome or crashes. The penalty equals the maximum episode steps minus the steps taken or a user-defined penalty. * The `max_duration_seconds` has been removed. Instead, the `max_episode_steps` parameter of the [gym.wrappers.TimeLimit](https://gymnasium.farama.org/api/wrappers/misc_wrappers/#gymnasium.wrappers.TimeLimit) wrapper is used to limit the episode duration. * The objective has been changed to track a periodic reference trajectory. +* The info dictionary has been extended with the reference, state of interest (i.e. the state to track) and reference error. The rest of the environment is the same as the original QuadXHover environment. Below, the modified cost and observation space is described. For more information about the environment (e.g. observation space, action space, episode termination, etc.), please refer to the [PyFlyt package documentation](https://jjshoots.github.io/PyFlyt/index.html). @@ -39,6 +40,20 @@ Where: The health penalty is optional and can be disabled using the `include_health_penalty` environment arguments. +## Environment step return + +In addition to the observations, the cost and a termination and truncation boolean, the environment also returns an info dictionary: + +```python +[observation, cost, termination, truncation, info_dict] +``` + +Compared to the original [QuadXHover-v1](https://jjshoots.github.io/PyFlyt/documentation/gym_envs/quadx_envs/quadx_hover_env.html) environment, the following keys were added to this info dictionary: + +* **reference**: The set cart position reference. +* **state\_of\_interest**: The state that should track the reference (SOI). +* **reference\_error**: The error between SOI and the reference. + ## How to use This environment is part of the [Stable Gym package](https://github.com/rickstaa/stable-gym). It is therefore registered as the `stable_gym:QuadXTrackingCost-v1` gymnasium environment when you import the Stable Gym package. If you want to use the environment in stand-alone mode, you can register it yourself. diff --git a/stable_gym/envs/robotics/quadrotor/quadx_tracking_cost/__init__.py b/stable_gym/envs/robotics/quadrotor/quadx_tracking_cost/__init__.py index df366e27..8e83a0ec 100644 --- a/stable_gym/envs/robotics/quadrotor/quadx_tracking_cost/__init__.py +++ b/stable_gym/envs/robotics/quadrotor/quadx_tracking_cost/__init__.py @@ -10,6 +10,8 @@ parameter of the :class:`gym.wrappers.TimeLimit` wrapper is used to limit the episode duration. - The objective has been changed to track a periodic reference trajectory. +- The info dictionary has been extended with the reference, state of interest + (i.e. the state to track) and reference error. The rest of the environment is the same as the original QuadXHover environment. For more information about the original environment, please refer the diff --git a/stable_gym/envs/robotics/quadrotor/quadx_tracking_cost/quadx_tracking_cost.py b/stable_gym/envs/robotics/quadrotor/quadx_tracking_cost/quadx_tracking_cost.py index 1edf992c..b942279a 100644 --- a/stable_gym/envs/robotics/quadrotor/quadx_tracking_cost/quadx_tracking_cost.py +++ b/stable_gym/envs/robotics/quadrotor/quadx_tracking_cost/quadx_tracking_cost.py @@ -1,11 +1,13 @@ """The QuadXTrackingCost gymnasium environment.""" +from pathlib import PurePath + import gymnasium as gym import matplotlib.pyplot as plt import numpy as np -from gymnasium import utils import PyFlyt -from pathlib import PurePath +from gymnasium import utils from PyFlyt.gym_envs.quadx_envs.quadx_hover_env import QuadXHoverEnv + from stable_gym import ENVS # noqa: F401 EPISODES = 10 # Number of env episodes to run when __main__ is called. @@ -32,6 +34,8 @@ class QuadXTrackingCost(QuadXHoverEnv, utils.EzPickle): parameter of the :class:`gym.wrappers.TimeLimit` wrapper is used to limit the episode duration. - The objective has been changed to track a periodic reference trajectory. + - The info dictionary has been extended with the reference, state of interest + (i.e. the state to track) and reference error. The rest of the environment is the same as the original QuadXHover environment. Please refer to the `original codebase `__, @@ -367,6 +371,14 @@ def step(self, action): self.state = obs + # Add reference, state_of_interest and reference_error to info. + info_dict = dict( + reference=ref, + state_of_interest=self.env.state(0)[-1], + reference_error=self.env.state(0)[-1] - ref, + ) + info.update(info_dict) + return obs, cost, terminated, truncated, info def reset(self, seed=None, options=None): @@ -411,6 +423,14 @@ def reset(self, seed=None, options=None): self.state = obs + # Add reference, state_of_interest and reference_error to info. + info_dict = dict( + reference=ref, + state_of_interest=self.env.state(0)[-1], + reference_error=self.env.state(0)[-1] - ref, + ) + info.update(info_dict) + return obs, info def visualize_reference(self): diff --git a/tests/__snapshots__/test_quadx_tracking_cost.ambr b/tests/__snapshots__/test_quadx_tracking_cost.ambr index 9bd421b7..9c44e585 100644 --- a/tests/__snapshots__/test_quadx_tracking_cost.ambr +++ b/tests/__snapshots__/test_quadx_tracking_cost.ambr @@ -79,6 +79,9 @@ 'collision': False, 'env_complete': False, 'out_of_bounds': False, + 'reference': array([ 0.1175374 , -0.99306846, 1.01177661]), + 'reference_error': array([-0.11730597, 0.99254164, -0.06975631]), + 'state_of_interest': array([ 2.31426066e-04, -5.26820864e-04, 9.42020303e-01]), }) # --- # name: TestQuadXTrackingCostEqual.test_snapshot.121 @@ -185,6 +188,9 @@ 'collision': False, 'env_complete': False, 'out_of_bounds': False, + 'reference': array([ 0.15643447, -0.98768834, 1.01569763]), + 'reference_error': array([-0.15504015, 0.98936314, -0.05594122]), + 'state_of_interest': array([0.00139432, 0.0016748 , 0.95975641]), }) # --- # name: TestQuadXTrackingCostEqual.test_snapshot.152 @@ -291,6 +297,9 @@ 'collision': False, 'env_complete': False, 'out_of_bounds': False, + 'reference': array([ 0.19509032, -0.98078528, 1.01961477]), + 'reference_error': array([-0.19229362, 0.98562446, -0.03984327]), + 'state_of_interest': array([0.0027967 , 0.00483918, 0.9797715 ]), }) # --- # name: TestQuadXTrackingCostEqual.test_snapshot.19 @@ -325,6 +334,9 @@ 'collision': False, 'env_complete': False, 'out_of_bounds': False, + 'reference': array([ 0., -1., 1.]), + 'reference_error': array([-2.72890000e-09, 1.00000001e+00, -3.18775526e-02]), + 'state_of_interest': array([-2.72890000e-09, 1.02375000e-08, 9.68122447e-01]), }) # --- # name: TestQuadXTrackingCostEqual.test_snapshot.28 @@ -431,6 +443,9 @@ 'collision': False, 'env_complete': False, 'out_of_bounds': False, + 'reference': array([ 0.03925982, -0.99922904, 1.00392683]), + 'reference_error': array([-0.03926442, 0.99922081, -0.05509524]), + 'state_of_interest': array([-4.60863660e-06, -8.22990120e-06, 9.48831587e-01]), }) # --- # name: TestQuadXTrackingCostEqual.test_snapshot.59 @@ -537,6 +552,9 @@ 'collision': False, 'env_complete': False, 'out_of_bounds': False, + 'reference': array([ 0.0784591 , -0.99691733, 1.00785269]), + 'reference_error': array([-0.07861498, 0.99650484, -0.06989215]), + 'state_of_interest': array([-1.55883875e-04, -4.12495247e-04, 9.37960537e-01]), }) # --- # name: TestQuadXTrackingCostEqual.test_snapshot.9