Skip to content

Commit

Permalink
feat(quadxtrackingcost): add extra keys to environment step info dict…
Browse files Browse the repository at this point in the history
…ionary (#263)

This commit adds the `reference`, `state_of_interest` and `reference_error` keys to the step info
dictionary.
  • Loading branch information
rickstaa authored Jul 25, 2023
1 parent 2c2b5bb commit 9743330
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 2 deletions.
15 changes: 15 additions & 0 deletions stable_gym/envs/robotics/quadrotor/quadx_tracking_cost/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ An actuated multirotor unmanned aerial vehicle (UAV) in the Quad-X configuration
* A health penalty has been added. This penalty is applied when the quadrotor moves outside the flight dome or crashes. The penalty equals the maximum episode steps minus the steps taken or a user-defined penalty.
* The `max_duration_seconds` has been removed. Instead, the `max_episode_steps` parameter of the [gym.wrappers.TimeLimit](https://gymnasium.farama.org/api/wrappers/misc_wrappers/#gymnasium.wrappers.TimeLimit) wrapper is used to limit the episode duration.
* The objective has been changed to track a periodic reference trajectory.
* The info dictionary has been extended with the reference, state of interest (i.e. the state to track) and reference error.

The rest of the environment is the same as the original QuadXHover environment. Below, the modified cost and observation space is described. For more information about the environment (e.g. observation space, action space, episode termination, etc.), please refer to the [PyFlyt package documentation](https://jjshoots.github.io/PyFlyt/index.html).

Expand Down Expand Up @@ -39,6 +40,20 @@ Where:

The health penalty is optional and can be disabled using the `include_health_penalty` environment arguments.

## Environment step return

In addition to the observations, the cost and a termination and truncation boolean, the environment also returns an info dictionary:

```python
[observation, cost, termination, truncation, info_dict]
```

Compared to the original [QuadXHover-v1](https://jjshoots.github.io/PyFlyt/documentation/gym_envs/quadx_envs/quadx_hover_env.html) environment, the following keys were added to this info dictionary:

* **reference**: The set cart position reference.
* **state\_of\_interest**: The state that should track the reference (SOI).
* **reference\_error**: The error between SOI and the reference.

## How to use

This environment is part of the [Stable Gym package](https://github.com/rickstaa/stable-gym). It is therefore registered as the `stable_gym:QuadXTrackingCost-v1` gymnasium environment when you import the Stable Gym package. If you want to use the environment in stand-alone mode, you can register it yourself.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
parameter of the :class:`gym.wrappers.TimeLimit` wrapper is used to limit
the episode duration.
- The objective has been changed to track a periodic reference trajectory.
- The info dictionary has been extended with the reference, state of interest
(i.e. the state to track) and reference error.
The rest of the environment is the same as the original QuadXHover environment. For more
information about the original environment, please refer the
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
"""The QuadXTrackingCost gymnasium environment."""
from pathlib import PurePath

import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np
from gymnasium import utils
import PyFlyt
from pathlib import PurePath
from gymnasium import utils
from PyFlyt.gym_envs.quadx_envs.quadx_hover_env import QuadXHoverEnv

from stable_gym import ENVS # noqa: F401

EPISODES = 10 # Number of env episodes to run when __main__ is called.
Expand All @@ -32,6 +34,8 @@ class QuadXTrackingCost(QuadXHoverEnv, utils.EzPickle):
parameter of the :class:`gym.wrappers.TimeLimit` wrapper is used to limit
the episode duration.
- The objective has been changed to track a periodic reference trajectory.
- The info dictionary has been extended with the reference, state of interest
(i.e. the state to track) and reference error.
The rest of the environment is the same as the original QuadXHover environment.
Please refer to the `original codebase <https://github.com/jjshoots/PyFlyt>`__,
Expand Down Expand Up @@ -367,6 +371,14 @@ def step(self, action):

self.state = obs

# Add reference, state_of_interest and reference_error to info.
info_dict = dict(
reference=ref,
state_of_interest=self.env.state(0)[-1],
reference_error=self.env.state(0)[-1] - ref,
)
info.update(info_dict)

return obs, cost, terminated, truncated, info

def reset(self, seed=None, options=None):
Expand Down Expand Up @@ -411,6 +423,14 @@ def reset(self, seed=None, options=None):

self.state = obs

# Add reference, state_of_interest and reference_error to info.
info_dict = dict(
reference=ref,
state_of_interest=self.env.state(0)[-1],
reference_error=self.env.state(0)[-1] - ref,
)
info.update(info_dict)

return obs, info

def visualize_reference(self):
Expand Down
18 changes: 18 additions & 0 deletions tests/__snapshots__/test_quadx_tracking_cost.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@
'collision': False,
'env_complete': False,
'out_of_bounds': False,
'reference': array([ 0.1175374 , -0.99306846, 1.01177661]),
'reference_error': array([-0.11730597, 0.99254164, -0.06975631]),
'state_of_interest': array([ 2.31426066e-04, -5.26820864e-04, 9.42020303e-01]),
})
# ---
# name: TestQuadXTrackingCostEqual.test_snapshot.121
Expand Down Expand Up @@ -185,6 +188,9 @@
'collision': False,
'env_complete': False,
'out_of_bounds': False,
'reference': array([ 0.15643447, -0.98768834, 1.01569763]),
'reference_error': array([-0.15504015, 0.98936314, -0.05594122]),
'state_of_interest': array([0.00139432, 0.0016748 , 0.95975641]),
})
# ---
# name: TestQuadXTrackingCostEqual.test_snapshot.152
Expand Down Expand Up @@ -291,6 +297,9 @@
'collision': False,
'env_complete': False,
'out_of_bounds': False,
'reference': array([ 0.19509032, -0.98078528, 1.01961477]),
'reference_error': array([-0.19229362, 0.98562446, -0.03984327]),
'state_of_interest': array([0.0027967 , 0.00483918, 0.9797715 ]),
})
# ---
# name: TestQuadXTrackingCostEqual.test_snapshot.19
Expand Down Expand Up @@ -325,6 +334,9 @@
'collision': False,
'env_complete': False,
'out_of_bounds': False,
'reference': array([ 0., -1., 1.]),
'reference_error': array([-2.72890000e-09, 1.00000001e+00, -3.18775526e-02]),
'state_of_interest': array([-2.72890000e-09, 1.02375000e-08, 9.68122447e-01]),
})
# ---
# name: TestQuadXTrackingCostEqual.test_snapshot.28
Expand Down Expand Up @@ -431,6 +443,9 @@
'collision': False,
'env_complete': False,
'out_of_bounds': False,
'reference': array([ 0.03925982, -0.99922904, 1.00392683]),
'reference_error': array([-0.03926442, 0.99922081, -0.05509524]),
'state_of_interest': array([-4.60863660e-06, -8.22990120e-06, 9.48831587e-01]),
})
# ---
# name: TestQuadXTrackingCostEqual.test_snapshot.59
Expand Down Expand Up @@ -537,6 +552,9 @@
'collision': False,
'env_complete': False,
'out_of_bounds': False,
'reference': array([ 0.0784591 , -0.99691733, 1.00785269]),
'reference_error': array([-0.07861498, 0.99650484, -0.06989215]),
'state_of_interest': array([-1.55883875e-04, -4.12495247e-04, 9.37960537e-01]),
})
# ---
# name: TestQuadXTrackingCostEqual.test_snapshot.9
Expand Down

0 comments on commit 9743330

Please sign in to comment.