Skip to content

Commit

Permalink
Merge branch 'main' into fix/single-device-fabric
Browse files Browse the repository at this point in the history
  • Loading branch information
belerico authored Apr 2, 2024
2 parents 9e2e3f6 + 9f557c6 commit 4df76b0
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 2 deletions.
22 changes: 22 additions & 0 deletions sheeprl/algos/a2c/a2c.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,28 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):

# Single environment step
obs, rewards, done, truncated, info = envs.step(real_actions.reshape(envs.action_space.shape))
truncated_envs = np.nonzero(truncated)[0]
if len(truncated_envs) > 0:
real_next_obs = {
k: torch.empty(
len(truncated_envs),
*observation_space[k].shape,
dtype=torch.float32,
device=device,
)
for k in obs_keys
}
for i, truncated_env in enumerate(truncated_envs):
for k, v in info["final_observation"][truncated_env].items():
torch_v = torch.as_tensor(v, dtype=torch.float32, device=device)
if k in cfg.algo.cnn_keys.encoder:
torch_v = torch_v.view(-1, *v.shape[-2:])
torch_v = torch_v / 255.0 - 0.5
real_next_obs[k][i] = torch_v
_, _, vals = player(real_next_obs)
rewards[truncated_envs] += cfg.algo.gamma * vals.cpu().numpy().reshape(
rewards[truncated_envs].shape
)

dones = np.logical_or(done, truncated)
dones = dones.reshape(cfg.env.num_envs, -1)
Expand Down
4 changes: 3 additions & 1 deletion sheeprl/algos/ppo/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
torch_v = torch_v / 255.0 - 0.5
real_next_obs[k][i] = torch_v
_, _, _, vals = player(real_next_obs)
rewards[truncated_envs] += vals.cpu().numpy().reshape(rewards[truncated_envs].shape)
rewards[truncated_envs] += cfg.algo.gamma * vals.cpu().numpy().reshape(
rewards[truncated_envs].shape
)
dones = np.logical_or(dones, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8)
rewards = rewards.reshape(cfg.env.num_envs, -1)

Expand Down
4 changes: 3 additions & 1 deletion sheeprl/algos/ppo/ppo_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,9 @@ def player(
torch_v = torch_v / 255.0 - 0.5
real_next_obs[k][i] = torch_v
_, _, _, vals = agent(real_next_obs)
rewards[truncated_envs] += vals.cpu().numpy().reshape(rewards[truncated_envs].shape)
rewards[truncated_envs] += cfg.algo.gamma * vals.cpu().numpy().reshape(
rewards[truncated_envs].shape
)
dones = np.logical_or(dones, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8)
rewards = rewards.reshape(cfg.env.num_envs, -1)

Expand Down

0 comments on commit 4df76b0

Please sign in to comment.