Skip to content

Commit

Permalink
feat: add reference info to step/reset return (#275)
Browse files Browse the repository at this point in the history
This commit adds the `reference`, `state_of_interest` and
`reference_error` keys to the info dictionary that is returned by the
environment step and reset methods.
  • Loading branch information
rickstaa authored Jul 26, 2023
1 parent 021d846 commit 7d715bd
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 18 deletions.
31 changes: 28 additions & 3 deletions stable_gym/envs/mujoco/ant_cost/ant_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,6 @@ def step(self, action):
"""
obs, _, terminated, truncated, info = super().step(action)

self.state = obs

cost, cost_info = self.cost(info["x_velocity"], -info["reward_ctrl"])

# Add reference, x velocity and reference error to observation.
Expand All @@ -295,8 +293,23 @@ def step(self, action):
if not self._exclude_x_velocity_from_observation:
obs = np.append(obs, info["x_velocity"])

# Update info.
self.state = obs

# Update info dictionary.
del (
info["reward_forward"],
info["forward_reward"],
info["reward_ctrl"],
info["reward_survive"],
)
info.update(cost_info)
info.update(
{
"reference": self.reference_forward_velocity,
"state_of_interest": info["x_velocity"],
"reference_error": info["x_velocity"] - self.reference_forward_velocity,
}
)

return obs, cost, terminated, truncated, info

Expand All @@ -318,6 +331,8 @@ def reset(self, seed=None, options=None):
"""
obs, info = super().reset(seed=seed, options=options)

_, cost_info = self.cost(0.0, 0.0)

# Randomize the reference forward velocity if requested.
if self._randomise_reference_forward_velocity:
self.reference_forward_velocity = self.np_random.uniform(
Expand All @@ -334,6 +349,16 @@ def reset(self, seed=None, options=None):

self.state = obs

# Update info dictionary.
info.update(cost_info)
info.update(
{
"reference": self.reference_forward_velocity,
"state_of_interest": 0.0,
"reference_error": 0.0 - self.reference_forward_velocity,
}
)

return obs, info

@property
Expand Down
25 changes: 22 additions & 3 deletions stable_gym/envs/mujoco/half_cheetah_cost/half_cheetah_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,6 @@ def step(self, action):
"""
obs, _, terminated, truncated, info = super().step(action)

self.state = obs

cost, cost_info = self.cost(info["x_velocity"], -info["reward_ctrl"])

# Add reference, x velocity and reference error to observation.
Expand All @@ -243,9 +241,18 @@ def step(self, action):
if not self._exclude_x_velocity_from_observation:
obs = np.append(obs, info["x_velocity"])

# Update info.
self.state = obs

# Update info dictionary.
del info["reward_run"], info["reward_ctrl"]
info.update(cost_info)
info.update(
{
"reference": self.reference_forward_velocity,
"state_of_interest": info["x_velocity"],
"reference_error": info["x_velocity"] - self.reference_forward_velocity,
}
)

return obs, cost, terminated, truncated, info

Expand All @@ -267,6 +274,8 @@ def reset(self, seed=None, options=None):
"""
obs, info = super().reset(seed=seed, options=options)

_, cost_info = self.cost(0.0, 0.0)

# Randomize the reference forward velocity if requested.
if self._randomise_reference_forward_velocity:
self.reference_forward_velocity = self.np_random.uniform(
Expand All @@ -283,6 +292,16 @@ def reset(self, seed=None, options=None):

self.state = obs

# Update info dictionary.
info.update(cost_info)
info.update(
{
"reference": self.reference_forward_velocity,
"state_of_interest": 0.0,
"reference_error": 0.0 - self.reference_forward_velocity,
}
)

return obs, info

@property
Expand Down
25 changes: 22 additions & 3 deletions stable_gym/envs/mujoco/hopper_cost/hopper_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,8 +271,6 @@ def step(self, action):
"""
obs, _, terminated, truncated, info = super().step(action)

self.state = obs

ctrl_cost = super().control_cost(action)
cost, cost_info = self.cost(info["x_velocity"], ctrl_cost)

Expand All @@ -284,8 +282,17 @@ def step(self, action):
if not self._exclude_x_velocity_from_observation:
obs = np.append(obs, info["x_velocity"])

# Update info.
self.state = obs

# Update info dictionary.
info.update(cost_info)
info.update(
{
"reference": self.reference_forward_velocity,
"state_of_interest": info["x_velocity"],
"reference_error": info["x_velocity"] - self.reference_forward_velocity,
}
)

return obs, cost, terminated, truncated, info

Expand All @@ -307,6 +314,8 @@ def reset(self, seed=None, options=None):
"""
obs, info = super().reset(seed=seed, options=options)

_, cost_info = self.cost(0.0, 0.0)

# Randomize the reference forward velocity if requested.
if self._randomise_reference_forward_velocity:
self.reference_forward_velocity = self.np_random.uniform(
Expand All @@ -323,6 +332,16 @@ def reset(self, seed=None, options=None):

self.state = obs

# Update info dictionary.
info.update(cost_info)
info.update(
{
"reference": self.reference_forward_velocity,
"state_of_interest": 0.0,
"reference_error": 0.0 - self.reference_forward_velocity,
}
)

return obs, info

@property
Expand Down
25 changes: 22 additions & 3 deletions stable_gym/envs/mujoco/humanoid_cost/humanoid_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,8 +262,6 @@ def step(self, action):
"""
obs, _, terminated, truncated, info = super().step(action)

self.state = obs

cost, cost_info = self.cost(info["x_velocity"], -info["reward_quadctrl"])

# Add reference, x velocity and reference error to observation.
Expand All @@ -274,14 +272,23 @@ def step(self, action):
if not self._exclude_x_velocity_from_observation:
obs = np.append(obs, info["x_velocity"])

# Update info.
self.state = obs

# Update info dictionary.
del (
info["reward_linvel"],
info["reward_quadctrl"],
info["reward_alive"],
info["forward_reward"],
)
info.update(cost_info)
info.update(
{
"reference": self.reference_forward_velocity,
"state_of_interest": info["x_velocity"],
"reference_error": info["x_velocity"] - self.reference_forward_velocity,
}
)

return obs, cost, terminated, truncated, info

Expand All @@ -303,6 +310,8 @@ def reset(self, seed=None, options=None):
"""
obs, info = super().reset(seed=seed, options=options)

_, cost_info = self.cost(0.0, 0.0)

# Randomize the reference forward velocity if requested.
if self._randomise_reference_forward_velocity:
self.reference_forward_velocity = self.np_random.uniform(
Expand All @@ -319,6 +328,16 @@ def reset(self, seed=None, options=None):

self.state = obs

# Update info dictionary.
info.update(cost_info)
info.update(
{
"reference": self.reference_forward_velocity,
"state_of_interest": 0.0,
"reference_error": 0.0 - self.reference_forward_velocity,
}
)

return obs, info

@property
Expand Down
25 changes: 22 additions & 3 deletions stable_gym/envs/mujoco/swimmer_cost/swimmer_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,6 @@ def step(self, action):
"""
obs, _, terminated, truncated, info = super().step(action)

self.state = obs

cost, cost_info = self.cost(info["x_velocity"], -info["reward_ctrl"])

# Add reference, x velocity and reference error to observation.
Expand All @@ -241,9 +239,18 @@ def step(self, action):
if not self._exclude_x_velocity_from_observation:
obs = np.append(obs, info["x_velocity"])

# Update info.
self.state = obs

# Update info dictionary.
del info["reward_fwd"], info["reward_ctrl"], info["forward_reward"]
info.update(cost_info)
info.update(
{
"reference": self.reference_forward_velocity,
"state_of_interest": info["x_velocity"],
"reference_error": info["x_velocity"] - self.reference_forward_velocity,
}
)

return obs, cost, terminated, truncated, info

Expand All @@ -265,6 +272,8 @@ def reset(self, seed=None, options=None):
"""
obs, info = super().reset(seed=seed, options=options)

_, cost_info = self.cost(0.0, 0.0)

# Randomize the reference forward velocity if requested.
if self._randomise_reference_forward_velocity:
self.reference_forward_velocity = self.np_random.uniform(
Expand All @@ -281,6 +290,16 @@ def reset(self, seed=None, options=None):

self.state = obs

# Update info dictionary.
info.update(cost_info)
info.update(
{
"reference": self.reference_forward_velocity,
"state_of_interest": 0.0,
"reference_error": 0.0 - self.reference_forward_velocity,
}
)

return obs, info

@property
Expand Down
25 changes: 22 additions & 3 deletions stable_gym/envs/mujoco/walker2d_cost/walker2d_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,6 @@ def step(self, action):
"""
obs, _, terminated, truncated, info = super().step(action)

self.state = obs

ctrl_cost = super().control_cost(action)
cost, cost_info = self.cost(info["x_velocity"], ctrl_cost)

Expand All @@ -279,8 +277,17 @@ def step(self, action):
if not self._exclude_x_velocity_from_observation:
obs = np.append(obs, info["x_velocity"])

# Update info.
self.state = obs

# Update info dictionary.
info.update(cost_info)
info.update(
{
"reference": self.reference_forward_velocity,
"state_of_interest": info["x_velocity"],
"reference_error": info["x_velocity"] - self.reference_forward_velocity,
}
)

return obs, cost, terminated, truncated, info

Expand All @@ -302,6 +309,8 @@ def reset(self, seed=None, options=None):
"""
obs, info = super().reset(seed=seed, options=options)

_, cost_info = self.cost(0.0, 0.0)

# Randomize the reference forward velocity if requested.
if self._randomise_reference_forward_velocity:
self.reference_forward_velocity = self.np_random.uniform(
Expand All @@ -318,6 +327,16 @@ def reset(self, seed=None, options=None):

self.state = obs

# Update info dictionary.
info.update(cost_info)
info.update(
{
"reference": self.reference_forward_velocity,
"state_of_interest": 0.0,
"reference_error": 0.0 - self.reference_forward_velocity,
}
)

return obs, info

@property
Expand Down

0 comments on commit 7d715bd

Please sign in to comment.