From be010196f605ad2b65639de8a02a5215e0135307 Mon Sep 17 00:00:00 2001 From: rickstaa Date: Wed, 26 Jul 2023 17:38:56 +0200 Subject: [PATCH] feat(minitaur): add minitaur step/reset reference info return This commit adds the `reference`, `state_of_interest` and `reference_error` to the info dictionary that is returned by the environment's step and reset methods. --- .../minitaur_bullet_cost/minitaur_bullet_cost.py | 11 +++++++++-- tests/__snapshots__/test_minitaur_cost.ambr | 15 +++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/stable_gym/envs/robotics/minitaur/minitaur_bullet_cost/minitaur_bullet_cost.py b/stable_gym/envs/robotics/minitaur/minitaur_bullet_cost/minitaur_bullet_cost.py index 74b85bbb..e589ae20 100644 --- a/stable_gym/envs/robotics/minitaur/minitaur_bullet_cost/minitaur_bullet_cost.py +++ b/stable_gym/envs/robotics/minitaur/minitaur_bullet_cost/minitaur_bullet_cost.py @@ -349,7 +349,7 @@ def step(self, action): self.state = obs self.t = self.t + self.dt - # Retrieve original rew5ards and base velocity. + # Retrieve original rewards. # NOTE: Han et al. 2018 used the squared error for the drift reward. We use the # version found in the original Minitaur environment (i.e. absolute distance). objectives = super().get_objectives() @@ -359,9 +359,16 @@ def step(self, action): # Compute the cost and update the info dict. cost, cost_info = self.cost( - self.base_velocity, energy_reward, drift_cost, shake_cost + base_velocity, energy_reward, drift_cost, shake_cost ) info.update(cost_info) + info.update( + { + "reference": self.reference_forward_velocity, + "state_of_interest": base_velocity, + "reference_error": base_velocity - self.reference_forward_velocity, + } + ) # Add optional health penalty at the end of the episode if requested. if self._include_health_penalty: diff --git a/tests/__snapshots__/test_minitaur_cost.ambr b/tests/__snapshots__/test_minitaur_cost.ambr index 59dd61d8..e75bc398 100644 --- a/tests/__snapshots__/test_minitaur_cost.ambr +++ b/tests/__snapshots__/test_minitaur_cost.ambr @@ -17,6 +17,9 @@ 'cost_shake': 0.001545201927514, 'cost_velocity': 1.262878009055252, 'energy_cost': 0.208187038300126, + 'reference': 1.0, + 'reference_error': -1.123778451944712, + 'state_of_interest': -0.123778451944712, }) # --- # name: TestMinitaurBulletCostEqual.test_snapshot.102 @@ -136,6 +139,9 @@ 'cost_shake': 0.001216954533656, 'cost_velocity': 1.308016924484482, 'energy_cost': 0.411224652260461, + 'reference': 1.0, + 'reference_error': -1.143685675561464, + 'state_of_interest': -0.143685675561464, }) # --- # name: TestMinitaurBulletCostEqual.test_snapshot.137 @@ -258,6 +264,9 @@ 'cost_shake': 0.001062425280662, 'cost_velocity': 1.05929279162964, 'energy_cost': 0.411447438208018, + 'reference': 1.0, + 'reference_error': -1.029219506047976, + 'state_of_interest': -0.029219506047976, }) # --- # name: TestMinitaurBulletCostEqual.test_snapshot.172 @@ -380,6 +389,9 @@ 'cost_shake': 0.002593163820677, 'cost_velocity': 1.011816047403776, 'energy_cost': 0.457486833811053, + 'reference': 1.0, + 'reference_error': -1.005890673683664, + 'state_of_interest': -0.005890673683664, }) # --- # name: TestMinitaurBulletCostEqual.test_snapshot.21 @@ -536,6 +548,9 @@ 'cost_shake': 0.158954162148827, 'cost_velocity': 1.147753757228502, 'energy_cost': 0.201663962061444, + 'reference': 1.0, + 'reference_error': -1.071332701465097, + 'state_of_interest': -0.071332701465097, }) # --- # name: TestMinitaurBulletCostEqual.test_snapshot.67