diff --git a/sheeprl/algos/dreamer_v1/agent.py b/sheeprl/algos/dreamer_v1/agent.py index a1492103..a654e135 100644 --- a/sheeprl/algos/dreamer_v1/agent.py +++ b/sheeprl/algos/dreamer_v1/agent.py @@ -276,14 +276,14 @@ def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None: self.stochastic_state[:, reset_envs] = torch.zeros_like(self.stochastic_state[:, reset_envs]) def get_exploration_actions( - self, obs: Tensor, sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None, step: int = 0 + self, obs: Tensor, greedy: bool = False, mask: Optional[Dict[str, Tensor]] = None, step: int = 0 ) -> Sequence[Tensor]: """Return the actions with a certain amount of noise for exploration. Args: obs (Tensor): the current observations. - sample_actions (bool): whether or not to sample the actions. - Default to True. + greedy (bool): whether or not to sample the actions. + Default to False. mask (Dict[str, Tensor], optional): the action mask (whether or not each action can be executed). Defaults to None. step (int): the step of the training, used for the exploration amount. @@ -292,7 +292,7 @@ def get_exploration_actions( Returns: The actions the agent has to perform (Sequence[Tensor]). """ - actions = self.get_actions(obs, sample_actions=sample_actions, mask=mask) + actions = self.get_actions(obs, greedy=greedy, mask=mask) expl_actions = None if self.actor._expl_amount > 0: expl_actions = self.actor.add_exploration_noise(actions, step=step, mask=mask) @@ -300,14 +300,14 @@ def get_exploration_actions( return expl_actions or actions def get_actions( - self, obs: Tensor, sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None + self, obs: Tensor, greedy: bool = False, mask: Optional[Dict[str, Tensor]] = None ) -> Sequence[Tensor]: """Return the greedy actions. Args: obs (Tensor): the current observations. - sample_actions (bool): whether or not to sample the actions. - Default to True. + greedy (bool): whether or not to sample the actions. + Default to False. mask (Dict[str, Tensor], optional): the action mask (whether or not each action can be executed). Defaults to None. @@ -321,7 +321,7 @@ def get_actions( _, self.stochastic_state = compute_stochastic_state( self.representation_model(torch.cat((self.recurrent_state, embedded_obs), -1)), ) - actions, _ = self.actor(torch.cat((self.stochastic_state, self.recurrent_state), -1), sample_actions, mask) + actions, _ = self.actor(torch.cat((self.stochastic_state, self.recurrent_state), -1), greedy, mask) self.actions = torch.cat(actions, -1) return actions diff --git a/sheeprl/algos/dreamer_v1/evaluate.py b/sheeprl/algos/dreamer_v1/evaluate.py index 48845a17..ae16efac 100644 --- a/sheeprl/algos/dreamer_v1/evaluate.py +++ b/sheeprl/algos/dreamer_v1/evaluate.py @@ -54,4 +54,4 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): state["actor"], ) del _ - test(player, fabric, cfg, log_dir, sample_actions=False) + test(player, fabric, cfg, log_dir, greedy=True) diff --git a/sheeprl/algos/dreamer_v2/agent.py b/sheeprl/algos/dreamer_v2/agent.py index b7908d4c..0c580c62 100644 --- a/sheeprl/algos/dreamer_v2/agent.py +++ b/sheeprl/algos/dreamer_v2/agent.py @@ -503,7 +503,7 @@ def _get_expl_amount(self, step: int) -> Tensor: return max(amount, self._expl_min) def forward( - self, state: Tensor, sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None + self, state: Tensor, greedy: bool = False, mask: Optional[Dict[str, Tensor]] = None ) -> Tuple[Sequence[Tensor], Sequence[Distribution]]: """ Call the forward method of the actor model and reorganizes the result with shape (batch_size, *, num_actions), @@ -511,8 +511,8 @@ def forward( Args: state (Tensor): the current state of shape (batch_size, *, stochastic_size + recurrent_state_size). - sample_actions (bool): whether or not to sample the actions. - Default to True. + greedy (bool): whether or not to sample the actions. + Default to False. mask (Dict[str, Tensor], optional): the action mask (which actions can be selected). Default to None. @@ -536,7 +536,7 @@ def forward( std = 2 * torch.sigmoid((std + self.init_std) / 2) + self.min_std dist = TruncatedNormal(torch.tanh(mean), std, -1, 1) actions_dist = Independent(dist, 1) - if sample_actions: + if not greedy: actions = actions_dist.rsample() else: sample = actions_dist.sample((100,)) @@ -549,7 +549,7 @@ def forward( actions: List[Tensor] = [] for logits in pre_dist: actions_dist.append(OneHotCategoricalStraightThrough(logits=logits)) - if sample_actions: + if not greedy: actions.append(actions_dist[-1].rsample()) else: actions.append(actions_dist[-1].mode) @@ -608,7 +608,7 @@ def __init__( ) def forward( - self, state: Tensor, sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None + self, state: Tensor, greedy: bool = False, mask: Optional[Dict[str, Tensor]] = None ) -> Tuple[Sequence[Tensor], Sequence[Distribution]]: """ Call the forward method of the actor model and reorganizes the result with shape (batch_size, *, num_actions), @@ -616,8 +616,8 @@ def forward( Args: state (Tensor): the current state of shape (batch_size, *, stochastic_size + recurrent_state_size). - sample_actions (bool): whether or not to sample the actions. - Default to True. + greedy (bool): whether or not to sample the actions. + Default to False. mask (Dict[str, Tensor], optional): the action mask (which actions can be selected). Default to None. @@ -652,7 +652,7 @@ def forward( elif sampled_action == 18: # Destroy action logits[t, b][torch.logical_not(mask["mask_destroy"][t, b])] = -torch.inf actions_dist.append(OneHotCategoricalStraightThrough(logits=logits)) - if sample_actions: + if not greedy: actions.append(actions_dist[-1].rsample()) else: actions.append(actions_dist[-1].mode) @@ -802,7 +802,7 @@ def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None: def get_actions( self, obs: Dict[str, Tensor], - sample_actions: bool = True, + greedy: bool = False, mask: Optional[Dict[str, Tensor]] = None, ) -> Sequence[Tensor]: """ @@ -810,8 +810,8 @@ def get_actions( Args: obs (Dict[str, Tensor]): the current observations. - sample_actions (bool): whether or not to sample the actions. - Default to True. + greedy (bool): whether or not to sample the actions. + Default to False. mask (Dict[str, Tensor], optional): the action mask (which actions can be selected). Default to None. @@ -827,7 +827,7 @@ def get_actions( self.stochastic_state = stochastic_state.view( *stochastic_state.shape[:-2], self.stochastic_size * self.discrete_size ) - actions, _ = self.actor(torch.cat((self.stochastic_state, self.recurrent_state), -1), sample_actions, mask) + actions, _ = self.actor(torch.cat((self.stochastic_state, self.recurrent_state), -1), greedy, mask) self.actions = torch.cat(actions, -1) return actions diff --git a/sheeprl/algos/dreamer_v2/evaluate.py b/sheeprl/algos/dreamer_v2/evaluate.py index 145b758d..8b5990b9 100644 --- a/sheeprl/algos/dreamer_v2/evaluate.py +++ b/sheeprl/algos/dreamer_v2/evaluate.py @@ -54,4 +54,4 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): state["actor"], ) del _ - test(player, fabric, cfg, log_dir, sample_actions=False) + test(player, fabric, cfg, log_dir, greedy=True) diff --git a/sheeprl/algos/dreamer_v2/utils.py b/sheeprl/algos/dreamer_v2/utils.py index f3a42f39..ed9debb0 100644 --- a/sheeprl/algos/dreamer_v2/utils.py +++ b/sheeprl/algos/dreamer_v2/utils.py @@ -108,7 +108,7 @@ def test( cfg: Dict[str, Any], log_dir: str, test_name: str = "", - sample_actions: bool = False, + greedy: bool = True, ): """Test the model on the environment with the frozen model. @@ -119,8 +119,8 @@ def test( log_dir (str): the logging directory. test_name (str): the name of the test. Default to "". - sample_actoins (bool): whether or not to sample actions. - Default to False. + greedy (bool): whether or not to sample actions. + Default to True. """ env: gym.Env = make_env(cfg, cfg.seed, 0, log_dir, "test" + (f"_{test_name}" if test_name != "" else ""))() done = False @@ -140,7 +140,7 @@ def test( elif k in cfg.algo.mlp_keys.encoder: preprocessed_obs[k] = v[None, ...].to(device) real_actions = player.get_actions( - preprocessed_obs, sample_actions, {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} + preprocessed_obs, greedy, {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} ) if player.actor.is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() diff --git a/sheeprl/algos/p2e_dv1/evaluate.py b/sheeprl/algos/p2e_dv1/evaluate.py index 212ead04..4381ae78 100644 --- a/sheeprl/algos/p2e_dv1/evaluate.py +++ b/sheeprl/algos/p2e_dv1/evaluate.py @@ -55,4 +55,4 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): actor_task_state=state["actor_task"], ) del _ - test(player, fabric, cfg, log_dir, sample_actions=False) + test(player, fabric, cfg, log_dir, greedy=True) diff --git a/sheeprl/algos/p2e_dv2/evaluate.py b/sheeprl/algos/p2e_dv2/evaluate.py index 8f49205c..5dcb9b15 100644 --- a/sheeprl/algos/p2e_dv2/evaluate.py +++ b/sheeprl/algos/p2e_dv2/evaluate.py @@ -55,4 +55,4 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): actor_task_state=state["actor_task"], ) del _ - test(player, fabric, cfg, log_dir, sample_actions=False) + test(player, fabric, cfg, log_dir, greedy=True)