Skip to content

Commit

Permalink
feat(minitaur): add minitaur step/reset reference info return
Browse files Browse the repository at this point in the history
This commit adds the `reference`, `state_of_interest` and `reference_error` to the info dictionary
that is returned by the environment's step and reset methods.
  • Loading branch information
rickstaa committed Jul 26, 2023
1 parent 7d715bd commit be01019
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def step(self, action):
self.state = obs
self.t = self.t + self.dt

# Retrieve original rew5ards and base velocity.
# Retrieve original rewards.
# NOTE: Han et al. 2018 used the squared error for the drift reward. We use the
# version found in the original Minitaur environment (i.e. absolute distance).
objectives = super().get_objectives()
Expand All @@ -359,9 +359,16 @@ def step(self, action):

# Compute the cost and update the info dict.
cost, cost_info = self.cost(
self.base_velocity, energy_reward, drift_cost, shake_cost
base_velocity, energy_reward, drift_cost, shake_cost
)
info.update(cost_info)
info.update(
{
"reference": self.reference_forward_velocity,
"state_of_interest": base_velocity,
"reference_error": base_velocity - self.reference_forward_velocity,
}
)

# Add optional health penalty at the end of the episode if requested.
if self._include_health_penalty:
Expand Down
15 changes: 15 additions & 0 deletions tests/__snapshots__/test_minitaur_cost.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
'cost_shake': 0.001545201927514,
'cost_velocity': 1.262878009055252,
'energy_cost': 0.208187038300126,
'reference': 1.0,
'reference_error': -1.123778451944712,
'state_of_interest': -0.123778451944712,
})
# ---
# name: TestMinitaurBulletCostEqual.test_snapshot.102
Expand Down Expand Up @@ -136,6 +139,9 @@
'cost_shake': 0.001216954533656,
'cost_velocity': 1.308016924484482,
'energy_cost': 0.411224652260461,
'reference': 1.0,
'reference_error': -1.143685675561464,
'state_of_interest': -0.143685675561464,
})
# ---
# name: TestMinitaurBulletCostEqual.test_snapshot.137
Expand Down Expand Up @@ -258,6 +264,9 @@
'cost_shake': 0.001062425280662,
'cost_velocity': 1.05929279162964,
'energy_cost': 0.411447438208018,
'reference': 1.0,
'reference_error': -1.029219506047976,
'state_of_interest': -0.029219506047976,
})
# ---
# name: TestMinitaurBulletCostEqual.test_snapshot.172
Expand Down Expand Up @@ -380,6 +389,9 @@
'cost_shake': 0.002593163820677,
'cost_velocity': 1.011816047403776,
'energy_cost': 0.457486833811053,
'reference': 1.0,
'reference_error': -1.005890673683664,
'state_of_interest': -0.005890673683664,
})
# ---
# name: TestMinitaurBulletCostEqual.test_snapshot.21
Expand Down Expand Up @@ -536,6 +548,9 @@
'cost_shake': 0.158954162148827,
'cost_velocity': 1.147753757228502,
'energy_cost': 0.201663962061444,
'reference': 1.0,
'reference_error': -1.071332701465097,
'state_of_interest': -0.071332701465097,
})
# ---
# name: TestMinitaurBulletCostEqual.test_snapshot.67
Expand Down

0 comments on commit be01019

Please sign in to comment.