From 7d715bdcbe5fff797cde70a6a4ba621e64cd786e Mon Sep 17 00:00:00 2001 From: Rick Staa Date: Wed, 26 Jul 2023 17:37:08 +0200 Subject: [PATCH] feat: add reference info to step/reset return (#275) This commit adds the `reference`, `state_of_interest` and `reference_error` keys to the info dictionary that is returned by the environment step and reset methods. --- stable_gym/envs/mujoco/ant_cost/ant_cost.py | 31 +++++++++++++++++-- .../half_cheetah_cost/half_cheetah_cost.py | 25 +++++++++++++-- .../envs/mujoco/hopper_cost/hopper_cost.py | 25 +++++++++++++-- .../mujoco/humanoid_cost/humanoid_cost.py | 25 +++++++++++++-- .../envs/mujoco/swimmer_cost/swimmer_cost.py | 25 +++++++++++++-- .../mujoco/walker2d_cost/walker2d_cost.py | 25 +++++++++++++-- 6 files changed, 138 insertions(+), 18 deletions(-) diff --git a/stable_gym/envs/mujoco/ant_cost/ant_cost.py b/stable_gym/envs/mujoco/ant_cost/ant_cost.py index 266c4142..22069b6c 100644 --- a/stable_gym/envs/mujoco/ant_cost/ant_cost.py +++ b/stable_gym/envs/mujoco/ant_cost/ant_cost.py @@ -283,8 +283,6 @@ def step(self, action): """ obs, _, terminated, truncated, info = super().step(action) - self.state = obs - cost, cost_info = self.cost(info["x_velocity"], -info["reward_ctrl"]) # Add reference, x velocity and reference error to observation. @@ -295,8 +293,23 @@ def step(self, action): if not self._exclude_x_velocity_from_observation: obs = np.append(obs, info["x_velocity"]) - # Update info. + self.state = obs + + # Update info dictionary. + del ( + info["reward_forward"], + info["forward_reward"], + info["reward_ctrl"], + info["reward_survive"], + ) info.update(cost_info) + info.update( + { + "reference": self.reference_forward_velocity, + "state_of_interest": info["x_velocity"], + "reference_error": info["x_velocity"] - self.reference_forward_velocity, + } + ) return obs, cost, terminated, truncated, info @@ -318,6 +331,8 @@ def reset(self, seed=None, options=None): """ obs, info = super().reset(seed=seed, options=options) + _, cost_info = self.cost(0.0, 0.0) + # Randomize the reference forward velocity if requested. if self._randomise_reference_forward_velocity: self.reference_forward_velocity = self.np_random.uniform( @@ -334,6 +349,16 @@ def reset(self, seed=None, options=None): self.state = obs + # Update info dictionary. + info.update(cost_info) + info.update( + { + "reference": self.reference_forward_velocity, + "state_of_interest": 0.0, + "reference_error": 0.0 - self.reference_forward_velocity, + } + ) + return obs, info @property diff --git a/stable_gym/envs/mujoco/half_cheetah_cost/half_cheetah_cost.py b/stable_gym/envs/mujoco/half_cheetah_cost/half_cheetah_cost.py index 2ca31693..d3d19fcf 100644 --- a/stable_gym/envs/mujoco/half_cheetah_cost/half_cheetah_cost.py +++ b/stable_gym/envs/mujoco/half_cheetah_cost/half_cheetah_cost.py @@ -231,8 +231,6 @@ def step(self, action): """ obs, _, terminated, truncated, info = super().step(action) - self.state = obs - cost, cost_info = self.cost(info["x_velocity"], -info["reward_ctrl"]) # Add reference, x velocity and reference error to observation. @@ -243,9 +241,18 @@ def step(self, action): if not self._exclude_x_velocity_from_observation: obs = np.append(obs, info["x_velocity"]) - # Update info. + self.state = obs + + # Update info dictionary. del info["reward_run"], info["reward_ctrl"] info.update(cost_info) + info.update( + { + "reference": self.reference_forward_velocity, + "state_of_interest": info["x_velocity"], + "reference_error": info["x_velocity"] - self.reference_forward_velocity, + } + ) return obs, cost, terminated, truncated, info @@ -267,6 +274,8 @@ def reset(self, seed=None, options=None): """ obs, info = super().reset(seed=seed, options=options) + _, cost_info = self.cost(0.0, 0.0) + # Randomize the reference forward velocity if requested. if self._randomise_reference_forward_velocity: self.reference_forward_velocity = self.np_random.uniform( @@ -283,6 +292,16 @@ def reset(self, seed=None, options=None): self.state = obs + # Update info dictionary. + info.update(cost_info) + info.update( + { + "reference": self.reference_forward_velocity, + "state_of_interest": 0.0, + "reference_error": 0.0 - self.reference_forward_velocity, + } + ) + return obs, info @property diff --git a/stable_gym/envs/mujoco/hopper_cost/hopper_cost.py b/stable_gym/envs/mujoco/hopper_cost/hopper_cost.py index 3710680a..d380747d 100644 --- a/stable_gym/envs/mujoco/hopper_cost/hopper_cost.py +++ b/stable_gym/envs/mujoco/hopper_cost/hopper_cost.py @@ -271,8 +271,6 @@ def step(self, action): """ obs, _, terminated, truncated, info = super().step(action) - self.state = obs - ctrl_cost = super().control_cost(action) cost, cost_info = self.cost(info["x_velocity"], ctrl_cost) @@ -284,8 +282,17 @@ def step(self, action): if not self._exclude_x_velocity_from_observation: obs = np.append(obs, info["x_velocity"]) - # Update info. + self.state = obs + + # Update info dictionary. info.update(cost_info) + info.update( + { + "reference": self.reference_forward_velocity, + "state_of_interest": info["x_velocity"], + "reference_error": info["x_velocity"] - self.reference_forward_velocity, + } + ) return obs, cost, terminated, truncated, info @@ -307,6 +314,8 @@ def reset(self, seed=None, options=None): """ obs, info = super().reset(seed=seed, options=options) + _, cost_info = self.cost(0.0, 0.0) + # Randomize the reference forward velocity if requested. if self._randomise_reference_forward_velocity: self.reference_forward_velocity = self.np_random.uniform( @@ -323,6 +332,16 @@ def reset(self, seed=None, options=None): self.state = obs + # Update info dictionary. + info.update(cost_info) + info.update( + { + "reference": self.reference_forward_velocity, + "state_of_interest": 0.0, + "reference_error": 0.0 - self.reference_forward_velocity, + } + ) + return obs, info @property diff --git a/stable_gym/envs/mujoco/humanoid_cost/humanoid_cost.py b/stable_gym/envs/mujoco/humanoid_cost/humanoid_cost.py index 3d2354c5..66459b39 100644 --- a/stable_gym/envs/mujoco/humanoid_cost/humanoid_cost.py +++ b/stable_gym/envs/mujoco/humanoid_cost/humanoid_cost.py @@ -262,8 +262,6 @@ def step(self, action): """ obs, _, terminated, truncated, info = super().step(action) - self.state = obs - cost, cost_info = self.cost(info["x_velocity"], -info["reward_quadctrl"]) # Add reference, x velocity and reference error to observation. @@ -274,7 +272,9 @@ def step(self, action): if not self._exclude_x_velocity_from_observation: obs = np.append(obs, info["x_velocity"]) - # Update info. + self.state = obs + + # Update info dictionary. del ( info["reward_linvel"], info["reward_quadctrl"], @@ -282,6 +282,13 @@ def step(self, action): info["forward_reward"], ) info.update(cost_info) + info.update( + { + "reference": self.reference_forward_velocity, + "state_of_interest": info["x_velocity"], + "reference_error": info["x_velocity"] - self.reference_forward_velocity, + } + ) return obs, cost, terminated, truncated, info @@ -303,6 +310,8 @@ def reset(self, seed=None, options=None): """ obs, info = super().reset(seed=seed, options=options) + _, cost_info = self.cost(0.0, 0.0) + # Randomize the reference forward velocity if requested. if self._randomise_reference_forward_velocity: self.reference_forward_velocity = self.np_random.uniform( @@ -319,6 +328,16 @@ def reset(self, seed=None, options=None): self.state = obs + # Update info dictionary. + info.update(cost_info) + info.update( + { + "reference": self.reference_forward_velocity, + "state_of_interest": 0.0, + "reference_error": 0.0 - self.reference_forward_velocity, + } + ) + return obs, info @property diff --git a/stable_gym/envs/mujoco/swimmer_cost/swimmer_cost.py b/stable_gym/envs/mujoco/swimmer_cost/swimmer_cost.py index 585d154d..5bdbd5a5 100644 --- a/stable_gym/envs/mujoco/swimmer_cost/swimmer_cost.py +++ b/stable_gym/envs/mujoco/swimmer_cost/swimmer_cost.py @@ -229,8 +229,6 @@ def step(self, action): """ obs, _, terminated, truncated, info = super().step(action) - self.state = obs - cost, cost_info = self.cost(info["x_velocity"], -info["reward_ctrl"]) # Add reference, x velocity and reference error to observation. @@ -241,9 +239,18 @@ def step(self, action): if not self._exclude_x_velocity_from_observation: obs = np.append(obs, info["x_velocity"]) - # Update info. + self.state = obs + + # Update info dictionary. del info["reward_fwd"], info["reward_ctrl"], info["forward_reward"] info.update(cost_info) + info.update( + { + "reference": self.reference_forward_velocity, + "state_of_interest": info["x_velocity"], + "reference_error": info["x_velocity"] - self.reference_forward_velocity, + } + ) return obs, cost, terminated, truncated, info @@ -265,6 +272,8 @@ def reset(self, seed=None, options=None): """ obs, info = super().reset(seed=seed, options=options) + _, cost_info = self.cost(0.0, 0.0) + # Randomize the reference forward velocity if requested. if self._randomise_reference_forward_velocity: self.reference_forward_velocity = self.np_random.uniform( @@ -281,6 +290,16 @@ def reset(self, seed=None, options=None): self.state = obs + # Update info dictionary. + info.update(cost_info) + info.update( + { + "reference": self.reference_forward_velocity, + "state_of_interest": 0.0, + "reference_error": 0.0 - self.reference_forward_velocity, + } + ) + return obs, info @property diff --git a/stable_gym/envs/mujoco/walker2d_cost/walker2d_cost.py b/stable_gym/envs/mujoco/walker2d_cost/walker2d_cost.py index d5e12ea0..d6439754 100644 --- a/stable_gym/envs/mujoco/walker2d_cost/walker2d_cost.py +++ b/stable_gym/envs/mujoco/walker2d_cost/walker2d_cost.py @@ -266,8 +266,6 @@ def step(self, action): """ obs, _, terminated, truncated, info = super().step(action) - self.state = obs - ctrl_cost = super().control_cost(action) cost, cost_info = self.cost(info["x_velocity"], ctrl_cost) @@ -279,8 +277,17 @@ def step(self, action): if not self._exclude_x_velocity_from_observation: obs = np.append(obs, info["x_velocity"]) - # Update info. + self.state = obs + + # Update info dictionary. info.update(cost_info) + info.update( + { + "reference": self.reference_forward_velocity, + "state_of_interest": info["x_velocity"], + "reference_error": info["x_velocity"] - self.reference_forward_velocity, + } + ) return obs, cost, terminated, truncated, info @@ -302,6 +309,8 @@ def reset(self, seed=None, options=None): """ obs, info = super().reset(seed=seed, options=options) + _, cost_info = self.cost(0.0, 0.0) + # Randomize the reference forward velocity if requested. if self._randomise_reference_forward_velocity: self.reference_forward_velocity = self.np_random.uniform( @@ -318,6 +327,16 @@ def reset(self, seed=None, options=None): self.state = obs + # Update info dictionary. + info.update(cost_info) + info.update( + { + "reference": self.reference_forward_velocity, + "state_of_interest": 0.0, + "reference_error": 0.0 - self.reference_forward_velocity, + } + ) + return obs, info @property