-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* feat: add 'HalfCheetahCost' environment This commit adds a modified version of the [HalfCeetah environment](https://gymnasium.farama.org/environments/mujoco/half_cheetah/) found in the [gymnasium](https://gymnasium.farama.org) package. In this modified version, the cost was replaced by a reward. This cost is the squared difference between HalfCheetah's forward velocity and a reference value (error). * feat: add 'AntCost' environment This commit adds a modified version of the [Ant environment](https://gymnasium.farama.org/environments/mujoco/ant/) found in the [gymnasium](https://gymnasium.farama.org) package. In this modified version, the cost was replaced by a reward. This cost is the squared difference between Ant's forward velocity and a reference value (error). * test: add 'AntCost' tests
- Loading branch information
Showing
9 changed files
with
284 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# AntCost gymnasium environment | ||
|
||
<div align="center"> | ||
<img src="https://github.com/rickstaa/stable-gym/assets/17570430/c9f6d7f9-586e-4236-91d3-fa2d0ce4aadc" alt="Ant Cost" width="200px"> | ||
</div> | ||
</br> | ||
|
||
An actuated 8-jointed ant. This environment corresponds to the [Ant-v4](https://gymnasium.farama.org/environments/mujoco/ant) environment included in the [gymnasium package](https://gymnasium.farama.org/). It is different in the fact that: | ||
|
||
* The objective was changed to a velocity-tracking task. To do this, the reward is replaced with a cost. This cost is the squared | ||
difference between the Ant's forward velocity and a reference value (error). | ||
|
||
The rest of the environment is the same as the original Ant environment. Below, the modified cost is described. For more information about the environment (e.g. observation space, action space, episode termination, etc.), please refer to the [gymnasium library](https://gymnasium.farama.org/environments/mujoco/ant/). | ||
|
||
## Cost function | ||
|
||
The cost function of this environment is designed in such a way that it tries to minimize the error between the Ant's forward velocity and a reference value. The cost function is defined as: | ||
|
||
$$ | ||
cost = w\_{forward} \times (x\_{velocity} - x\_{reference\_x\_velocity})^2 + w\_{ctrl} \times c\_{ctrl} | ||
$$ | ||
|
||
## How to use | ||
|
||
This environment is part of the [Stable Gym package](https://github.com/rickstaa/stable-gym). It is therefore registered as the `stable_gym:AntCost-v1` gymnasium environment when you import the Stable Gym package. If you want to use the environment in stand-alone mode, you can register it yourself. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
"""Modified version of the Ant Mujoco environment in v0.28.1 of the | ||
`gymnasium library <https://gymnasium.farama.org/environments/mujoco/ant>`_. | ||
This modification was first described by `Han et al. 2020 <https://arxiv.org/abs/2004.14288>`_. | ||
In this modified version: | ||
- The objective was changed to a velocity-tracking task. To do this, the reward is replaced with a cost. | ||
This cost is the squared difference between the Ant's forward velocity and a reference value (error). | ||
""" # noqa: E501 | ||
from stable_gym.envs.mujoco.ant_cost.ant_cost import AntCost |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
"""The AntCost gymnasium environment.""" | ||
|
||
import gymnasium as gym | ||
import matplotlib.pyplot as plt | ||
import numpy as np | ||
from gymnasium.envs.mujoco.ant_v4 import AntEnv | ||
|
||
import stable_gym # NOTE: Required to register environments. # noqa: F401 | ||
|
||
RANDOM_STEP = True # Use random action in __main__. Zero action otherwise. | ||
|
||
|
||
# TODO: Add solving criteria after training. | ||
class AntCost(AntEnv): | ||
"""Custom Ant gymnasium environment. | ||
.. note:: | ||
Can also be used in a vectorized manner. See the | ||
:gymnasium:`gym.vector <api/vector>` documentation. | ||
Source: | ||
This is a modified version of the Ant Mujoco environment in v0.28.1 of the | ||
:gymnasium:`gymnasium library <environments/mujoco/ant>`. This modification | ||
was first described by `Han et al. 2020 <https://arxiv.org/abs/2004.14288>`_. | ||
Compared to the original Ant environment in this modified version: | ||
- The objective was changed to a velocity-tracking task. To do this, the reward | ||
is replaced with a cost. This cost is the squared difference between the | ||
Ant's forward velocity and a reference value (error). | ||
The rest of the environment is the same as the original Ant environment. | ||
Below, the modified cost is described. For more information about the environment | ||
(e.g. observation space, action space, episode termination, etc.), please refer | ||
to the :gymnasium:`gymnasium library <environments/mujoco/ant>`. | ||
Modified cost: | ||
.. math:: | ||
cost = w_{forward} \\times (x_{velocity} - x_{reference\_x\_velocity})^2 + w_{ctrl} \\times c_{ctrl} | ||
Solved Requirements: | ||
Considered solved when the average cost is less than or equal to 50 over | ||
100 consecutive trials. | ||
How to use: | ||
.. code-block:: python | ||
import stable_gym | ||
import gymnasium as gym | ||
env = gym.make("AntCost-v1") | ||
Attributes: | ||
reference_forward_velocity (float): The forward velocity that the agent should try | ||
to track. | ||
include_ctrl_cost (bool): Whether you also want to penalize the Ant if it | ||
takes actions that are too large. | ||
forward_velocity_weight (float): The weight used to scale the forward velocity error. | ||
""" # noqa: E501, W605 | ||
|
||
def __init__( | ||
self, | ||
reference_forward_velocity=1.0, | ||
include_ctrl_cost=True, | ||
forward_velocity_weight=1.0, | ||
ctrl_cost_weight=None, | ||
**kwargs, | ||
): | ||
"""Constructs all the necessary attributes for the AntCost instance. | ||
Args: | ||
reference_forward_velocity (float, optional): The forward velocity that the | ||
agent should try to track. Defaults to ``1.0``. | ||
include_ctrl_cost (bool, optional): Whether you also want to penalize the | ||
Ant if it takes actions that are too large. Defaults to ``True``. | ||
forward_velocity_weight (float, optional): The weight used to scale the | ||
forward velocity error. Defaults to ``1.0``. | ||
ctrl_cost_weight (_type_, optional): The weight used to scale the control | ||
cost. Defaults to ``None`` meaning that the default value of the | ||
:attr:`~gymnasium.envs.mujoco.ant_v4.AntEnv.ctrl_cost_weight` | ||
attribute is used. | ||
""" # noqa: E501 | ||
super().__init__(**kwargs) | ||
self.reference_forward_velocity = reference_forward_velocity | ||
self.include_ctrl_cost = include_ctrl_cost | ||
self._ctrl_cost_weight = ( | ||
ctrl_cost_weight if ctrl_cost_weight else self._ctrl_cost_weight | ||
) | ||
self.forward_velocity_weight = forward_velocity_weight | ||
self.state = None | ||
|
||
def step(self, action): | ||
"""Take step into the environment. | ||
.. note:: | ||
This method overrides the | ||
:meth:`~gymnasium.envs.mujoco.ant_v4.AntEnv.step` method | ||
such that the new cost function is used. | ||
Args: | ||
action (np.ndarray): Action to take in the environment. | ||
Returns: | ||
(tuple): tuple containing: | ||
- obs (:obj:`np.ndarray`): Environment observation. | ||
- cost (:obj:`float`): Cost of the action. | ||
- terminated (:obj`bool`): Whether the episode is terminated. | ||
- truncated (:obj:`bool`): Whether the episode was truncated. This value | ||
is set by wrappers when for example a time limit is reached or the | ||
agent goes out of bounds. | ||
- info (:obj`dict`): Additional information about the environment. | ||
""" | ||
obs, _, terminated, truncated, info = super().step(action) | ||
self.state = obs | ||
cost, cost_info = self.cost(info["x_velocity"], -info["reward_ctrl"]) | ||
|
||
# Update info. | ||
info["reward_fwd"] = cost_info["reward_fwd"] | ||
info["forward_reward"] = cost_info["reward_fwd"] | ||
|
||
return obs, cost, terminated, truncated, info | ||
|
||
def cost(self, x_velocity, ctrl_cost): | ||
"""Compute the cost of the action. | ||
Args: | ||
x_velocity (float): The Ant's x velocity. | ||
ctrl_cost (float): The control cost. | ||
Returns: | ||
(tuple): tuple containing: | ||
- cost (float): The cost of the action. | ||
- info (:obj:`dict`): Additional information about the cost. | ||
""" | ||
reward_fwd = self.forward_velocity_weight * np.square( | ||
x_velocity - self.reference_forward_velocity | ||
) | ||
cost = reward_fwd | ||
if self.include_ctrl_cost: | ||
cost += ctrl_cost | ||
return cost, {"reward_fwd": reward_fwd, "reward_ctrl": ctrl_cost} | ||
|
||
@property | ||
def ctrl_cost_weight(self): | ||
"""Property that returns the control cost weight.""" | ||
return self._ctrl_cost_weight | ||
|
||
@ctrl_cost_weight.setter | ||
def ctrl_cost_weight(self, value): | ||
"""Setter for the control cost weight.""" | ||
self._ctrl_cost_weight = value | ||
|
||
|
||
if __name__ == "__main__": | ||
print("Setting up AntCost environment.") | ||
env = gym.make("AntCost", render_mode="human") | ||
|
||
# Take T steps in the environment. | ||
T = 1000 | ||
path = [] | ||
t1 = [] | ||
s = env.reset( | ||
options={ | ||
"low": [-2, -0.2, -0.2, -0.2], | ||
"high": [2, 0.2, 0.2, 0.2], | ||
} | ||
) | ||
print(f"Taking {T} steps in the AntCost environment.") | ||
for i in range(int(T / env.dt)): | ||
action = ( | ||
env.action_space.sample() | ||
if RANDOM_STEP | ||
else np.zeros(env.action_space.shape) | ||
) | ||
s, r, terminated, truncated, info = env.step(action) | ||
if terminated: | ||
env.reset() | ||
path.append(s) | ||
t1.append(i * env.dt) | ||
print("Finished AntCost environment simulation.") | ||
|
||
# Plot results. | ||
print("Plot results.") | ||
fig = plt.figure(figsize=(9, 6)) | ||
ax = fig.add_subplot(111) | ||
ax.plot(t1, np.array(path)[:, 0], color="orange", label="x") | ||
ax.plot(t1, np.array(path)[:, 1], color="magenta", label="x_dot") | ||
ax.plot(t1, np.array(path)[:, 2], color="sienna", label="theta") | ||
ax.plot(t1, np.array(path)[:, 3], color="blue", label=" theat_dot1") | ||
|
||
handles, labels = ax.get_legend_handles_labels() | ||
ax.legend(handles, labels, loc=2, fancybox=False, shadow=False) | ||
plt.ioff() | ||
plt.show() | ||
|
||
print("done") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
gymnasium==0.28.1 | ||
gymnasium[mujoco]==0.28.1 | ||
matplotlib==3.7.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
"""Test if the AntCost environment still behaves like the original Ant | ||
environment when the same environment parameters are used. | ||
""" | ||
import gymnasium as gym | ||
import numpy as np | ||
from gymnasium.logger import ERROR | ||
|
||
import stable_gym # noqa: F401 | ||
|
||
gym.logger.set_level(ERROR) | ||
|
||
|
||
class TestAntCostEqual: | ||
# Make original Ant environment. | ||
env = gym.make("Ant") | ||
# Make AntCost environment. | ||
env_cost = gym.make("AntCost") | ||
|
||
def test_equal_reset(self): | ||
"""Test if reset behaves the same.""" | ||
# Perform reset and check if observations are equal. | ||
observation, _ = self.env.reset(seed=42) | ||
observation_cost, _ = self.env_cost.reset(seed=42) | ||
assert np.allclose( | ||
observation, observation_cost | ||
), f"{observation} != {observation_cost}" | ||
|
||
def test_equal_steps(self): | ||
"""Test if steps behave the same.""" | ||
# Perform several steps and check if observations are equal. | ||
self.env.reset(seed=42), self.env_cost.reset(seed=42) | ||
for _ in range(10): | ||
self.env.action_space.seed(42) | ||
action = self.env.action_space.sample() | ||
observation, _, _, _, _ = self.env.step(action) | ||
observation_cost, _, _, _, _ = self.env_cost.step(action) | ||
assert np.allclose( | ||
observation, observation_cost | ||
), f"{observation} != {observation_cost}" |