Skip to content

Commit

Permalink
From sample_actions to greedy
Browse files Browse the repository at this point in the history
  • Loading branch information
belerico committed Apr 5, 2024
1 parent fe71391 commit 9736d3a
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 29 deletions.
16 changes: 8 additions & 8 deletions sheeprl/algos/dreamer_v1/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,14 +276,14 @@ def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None:
self.stochastic_state[:, reset_envs] = torch.zeros_like(self.stochastic_state[:, reset_envs])

def get_exploration_actions(
self, obs: Tensor, sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None, step: int = 0
self, obs: Tensor, greedy: bool = False, mask: Optional[Dict[str, Tensor]] = None, step: int = 0
) -> Sequence[Tensor]:
"""Return the actions with a certain amount of noise for exploration.
Args:
obs (Tensor): the current observations.
sample_actions (bool): whether or not to sample the actions.
Default to True.
greedy (bool): whether or not to sample the actions.
Default to False.
mask (Dict[str, Tensor], optional): the action mask (whether or not each action can be executed).
Defaults to None.
step (int): the step of the training, used for the exploration amount.
Expand All @@ -292,22 +292,22 @@ def get_exploration_actions(
Returns:
The actions the agent has to perform (Sequence[Tensor]).
"""
actions = self.get_actions(obs, sample_actions=sample_actions, mask=mask)
actions = self.get_actions(obs, greedy=greedy, mask=mask)
expl_actions = None
if self.actor._expl_amount > 0:
expl_actions = self.actor.add_exploration_noise(actions, step=step, mask=mask)
self.actions = torch.cat(expl_actions, dim=-1)
return expl_actions or actions

def get_actions(
self, obs: Tensor, sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None
self, obs: Tensor, greedy: bool = False, mask: Optional[Dict[str, Tensor]] = None
) -> Sequence[Tensor]:
"""Return the greedy actions.
Args:
obs (Tensor): the current observations.
sample_actions (bool): whether or not to sample the actions.
Default to True.
greedy (bool): whether or not to sample the actions.
Default to False.
mask (Dict[str, Tensor], optional): the action mask (whether or not each action can be executed).
Defaults to None.
Expand All @@ -321,7 +321,7 @@ def get_actions(
_, self.stochastic_state = compute_stochastic_state(
self.representation_model(torch.cat((self.recurrent_state, embedded_obs), -1)),
)
actions, _ = self.actor(torch.cat((self.stochastic_state, self.recurrent_state), -1), sample_actions, mask)
actions, _ = self.actor(torch.cat((self.stochastic_state, self.recurrent_state), -1), greedy, mask)
self.actions = torch.cat(actions, -1)
return actions

Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/dreamer_v1/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]):
state["actor"],
)
del _
test(player, fabric, cfg, log_dir, sample_actions=False)
test(player, fabric, cfg, log_dir, greedy=True)
26 changes: 13 additions & 13 deletions sheeprl/algos/dreamer_v2/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,16 +503,16 @@ def _get_expl_amount(self, step: int) -> Tensor:
return max(amount, self._expl_min)

def forward(
self, state: Tensor, sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None
self, state: Tensor, greedy: bool = False, mask: Optional[Dict[str, Tensor]] = None
) -> Tuple[Sequence[Tensor], Sequence[Distribution]]:
"""
Call the forward method of the actor model and reorganizes the result with shape (batch_size, *, num_actions),
where * means any number of dimensions including None.
Args:
state (Tensor): the current state of shape (batch_size, *, stochastic_size + recurrent_state_size).
sample_actions (bool): whether or not to sample the actions.
Default to True.
greedy (bool): whether or not to sample the actions.
Default to False.
mask (Dict[str, Tensor], optional): the action mask (which actions can be selected).
Default to None.
Expand All @@ -536,7 +536,7 @@ def forward(
std = 2 * torch.sigmoid((std + self.init_std) / 2) + self.min_std
dist = TruncatedNormal(torch.tanh(mean), std, -1, 1)
actions_dist = Independent(dist, 1)
if sample_actions:
if not greedy:
actions = actions_dist.rsample()
else:
sample = actions_dist.sample((100,))
Expand All @@ -549,7 +549,7 @@ def forward(
actions: List[Tensor] = []
for logits in pre_dist:
actions_dist.append(OneHotCategoricalStraightThrough(logits=logits))
if sample_actions:
if not greedy:
actions.append(actions_dist[-1].rsample())
else:
actions.append(actions_dist[-1].mode)
Expand Down Expand Up @@ -608,16 +608,16 @@ def __init__(
)

def forward(
self, state: Tensor, sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None
self, state: Tensor, greedy: bool = False, mask: Optional[Dict[str, Tensor]] = None
) -> Tuple[Sequence[Tensor], Sequence[Distribution]]:
"""
Call the forward method of the actor model and reorganizes the result with shape (batch_size, *, num_actions),
where * means any number of dimensions including None.
Args:
state (Tensor): the current state of shape (batch_size, *, stochastic_size + recurrent_state_size).
sample_actions (bool): whether or not to sample the actions.
Default to True.
greedy (bool): whether or not to sample the actions.
Default to False.
mask (Dict[str, Tensor], optional): the action mask (which actions can be selected).
Default to None.
Expand Down Expand Up @@ -652,7 +652,7 @@ def forward(
elif sampled_action == 18: # Destroy action
logits[t, b][torch.logical_not(mask["mask_destroy"][t, b])] = -torch.inf
actions_dist.append(OneHotCategoricalStraightThrough(logits=logits))
if sample_actions:
if not greedy:
actions.append(actions_dist[-1].rsample())
else:
actions.append(actions_dist[-1].mode)
Expand Down Expand Up @@ -802,16 +802,16 @@ def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None:
def get_actions(
self,
obs: Dict[str, Tensor],
sample_actions: bool = True,
greedy: bool = False,
mask: Optional[Dict[str, Tensor]] = None,
) -> Sequence[Tensor]:
"""
Return the greedy actions.
Args:
obs (Dict[str, Tensor]): the current observations.
sample_actions (bool): whether or not to sample the actions.
Default to True.
greedy (bool): whether or not to sample the actions.
Default to False.
mask (Dict[str, Tensor], optional): the action mask (which actions can be selected).
Default to None.
Expand All @@ -827,7 +827,7 @@ def get_actions(
self.stochastic_state = stochastic_state.view(
*stochastic_state.shape[:-2], self.stochastic_size * self.discrete_size
)
actions, _ = self.actor(torch.cat((self.stochastic_state, self.recurrent_state), -1), sample_actions, mask)
actions, _ = self.actor(torch.cat((self.stochastic_state, self.recurrent_state), -1), greedy, mask)
self.actions = torch.cat(actions, -1)
return actions

Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/dreamer_v2/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]):
state["actor"],
)
del _
test(player, fabric, cfg, log_dir, sample_actions=False)
test(player, fabric, cfg, log_dir, greedy=True)
8 changes: 4 additions & 4 deletions sheeprl/algos/dreamer_v2/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test(
cfg: Dict[str, Any],
log_dir: str,
test_name: str = "",
sample_actions: bool = False,
greedy: bool = True,
):
"""Test the model on the environment with the frozen model.
Expand All @@ -119,8 +119,8 @@ def test(
log_dir (str): the logging directory.
test_name (str): the name of the test.
Default to "".
sample_actoins (bool): whether or not to sample actions.
Default to False.
greedy (bool): whether or not to sample actions.
Default to True.
"""
env: gym.Env = make_env(cfg, cfg.seed, 0, log_dir, "test" + (f"_{test_name}" if test_name != "" else ""))()
done = False
Expand All @@ -140,7 +140,7 @@ def test(
elif k in cfg.algo.mlp_keys.encoder:
preprocessed_obs[k] = v[None, ...].to(device)
real_actions = player.get_actions(
preprocessed_obs, sample_actions, {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")}
preprocessed_obs, greedy, {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")}
)
if player.actor.is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/p2e_dv1/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,4 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]):
actor_task_state=state["actor_task"],
)
del _
test(player, fabric, cfg, log_dir, sample_actions=False)
test(player, fabric, cfg, log_dir, greedy=True)
2 changes: 1 addition & 1 deletion sheeprl/algos/p2e_dv2/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,4 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]):
actor_task_state=state["actor_task"],
)
del _
test(player, fabric, cfg, log_dir, sample_actions=False)
test(player, fabric, cfg, log_dir, greedy=True)

0 comments on commit 9736d3a

Please sign in to comment.