Skip to content

Commit

Permalink
Fix player in notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
belerico committed Apr 8, 2024
1 parent 7711d19 commit d25c7b5
Showing 1 changed file with 3 additions and 14 deletions.
17 changes: 3 additions & 14 deletions notebooks/dreamer_v3_imagination.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
"from omegaconf import OmegaConf\n",
"from PIL import Image\n",
"\n",
"from sheeprl.algos.dreamer_v3.agent import PlayerDV3, build_agent\n",
"from sheeprl.algos.dreamer_v3.agent import build_agent\n",
"from sheeprl.data.buffers import SequentialReplayBuffer\n",
"from sheeprl.utils.env import make_env\n",
"from sheeprl.utils.utils import dotdict"
Expand Down Expand Up @@ -128,7 +128,7 @@
"actions_dim = tuple(\n",
" action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])\n",
")\n",
"world_model, actor, critic, critic_target = build_agent(\n",
"world_model, actor, critic, critic_target, player = build_agent(\n",
" fabric,\n",
" actions_dim,\n",
" is_continuous,\n",
Expand All @@ -138,17 +138,6 @@
" state[\"actor\"],\n",
" state[\"critic\"],\n",
" state[\"target_critic\"],\n",
")\n",
"player = PlayerDV3(\n",
" world_model.encoder.module,\n",
" world_model.rssm,\n",
" actor.module,\n",
" actions_dim,\n",
" cfg.env.num_envs,\n",
" cfg.algo.world_model.stochastic_size,\n",
" cfg.algo.world_model.recurrent_model.recurrent_state_size,\n",
" fabric.device,\n",
" cfg.algo.world_model.discrete_size,\n",
")"
]
},
Expand Down Expand Up @@ -230,7 +219,7 @@
" mask = {k: v for k, v in preprocessed_obs.items() if k.startswith(\"mask\")}\n",
" if len(mask) == 0:\n",
" mask = None\n",
" real_actions = actions = player.get_actions(preprocessed_obs, mask)\n",
" real_actions = actions = player.get_actions(preprocessed_obs, mask=mask)\n",
" actions = torch.cat(actions, -1).cpu().numpy()\n",
" if is_continuous:\n",
" real_actions = torch.cat(real_actions, dim=-1).cpu().numpy()\n",
Expand Down

0 comments on commit d25c7b5

Please sign in to comment.