From 32a37361e5046d1068cf75e61aae91d724a000b6 Mon Sep 17 00:00:00 2001 From: Federico Belotti Date: Mon, 8 Apr 2024 10:54:45 +0200 Subject: [PATCH] Fix/player build agent (#258) * Decoupled RSSM for DV3 agent * Initialize posterior with prior if is_first is True * Fix PlayerDV3 creation in evaluation * Fix representation_model * Fix compute first prior state with a zero posterior * DV3 replay ratio conversion * Removed expl parameters dependent on old per_Rank_gradient_steps * feat: update repeats computation * feat: update learning starts in config * fix: remove files * feat: update repeats * Let Dv3 compute bootstrap correctly * feat: added replay ratio and update exploration * Fix exploration actions computation on DV1 * Fix naming * Add replay-ratio to SAC * feat: added replay ratio to p2e algos * feat: update configs and utils of p2e algos * Add replay-ratio to SAC-AE * Add DrOQ replay ratio * Fix tests * Fix mispelled * Fix wrong attribute accesing * FIx naming and configs * feat: add terminated and truncated to dreamer, p2e and ppo algos * fix: dmc wrapper * feat: update algos to split terminated from truncated * fix: crafter and diambra wrappers * feat: replace done with truncated key in when the buffer is added to the checkpoint * Set Distribution.validate_args once at the beginning * Move validate_args in run method * Defined PPOPlayer * Defined RecurrentPPOPlayer * feat: added truncated/terminated to minedojo environment * Add SACPlayer * FIx evaluate.py for PPO and RecurrentPPO * FIx DrOQ build_agent * Delete unused training agent during evaluation * Fix SACPlayer * Fix DrOQ build_agent * Fix PPO decoupled creating single-device fabric * Adapt SAC decoupled to new build_agent * Fix typings + add get_actions method * Add SACAEPlayer * Add A2CPlayer from PPOPlayer * Fix get_single_device_fabric * Fix calling get_values on player instead of agent * Setup PLayerDV1 in build_agent * Fix typings * DV2 player * Remove one weight tie * Fix DV3 player in build_agent * Fix return player from build_agent * Set actor_type during evaluation * Build player in build_agent for P2E-DV3 * Update comments * Update dreamer-v3 cfg * Learnable initial recurrent state choice * Fix DecoupledRSSM to accept the learnable_initial_recurrent_state flag * Preserve input dtype after LayerNorm (https://github.com/pytorch/pytorch/issues/66707#issuecomment-2028904230) * Fix imports * Move hyperparams to rightful key inside world_model * sample_actions to greedy * From sample_actions to greedy * From sample_actions to greedy * unwrap_fabric before test in p2e * Fix player in notebook * Update how-tos --------- Co-authored-by: Michele Milesi --- howto/register_external_algorithm.md | 334 +++++++++++++++++- howto/register_new_algorithm.md | 348 ++++++++++++++++++- notebooks/dreamer_v3_imagination.ipynb | 17 +- sheeprl/algos/a2c/a2c.py | 16 +- sheeprl/algos/a2c/agent.py | 23 +- sheeprl/algos/a2c/evaluate.py | 3 +- sheeprl/algos/a2c/utils.py | 9 +- sheeprl/algos/dreamer_v1/agent.py | 83 +++-- sheeprl/algos/dreamer_v1/dreamer_v1.py | 15 +- sheeprl/algos/dreamer_v1/evaluate.py | 19 +- sheeprl/algos/dreamer_v2/agent.py | 92 +++-- sheeprl/algos/dreamer_v2/dreamer_v2.py | 16 +- sheeprl/algos/dreamer_v2/evaluate.py | 20 +- sheeprl/algos/dreamer_v2/utils.py | 8 +- sheeprl/algos/dreamer_v3/agent.py | 97 +++--- sheeprl/algos/dreamer_v3/dreamer_v3.py | 18 +- sheeprl/algos/dreamer_v3/evaluate.py | 20 +- sheeprl/algos/dreamer_v3/utils.py | 8 +- sheeprl/algos/droq/agent.py | 34 +- sheeprl/algos/droq/droq.py | 9 +- sheeprl/algos/droq/evaluate.py | 5 +- sheeprl/algos/p2e_dv1/agent.py | 21 +- sheeprl/algos/p2e_dv1/evaluate.py | 23 +- sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py | 21 +- sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py | 24 +- sheeprl/algos/p2e_dv2/agent.py | 27 +- sheeprl/algos/p2e_dv2/evaluate.py | 25 +- sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py | 21 +- sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py | 25 +- sheeprl/algos/p2e_dv3/agent.py | 18 +- sheeprl/algos/p2e_dv3/evaluate.py | 25 +- sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py | 22 +- sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py | 35 +- sheeprl/algos/ppo/agent.py | 92 ++++- sheeprl/algos/ppo/evaluate.py | 3 +- sheeprl/algos/ppo/ppo.py | 10 +- sheeprl/algos/ppo/ppo_decoupled.py | 55 ++- sheeprl/algos/ppo/utils.py | 8 +- sheeprl/algos/ppo_recurrent/agent.py | 196 +++++++++-- sheeprl/algos/ppo_recurrent/evaluate.py | 3 +- sheeprl/algos/ppo_recurrent/ppo_recurrent.py | 14 +- sheeprl/algos/ppo_recurrent/utils.py | 8 +- sheeprl/algos/sac/agent.py | 88 ++++- sheeprl/algos/sac/evaluate.py | 5 +- sheeprl/algos/sac/sac.py | 25 +- sheeprl/algos/sac/sac_decoupled.py | 35 +- sheeprl/algos/sac/utils.py | 6 +- sheeprl/algos/sac_ae/agent.py | 99 +++++- sheeprl/algos/sac_ae/evaluate.py | 7 +- sheeprl/algos/sac_ae/sac_ae.py | 36 +- sheeprl/algos/sac_ae/utils.py | 6 +- sheeprl/configs/algo/ppo_recurrent.yaml | 12 +- sheeprl/utils/utils.py | 8 + 53 files changed, 1574 insertions(+), 623 deletions(-) diff --git a/howto/register_external_algorithm.md b/howto/register_external_algorithm.md index a4b94368..6d95973c 100644 --- a/howto/register_external_algorithm.md +++ b/howto/register_external_algorithm.md @@ -17,12 +17,13 @@ my_awesome_algo ## The agent -The agent is the core of the algorithm and it is defined in the `agent.py` file. It must contain at least single function called `build_agent` that returns a `torch.nn.Module` wrapped with Fabric: +The agent is the core of the algorithm and it is defined in the `agent.py` file. It must contain at least single function called `build_agent` that returns at least a tuple composed of two `torch.nn.Module` wrapped with Fabric; the first one is the agent used during the training phase, while the other one is the one used during the environment interaction: ```python from __future__ import annotations -from typing import Any, Dict, Sequence +import copy +from typing import Any, Dict, Sequence, Tuple import gymnasium from lightning import Fabric @@ -30,6 +31,8 @@ from lightning.fabric.wrappers import _FabricModule import torch from torch import Tensor +from sheeprl.utils.fabric import get_single_device_fabric + class SOTAAgent(torch.nn.Module): def __init__(self, ...): @@ -38,6 +41,16 @@ class SOTAAgent(torch.nn.Module): def forward(self, obs: Dict[str, torch.Tensor]) -> Tensor: ... +class SOTAAgentPlayer(torch.nn.Module): + def __init__(self, ...): + ... + + def forward(self, obs: Dict[str, torch.Tensor]) -> Tensor: + ... + + def get_actions(self, obs: Dict[str, torch.Tensor], greedy: bool = False) -> Tensor: + ... + def build_agent( fabric: Fabric, @@ -46,7 +59,7 @@ def build_agent( cfg: Dict[str, Any], observation_space: gymnasium.spaces.Dict, state: Dict[str, Any] | None = None, -) -> _FabricModule: +) -> Tuple[_FabricModule, _FabricModule]: # Define the agent here agent = SOTAAgent(...) @@ -55,12 +68,325 @@ def build_agent( if state: agent.load_state_dict(state) + # Setup player agent + player = copy.deepcopy(agent) + # Setup the agent with Fabric agent = fabric.setup_model(agent) - return agent + # Setup the player agent with a single-device Fabric + fabric_player = get_single_device_fabric(fabric) + player = fabric_player.setup_module(player) + + # Tie weights between the agent and the player + for agent_p, player_p in zip(agent.parameters(), player.parameters()): + player_p.data = agent_p.data + return agent, player ``` +The player agent is wrapped with a **single-device Fabric**, in this way we maintain the same precision and device of the main Fabric object, but with the player agent being able to interct with the environment skipping possible distributed synchronization points. + +If the agent is composed of multiple models, each one with its own forward method, it is advisable to wrap each one of them with the main Fabric object; the same happens for the player agent, where each of the models has to be consequently wrapped with the single-device Fabric obejct. Here we have the example of the **PPOAgent**: + +```python +from __future__ import annotations + +import copy +from math import prod +from typing import Any, Dict, List, Optional, Sequence, Tuple + +import gymnasium +import torch +import torch.nn as nn +from lightning import Fabric +from torch import Tensor +from torch.distributions import Distribution, Independent, Normal, OneHotCategorical + +from sheeprl.models.models import MLP, MultiEncoder, NatureCNN +from sheeprl.utils.fabric import get_single_device_fabric + + +class CNNEncoder(nn.Module): + def __init__( + self, + in_channels: int, + features_dim: int, + screen_size: int, + keys: Sequence[str], + ) -> None: + super().__init__() + self.keys = keys + self.input_dim = (in_channels, screen_size, screen_size) + self.output_dim = features_dim + self.model = NatureCNN(in_channels=in_channels, features_dim=features_dim, screen_size=screen_size) + + def forward(self, obs: Dict[str, Tensor]) -> Tensor: + x = torch.cat([obs[k] for k in self.keys], dim=-3) + return self.model(x) + + +class MLPEncoder(nn.Module): + def __init__( + self, + input_dim: int, + features_dim: int | None, + keys: Sequence[str], + dense_units: int = 64, + mlp_layers: int = 2, + dense_act: nn.Module = nn.ReLU, + layer_norm: bool = False, + ) -> None: + super().__init__() + self.keys = keys + self.input_dim = input_dim + self.output_dim = features_dim if features_dim else dense_units + self.model = MLP( + input_dim, + features_dim, + [dense_units] * mlp_layers, + activation=dense_act, + norm_layer=[nn.LayerNorm for _ in range(mlp_layers)] if layer_norm else None, + norm_args=[{"normalized_shape": dense_units} for _ in range(mlp_layers)] if layer_norm else None, + ) + + def forward(self, obs: Dict[str, Tensor]) -> Tensor: + x = torch.cat([obs[k] for k in self.keys], dim=-1) + return self.model(x) + + +class PPOActor(nn.Module): + def __init__(self, actor_backbone: torch.nn.Module, actor_heads: torch.nn.ModuleList, is_continuous: bool) -> None: + super().__init__() + self.actor_backbone = actor_backbone + self.actor_heads = actor_heads + self.is_continuous = is_continuous + + def forward(self, x: Tensor) -> List[Tensor]: + x = self.actor_backbone(x) + return [head(x) for head in self.actor_heads] + + +class PPOAgent(nn.Module): + def __init__( + self, + actions_dim: Sequence[int], + obs_space: gymnasium.spaces.Dict, + encoder_cfg: Dict[str, Any], + actor_cfg: Dict[str, Any], + critic_cfg: Dict[str, Any], + cnn_keys: Sequence[str], + mlp_keys: Sequence[str], + screen_size: int, + distribution_cfg: Dict[str, Any], + is_continuous: bool = False, + ): + super().__init__() + self.is_continuous = is_continuous + self.distribution_cfg = distribution_cfg + self.actions_dim = actions_dim + in_channels = sum([prod(obs_space[k].shape[:-2]) for k in cnn_keys]) + mlp_input_dim = sum([obs_space[k].shape[0] for k in mlp_keys]) + cnn_encoder = ( + CNNEncoder(in_channels, encoder_cfg.cnn_features_dim, screen_size, cnn_keys) + if cnn_keys is not None and len(cnn_keys) > 0 + else None + ) + mlp_encoder = ( + MLPEncoder( + mlp_input_dim, + encoder_cfg.mlp_features_dim, + mlp_keys, + encoder_cfg.dense_units, + encoder_cfg.mlp_layers, + eval(encoder_cfg.dense_act), + encoder_cfg.layer_norm, + ) + if mlp_keys is not None and len(mlp_keys) > 0 + else None + ) + self.feature_extractor = MultiEncoder(cnn_encoder, mlp_encoder) + features_dim = self.feature_extractor.output_dim + self.critic = MLP( + input_dims=features_dim, + output_dim=1, + hidden_sizes=[critic_cfg.dense_units] * critic_cfg.mlp_layers, + activation=eval(critic_cfg.dense_act), + norm_layer=[nn.LayerNorm for _ in range(critic_cfg.mlp_layers)] if critic_cfg.layer_norm else None, + norm_args=( + [{"normalized_shape": critic_cfg.dense_units} for _ in range(critic_cfg.mlp_layers)] + if critic_cfg.layer_norm + else None + ), + ) + actor_backbone = ( + MLP( + input_dims=features_dim, + output_dim=None, + hidden_sizes=[actor_cfg.dense_units] * actor_cfg.mlp_layers, + activation=eval(actor_cfg.dense_act), + flatten_dim=None, + norm_layer=[nn.LayerNorm] * actor_cfg.mlp_layers if actor_cfg.layer_norm else None, + norm_args=( + [{"normalized_shape": actor_cfg.dense_units} for _ in range(actor_cfg.mlp_layers)] + if actor_cfg.layer_norm + else None + ), + ) + if actor_cfg.mlp_layers > 0 + else nn.Identity() + ) + if is_continuous: + actor_heads = nn.ModuleList([nn.Linear(actor_cfg.dense_units, sum(actions_dim) * 2)]) + else: + actor_heads = nn.ModuleList([nn.Linear(actor_cfg.dense_units, action_dim) for action_dim in actions_dim]) + self.actor = PPOActor(actor_backbone, actor_heads, is_continuous) + + def forward( + self, obs: Dict[str, Tensor], actions: Optional[List[Tensor]] = None + ) -> Tuple[Sequence[Tensor], Tensor, Tensor, Tensor]: + feat = self.feature_extractor(obs) + actor_out: List[Tensor] = self.actor(feat) + values = self.critic(feat) + if self.is_continuous: + mean, log_std = torch.chunk(actor_out[0], chunks=2, dim=-1) + std = log_std.exp() + normal = Independent(Normal(mean, std), 1) + if actions is None: + actions = normal.sample() + else: + # always composed by a tuple of one element containing all the + # continuous actions + actions = actions[0] + log_prob = normal.log_prob(actions) + return tuple([actions]), log_prob.unsqueeze(dim=-1), normal.entropy().unsqueeze(dim=-1), values + else: + should_append = False + actions_logprobs: List[Tensor] = [] + actions_entropies: List[Tensor] = [] + actions_dist: List[Distribution] = [] + if actions is None: + should_append = True + actions: List[Tensor] = [] + for i, logits in enumerate(actor_out): + actions_dist.append(OneHotCategorical(logits=logits)) + actions_entropies.append(actions_dist[-1].entropy()) + if should_append: + actions.append(actions_dist[-1].sample()) + actions_logprobs.append(actions_dist[-1].log_prob(actions[i])) + return ( + tuple(actions), + torch.stack(actions_logprobs, dim=-1).sum(dim=-1, keepdim=True), + torch.stack(actions_entropies, dim=-1).sum(dim=-1, keepdim=True), + values, + ) + + +class PPOPlayer(nn.Module): + def __init__(self, feature_extractor: MultiEncoder, actor: PPOActor, critic: nn.Module) -> None: + super().__init__() + self.feature_extractor = feature_extractor + self.critic = critic + self.actor = actor + + def forward(self, obs: Dict[str, Tensor]) -> Tuple[Sequence[Tensor], Tensor, Tensor]: + feat = self.feature_extractor(obs) + values = self.critic(feat) + actor_out: List[Tensor] = self.actor(feat) + if self.actor.is_continuous: + mean, log_std = torch.chunk(actor_out[0], chunks=2, dim=-1) + std = log_std.exp() + normal = Independent(Normal(mean, std), 1) + actions = normal.sample() + log_prob = normal.log_prob(actions) + return tuple([actions]), log_prob.unsqueeze(dim=-1), values + else: + actions_dist: List[Distribution] = [] + actions_logprobs: List[Tensor] = [] + actions: List[Tensor] = [] + for i, logits in enumerate(actor_out): + actions_dist.append(OneHotCategorical(logits=logits)) + actions.append(actions_dist[-1].sample()) + actions_logprobs.append(actions_dist[-1].log_prob(actions[i])) + return ( + tuple(actions), + torch.stack(actions_logprobs, dim=-1).sum(dim=-1, keepdim=True), + values, + ) + + def get_values(self, obs: Dict[str, Tensor]) -> Tensor: + feat = self.feature_extractor(obs) + return self.critic(feat) + + def get_actions(self, obs: Dict[str, Tensor], greedy: bool = False) -> Sequence[Tensor]: + feat = self.feature_extractor(obs) + actor_out: List[Tensor] = self.actor(feat) + if self.actor.is_continuous: + mean, log_std = torch.chunk(actor_out[0], chunks=2, dim=-1) + if greedy: + actions = mean + else: + std = log_std.exp() + normal = Independent(Normal(mean, std), 1) + actions = normal.sample() + return tuple([actions]) + else: + actions: List[Tensor] = [] + actions_dist: List[Distribution] = [] + for logits in actor_out: + actions_dist.append(OneHotCategorical(logits=logits)) + if greedy: + actions.append(actions_dist[-1].mode) + else: + actions.append(actions_dist[-1].sample()) + return tuple(actions) + + +def build_agent( + fabric: Fabric, + actions_dim: Sequence[int], + is_continuous: bool, + cfg: Dict[str, Any], + obs_space: gymnasium.spaces.Dict, + agent_state: Optional[Dict[str, Tensor]] = None, +) -> Tuple[PPOAgent, PPOPlayer]: + agent = PPOAgent( + actions_dim=actions_dim, + obs_space=obs_space, + encoder_cfg=cfg.algo.encoder, + actor_cfg=cfg.algo.actor, + critic_cfg=cfg.algo.critic, + cnn_keys=cfg.algo.cnn_keys.encoder, + mlp_keys=cfg.algo.mlp_keys.encoder, + screen_size=cfg.env.screen_size, + distribution_cfg=cfg.distribution, + is_continuous=is_continuous, + ) + if agent_state: + agent.load_state_dict(agent_state) + + # Setup player agent + player = PPOPlayer(copy.deepcopy(agent.feature_extractor), copy.deepcopy(agent.actor), copy.deepcopy(agent.critic)) + + # Setup training agent + agent.feature_extractor = fabric.setup_module(agent.feature_extractor) + agent.critic = fabric.setup_module(agent.critic) + agent.actor = fabric.setup_module(agent.actor) + + # Setup player agent + fabric_player = get_single_device_fabric(fabric) + player.feature_extractor = fabric_player.setup_module(player.feature_extractor) + player.critic = fabric_player.setup_module(player.critic) + player.actor = fabric_player.setup_module(player.actor) + + # Tie weights between the agent and the player + for agent_p, player_p in zip(agent.feature_extractor.parameters(), player.feature_extractor.parameters()): + player_p.data = agent_p.data + for agent_p, player_p in zip(agent.actor.parameters(), player.actor.parameters()): + player_p.data = agent_p.data + for agent_p, player_p in zip(agent.critic.parameters(), player.critic.parameters()): + player_p.data = agent_p.data + return agent, player + ## Loss functions All the loss functions to be optimized by the agent during the training should be defined under the `loss.py` file, even though is not strictly necessary: diff --git a/howto/register_new_algorithm.md b/howto/register_new_algorithm.md index b3b5f6c8..362e21b5 100644 --- a/howto/register_new_algorithm.md +++ b/howto/register_new_algorithm.md @@ -16,12 +16,13 @@ algos ``` ## The agent -The agent is the core of the algorithm and it is defined in the `agent.py` file. It must contain at least single function called `build_agent` that returns a `torch.nn.Module` wrapped with Fabric: +The agent is the core of the algorithm and it is defined in the `agent.py` file. It must contain at least single function called `build_agent` that returns at least a tuple composed of two `torch.nn.Module` wrapped with Fabric; the first one is the agent used during the training phase, while the other one is the one used during the environment interaction: ```python from __future__ import annotations -from typing import Any, Dict, Sequence +import copy +from typing import Any, Dict, Sequence, Tuple import gymnasium from lightning import Fabric @@ -29,6 +30,8 @@ from lightning.fabric.wrappers import _FabricModule import torch from torch import Tensor +from sheeprl.utils.fabric import get_single_device_fabric + class SOTAAgent(torch.nn.Module): def __init__(self, ...): @@ -37,6 +40,16 @@ class SOTAAgent(torch.nn.Module): def forward(self, obs: Dict[str, torch.Tensor]) -> Tensor: ... +class SOTAAgentPlayer(torch.nn.Module): + def __init__(self, ...): + ... + + def forward(self, obs: Dict[str, torch.Tensor]) -> Tensor: + ... + + def get_actions(self, obs: Dict[str, torch.Tensor], greedy: bool = False) -> Tensor: + ... + def build_agent( fabric: Fabric, @@ -45,7 +58,7 @@ def build_agent( cfg: Dict[str, Any], observation_space: gymnasium.spaces.Dict, state: Dict[str, Any] | None = None, -) -> _FabricModule: +) -> Tuple[_FabricModule, _FabricModule]: # Define the agent here agent = SOTAAgent(...) @@ -54,10 +67,324 @@ def build_agent( if state: agent.load_state_dict(state) + # Setup player agent + player = copy.deepcopy(agent) + # Setup the agent with Fabric agent = fabric.setup_model(agent) - return agent + # Setup the player agent with a single-device Fabric + fabric_player = get_single_device_fabric(fabric) + player = fabric_player.setup_module(player) + + # Tie weights between the agent and the player + for agent_p, player_p in zip(agent.parameters(), player.parameters()): + player_p.data = agent_p.data + return agent, player +``` + +The player agent is wrapped with a **single-device Fabric**, in this way we maintain the same precision and device of the main Fabric object, but with the player agent being able to interct with the environment skipping possible distributed synchronization points. + +If the agent is composed of multiple models, each one with its own forward method, it is advisable to wrap each one of them with the main Fabric object; the same happens for the player agent, where each of the models has to be consequently wrapped with the single-device Fabric obejct. Here we have the example of the **PPOAgent**: + +```python +from __future__ import annotations + +import copy +from math import prod +from typing import Any, Dict, List, Optional, Sequence, Tuple + +import gymnasium +import torch +import torch.nn as nn +from lightning import Fabric +from torch import Tensor +from torch.distributions import Distribution, Independent, Normal, OneHotCategorical + +from sheeprl.models.models import MLP, MultiEncoder, NatureCNN +from sheeprl.utils.fabric import get_single_device_fabric + + +class CNNEncoder(nn.Module): + def __init__( + self, + in_channels: int, + features_dim: int, + screen_size: int, + keys: Sequence[str], + ) -> None: + super().__init__() + self.keys = keys + self.input_dim = (in_channels, screen_size, screen_size) + self.output_dim = features_dim + self.model = NatureCNN(in_channels=in_channels, features_dim=features_dim, screen_size=screen_size) + + def forward(self, obs: Dict[str, Tensor]) -> Tensor: + x = torch.cat([obs[k] for k in self.keys], dim=-3) + return self.model(x) + + +class MLPEncoder(nn.Module): + def __init__( + self, + input_dim: int, + features_dim: int | None, + keys: Sequence[str], + dense_units: int = 64, + mlp_layers: int = 2, + dense_act: nn.Module = nn.ReLU, + layer_norm: bool = False, + ) -> None: + super().__init__() + self.keys = keys + self.input_dim = input_dim + self.output_dim = features_dim if features_dim else dense_units + self.model = MLP( + input_dim, + features_dim, + [dense_units] * mlp_layers, + activation=dense_act, + norm_layer=[nn.LayerNorm for _ in range(mlp_layers)] if layer_norm else None, + norm_args=[{"normalized_shape": dense_units} for _ in range(mlp_layers)] if layer_norm else None, + ) + + def forward(self, obs: Dict[str, Tensor]) -> Tensor: + x = torch.cat([obs[k] for k in self.keys], dim=-1) + return self.model(x) + + +class PPOActor(nn.Module): + def __init__(self, actor_backbone: torch.nn.Module, actor_heads: torch.nn.ModuleList, is_continuous: bool) -> None: + super().__init__() + self.actor_backbone = actor_backbone + self.actor_heads = actor_heads + self.is_continuous = is_continuous + + def forward(self, x: Tensor) -> List[Tensor]: + x = self.actor_backbone(x) + return [head(x) for head in self.actor_heads] + + +class PPOAgent(nn.Module): + def __init__( + self, + actions_dim: Sequence[int], + obs_space: gymnasium.spaces.Dict, + encoder_cfg: Dict[str, Any], + actor_cfg: Dict[str, Any], + critic_cfg: Dict[str, Any], + cnn_keys: Sequence[str], + mlp_keys: Sequence[str], + screen_size: int, + distribution_cfg: Dict[str, Any], + is_continuous: bool = False, + ): + super().__init__() + self.is_continuous = is_continuous + self.distribution_cfg = distribution_cfg + self.actions_dim = actions_dim + in_channels = sum([prod(obs_space[k].shape[:-2]) for k in cnn_keys]) + mlp_input_dim = sum([obs_space[k].shape[0] for k in mlp_keys]) + cnn_encoder = ( + CNNEncoder(in_channels, encoder_cfg.cnn_features_dim, screen_size, cnn_keys) + if cnn_keys is not None and len(cnn_keys) > 0 + else None + ) + mlp_encoder = ( + MLPEncoder( + mlp_input_dim, + encoder_cfg.mlp_features_dim, + mlp_keys, + encoder_cfg.dense_units, + encoder_cfg.mlp_layers, + eval(encoder_cfg.dense_act), + encoder_cfg.layer_norm, + ) + if mlp_keys is not None and len(mlp_keys) > 0 + else None + ) + self.feature_extractor = MultiEncoder(cnn_encoder, mlp_encoder) + features_dim = self.feature_extractor.output_dim + self.critic = MLP( + input_dims=features_dim, + output_dim=1, + hidden_sizes=[critic_cfg.dense_units] * critic_cfg.mlp_layers, + activation=eval(critic_cfg.dense_act), + norm_layer=[nn.LayerNorm for _ in range(critic_cfg.mlp_layers)] if critic_cfg.layer_norm else None, + norm_args=( + [{"normalized_shape": critic_cfg.dense_units} for _ in range(critic_cfg.mlp_layers)] + if critic_cfg.layer_norm + else None + ), + ) + actor_backbone = ( + MLP( + input_dims=features_dim, + output_dim=None, + hidden_sizes=[actor_cfg.dense_units] * actor_cfg.mlp_layers, + activation=eval(actor_cfg.dense_act), + flatten_dim=None, + norm_layer=[nn.LayerNorm] * actor_cfg.mlp_layers if actor_cfg.layer_norm else None, + norm_args=( + [{"normalized_shape": actor_cfg.dense_units} for _ in range(actor_cfg.mlp_layers)] + if actor_cfg.layer_norm + else None + ), + ) + if actor_cfg.mlp_layers > 0 + else nn.Identity() + ) + if is_continuous: + actor_heads = nn.ModuleList([nn.Linear(actor_cfg.dense_units, sum(actions_dim) * 2)]) + else: + actor_heads = nn.ModuleList([nn.Linear(actor_cfg.dense_units, action_dim) for action_dim in actions_dim]) + self.actor = PPOActor(actor_backbone, actor_heads, is_continuous) + + def forward( + self, obs: Dict[str, Tensor], actions: Optional[List[Tensor]] = None + ) -> Tuple[Sequence[Tensor], Tensor, Tensor, Tensor]: + feat = self.feature_extractor(obs) + actor_out: List[Tensor] = self.actor(feat) + values = self.critic(feat) + if self.is_continuous: + mean, log_std = torch.chunk(actor_out[0], chunks=2, dim=-1) + std = log_std.exp() + normal = Independent(Normal(mean, std), 1) + if actions is None: + actions = normal.sample() + else: + # always composed by a tuple of one element containing all the + # continuous actions + actions = actions[0] + log_prob = normal.log_prob(actions) + return tuple([actions]), log_prob.unsqueeze(dim=-1), normal.entropy().unsqueeze(dim=-1), values + else: + should_append = False + actions_logprobs: List[Tensor] = [] + actions_entropies: List[Tensor] = [] + actions_dist: List[Distribution] = [] + if actions is None: + should_append = True + actions: List[Tensor] = [] + for i, logits in enumerate(actor_out): + actions_dist.append(OneHotCategorical(logits=logits)) + actions_entropies.append(actions_dist[-1].entropy()) + if should_append: + actions.append(actions_dist[-1].sample()) + actions_logprobs.append(actions_dist[-1].log_prob(actions[i])) + return ( + tuple(actions), + torch.stack(actions_logprobs, dim=-1).sum(dim=-1, keepdim=True), + torch.stack(actions_entropies, dim=-1).sum(dim=-1, keepdim=True), + values, + ) + + +class PPOPlayer(nn.Module): + def __init__(self, feature_extractor: MultiEncoder, actor: PPOActor, critic: nn.Module) -> None: + super().__init__() + self.feature_extractor = feature_extractor + self.critic = critic + self.actor = actor + + def forward(self, obs: Dict[str, Tensor]) -> Tuple[Sequence[Tensor], Tensor, Tensor]: + feat = self.feature_extractor(obs) + values = self.critic(feat) + actor_out: List[Tensor] = self.actor(feat) + if self.actor.is_continuous: + mean, log_std = torch.chunk(actor_out[0], chunks=2, dim=-1) + std = log_std.exp() + normal = Independent(Normal(mean, std), 1) + actions = normal.sample() + log_prob = normal.log_prob(actions) + return tuple([actions]), log_prob.unsqueeze(dim=-1), values + else: + actions_dist: List[Distribution] = [] + actions_logprobs: List[Tensor] = [] + actions: List[Tensor] = [] + for i, logits in enumerate(actor_out): + actions_dist.append(OneHotCategorical(logits=logits)) + actions.append(actions_dist[-1].sample()) + actions_logprobs.append(actions_dist[-1].log_prob(actions[i])) + return ( + tuple(actions), + torch.stack(actions_logprobs, dim=-1).sum(dim=-1, keepdim=True), + values, + ) + + def get_values(self, obs: Dict[str, Tensor]) -> Tensor: + feat = self.feature_extractor(obs) + return self.critic(feat) + + def get_actions(self, obs: Dict[str, Tensor], greedy: bool = False) -> Sequence[Tensor]: + feat = self.feature_extractor(obs) + actor_out: List[Tensor] = self.actor(feat) + if self.actor.is_continuous: + mean, log_std = torch.chunk(actor_out[0], chunks=2, dim=-1) + if greedy: + actions = mean + else: + std = log_std.exp() + normal = Independent(Normal(mean, std), 1) + actions = normal.sample() + return tuple([actions]) + else: + actions: List[Tensor] = [] + actions_dist: List[Distribution] = [] + for logits in actor_out: + actions_dist.append(OneHotCategorical(logits=logits)) + if greedy: + actions.append(actions_dist[-1].mode) + else: + actions.append(actions_dist[-1].sample()) + return tuple(actions) + + +def build_agent( + fabric: Fabric, + actions_dim: Sequence[int], + is_continuous: bool, + cfg: Dict[str, Any], + obs_space: gymnasium.spaces.Dict, + agent_state: Optional[Dict[str, Tensor]] = None, +) -> Tuple[PPOAgent, PPOPlayer]: + agent = PPOAgent( + actions_dim=actions_dim, + obs_space=obs_space, + encoder_cfg=cfg.algo.encoder, + actor_cfg=cfg.algo.actor, + critic_cfg=cfg.algo.critic, + cnn_keys=cfg.algo.cnn_keys.encoder, + mlp_keys=cfg.algo.mlp_keys.encoder, + screen_size=cfg.env.screen_size, + distribution_cfg=cfg.distribution, + is_continuous=is_continuous, + ) + if agent_state: + agent.load_state_dict(agent_state) + + # Setup player agent + player = PPOPlayer(copy.deepcopy(agent.feature_extractor), copy.deepcopy(agent.actor), copy.deepcopy(agent.critic)) + + # Setup training agent + agent.feature_extractor = fabric.setup_module(agent.feature_extractor) + agent.critic = fabric.setup_module(agent.critic) + agent.actor = fabric.setup_module(agent.actor) + + # Setup player agent + fabric_player = get_single_device_fabric(fabric) + player.feature_extractor = fabric_player.setup_module(player.feature_extractor) + player.critic = fabric_player.setup_module(player.critic) + player.actor = fabric_player.setup_module(player.actor) + + # Tie weights between the agent and the player + for agent_p, player_p in zip(agent.feature_extractor.parameters(), player.feature_extractor.parameters()): + player_p.data = agent_p.data + for agent_p, player_p in zip(agent.actor.parameters(), player.actor.parameters()): + player_p.data = agent_p.data + for agent_p, player_p in zip(agent.critic.parameters(), player.critic.parameters()): + player_p.data = agent_p.data + return agent, player ``` ## Loss functions @@ -181,7 +508,7 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]): # Given that the environment has been created with the `make_env` method, the agent # forward method must accept as input a dictionary like {"obs1_name": obs1, "obs2_name": obs2, ...}. # The agent should be able to process both image and vector-like observations. - agent = build_agent( + agent, player = build_agent( fabric, actions_dim, is_continuous, @@ -273,7 +600,7 @@ def sota_main(fabric: Fabric, cfg: Dict[str, Any]): torch_obs = { k: torch.as_tensor(normalized_obs[k], dtype=torch.float32, device=device) for k in obs_keys } - actions = agent.module(torch_obs) + actions = player.get_actions(torch_obs) if is_continuous: real_actions = torch.cat(actions, -1).cpu().numpy() else: @@ -392,7 +719,7 @@ import torch from lightning import Fabric from lightning.fabric.wrappers import _FabricModule -from sheeprl.algos.sota.agent import SOTAAgent +from sheeprl.algos.sota.agent import SOTAAgentPlayer from sheeprl.utils.imports import _IS_MLFLOW_AVAILABLE from sheeprl.utils.utils import unwrap_fabric @@ -401,7 +728,7 @@ if TYPE_CHECKING: @torch.no_grad() -def test(agent: SOTAAgent, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): +def test(agent: SOTAAgentPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)() agent.eval() done = False @@ -419,10 +746,11 @@ def test(agent: SOTAAgent, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): while not done: # Act greedly through the environment + actions = agent.get_actions(obs, greedy=True) if agent.is_continuous: - actions = torch.cat(agent.get_greedy_actions(obs), dim=-1) + actions = torch.cat(actions, dim=-1) else: - actions = torch.cat([act.argmax(dim=-1) for act in agent.get_greedy_actions(obs)], dim=-1) + actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1) # Single environment step o, reward, done, truncated, _ = env.step(actions.cpu().numpy().reshape(env.action_space.shape)) diff --git a/notebooks/dreamer_v3_imagination.ipynb b/notebooks/dreamer_v3_imagination.ipynb index c58531f2..3ae85a77 100644 --- a/notebooks/dreamer_v3_imagination.ipynb +++ b/notebooks/dreamer_v3_imagination.ipynb @@ -51,7 +51,7 @@ "from omegaconf import OmegaConf\n", "from PIL import Image\n", "\n", - "from sheeprl.algos.dreamer_v3.agent import PlayerDV3, build_agent\n", + "from sheeprl.algos.dreamer_v3.agent import build_agent\n", "from sheeprl.data.buffers import SequentialReplayBuffer\n", "from sheeprl.utils.env import make_env\n", "from sheeprl.utils.utils import dotdict" @@ -128,7 +128,7 @@ "actions_dim = tuple(\n", " action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])\n", ")\n", - "world_model, actor, critic, critic_target = build_agent(\n", + "world_model, actor, critic, critic_target, player = build_agent(\n", " fabric,\n", " actions_dim,\n", " is_continuous,\n", @@ -138,17 +138,6 @@ " state[\"actor\"],\n", " state[\"critic\"],\n", " state[\"target_critic\"],\n", - ")\n", - "player = PlayerDV3(\n", - " world_model.encoder.module,\n", - " world_model.rssm,\n", - " actor.module,\n", - " actions_dim,\n", - " cfg.env.num_envs,\n", - " cfg.algo.world_model.stochastic_size,\n", - " cfg.algo.world_model.recurrent_model.recurrent_state_size,\n", - " fabric.device,\n", - " cfg.algo.world_model.discrete_size,\n", ")" ] }, @@ -230,7 +219,7 @@ " mask = {k: v for k, v in preprocessed_obs.items() if k.startswith(\"mask\")}\n", " if len(mask) == 0:\n", " mask = None\n", - " real_actions = actions = player.get_actions(preprocessed_obs, mask)\n", + " real_actions = actions = player.get_actions(preprocessed_obs, mask=mask)\n", " actions = torch.cat(actions, -1).cpu().numpy()\n", " if is_continuous:\n", " real_actions = torch.cat(real_actions, dim=-1).cpu().numpy()\n", diff --git a/sheeprl/algos/a2c/a2c.py b/sheeprl/algos/a2c/a2c.py index 8d45d0ed..6d0cdfd4 100644 --- a/sheeprl/algos/a2c/a2c.py +++ b/sheeprl/algos/a2c/a2c.py @@ -10,7 +10,7 @@ from torch.utils.data import BatchSampler, DistributedSampler, RandomSampler from torchmetrics import SumMetric -from sheeprl.algos.a2c.agent import build_agent +from sheeprl.algos.a2c.agent import A2CAgent, build_agent from sheeprl.algos.a2c.loss import policy_loss, value_loss from sheeprl.algos.a2c.utils import test from sheeprl.data import ReplayBuffer @@ -24,7 +24,7 @@ def train( fabric: Fabric, - agent: torch.nn.Module, + agent: A2CAgent, optimizer: torch.optim.Optimizer, data: Dict[str, torch.Tensor], aggregator: MetricAggregator, @@ -67,7 +67,9 @@ def train( # is_accumulating is True for every i except for the last one is_accumulating = i < len(sampler) - 1 - with fabric.no_backward_sync(agent, enabled=is_accumulating): + with fabric.no_backward_sync(agent.feature_extractor, enabled=is_accumulating), fabric.no_backward_sync( + agent.actor, enabled=is_accumulating + ), fabric.no_backward_sync(agent.critic, enabled=is_accumulating): _, logprobs, values = agent(obs, torch.split(batch["actions"], agent.actions_dim, dim=-1)) # Policy loss @@ -262,10 +264,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): torch_v = torch_v.view(-1, *v.shape[-2:]) torch_v = torch_v / 255.0 - 0.5 real_next_obs[k][i] = torch_v - _, _, vals = player(real_next_obs) - rewards[truncated_envs] += cfg.algo.gamma * vals.cpu().numpy().reshape( - rewards[truncated_envs].shape - ) + vals = player.get_values(real_next_obs).cpu().numpy() + rewards[truncated_envs] += cfg.algo.gamma * vals.reshape(rewards[truncated_envs].shape) dones = np.logical_or(terminated, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8) rewards = rewards.reshape(cfg.env.num_envs, -1) @@ -305,7 +305,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) with torch.inference_mode(): torch_obs = {k: torch.as_tensor(next_obs[k], dtype=torch.float32, device=device) for k in obs_keys} - _, _, next_values = player(torch_obs) + next_values = player.get_values(torch_obs) returns, advantages = gae( local_data["rewards"].to(torch.float64), local_data["values"], diff --git a/sheeprl/algos/a2c/agent.py b/sheeprl/algos/a2c/agent.py index 74fea68c..6cba771d 100644 --- a/sheeprl/algos/a2c/agent.py +++ b/sheeprl/algos/a2c/agent.py @@ -7,11 +7,10 @@ import torch import torch.nn as nn from lightning import Fabric -from lightning.fabric.wrappers import _FabricModule from torch import Tensor from torch.distributions import Distribution, Independent, Normal, OneHotCategorical -from sheeprl.algos.ppo.agent import PPOActor +from sheeprl.algos.ppo.agent import PPOActor, PPOPlayer from sheeprl.models.models import MLP from sheeprl.utils.fabric import get_single_device_fabric @@ -161,7 +160,7 @@ def build_agent( cfg: Dict[str, Any], obs_space: gymnasium.spaces.Dict, agent_state: Optional[Dict[str, Tensor]] = None, -) -> Tuple[_FabricModule, _FabricModule]: +) -> Tuple[A2CAgent, PPOPlayer]: agent = A2CAgent( actions_dim=actions_dim, obs_space=obs_space, @@ -174,16 +173,26 @@ def build_agent( ) if agent_state: agent.load_state_dict(agent_state) - player = copy.deepcopy(agent) + + # Setup player agent + player = PPOPlayer(copy.deepcopy(agent.feature_extractor), copy.deepcopy(agent.actor), copy.deepcopy(agent.critic)) # Setup training agent - agent = fabric.setup_module(agent) + agent.feature_extractor = fabric.setup_module(agent.feature_extractor) + agent.critic = fabric.setup_module(agent.critic) + agent.actor = fabric.setup_module(agent.actor) # Setup player agent fabric_player = get_single_device_fabric(fabric) - player = fabric_player.setup_module(player) + player.feature_extractor = fabric_player.setup_module(player.feature_extractor) + player.critic = fabric_player.setup_module(player.critic) + player.actor = fabric_player.setup_module(player.actor) # Tie weights between the agent and the player - for agent_p, player_p in zip(agent.parameters(), player.parameters()): + for agent_p, player_p in zip(agent.feature_extractor.parameters(), player.feature_extractor.parameters()): + player_p.data = agent_p.data + for agent_p, player_p in zip(agent.actor.parameters(), player.actor.parameters()): + player_p.data = agent_p.data + for agent_p, player_p in zip(agent.critic.parameters(), player.critic.parameters()): player_p.data = agent_p.data return agent, player diff --git a/sheeprl/algos/a2c/evaluate.py b/sheeprl/algos/a2c/evaluate.py index c3ee6a3b..0c2ecd9f 100644 --- a/sheeprl/algos/a2c/evaluate.py +++ b/sheeprl/algos/a2c/evaluate.py @@ -54,5 +54,6 @@ def evaluate_a2c(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): else (env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]) ) # Create the actor and critic models - agent, _ = build_agent(fabric, actions_dim, is_continuous, cfg, observation_space, state["agent"]) + _, agent = build_agent(fabric, actions_dim, is_continuous, cfg, observation_space, state["agent"]) + del _ test(agent, fabric, cfg, log_dir) diff --git a/sheeprl/algos/a2c/utils.py b/sheeprl/algos/a2c/utils.py index 23dd4bf6..39fb3e93 100644 --- a/sheeprl/algos/a2c/utils.py +++ b/sheeprl/algos/a2c/utils.py @@ -4,16 +4,15 @@ import torch from lightning import Fabric -from lightning.fabric.wrappers import _FabricModule -from sheeprl.algos.a2c.agent import A2CAgent +from sheeprl.algos.ppo.agent import PPOPlayer from sheeprl.utils.env import make_env AGGREGATOR_KEYS = {"Rewards/rew_avg", "Game/ep_len_avg", "Loss/value_loss", "Loss/policy_loss"} @torch.no_grad() -def test(agent: A2CAgent | _FabricModule, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): +def test(agent: PPOPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)() agent.eval() done = False @@ -28,8 +27,8 @@ def test(agent: A2CAgent | _FabricModule, fabric: Fabric, cfg: Dict[str, Any], l while not done: # Act greedly through the environment - actions, _, _ = agent(obs, greedy=True) - if agent.is_continuous: + actions = agent.get_actions(obs, greedy=True) + if agent.actor.is_continuous: actions = torch.cat(actions, dim=-1) else: actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1) diff --git a/sheeprl/algos/dreamer_v1/agent.py b/sheeprl/algos/dreamer_v1/agent.py index 94353e6e..a654e135 100644 --- a/sheeprl/algos/dreamer_v1/agent.py +++ b/sheeprl/algos/dreamer_v1/agent.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy from typing import Any, Dict, Optional, Sequence, Tuple import gymnasium @@ -219,7 +220,6 @@ class PlayerDV1(nn.Module): """The model of the DreamerV1 player. Args: - fabric (Fabric): the fabric object. encoder (nn.Module| _FabricModule): the encoder. recurrent_model (nn.Module| _FabricModule): the recurrent model. representation_model (nn.Module| _FabricModule): the representation model. @@ -228,44 +228,35 @@ class PlayerDV1(nn.Module): num_envs (int): the number of environments. stochastic_size (int): the size of the stochastic state. recurrent_state_size (int): the size of the recurrent state. + device (str | torch.device): the device where the model is stored. actor_type (str, optional): which actor the player is using ('task' or 'exploration'). Default to None. """ def __init__( self, - fabric: Fabric, - encoder: nn.Module | _FabricModule, - recurrent_model: nn.Module | _FabricModule, - representation_model: nn.Module | _FabricModule, + encoder: MultiEncoder | _FabricModule, + recurrent_model: RecurrentModel | _FabricModule, + representation_model: MLP | _FabricModule, actor: Actor | _FabricModule, actions_dim: Sequence[int], num_envs: int, stochastic_size: int, recurrent_state_size: int, + device: str | torch.device, actor_type: str | None = None, ) -> None: super().__init__() - single_device_fabric = get_single_device_fabric(fabric) - self.encoder = single_device_fabric.setup_module( - getattr(encoder, "module", encoder), - ) - self.recurrent_model = single_device_fabric.setup_module( - getattr(recurrent_model, "module", recurrent_model), - ) - self.representation_model = single_device_fabric.setup_module( - getattr(representation_model, "module", representation_model) - ) - self.actor = single_device_fabric.setup_module( - getattr(actor, "module", actor), - ) - self.device = single_device_fabric.device + self.encoder = encoder + self.recurrent_model = recurrent_model + self.representation_model = representation_model + self.actor = actor self.actions_dim = actions_dim + self.num_envs = num_envs self.stochastic_size = stochastic_size self.recurrent_state_size = recurrent_state_size - self.num_envs = num_envs + self.device = device self.actor_type = actor_type - self.init_states() def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None: """Initialize the states and the actions for the ended environments. @@ -285,14 +276,14 @@ def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None: self.stochastic_state[:, reset_envs] = torch.zeros_like(self.stochastic_state[:, reset_envs]) def get_exploration_actions( - self, obs: Tensor, sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None, step: int = 0 + self, obs: Tensor, greedy: bool = False, mask: Optional[Dict[str, Tensor]] = None, step: int = 0 ) -> Sequence[Tensor]: """Return the actions with a certain amount of noise for exploration. Args: obs (Tensor): the current observations. - sample_actions (bool): whether or not to sample the actions. - Default to True. + greedy (bool): whether or not to sample the actions. + Default to False. mask (Dict[str, Tensor], optional): the action mask (whether or not each action can be executed). Defaults to None. step (int): the step of the training, used for the exploration amount. @@ -301,7 +292,7 @@ def get_exploration_actions( Returns: The actions the agent has to perform (Sequence[Tensor]). """ - actions = self.get_actions(obs, sample_actions=sample_actions, mask=mask) + actions = self.get_actions(obs, greedy=greedy, mask=mask) expl_actions = None if self.actor._expl_amount > 0: expl_actions = self.actor.add_exploration_noise(actions, step=step, mask=mask) @@ -309,14 +300,14 @@ def get_exploration_actions( return expl_actions or actions def get_actions( - self, obs: Tensor, sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None + self, obs: Tensor, greedy: bool = False, mask: Optional[Dict[str, Tensor]] = None ) -> Sequence[Tensor]: """Return the greedy actions. Args: obs (Tensor): the current observations. - sample_actions (bool): whether or not to sample the actions. - Default to True. + greedy (bool): whether or not to sample the actions. + Default to False. mask (Dict[str, Tensor], optional): the action mask (whether or not each action can be executed). Defaults to None. @@ -330,7 +321,7 @@ def get_actions( _, self.stochastic_state = compute_stochastic_state( self.representation_model(torch.cat((self.recurrent_state, embedded_obs), -1)), ) - actions, _ = self.actor(torch.cat((self.stochastic_state, self.recurrent_state), -1), sample_actions, mask) + actions, _ = self.actor(torch.cat((self.stochastic_state, self.recurrent_state), -1), greedy, mask) self.actions = torch.cat(actions, -1) return actions @@ -344,7 +335,7 @@ def build_agent( world_model_state: Optional[Dict[str, Tensor]] = None, actor_state: Optional[Dict[str, Tensor]] = None, critic_state: Optional[Dict[str, Tensor]] = None, -) -> Tuple[WorldModel, _FabricModule, _FabricModule]: +) -> Tuple[WorldModel, _FabricModule, _FabricModule, PlayerDV1]: """Build the models and wrap them with Fabric. Args: @@ -365,6 +356,7 @@ def build_agent( reward models and the continue model. The actor (_FabricModule). The critic (_FabricModule). + The player (PlayerDV1). """ world_model_cfg = cfg.algo.world_model actor_cfg = cfg.algo.actor @@ -511,6 +503,20 @@ def build_agent( if critic_state: critic.load_state_dict(critic_state) + # Create the player agent + fabric_player = get_single_device_fabric(fabric) + player = PlayerDV1( + copy.deepcopy(world_model.encoder), + copy.deepcopy(world_model.rssm.recurrent_model), + copy.deepcopy(world_model.rssm.representation_model), + copy.deepcopy(actor), + actions_dim, + cfg.env.num_envs, + cfg.algo.world_model.stochastic_size, + cfg.algo.world_model.recurrent_model.recurrent_state_size, + fabric_player.device, + ) + # Setup models with Fabric world_model.encoder = fabric.setup_module(world_model.encoder) world_model.observation_model = fabric.setup_module(world_model.observation_model) @@ -523,4 +529,19 @@ def build_agent( actor = fabric.setup_module(actor) critic = fabric.setup_module(critic) - return world_model, actor, critic + # Setup the player agent with a single-device Fabric + player.encoder = fabric_player.setup_module(player.encoder) + player.recurrent_model = fabric_player.setup_module(player.recurrent_model) + player.representation_model = fabric_player.setup_module(player.representation_model) + player.actor = fabric_player.setup_module(player.actor) + + # Tie weights between the agent and the player + for agent_p, p in zip(world_model.encoder.parameters(), player.encoder.parameters()): + p.data = agent_p.data + for agent_p, p in zip(world_model.rssm.recurrent_model.parameters(), player.recurrent_model.parameters()): + p.data = agent_p.data + for agent_p, p in zip(world_model.rssm.representation_model.parameters(), player.representation_model.parameters()): + p.data = agent_p.data + for agent_p, p in zip(actor.parameters(), player.actor.parameters()): + p.data = agent_p.data + return world_model, actor, critic, player diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index 381cc2b6..664769b6 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -17,7 +17,7 @@ from torch.distributions.utils import logits_to_probs from torchmetrics import SumMetric -from sheeprl.algos.dreamer_v1.agent import PlayerDV1, WorldModel, build_agent +from sheeprl.algos.dreamer_v1.agent import WorldModel, build_agent from sheeprl.algos.dreamer_v1.loss import actor_loss, critic_loss, reconstruction_loss from sheeprl.algos.dreamer_v1.utils import compute_lambda_values from sheeprl.algos.dreamer_v2.utils import test @@ -442,7 +442,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.print("Decoder MLP keys:", cfg.algo.mlp_keys.decoder) obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder - world_model, actor, critic = build_agent( + world_model, actor, critic, player = build_agent( fabric, actions_dim, is_continuous, @@ -452,17 +452,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): state["actor"] if cfg.checkpoint.resume_from else None, state["critic"] if cfg.checkpoint.resume_from else None, ) - player = PlayerDV1( - fabric, - world_model.encoder, - world_model.rssm.recurrent_model, - world_model.rssm.representation_model, - actor, - actions_dim, - cfg.env.num_envs, - cfg.algo.world_model.stochastic_size, - cfg.algo.world_model.recurrent_model.recurrent_state_size, - ) # Optimizers world_optimizer = hydra.utils.instantiate( diff --git a/sheeprl/algos/dreamer_v1/evaluate.py b/sheeprl/algos/dreamer_v1/evaluate.py index 4481a501..ae16efac 100644 --- a/sheeprl/algos/dreamer_v1/evaluate.py +++ b/sheeprl/algos/dreamer_v1/evaluate.py @@ -5,7 +5,7 @@ import gymnasium as gym from lightning import Fabric -from sheeprl.algos.dreamer_v1.agent import PlayerDV1, build_agent +from sheeprl.algos.dreamer_v1.agent import build_agent from sheeprl.algos.dreamer_v2.utils import test from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger @@ -44,7 +44,7 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) # Create the actor and critic models - world_model, actor, _ = build_agent( + _, _, _, player = build_agent( fabric, actions_dim, is_continuous, @@ -53,16 +53,5 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): state["world_model"], state["actor"], ) - player = PlayerDV1( - fabric, - world_model.encoder, - world_model.rssm.recurrent_model, - world_model.rssm.representation_model, - actor, - actions_dim, - cfg.env.num_envs, - cfg.algo.world_model.stochastic_size, - cfg.algo.world_model.recurrent_model.recurrent_state_size, - ) - - test(player, fabric, cfg, log_dir, sample_actions=False) + del _ + test(player, fabric, cfg, log_dir, greedy=True) diff --git a/sheeprl/algos/dreamer_v2/agent.py b/sheeprl/algos/dreamer_v2/agent.py index 120b9d7f..0c580c62 100644 --- a/sheeprl/algos/dreamer_v2/agent.py +++ b/sheeprl/algos/dreamer_v2/agent.py @@ -503,7 +503,7 @@ def _get_expl_amount(self, step: int) -> Tensor: return max(amount, self._expl_min) def forward( - self, state: Tensor, sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None + self, state: Tensor, greedy: bool = False, mask: Optional[Dict[str, Tensor]] = None ) -> Tuple[Sequence[Tensor], Sequence[Distribution]]: """ Call the forward method of the actor model and reorganizes the result with shape (batch_size, *, num_actions), @@ -511,8 +511,8 @@ def forward( Args: state (Tensor): the current state of shape (batch_size, *, stochastic_size + recurrent_state_size). - sample_actions (bool): whether or not to sample the actions. - Default to True. + greedy (bool): whether or not to sample the actions. + Default to False. mask (Dict[str, Tensor], optional): the action mask (which actions can be selected). Default to None. @@ -536,7 +536,7 @@ def forward( std = 2 * torch.sigmoid((std + self.init_std) / 2) + self.min_std dist = TruncatedNormal(torch.tanh(mean), std, -1, 1) actions_dist = Independent(dist, 1) - if sample_actions: + if not greedy: actions = actions_dist.rsample() else: sample = actions_dist.sample((100,)) @@ -549,7 +549,7 @@ def forward( actions: List[Tensor] = [] for logits in pre_dist: actions_dist.append(OneHotCategoricalStraightThrough(logits=logits)) - if sample_actions: + if not greedy: actions.append(actions_dist[-1].rsample()) else: actions.append(actions_dist[-1].mode) @@ -608,7 +608,7 @@ def __init__( ) def forward( - self, state: Tensor, sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None + self, state: Tensor, greedy: bool = False, mask: Optional[Dict[str, Tensor]] = None ) -> Tuple[Sequence[Tensor], Sequence[Distribution]]: """ Call the forward method of the actor model and reorganizes the result with shape (batch_size, *, num_actions), @@ -616,8 +616,8 @@ def forward( Args: state (Tensor): the current state of shape (batch_size, *, stochastic_size + recurrent_state_size). - sample_actions (bool): whether or not to sample the actions. - Default to True. + greedy (bool): whether or not to sample the actions. + Default to False. mask (Dict[str, Tensor], optional): the action mask (which actions can be selected). Default to None. @@ -652,7 +652,7 @@ def forward( elif sampled_action == 18: # Destroy action logits[t, b][torch.logical_not(mask["mask_destroy"][t, b])] = -torch.inf actions_dist.append(OneHotCategoricalStraightThrough(logits=logits)) - if sample_actions: + if not greedy: actions.append(actions_dist[-1].rsample()) else: actions.append(actions_dist[-1].mode) @@ -737,7 +737,6 @@ class PlayerDV2(nn.Module): The model of the Dreamer_v2 player. Args: - fabric: the fabric of the model. encoder (nn.Module | _FabricModule): the encoder. recurrent_model (nn.Module | _FabricModule): the recurrent model. representation_model (nn.Module | _FabricModule): the representation model. @@ -746,6 +745,7 @@ class PlayerDV2(nn.Module): num_envs (int): the number of environments. stochastic_size (int): the size of the stochastic state. recurrent_state_size (int): the size of the recurrent state. + device (str | torch.device): the device where the model is stored. discrete_size (int): the dimension of a single Categorical variable in the stochastic state (prior or posterior). Defaults to 32. @@ -755,7 +755,6 @@ class PlayerDV2(nn.Module): def __init__( self, - fabric: Fabric, encoder: nn.Module | _FabricModule, recurrent_model: nn.Module | _FabricModule, representation_model: nn.Module | _FabricModule, @@ -764,29 +763,21 @@ def __init__( num_envs: int, stochastic_size: int, recurrent_state_size: int, + device: str | torch.device, discrete_size: int = 32, actor_type: str | None = None, ) -> None: super().__init__() - fabric_player = get_single_device_fabric(fabric) - self.encoder = fabric_player.setup_module( - getattr(encoder, "module", encoder), - ) - self.recurrent_model = fabric_player.setup_module( - getattr(recurrent_model, "module", recurrent_model), - ) - self.representation_model = fabric_player.setup_module( - getattr(representation_model, "module", representation_model), - ) - self.actor = fabric_player.setup_module( - getattr(actor, "module", actor), - ) - self.device = fabric_player.device + self.encoder = encoder + self.recurrent_model = recurrent_model + self.representation_model = representation_model + self.actor = actor self.actions_dim = actions_dim + self.num_envs = num_envs self.stochastic_size = stochastic_size - self.discrete_size = discrete_size self.recurrent_state_size = recurrent_state_size - self.num_envs = num_envs + self.device = device + self.discrete_size = discrete_size self.actor_type = actor_type def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None: @@ -811,7 +802,7 @@ def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None: def get_actions( self, obs: Dict[str, Tensor], - sample_actions: bool = True, + greedy: bool = False, mask: Optional[Dict[str, Tensor]] = None, ) -> Sequence[Tensor]: """ @@ -819,8 +810,8 @@ def get_actions( Args: obs (Dict[str, Tensor]): the current observations. - sample_actions (bool): whether or not to sample the actions. - Default to True. + greedy (bool): whether or not to sample the actions. + Default to False. mask (Dict[str, Tensor], optional): the action mask (which actions can be selected). Default to None. @@ -836,7 +827,7 @@ def get_actions( self.stochastic_state = stochastic_state.view( *stochastic_state.shape[:-2], self.stochastic_size * self.discrete_size ) - actions, _ = self.actor(torch.cat((self.stochastic_state, self.recurrent_state), -1), sample_actions, mask) + actions, _ = self.actor(torch.cat((self.stochastic_state, self.recurrent_state), -1), greedy, mask) self.actions = torch.cat(actions, -1) return actions @@ -851,7 +842,7 @@ def build_agent( actor_state: Optional[Dict[str, Tensor]] = None, critic_state: Optional[Dict[str, Tensor]] = None, target_critic_state: Optional[Dict[str, Tensor]] = None, -) -> Tuple[WorldModel, _FabricModule, _FabricModule, _FabricModule]: +) -> Tuple[WorldModel, _FabricModule, _FabricModule, _FabricModule, PlayerDV2]: """Build the models and wrap them with Fabric. Args: @@ -1062,6 +1053,21 @@ def build_agent( if critic_state: critic.load_state_dict(critic_state) + # Create the player agent + fabric_player = get_single_device_fabric(fabric) + player = PlayerDV2( + copy.deepcopy(world_model.encoder), + copy.deepcopy(world_model.rssm.recurrent_model), + copy.deepcopy(world_model.rssm.representation_model), + copy.deepcopy(actor), + actions_dim, + cfg.env.num_envs, + cfg.algo.world_model.stochastic_size, + cfg.algo.world_model.recurrent_model.recurrent_state_size, + fabric_player.device, + discrete_size=cfg.algo.world_model.discrete_size, + ) + # Setup models with Fabric world_model.encoder = fabric.setup_module(world_model.encoder) world_model.observation_model = fabric.setup_module(world_model.observation_model) @@ -1078,7 +1084,21 @@ def build_agent( target_critic = copy.deepcopy(critic.module) if target_critic_state: target_critic.load_state_dict(target_critic_state) - single_device_fabric = get_single_device_fabric(fabric) - target_critic = single_device_fabric.setup_module(target_critic) - - return world_model, actor, critic, target_critic + target_critic = fabric_player.setup_module(target_critic) + + # Setup the player agent with a single-device Fabric + player.encoder = fabric_player.setup_module(player.encoder) + player.recurrent_model = fabric_player.setup_module(player.recurrent_model) + player.representation_model = fabric_player.setup_module(player.representation_model) + player.actor = fabric_player.setup_module(player.actor) + + # Tie weights between the agent and the player + for agent_p, p in zip(world_model.encoder.parameters(), player.encoder.parameters()): + p.data = agent_p.data + for agent_p, p in zip(world_model.rssm.recurrent_model.parameters(), player.recurrent_model.parameters()): + p.data = agent_p.data + for agent_p, p in zip(world_model.rssm.representation_model.parameters(), player.representation_model.parameters()): + p.data = agent_p.data + for agent_p, p in zip(actor.parameters(), player.actor.parameters()): + p.data = agent_p.data + return world_model, actor, critic, target_critic, player diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index 05ab0a86..ab50f618 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -22,7 +22,7 @@ from torch.optim import Optimizer from torchmetrics import SumMetric -from sheeprl.algos.dreamer_v2.agent import PlayerDV2, WorldModel, build_agent +from sheeprl.algos.dreamer_v2.agent import WorldModel, build_agent from sheeprl.algos.dreamer_v2.loss import reconstruction_loss from sheeprl.algos.dreamer_v2.utils import compute_lambda_values, test from sheeprl.data.buffers import EnvIndependentReplayBuffer, EpisodeBuffer, SequentialReplayBuffer @@ -450,7 +450,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.print("Decoder MLP keys:", cfg.algo.mlp_keys.decoder) obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder - world_model, actor, critic, target_critic = build_agent( + world_model, actor, critic, target_critic, player = build_agent( fabric, actions_dim, is_continuous, @@ -461,18 +461,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): state["critic"] if cfg.checkpoint.resume_from else None, state["target_critic"] if cfg.checkpoint.resume_from else None, ) - player = PlayerDV2( - fabric, - world_model.encoder, - world_model.rssm.recurrent_model, - world_model.rssm.representation_model, - actor, - actions_dim, - cfg.env.num_envs, - cfg.algo.world_model.stochastic_size, - cfg.algo.world_model.recurrent_model.recurrent_state_size, - discrete_size=cfg.algo.world_model.discrete_size, - ) # Optimizers world_optimizer = hydra.utils.instantiate( diff --git a/sheeprl/algos/dreamer_v2/evaluate.py b/sheeprl/algos/dreamer_v2/evaluate.py index 29515b7d..8b5990b9 100644 --- a/sheeprl/algos/dreamer_v2/evaluate.py +++ b/sheeprl/algos/dreamer_v2/evaluate.py @@ -5,7 +5,7 @@ import gymnasium as gym from lightning import Fabric -from sheeprl.algos.dreamer_v2.agent import PlayerDV2, build_agent +from sheeprl.algos.dreamer_v2.agent import build_agent from sheeprl.algos.dreamer_v2.utils import test from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger @@ -44,7 +44,7 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) # Create the actor and critic models - world_model, actor, _, _ = build_agent( + _, _, _, _, player = build_agent( fabric, actions_dim, is_continuous, @@ -53,17 +53,5 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): state["world_model"], state["actor"], ) - player = PlayerDV2( - fabric, - world_model.encoder, - world_model.rssm.recurrent_model, - world_model.rssm.representation_model, - actor, - actions_dim, - cfg.env.num_envs, - cfg.algo.world_model.stochastic_size, - cfg.algo.world_model.recurrent_model.recurrent_state_size, - discrete_size=cfg.algo.world_model.discrete_size, - ) - - test(player, fabric, cfg, log_dir, sample_actions=False) + del _ + test(player, fabric, cfg, log_dir, greedy=True) diff --git a/sheeprl/algos/dreamer_v2/utils.py b/sheeprl/algos/dreamer_v2/utils.py index f3a42f39..ed9debb0 100644 --- a/sheeprl/algos/dreamer_v2/utils.py +++ b/sheeprl/algos/dreamer_v2/utils.py @@ -108,7 +108,7 @@ def test( cfg: Dict[str, Any], log_dir: str, test_name: str = "", - sample_actions: bool = False, + greedy: bool = True, ): """Test the model on the environment with the frozen model. @@ -119,8 +119,8 @@ def test( log_dir (str): the logging directory. test_name (str): the name of the test. Default to "". - sample_actoins (bool): whether or not to sample actions. - Default to False. + greedy (bool): whether or not to sample actions. + Default to True. """ env: gym.Env = make_env(cfg, cfg.seed, 0, log_dir, "test" + (f"_{test_name}" if test_name != "" else ""))() done = False @@ -140,7 +140,7 @@ def test( elif k in cfg.algo.mlp_keys.encoder: preprocessed_obs[k] = v[None, ...].to(device) real_actions = player.get_actions( - preprocessed_obs, sample_actions, {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} + preprocessed_obs, greedy, {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} ) if player.actor.is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() diff --git a/sheeprl/algos/dreamer_v3/agent.py b/sheeprl/algos/dreamer_v3/agent.py index 0d8e0a97..77d86697 100644 --- a/sheeprl/algos/dreamer_v3/agent.py +++ b/sheeprl/algos/dreamer_v3/agent.py @@ -598,7 +598,6 @@ class PlayerDV3(nn.Module): The model of the Dreamer_v3 player. Args: - fabric (_FabricModule): the fabric module. encoder (MultiEncoder): the encoder. rssm (RSSM | DecoupledRSSM): the RSSM model. actor (_FabricModule): the actor. @@ -617,7 +616,6 @@ class PlayerDV3(nn.Module): def __init__( self, - fabric: Fabric, encoder: MultiEncoder | _FabricModule, rssm: RSSM | DecoupledRSSM, actor: Actor | MinedojoActor | _FabricModule, @@ -625,40 +623,22 @@ def __init__( num_envs: int, stochastic_size: int, recurrent_state_size: int, + device: str | torch.device, discrete_size: int = 32, actor_type: str | None = None, - decoupled_rssm: bool = False, ) -> None: super().__init__() - single_device_fabric = get_single_device_fabric(fabric) - self.encoder = single_device_fabric.setup_module(getattr(encoder, "module", encoder)) - if decoupled_rssm: - rssm_cls = DecoupledRSSM - else: - rssm_cls = RSSM - self.rssm = rssm_cls( - recurrent_model=single_device_fabric.setup_module( - getattr(rssm.recurrent_model, "module", rssm.recurrent_model) - ), - representation_model=single_device_fabric.setup_module( - getattr(rssm.representation_model, "module", rssm.representation_model) - ), - transition_model=single_device_fabric.setup_module( - getattr(rssm.transition_model, "module", rssm.transition_model) - ), - distribution_cfg=actor.distribution_cfg, - discrete=rssm.discrete, - unimix=rssm.unimix, - ).to(single_device_fabric.device) - self.actor = single_device_fabric.setup_module(getattr(actor, "module", actor)) - self.device = single_device_fabric.device + self.encoder = encoder + self.rssm = rssm + self.actor = actor self.actions_dim = actions_dim + self.num_envs = num_envs self.stochastic_size = stochastic_size - self.discrete_size = discrete_size self.recurrent_state_size = recurrent_state_size - self.num_envs = num_envs + self.device = device + self.discrete_size = discrete_size self.actor_type = actor_type - self.decoupled_rssm = decoupled_rssm + self.decoupled_rssm = isinstance(rssm, DecoupledRSSM) @torch.no_grad() def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None: @@ -681,7 +661,7 @@ def init_states(self, reset_envs: Optional[Sequence[int]] = None) -> None: def get_actions( self, obs: Dict[str, Tensor], - sample_actions: bool = True, + greedy: bool = False, mask: Optional[Dict[str, Tensor]] = None, ) -> Sequence[Tensor]: """ @@ -689,8 +669,8 @@ def get_actions( Args: obs (Dict[str, Tensor]): the current observations. - sample_actions (bool): whether or not to sample the actions. - Default to True. + greedy (bool): whether or not to sample the actions. + Default to False. Returns: The actions the agent has to perform. @@ -706,7 +686,7 @@ def get_actions( self.stochastic_state = self.stochastic_state.view( *self.stochastic_state.shape[:-2], self.stochastic_size * self.discrete_size ) - actions, _ = self.actor(torch.cat((self.stochastic_state, self.recurrent_state), -1), sample_actions, mask) + actions, _ = self.actor(torch.cat((self.stochastic_state, self.recurrent_state), -1), greedy, mask) self.actions = torch.cat(actions, -1) return actions @@ -801,7 +781,7 @@ def __init__( self._action_clip = action_clip def forward( - self, state: Tensor, sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None + self, state: Tensor, greedy: bool = False, mask: Optional[Dict[str, Tensor]] = None ) -> Tuple[Sequence[Tensor], Sequence[Distribution]]: """ Call the forward method of the actor model and reorganizes the result with shape (batch_size, *, num_actions), @@ -809,8 +789,8 @@ def forward( Args: state (Tensor): the current state of shape (batch_size, *, stochastic_size + recurrent_state_size). - sample_actions (bool): whether or not to sample the actions. - Default to True. + greedy (bool): whether or not to sample the actions. + Default to False. mask (Dict[str, Tensor], optional): the mask to use on the actions. Default to None. @@ -834,7 +814,7 @@ def forward( std = (self.max_std - self.min_std) * torch.sigmoid(std + self.init_std) + self.min_std dist = Normal(torch.tanh(mean), std) actions_dist = Independent(dist, 1) - if sample_actions: + if not greedy: actions = actions_dist.rsample() else: sample = actions_dist.sample((100,)) @@ -850,7 +830,7 @@ def forward( actions: List[Tensor] = [] for logits in pre_dist: actions_dist.append(OneHotCategoricalStraightThrough(logits=self._uniform_mix(logits))) - if sample_actions: + if not greedy: actions.append(actions_dist[-1].rsample()) else: actions.append(actions_dist[-1].mode) @@ -899,7 +879,7 @@ def __init__( ) def forward( - self, state: Tensor, sample_actions: bool = True, mask: Optional[Dict[str, Tensor]] = None + self, state: Tensor, greedy: bool = True, mask: Optional[Dict[str, Tensor]] = None ) -> Tuple[Sequence[Tensor], Sequence[Distribution]]: """ Call the forward method of the actor model and reorganizes the result with shape (batch_size, *, num_actions), @@ -907,7 +887,7 @@ def forward( Args: state (Tensor): the current state of shape (batch_size, *, stochastic_size + recurrent_state_size). - sample_actions (bool): whether or not to sample the actions. + greedy (bool): whether or not to sample the actions. Default to True. mask (Dict[str, Tensor], optional): the mask to apply to the actions. Default to None. @@ -943,7 +923,7 @@ def forward( elif sampled_action == 18: # Destroy action logits[t, b][torch.logical_not(mask["mask_destroy"][t, b])] = -torch.inf actions_dist.append(OneHotCategoricalStraightThrough(logits=logits)) - if sample_actions: + if not greedy: actions.append(actions_dist[-1].rsample()) else: actions.append(actions_dist[-1].mode) @@ -962,7 +942,7 @@ def build_agent( actor_state: Optional[Dict[str, Tensor]] = None, critic_state: Optional[Dict[str, Tensor]] = None, target_critic_state: Optional[Dict[str, Tensor]] = None, -) -> Tuple[WorldModel, _FabricModule, _FabricModule, torch.nn.Module]: +) -> Tuple[WorldModel, _FabricModule, _FabricModule, _FabricModule, PlayerDV3]: """Build the models and wrap them with Fabric. Args: @@ -1207,6 +1187,20 @@ def build_agent( if critic_state: critic.load_state_dict(critic_state) + # Create the player agent + fabric_player = get_single_device_fabric(fabric) + player = PlayerDV3( + copy.deepcopy(world_model.encoder), + copy.deepcopy(world_model.rssm), + copy.deepcopy(actor), + actions_dim, + cfg.env.num_envs, + cfg.algo.world_model.stochastic_size, + cfg.algo.world_model.recurrent_model.recurrent_state_size, + fabric_player.device, + discrete_size=cfg.algo.world_model.discrete_size, + ) + # Setup models with Fabric world_model.encoder = fabric.setup_module(world_model.encoder) world_model.observation_model = fabric.setup_module(world_model.observation_model) @@ -1223,7 +1217,20 @@ def build_agent( target_critic = copy.deepcopy(critic.module) if target_critic_state: target_critic.load_state_dict(target_critic_state) - single_device_fabric = get_single_device_fabric(fabric) - target_critic = single_device_fabric.setup_module(target_critic) - - return world_model, actor, critic, target_critic + target_critic = fabric_player.setup_module(target_critic) + + # Setup the player agent with a single-device Fabric + player.encoder = fabric_player.setup_module(player.encoder) + player.rssm.recurrent_model = fabric_player.setup_module(player.rssm.recurrent_model) + player.rssm.transition_model = fabric_player.setup_module(player.rssm.transition_model) + player.rssm.representation_model = fabric_player.setup_module(player.rssm.representation_model) + player.actor = fabric_player.setup_module(player.actor) + + # Tie weights between the agent and the player + for agent_p, p in zip(world_model.encoder.parameters(), player.encoder.parameters()): + p.data = agent_p.data + for agent_p, p in zip(world_model.rssm.parameters(), player.rssm.parameters()): + p.data = agent_p.data + for agent_p, p in zip(actor.parameters(), player.actor.parameters()): + p.data = agent_p.data + return world_model, actor, critic, target_critic, player diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index cb5a2a56..4d058ce2 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -22,7 +22,7 @@ from torch.optim import Optimizer from torchmetrics import SumMetric -from sheeprl.algos.dreamer_v3.agent import PlayerDV3, WorldModel, build_agent +from sheeprl.algos.dreamer_v3.agent import WorldModel, build_agent from sheeprl.algos.dreamer_v3.loss import reconstruction_loss from sheeprl.algos.dreamer_v3.utils import Moments, compute_lambda_values, test from sheeprl.data.buffers import EnvIndependentReplayBuffer, SequentialReplayBuffer @@ -429,7 +429,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.print("Decoder MLP keys:", cfg.algo.mlp_keys.decoder) obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder - world_model, actor, critic, target_critic = build_agent( + world_model, actor, critic, target_critic, player = build_agent( fabric, actions_dim, is_continuous, @@ -440,18 +440,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): state["critic"] if cfg.checkpoint.resume_from else None, state["target_critic"] if cfg.checkpoint.resume_from else None, ) - player = PlayerDV3( - fabric, - world_model.encoder, - world_model.rssm, - actor, - actions_dim, - cfg.env.num_envs, - cfg.algo.world_model.stochastic_size, - cfg.algo.world_model.recurrent_model.recurrent_state_size, - discrete_size=cfg.algo.world_model.discrete_size, - decoupled_rssm=cfg.algo.world_model.decoupled_rssm, - ) # Optimizers world_optimizer = hydra.utils.instantiate( @@ -775,7 +763,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() if fabric.is_global_zero and cfg.algo.run_test: - test(player, fabric, cfg, log_dir, sample_actions=True) + test(player, fabric, cfg, log_dir, greedy=False) if not cfg.model_manager.disabled and fabric.is_global_zero: from sheeprl.algos.dreamer_v1.utils import log_models diff --git a/sheeprl/algos/dreamer_v3/evaluate.py b/sheeprl/algos/dreamer_v3/evaluate.py index cc0e67f9..ad58fdd8 100644 --- a/sheeprl/algos/dreamer_v3/evaluate.py +++ b/sheeprl/algos/dreamer_v3/evaluate.py @@ -5,7 +5,7 @@ import gymnasium as gym from lightning import Fabric -from sheeprl.algos.dreamer_v3.agent import PlayerDV3, build_agent +from sheeprl.algos.dreamer_v3.agent import build_agent from sheeprl.algos.dreamer_v3.utils import test from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger @@ -44,7 +44,7 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) # Create the actor and critic models - world_model, actor, _, _ = build_agent( + _, _, _, _, player = build_agent( fabric, actions_dim, is_continuous, @@ -53,17 +53,5 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): state["world_model"], state["actor"], ) - player = PlayerDV3( - fabric, - world_model.encoder, - world_model.rssm, - actor, - actions_dim, - cfg.env.num_envs, - cfg.algo.world_model.stochastic_size, - cfg.algo.world_model.recurrent_model.recurrent_state_size, - discrete_size=cfg.algo.world_model.discrete_size, - decoupled_rssm=cfg.algo.world_model.decoupled_rssm, - ) - - test(player, fabric, cfg, log_dir, sample_actions=True) + del _ + test(player, fabric, cfg, log_dir, greedy=False) diff --git a/sheeprl/algos/dreamer_v3/utils.py b/sheeprl/algos/dreamer_v3/utils.py index bb8bf297..42d79060 100644 --- a/sheeprl/algos/dreamer_v3/utils.py +++ b/sheeprl/algos/dreamer_v3/utils.py @@ -84,7 +84,7 @@ def test( cfg: Dict[str, Any], log_dir: str, test_name: str = "", - sample_actions: bool = False, + greedy: bool = True, ): """Test the model on the environment with the frozen model. @@ -95,8 +95,8 @@ def test( log_dir (str): the logging directory. test_name (str): the name of the test. Default to "". - sample_actions (bool): whether or not to sample the actions. - Default to False. + greedy (bool): whether or not to sample the actions. + Default to True. """ env: gym.Env = make_env(cfg, cfg.seed, 0, log_dir, "test" + (f"_{test_name}" if test_name != "" else ""))() done = False @@ -116,7 +116,7 @@ def test( elif k in cfg.algo.mlp_keys.encoder: preprocessed_obs[k] = v[None, ...].to(device) real_actions = player.get_actions( - preprocessed_obs, sample_actions, {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} + preprocessed_obs, greedy, {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")} ) if player.actor.is_continuous: real_actions = torch.cat(real_actions, -1).cpu().numpy() diff --git a/sheeprl/algos/droq/agent.py b/sheeprl/algos/droq/agent.py index 959a56c3..cae92833 100644 --- a/sheeprl/algos/droq/agent.py +++ b/sheeprl/algos/droq/agent.py @@ -9,7 +9,7 @@ from lightning.fabric.wrappers import _FabricModule from torch import Tensor -from sheeprl.algos.sac.agent import SACActor +from sheeprl.algos.sac.agent import SACActor, SACPlayer from sheeprl.models.models import MLP from sheeprl.utils.fabric import get_single_device_fabric @@ -215,7 +215,7 @@ def build_agent( obs_space: gymnasium.spaces.Dict, action_space: gymnasium.spaces.Box, agent_state: Optional[Dict[str, Tensor]] = None, -) -> DROQAgent: +) -> Tuple[DROQAgent, SACPlayer]: act_dim = prod(action_space.shape) obs_dim = sum([prod(obs_space[k].shape) for k in cfg.algo.mlp_keys.encoder]) actor = SACActor( @@ -237,10 +237,26 @@ def build_agent( ] target_entropy = -act_dim agent = DROQAgent( - actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device + actor, + critics, + target_entropy, + alpha=cfg.algo.alpha.alpha, + tau=cfg.algo.tau, + device=fabric.device, ) if agent_state: agent.load_state_dict(agent_state) + + # Setup player agent + player = SACPlayer( + copy.deepcopy(agent.actor.model), + copy.deepcopy(agent.actor.fc_mean), + copy.deepcopy(agent.actor.fc_logstd), + action_low=action_space.low, + action_high=action_space.high, + ) + + # Setup training agent agent.actor = fabric.setup_module(agent.actor) agent.critics = [fabric.setup_module(critic) for critic in agent.critics] @@ -249,4 +265,14 @@ def build_agent( fabric_player = get_single_device_fabric(fabric) agent.qfs_target = nn.ModuleList([fabric_player.setup_module(target) for target in agent.qfs_target]) - return agent + # Setup player agent + player.model = fabric_player.setup_module(player.model) + player.fc_mean = fabric_player.setup_module(player.fc_mean) + player.fc_logstd = fabric_player.setup_module(player.fc_logstd) + player.action_scale = player.action_scale.to(fabric_player.device) + player.action_bias = player.action_bias.to(fabric_player.device) + + # Tie weights between the agent and the player + for agent_p, player_p in zip(agent.actor.parameters(), player.parameters()): + player_p.data = agent_p.data + return agent, player diff --git a/sheeprl/algos/droq/droq.py b/sheeprl/algos/droq/droq.py index 0c54d42d..a917684b 100644 --- a/sheeprl/algos/droq/droq.py +++ b/sheeprl/algos/droq/droq.py @@ -21,7 +21,6 @@ from sheeprl.algos.sac.sac import test from sheeprl.data.buffers import ReplayBuffer from sheeprl.utils.env import make_env -from sheeprl.utils.fabric import get_single_device_fabric from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm @@ -202,11 +201,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) # Define the agent and the optimizer and setup them with Fabric - agent = build_agent( + agent, player = build_agent( fabric, cfg, observation_space, action_space, state["agent"] if cfg.checkpoint.resume_from else None ) - fabric_player = get_single_device_fabric(fabric) - actor = fabric_player.setup_module(agent.actor.module) # Optimizers qf_optimizer = hydra.utils.instantiate(cfg.algo.critic.optimizer, params=agent.qfs.parameters(), _convert_="all") @@ -308,7 +305,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): else: with torch.inference_mode(): # Sample an action given the observation received by the environment - actions, _ = actor(torch.from_numpy(obs).to(device)) + actions = player(torch.from_numpy(obs).to(device)) actions = actions.cpu().numpy() next_obs, rewards, terminated, truncated, infos = envs.step(actions.reshape(envs.action_space.shape)) @@ -427,7 +424,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() if fabric.is_global_zero and cfg.algo.run_test: - test(actor, fabric, cfg, log_dir) + test(player, fabric, cfg, log_dir) if not cfg.model_manager.disabled and fabric.is_global_zero: from sheeprl.algos.sac.utils import log_models diff --git a/sheeprl/algos/droq/evaluate.py b/sheeprl/algos/droq/evaluate.py index 2fa3e0c9..695e58b1 100644 --- a/sheeprl/algos/droq/evaluate.py +++ b/sheeprl/algos/droq/evaluate.py @@ -46,5 +46,6 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): if cfg.metric.log_level > 0: fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) - agent = build_agent(fabric, cfg, observation_space, action_space, state["agent"]) - test(agent.actor, fabric, cfg, log_dir) + _, agent = build_agent(fabric, cfg, observation_space, action_space, state["agent"]) + del _ + test(agent, fabric, cfg, log_dir) diff --git a/sheeprl/algos/p2e_dv1/agent.py b/sheeprl/algos/p2e_dv1/agent.py index 32a6269d..254f269a 100644 --- a/sheeprl/algos/p2e_dv1/agent.py +++ b/sheeprl/algos/p2e_dv1/agent.py @@ -8,12 +8,13 @@ from lightning.pytorch.utilities.seed import isolate_rng from torch import nn -from sheeprl.algos.dreamer_v1.agent import WorldModel +from sheeprl.algos.dreamer_v1.agent import PlayerDV1, WorldModel from sheeprl.algos.dreamer_v1.agent import build_agent as dv1_build_agent from sheeprl.algos.dreamer_v2.agent import Actor as DV2Actor from sheeprl.algos.dreamer_v2.agent import MinedojoActor as DV2MinedojoActor from sheeprl.models.models import MLP -from sheeprl.utils.utils import init_weights +from sheeprl.utils.fabric import get_single_device_fabric +from sheeprl.utils.utils import init_weights, unwrap_fabric # In order to use the hydra.utils.get_class method, in this way the user can # specify in the configs the name of the class without having to know where @@ -34,7 +35,7 @@ def build_agent( critic_task_state: Optional[Dict[str, torch.Tensor]] = None, actor_exploration_state: Optional[Dict[str, torch.Tensor]] = None, critic_exploration_state: Optional[Dict[str, torch.Tensor]] = None, -) -> Tuple[WorldModel, _FabricModule, _FabricModule, _FabricModule, _FabricModule, _FabricModule]: +) -> Tuple[WorldModel, nn.ModuleList, _FabricModule, _FabricModule, _FabricModule, _FabricModule, PlayerDV1]: """Build the models and wrap them with Fabric. Args: @@ -64,6 +65,7 @@ def build_agent( The critic_task (_FabricModule): for predicting the values of the task. The actor_exploration (_FabricModule): for exploring the environment. The critic_exploration (_FabricModule): for predicting the values of the exploration. + The player (PlayerDV1): the player object. """ world_model_cfg = cfg.algo.world_model actor_cfg = cfg.algo.actor @@ -73,7 +75,7 @@ def build_agent( latent_state_size = world_model_cfg.stochastic_size + world_model_cfg.recurrent_model.recurrent_state_size # Create exploration models - world_model, actor_exploration, critic_exploration = dv1_build_agent( + world_model, actor_exploration, critic_exploration, player = dv1_build_agent( fabric, actions_dim=actions_dim, is_continuous=is_continuous, @@ -83,6 +85,7 @@ def build_agent( actor_state=actor_exploration_state, critic_state=critic_exploration_state, ) + player.actor_type = cfg.algo.player.actor_type actor_cls = hydra.utils.get_class(cfg.algo.actor.cls) actor_task: Union[Actor, MinedojoActor] = actor_cls( latent_state_size=latent_state_size, @@ -141,4 +144,12 @@ def build_agent( for i in range(len(ensembles)): ensembles[i] = fabric.setup_module(ensembles[i]) - return world_model, ensembles, actor_task, critic_task, actor_exploration, critic_exploration + # Setup player agent + if cfg.algo.player.actor_type != "exploration": + fabric_player = get_single_device_fabric(fabric) + player_actor = unwrap_fabric(actor_task) + player.actor = fabric_player.setup_module(player_actor) + for agent_p, p in zip(actor_task.parameters(), player.actor.parameters()): + p.data = agent_p.data + + return world_model, ensembles, actor_task, critic_task, actor_exploration, critic_exploration, player diff --git a/sheeprl/algos/p2e_dv1/evaluate.py b/sheeprl/algos/p2e_dv1/evaluate.py index 2d2d9bf9..4381ae78 100644 --- a/sheeprl/algos/p2e_dv1/evaluate.py +++ b/sheeprl/algos/p2e_dv1/evaluate.py @@ -5,7 +5,6 @@ import gymnasium as gym from lightning import Fabric -from sheeprl.algos.dreamer_v1.agent import PlayerDV1 from sheeprl.algos.dreamer_v2.utils import test from sheeprl.algos.p2e_dv1.agent import build_agent from sheeprl.utils.env import make_env @@ -45,25 +44,15 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) # Create the actor and critic models - world_model, _, actor_task, _, _, _ = build_agent( + cfg.algo.player.actor_type = "task" + _, _, _, _, _, _, player = build_agent( fabric, actions_dim, is_continuous, cfg, observation_space, - state["world_model"], - state["actor_task"], + world_model_state=state["world_model"], + actor_task_state=state["actor_task"], ) - player = PlayerDV1( - fabric, - world_model.encoder, - world_model.rssm.recurrent_model, - world_model.rssm.representation_model, - actor_task, - actions_dim, - cfg.env.num_envs, - cfg.algo.world_model.stochastic_size, - cfg.algo.world_model.recurrent_model.recurrent_state_size, - ) - - test(player, fabric, cfg, log_dir, sample_actions=False) + del _ + test(player, fabric, cfg, log_dir, greedy=True) diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py index 081a2a79..4f31249a 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py @@ -17,7 +17,7 @@ from torch.distributions.utils import logits_to_probs from torchmetrics import SumMetric -from sheeprl.algos.dreamer_v1.agent import PlayerDV1, WorldModel +from sheeprl.algos.dreamer_v1.agent import WorldModel from sheeprl.algos.dreamer_v1.loss import actor_loss, critic_loss, reconstruction_loss from sheeprl.algos.dreamer_v1.utils import compute_lambda_values from sheeprl.algos.dreamer_v2.utils import test @@ -29,7 +29,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import Ratio, save_configs +from sheeprl.utils.utils import Ratio, save_configs, unwrap_fabric # Decomment the following line if you are using MineDojo on an headless machine # os.environ["MINEDOJO_HEADLESS"] = "1" @@ -433,7 +433,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.print("Decoder MLP keys:", cfg.algo.mlp_keys.decoder) obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder - world_model, ensembles, actor_task, critic_task, actor_exploration, critic_exploration = build_agent( + world_model, ensembles, actor_task, critic_task, actor_exploration, critic_exploration, player = build_agent( fabric, actions_dim, is_continuous, @@ -447,19 +447,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): state["critic_exploration"] if cfg.checkpoint.resume_from else None, ) - player = PlayerDV1( - fabric, - world_model.encoder, - world_model.rssm.recurrent_model, - world_model.rssm.representation_model, - actor_exploration, - actions_dim, - cfg.env.num_envs, - cfg.algo.world_model.stochastic_size, - cfg.algo.world_model.recurrent_model.recurrent_state_size, - actor_type=cfg.algo.player.actor_type, - ) - # Optimizers world_optimizer = hydra.utils.instantiate( cfg.algo.world_model.optimizer, params=world_model.parameters(), _convert_="all" @@ -799,7 +786,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if fabric.is_global_zero and cfg.algo.run_test: player.actor_type = "task" fabric_player = get_single_device_fabric(fabric) - player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) + player.actor = fabric_player.setup_module(unwrap_fabric(actor_task)) test(player, fabric, cfg, log_dir, "zero-shot") if not cfg.model_manager.disabled and fabric.is_global_zero: diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py index 7779274a..68e6e75d 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py @@ -13,7 +13,6 @@ from lightning.fabric import Fabric from torchmetrics import SumMetric -from sheeprl.algos.dreamer_v1.agent import PlayerDV1 from sheeprl.algos.dreamer_v1.dreamer_v1 import train from sheeprl.algos.dreamer_v2.utils import test from sheeprl.algos.p2e_dv1.agent import build_agent @@ -24,7 +23,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import Ratio, save_configs +from sheeprl.utils.utils import Ratio, save_configs, unwrap_fabric # Decomment the following line if you are using MineDojo on an headless machine # os.environ["MINEDOJO_HEADLESS"] = "1" @@ -133,7 +132,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): fabric.print("Decoder MLP keys:", cfg.algo.mlp_keys.decoder) obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder - world_model, _, actor_task, critic_task, actor_exploration, _ = build_agent( + world_model, _, actor_task, critic_task, actor_exploration, _, player = build_agent( fabric, actions_dim, is_continuous, @@ -146,19 +145,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): state["actor_exploration"], ) - player = PlayerDV1( - fabric, - world_model.encoder, - world_model.rssm.recurrent_model, - world_model.rssm.representation_model, - actor_exploration if cfg.algo.player.actor_type == "exploration" else actor_task, - actions_dim, - cfg.env.num_envs, - cfg.algo.world_model.stochastic_size, - cfg.algo.world_model.recurrent_model.recurrent_state_size, - actor_type=cfg.algo.player.actor_type, - ) - # Optimizers world_optimizer = hydra.utils.instantiate( cfg.algo.world_model.optimizer, params=world_model.parameters(), _convert_="all" @@ -346,7 +332,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): if per_rank_gradient_steps > 0: if player.actor_type != "task": player.actor_type = "task" - player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) + player.actor = fabric_player.setup_module(unwrap_fabric(actor_task)) + for agent_p, p in zip(actor_task.parameters(), player.actor.parameters()): + p.data = agent_p.data with timer("Time/train_time", SumMetric, sync_on_compute=cfg.metric.sync_on_compute): sample = rb.sample_tensors( batch_size=cfg.algo.per_rank_batch_size, @@ -445,7 +433,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # task test few-shot if fabric.is_global_zero and cfg.algo.run_test: player.actor_type = "task" - player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) + player.actor = fabric_player.setup_module(unwrap_fabric(actor_task)) test(player, fabric, cfg, log_dir, "few-shot") if not cfg.model_manager.disabled and fabric.is_global_zero: diff --git a/sheeprl/algos/p2e_dv2/agent.py b/sheeprl/algos/p2e_dv2/agent.py index f89243ba..a5a9495c 100644 --- a/sheeprl/algos/p2e_dv2/agent.py +++ b/sheeprl/algos/p2e_dv2/agent.py @@ -11,11 +11,11 @@ from sheeprl.algos.dreamer_v2.agent import Actor as DV2Actor from sheeprl.algos.dreamer_v2.agent import MinedojoActor as DV2MinedojoActor -from sheeprl.algos.dreamer_v2.agent import WorldModel +from sheeprl.algos.dreamer_v2.agent import PlayerDV2, WorldModel from sheeprl.algos.dreamer_v2.agent import build_agent as dv2_build_agent from sheeprl.models.models import MLP from sheeprl.utils.fabric import get_single_device_fabric -from sheeprl.utils.utils import init_weights +from sheeprl.utils.utils import init_weights, unwrap_fabric # In order to use the hydra.utils.get_class method, in this way the user can # specify in the configs the name of the class without having to know where @@ -38,7 +38,17 @@ def build_agent( actor_exploration_state: Optional[Dict[str, torch.Tensor]] = None, critic_exploration_state: Optional[Dict[str, torch.Tensor]] = None, target_critic_exploration_state: Optional[Dict[str, torch.Tensor]] = None, -) -> Tuple[WorldModel, _FabricModule, _FabricModule, _FabricModule, nn.Module, _FabricModule, _FabricModule, nn.Module]: +) -> Tuple[ + WorldModel, + nn.ModuleList, + _FabricModule, + _FabricModule, + _FabricModule, + _FabricModule, + _FabricModule, + _FabricModule, + PlayerDV2, +]: """Build the models and wrap them with Fabric. Args: @@ -84,7 +94,7 @@ def build_agent( latent_state_size = stochastic_size + world_model_cfg.recurrent_model.recurrent_state_size # Create exploration models - world_model, actor_exploration, critic_exploration, target_critic_exploration = dv2_build_agent( + world_model, actor_exploration, critic_exploration, target_critic_exploration, player = dv2_build_agent( fabric, actions_dim=actions_dim, is_continuous=is_continuous, @@ -178,6 +188,14 @@ def build_agent( for i in range(len(ensembles)): ensembles[i] = fabric.setup_module(ensembles[i]) + # Setup player agent + if cfg.algo.player.actor_type != "exploration": + fabric_player = get_single_device_fabric(fabric) + player_actor = unwrap_fabric(actor_task) + player.actor = fabric_player.setup_module(player_actor) + for agent_p, p in zip(actor_task.parameters(), player.actor.parameters()): + p.data = agent_p.data + return ( world_model, ensembles, @@ -187,4 +205,5 @@ def build_agent( actor_exploration, critic_exploration, target_critic_exploration, + player, ) diff --git a/sheeprl/algos/p2e_dv2/evaluate.py b/sheeprl/algos/p2e_dv2/evaluate.py index 26ef629d..5dcb9b15 100644 --- a/sheeprl/algos/p2e_dv2/evaluate.py +++ b/sheeprl/algos/p2e_dv2/evaluate.py @@ -5,7 +5,6 @@ import gymnasium as gym from lightning import Fabric -from sheeprl.algos.dreamer_v2.agent import PlayerDV2 from sheeprl.algos.dreamer_v2.utils import test from sheeprl.algos.p2e_dv2.agent import build_agent from sheeprl.utils.env import make_env @@ -45,27 +44,15 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) # Create the actor and critic models - world_model, _, actor_task, _, _, _, _, _ = build_agent( + cfg.algo.player.actor_type = "task" + _, _, _, _, _, _, _, _, player = build_agent( fabric, actions_dim, is_continuous, cfg, observation_space, - state["world_model"], - None, - state["actor_task"], + world_model_state=state["world_model"], + actor_task_state=state["actor_task"], ) - player = PlayerDV2( - fabric, - world_model.encoder, - world_model.rssm.recurrent_model, - world_model.rssm.representation_model, - actor_task, - actions_dim, - cfg.env.num_envs, - cfg.algo.world_model.stochastic_size, - cfg.algo.world_model.recurrent_model.recurrent_state_size, - discrete_size=cfg.algo.world_model.discrete_size, - ) - - test(player, fabric, cfg, log_dir, sample_actions=False) + del _ + test(player, fabric, cfg, log_dir, greedy=True) diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py index a474f2c2..e0f5a121 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py @@ -17,7 +17,7 @@ from torch.distributions.utils import logits_to_probs from torchmetrics import SumMetric -from sheeprl.algos.dreamer_v2.agent import PlayerDV2, WorldModel +from sheeprl.algos.dreamer_v2.agent import WorldModel from sheeprl.algos.dreamer_v2.loss import reconstruction_loss from sheeprl.algos.dreamer_v2.utils import compute_lambda_values, test from sheeprl.algos.p2e_dv2.agent import build_agent @@ -28,7 +28,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import Ratio, save_configs +from sheeprl.utils.utils import Ratio, save_configs, unwrap_fabric # Decomment the following line if you are using MineDojo on an headless machine # os.environ["MINEDOJO_HEADLESS"] = "1" @@ -553,6 +553,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): actor_exploration, critic_exploration, target_critic_exploration, + player, ) = build_agent( fabric, actions_dim, @@ -569,20 +570,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): state["target_critic_exploration"] if cfg.checkpoint.resume_from else None, ) - player = PlayerDV2( - fabric, - world_model.encoder, - world_model.rssm.recurrent_model, - world_model.rssm.representation_model, - actor_exploration, - actions_dim, - cfg.env.num_envs, - cfg.algo.world_model.stochastic_size, - cfg.algo.world_model.recurrent_model.recurrent_state_size, - discrete_size=cfg.algo.world_model.discrete_size, - actor_type=cfg.algo.player.actor_type, - ) - # Optimizers world_optimizer = hydra.utils.instantiate( cfg.algo.world_model.optimizer, params=world_model.parameters(), _convert_="all" @@ -950,7 +937,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if fabric.is_global_zero and cfg.algo.run_test: player.actor_type = "task" fabric_player = get_single_device_fabric(fabric) - player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) + player.actor = fabric_player.setup_module(unwrap_fabric(actor_task)) test(player, fabric, cfg, log_dir, "zero-shot") if not cfg.model_manager.disabled and fabric.is_global_zero: diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py index 4e0dcc6b..bd9c548e 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py @@ -13,7 +13,6 @@ from lightning.fabric import Fabric from torchmetrics import SumMetric -from sheeprl.algos.dreamer_v2.agent import PlayerDV2 from sheeprl.algos.dreamer_v2.dreamer_v2 import train from sheeprl.algos.dreamer_v2.utils import test from sheeprl.algos.p2e_dv2.agent import build_agent @@ -24,7 +23,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import Ratio, save_configs +from sheeprl.utils.utils import Ratio, save_configs, unwrap_fabric # Decomment the following line if you are using MineDojo on an headless machine # os.environ["MINEDOJO_HEADLESS"] = "1" @@ -136,7 +135,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): fabric.print("Decoder MLP keys:", cfg.algo.mlp_keys.decoder) obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder - world_model, _, actor_task, critic_task, target_critic_task, actor_exploration, _, _ = build_agent( + world_model, _, actor_task, critic_task, target_critic_task, actor_exploration, _, _, player = build_agent( fabric, actions_dim, is_continuous, @@ -150,20 +149,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): state["actor_exploration"], ) - player = PlayerDV2( - fabric, - world_model.encoder, - world_model.rssm.recurrent_model, - world_model.rssm.representation_model, - actor_exploration if cfg.algo.player.actor_type == "exploration" else actor_task, - actions_dim, - cfg.env.num_envs, - cfg.algo.world_model.stochastic_size, - cfg.algo.world_model.recurrent_model.recurrent_state_size, - discrete_size=cfg.algo.world_model.discrete_size, - actor_type=cfg.algo.player.actor_type, - ) - # Optimizers world_optimizer = hydra.utils.instantiate( cfg.algo.world_model.optimizer, params=world_model.parameters(), _convert_="all" @@ -370,7 +355,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): if per_rank_gradient_steps > 0: if player.actor_type != "task": player.actor_type = "task" - player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) + player.actor = fabric_player.setup_module(unwrap_fabric(actor_task)) + for agent_p, p in zip(actor_task.parameters(), player.actor.parameters()): + p.data = agent_p.data local_data = rb.sample_tensors( batch_size=cfg.algo.per_rank_batch_size, sequence_length=cfg.algo.per_rank_sequence_length, @@ -474,7 +461,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # task test few-shot if fabric.is_global_zero and cfg.algo.run_test: player.actor_type = "task" - player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) + player.actor = fabric_player.setup_module(unwrap_fabric(actor_task)) test(player, fabric, cfg, log_dir, "few-shot") if not cfg.model_manager.disabled and fabric.is_global_zero: diff --git a/sheeprl/algos/p2e_dv3/agent.py b/sheeprl/algos/p2e_dv3/agent.py index 2f35674d..c1336ccd 100644 --- a/sheeprl/algos/p2e_dv3/agent.py +++ b/sheeprl/algos/p2e_dv3/agent.py @@ -10,11 +10,12 @@ from sheeprl.algos.dreamer_v3.agent import Actor as DV3Actor from sheeprl.algos.dreamer_v3.agent import MinedojoActor as DV3MinedojoActor -from sheeprl.algos.dreamer_v3.agent import WorldModel +from sheeprl.algos.dreamer_v3.agent import PlayerDV3, WorldModel from sheeprl.algos.dreamer_v3.agent import build_agent as dv3_build_agent from sheeprl.algos.dreamer_v3.utils import init_weights, uniform_init_weights from sheeprl.models.models import MLP from sheeprl.utils.fabric import get_single_device_fabric +from sheeprl.utils.utils import unwrap_fabric # In order to use the hydra.utils.get_class method, in this way the user can # specify in the configs the name of the class without having to know where @@ -36,7 +37,9 @@ def build_agent( target_critic_task_state: Optional[Dict[str, torch.Tensor]] = None, actor_exploration_state: Optional[Dict[str, torch.Tensor]] = None, critics_exploration_state: Optional[Dict[str, Dict[str, Any]]] = None, -) -> Tuple[WorldModel, _FabricModule, _FabricModule, _FabricModule, nn.Module, _FabricModule, Dict[str, Any]]: +) -> Tuple[ + WorldModel, nn.ModuleList, _FabricModule, _FabricModule, _FabricModule, _FabricModule, Dict[str, Any], PlayerDV3 +]: """Build the models and wrap them with Fabric. Args: @@ -82,7 +85,7 @@ def build_agent( latent_state_size = stochastic_size + world_model_cfg.recurrent_model.recurrent_state_size # Create task models - world_model, actor_task, critic_task, target_critic_task = dv3_build_agent( + world_model, actor_task, critic_task, target_critic_task, player = dv3_build_agent( fabric, actions_dim=actions_dim, is_continuous=is_continuous, @@ -200,6 +203,14 @@ def build_agent( for i in range(len(ensembles)): ensembles[i] = fabric.setup_module(ensembles[i]) + # Setup player agent + if cfg.algo.player.actor_type == "exploration": + fabric_player = get_single_device_fabric(fabric) + player_actor = unwrap_fabric(actor_exploration) + player.actor = fabric_player.setup_module(player_actor) + for agent_p, p in zip(actor_exploration.parameters(), player.actor.parameters()): + p.data = agent_p.data + return ( world_model, ensembles, @@ -208,4 +219,5 @@ def build_agent( target_critic_task, actor_exploration, critics_exploration, + player, ) diff --git a/sheeprl/algos/p2e_dv3/evaluate.py b/sheeprl/algos/p2e_dv3/evaluate.py index e59052b7..6049867b 100644 --- a/sheeprl/algos/p2e_dv3/evaluate.py +++ b/sheeprl/algos/p2e_dv3/evaluate.py @@ -5,7 +5,6 @@ import gymnasium as gym from lightning import Fabric -from sheeprl.algos.dreamer_v3.agent import PlayerDV3 from sheeprl.algos.dreamer_v3.utils import test from sheeprl.algos.p2e_dv3.agent import build_agent from sheeprl.utils.env import make_env @@ -45,27 +44,15 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) # Create the actor and critic models - world_model, _, actor, _, _, _, _ = build_agent( + cfg.algo.player.actor_type == "task" + _, _, _, _, _, _, _, player = build_agent( fabric, actions_dim, is_continuous, cfg, observation_space, - state["world_model"], - None, - state["actor_task"], + world_model_state=state["world_model"], + actor_task_state=state["actor_task"], ) - player = PlayerDV3( - fabric, - world_model.encoder, - world_model.rssm, - actor, - actions_dim, - cfg.env.num_envs, - cfg.algo.world_model.stochastic_size, - cfg.algo.world_model.recurrent_model.recurrent_state_size, - discrete_size=cfg.algo.world_model.discrete_size, - actor_type="task", - ) - - test(player, fabric, cfg, log_dir, sample_actions=True) + del _ + test(player, fabric, cfg, log_dir, greedy=False) diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py index e74d10db..da0ea8fe 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py @@ -15,7 +15,7 @@ from torch.distributions import Distribution, Independent, OneHotCategorical from torchmetrics import SumMetric -from sheeprl.algos.dreamer_v3.agent import PlayerDV3, WorldModel +from sheeprl.algos.dreamer_v3.agent import WorldModel from sheeprl.algos.dreamer_v3.loss import reconstruction_loss from sheeprl.algos.dreamer_v3.utils import Moments, compute_lambda_values, test from sheeprl.algos.p2e_dv3.agent import build_agent @@ -32,7 +32,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import Ratio, save_configs +from sheeprl.utils.utils import Ratio, save_configs, unwrap_fabric # Decomment the following line if you are using MineDojo on an headless machine # os.environ["MINEDOJO_HEADLESS"] = "1" @@ -597,6 +597,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): target_critic_task, actor_exploration, critics_exploration, + player, ) = build_agent( fabric, actions_dim, @@ -612,19 +613,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): state["critics_exploration"] if cfg.checkpoint.resume_from else None, ) - player = PlayerDV3( - fabric, - world_model.encoder, - world_model.rssm, - actor_exploration, - actions_dim, - cfg.env.num_envs, - cfg.algo.world_model.stochastic_size, - cfg.algo.world_model.recurrent_model.recurrent_state_size, - discrete_size=cfg.algo.world_model.discrete_size, - actor_type=cfg.algo.player.actor_type, - ) - # Optimizers world_optimizer = hydra.utils.instantiate( cfg.algo.world_model.optimizer, params=world_model.parameters(), _convert_="all" @@ -1045,8 +1033,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if fabric.is_global_zero and cfg.algo.run_test: player.actor_type = "task" fabric_player = get_single_device_fabric(fabric) - player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) - test(player, fabric, cfg, log_dir, "zero-shot") + player.actor = fabric_player.setup_module(unwrap_fabric(actor_task)) + test(player, fabric, cfg, log_dir, "zero-shot", greedy=False) if not cfg.model_manager.disabled and fabric.is_global_zero: from sheeprl.algos.dreamer_v1.utils import log_models diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py index d1110472..2f7b272f 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py @@ -11,7 +11,6 @@ from lightning.fabric import Fabric from torchmetrics import SumMetric -from sheeprl.algos.dreamer_v3.agent import PlayerDV3 from sheeprl.algos.dreamer_v3.dreamer_v3 import train from sheeprl.algos.dreamer_v3.utils import Moments, test from sheeprl.algos.p2e_dv3.agent import build_agent @@ -22,7 +21,7 @@ from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import Ratio, save_configs +from sheeprl.utils.utils import Ratio, save_configs, unwrap_fabric @register_algorithm() @@ -130,15 +129,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): fabric.print("Decoder MLP keys:", cfg.algo.mlp_keys.decoder) obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder - ( - world_model, - _, - actor_task, - critic_task, - target_critic_task, - actor_exploration, - _, - ) = build_agent( + (world_model, _, actor_task, critic_task, target_critic_task, actor_exploration, _, player) = build_agent( fabric, actions_dim, is_continuous, @@ -152,20 +143,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): state["actor_exploration"], ) - # initialize the ensembles with different seeds to be sure they have different weights - player = PlayerDV3( - fabric, - world_model.encoder, - world_model.rssm, - actor_exploration if cfg.algo.player.actor_type == "exploration" else actor_task, - actions_dim, - cfg.env.num_envs, - cfg.algo.world_model.stochastic_size, - cfg.algo.world_model.recurrent_model.recurrent_state_size, - discrete_size=cfg.algo.world_model.discrete_size, - actor_type=cfg.algo.player.actor_type, - ) - # Optimizers world_optimizer = hydra.utils.instantiate( cfg.algo.world_model.optimizer, params=world_model.parameters(), _convert_="all" @@ -374,7 +351,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): if per_rank_gradient_steps > 0: if player.actor_type != "task": player.actor_type = "task" - player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) + player.actor = fabric_player.setup_module(unwrap_fabric(actor_task)) + for agent_p, p in zip(actor_task.parameters(), player.actor.parameters()): + p.data = agent_p.data local_data = rb.sample_tensors( cfg.algo.per_rank_batch_size, sequence_length=cfg.algo.per_rank_sequence_length, @@ -483,8 +462,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): # task test few-shot if fabric.is_global_zero and cfg.algo.run_test: player.actor_type = "task" - player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task)) - test(player, fabric, cfg, log_dir, "few-shot") + player.actor = fabric_player.setup_module(unwrap_fabric(actor_task)) + test(player, fabric, cfg, log_dir, "few-shot", greedy=False) if not cfg.model_manager.disabled and fabric.is_global_zero: from sheeprl.algos.dreamer_v1.utils import log_models diff --git a/sheeprl/algos/ppo/agent.py b/sheeprl/algos/ppo/agent.py index e77f4d57..51cc0f18 100644 --- a/sheeprl/algos/ppo/agent.py +++ b/sheeprl/algos/ppo/agent.py @@ -151,20 +151,17 @@ def __init__( self.actor = PPOActor(actor_backbone, actor_heads, is_continuous) def forward( - self, obs: Dict[str, Tensor], actions: Optional[List[Tensor]] = None, greedy: bool = False + self, obs: Dict[str, Tensor], actions: Optional[List[Tensor]] = None ) -> Tuple[Sequence[Tensor], Tensor, Tensor, Tensor]: feat = self.feature_extractor(obs) - values = self.critic(feat) actor_out: List[Tensor] = self.actor(feat) + values = self.critic(feat) if self.is_continuous: mean, log_std = torch.chunk(actor_out[0], chunks=2, dim=-1) std = log_std.exp() normal = Independent(Normal(mean, std), 1) if actions is None: - if greedy: - actions = mean - else: - actions = normal.sample() + actions = normal.sample() else: # always composed by a tuple of one element containing all the # continuous actions @@ -183,10 +180,7 @@ def forward( actions_dist.append(OneHotCategorical(logits=logits)) actions_entropies.append(actions_dist[-1].entropy()) if should_append: - if greedy: - actions.append(actions_dist[-1].mode) - else: - actions.append(actions_dist[-1].sample()) + actions.append(actions_dist[-1].sample()) actions_logprobs.append(actions_dist[-1].log_prob(actions[i])) return ( tuple(actions), @@ -195,10 +189,66 @@ def forward( values, ) - def get_value(self, obs: Dict[str, Tensor]) -> Tensor: + +class PPOPlayer(nn.Module): + def __init__(self, feature_extractor: MultiEncoder, actor: PPOActor, critic: nn.Module) -> None: + super().__init__() + self.feature_extractor = feature_extractor + self.critic = critic + self.actor = actor + + def forward(self, obs: Dict[str, Tensor]) -> Tuple[Sequence[Tensor], Tensor, Tensor]: + feat = self.feature_extractor(obs) + values = self.critic(feat) + actor_out: List[Tensor] = self.actor(feat) + if self.actor.is_continuous: + mean, log_std = torch.chunk(actor_out[0], chunks=2, dim=-1) + std = log_std.exp() + normal = Independent(Normal(mean, std), 1) + actions = normal.sample() + log_prob = normal.log_prob(actions) + return tuple([actions]), log_prob.unsqueeze(dim=-1), values + else: + actions_dist: List[Distribution] = [] + actions_logprobs: List[Tensor] = [] + actions: List[Tensor] = [] + for i, logits in enumerate(actor_out): + actions_dist.append(OneHotCategorical(logits=logits)) + actions.append(actions_dist[-1].sample()) + actions_logprobs.append(actions_dist[-1].log_prob(actions[i])) + return ( + tuple(actions), + torch.stack(actions_logprobs, dim=-1).sum(dim=-1, keepdim=True), + values, + ) + + def get_values(self, obs: Dict[str, Tensor]) -> Tensor: feat = self.feature_extractor(obs) return self.critic(feat) + def get_actions(self, obs: Dict[str, Tensor], greedy: bool = False) -> Sequence[Tensor]: + feat = self.feature_extractor(obs) + actor_out: List[Tensor] = self.actor(feat) + if self.actor.is_continuous: + mean, log_std = torch.chunk(actor_out[0], chunks=2, dim=-1) + if greedy: + actions = mean + else: + std = log_std.exp() + normal = Independent(Normal(mean, std), 1) + actions = normal.sample() + return tuple([actions]) + else: + actions: List[Tensor] = [] + actions_dist: List[Distribution] = [] + for logits in actor_out: + actions_dist.append(OneHotCategorical(logits=logits)) + if greedy: + actions.append(actions_dist[-1].mode) + else: + actions.append(actions_dist[-1].sample()) + return tuple(actions) + def build_agent( fabric: Fabric, @@ -207,7 +257,7 @@ def build_agent( cfg: Dict[str, Any], obs_space: gymnasium.spaces.Dict, agent_state: Optional[Dict[str, Tensor]] = None, -) -> Tuple[PPOAgent, PPOAgent]: +) -> Tuple[PPOAgent, PPOPlayer]: agent = PPOAgent( actions_dim=actions_dim, obs_space=obs_space, @@ -222,16 +272,26 @@ def build_agent( ) if agent_state: agent.load_state_dict(agent_state) - player = copy.deepcopy(agent) + + # Setup player agent + player = PPOPlayer(copy.deepcopy(agent.feature_extractor), copy.deepcopy(agent.actor), copy.deepcopy(agent.critic)) # Setup training agent - agent = fabric.setup_module(agent) + agent.feature_extractor = fabric.setup_module(agent.feature_extractor) + agent.critic = fabric.setup_module(agent.critic) + agent.actor = fabric.setup_module(agent.actor) # Setup player agent fabric_player = get_single_device_fabric(fabric) - player = fabric_player.setup_module(player) + player.feature_extractor = fabric_player.setup_module(player.feature_extractor) + player.critic = fabric_player.setup_module(player.critic) + player.actor = fabric_player.setup_module(player.actor) # Tie weights between the agent and the player - for agent_p, player_p in zip(agent.parameters(), player.parameters()): + for agent_p, player_p in zip(agent.feature_extractor.parameters(), player.feature_extractor.parameters()): + player_p.data = agent_p.data + for agent_p, player_p in zip(agent.actor.parameters(), player.actor.parameters()): + player_p.data = agent_p.data + for agent_p, player_p in zip(agent.critic.parameters(), player.critic.parameters()): player_p.data = agent_p.data return agent, player diff --git a/sheeprl/algos/ppo/evaluate.py b/sheeprl/algos/ppo/evaluate.py index 4725c2bd..82fa9031 100644 --- a/sheeprl/algos/ppo/evaluate.py +++ b/sheeprl/algos/ppo/evaluate.py @@ -49,7 +49,8 @@ def evaluate_ppo(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): else (env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]) ) # Create the actor and critic models - agent, _ = build_agent(fabric, actions_dim, is_continuous, cfg, observation_space, state["agent"]) + _, agent = build_agent(fabric, actions_dim, is_continuous, cfg, observation_space, state["agent"]) + del _ test(agent, fabric, cfg, log_dir) diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index ba2aa447..d2757486 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -275,7 +275,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): torch_obs = { k: torch.as_tensor(normalized_obs[k], dtype=torch.float32, device=device) for k in obs_keys } - actions, logprobs, _, values = player(torch_obs) + actions, logprobs, values = player(torch_obs) if is_continuous: real_actions = torch.cat(actions, -1).cpu().numpy() else: @@ -302,10 +302,8 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): torch_v = torch_v.view(-1, *v.shape[-2:]) torch_v = torch_v / 255.0 - 0.5 real_next_obs[k][i] = torch_v - _, _, _, vals = player(real_next_obs) - rewards[truncated_envs] += cfg.algo.gamma * vals.cpu().numpy().reshape( - rewards[truncated_envs].shape - ) + vals = player.get_values(real_next_obs).cpu().numpy() + rewards[truncated_envs] += cfg.algo.gamma * vals.reshape(rewards[truncated_envs].shape) dones = np.logical_or(terminated, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8) rewards = rewards.reshape(cfg.env.num_envs, -1) @@ -349,7 +347,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): with torch.inference_mode(): normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys) torch_obs = {k: torch.as_tensor(normalized_obs[k], dtype=torch.float32, device=device) for k in obs_keys} - _, _, _, next_values = player(torch_obs) + next_values = player.get_values(torch_obs) returns, advantages = gae( local_data["rewards"].to(torch.float64), local_data["values"], diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index 866228d7..a1d8bb1c 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -16,7 +16,7 @@ from torch.utils.data import BatchSampler, RandomSampler from torchmetrics import SumMetric -from sheeprl.algos.ppo.agent import PPOAgent +from sheeprl.algos.ppo.agent import build_agent from sheeprl.algos.ppo.loss import entropy_loss, policy_loss, value_loss from sheeprl.algos.ppo.utils import normalize_obs, test from sheeprl.data.buffers import ReplayBuffer @@ -31,15 +31,19 @@ @torch.inference_mode() def player( - fabric: Fabric, cfg: Dict[str, Any], world_collective: TorchCollective, player_trainer_collective: TorchCollective + fabric: Fabric, + world_collective: TorchCollective, + player_trainer_collective: TorchCollective, + cfg: Dict[str, Any], ): - # Initialize the fabric object - log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name, False) - device = fabric.device + # Initialize Fabric player-only + fabric_player = get_single_device_fabric(fabric) + log_dir = get_log_dir(fabric_player, cfg.root_dir, cfg.run_name, False) + device = fabric_player.device # Resume from checkpoint if cfg.checkpoint.resume_from: - state = fabric.load(cfg.checkpoint.resume_from) + state = fabric_player.load(cfg.checkpoint.resume_from) # Environment setup vectorized_env = gym.vector.SyncVectorEnv if cfg.env.sync_env else gym.vector.AsyncVectorEnv @@ -91,9 +95,15 @@ def player( "distribution_cfg": cfg.distribution, "is_continuous": is_continuous, } - agent = PPOAgent(**agent_args).to(device) - fabric_player = get_single_device_fabric(fabric) - agent = fabric_player.setup_module(agent, move_to_device=False) + _, agent = build_agent( + fabric_player, + actions_dim=actions_dim, + is_continuous=is_continuous, + cfg=cfg, + obs_space=observation_space, + agent_state=state["agent"] if cfg.checkpoint.resume_from else None, + ) + del _ if fabric.is_global_zero: save_configs(cfg, log_dir) @@ -198,7 +208,7 @@ def player( torch_obs = { k: torch.as_tensor(normalized_obs[k], dtype=torch.float32, device=device) for k in obs_keys } - actions, logprobs, _, values = agent(torch_obs) + actions, logprobs, values = agent(torch_obs) if is_continuous: real_actions = torch.cat(actions, -1).cpu().numpy() else: @@ -225,10 +235,8 @@ def player( torch_v = torch_v.view(-1, *v.shape[-2:]) torch_v = torch_v / 255.0 - 0.5 real_next_obs[k][i] = torch_v - _, _, _, vals = agent(real_next_obs) - rewards[truncated_envs] += cfg.algo.gamma * vals.cpu().numpy().reshape( - rewards[truncated_envs].shape - ) + vals = agent.get_values(real_next_obs).cpu().numpy() + rewards[truncated_envs] += cfg.algo.gamma * vals.reshape(rewards[truncated_envs].shape) dones = np.logical_or(terminated, truncated).reshape(cfg.env.num_envs, -1).astype(np.uint8) rewards = rewards.reshape(cfg.env.num_envs, -1) @@ -270,7 +278,7 @@ def player( # Estimate returns with GAE (https://arxiv.org/abs/1506.02438) normalized_obs = normalize_obs(next_obs, cfg.algo.cnn_keys.encoder, obs_keys) torch_obs = {k: torch.as_tensor(normalized_obs[k], dtype=torch.float32, device=device) for k in obs_keys} - _, _, _, next_values = agent(torch_obs) + next_values = agent.get_values(torch_obs) returns, advantages = gae( local_data["rewards"].to(torch.float64), local_data["values"], @@ -393,7 +401,15 @@ def trainer( world_collective.broadcast_object_list(agent_args, src=0) # Define the agent and the optimizer - agent = PPOAgent(**agent_args[0]) + agent, _ = build_agent( + fabric, + actions_dim=agent_args[0]["actions_dim"], + is_continuous=agent_args[0]["is_continuous"], + cfg=cfg, + obs_space=agent_args[0]["obs_space"], + agent_state=state["agent"] if cfg.checkpoint.resume_from else None, + ) + del _ optimizer = hydra.utils.instantiate(cfg.algo.optimizer, params=agent.parameters(), _convert_="all") # Load the state from the checkpoint @@ -402,7 +418,6 @@ def trainer( optimizer.load_state_dict(state["optimizer"]) # Setup agent and optimizer with Fabric - agent = fabric.setup_module(agent) optimizer = fabric.setup_optimizers(optimizer) # Send weights to rank-0, a.k.a. the player @@ -483,7 +498,9 @@ def trainer( ): # The Join context is needed because there can be the possibility # that some ranks receive less data - with Join([agent._forward_module]): + with Join( + [agent.feature_extractor._forward_module, agent.actor._forward_module, agent.critic._forward_module] + ): for _ in range(cfg.algo.update_epochs): for batch_idxes in sampler: batch = {k: data[k][batch_idxes] for k in data.keys()} @@ -651,6 +668,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ranks=list(range(1, world_collective.world_size)), timeout=timedelta(days=1) ) if global_rank == 0: - player(fabric, cfg, world_collective, player_trainer_collective) + player(fabric, world_collective, player_trainer_collective, cfg) else: trainer(world_collective, player_trainer_collective, optimization_pg, cfg) diff --git a/sheeprl/algos/ppo/utils.py b/sheeprl/algos/ppo/utils.py index aee52933..e3b55340 100644 --- a/sheeprl/algos/ppo/utils.py +++ b/sheeprl/algos/ppo/utils.py @@ -10,7 +10,7 @@ from lightning.fabric.wrappers import _FabricModule from torch import Tensor -from sheeprl.algos.ppo.agent import PPOAgent, build_agent +from sheeprl.algos.ppo.agent import PPOPlayer, build_agent from sheeprl.utils.env import make_env from sheeprl.utils.imports import _IS_MLFLOW_AVAILABLE from sheeprl.utils.utils import unwrap_fabric @@ -23,7 +23,7 @@ @torch.no_grad() -def test(agent: PPOAgent | _FabricModule, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): +def test(agent: PPOPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)() agent.eval() done = False @@ -41,8 +41,8 @@ def test(agent: PPOAgent | _FabricModule, fabric: Fabric, cfg: Dict[str, Any], l while not done: # Act greedly through the environment - actions, _, _, _ = agent(obs, greedy=True) - if agent.is_continuous: + actions = agent.get_actions(obs, greedy=True) + if agent.actor.is_continuous: actions = torch.cat(actions, dim=-1) else: actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1) diff --git a/sheeprl/algos/ppo_recurrent/agent.py b/sheeprl/algos/ppo_recurrent/agent.py index 22afc5c8..4b3e9b19 100644 --- a/sheeprl/algos/ppo_recurrent/agent.py +++ b/sheeprl/algos/ppo_recurrent/agent.py @@ -151,7 +151,7 @@ def __init__( ) # Actor - self.actor_backbone = MLP( + actor_backbone = MLP( input_dims=self.rnn_hidden_size, output_dim=None, hidden_sizes=[actor_cfg.dense_units] * actor_cfg.mlp_layers, @@ -165,12 +165,10 @@ def __init__( ), ) if is_continuous: - self.actor_heads = nn.ModuleList([nn.Linear(actor_cfg.dense_units, int(sum(actions_dim)) * 2)]) + actor_heads = nn.ModuleList([nn.Linear(actor_cfg.dense_units, int(sum(actions_dim)) * 2)]) else: - self.actor_heads = nn.ModuleList( - [nn.Linear(actor_cfg.dense_units, action_dim) for action_dim in actions_dim] - ) - self.actor = PPOActor(self.actor_backbone, self.actor_heads, is_continuous) + actor_heads = nn.ModuleList([nn.Linear(actor_cfg.dense_units, action_dim) for action_dim in actions_dim]) + self.actor = PPOActor(actor_backbone, actor_heads, is_continuous) # Initial recurrent states for both the actor and critic rnn self._initial_states: Tensor = self.reset_hidden_states() @@ -190,13 +188,118 @@ def reset_hidden_states(self) -> Tuple[Tensor, Tensor]: ) return states - def get_actions( - self, pre_dist: Tuple[Tensor, ...], actions: Optional[List[Tensor]] = None, greedy: bool = False + def _get_actions( + self, pre_dist: Tuple[Tensor, ...], actions: Optional[List[Tensor]] = None ) -> Tuple[Tuple[Tensor, ...], Tensor, Tensor]: logprobs = [] entropies = [] sampled_actions = [] if self.is_continuous: + dist = Independent(Normal(*pre_dist), 1) + if actions is None: + sampled_actions.append(dist.sample()) + else: + sampled_actions.append(actions[0]) + entropies.append(dist.entropy()) + logprobs.append(dist.log_prob(actions)) + else: + for i, logits in enumerate(pre_dist): + dist = OneHotCategorical(logits=logits) + if actions is None: + sampled_actions.append(dist.sample()) + else: + sampled_actions.append(actions[i]) + entropies.append(dist.entropy()) + logprobs.append(dist.log_prob(sampled_actions[-1])) + return ( + tuple(sampled_actions), + torch.stack(logprobs, dim=-1).sum(dim=-1, keepdim=True), + torch.stack(entropies, dim=-1).sum(dim=-1, keepdim=True), + ) + + def _get_pre_dist(self, input: Tensor) -> Union[Tuple[Tensor, ...], Tuple[Tensor, Tensor]]: + pre_dist: List[Tensor] = self.actor(input) + if self.is_continuous: + mean, log_std = torch.chunk(pre_dist[0], chunks=2, dim=-1) + std = log_std.exp() + return (mean, std) + else: + return tuple(pre_dist) + + def _get_values(self, input: Tensor) -> Tensor: + return self.critic(input) + + def forward( + self, + obs: Dict[str, Tensor], + prev_actions: Tensor, + prev_states: Tuple[Tensor, Tensor], + actions: Optional[List[Tensor]] = None, + mask: Optional[Tensor] = None, + ) -> Tuple[Tuple[Tensor, ...], Tensor, Tensor, Tensor, Tuple[Tensor, Tensor]]: + """Compute actor logits and critic values. + + Args: + obs (Tensor): observations collected (possibly padded with zeros). + prev_actions (Tensor): the previous actions. + prev_states (Tuple[Tensor, Tensor]): the previous state of the LSTM. + actions (List[Tensor], optional): the actions from the replay buffer. + mask (Tensor, optional): the mask of the padded sequences. + + Returns: + actions (Tuple[Tensor, ...]): the sampled actions + logprobs (Tensor): the log probabilities of the actions w.r.t. their distributions. + entropies (Tensor): the entropies of the actions distributions. + values (Tensor): the state values. + states (Tuple[Tensor, Tensor]): the new recurrent states (hx, cx). + """ + embedded_obs = self.feature_extractor(obs) + out, states = self.rnn(torch.cat((embedded_obs, prev_actions), dim=-1), prev_states, mask) + values = self._get_values(out) + pre_dist = self._get_pre_dist(out) + actions, logprobs, entropies = self._get_actions(pre_dist, actions) + return actions, logprobs, entropies, values, states + + +class RecurrentPPOPlayer(nn.Module): + def __init__( + self, + feature_extractor: MultiEncoder, + rnn: RecurrentModel, + actor: PPOActor, + critic: nn.Module, + rnn_hidden_size: int, + actions_dim: Sequence[int], + ) -> None: + super().__init__() + self.feature_extractor = feature_extractor + self.rnn = rnn + self.critic = critic + self.actor = actor + self.rnn_hidden_size = rnn_hidden_size + self.actions_dim = actions_dim + + @property + def initial_states(self) -> Tuple[Tensor, Tensor]: + return self._initial_states + + @initial_states.setter + def initial_states(self, value: Tuple[Tensor, Tensor]) -> None: + self._initial_states = value + + def reset_hidden_states(self) -> Tuple[Tensor, Tensor]: + states = ( + torch.zeros(1, self.num_envs, self.rnn_hidden_size, device=self.device), + torch.zeros(1, self.num_envs, self.rnn_hidden_size, device=self.device), + ) + return states + + def _get_actions( + self, pre_dist: Tuple[Tensor, ...], actions: Optional[List[Tensor]] = None, greedy: bool = False + ) -> Tuple[Tuple[Tensor, ...], Tensor]: + logprobs = [] + sampled_actions = [] + if self.actor.is_continuous: dist = Independent(Normal(*pre_dist), 1) if greedy: sampled_actions.append(dist.mode) @@ -205,7 +308,6 @@ def get_actions( sampled_actions.append(dist.sample()) else: sampled_actions.append(actions[0]) - entropies.append(dist.entropy()) logprobs.append(dist.log_prob(actions)) else: for i, logits in enumerate(pre_dist): @@ -217,24 +319,22 @@ def get_actions( sampled_actions.append(dist.sample()) else: sampled_actions.append(actions[i]) - entropies.append(dist.entropy()) logprobs.append(dist.log_prob(sampled_actions[-1])) return ( tuple(sampled_actions), torch.stack(logprobs, dim=-1).sum(dim=-1, keepdim=True), - torch.stack(entropies, dim=-1).sum(dim=-1, keepdim=True), ) - def get_pre_dist(self, input: Tensor) -> Union[Tuple[Tensor, ...], Tuple[Tensor, Tensor]]: + def _get_pre_dist(self, input: Tensor) -> Union[Tuple[Tensor, ...], Tuple[Tensor, Tensor]]: pre_dist: List[Tensor] = self.actor(input) - if self.is_continuous: + if self.actor.is_continuous: mean, log_std = torch.chunk(pre_dist[0], chunks=2, dim=-1) std = log_std.exp() return (mean, std) else: return tuple(pre_dist) - def get_values(self, input: Tensor) -> Tensor: + def _get_values(self, input: Tensor) -> Tensor: return self.critic(input) def forward( @@ -264,10 +364,48 @@ def forward( """ embedded_obs = self.feature_extractor(obs) out, states = self.rnn(torch.cat((embedded_obs, prev_actions), dim=-1), prev_states, mask) - values = self.get_values(out) - pre_dist = self.get_pre_dist(out) - actions, logprobs, entropies = self.get_actions(pre_dist, actions, greedy=greedy) - return actions, logprobs, entropies, values, states + values = self._get_values(out) + pre_dist = self._get_pre_dist(out) + actions, logprobs = self._get_actions(pre_dist, actions, greedy=greedy) + return actions, logprobs, values, states + + def get_values( + self, + obs: Dict[str, Tensor], + prev_actions: Tensor, + prev_states: Tuple[Tensor, Tensor], + mask: Optional[Tensor] = None, + ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: + embedded_obs = self.feature_extractor(obs) + out, states = self.rnn(torch.cat((embedded_obs, prev_actions), dim=-1), prev_states, mask) + return self._get_values(out), states + + def get_actions( + self, + obs: Dict[str, Tensor], + prev_actions: Tensor, + prev_states: Tuple[Tensor, Tensor], + mask: Optional[Tensor] = None, + greedy: bool = False, + ) -> Tuple[Sequence[Tensor], Tuple[Tensor, Tensor]]: + embedded_obs = self.feature_extractor(obs) + out, states = self.rnn(torch.cat((embedded_obs, prev_actions), dim=-1), prev_states, mask) + pre_dist = self._get_pre_dist(out) + sampled_actions = [] + if self.actor.is_continuous: + dist = Independent(Normal(*pre_dist), 1) + if greedy: + sampled_actions.append(dist.mode) + else: + sampled_actions.append(dist.sample()) + else: + for logits in pre_dist: + dist = OneHotCategorical(logits=logits) + if greedy: + sampled_actions.append(dist.mode) + else: + sampled_actions.append(dist.sample()) + return tuple(sampled_actions), states def build_agent( @@ -277,7 +415,7 @@ def build_agent( cfg: Dict[str, Any], obs_space: gymnasium.spaces.Dict, agent_state: Optional[Dict[str, Tensor]] = None, -) -> Tuple[RecurrentPPOAgent, RecurrentPPOAgent]: +) -> Tuple[RecurrentPPOAgent, RecurrentPPOPlayer]: agent = RecurrentPPOAgent( actions_dim=actions_dim, obs_space=obs_space, @@ -295,7 +433,16 @@ def build_agent( ) if agent_state: agent.load_state_dict(agent_state) - player = copy.deepcopy(agent) + + # Setup player agent + player = RecurrentPPOPlayer( + copy.deepcopy(agent.feature_extractor), + copy.deepcopy(agent.rnn), + copy.deepcopy(agent.actor), + copy.deepcopy(agent.critic), + cfg.algo.rnn.lstm.hidden_size, + actions_dim, + ) # Setup training agent agent.feature_extractor = fabric.setup_module(agent.feature_extractor) @@ -311,7 +458,12 @@ def build_agent( player.actor = fabric_player.setup_module(player.actor) # Tie weights between the agent and the player - for agent_p, player_p in zip(agent.parameters(), player.parameters()): + for agent_p, player_p in zip(agent.feature_extractor.parameters(), player.feature_extractor.parameters()): + player_p.data = agent_p.data + for agent_p, player_p in zip(agent.rnn.parameters(), player.rnn.parameters()): + player_p.data = agent_p.data + for agent_p, player_p in zip(agent.actor.parameters(), player.actor.parameters()): + player_p.data = agent_p.data + for agent_p, player_p in zip(agent.critic.parameters(), player.critic.parameters()): player_p.data = agent_p.data - return agent, player diff --git a/sheeprl/algos/ppo_recurrent/evaluate.py b/sheeprl/algos/ppo_recurrent/evaluate.py index 12f57dba..740d5f19 100644 --- a/sheeprl/algos/ppo_recurrent/evaluate.py +++ b/sheeprl/algos/ppo_recurrent/evaluate.py @@ -49,5 +49,6 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): else (env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]) ) # Create the actor and critic models - agent, _ = build_agent(fabric, actions_dim, is_continuous, cfg, observation_space, state["agent"]) + _, agent = build_agent(fabric, actions_dim, is_continuous, cfg, observation_space, state["agent"]) + del _ test(agent, fabric, cfg, log_dir) diff --git a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py index 4054261d..dd2a33a7 100644 --- a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py +++ b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py @@ -297,7 +297,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # [Seq_len, Batch_size, D] --> [1, num_envs, D] normalized_obs = normalize_obs(obs, cfg.algo.cnn_keys.encoder, obs_keys) torch_obs = {k: torch.as_tensor(v, device=device).float() for k, v in normalized_obs.items()} - actions, logprobs, _, values, states = player( + actions, logprobs, values, states = player( torch_obs, prev_actions=torch_prev_actions, prev_states=prev_states ) if is_continuous: @@ -329,12 +329,12 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): if k in cfg.algo.cnn_keys.encoder: torch_v = torch_v.view(1, 1, -1, *torch_v.shape[-2:]) / 255.0 - 0.5 real_next_obs[k][0, i] = torch_v - feat = player.feature_extractor(real_next_obs) - rnn_out, _ = player.rnn( - torch.cat((feat, torch_actions[:, truncated_envs, :]), dim=-1), + vals, _ = player.get_values( + real_next_obs, + torch_actions[:, truncated_envs, :], tuple(s[:, truncated_envs, ...] for s in states), ) - vals = player.get_values(rnn_out).view(rewards[truncated_envs].shape).cpu().numpy() + vals = vals.view(rewards[truncated_envs].shape).cpu().numpy() rewards[truncated_envs] += cfg.algo.gamma * vals.reshape(rewards[truncated_envs].shape) dones = np.logical_or(terminated, truncated).reshape(1, cfg.env.num_envs, -1).astype(np.float32) rewards = rewards.reshape(1, cfg.env.num_envs, -1).astype(np.float32) @@ -389,9 +389,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): with torch.inference_mode(): normalized_obs = normalize_obs(obs, cfg.algo.cnn_keys.encoder, obs_keys) torch_obs = {k: torch.as_tensor(v, device=device).float() for k, v in normalized_obs.items()} - feat = player.feature_extractor(torch_obs) - rnn_out, _ = player.rnn(torch.cat((feat, torch_actions), dim=-1), states) - next_values = player.get_values(rnn_out) + next_values, _ = player.get_values(torch_obs, torch_actions, states) returns, advantages = gae( local_data["rewards"].to(torch.float64), local_data["values"], diff --git a/sheeprl/algos/ppo_recurrent/utils.py b/sheeprl/algos/ppo_recurrent/utils.py index 64f388f8..4e643d6d 100644 --- a/sheeprl/algos/ppo_recurrent/utils.py +++ b/sheeprl/algos/ppo_recurrent/utils.py @@ -8,7 +8,7 @@ from sheeprl.algos.ppo.utils import AGGREGATOR_KEYS as ppo_aggregator_keys from sheeprl.algos.ppo.utils import MODELS_TO_REGISTER as ppo_models_to_register -from sheeprl.algos.ppo_recurrent.agent import RecurrentPPOAgent, build_agent +from sheeprl.algos.ppo_recurrent.agent import RecurrentPPOPlayer, build_agent from sheeprl.utils.env import make_env from sheeprl.utils.imports import _IS_MLFLOW_AVAILABLE from sheeprl.utils.utils import unwrap_fabric @@ -22,7 +22,7 @@ @torch.no_grad() -def test(agent: "RecurrentPPOAgent", fabric: Fabric, cfg: Dict[str, Any], log_dir: str): +def test(agent: "RecurrentPPOPlayer", fabric: Fabric, cfg: Dict[str, Any], log_dir: str): env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)() agent.eval() done = False @@ -47,8 +47,8 @@ def test(agent: "RecurrentPPOAgent", fabric: Fabric, cfg: Dict[str, Any], log_di actions = torch.zeros(1, 1, sum(agent.actions_dim), device=fabric.device) while not done: # Act greedly through the environment - actions, _, _, _, state = agent(next_obs, actions, state, greedy=True) - if agent.is_continuous: + actions, state = agent.get_actions(next_obs, actions, state, greedy=True) + if agent.actor.is_continuous: real_actions = torch.cat(actions, -1) actions = torch.cat(actions, dim=-1).view(1, 1, -1) else: diff --git a/sheeprl/algos/sac/agent.py b/sheeprl/algos/sac/agent.py index 63b410a9..0b77738b 100644 --- a/sheeprl/algos/sac/agent.py +++ b/sheeprl/algos/sac/agent.py @@ -89,7 +89,7 @@ def __init__( self.register_buffer("action_scale", torch.tensor((action_high - action_low) / 2.0, dtype=torch.float32)) self.register_buffer("action_bias", torch.tensor((action_high + action_low) / 2.0, dtype=torch.float32)) - def forward(self, obs: Tensor, greedy: bool = False) -> Union[Tensor, Tuple[Tensor, Tensor]]: + def forward(self, obs: Tensor) -> Tuple[Tensor, Tensor]: """Given an observation, it returns a tanh-squashed sampled action (correctly rescaled to the environment action bounds) and its log-prob (as defined in Eq. 26 of https://arxiv.org/abs/1812.05905) @@ -103,14 +103,11 @@ def forward(self, obs: Tensor, greedy: bool = False) -> Union[Tensor, Tuple[Tens """ x = self.model(obs) mean = self.fc_mean(x) - if greedy: - return torch.tanh(mean) * self.action_scale + self.action_bias - else: - log_std = self.fc_logstd(x) - std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX).exp() - return self.get_actions_and_log_probs(mean, std) + log_std = self.fc_logstd(x) + std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX).exp() + return self._get_actions_and_log_probs(mean, std) - def get_actions_and_log_probs(self, mean: Tensor, std: Tensor): + def _get_actions_and_log_probs(self, mean: Tensor, std: Tensor) -> Tuple[Tensor, Tensor]: """Given the mean and the std of a Normal distribution, it returns a tanh-squashed sampled action (correctly rescaled to the environment action bounds) and its log-prob (as defined in Eq. 26 of https://arxiv.org/abs/1812.05905) @@ -248,9 +245,6 @@ def log_alpha(self) -> Tensor: def get_actions_and_log_probs(self, obs: Tensor) -> Tuple[Tensor, Tensor]: return self.actor(obs) - def get_greedy_actions(self, obs: Tensor) -> Tensor: - return self.actor.get_greedy_actions(obs) - def get_q_values(self, obs: Tensor, action: Tensor) -> Tensor: return torch.cat([self.qfs[i](obs, action) for i in range(len(self.qfs))], dim=-1) @@ -273,13 +267,60 @@ def qfs_target_ema(self) -> None: target_param.data.copy_(self._tau * param.data + (1 - self._tau) * target_param.data) +class SACPlayer(nn.Module): + def __init__( + self, + feature_extractor: nn.Module, + fc_mean: nn.Module, + fc_logstd: nn.Module, + action_low: Union[SupportsFloat, NDArray] = -1.0, + action_high: Union[SupportsFloat, NDArray] = 1.0, + ): + super().__init__() + self.model = feature_extractor + self.fc_mean = fc_mean + self.fc_logstd = fc_logstd + + # Action rescaling buffers + self.register_buffer("action_scale", torch.tensor((action_high - action_low) / 2.0, dtype=torch.float32)) + self.register_buffer("action_bias", torch.tensor((action_high + action_low) / 2.0, dtype=torch.float32)) + + def forward(self, obs: Tensor, greedy: bool = False) -> Tensor: + """Given an observation, it returns a tanh-squashed + sampled action (correctly rescaled to the environment action bounds) and its + log-prob (as defined in Eq. 26 of https://arxiv.org/abs/1812.05905) + + Args: + obs (Tensor): the observation tensor + + Returns: + tanh-squashed action, rescaled to the environment action bounds + action log-prob + """ + x = self.model(obs) + mean = self.fc_mean(x) + if greedy: + return torch.tanh(mean) * self.action_scale + self.action_bias + else: + log_std = self.fc_logstd(x) + std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX).exp() + normal = torch.distributions.Normal(mean, std) + x_t = normal.rsample() + y_t = torch.tanh(x_t) + actions = y_t * self.action_scale + self.action_bias + return actions + + def get_actions(self, obs: Tensor, greedy: bool = False) -> Tensor: + return self(obs, greedy=greedy) + + def build_agent( fabric: Fabric, cfg: Dict[str, Any], obs_space: gymnasium.spaces.Dict, action_space: gymnasium.spaces.Box, agent_state: Optional[Dict[str, Tensor]] = None, -) -> SACAgent: +) -> Tuple[SACAgent, SACPlayer]: act_dim = prod(action_space.shape) obs_dim = sum([prod(obs_space[k].shape) for k in cfg.algo.mlp_keys.encoder]) actor = SACActor( @@ -298,6 +339,17 @@ def build_agent( agent = SACAgent(actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device) if agent_state: agent.load_state_dict(agent_state) + + # Setup player agent + player = SACPlayer( + copy.deepcopy(agent.actor.model), + copy.deepcopy(agent.actor.fc_mean), + copy.deepcopy(agent.actor.fc_logstd), + action_low=action_space.low, + action_high=action_space.high, + ) + + # Setup training agent agent.actor = fabric.setup_module(agent.actor) agent.critics = [fabric.setup_module(critic) for critic in agent.critics] @@ -306,4 +358,14 @@ def build_agent( fabric_player = get_single_device_fabric(fabric) agent.qfs_target = nn.ModuleList([fabric_player.setup_module(target) for target in agent.qfs_target]) - return agent + # Setup player agent + player.model = fabric_player.setup_module(player.model) + player.fc_mean = fabric_player.setup_module(player.fc_mean) + player.fc_logstd = fabric_player.setup_module(player.fc_logstd) + player.action_scale = player.action_scale.to(fabric_player.device) + player.action_bias = player.action_bias.to(fabric_player.device) + + # Tie weights between the agent and the player + for agent_p, player_p in zip(agent.actor.parameters(), player.parameters()): + player_p.data = agent_p.data + return agent, player diff --git a/sheeprl/algos/sac/evaluate.py b/sheeprl/algos/sac/evaluate.py index 11f70741..2efb1c36 100644 --- a/sheeprl/algos/sac/evaluate.py +++ b/sheeprl/algos/sac/evaluate.py @@ -45,5 +45,6 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): ) fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) - agent = build_agent(fabric, cfg, observation_space, action_space, state["agent"]) - test(agent.actor, fabric, cfg, log_dir) + _, agent = build_agent(fabric, cfg, observation_space, action_space, state["agent"]) + del _ + test(agent, fabric, cfg, log_dir) diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py index ec662cd2..4560c28d 100644 --- a/sheeprl/algos/sac/sac.py +++ b/sheeprl/algos/sac/sac.py @@ -22,7 +22,6 @@ from sheeprl.algos.sac.utils import test from sheeprl.data.buffers import ReplayBuffer from sheeprl.utils.env import make_env -from sheeprl.utils.fabric import get_single_device_fabric from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm @@ -144,18 +143,26 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) # Define the agent and the optimizer and setup sthem with Fabric - agent = build_agent( + agent, player = build_agent( fabric, cfg, observation_space, action_space, state["agent"] if cfg.checkpoint.resume_from else None ) - fabric_player = get_single_device_fabric(fabric) - actor = fabric_player.setup_module(agent.actor.module) # Optimizers - qf_optimizer = hydra.utils.instantiate(cfg.algo.critic.optimizer, params=agent.qfs.parameters(), _convert_="all") + qf_optimizer = hydra.utils.instantiate( + cfg.algo.critic.optimizer, + params=agent.qfs.parameters(), + _convert_="all", + ) actor_optimizer = hydra.utils.instantiate( - cfg.algo.actor.optimizer, params=agent.actor.parameters(), _convert_="all" + cfg.algo.actor.optimizer, + params=agent.actor.parameters(), + _convert_="all", + ) + alpha_optimizer = hydra.utils.instantiate( + cfg.algo.alpha.optimizer, + params=[agent.log_alpha], + _convert_="all", ) - alpha_optimizer = hydra.utils.instantiate(cfg.algo.alpha.optimizer, params=[agent.log_alpha], _convert_="all") if cfg.checkpoint.resume_from: qf_optimizer.load_state_dict(state["qf_optimizer"]) actor_optimizer.load_state_dict(state["actor_optimizer"]) @@ -251,7 +258,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Sample an action given the observation received by the environment with torch.inference_mode(): torch_obs = torch.as_tensor(obs, dtype=torch.float32, device=device) - actions, _ = actor(torch_obs) + actions = player(torch_obs) actions = actions.cpu().numpy() next_obs, rewards, terminated, truncated, infos = envs.step(actions) next_obs = np.concatenate([next_obs[k] for k in cfg.algo.mlp_keys.encoder], axis=-1) @@ -406,7 +413,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() if fabric.is_global_zero and cfg.algo.run_test: - test(actor, fabric, cfg, log_dir) + test(player, fabric, cfg, log_dir) if not cfg.model_manager.disabled and fabric.is_global_zero: from sheeprl.algos.sac.utils import log_models diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py index ee0a0ceb..b76337db 100644 --- a/sheeprl/algos/sac/sac_decoupled.py +++ b/sheeprl/algos/sac/sac_decoupled.py @@ -16,7 +16,7 @@ from torch.utils.data.sampler import BatchSampler from torchmetrics import SumMetric -from sheeprl.algos.sac.agent import SACActor, SACAgent, SACCritic, build_agent +from sheeprl.algos.sac.agent import SACAgent, SACCritic, build_agent from sheeprl.algos.sac.sac import train from sheeprl.algos.sac.utils import test from sheeprl.data.buffers import ReplayBuffer @@ -31,15 +31,17 @@ @torch.inference_mode() def player( - fabric: Fabric, cfg: Dict[str, Any], world_collective: TorchCollective, player_trainer_collective: TorchCollective + fabric: Fabric, world_collective: TorchCollective, player_trainer_collective: TorchCollective, cfg: Dict[str, Any] ): - log_dir = get_log_dir(fabric, cfg.root_dir, cfg.run_name, False) + # Initialize Fabric player-only + fabric_player = get_single_device_fabric(fabric) + log_dir = get_log_dir(fabric_player, cfg.root_dir, cfg.run_name, False) + device = fabric_player.device rank = fabric.global_rank - device = fabric.device # Resume from checkpoint if cfg.checkpoint.resume_from: - state = fabric.load(cfg.checkpoint.resume_from) + state = fabric_player.load(cfg.checkpoint.resume_from) if len(cfg.algo.cnn_keys.encoder) > 0: warnings.warn("SAC algorithm cannot allow to use images as observations, the CNN keys will be ignored") @@ -89,16 +91,13 @@ def player( # Define the agent and the optimizer and setup them with Fabric act_dim = prod(action_space.shape) obs_dim = sum([prod(observation_space[k].shape) for k in cfg.algo.mlp_keys.encoder]) - actor = SACActor( - observation_dim=obs_dim, - action_dim=act_dim, - distribution_cfg=cfg.distribution, - hidden_size=cfg.algo.actor.hidden_size, - action_low=action_space.low, - action_high=action_space.high, - ).to(device) - fabric_player = get_single_device_fabric(fabric) - actor = fabric_player.setup_module(actor, move_to_device=False) + _, actor = build_agent( + fabric_player, + cfg, + observation_space, + action_space, + state["agent"] if cfg.checkpoint.resume_from else None, + ) flattened_parameters = torch.empty_like( torch.nn.utils.convert_parameters.parameters_to_vector(actor.parameters()), device=device ) @@ -188,7 +187,7 @@ def player( else: # Sample an action given the observation received by the environment torch_obs = torch.as_tensor(obs, dtype=torch.float32, device=device) - actions, _ = actor(torch_obs) + actions = actor(torch_obs) actions = actions.cpu().numpy() next_obs, rewards, terminated, truncated, infos = envs.step(actions) next_obs = np.concatenate([next_obs[k] for k in cfg.algo.mlp_keys.encoder], axis=-1) @@ -383,7 +382,7 @@ def trainer( assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" # Define the agent and the optimizer and setup them with Fabric - agent = build_agent( + agent, _ = build_agent( fabric, cfg, envs.single_observation_space, @@ -579,6 +578,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ranks=list(range(1, world_collective.world_size)), timeout=timedelta(days=1) ) if global_rank == 0: - player(fabric, cfg, world_collective, player_trainer_collective) + player(fabric, world_collective, player_trainer_collective, cfg) else: trainer(world_collective, player_trainer_collective, optimization_pg, cfg) diff --git a/sheeprl/algos/sac/utils.py b/sheeprl/algos/sac/utils.py index 3fe14d31..ae624cf1 100644 --- a/sheeprl/algos/sac/utils.py +++ b/sheeprl/algos/sac/utils.py @@ -8,7 +8,7 @@ from lightning import Fabric from lightning.fabric.wrappers import _FabricModule -from sheeprl.algos.sac.agent import SACActor, build_agent +from sheeprl.algos.sac.agent import SACPlayer, build_agent from sheeprl.utils.env import make_env from sheeprl.utils.imports import _IS_MLFLOW_AVAILABLE from sheeprl.utils.utils import unwrap_fabric @@ -27,7 +27,7 @@ @torch.no_grad() -def test(actor: SACActor, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): +def test(actor: SACPlayer, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): env = make_env(cfg, None, 0, log_dir, "test", vector_env_idx=0)() actor.eval() done = False @@ -41,7 +41,7 @@ def test(actor: SACActor, fabric: Fabric, cfg: Dict[str, Any], log_dir: str): ) # [N_envs, N_obs] while not done: # Act greedly through the environment - action = actor(next_obs, greedy=True) + action = actor.get_actions(next_obs, greedy=True) # Single environment step next_obs, reward, done, truncated, info = env.step(action.cpu().numpy().reshape(env.action_space.shape)) diff --git a/sheeprl/algos/sac_ae/agent.py b/sheeprl/algos/sac_ae/agent.py index d33ceb56..5b801cda 100644 --- a/sheeprl/algos/sac_ae/agent.py +++ b/sheeprl/algos/sac_ae/agent.py @@ -261,9 +261,7 @@ def __init__( # Orthogonal init self.apply(weight_init) - def forward( - self, obs: Tensor, detach_encoder_features: bool = False, greedy: bool = False - ) -> Union[Tensor, Tuple[Tensor, Tensor]]: + def forward(self, obs: Tensor, detach_encoder_features: bool = False) -> Tuple[Tensor, Tensor]: """Given an observation, it returns a tanh-squashed sampled action (correctly rescaled to the environment action bounds) and its log-prob (as defined in Eq. 26 of https://arxiv.org/abs/1812.05905) @@ -278,16 +276,13 @@ def forward( features = self.encoder(obs, detach_encoder_features=detach_encoder_features) x = self.model(features) mean = self.fc_mean(x) - if greedy: - return torch.tanh(mean) * self.action_scale + self.action_bias - else: - log_std = self.fc_logstd(x) - # log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX) - log_std = torch.tanh(log_std) - log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1) - return self.get_actions_and_log_probs(mean, log_std.exp()) + log_std = self.fc_logstd(x) + # log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX) + log_std = torch.tanh(log_std) + log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1) + return self.get_actions_and_log_probs(mean, log_std.exp()) - def get_actions_and_log_probs(self, mean: Tensor, std: Tensor): + def get_actions_and_log_probs(self, mean: Tensor, std: Tensor) -> Tuple[Tensor, Tensor]: """Given the mean and the std of a Normal distribution, it returns a tanh-squashed sampled action (correctly rescaled to the environment action bounds) and its log-prob (as defined in Eq. 26 of https://arxiv.org/abs/1812.05905) @@ -425,9 +420,6 @@ def log_alpha(self) -> Tensor: def get_actions_and_log_probs(self, obs: Tensor, detach_encoder_features: bool = False) -> Tuple[Tensor, Tensor]: return self.actor(obs, detach_encoder_features) - def get_greedy_actions(self, obs: Tensor) -> Tensor: - return self.actor.get_greedy_actions(obs) - def get_q_values(self, obs: Tensor, action: Tensor, detach_encoder_features: bool = False) -> Tensor: return self.critic(obs, action, detach_encoder_features) @@ -457,6 +449,58 @@ def critic_encoder_target_ema(self) -> None: target_param.data.copy_(self._encoder_tau * param.data + (1 - self._encoder_tau) * target_param.data) +class SACAEPlayer(nn.Module): + def __init__( + self, + feature_extractor: MultiEncoder, + fc: nn.Module, + fc_mean: nn.Module, + fc_logstd: nn.Module, + action_low: Union[SupportsFloat, NDArray] = -1.0, + action_high: Union[SupportsFloat, NDArray] = 1.0, + ): + super().__init__() + self.encoder = feature_extractor + self.model = fc + self.fc_mean = fc_mean + self.fc_logstd = fc_logstd + + # Action rescaling buffers + self.register_buffer("action_scale", torch.tensor((action_high - action_low) / 2.0, dtype=torch.float32)) + self.register_buffer("action_bias", torch.tensor((action_high + action_low) / 2.0, dtype=torch.float32)) + + def forward(self, obs: Tensor, greedy: bool = False) -> Tensor: + """Given an observation, it returns a tanh-squashed + sampled action (correctly rescaled to the environment action bounds) and its + log-prob (as defined in Eq. 26 of https://arxiv.org/abs/1812.05905) + + Args: + obs (Tensor): the observation tensor + + Returns: + tanh-squashed action, rescaled to the environment action bounds + action log-prob + """ + features = self.encoder(obs, detach_encoder_features=False) + x = self.model(features) + mean = self.fc_mean(x) + if greedy: + return torch.tanh(mean) * self.action_scale + self.action_bias + else: + log_std = self.fc_logstd(x) + # log_std = torch.clamp(log_std, LOG_STD_MIN, LOG_STD_MAX) + log_std = torch.tanh(log_std) + log_std = LOG_STD_MIN + 0.5 * (LOG_STD_MAX - LOG_STD_MIN) * (log_std + 1) + normal = torch.distributions.Normal(mean, log_std.exp()) + x_t = normal.rsample() + y_t = torch.tanh(x_t) + actions = y_t * self.action_scale + self.action_bias + return actions + + def get_actions(self, obs: Tensor, greedy: bool = False) -> Tensor: + return self(obs, greedy) + + def build_agent( fabric: Fabric, cfg: Dict[str, Any], @@ -465,7 +509,7 @@ def build_agent( agent_state: Optional[Dict[str, Tensor]] = None, encoder_state: Optional[Dict[str, Tensor]] = None, decoder_sate: Optional[Dict[str, Tensor]] = None, -) -> Tuple[SACAEAgent, _FabricModule, _FabricModule]: +) -> Tuple[SACAEAgent, _FabricModule, _FabricModule, SACAEPlayer]: act_dim = prod(action_space.shape) target_entropy = -act_dim @@ -561,6 +605,16 @@ def build_agent( if agent_state: agent.load_state_dict(agent_state) + # Setup player agent + player = SACAEPlayer( + copy.deepcopy(agent.actor.encoder), + copy.deepcopy(agent.actor.model), + copy.deepcopy(agent.actor.fc_mean), + copy.deepcopy(agent.actor.fc_logstd), + action_low=action_space.low, + action_high=action_space.high, + ) + encoder = fabric.setup_module(encoder) decoder = fabric.setup_module(decoder) agent.actor = fabric.setup_module(agent.actor) @@ -571,4 +625,15 @@ def build_agent( fabric_player = get_single_device_fabric(fabric) agent.critic_target = fabric_player.setup_module(agent.critic_target) - return agent, encoder, decoder + # Setup player agent + player.encoder = fabric_player.setup_module(player.encoder) + player.model = fabric_player.setup_module(player.model) + player.fc_mean = fabric_player.setup_module(player.fc_mean) + player.fc_logstd = fabric_player.setup_module(player.fc_logstd) + player.action_scale = player.action_scale.to(fabric_player.device) + player.action_bias = player.action_bias.to(fabric_player.device) + + # Tie weights between the agent and the player + for agent_p, player_p in zip(agent.actor.parameters(), player.parameters()): + player_p.data = agent_p.data + return agent, encoder, decoder, player diff --git a/sheeprl/algos/sac_ae/evaluate.py b/sheeprl/algos/sac_ae/evaluate.py index df3c4314..27a57514 100644 --- a/sheeprl/algos/sac_ae/evaluate.py +++ b/sheeprl/algos/sac_ae/evaluate.py @@ -5,7 +5,7 @@ import gymnasium as gym from lightning import Fabric -from sheeprl.algos.sac_ae.agent import SACAEAgent, build_agent +from sheeprl.algos.sac_ae.agent import build_agent from sheeprl.algos.sac_ae.utils import test from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger @@ -38,8 +38,7 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): fabric.print("Encoder CNN keys:", cfg.algo.cnn_keys.encoder) fabric.print("Encoder MLP keys:", cfg.algo.mlp_keys.encoder) - agent: SACAEAgent - agent, _, _ = build_agent( + _, _, _, agent = build_agent( fabric, cfg, observation_space, action_space, state["agent"], state["encoder"], state["decoder"] ) - test(agent.actor, fabric, cfg, log_dir) + test(agent, fabric, cfg, log_dir) diff --git a/sheeprl/algos/sac_ae/sac_ae.py b/sheeprl/algos/sac_ae/sac_ae.py index 86ae8c8f..183b5297 100644 --- a/sheeprl/algos/sac_ae/sac_ae.py +++ b/sheeprl/algos/sac_ae/sac_ae.py @@ -2,7 +2,6 @@ import copy import os -import time import warnings from typing import Any, Dict, Optional, Union @@ -26,7 +25,6 @@ from sheeprl.data.buffers import ReplayBuffer from sheeprl.models.models import MultiDecoder, MultiEncoder from sheeprl.utils.env import make_env -from sheeprl.utils.fabric import get_single_device_fabric from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm @@ -195,7 +193,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): obs_keys = cfg.algo.cnn_keys.encoder + cfg.algo.mlp_keys.encoder # Define the agent and the optimizer and setup them with Fabric - agent, encoder, decoder = build_agent( + agent, encoder, decoder, player = build_agent( fabric, cfg, observation_space, @@ -204,20 +202,32 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): state["encoder"] if cfg.checkpoint.resume_from else None, state["decoder"] if cfg.checkpoint.resume_from else None, ) - fabric_player = get_single_device_fabric(fabric) - actor = fabric_player.setup_module(agent.actor.module) # Optimizers - qf_optimizer = hydra.utils.instantiate(cfg.algo.critic.optimizer, params=agent.critic.parameters(), _convert_="all") + qf_optimizer = hydra.utils.instantiate( + cfg.algo.critic.optimizer, + params=agent.critic.parameters(), + _convert_="all", + ) actor_optimizer = hydra.utils.instantiate( - cfg.algo.actor.optimizer, params=agent.actor.parameters(), _convert_="all" + cfg.algo.actor.optimizer, + params=agent.actor.parameters(), + _convert_="all", + ) + alpha_optimizer = hydra.utils.instantiate( + cfg.algo.alpha.optimizer, + params=[agent.log_alpha], + _convert_="all", ) - alpha_optimizer = hydra.utils.instantiate(cfg.algo.alpha.optimizer, params=[agent.log_alpha], _convert_="all") encoder_optimizer = hydra.utils.instantiate( - cfg.algo.encoder.optimizer, params=encoder.parameters(), _convert_="all" + cfg.algo.encoder.optimizer, + params=encoder.parameters(), + _convert_="all", ) decoder_optimizer = hydra.utils.instantiate( - cfg.algo.decoder.optimizer, params=decoder.parameters(), _convert_="all" + cfg.algo.decoder.optimizer, + params=decoder.parameters(), + _convert_="all", ) if cfg.checkpoint.resume_from: @@ -271,7 +281,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): policy_step = state["update"] * cfg.env.num_envs if cfg.checkpoint.resume_from else 0 last_log = state["last_log"] if cfg.checkpoint.resume_from else 0 last_checkpoint = state["last_checkpoint"] if cfg.checkpoint.resume_from else 0 - time.time() policy_steps_per_update = int(cfg.env.num_envs * fabric.world_size) num_updates = int(cfg.algo.total_steps // policy_steps_per_update) if not cfg.dry_run else 1 learning_starts = cfg.algo.learning_starts // policy_steps_per_update if not cfg.dry_run else 0 @@ -322,8 +331,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): with torch.inference_mode(): normalized_obs = {k: v / 255 if k in cfg.algo.cnn_keys.encoder else v for k, v in obs.items()} torch_obs = {k: torch.from_numpy(v).to(device).float() for k, v in normalized_obs.items()} - actions, _ = actor(torch_obs) - actions = actions.cpu().numpy() + actions = player(torch_obs).cpu().numpy() next_obs, rewards, terminated, truncated, infos = envs.step(actions.reshape(envs.action_space.shape)) if cfg.metric.log_level > 0 and "final_info" in infos: @@ -483,7 +491,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): envs.close() if fabric.is_global_zero and cfg.algo.run_test: - test(actor, fabric, cfg, log_dir) + test(player, fabric, cfg, log_dir) if not cfg.model_manager.disabled and fabric.is_global_zero: from sheeprl.algos.sac_ae.utils import log_models diff --git a/sheeprl/algos/sac_ae/utils.py b/sheeprl/algos/sac_ae/utils.py index 2a9e542b..07891f02 100644 --- a/sheeprl/algos/sac_ae/utils.py +++ b/sheeprl/algos/sac_ae/utils.py @@ -18,14 +18,14 @@ if TYPE_CHECKING: from mlflow.models.model import ModelInfo - from sheeprl.algos.sac_ae.agent import SACAEContinuousActor + from sheeprl.algos.sac_ae.agent import SACAEPlayer AGGREGATOR_KEYS = AGGREGATOR_KEYS.union({"Loss/reconstruction_loss"}) MODELS_TO_REGISTER = {"agent", "encoder", "decoder"} @torch.no_grad() -def test(actor: "SACAEContinuousActor", fabric: Fabric, cfg: Dict[str, Any], log_dir: str): +def test(actor: "SACAEPlayer", fabric: Fabric, cfg: Dict[str, Any], log_dir: str): env = make_env(cfg, cfg.seed, 0, log_dir, "test", vector_env_idx=0)() cnn_keys = actor.encoder.cnn_keys mlp_keys = actor.encoder.mlp_keys @@ -45,7 +45,7 @@ def test(actor: "SACAEContinuousActor", fabric: Fabric, cfg: Dict[str, Any], log while not done: # Act greedly through the environment - action = actor(next_obs, greedy=True) + action = actor.get_actions(next_obs, greedy=True) # Single environment step o, reward, done, truncated, _ = env.step(action.cpu().numpy().reshape(env.action_space.shape)) diff --git a/sheeprl/configs/algo/ppo_recurrent.yaml b/sheeprl/configs/algo/ppo_recurrent.yaml index 24384415..d4b848f1 100644 --- a/sheeprl/configs/algo/ppo_recurrent.yaml +++ b/sheeprl/configs/algo/ppo_recurrent.yaml @@ -7,22 +7,22 @@ name: ppo_recurrent vf_coef: 0.2 clip_coef: 0.2 ent_coef: 0.001 -clip_vloss: True +clip_vloss: False anneal_lr: False max_grad_norm: 0.5 -anneal_ent_coef: True -normalize_advantages: True +anneal_ent_coef: False +normalize_advantages: False reset_recurrent_state_on_done: True per_rank_sequence_length: ??? # Model related parameters mlp_layers: 1 layer_norm: True -dense_units: 256 +dense_units: 64 dense_act: torch.nn.ReLU rnn: lstm: - hidden_size: 128 + hidden_size: 64 pre_rnn_mlp: bias: True apply: False @@ -36,7 +36,7 @@ rnn: layer_norm: ${algo.layer_norm} dense_units: ${algo.rnn.lstm.hidden_size} encoder: - dense_units: 128 + dense_units: 64 # Optimizer related parameters optimizer: diff --git a/sheeprl/utils/utils.py b/sheeprl/utils/utils.py index 6d38087f..a66f5d84 100644 --- a/sheeprl/utils/utils.py +++ b/sheeprl/utils/utils.py @@ -238,6 +238,14 @@ def print_config( def unwrap_fabric(model: _FabricModule | nn.Module) -> nn.Module: + """Recursively unwrap the model from _FabricModule. This method returns a deep copy of the model. + + Args: + model (_FabricModule | nn.Module): the model to unwrap. + + Returns: + nn.Module: the unwrapped model. + """ model = copy.deepcopy(model) if isinstance(model, _FabricModule): model = model.module