Skip to content

Commit

Permalink
fix: sac coupled and decoupled (#205)
Browse files Browse the repository at this point in the history
  • Loading branch information
michele-milesi authored Feb 12, 2024
1 parent 7de395f commit 920c7d9
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
4 changes: 3 additions & 1 deletion sheeprl/algos/sac/sac.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
if "final_observation" in infos:
for idx, final_obs in enumerate(infos["final_observation"]):
if final_obs is not None:
real_next_obs[idx] = np.concatenate([v for v in final_obs.values()], axis=-1)
real_next_obs[idx] = np.concatenate(
[v for k, v in final_obs.items() if k in cfg.algo.mlp_keys.encoder], axis=-1
)

step_data["dones"] = dones[np.newaxis]
step_data["actions"] = actions[np.newaxis]
Expand Down
4 changes: 3 additions & 1 deletion sheeprl/algos/sac/sac_decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,9 @@ def player(
if "final_observation" in infos:
for idx, final_obs in enumerate(infos["final_observation"]):
if final_obs is not None:
real_next_obs[idx] = np.concatenate([v for v in final_obs.values()], axis=-1)
real_next_obs[idx] = np.concatenate(
[v for k, v in final_obs.items() if k in cfg.algo.mlp_keys.encoder], axis=-1
)

step_data["dones"] = dones[np.newaxis]
step_data["actions"] = actions[np.newaxis]
Expand Down

0 comments on commit 920c7d9

Please sign in to comment.