From 87f06c7e7b25e327188ff660ceb71491524377ab Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Mon, 10 Jul 2023 11:03:39 +0200 Subject: [PATCH] add ant cost (#184) * feat: add 'HalfCheetahCost' environment This commit adds a modified version of the [HalfCeetah environment](https://gymnasium.farama.org/environments/mujoco/half_cheetah/) found in the [gymnasium](https://gymnasium.farama.org) package. In this modified version, the cost was replaced by a reward. This cost is the squared difference between HalfCheetah's forward velocity and a reference value (error). * feat: add 'AntCost' environment This commit adds a modified version of the [Ant environment](https://gymnasium.farama.org/environments/mujoco/ant/) found in the [gymnasium](https://gymnasium.farama.org) package. In this modified version, the cost was replaced by a reward. This cost is the squared difference between Ant's forward velocity and a reference value (error). * test: add 'AntCost' tests --- stable_gym/__init__.py | 5 + .../envs/biological/oscillator/README.md | 2 +- stable_gym/envs/mujoco/ant_cost/README.md | 25 +++ stable_gym/envs/mujoco/ant_cost/__init__.py | 9 + stable_gym/envs/mujoco/ant_cost/ant_cost.py | 197 ++++++++++++++++++ .../envs/mujoco/ant_cost/requirements.txt | 3 + .../envs/mujoco/half_cheetah_cost/README.md | 4 + stable_gym/envs/mujoco/swimmer_cost/README.md | 2 +- tests/test_ant_cost.py | 39 ++++ 9 files changed, 284 insertions(+), 2 deletions(-) create mode 100644 stable_gym/envs/mujoco/ant_cost/README.md create mode 100644 stable_gym/envs/mujoco/ant_cost/__init__.py create mode 100644 stable_gym/envs/mujoco/ant_cost/ant_cost.py create mode 100644 stable_gym/envs/mujoco/ant_cost/requirements.txt create mode 100644 tests/test_ant_cost.py diff --git a/stable_gym/__init__.py b/stable_gym/__init__.py index 39fb7431..621d1023 100644 --- a/stable_gym/__init__.py +++ b/stable_gym/__init__.py @@ -38,6 +38,11 @@ "max_step": 200, "reward_threshold": 300, }, + "AntCost-v1": { + "module": "stable_gym.envs.mujoco.ant_cost.ant_cost:AntCost", + "max_step": 1000, + "reward_threshold": 300, + }, } for env, val in ENVS.items(): diff --git a/stable_gym/envs/biological/oscillator/README.md b/stable_gym/envs/biological/oscillator/README.md index bb54345e..0bb670f2 100644 --- a/stable_gym/envs/biological/oscillator/README.md +++ b/stable_gym/envs/biological/oscillator/README.md @@ -32,7 +32,7 @@ The agent's goal in the oscillator environment is to act in such a way that one The Oscillator environment uses the absolute difference between the reference and the state of interest as the cost function: $$ -cost = (p_1 - r_1)^2 +cost = (p\_1 - r\_1)^2 $$ ## Environment step return diff --git a/stable_gym/envs/mujoco/ant_cost/README.md b/stable_gym/envs/mujoco/ant_cost/README.md new file mode 100644 index 00000000..ab76810b --- /dev/null +++ b/stable_gym/envs/mujoco/ant_cost/README.md @@ -0,0 +1,25 @@ +# AntCost gymnasium environment + +
+ Ant Cost +
+
+ +An actuated 8-jointed ant. This environment corresponds to the [Ant-v4](https://gymnasium.farama.org/environments/mujoco/ant) environment included in the [gymnasium package](https://gymnasium.farama.org/). It is different in the fact that: + +* The objective was changed to a velocity-tracking task. To do this, the reward is replaced with a cost. This cost is the squared + difference between the Ant's forward velocity and a reference value (error). + +The rest of the environment is the same as the original Ant environment. Below, the modified cost is described. For more information about the environment (e.g. observation space, action space, episode termination, etc.), please refer to the [gymnasium library](https://gymnasium.farama.org/environments/mujoco/ant/). + +## Cost function + +The cost function of this environment is designed in such a way that it tries to minimize the error between the Ant's forward velocity and a reference value. The cost function is defined as: + +$$ +cost = w\_{forward} \times (x\_{velocity} - x\_{reference\_x\_velocity})^2 + w\_{ctrl} \times c\_{ctrl} +$$ + +## How to use + +This environment is part of the [Stable Gym package](https://github.com/rickstaa/stable-gym). It is therefore registered as the `stable_gym:AntCost-v1` gymnasium environment when you import the Stable Gym package. If you want to use the environment in stand-alone mode, you can register it yourself. diff --git a/stable_gym/envs/mujoco/ant_cost/__init__.py b/stable_gym/envs/mujoco/ant_cost/__init__.py new file mode 100644 index 00000000..4ef3ec6e --- /dev/null +++ b/stable_gym/envs/mujoco/ant_cost/__init__.py @@ -0,0 +1,9 @@ +"""Modified version of the Ant Mujoco environment in v0.28.1 of the +`gymnasium library `_. +This modification was first described by `Han et al. 2020 `_. +In this modified version: + +- The objective was changed to a velocity-tracking task. To do this, the reward is replaced with a cost. + This cost is the squared difference between the Ant's forward velocity and a reference value (error). +""" # noqa: E501 +from stable_gym.envs.mujoco.ant_cost.ant_cost import AntCost diff --git a/stable_gym/envs/mujoco/ant_cost/ant_cost.py b/stable_gym/envs/mujoco/ant_cost/ant_cost.py new file mode 100644 index 00000000..0bc17dc8 --- /dev/null +++ b/stable_gym/envs/mujoco/ant_cost/ant_cost.py @@ -0,0 +1,197 @@ +"""The AntCost gymnasium environment.""" + +import gymnasium as gym +import matplotlib.pyplot as plt +import numpy as np +from gymnasium.envs.mujoco.ant_v4 import AntEnv + +import stable_gym # NOTE: Required to register environments. # noqa: F401 + +RANDOM_STEP = True # Use random action in __main__. Zero action otherwise. + + +# TODO: Add solving criteria after training. +class AntCost(AntEnv): + """Custom Ant gymnasium environment. + + .. note:: + Can also be used in a vectorized manner. See the + :gymnasium:`gym.vector ` documentation. + + Source: + This is a modified version of the Ant Mujoco environment in v0.28.1 of the + :gymnasium:`gymnasium library `. This modification + was first described by `Han et al. 2020 `_. + Compared to the original Ant environment in this modified version: + + - The objective was changed to a velocity-tracking task. To do this, the reward + is replaced with a cost. This cost is the squared difference between the + Ant's forward velocity and a reference value (error). + + The rest of the environment is the same as the original Ant environment. + Below, the modified cost is described. For more information about the environment + (e.g. observation space, action space, episode termination, etc.), please refer + to the :gymnasium:`gymnasium library `. + + Modified cost: + .. math:: + + cost = w_{forward} \\times (x_{velocity} - x_{reference\_x\_velocity})^2 + w_{ctrl} \\times c_{ctrl} + + Solved Requirements: + Considered solved when the average cost is less than or equal to 50 over + 100 consecutive trials. + + How to use: + .. code-block:: python + + import stable_gym + import gymnasium as gym + env = gym.make("AntCost-v1") + + Attributes: + reference_forward_velocity (float): The forward velocity that the agent should try + to track. + include_ctrl_cost (bool): Whether you also want to penalize the Ant if it + takes actions that are too large. + forward_velocity_weight (float): The weight used to scale the forward velocity error. + """ # noqa: E501, W605 + + def __init__( + self, + reference_forward_velocity=1.0, + include_ctrl_cost=True, + forward_velocity_weight=1.0, + ctrl_cost_weight=None, + **kwargs, + ): + """Constructs all the necessary attributes for the AntCost instance. + + Args: + reference_forward_velocity (float, optional): The forward velocity that the + agent should try to track. Defaults to ``1.0``. + include_ctrl_cost (bool, optional): Whether you also want to penalize the + Ant if it takes actions that are too large. Defaults to ``True``. + forward_velocity_weight (float, optional): The weight used to scale the + forward velocity error. Defaults to ``1.0``. + ctrl_cost_weight (_type_, optional): The weight used to scale the control + cost. Defaults to ``None`` meaning that the default value of the + :attr:`~gymnasium.envs.mujoco.ant_v4.AntEnv.ctrl_cost_weight` + attribute is used. + """ # noqa: E501 + super().__init__(**kwargs) + self.reference_forward_velocity = reference_forward_velocity + self.include_ctrl_cost = include_ctrl_cost + self._ctrl_cost_weight = ( + ctrl_cost_weight if ctrl_cost_weight else self._ctrl_cost_weight + ) + self.forward_velocity_weight = forward_velocity_weight + self.state = None + + def step(self, action): + """Take step into the environment. + + .. note:: + This method overrides the + :meth:`~gymnasium.envs.mujoco.ant_v4.AntEnv.step` method + such that the new cost function is used. + + Args: + action (np.ndarray): Action to take in the environment. + + Returns: + (tuple): tuple containing: + + - obs (:obj:`np.ndarray`): Environment observation. + - cost (:obj:`float`): Cost of the action. + - terminated (:obj`bool`): Whether the episode is terminated. + - truncated (:obj:`bool`): Whether the episode was truncated. This value + is set by wrappers when for example a time limit is reached or the + agent goes out of bounds. + - info (:obj`dict`): Additional information about the environment. + """ + obs, _, terminated, truncated, info = super().step(action) + self.state = obs + cost, cost_info = self.cost(info["x_velocity"], -info["reward_ctrl"]) + + # Update info. + info["reward_fwd"] = cost_info["reward_fwd"] + info["forward_reward"] = cost_info["reward_fwd"] + + return obs, cost, terminated, truncated, info + + def cost(self, x_velocity, ctrl_cost): + """Compute the cost of the action. + + Args: + x_velocity (float): The Ant's x velocity. + ctrl_cost (float): The control cost. + + Returns: + (tuple): tuple containing: + + - cost (float): The cost of the action. + - info (:obj:`dict`): Additional information about the cost. + """ + reward_fwd = self.forward_velocity_weight * np.square( + x_velocity - self.reference_forward_velocity + ) + cost = reward_fwd + if self.include_ctrl_cost: + cost += ctrl_cost + return cost, {"reward_fwd": reward_fwd, "reward_ctrl": ctrl_cost} + + @property + def ctrl_cost_weight(self): + """Property that returns the control cost weight.""" + return self._ctrl_cost_weight + + @ctrl_cost_weight.setter + def ctrl_cost_weight(self, value): + """Setter for the control cost weight.""" + self._ctrl_cost_weight = value + + +if __name__ == "__main__": + print("Setting up AntCost environment.") + env = gym.make("AntCost", render_mode="human") + + # Take T steps in the environment. + T = 1000 + path = [] + t1 = [] + s = env.reset( + options={ + "low": [-2, -0.2, -0.2, -0.2], + "high": [2, 0.2, 0.2, 0.2], + } + ) + print(f"Taking {T} steps in the AntCost environment.") + for i in range(int(T / env.dt)): + action = ( + env.action_space.sample() + if RANDOM_STEP + else np.zeros(env.action_space.shape) + ) + s, r, terminated, truncated, info = env.step(action) + if terminated: + env.reset() + path.append(s) + t1.append(i * env.dt) + print("Finished AntCost environment simulation.") + + # Plot results. + print("Plot results.") + fig = plt.figure(figsize=(9, 6)) + ax = fig.add_subplot(111) + ax.plot(t1, np.array(path)[:, 0], color="orange", label="x") + ax.plot(t1, np.array(path)[:, 1], color="magenta", label="x_dot") + ax.plot(t1, np.array(path)[:, 2], color="sienna", label="theta") + ax.plot(t1, np.array(path)[:, 3], color="blue", label=" theat_dot1") + + handles, labels = ax.get_legend_handles_labels() + ax.legend(handles, labels, loc=2, fancybox=False, shadow=False) + plt.ioff() + plt.show() + + print("done") diff --git a/stable_gym/envs/mujoco/ant_cost/requirements.txt b/stable_gym/envs/mujoco/ant_cost/requirements.txt new file mode 100644 index 00000000..b9d95fcf --- /dev/null +++ b/stable_gym/envs/mujoco/ant_cost/requirements.txt @@ -0,0 +1,3 @@ +gymnasium==0.28.1 +gymnasium[mujoco]==0.28.1 +matplotlib==3.7.0 diff --git a/stable_gym/envs/mujoco/half_cheetah_cost/README.md b/stable_gym/envs/mujoco/half_cheetah_cost/README.md index 9e092beb..390282fa 100644 --- a/stable_gym/envs/mujoco/half_cheetah_cost/README.md +++ b/stable_gym/envs/mujoco/half_cheetah_cost/README.md @@ -1,7 +1,11 @@ # HalfCheetahCost gymnasium environment
+<<<<<<< HEAD + Half Cheetah Cost +======= Half Cheetah +>>>>>>> main

diff --git a/stable_gym/envs/mujoco/swimmer_cost/README.md b/stable_gym/envs/mujoco/swimmer_cost/README.md index bed09c55..132b362d 100644 --- a/stable_gym/envs/mujoco/swimmer_cost/README.md +++ b/stable_gym/envs/mujoco/swimmer_cost/README.md @@ -1,7 +1,7 @@ # SwimmerCost gymnasium environment
- swimmer + Swimmer

diff --git a/tests/test_ant_cost.py b/tests/test_ant_cost.py new file mode 100644 index 00000000..4f8294d5 --- /dev/null +++ b/tests/test_ant_cost.py @@ -0,0 +1,39 @@ +"""Test if the AntCost environment still behaves like the original Ant +environment when the same environment parameters are used. +""" +import gymnasium as gym +import numpy as np +from gymnasium.logger import ERROR + +import stable_gym # noqa: F401 + +gym.logger.set_level(ERROR) + + +class TestAntCostEqual: + # Make original Ant environment. + env = gym.make("Ant") + # Make AntCost environment. + env_cost = gym.make("AntCost") + + def test_equal_reset(self): + """Test if reset behaves the same.""" + # Perform reset and check if observations are equal. + observation, _ = self.env.reset(seed=42) + observation_cost, _ = self.env_cost.reset(seed=42) + assert np.allclose( + observation, observation_cost + ), f"{observation} != {observation_cost}" + + def test_equal_steps(self): + """Test if steps behave the same.""" + # Perform several steps and check if observations are equal. + self.env.reset(seed=42), self.env_cost.reset(seed=42) + for _ in range(10): + self.env.action_space.seed(42) + action = self.env.action_space.sample() + observation, _, _, _, _ = self.env.step(action) + observation_cost, _, _, _, _ = self.env_cost.step(action) + assert np.allclose( + observation, observation_cost + ), f"{observation} != {observation_cost}"