Skip to content

Commit

Permalink
feat: add SwimmerCost environment (#180)
Browse files Browse the repository at this point in the history
* feat: add SwimmerCost environment

This commit adds the SwimmerCost environment. This environment was based on the
[swimmer environment](https://gymnasium.farama.org/environments/mujoco/swimmer/) of
the [gymnasium library](https://gymnasium.farama.org). Compared to that environment,
only the reward was changed so that a reference speed was tracked.

* test: add 'SwimmerCost' tests

* docs: improve docs

* test: rename 'Swimmer' test

* docs: fix documentation error

* docs: improve environments documentation
  • Loading branch information
rickstaa authored Jul 9, 2023
1 parent 277c1f9 commit f9eb341
Show file tree
Hide file tree
Showing 21 changed files with 352 additions and 47 deletions.
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
]
autoapi_dirs = ["../../stable_gym"]
myst_heading_anchors = 2 # Add anchors to headings.
myst_enable_extensions = ["dollarmath", "html_image"]

# Extensions settings.
autodoc_member_order = "bysource"
Expand Down
21 changes: 20 additions & 1 deletion docs/source/envs/envs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,33 @@ Gym environments that are based on Biological systems.
Classic control environments
----------------------------

Environments for classical control theory problems.
Environments that are based on classical control problems or `classical control`_
environments found in the :gymnasium:`gymnasium <>` library.

.. _`classical control`: https://gymnasium.farama.org/environments/classic_control

.. toctree::
:maxdepth: 1

./classic_control/ex3_ekf.rst
./classic_control/cartpole_cost.rst


.. _`classical control gymnasium environments`: https://gymnasium.farama.org/environments/classic_control

Mujoco environments
-------------------

Environments that are based on the on `Mujoco`_ or `Mujoco gymnasium`_ environments.

.. toctree::
:maxdepth: 1

./mujoco/swimmer_cost.rst

.. _`Mujoco`: https://mujoco.org/
.. _`mujoco gymnasium`: https://gymnasium.farama.org/environments/mujoco

Robotics environment
--------------------

Expand Down
2 changes: 2 additions & 0 deletions docs/source/envs/mujoco/swimmer_cost.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
.. include:: ../../../../stable_gym/envs/mujoco/swimmer_cost/README.md
:parser: myst_parser.sphinx_
1 change: 1 addition & 0 deletions examples/use_stable_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
ENV_NAME = "Oscillator-v1"
# ENV_NAME = "Ex3EKF-v1"
# ENV_NAME = "CartPoleCost-v1"
# ENV_NAME = "SwimmerCost-v1"

if __name__ == "__main__":
env = gym.make(ENV_NAME, render_mode="human")
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ classifiers = [
dependencies = [
"gymnasium>=0.28.1",
"gymnasium[classic_control]>=0.28.1",
"gymnasium[mujoco]>=0.28.1",
"matplotlib>=3.7.1",
"iteration_utilities>=0.11.0"
]
Expand Down
9 changes: 7 additions & 2 deletions stable_gym/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from gymnasium.envs.registration import register

# Make module version available.
from .version import __version__
from .version import __version_tuple__
from .version import __version__, __version_tuple__

# Available environments.
# TODO: Update reward thresholds.
ENVS = {
"Oscillator-v1": {
"module": "stable_gym.envs.biological.oscillator.oscillator:Oscillator",
Expand All @@ -28,6 +28,11 @@
"max_step": 250,
"reward_threshold": 300,
},
"SwimmerCost-v1": {
"module": "stable_gym.envs.mujoco.swimmer_cost.swimmer_cost:SwimmerCost",
"max_step": 250,
"reward_threshold": 300,
},
}

for env, val in ENVS.items():
Expand Down
6 changes: 3 additions & 3 deletions stable_gym/envs/biological/oscillator/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ The agent's goal in the oscillator environment is to act in such a way that one

The Oscillator environment uses the absolute difference between the reference and the state of interest as the cost function:

```python
cost = np.square(p1 - r1)
```
$$
cost = (p_1 - r_1)^2
$$

## Environment step return

Expand Down
8 changes: 4 additions & 4 deletions stable_gym/envs/biological/oscillator/oscillator.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,13 +232,13 @@ def step(self, action):
Returns:
(tuple): tuple containing:
- obs (:obj:`numpy.ndarray`): The current state
- cost (:obj:`numpy.float64`): The current cost.
- terminated (:obj:`bool`): Whether the episode was done.
- obs (:obj:`np.ndarray`): Environment observation.
- cost (:obj:`float`): Cost of the action.
- terminated (:obj`bool`): Whether the episode is terminated.
- truncated (:obj:`bool`): Whether the episode was truncated. This value
is set by wrappers when for example a time limit is reached or the
agent goes out of bounds.
- info_dict (:obj:`dict`): Dictionary with additional information.
- info (:obj`dict`): Additional information about the environment.
"""
# Clip action if needed.
if self._clip_action:
Expand Down
8 changes: 5 additions & 3 deletions stable_gym/envs/classic_control/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Stable Gym gymnasium environments that are based on classical control theory
problems.
"""
"""Stable Gym gymnasium environments based on classical control problems or
`classical control`_ environments found in the :gymnasium:`gymnasium <>` library.
.. _`classical control`: https://gymnasium.farama.org/environments/classic_control
""" # noqa: E501
from stable_gym.envs.classic_control.cartpole_cost.cartpole_cost import CartPoleCost
from stable_gym.envs.classic_control.ex3_ekf.ex3_ekf import Ex3EKF
6 changes: 4 additions & 2 deletions stable_gym/envs/classic_control/cartpole_cost/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# CartPoleCost gymnasium environment

![cart\_pole](https://github.com/rickstaa/stable-gym/assets/17570430/eb3d4f34-1429-4597-a51f-16aea0e7def2)
<div align="center">
<img src="https://github.com/rickstaa/stable-gym/assets/17570430/eb3d4f34-1429-4597-a51f-16aea0e7def2" alt="cartpole" width="400px">
</div>

<!--alex ignore joint-->

Expand Down Expand Up @@ -66,7 +68,7 @@ The cost function of this environment is designed in such a way that it tries to
* A stabilisation task. In this task, the agent attempts to stabilize a given state (e.g. keep the pole angle and or cart position zero)
* A reference tracking task. The agent tries to make a state track a given reference in this task.

The exact definition of these tasks can be found in the environment `cost()` method.
The exact definition of these tasks can be found in the environment's `stable_gym.envs.classical_control.cartpole_cost.cartpole_cost.CartPoleCost.cost` method.

## Environment step return

Expand Down
17 changes: 4 additions & 13 deletions stable_gym/envs/classic_control/cartpole_cost/cartpole_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,6 @@ class CartPoleCost(gym.Env, CartPoleDisturber):
Stable Learning Control package (SLC). For more information see
`the SLC documentation <https://rickstaa.dev/stable-learning-control/utils/tester.html#robustness-eval-utility>`_.
Description:
This environment was based on the cart-pole environment described by Barto,
Sutton, and Anderson in
`Neuronlike Adaptive Elements That Can Solve Difficult Learning Control Problem`_.
A pole is attached by an un-actuated joint to a cart, which moves along a
frictionless track. The pendulum is placed upright on the cart and the goal
is to balance the pole while optionally tracking a certain reference signal by
applying forces in the left and right direction on the cart.
Source:
This environment corresponds to the version that is included in the Farama
Foundation gymnasium package. It is different from this version in the fact
Expand Down Expand Up @@ -389,13 +380,13 @@ def step(self, action):
Returns:
(tuple): tuple containing:
- obs (:obj:`numpy.ndarray`): The current state
- cost (:obj:`numpy.float64`): The current cost.
- terminated (:obj:`bool`): Whether the episode was done.
- obs (:obj:`np.ndarray`): Environment observation.
- cost (:obj:`float`): Cost of the action.
- terminated (:obj`bool`): Whether the episode is terminated.
- truncated (:obj:`bool`): Whether the episode was truncated. This value
is set by wrappers when for example a time limit is reached or the
agent goes out of bounds.
- info_dict (:obj:`dict`): Dictionary with additional information.
- info (:obj`dict`): Additional information about the environment.
"""
# Clip action if needed.
# NOTE: This is not done in the original environment.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
gymnasium==0.28.1
gymnasium[classic_control]==0.28.1
matplotlib==3.7.0
8 changes: 4 additions & 4 deletions stable_gym/envs/classic_control/ex3_ekf/ex3_ekf.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,13 @@ def step(self, action):
Returns:
(tuple): tuple containing:
- obs (:obj:`numpy.ndarray`): The current state
- cost (:obj:`numpy.float64`): The current cost.
- terminated (:obj:`bool`): Whether the episode was done.
- obs (:obj:`np.ndarray`): Environment observation.
- cost (:obj:`float`): Cost of the action.
- terminated (:obj`bool`): Whether the episode is terminated.
- truncated (:obj:`bool`): Whether the episode was truncated. This value
is set by wrappers when for example a time limit is reached or the
agent goes out of bounds.
- info_dict (:obj:`dict`): Dictionary with additional information.
- info (:obj`dict`): Additional information about the environment.
"""
# Clip action if needed.
if self._clipped_action:
Expand Down
1 change: 1 addition & 0 deletions stable_gym/envs/classic_control/ex3_ekf/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
gymnasium==0.28.1
gymnasium[classic_control]==0.28.1
matplotlib==3.7.0
5 changes: 5 additions & 0 deletions stable_gym/envs/mujoco/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""Stable Gym gymnasium environments that are based on `Mujoco`_ or `Mujoco gymnasium`_ environments.
.. _`Mujoco`: https://mujoco.org
.. _`Mujoco gymnasium`: https://gymnasium.farama.org/environments/mujoco
"""
24 changes: 24 additions & 0 deletions stable_gym/envs/mujoco/swimmer_cost/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# SwimmerCost gymnasium environment

<div align="center">
<img src="https://github.com/rickstaa/stable-gym/assets/17570430/dccd73b4-c97e-46ce-ba0d-4a1328c0aefe" alt="swimmer" width="200px">
</div>
</br>

An actuated two-jointed swimmer. This environment corresponds to the [Swimmer-v4](https://gymnasium.farama.org/environments/mujoco/swimmer) environment included in the [gymnasium package](https://gymnasium.farama.org/). It is different in the fact that:

* The objective was changed to a speed-tracking task. To do this, the reward is replaced with a cost. This cost is the squared difference between the swimmer's forward speed and a reference value (error).

The rest of the environment is the same as the original Swimmer environment. Below, the modified cost is described. For more information about the environment (e.g. observation space, action space, episode termination, etc.), please refer to the [gymnasium library](https://gymnasium.farama.org/environments/mujoco/swimmer/).

## Cost function

The cost function of this environment is designed in such a way that it tries to minimize the error between the swimmer's forward speed and a reference value. The cost function is defined as:

$$
cost = w\_{forward} \times (x\_{speed} - x\_{reference\_speed})^2 + w\_{ctrl} \times c\_{ctrl}
$$

## How to use

This environment is part of the [Stable Gym package](https://github.com/rickstaa/stable-gym). It is therefore registered as the `stable_gym:SwimmerCost-v1` gymnasium environment when you import the Stable Gym package. If you want to use the environment in stand-alone mode, you can register it yourself.
9 changes: 9 additions & 0 deletions stable_gym/envs/mujoco/swimmer_cost/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Modified version of the swimmer Mujoco environment in v0.28.1 of the
`gymnasium library <https://gymnasium.farama.org/environments/mujoco/swimmer>`_.
This modification was first described by `Han et al. 2020 <https://arxiv.org/abs/2004.14288>`_.
In this modified version:
- The objective was changed to a speed-tracking task. To do this, the reward is replaced with a cost.
This cost is the squared difference between the swimmer's forward speed and a reference value (error).
""" # noqa: E501
from stable_gym.envs.mujoco.swimmer_cost.swimmer_cost import SwimmerCost
3 changes: 3 additions & 0 deletions stable_gym/envs/mujoco/swimmer_cost/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
gymnasium==0.28.1
gymnasium[mujoco]==0.28.1
matplotlib==3.7.0
Loading

0 comments on commit f9eb341

Please sign in to comment.