Skip to content

Commit

Permalink
add ant cost (#184)
Browse files Browse the repository at this point in the history
* feat: add 'HalfCheetahCost' environment

This commit adds a modified version of the
[HalfCeetah environment](https://gymnasium.farama.org/environments/mujoco/half_cheetah/)
found in the [gymnasium](https://gymnasium.farama.org) package. In this modified
version, the cost was replaced by a reward. This cost is the squared difference
between HalfCheetah's forward velocity and a reference value (error).

* feat: add 'AntCost' environment

This commit adds a modified version of the
[Ant environment](https://gymnasium.farama.org/environments/mujoco/ant/)
found in the [gymnasium](https://gymnasium.farama.org) package. In this modified
version, the cost was replaced by a reward. This cost is the squared difference
between Ant's forward velocity and a reference value (error).

* test: add 'AntCost' tests
  • Loading branch information
rickstaa authored Jul 10, 2023
1 parent a7b3ba9 commit 87f06c7
Show file tree
Hide file tree
Showing 9 changed files with 284 additions and 2 deletions.
5 changes: 5 additions & 0 deletions stable_gym/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@
"max_step": 200,
"reward_threshold": 300,
},
"AntCost-v1": {
"module": "stable_gym.envs.mujoco.ant_cost.ant_cost:AntCost",
"max_step": 1000,
"reward_threshold": 300,
},
}

for env, val in ENVS.items():
Expand Down
2 changes: 1 addition & 1 deletion stable_gym/envs/biological/oscillator/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ The agent's goal in the oscillator environment is to act in such a way that one
The Oscillator environment uses the absolute difference between the reference and the state of interest as the cost function:

$$
cost = (p_1 - r_1)^2
cost = (p\_1 - r\_1)^2
$$

## Environment step return
Expand Down
25 changes: 25 additions & 0 deletions stable_gym/envs/mujoco/ant_cost/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# AntCost gymnasium environment

<div align="center">
<img src="https://github.com/rickstaa/stable-gym/assets/17570430/c9f6d7f9-586e-4236-91d3-fa2d0ce4aadc" alt="Ant Cost" width="200px">
</div>
</br>

An actuated 8-jointed ant. This environment corresponds to the [Ant-v4](https://gymnasium.farama.org/environments/mujoco/ant) environment included in the [gymnasium package](https://gymnasium.farama.org/). It is different in the fact that:

* The objective was changed to a velocity-tracking task. To do this, the reward is replaced with a cost. This cost is the squared
difference between the Ant's forward velocity and a reference value (error).

The rest of the environment is the same as the original Ant environment. Below, the modified cost is described. For more information about the environment (e.g. observation space, action space, episode termination, etc.), please refer to the [gymnasium library](https://gymnasium.farama.org/environments/mujoco/ant/).

## Cost function

The cost function of this environment is designed in such a way that it tries to minimize the error between the Ant's forward velocity and a reference value. The cost function is defined as:

$$
cost = w\_{forward} \times (x\_{velocity} - x\_{reference\_x\_velocity})^2 + w\_{ctrl} \times c\_{ctrl}
$$

## How to use

This environment is part of the [Stable Gym package](https://github.com/rickstaa/stable-gym). It is therefore registered as the `stable_gym:AntCost-v1` gymnasium environment when you import the Stable Gym package. If you want to use the environment in stand-alone mode, you can register it yourself.
9 changes: 9 additions & 0 deletions stable_gym/envs/mujoco/ant_cost/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Modified version of the Ant Mujoco environment in v0.28.1 of the
`gymnasium library <https://gymnasium.farama.org/environments/mujoco/ant>`_.
This modification was first described by `Han et al. 2020 <https://arxiv.org/abs/2004.14288>`_.
In this modified version:
- The objective was changed to a velocity-tracking task. To do this, the reward is replaced with a cost.
This cost is the squared difference between the Ant's forward velocity and a reference value (error).
""" # noqa: E501
from stable_gym.envs.mujoco.ant_cost.ant_cost import AntCost
197 changes: 197 additions & 0 deletions stable_gym/envs/mujoco/ant_cost/ant_cost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
"""The AntCost gymnasium environment."""

import gymnasium as gym
import matplotlib.pyplot as plt
import numpy as np
from gymnasium.envs.mujoco.ant_v4 import AntEnv

import stable_gym # NOTE: Required to register environments. # noqa: F401

RANDOM_STEP = True # Use random action in __main__. Zero action otherwise.


# TODO: Add solving criteria after training.
class AntCost(AntEnv):
"""Custom Ant gymnasium environment.
.. note::
Can also be used in a vectorized manner. See the
:gymnasium:`gym.vector <api/vector>` documentation.
Source:
This is a modified version of the Ant Mujoco environment in v0.28.1 of the
:gymnasium:`gymnasium library <environments/mujoco/ant>`. This modification
was first described by `Han et al. 2020 <https://arxiv.org/abs/2004.14288>`_.
Compared to the original Ant environment in this modified version:
- The objective was changed to a velocity-tracking task. To do this, the reward
is replaced with a cost. This cost is the squared difference between the
Ant's forward velocity and a reference value (error).
The rest of the environment is the same as the original Ant environment.
Below, the modified cost is described. For more information about the environment
(e.g. observation space, action space, episode termination, etc.), please refer
to the :gymnasium:`gymnasium library <environments/mujoco/ant>`.
Modified cost:
.. math::
cost = w_{forward} \\times (x_{velocity} - x_{reference\_x\_velocity})^2 + w_{ctrl} \\times c_{ctrl}
Solved Requirements:
Considered solved when the average cost is less than or equal to 50 over
100 consecutive trials.
How to use:
.. code-block:: python
import stable_gym
import gymnasium as gym
env = gym.make("AntCost-v1")
Attributes:
reference_forward_velocity (float): The forward velocity that the agent should try
to track.
include_ctrl_cost (bool): Whether you also want to penalize the Ant if it
takes actions that are too large.
forward_velocity_weight (float): The weight used to scale the forward velocity error.
""" # noqa: E501, W605

def __init__(
self,
reference_forward_velocity=1.0,
include_ctrl_cost=True,
forward_velocity_weight=1.0,
ctrl_cost_weight=None,
**kwargs,
):
"""Constructs all the necessary attributes for the AntCost instance.
Args:
reference_forward_velocity (float, optional): The forward velocity that the
agent should try to track. Defaults to ``1.0``.
include_ctrl_cost (bool, optional): Whether you also want to penalize the
Ant if it takes actions that are too large. Defaults to ``True``.
forward_velocity_weight (float, optional): The weight used to scale the
forward velocity error. Defaults to ``1.0``.
ctrl_cost_weight (_type_, optional): The weight used to scale the control
cost. Defaults to ``None`` meaning that the default value of the
:attr:`~gymnasium.envs.mujoco.ant_v4.AntEnv.ctrl_cost_weight`
attribute is used.
""" # noqa: E501
super().__init__(**kwargs)
self.reference_forward_velocity = reference_forward_velocity
self.include_ctrl_cost = include_ctrl_cost
self._ctrl_cost_weight = (
ctrl_cost_weight if ctrl_cost_weight else self._ctrl_cost_weight
)
self.forward_velocity_weight = forward_velocity_weight
self.state = None

def step(self, action):
"""Take step into the environment.
.. note::
This method overrides the
:meth:`~gymnasium.envs.mujoco.ant_v4.AntEnv.step` method
such that the new cost function is used.
Args:
action (np.ndarray): Action to take in the environment.
Returns:
(tuple): tuple containing:
- obs (:obj:`np.ndarray`): Environment observation.
- cost (:obj:`float`): Cost of the action.
- terminated (:obj`bool`): Whether the episode is terminated.
- truncated (:obj:`bool`): Whether the episode was truncated. This value
is set by wrappers when for example a time limit is reached or the
agent goes out of bounds.
- info (:obj`dict`): Additional information about the environment.
"""
obs, _, terminated, truncated, info = super().step(action)
self.state = obs
cost, cost_info = self.cost(info["x_velocity"], -info["reward_ctrl"])

# Update info.
info["reward_fwd"] = cost_info["reward_fwd"]
info["forward_reward"] = cost_info["reward_fwd"]

return obs, cost, terminated, truncated, info

def cost(self, x_velocity, ctrl_cost):
"""Compute the cost of the action.
Args:
x_velocity (float): The Ant's x velocity.
ctrl_cost (float): The control cost.
Returns:
(tuple): tuple containing:
- cost (float): The cost of the action.
- info (:obj:`dict`): Additional information about the cost.
"""
reward_fwd = self.forward_velocity_weight * np.square(
x_velocity - self.reference_forward_velocity
)
cost = reward_fwd
if self.include_ctrl_cost:
cost += ctrl_cost
return cost, {"reward_fwd": reward_fwd, "reward_ctrl": ctrl_cost}

@property
def ctrl_cost_weight(self):
"""Property that returns the control cost weight."""
return self._ctrl_cost_weight

@ctrl_cost_weight.setter
def ctrl_cost_weight(self, value):
"""Setter for the control cost weight."""
self._ctrl_cost_weight = value


if __name__ == "__main__":
print("Setting up AntCost environment.")
env = gym.make("AntCost", render_mode="human")

# Take T steps in the environment.
T = 1000
path = []
t1 = []
s = env.reset(
options={
"low": [-2, -0.2, -0.2, -0.2],
"high": [2, 0.2, 0.2, 0.2],
}
)
print(f"Taking {T} steps in the AntCost environment.")
for i in range(int(T / env.dt)):
action = (
env.action_space.sample()
if RANDOM_STEP
else np.zeros(env.action_space.shape)
)
s, r, terminated, truncated, info = env.step(action)
if terminated:
env.reset()
path.append(s)
t1.append(i * env.dt)
print("Finished AntCost environment simulation.")

# Plot results.
print("Plot results.")
fig = plt.figure(figsize=(9, 6))
ax = fig.add_subplot(111)
ax.plot(t1, np.array(path)[:, 0], color="orange", label="x")
ax.plot(t1, np.array(path)[:, 1], color="magenta", label="x_dot")
ax.plot(t1, np.array(path)[:, 2], color="sienna", label="theta")
ax.plot(t1, np.array(path)[:, 3], color="blue", label=" theat_dot1")

handles, labels = ax.get_legend_handles_labels()
ax.legend(handles, labels, loc=2, fancybox=False, shadow=False)
plt.ioff()
plt.show()

print("done")
3 changes: 3 additions & 0 deletions stable_gym/envs/mujoco/ant_cost/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
gymnasium==0.28.1
gymnasium[mujoco]==0.28.1
matplotlib==3.7.0
4 changes: 4 additions & 0 deletions stable_gym/envs/mujoco/half_cheetah_cost/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
# HalfCheetahCost gymnasium environment

<div align="center">
<<<<<<< HEAD
<img src="https://github.com/rickstaa/stable-gym/assets/17570430/44360980-3ad1-40e9-863e-3417ed3aa4c8" alt="Half Cheetah Cost" width="200px">
=======
<img src="https://github.com/rickstaa/stable-gym/assets/17570430/44360980-3ad1-40e9-863e-3417ed3aa4c8" alt="Half Cheetah" width="200px">
>>>>>>> main
</div>
</br>
Expand Down
2 changes: 1 addition & 1 deletion stable_gym/envs/mujoco/swimmer_cost/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# SwimmerCost gymnasium environment

<div align="center">
<img src="https://github.com/rickstaa/stable-gym/assets/17570430/dccd73b4-c97e-46ce-ba0d-4a1328c0aefe" alt="swimmer" width="200px">
<img src="https://github.com/rickstaa/stable-gym/assets/17570430/dccd73b4-c97e-46ce-ba0d-4a1328c0aefe" alt="Swimmer" width="200px">
</div>
</br>

Expand Down
39 changes: 39 additions & 0 deletions tests/test_ant_cost.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Test if the AntCost environment still behaves like the original Ant
environment when the same environment parameters are used.
"""
import gymnasium as gym
import numpy as np
from gymnasium.logger import ERROR

import stable_gym # noqa: F401

gym.logger.set_level(ERROR)


class TestAntCostEqual:
# Make original Ant environment.
env = gym.make("Ant")
# Make AntCost environment.
env_cost = gym.make("AntCost")

def test_equal_reset(self):
"""Test if reset behaves the same."""
# Perform reset and check if observations are equal.
observation, _ = self.env.reset(seed=42)
observation_cost, _ = self.env_cost.reset(seed=42)
assert np.allclose(
observation, observation_cost
), f"{observation} != {observation_cost}"

def test_equal_steps(self):
"""Test if steps behave the same."""
# Perform several steps and check if observations are equal.
self.env.reset(seed=42), self.env_cost.reset(seed=42)
for _ in range(10):
self.env.action_space.seed(42)
action = self.env.action_space.sample()
observation, _, _, _, _ = self.env.step(action)
observation_cost, _, _, _, _ = self.env_cost.step(action)
assert np.allclose(
observation, observation_cost
), f"{observation} != {observation_cost}"

0 comments on commit 87f06c7

Please sign in to comment.