-
Notifications
You must be signed in to change notification settings - Fork 8.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor mujoco envs to support dynamic arguments (#1304)
* Refactor gym envs to support dynamic arguments * Fix viewer setup lookat configuration * Add xml_file argument for mujoco envs * Move refactored mujoco envs to their own _v3.py files * Revert "Add xml_file argument for mujoco envs" This reverts commit 4a3a74c. * Revert "Fix viewer setup lookat configuration" This reverts commit 62b4bcf. * Revert "Refactor gym envs to support dynamic arguments" This reverts commit b2a439f. * Fix v3 SwimmerEnv info * Regiter v3 mujoco environments * Implement v2 to v3 conversion test * Add extra step info the v3 environments * polish the new unit tests a little bit
- Loading branch information
1 parent
17abad3
commit 90a0564
Showing
9 changed files
with
889 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
import numpy as np | ||
from gym import utils | ||
from gym.envs.mujoco import mujoco_env | ||
|
||
|
||
DEFAULT_CAMERA_CONFIG = { | ||
'distance': 4.0, | ||
} | ||
|
||
|
||
class AntEnv(mujoco_env.MujocoEnv, utils.EzPickle): | ||
def __init__(self, | ||
xml_file='ant.xml', | ||
ctrl_cost_weight=0.5, | ||
contact_cost_weight=5e-4, | ||
healthy_reward=1.0, | ||
terminate_when_unhealthy=True, | ||
healthy_z_range=(0.2, 1.0), | ||
contact_force_range=(-1.0, 1.0), | ||
reset_noise_scale=0.1, | ||
exclude_current_positions_from_observation=True): | ||
utils.EzPickle.__init__(**locals()) | ||
|
||
self._ctrl_cost_weight = ctrl_cost_weight | ||
self._contact_cost_weight = contact_cost_weight | ||
|
||
self._healthy_reward = healthy_reward | ||
self._terminate_when_unhealthy = terminate_when_unhealthy | ||
self._healthy_z_range = healthy_z_range | ||
|
||
self._contact_force_range = contact_force_range | ||
|
||
self._reset_noise_scale = reset_noise_scale | ||
|
||
self._exclude_current_positions_from_observation = ( | ||
exclude_current_positions_from_observation) | ||
|
||
mujoco_env.MujocoEnv.__init__(self, xml_file, 5) | ||
|
||
@property | ||
def healthy_reward(self): | ||
return float( | ||
self.is_healthy | ||
or self._terminate_when_unhealthy | ||
) * self._healthy_reward | ||
|
||
def control_cost(self, action): | ||
control_cost = self._ctrl_cost_weight * np.sum(np.square(action)) | ||
return control_cost | ||
|
||
@property | ||
def contact_forces(self): | ||
raw_contact_forces = self.sim.data.cfrc_ext | ||
min_value, max_value = self._contact_force_range | ||
contact_forces = np.clip(raw_contact_forces, min_value, max_value) | ||
return contact_forces | ||
|
||
@property | ||
def contact_cost(self): | ||
contact_cost = self._contact_cost_weight * np.sum( | ||
np.square(self.contact_forces)) | ||
return contact_cost | ||
|
||
@property | ||
def is_healthy(self): | ||
state = self.state_vector() | ||
min_z, max_z = self._healthy_z_range | ||
is_healthy = (np.isfinite(state).all() and min_z <= state[2] <= max_z) | ||
return is_healthy | ||
|
||
@property | ||
def done(self): | ||
done = (not self.is_healthy | ||
if self._terminate_when_unhealthy | ||
else False) | ||
return done | ||
|
||
def step(self, action): | ||
xy_position_before = self.get_body_com("torso")[:2].copy() | ||
self.do_simulation(action, self.frame_skip) | ||
xy_position_after = self.get_body_com("torso")[:2].copy() | ||
|
||
xy_velocity = (xy_position_after - xy_position_before) / self.dt | ||
x_velocity, y_velocity = xy_velocity | ||
|
||
ctrl_cost = self.control_cost(action) | ||
contact_cost = self.contact_cost | ||
|
||
forward_reward = x_velocity | ||
healthy_reward = self.healthy_reward | ||
|
||
rewards = forward_reward + healthy_reward | ||
costs = ctrl_cost + contact_cost | ||
|
||
reward = rewards - costs | ||
done = self.done | ||
observation = self._get_obs() | ||
info = { | ||
'reward_forward': forward_reward, | ||
'reward_ctrl': -ctrl_cost, | ||
'reward_contact': -contact_cost, | ||
'reward_survive': healthy_reward, | ||
|
||
'x_position': xy_position_after[0], | ||
'y_position': xy_position_after[1], | ||
'distance_from_origin': np.linalg.norm(xy_position_after, ord=2), | ||
|
||
'x_velocity': x_velocity, | ||
'y_velocity': y_velocity, | ||
'forward_reward': forward_reward, | ||
} | ||
|
||
return observation, reward, done, info | ||
|
||
def _get_obs(self): | ||
position = self.sim.data.qpos.flat.copy() | ||
velocity = self.sim.data.qvel.flat.copy() | ||
contact_force = self.contact_forces.flat.copy() | ||
|
||
if self._exclude_current_positions_from_observation: | ||
position = position[2:] | ||
|
||
observations = np.concatenate((position, velocity, contact_force)) | ||
|
||
return observations | ||
|
||
def reset_model(self): | ||
noise_low = -self._reset_noise_scale | ||
noise_high = self._reset_noise_scale | ||
|
||
qpos = self.init_qpos + self.np_random.uniform( | ||
low=noise_low, high=noise_high, size=self.model.nq) | ||
qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn( | ||
self.model.nv) | ||
self.set_state(qpos, qvel) | ||
|
||
observation = self._get_obs() | ||
|
||
return observation | ||
|
||
def viewer_setup(self): | ||
for key, value in DEFAULT_CAMERA_CONFIG.items(): | ||
if isinstance(value, np.ndarray): | ||
getattr(self.viewer.cam, key)[:] = value | ||
else: | ||
setattr(self.viewer.cam, key, value) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import numpy as np | ||
from gym import utils | ||
from gym.envs.mujoco import mujoco_env | ||
|
||
|
||
DEFAULT_CAMERA_CONFIG = { | ||
'distance': 4.0, | ||
} | ||
|
||
|
||
class HalfCheetahEnv(mujoco_env.MujocoEnv, utils.EzPickle): | ||
def __init__(self, | ||
xml_file='half_cheetah.xml', | ||
forward_reward_weight=1.0, | ||
ctrl_cost_weight=0.1, | ||
reset_noise_scale=0.1, | ||
exclude_current_positions_from_observation=True): | ||
utils.EzPickle.__init__(**locals()) | ||
|
||
self._forward_reward_weight = forward_reward_weight | ||
|
||
self._ctrl_cost_weight = ctrl_cost_weight | ||
|
||
self._reset_noise_scale = reset_noise_scale | ||
|
||
self._exclude_current_positions_from_observation = ( | ||
exclude_current_positions_from_observation) | ||
|
||
mujoco_env.MujocoEnv.__init__(self, xml_file, 5) | ||
|
||
def control_cost(self, action): | ||
control_cost = self._ctrl_cost_weight * np.sum(np.square(action)) | ||
return control_cost | ||
|
||
def step(self, action): | ||
x_position_before = self.sim.data.qpos[0] | ||
self.do_simulation(action, self.frame_skip) | ||
x_position_after = self.sim.data.qpos[0] | ||
x_velocity = ((x_position_after - x_position_before) | ||
/ self.dt) | ||
|
||
ctrl_cost = self.control_cost(action) | ||
|
||
forward_reward = self._forward_reward_weight * x_velocity | ||
|
||
observation = self._get_obs() | ||
reward = forward_reward - ctrl_cost | ||
done = False | ||
info = { | ||
'x_position': x_position_after, | ||
'x_velocity': x_velocity, | ||
|
||
'reward_run': forward_reward, | ||
'reward_ctrl': -ctrl_cost | ||
} | ||
|
||
return observation, reward, done, info | ||
|
||
def _get_obs(self): | ||
position = self.sim.data.qpos.flat.copy() | ||
velocity = self.sim.data.qvel.flat.copy() | ||
|
||
if self._exclude_current_positions_from_observation: | ||
position = position[1:] | ||
|
||
observation = np.concatenate((position, velocity)).ravel() | ||
return observation | ||
|
||
def reset_model(self): | ||
noise_low = -self._reset_noise_scale | ||
noise_high = self._reset_noise_scale | ||
|
||
qpos = self.init_qpos + self.np_random.uniform( | ||
low=noise_low, high=noise_high, size=self.model.nq) | ||
qvel = self.init_qvel + self._reset_noise_scale * self.np_random.randn( | ||
self.model.nv) | ||
|
||
self.set_state(qpos, qvel) | ||
|
||
observation = self._get_obs() | ||
return observation | ||
|
||
def viewer_setup(self): | ||
for key, value in DEFAULT_CAMERA_CONFIG.items(): | ||
if isinstance(value, np.ndarray): | ||
getattr(self.viewer.cam, key)[:] = value | ||
else: | ||
setattr(self.viewer.cam, key, value) |
Oops, something went wrong.