Skip to content

Commit

Permalink
docs: improve 'CartPoleCost' documentation. (#193)
Browse files Browse the repository at this point in the history
  • Loading branch information
rickstaa authored Jul 11, 2023
1 parent 75c5356 commit 16e3236
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 21 deletions.
1 change: 0 additions & 1 deletion examples/use_stable_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import stable_gym # noqa: F401

ENV_NAME = "Oscillator-v1"
# ENV_NAME = "Ex3EKF-v1"
# ENV_NAME = "CartPoleCost-v1"
# ENV_NAME = "SwimmerCost-v1"

Expand Down
4 changes: 2 additions & 2 deletions stable_gym/envs/biological/oscillator/oscillator.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,8 +371,8 @@ def reset(
Returns:
(tuple): tuple containing:
- observations (:obj:`numpy.ndarray`): Array containing the current
observations.
- observation (:obj:`numpy.ndarray`): Array containing the current
observation.
- info (:obj:`dict`): Dictionary containing additional information.
"""
super().reset(seed=seed)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,8 @@ def reset(
Returns:
(tuple): tuple containing:
- observations (:obj:`numpy.ndarray`): Array containing the current
observations.
- observation (:obj:`numpy.ndarray`): Array containing the current
observation.
- info (:obj:`dict`): Dictionary containing additional information.
"""
super().reset(seed=seed)
Expand Down
21 changes: 16 additions & 5 deletions stable_gym/envs/classic_control/cartpole_cost/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@ An episode is terminated when:
* Cart Position is more than 10 m (center of the cart reaches the edge of the
display).
* Episode length is greater than 200.
* The cost is greater than 100.
* The cost is greater than a set threshold (100 by default). This threshold can be changed using the `max_cost` environment argument.

## Environment goals

This environment has two task types (goals), which can be set using the `task_type` environment argument.

### stabilisation task

The stabilisation task is similar to the original `CartPole-v1` environment. The pendulum starts upright, and the goal is to prevent it from falling over by increasing and reducing the cart's control force. This must be done while the cart does not violate set position constraints. These constraints are defined in the cost function.
Expand All @@ -63,12 +65,21 @@ Similar to the stabilisation task, now the card must also track a cart position

## Cost function

The cost function of this environment is designed in such a way that it tries to minimize the error of a set of states and a set of reference states. It contains two types of tasks:
The cost function of this environment is designed in such a way that it tries to minimize the error of a set of states and reference states. As stated above, this environment contains two types of tasks. Each task type has a different cost function:

**Stabilisation task**

$$
cost = (x / x_{threshold})^2 + 20 * (theta / theta_{threshold})^2
$$

**Reference tracking task**

* A stabilisation task. In this task, the agent attempts to stabilize a given state (e.g. keep the pole angle and or cart position zero)
* A reference tracking task. The agent tries to make a state track a given reference in this task.
$$
cost = ((x - x_{ref})/ x_{threshold})^2 + 20 * (theta / theta_{threshold})^2
$$

The exact definition of these tasks can be found in the environment's `stable_gym.envs.classical_control.cartpole_cost.cartpole_cost.CartPoleCost.cost` method.
The exact definition of these tasks can be found in the environment's `stable_gym.envs.classical_control.cartpole_cost.cartpole_cost.CartPoleCost.cost` method. The cost is between `0` and a set threshold value in both tasks, and the maximum cost is used when the episode is terminated.

## Environment step return

Expand Down
31 changes: 20 additions & 11 deletions stable_gym/envs/classic_control/cartpole_cost/cartpole_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ class CartPoleCost(gym.Env, CartPoleDisturber):
Cost:
A cost, computed using the :meth:`CartPoleCost.cost` method, is given for each
simulation step including the terminal step. This cost is defined as a error
between a state variable and a reference value.
between a state variable and a reference value. The exact cost depends on the
task type. The cost is set to the maximum cost when the episode is terminated.
Starting State:
All observations are assigned a uniform random value in ``[-0.2..0.2]``.
Expand All @@ -106,7 +107,8 @@ class CartPoleCost(gym.Env, CartPoleDisturber):
- Cart Position is more than 10 m (center of the cart reaches the edge of the
display).
- Episode length is greater than 200.
- The cost is greater than 100.
- The cost is greater than a threshold (100 by default). This threshold can
be changed using the ``max_cost`` environment argument.
Solved Requirements:
Considered solved when the average cost is less than or equal to 50 over
Expand All @@ -125,6 +127,7 @@ class CartPoleCost(gym.Env, CartPoleDisturber):
Attributes:
state (numpy.ndarray): Array containing the current state.
t (float): Current time step.
tau (float): The time step size. Also available as ``self.dt``.
target_pos (float): The target position.
constraint_pos (float): The constraint position.
kinematics_integrator (str): The kinematics integrator used to update the state.
Expand All @@ -135,7 +138,7 @@ class CartPoleCost(gym.Env, CartPoleDisturber):
terminal state.
max_v (float): The maximum velocity of the cart.
max_w (float): The maximum angular velocity of the pole.
cost_range (gym.spaces.Box): The range of the cost.
max_cost (float): The maximum cost.
.. _`Neuronlike Adaptive Elements That Can Solve Difficult Learning Control Problem`: https://ieeexplore.ieee.org/document/6313077
.. _`Han et al. 2020`: https://arxiv.org/abs/2004.14288
Expand All @@ -155,6 +158,7 @@ def __init__(
reference_type="constant",
reference_target_position=0.0,
reference_constraint_position=4.0,
max_cost=100.0,
clip_action=True,
):
"""Constructs all the necessary attributes for the CartPoleCost instance.
Expand All @@ -172,6 +176,8 @@ def __init__(
when ``task_type`` is ``reference_tracking``.
reference_constraint_position: The reference constraint position, by
default ``4.0``. Not used in the environment but used for the info dict.
max_cost (float, optional): The maximum cost allowed before the episode is
terminated. Defaults to ``100.0``.
clip_action (str, optional): Whether the actions should be clipped if
they are greater than the set action limit. Defaults to ``True``.
"""
Expand Down Expand Up @@ -209,6 +215,8 @@ def __init__(
self.x_threshold = 10 # NOTE: original uses 2.4.
self.max_v = 50 # NOTE: Original uses np.finfo(np.float32).max (i.e. inf).
self.max_w = 50 # NOTE: Original uses np.finfo(np.float32).max (i.e. inf).
assert max_cost > 0, "The maximum cost must be greater than 0."
self.max_cost = max_cost

# Create observation space bounds.
# Angle limit set to 2 * theta_threshold_radians so failing observation
Expand Down Expand Up @@ -236,9 +244,9 @@ def __init__(
# Clip the reward.
# NOTE: Original does not do this. Here this is done because we want to decrease
# the cost.
self.cost_range = spaces.Box(
self._cost_range = spaces.Box(
np.array([0.0], dtype=np.float32),
np.array([100], dtype=np.float32),
np.array([self.max_cost], dtype=np.float32),
dtype=np.float32,
)

Expand Down Expand Up @@ -449,13 +457,14 @@ def step(self, action):
terminated = bool(
abs(x) > self.x_threshold
or abs(theta) > self.theta_threshold_radians
or cost > self.cost_range.high # NOTE: Added compared to original.
or cost < self.cost_range.low # NOTE: Added compared to original.
or cost > self._cost_range.high # NOTE: Added compared to original.
or cost < self._cost_range.low # NOTE: Added compared to original.
)

# Handle termination.
if terminated:
cost = 100.0 # NOTE: Different cost compared to the original.
# Ensure cost is at max cost.
cost = self.max_cost # NOTE: Different cost compared to the original.

# Throw warning if already done.
if self.steps_beyond_terminated is None:
Expand Down Expand Up @@ -518,8 +527,8 @@ def reset(self, seed=None, options=None, random=True):
Returns:
(tuple): tuple containing:
- observations (:obj:`numpy.ndarray`): Array containing the current
observations.
- observation (:obj:`numpy.ndarray`): Array containing the current
observation.
- info (:obj:`dict`): Dictionary containing additional information.
"""
super().reset(seed=seed)
Expand Down Expand Up @@ -743,7 +752,7 @@ def dt(self):

if __name__ == "__main__":
print("Setting up CartPoleCost environment.")
env = gym.make("CartPoleCost", render_mode="human")
env = gym.make("CartPoleCost", render_mode="human", max_cost=-1)

# Take T steps in the environment.
T = 1000
Expand Down

0 comments on commit 16e3236

Please sign in to comment.