diff --git a/README.md b/README.md index 4694f5a4..ae2731c0 100644 --- a/README.md +++ b/README.md @@ -9,8 +9,6 @@ - - @@ -58,7 +56,7 @@ DOA++(w/o optimizations)1 7M 18d 22h - 2726/332832 + 2726/33282 N.A. 1-3080 diff --git a/sheeprl/algos/dreamer_v1/agent.py b/sheeprl/algos/dreamer_v1/agent.py index db9f8db8..5aa56849 100644 --- a/sheeprl/algos/dreamer_v1/agent.py +++ b/sheeprl/algos/dreamer_v1/agent.py @@ -324,7 +324,7 @@ def get_greedy_action( return actions -def build_models( +def build_agent( fabric: Fabric, actions_dim: Sequence[int], is_continuous: bool, diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py index 0a9e8d6a..48454131 100644 --- a/sheeprl/algos/dreamer_v1/dreamer_v1.py +++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py @@ -20,7 +20,7 @@ from torch.utils.data import BatchSampler from torchmetrics import SumMetric -from sheeprl.algos.dreamer_v1.agent import PlayerDV1, WorldModel, build_models +from sheeprl.algos.dreamer_v1.agent import PlayerDV1, WorldModel, build_agent from sheeprl.algos.dreamer_v1.loss import actor_loss, critic_loss, reconstruction_loss from sheeprl.algos.dreamer_v1.utils import compute_lambda_values from sheeprl.algos.dreamer_v2.utils import test @@ -444,7 +444,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r @@ -477,7 +477,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder) obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder - world_model, actor, critic = build_models( + world_model, actor, critic = build_agent( fabric, actions_dim, is_continuous, diff --git a/sheeprl/algos/dreamer_v1/evaluate.py b/sheeprl/algos/dreamer_v1/evaluate.py index c2dfb450..5574cbec 100644 --- a/sheeprl/algos/dreamer_v1/evaluate.py +++ b/sheeprl/algos/dreamer_v1/evaluate.py @@ -5,7 +5,7 @@ import gymnasium as gym from lightning import Fabric -from sheeprl.algos.dreamer_v1.agent import PlayerDV1, build_models +from sheeprl.algos.dreamer_v1.agent import PlayerDV1, build_agent from sheeprl.algos.dreamer_v2.utils import test from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger @@ -43,11 +43,11 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) # Create the actor and critic models - world_model, actor, _ = build_models( + world_model, actor, _ = build_agent( fabric, actions_dim, is_continuous, diff --git a/sheeprl/algos/dreamer_v2/agent.py b/sheeprl/algos/dreamer_v2/agent.py index cb01db30..ededb565 100644 --- a/sheeprl/algos/dreamer_v2/agent.py +++ b/sheeprl/algos/dreamer_v2/agent.py @@ -862,7 +862,7 @@ def get_greedy_action( return actions -def build_models( +def build_agent( fabric: Fabric, actions_dim: Sequence[int], is_continuous: bool, diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py index 8f6b7b54..850e9776 100644 --- a/sheeprl/algos/dreamer_v2/dreamer_v2.py +++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py @@ -25,7 +25,7 @@ from torch.utils.data import BatchSampler from torchmetrics import SumMetric -from sheeprl.algos.dreamer_v2.agent import PlayerDV2, WorldModel, build_models +from sheeprl.algos.dreamer_v2.agent import PlayerDV2, WorldModel, build_agent from sheeprl.algos.dreamer_v2.loss import reconstruction_loss from sheeprl.algos.dreamer_v2.utils import compute_lambda_values, test from sheeprl.data.buffers import AsyncReplayBuffer, EpisodeBuffer @@ -468,7 +468,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r @@ -501,7 +501,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder) obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder - world_model, actor, critic, target_critic = build_models( + world_model, actor, critic, target_critic = build_agent( fabric, actions_dim, is_continuous, diff --git a/sheeprl/algos/dreamer_v2/evaluate.py b/sheeprl/algos/dreamer_v2/evaluate.py index d7d37a6a..ec2e0738 100644 --- a/sheeprl/algos/dreamer_v2/evaluate.py +++ b/sheeprl/algos/dreamer_v2/evaluate.py @@ -5,7 +5,7 @@ import gymnasium as gym from lightning import Fabric -from sheeprl.algos.dreamer_v2.agent import PlayerDV2, build_models +from sheeprl.algos.dreamer_v2.agent import PlayerDV2, build_agent from sheeprl.algos.dreamer_v2.utils import test from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger @@ -43,11 +43,11 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) # Create the actor and critic models - world_model, actor, _, _ = build_models( + world_model, actor, _, _ = build_agent( fabric, actions_dim, is_continuous, diff --git a/sheeprl/algos/dreamer_v3/agent.py b/sheeprl/algos/dreamer_v3/agent.py index 0d712b2e..1d7ac574 100644 --- a/sheeprl/algos/dreamer_v3/agent.py +++ b/sheeprl/algos/dreamer_v3/agent.py @@ -897,7 +897,7 @@ def add_exploration_noise( return tuple(expl_actions) -def build_models( +def build_agent( fabric: Fabric, actions_dim: Sequence[int], is_continuous: bool, diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py index c3588d66..0b360e8f 100644 --- a/sheeprl/algos/dreamer_v3/dreamer_v3.py +++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py @@ -25,7 +25,7 @@ from torch.utils.data import BatchSampler from torchmetrics import SumMetric -from sheeprl.algos.dreamer_v3.agent import PlayerDV3, WorldModel, build_models +from sheeprl.algos.dreamer_v3.agent import PlayerDV3, WorldModel, build_agent from sheeprl.algos.dreamer_v3.loss import reconstruction_loss from sheeprl.algos.dreamer_v3.utils import Moments, compute_lambda_values, test from sheeprl.data.buffers import AsyncReplayBuffer @@ -402,7 +402,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r @@ -435,7 +435,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder) obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder - world_model, actor, critic, target_critic = build_models( + world_model, actor, critic, target_critic = build_agent( fabric, actions_dim, is_continuous, diff --git a/sheeprl/algos/dreamer_v3/evaluate.py b/sheeprl/algos/dreamer_v3/evaluate.py index 5b1a0f38..ab5ab68f 100644 --- a/sheeprl/algos/dreamer_v3/evaluate.py +++ b/sheeprl/algos/dreamer_v3/evaluate.py @@ -5,7 +5,7 @@ import gymnasium as gym from lightning import Fabric -from sheeprl.algos.dreamer_v3.agent import PlayerDV3, build_models +from sheeprl.algos.dreamer_v3.agent import PlayerDV3, build_agent from sheeprl.algos.dreamer_v3.utils import test from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger @@ -43,11 +43,11 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) # Create the actor and critic models - world_model, actor, _, _ = build_models( + world_model, actor, _, _ = build_agent( fabric, actions_dim, is_continuous, diff --git a/sheeprl/algos/droq/agent.py b/sheeprl/algos/droq/agent.py index a3c88b12..cd206e22 100644 --- a/sheeprl/algos/droq/agent.py +++ b/sheeprl/algos/droq/agent.py @@ -1,8 +1,11 @@ import copy -from typing import Sequence, Tuple, Union +from math import prod +from typing import Any, Dict, Optional, Sequence, Tuple, Union +import gymnasium import torch import torch.nn as nn +from lightning import Fabric from lightning.fabric.wrappers import _FabricModule from torch import Tensor @@ -198,3 +201,41 @@ def qfs_target_ema(self, critic_idx: int) -> None: self.qfs_unwrapped[critic_idx].parameters(), self.qfs_target[critic_idx].parameters() ): target_param.data.copy_(self._tau * param.data + (1 - self._tau) * target_param.data) + + +def build_agent( + fabric: Fabric, + cfg: Dict[str, Any], + obs_space: gymnasium.spaces.Dict, + action_space: gymnasium.spaces.Box, + agent_state: Optional[Dict[str, Tensor]] = None, +) -> DROQAgent: + act_dim = prod(action_space.shape) + obs_dim = sum([prod(obs_space[k].shape) for k in cfg.mlp_keys.encoder]) + actor = SACActor( + observation_dim=obs_dim, + action_dim=act_dim, + distribution_cfg=cfg.distribution, + hidden_size=cfg.algo.actor.hidden_size, + action_low=action_space.low, + action_high=action_space.high, + ) + critics = [ + DROQCritic( + observation_dim=obs_dim + act_dim, + hidden_size=cfg.algo.critic.hidden_size, + num_critics=1, + dropout=cfg.algo.critic.dropout, + ) + for _ in range(cfg.algo.critic.n) + ] + target_entropy = -act_dim + agent = DROQAgent( + actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device + ) + if agent_state: + agent.load_state_dict(agent_state) + agent.actor = fabric.setup_module(agent.actor) + agent.critics = [fabric.setup_module(critic) for critic in agent.critics] + + return agent diff --git a/sheeprl/algos/droq/droq.py b/sheeprl/algos/droq/droq.py index 94d3a366..c0c5b71f 100644 --- a/sheeprl/algos/droq/droq.py +++ b/sheeprl/algos/droq/droq.py @@ -3,7 +3,6 @@ import copy import os import warnings -from math import prod from typing import Any, Dict import gymnasium as gym @@ -20,8 +19,7 @@ from torch.utils.data.sampler import BatchSampler from torchmetrics import SumMetric -from sheeprl.algos.droq.agent import DROQAgent, DROQCritic -from sheeprl.algos.sac.agent import SACActor +from sheeprl.algos.droq.agent import DROQAgent, build_agent from sheeprl.algos.sac.loss import entropy_loss, policy_loss from sheeprl.algos.sac.sac import test from sheeprl.data.buffers import ReplayBuffer @@ -196,33 +194,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) # Define the agent and the optimizer and setup them with Fabric - act_dim = prod(action_space.shape) - obs_dim = sum([prod(observation_space[k].shape) for k in cfg.mlp_keys.encoder]) - actor = SACActor( - observation_dim=obs_dim, - action_dim=act_dim, - distribution_cfg=cfg.distribution, - hidden_size=cfg.algo.actor.hidden_size, - action_low=action_space.low, - action_high=action_space.high, + agent = build_agent( + fabric, cfg, observation_space, action_space, state["agent"] if cfg.checkpoint.resume_from else None ) - critics = [ - DROQCritic( - observation_dim=obs_dim + act_dim, - hidden_size=cfg.algo.critic.hidden_size, - num_critics=1, - dropout=cfg.algo.critic.dropout, - ) - for _ in range(cfg.algo.critic.n) - ] - target_entropy = -act_dim - agent = DROQAgent( - actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device - ) - if cfg.checkpoint.resume_from: - agent.load_state_dict(state["agent"]) - agent.actor = fabric.setup_module(agent.actor) - agent.critics = [fabric.setup_module(critic) for critic in agent.critics] # Optimizers qf_optimizer = hydra.utils.instantiate(cfg.algo.critic.optimizer, params=agent.qfs.parameters()) diff --git a/sheeprl/algos/droq/evaluate.py b/sheeprl/algos/droq/evaluate.py index 0d082dbf..ce179264 100644 --- a/sheeprl/algos/droq/evaluate.py +++ b/sheeprl/algos/droq/evaluate.py @@ -1,13 +1,11 @@ from __future__ import annotations -from math import prod from typing import Any, Dict import gymnasium as gym from lightning import Fabric -from sheeprl.algos.droq.agent import DROQAgent, DROQCritic -from sheeprl.algos.sac.agent import SACActor +from sheeprl.algos.droq.agent import build_agent from sheeprl.algos.sac.utils import test from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger @@ -47,29 +45,5 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): if cfg.metric.log_level > 0: fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) - act_dim = prod(action_space.shape) - obs_dim = sum([prod(observation_space[k].shape) for k in cfg.mlp_keys.encoder]) - actor = SACActor( - observation_dim=obs_dim, - action_dim=act_dim, - distribution_cfg=cfg.distribution, - hidden_size=cfg.algo.actor.hidden_size, - action_low=action_space.low, - action_high=action_space.high, - ) - critics = [ - DROQCritic( - observation_dim=obs_dim + act_dim, - hidden_size=cfg.algo.critic.hidden_size, - num_critics=1, - dropout=cfg.algo.critic.dropout, - ) - for _ in range(cfg.algo.critic.n) - ] - target_entropy = -act_dim - agent = DROQAgent( - actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device - ) - agent.load_state_dict(state["agent"]) - agent = fabric.setup_module(agent) + agent = build_agent(fabric, cfg, observation_space, action_space, state["agent"]) test(agent.actor, fabric, cfg, log_dir) diff --git a/sheeprl/algos/p2e_dv1/agent.py b/sheeprl/algos/p2e_dv1/agent.py index 9c844dd8..9fb17606 100644 --- a/sheeprl/algos/p2e_dv1/agent.py +++ b/sheeprl/algos/p2e_dv1/agent.py @@ -7,7 +7,7 @@ from lightning.fabric.wrappers import _FabricModule from sheeprl.algos.dreamer_v1.agent import WorldModel -from sheeprl.algos.dreamer_v1.agent import build_models as dv1_build_models +from sheeprl.algos.dreamer_v1.agent import build_agent as dv1_build_agent from sheeprl.algos.dreamer_v2.agent import Actor as DV2Actor from sheeprl.algos.dreamer_v2.agent import MinedojoActor as DV2MinedojoActor from sheeprl.models.models import MLP @@ -20,7 +20,7 @@ MinedojoActor = DV2MinedojoActor -def build_models( +def build_agent( fabric: Fabric, actions_dim: Sequence[int], is_continuous: bool, @@ -67,7 +67,7 @@ def build_models( latent_state_size = world_model_cfg.stochastic_size + world_model_cfg.recurrent_model.recurrent_state_size # Create exploration models - world_model, actor_exploration, critic_exploration = dv1_build_models( + world_model, actor_exploration, critic_exploration = dv1_build_agent( fabric, actions_dim=actions_dim, is_continuous=is_continuous, diff --git a/sheeprl/algos/p2e_dv1/evaluate.py b/sheeprl/algos/p2e_dv1/evaluate.py index ddc7c294..349f4884 100644 --- a/sheeprl/algos/p2e_dv1/evaluate.py +++ b/sheeprl/algos/p2e_dv1/evaluate.py @@ -7,7 +7,7 @@ from sheeprl.algos.dreamer_v1.agent import PlayerDV1 from sheeprl.algos.dreamer_v2.utils import test -from sheeprl.algos.p2e_dv1.agent import build_models +from sheeprl.algos.p2e_dv1.agent import build_agent from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.registry import register_evaluation @@ -44,11 +44,11 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) # Create the actor and critic models - world_model, actor_task, _, _, _ = build_models( + world_model, actor_task, _, _, _ = build_agent( fabric, actions_dim, is_continuous, diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py index 41f44078..c81f569e 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py @@ -26,7 +26,7 @@ from sheeprl.algos.dreamer_v1.loss import actor_loss, critic_loss, reconstruction_loss from sheeprl.algos.dreamer_v1.utils import compute_lambda_values from sheeprl.algos.dreamer_v2.utils import test -from sheeprl.algos.p2e_dv1.agent import build_models +from sheeprl.algos.p2e_dv1.agent import build_agent from sheeprl.data.buffers import AsyncReplayBuffer from sheeprl.models.models import MLP from sheeprl.utils.env import make_env @@ -455,7 +455,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r @@ -488,7 +488,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder) obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder - world_model, actor_task, critic_task, actor_exploration, critic_exploration = build_models( + world_model, actor_task, critic_task, actor_exploration, critic_exploration = build_agent( fabric, actions_dim, is_continuous, diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py index ecb370eb..d8d743e0 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py @@ -20,7 +20,7 @@ from sheeprl.algos.dreamer_v1.agent import PlayerDV1 from sheeprl.algos.dreamer_v1.dreamer_v1 import train from sheeprl.algos.dreamer_v2.utils import test -from sheeprl.algos.p2e_dv1.agent import build_models +from sheeprl.algos.p2e_dv1.agent import build_agent from sheeprl.data.buffers import AsyncReplayBuffer from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger @@ -106,7 +106,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r @@ -139,7 +139,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder) obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder - world_model, actor_task, critic_task, actor_exploration, _ = build_models( + world_model, actor_task, critic_task, actor_exploration, _ = build_agent( fabric, actions_dim, is_continuous, diff --git a/sheeprl/algos/p2e_dv2/agent.py b/sheeprl/algos/p2e_dv2/agent.py index b40973d6..3ffef72f 100644 --- a/sheeprl/algos/p2e_dv2/agent.py +++ b/sheeprl/algos/p2e_dv2/agent.py @@ -11,7 +11,7 @@ from sheeprl.algos.dreamer_v2.agent import Actor as DV2Actor from sheeprl.algos.dreamer_v2.agent import MinedojoActor as DV2MinedojoActor from sheeprl.algos.dreamer_v2.agent import WorldModel -from sheeprl.algos.dreamer_v2.agent import build_models as dv2_build_models +from sheeprl.algos.dreamer_v2.agent import build_agent as dv2_build_agent from sheeprl.models.models import MLP from sheeprl.utils.utils import init_weights @@ -22,7 +22,7 @@ MinedojoActor = DV2MinedojoActor -def build_models( +def build_agent( fabric: Fabric, actions_dim: Sequence[int], is_continuous: bool, @@ -78,7 +78,7 @@ def build_models( latent_state_size = stochastic_size + world_model_cfg.recurrent_model.recurrent_state_size # Create exploration models - world_model, actor_exploration, critic_exploration, target_critic_exploration = dv2_build_models( + world_model, actor_exploration, critic_exploration, target_critic_exploration = dv2_build_agent( fabric, actions_dim=actions_dim, is_continuous=is_continuous, diff --git a/sheeprl/algos/p2e_dv2/evaluate.py b/sheeprl/algos/p2e_dv2/evaluate.py index f252f4e6..cb420460 100644 --- a/sheeprl/algos/p2e_dv2/evaluate.py +++ b/sheeprl/algos/p2e_dv2/evaluate.py @@ -7,7 +7,7 @@ from sheeprl.algos.dreamer_v2.agent import PlayerDV2 from sheeprl.algos.dreamer_v2.utils import test -from sheeprl.algos.p2e_dv2.agent import build_models +from sheeprl.algos.p2e_dv2.agent import build_agent from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.registry import register_evaluation @@ -44,11 +44,11 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) # Create the actor and critic models - world_model, actor_task, _, _, _, _, _ = build_models( + world_model, actor_task, _, _, _, _, _ = build_agent( fabric, actions_dim, is_continuous, diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py index 7e4b1cee..da3a1be6 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py @@ -25,7 +25,7 @@ from sheeprl.algos.dreamer_v2.agent import PlayerDV2, WorldModel from sheeprl.algos.dreamer_v2.loss import reconstruction_loss from sheeprl.algos.dreamer_v2.utils import compute_lambda_values, init_weights, test -from sheeprl.algos.p2e_dv2.agent import build_models +from sheeprl.algos.p2e_dv2.agent import build_agent from sheeprl.data.buffers import AsyncReplayBuffer, EpisodeBuffer from sheeprl.models.models import MLP from sheeprl.utils.distribution import OneHotCategoricalValidateArgs @@ -570,7 +570,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r @@ -611,7 +611,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): actor_exploration, critic_exploration, target_critic_exploration, - ) = build_models( + ) = build_agent( fabric, actions_dim, is_continuous, diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py index 878d7a77..adebc83b 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py @@ -21,7 +21,7 @@ from sheeprl.algos.dreamer_v2.agent import PlayerDV2 from sheeprl.algos.dreamer_v2.dreamer_v2 import train from sheeprl.algos.dreamer_v2.utils import test -from sheeprl.algos.p2e_dv2.agent import build_models +from sheeprl.algos.p2e_dv2.agent import build_agent from sheeprl.data.buffers import AsyncReplayBuffer, EpisodeBuffer from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger @@ -110,7 +110,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r @@ -143,7 +143,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder) obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder - world_model, actor_task, critic_task, target_critic_task, actor_exploration, _, _ = build_models( + world_model, actor_task, critic_task, target_critic_task, actor_exploration, _, _ = build_agent( fabric, actions_dim, is_continuous, diff --git a/sheeprl/algos/p2e_dv3/agent.py b/sheeprl/algos/p2e_dv3/agent.py index c4019b1c..586e3e1e 100644 --- a/sheeprl/algos/p2e_dv3/agent.py +++ b/sheeprl/algos/p2e_dv3/agent.py @@ -10,7 +10,7 @@ from sheeprl.algos.dreamer_v3.agent import Actor as DV3Actor from sheeprl.algos.dreamer_v3.agent import MinedojoActor as DV3MinedojoActor from sheeprl.algos.dreamer_v3.agent import WorldModel -from sheeprl.algos.dreamer_v3.agent import build_models as dv3_build_models +from sheeprl.algos.dreamer_v3.agent import build_agent as dv3_build_agent from sheeprl.algos.dreamer_v3.utils import init_weights, uniform_init_weights from sheeprl.models.models import MLP @@ -21,7 +21,7 @@ MinedojoActor = DV3MinedojoActor -def build_models( +def build_agent( fabric: Fabric, actions_dim: Sequence[int], is_continuous: bool, @@ -73,7 +73,7 @@ def build_models( latent_state_size = stochastic_size + world_model_cfg.recurrent_model.recurrent_state_size # Create task models - world_model, actor_task, critic_task, target_critic_task = dv3_build_models( + world_model, actor_task, critic_task, target_critic_task = dv3_build_agent( fabric, actions_dim=actions_dim, is_continuous=is_continuous, diff --git a/sheeprl/algos/p2e_dv3/evaluate.py b/sheeprl/algos/p2e_dv3/evaluate.py index 97ceb94f..20ccd61d 100644 --- a/sheeprl/algos/p2e_dv3/evaluate.py +++ b/sheeprl/algos/p2e_dv3/evaluate.py @@ -7,7 +7,7 @@ from sheeprl.algos.dreamer_v3.agent import PlayerDV3 from sheeprl.algos.dreamer_v3.utils import test -from sheeprl.algos.p2e_dv3.agent import build_models +from sheeprl.algos.p2e_dv3.agent import build_agent from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.registry import register_evaluation @@ -44,11 +44,11 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) # Create the actor and critic models - world_model, actor, _, _, _, _ = build_models( + world_model, actor, _, _, _, _ = build_agent( fabric, actions_dim, is_continuous, diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py index 86c9063f..776c5420 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py @@ -24,7 +24,7 @@ from sheeprl.algos.dreamer_v3.agent import PlayerDV3, WorldModel from sheeprl.algos.dreamer_v3.loss import reconstruction_loss from sheeprl.algos.dreamer_v3.utils import Moments, compute_lambda_values, init_weights, test -from sheeprl.algos.p2e_dv3.agent import build_models +from sheeprl.algos.p2e_dv3.agent import build_agent from sheeprl.data.buffers import AsyncReplayBuffer from sheeprl.models.models import MLP from sheeprl.utils.distribution import ( @@ -599,7 +599,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r @@ -639,7 +639,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): target_critic_task, actor_exploration, critics_exploration, - ) = build_models( + ) = build_agent( fabric, actions_dim, is_continuous, diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py index 321f5ea2..ea1eef9f 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py @@ -19,7 +19,7 @@ from sheeprl.algos.dreamer_v3.agent import PlayerDV3 from sheeprl.algos.dreamer_v3.dreamer_v3 import train from sheeprl.algos.dreamer_v3.utils import Moments, test -from sheeprl.algos.p2e_dv3.agent import build_models +from sheeprl.algos.p2e_dv3.agent import build_agent from sheeprl.data.buffers import AsyncReplayBuffer from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger @@ -104,7 +104,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): is_continuous = isinstance(action_space, gym.spaces.Box) is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r @@ -144,7 +144,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): target_critic_task, actor_exploration, _, - ) = build_models( + ) = build_agent( fabric, actions_dim, is_continuous, diff --git a/sheeprl/algos/ppo/agent.py b/sheeprl/algos/ppo/agent.py index 4efcea6b..6bd6b632 100644 --- a/sheeprl/algos/ppo/agent.py +++ b/sheeprl/algos/ppo/agent.py @@ -4,6 +4,8 @@ import gymnasium import torch import torch.nn as nn +from lightning import Fabric +from lightning.fabric.wrappers import _FabricModule from torch import Tensor from torch.distributions import Distribution, Independent, Normal @@ -62,7 +64,7 @@ def forward(self, obs: Dict[str, Tensor]) -> Tensor: class PPOAgent(nn.Module): def __init__( self, - actions_dim: List[int], + actions_dim: Sequence[int], obs_space: gymnasium.spaces.Dict, encoder_cfg: Dict[str, Any], actor_cfg: Dict[str, Any], @@ -194,3 +196,30 @@ def get_greedy_actions(self, obs: Dict[str, Tensor]) -> Sequence[Tensor]: for logits in pre_dist ] ) + + +def build_agent( + fabric: Fabric, + actions_dim: Sequence[int], + is_continuous: bool, + cfg: Dict[str, Any], + obs_space: gymnasium.spaces.Dict, + agent_state: Optional[Dict[str, Tensor]] = None, +) -> _FabricModule: + agent = PPOAgent( + actions_dim=actions_dim, + obs_space=obs_space, + encoder_cfg=cfg.algo.encoder, + actor_cfg=cfg.algo.actor, + critic_cfg=cfg.algo.critic, + cnn_keys=cfg.cnn_keys.encoder, + mlp_keys=cfg.mlp_keys.encoder, + screen_size=cfg.env.screen_size, + distribution_cfg=cfg.distribution, + is_continuous=is_continuous, + ) + if agent_state: + agent.load_state_dict(agent_state) + agent = fabric.setup_module(agent) + + return agent diff --git a/sheeprl/algos/ppo/evaluate.py b/sheeprl/algos/ppo/evaluate.py index 82764685..56ca9e1f 100644 --- a/sheeprl/algos/ppo/evaluate.py +++ b/sheeprl/algos/ppo/evaluate.py @@ -5,7 +5,7 @@ import gymnasium as gym from lightning import Fabric -from sheeprl.algos.ppo.agent import PPOAgent +from sheeprl.algos.ppo.agent import build_agent from sheeprl.algos.ppo.utils import test from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger @@ -42,26 +42,13 @@ def evaluate_ppo(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): is_continuous = isinstance(env.action_space, gym.spaces.Box) is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( env.action_space.shape if is_continuous else (env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]) ) # Create the actor and critic models - agent = PPOAgent( - actions_dim=actions_dim, - obs_space=observation_space, - encoder_cfg=cfg.algo.encoder, - actor_cfg=cfg.algo.actor, - critic_cfg=cfg.algo.critic, - cnn_keys=cfg.cnn_keys.encoder, - mlp_keys=cfg.mlp_keys.encoder, - screen_size=cfg.env.screen_size, - distribution_cfg=cfg.distribution, - is_continuous=is_continuous, - ) - agent.load_state_dict(state["agent"]) - agent = fabric.setup_module(agent) + agent = build_agent(fabric, actions_dim, is_continuous, cfg, observation_space, state["agent"]) test(agent, fabric, cfg, log_dir) diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py index 8be7f0ca..d8bd32db 100644 --- a/sheeprl/algos/ppo/ppo.py +++ b/sheeprl/algos/ppo/ppo.py @@ -19,7 +19,7 @@ from torch.utils.data import BatchSampler, DistributedSampler, RandomSampler from torchmetrics import SumMetric -from sheeprl.algos.ppo.agent import PPOAgent +from sheeprl.algos.ppo.agent import build_agent from sheeprl.algos.ppo.loss import entropy_loss, policy_loss, value_loss from sheeprl.algos.ppo.utils import normalize_obs, test from sheeprl.data import ReplayBuffer @@ -168,23 +168,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): is_continuous = isinstance(envs.single_action_space, gym.spaces.Box) is_multidiscrete = isinstance(envs.single_action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( envs.single_action_space.shape if is_continuous else (envs.single_action_space.nvec.tolist() if is_multidiscrete else [envs.single_action_space.n]) ) # Create the actor and critic models - agent = PPOAgent( - actions_dim=actions_dim, - obs_space=observation_space, - encoder_cfg=cfg.algo.encoder, - actor_cfg=cfg.algo.actor, - critic_cfg=cfg.algo.critic, - cnn_keys=cfg.cnn_keys.encoder, - mlp_keys=cfg.mlp_keys.encoder, - screen_size=cfg.env.screen_size, - distribution_cfg=cfg.distribution, - is_continuous=is_continuous, + agent = build_agent( + fabric, + actions_dim, + is_continuous, + cfg, + observation_space, + state["agent"] if cfg.checkpoint.resume_from else None, ) # Define the optimizer @@ -194,11 +190,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): # Load the state from the checkpoint if cfg.checkpoint.resume_from: - agent.load_state_dict(state["agent"]) optimizer.load_state_dict(state["optimizer"]) # Setup agent and optimizer with Fabric - agent = fabric.setup_module(agent) optimizer = fabric.setup_optimizers(optimizer) # Create a metric aggregator to log the metrics diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py index 5ab593e6..494189b0 100644 --- a/sheeprl/algos/ppo/ppo_decoupled.py +++ b/sheeprl/algos/ppo/ppo_decoupled.py @@ -78,7 +78,7 @@ def player( is_continuous = isinstance(envs.single_action_space, gym.spaces.Box) is_multidiscrete = isinstance(envs.single_action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( envs.single_action_space.shape if is_continuous else (envs.single_action_space.nvec.tolist() if is_multidiscrete else [envs.single_action_space.n]) diff --git a/sheeprl/algos/ppo_recurrent/agent.py b/sheeprl/algos/ppo_recurrent/agent.py index 90ae6ab3..7fd1878f 100644 --- a/sheeprl/algos/ppo_recurrent/agent.py +++ b/sheeprl/algos/ppo_recurrent/agent.py @@ -4,6 +4,7 @@ import gymnasium import torch import torch.nn as nn +from lightning import Fabric from torch import Tensor from torch.distributions import Independent, Normal @@ -286,3 +287,33 @@ def forward( pre_dist = self.get_pre_dist(out) actions, logprobs, entropies = self.get_sampled_actions(pre_dist, actions) return actions, logprobs, entropies, values, states + + +def build_agent( + fabric: Fabric, + actions_dim: Sequence[int], + is_continuous: bool, + cfg: Dict[str, Any], + obs_space: gymnasium.spaces.Dict, + agent_state: Optional[Dict[str, Tensor]] = None, +) -> RecurrentPPOAgent: + agent = RecurrentPPOAgent( + actions_dim=actions_dim, + obs_space=obs_space, + encoder_cfg=cfg.algo.encoder, + rnn_cfg=cfg.algo.rnn, + actor_cfg=cfg.algo.actor, + critic_cfg=cfg.algo.critic, + cnn_keys=cfg.cnn_keys.encoder, + mlp_keys=cfg.mlp_keys.encoder, + is_continuous=is_continuous, + distribution_cfg=cfg.distribution, + num_envs=cfg.env.num_envs, + screen_size=cfg.env.screen_size, + device=fabric.device, + ) + if agent_state: + agent.load_state_dict(agent_state) + agent = fabric.setup_module(agent) + + return agent diff --git a/sheeprl/algos/ppo_recurrent/evaluate.py b/sheeprl/algos/ppo_recurrent/evaluate.py index b0940742..e7aed909 100644 --- a/sheeprl/algos/ppo_recurrent/evaluate.py +++ b/sheeprl/algos/ppo_recurrent/evaluate.py @@ -5,7 +5,7 @@ import gymnasium as gym from lightning import Fabric -from sheeprl.algos.ppo_recurrent.agent import RecurrentPPOAgent +from sheeprl.algos.ppo_recurrent.agent import build_agent from sheeprl.algos.ppo_recurrent.utils import test from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger @@ -42,27 +42,11 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): is_continuous = isinstance(env.action_space, gym.spaces.Box) is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( env.action_space.shape if is_continuous else (env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]) ) # Create the actor and critic models - agent = RecurrentPPOAgent( - actions_dim=actions_dim, - obs_space=observation_space, - encoder_cfg=cfg.algo.encoder, - rnn_cfg=cfg.algo.rnn, - actor_cfg=cfg.algo.actor, - critic_cfg=cfg.algo.critic, - cnn_keys=cfg.cnn_keys.encoder, - mlp_keys=cfg.mlp_keys.encoder, - is_continuous=is_continuous, - distribution_cfg=cfg.distribution, - num_envs=cfg.env.num_envs, - screen_size=cfg.env.screen_size, - device=fabric.device, - ) - agent.load_state_dict(state["agent"]) - agent = fabric.setup_module(agent) + agent = build_agent(fabric, actions_dim, is_continuous, cfg, observation_space, state["agent"]) test(agent, fabric, cfg, log_dir) diff --git a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py index b067a7b9..ec4bad48 100644 --- a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py +++ b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py @@ -22,7 +22,7 @@ from sheeprl.algos.ppo.loss import entropy_loss, policy_loss, value_loss from sheeprl.algos.ppo.utils import normalize_obs -from sheeprl.algos.ppo_recurrent.agent import RecurrentPPOAgent +from sheeprl.algos.ppo_recurrent.agent import RecurrentPPOAgent, build_agent from sheeprl.algos.ppo_recurrent.utils import test from sheeprl.data.buffers import ReplayBuffer from sheeprl.utils.env import make_env @@ -178,37 +178,27 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): is_continuous = isinstance(envs.single_action_space, gym.spaces.Box) is_multidiscrete = isinstance(envs.single_action_space, gym.spaces.MultiDiscrete) - actions_dim = ( + actions_dim = tuple( envs.single_action_space.shape if is_continuous else (envs.single_action_space.nvec.tolist() if is_multidiscrete else [envs.single_action_space.n]) ) # Define the agent and the optimizer - agent = RecurrentPPOAgent( - actions_dim=actions_dim, - obs_space=observation_space, - encoder_cfg=cfg.algo.encoder, - rnn_cfg=cfg.algo.rnn, - actor_cfg=cfg.algo.actor, - critic_cfg=cfg.algo.critic, - cnn_keys=cfg.cnn_keys.encoder, - mlp_keys=cfg.mlp_keys.encoder, - is_continuous=is_continuous, - distribution_cfg=cfg.distribution, - num_envs=cfg.env.num_envs, - screen_size=cfg.env.screen_size, - device=device, + agent = build_agent( + fabric, + actions_dim, + is_continuous, + cfg, + observation_space, + state["agent"] if cfg.checkpoint.resume_from else None, ) optimizer = hydra.utils.instantiate(cfg.algo.optimizer, params=agent.parameters()) # Load the state from the checkpoint if cfg.checkpoint.resume_from: - agent.load_state_dict(state["agent"]) optimizer.load_state_dict(state["optimizer"]) - # Setup agent and optimizer with Fabric - agent = fabric.setup_module(agent) optimizer = fabric.setup_optimizers(optimizer) local_vars = locals() diff --git a/sheeprl/algos/sac/agent.py b/sheeprl/algos/sac/agent.py index 4e52a08f..215415a5 100644 --- a/sheeprl/algos/sac/agent.py +++ b/sheeprl/algos/sac/agent.py @@ -1,8 +1,11 @@ import copy -from typing import Any, Dict, Sequence, SupportsFloat, Tuple, Union +from math import prod +from typing import Any, Dict, Optional, Sequence, SupportsFloat, Tuple, Union +import gymnasium import torch import torch.nn as nn +from lightning import Fabric from lightning.fabric.wrappers import _FabricModule from numpy.typing import NDArray from torch import Tensor @@ -273,3 +276,34 @@ def get_next_target_q_values(self, next_obs: Tensor, rewards: Tensor, dones: Ten def qfs_target_ema(self) -> None: for param, target_param in zip(self.qfs_unwrapped.parameters(), self.qfs_target.parameters()): target_param.data.copy_(self._tau * param.data + (1 - self._tau) * target_param.data) + + +def build_agent( + fabric: Fabric, + cfg: Dict[str, Any], + obs_space: gymnasium.spaces.Dict, + action_space: gymnasium.spaces.Box, + agent_state: Optional[Dict[str, Tensor]] = None, +) -> SACAgent: + act_dim = prod(action_space.shape) + obs_dim = sum([prod(obs_space[k].shape) for k in cfg.mlp_keys.encoder]) + actor = SACActor( + observation_dim=obs_dim, + action_dim=act_dim, + distribution_cfg=cfg.distribution, + hidden_size=cfg.algo.actor.hidden_size, + action_low=action_space.low, + action_high=action_space.high, + ) + critics = [ + SACCritic(observation_dim=obs_dim + act_dim, hidden_size=cfg.algo.critic.hidden_size, num_critics=1) + for _ in range(cfg.algo.critic.n) + ] + target_entropy = -act_dim + agent = SACAgent(actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device) + if agent_state: + agent.load_state_dict(agent_state) + agent.actor = fabric.setup_module(agent.actor) + agent.critics = [fabric.setup_module(critic) for critic in agent.critics] + + return agent diff --git a/sheeprl/algos/sac/evaluate.py b/sheeprl/algos/sac/evaluate.py index 9d138562..89024046 100644 --- a/sheeprl/algos/sac/evaluate.py +++ b/sheeprl/algos/sac/evaluate.py @@ -1,12 +1,11 @@ from __future__ import annotations -from math import prod from typing import Any, Dict import gymnasium as gym from lightning import Fabric -from sheeprl.algos.sac.agent import SACActor, SACAgent, SACCritic +from sheeprl.algos.sac.agent import build_agent from sheeprl.algos.sac.utils import test from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger @@ -45,22 +44,5 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): ) fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) - act_dim = prod(action_space.shape) - obs_dim = sum([prod(observation_space[k].shape) for k in cfg.mlp_keys.encoder]) - actor = SACActor( - observation_dim=obs_dim, - action_dim=act_dim, - distribution_cfg=cfg.distribution, - hidden_size=cfg.algo.actor.hidden_size, - action_low=action_space.low, - action_high=action_space.high, - ) - critics = [ - SACCritic(observation_dim=obs_dim + act_dim, hidden_size=cfg.algo.critic.hidden_size, num_critics=1) - for _ in range(cfg.algo.critic.n) - ] - target_entropy = -act_dim - agent = SACAgent(actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device) - agent.load_state_dict(state["agent"]) - agent = fabric.setup_module(agent) + agent = build_agent(fabric, cfg, observation_space, action_space, state["agent"]) test(agent.actor, fabric, cfg, log_dir) diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py index 58e80cc4..d0a9ec7e 100644 --- a/sheeprl/algos/sac/sac.py +++ b/sheeprl/algos/sac/sac.py @@ -3,7 +3,6 @@ import copy import os import warnings -from math import prod from typing import Any, Dict, Optional import gymnasium as gym @@ -21,7 +20,7 @@ from torch.utils.data.sampler import BatchSampler from torchmetrics import SumMetric -from sheeprl.algos.sac.agent import SACActor, SACAgent, SACCritic +from sheeprl.algos.sac.agent import SACAgent, build_agent from sheeprl.algos.sac.loss import critic_loss, entropy_loss, policy_loss from sheeprl.algos.sac.utils import test from sheeprl.data.buffers import ReplayBuffer @@ -148,26 +147,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) # Define the agent and the optimizer and setup sthem with Fabric - act_dim = prod(action_space.shape) - obs_dim = sum([prod(observation_space[k].shape) for k in cfg.mlp_keys.encoder]) - actor = SACActor( - observation_dim=obs_dim, - action_dim=act_dim, - distribution_cfg=cfg.distribution, - hidden_size=cfg.algo.actor.hidden_size, - action_low=action_space.low, - action_high=action_space.high, + agent = build_agent( + fabric, cfg, observation_space, action_space, state["agent"] if cfg.checkpoint.resume_from else None ) - critics = [ - SACCritic(observation_dim=obs_dim + act_dim, hidden_size=cfg.algo.critic.hidden_size, num_critics=1) - for _ in range(cfg.algo.critic.n) - ] - target_entropy = -act_dim - agent = SACAgent(actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device) - if cfg.checkpoint.resume_from: - agent.load_state_dict(state["agent"]) - agent.actor = fabric.setup_module(agent.actor) - agent.critics = [fabric.setup_module(critic) for critic in agent.critics] # Optimizers qf_optimizer = hydra.utils.instantiate(cfg.algo.critic.optimizer, params=agent.qfs.parameters()) diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py index 70be5bad..8881ac2b 100644 --- a/sheeprl/algos/sac/sac_decoupled.py +++ b/sheeprl/algos/sac/sac_decoupled.py @@ -20,7 +20,7 @@ from torch.utils.data.sampler import BatchSampler from torchmetrics import SumMetric -from sheeprl.algos.sac.agent import SACActor, SACAgent, SACCritic +from sheeprl.algos.sac.agent import SACActor, SACAgent, SACCritic, build_agent from sheeprl.algos.sac.sac import train from sheeprl.algos.sac.utils import test from sheeprl.data.buffers import ReplayBuffer @@ -369,27 +369,13 @@ def trainer( assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" # Define the agent and the optimizer and setup them with Fabric - act_dim = prod(envs.single_action_space.shape) - obs_dim = sum([prod(envs.single_observation_space[k].shape) for k in cfg.mlp_keys.encoder]) - - actor = SACActor( - observation_dim=obs_dim, - action_dim=act_dim, - distribution_cfg=cfg.distribution, - hidden_size=cfg.algo.actor.hidden_size, - action_low=envs.single_action_space.low, - action_high=envs.single_action_space.high, + agent = build_agent( + fabric, + cfg, + envs.single_observation_space, + envs.single_action_space, + state["agent"] if cfg.checkpoint.resume_from else None, ) - critics = [ - SACCritic(observation_dim=obs_dim + act_dim, hidden_size=cfg.algo.critic.hidden_size, num_critics=1) - for _ in range(cfg.algo.critic.n) - ] - target_entropy = -act_dim - agent = SACAgent(actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device) - if cfg.checkpoint.resume_from: - agent.load_state_dict(state["agent"]) - agent.actor = fabric.setup_module(agent.actor) - agent.critics = [fabric.setup_module(critic) for critic in agent.critics] # Optimizers qf_optimizer = hydra.utils.instantiate(cfg.algo.critic.optimizer, params=agent.qfs.parameters()) diff --git a/sheeprl/algos/sac_ae/agent.py b/sheeprl/algos/sac_ae/agent.py index 0265cbe4..4fbbbc57 100644 --- a/sheeprl/algos/sac_ae/agent.py +++ b/sheeprl/algos/sac_ae/agent.py @@ -1,16 +1,18 @@ import copy from math import prod -from typing import Any, Dict, List, Sequence, SupportsFloat, Tuple, Union +from typing import Any, Dict, List, Optional, Sequence, SupportsFloat, Tuple, Union +import gymnasium import numpy as np import torch import torch.nn as nn +from lightning import Fabric from lightning.fabric.wrappers import _FabricModule from numpy.typing import NDArray from torch import Size, Tensor from sheeprl.algos.sac_ae.utils import weight_init -from sheeprl.models.models import CNN, MLP, DeCNN, MultiEncoder +from sheeprl.models.models import CNN, MLP, DeCNN, MultiDecoder, MultiEncoder LOG_STD_MAX = 2 LOG_STD_MIN = -10 @@ -448,3 +450,114 @@ def critic_encoder_target_ema(self) -> None: self.critic_unwrapped.encoder.parameters(), self.critic_target.encoder.parameters() ): target_param.data.copy_(self._encoder_tau * param.data + (1 - self._encoder_tau) * target_param.data) + + +def build_agent( + fabric: Fabric, + cfg: Dict[str, Any], + obs_space: gymnasium.spaces.Dict, + action_space: gymnasium.spaces.Box, + agent_state: Optional[Dict[str, Tensor]] = None, + encoder_state: Optional[Dict[str, Tensor]] = None, + decoder_sate: Optional[Dict[str, Tensor]] = None, +) -> Tuple[SACAEAgent, _FabricModule, _FabricModule]: + act_dim = prod(action_space.shape) + target_entropy = -act_dim + + # Define the encoder and decoder and setup them with fabric. + # Then we will set the critic encoder and actor decoder as the unwrapped encoder module: + # we do not need it wrapped with the strategy inside actor and critic + cnn_channels = [prod(obs_space[k].shape[:-2]) for k in cfg.cnn_keys.encoder] + mlp_dims = [obs_space[k].shape[0] for k in cfg.mlp_keys.encoder] + cnn_encoder = ( + CNNEncoder( + in_channels=sum(cnn_channels), + features_dim=cfg.algo.encoder.features_dim, + keys=cfg.cnn_keys.encoder, + screen_size=cfg.env.screen_size, + cnn_channels_multiplier=cfg.algo.encoder.cnn_channels_multiplier, + ) + if cfg.cnn_keys.encoder is not None and len(cfg.cnn_keys.encoder) > 0 + else None + ) + mlp_encoder = ( + MLPEncoder( + sum(mlp_dims), + cfg.mlp_keys.encoder, + cfg.algo.encoder.dense_units, + cfg.algo.encoder.mlp_layers, + eval(cfg.algo.encoder.dense_act), + cfg.algo.encoder.layer_norm, + ) + if cfg.mlp_keys.encoder is not None and len(cfg.mlp_keys.encoder) > 0 + else None + ) + encoder = MultiEncoder(cnn_encoder, mlp_encoder) + cnn_decoder = ( + CNNDecoder( + cnn_encoder.conv_output_shape, + features_dim=encoder.output_dim, + keys=cfg.cnn_keys.decoder, + channels=cnn_channels, + screen_size=cfg.env.screen_size, + cnn_channels_multiplier=cfg.algo.decoder.cnn_channels_multiplier, + ) + if cfg.cnn_keys.decoder is not None and len(cfg.cnn_keys.decoder) > 0 + else None + ) + mlp_decoder = ( + MLPDecoder( + encoder.output_dim, + mlp_dims, + cfg.mlp_keys.decoder, + cfg.algo.decoder.dense_units, + cfg.algo.decoder.mlp_layers, + eval(cfg.algo.decoder.dense_act), + cfg.algo.decoder.layer_norm, + ) + if cfg.mlp_keys.decoder is not None and len(cfg.mlp_keys.decoder) > 0 + else None + ) + decoder = MultiDecoder(cnn_decoder, mlp_decoder) + if cfg.checkpoint.resume_from: + encoder.load_state_dict(encoder_state) + decoder.load_state_dict(decoder_sate) + + # Setup actor and critic. Those will initialize with orthogonal weights + # both the actor and critic + actor = SACAEContinuousActor( + encoder=copy.deepcopy(encoder), + action_dim=act_dim, + distribution_cfg=cfg.distribution, + hidden_size=cfg.algo.actor.hidden_size, + action_low=action_space.low, + action_high=action_space.high, + ) + qfs = [ + SACAEQFunction( + input_dim=encoder.output_dim, action_dim=act_dim, hidden_size=cfg.algo.critic.hidden_size, output_dim=1 + ) + for _ in range(cfg.algo.critic.n) + ] + critic = SACAECritic(encoder=encoder, qfs=qfs) + + # The agent will tied convolutional and linear weights between the encoder actor and critic + agent = SACAEAgent( + actor, + critic, + target_entropy, + alpha=cfg.algo.alpha.alpha, + tau=cfg.algo.tau, + encoder_tau=cfg.algo.encoder.tau, + device=fabric.device, + ) + + if agent_state: + agent.load_state_dict(agent_state) + + encoder = fabric.setup_module(encoder) + decoder = fabric.setup_module(decoder) + agent.actor = fabric.setup_module(agent.actor) + agent.critic = fabric.setup_module(agent.critic) + + return agent, encoder, decoder diff --git a/sheeprl/algos/sac_ae/evaluate.py b/sheeprl/algos/sac_ae/evaluate.py index 508d4fdf..27d1ed32 100644 --- a/sheeprl/algos/sac_ae/evaluate.py +++ b/sheeprl/algos/sac_ae/evaluate.py @@ -1,22 +1,12 @@ from __future__ import annotations -import copy -from math import prod from typing import Any, Dict import gymnasium as gym from lightning import Fabric -from sheeprl.algos.sac_ae.agent import ( - CNNEncoder, - MLPEncoder, - SACAEAgent, - SACAEContinuousActor, - SACAECritic, - SACAEQFunction, -) +from sheeprl.algos.sac_ae.agent import SACAEAgent, build_agent from sheeprl.algos.sac_ae.utils import test_sac_ae -from sheeprl.models.models import MultiEncoder from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.registry import register_evaluation @@ -51,68 +41,8 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder) fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder) - act_dim = prod(action_space.shape) - target_entropy = -act_dim - - # Define the encoder and decoder and setup them with fabric. - # Then we will set the critic encoder and actor decoder as the unwrapped encoder module: - # we do not need it wrapped with the strategy inside actor and critic - cnn_channels = [prod(observation_space[k].shape[:-2]) for k in cfg.cnn_keys.encoder] - mlp_dims = [observation_space[k].shape[0] for k in cfg.mlp_keys.encoder] - cnn_encoder = ( - CNNEncoder( - in_channels=sum(cnn_channels), - features_dim=cfg.algo.encoder.features_dim, - keys=cfg.cnn_keys.encoder, - screen_size=cfg.env.screen_size, - cnn_channels_multiplier=cfg.algo.encoder.cnn_channels_multiplier, - ) - if cfg.cnn_keys.encoder is not None and len(cfg.cnn_keys.encoder) > 0 - else None - ) - mlp_encoder = ( - MLPEncoder( - sum(mlp_dims), - cfg.mlp_keys.encoder, - cfg.algo.encoder.dense_units, - cfg.algo.encoder.mlp_layers, - eval(cfg.algo.encoder.dense_act), - cfg.algo.encoder.layer_norm, - ) - if cfg.mlp_keys.encoder is not None and len(cfg.mlp_keys.encoder) > 0 - else None - ) - encoder = MultiEncoder(cnn_encoder, mlp_encoder) - encoder.load_state_dict(state["encoder"]) - - # Setup actor and critic. Those will initialize with orthogonal weights - # both the actor and critic - actor = SACAEContinuousActor( - encoder=copy.deepcopy(encoder), - action_dim=act_dim, - distribution_cfg=cfg.distribution, - hidden_size=cfg.algo.actor.hidden_size, - action_low=action_space.low, - action_high=action_space.high, - ) - qfs = [ - SACAEQFunction( - input_dim=encoder.output_dim, action_dim=act_dim, hidden_size=cfg.algo.critic.hidden_size, output_dim=1 - ) - for _ in range(cfg.algo.critic.n) - ] - critic = SACAECritic(encoder=encoder, qfs=qfs) - - # The agent will tied convolutional and linear weights between the encoder actor and critic - agent = SACAEAgent( - actor, - critic, - target_entropy, - alpha=cfg.algo.alpha.alpha, - tau=cfg.algo.tau, - encoder_tau=cfg.algo.encoder.tau, - device=fabric.device, + agent: SACAEAgent + agent, _, _ = build_agent( + fabric, cfg, observation_space, action_space, state["agent"], state["encoder"], state["decoder"] ) - agent.load_state_dict(state["agent"]) - agent.actor = fabric.setup_module(agent.actor) test_sac_ae(agent.actor, fabric, cfg, log_dir) diff --git a/sheeprl/algos/sac_ae/sac_ae.py b/sheeprl/algos/sac_ae/sac_ae.py index 46761bfb..3696d95a 100644 --- a/sheeprl/algos/sac_ae/sac_ae.py +++ b/sheeprl/algos/sac_ae/sac_ae.py @@ -4,7 +4,6 @@ import os import time import warnings -from math import prod from typing import Any, Dict, Optional, Union import gymnasium as gym @@ -25,16 +24,7 @@ from torchmetrics import SumMetric from sheeprl.algos.sac.loss import critic_loss, entropy_loss, policy_loss -from sheeprl.algos.sac_ae.agent import ( - CNNDecoder, - CNNEncoder, - MLPDecoder, - MLPEncoder, - SACAEAgent, - SACAEContinuousActor, - SACAECritic, - SACAEQFunction, -) +from sheeprl.algos.sac_ae.agent import SACAEAgent, build_agent from sheeprl.algos.sac_ae.utils import preprocess_obs, test_sac_ae from sheeprl.data.buffers import ReplayBuffer from sheeprl.models.models import MultiDecoder, MultiEncoder @@ -213,95 +203,14 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder) # Define the agent and the optimizer and setup them with Fabric - act_dim = prod(envs.single_action_space.shape) - target_entropy = -act_dim - - # Define the encoder and decoder and setup them with fabric. - # Then we will set the critic encoder and actor decoder as the unwrapped encoder module: - # we do not need it wrapped with the strategy inside actor and critic - cnn_channels = [prod(envs.single_observation_space[k].shape[:-2]) for k in cfg.cnn_keys.encoder] - mlp_dims = [envs.single_observation_space[k].shape[0] for k in cfg.mlp_keys.encoder] - cnn_encoder = ( - CNNEncoder( - in_channels=sum(cnn_channels), - features_dim=cfg.algo.encoder.features_dim, - keys=cfg.cnn_keys.encoder, - screen_size=cfg.env.screen_size, - cnn_channels_multiplier=cfg.algo.encoder.cnn_channels_multiplier, - ) - if cfg.cnn_keys.encoder is not None and len(cfg.cnn_keys.encoder) > 0 - else None - ) - mlp_encoder = ( - MLPEncoder( - sum(mlp_dims), - cfg.mlp_keys.encoder, - cfg.algo.encoder.dense_units, - cfg.algo.encoder.mlp_layers, - eval(cfg.algo.encoder.dense_act), - cfg.algo.encoder.layer_norm, - ) - if cfg.mlp_keys.encoder is not None and len(cfg.mlp_keys.encoder) > 0 - else None - ) - encoder = MultiEncoder(cnn_encoder, mlp_encoder) - cnn_decoder = ( - CNNDecoder( - cnn_encoder.conv_output_shape, - features_dim=encoder.output_dim, - keys=cfg.cnn_keys.decoder, - channels=cnn_channels, - screen_size=cfg.env.screen_size, - cnn_channels_multiplier=cfg.algo.decoder.cnn_channels_multiplier, - ) - if cfg.cnn_keys.decoder is not None and len(cfg.cnn_keys.decoder) > 0 - else None - ) - mlp_decoder = ( - MLPDecoder( - encoder.output_dim, - mlp_dims, - cfg.mlp_keys.decoder, - cfg.algo.decoder.dense_units, - cfg.algo.decoder.mlp_layers, - eval(cfg.algo.decoder.dense_act), - cfg.algo.decoder.layer_norm, - ) - if cfg.mlp_keys.decoder is not None and len(cfg.mlp_keys.decoder) > 0 - else None - ) - decoder = MultiDecoder(cnn_decoder, mlp_decoder) - if cfg.checkpoint.resume_from: - encoder.load_state_dict(state["encoder"]) - decoder.load_state_dict(state["decoder"]) - - # Setup actor and critic. Those will initialize with orthogonal weights - # both the actor and critic - actor = SACAEContinuousActor( - encoder=copy.deepcopy(encoder), - action_dim=act_dim, - distribution_cfg=cfg.distribution, - hidden_size=cfg.algo.actor.hidden_size, - action_low=envs.single_action_space.low, - action_high=envs.single_action_space.high, - ) - qfs = [ - SACAEQFunction( - input_dim=encoder.output_dim, action_dim=act_dim, hidden_size=cfg.algo.critic.hidden_size, output_dim=1 - ) - for _ in range(cfg.algo.critic.n) - ] - critic = SACAECritic(encoder=encoder, qfs=qfs) - - # The agent will tied convolutional and linear weights between the encoder actor and critic - agent = SACAEAgent( - actor, - critic, - target_entropy, - alpha=cfg.algo.alpha.alpha, - tau=cfg.algo.tau, - encoder_tau=cfg.algo.encoder.tau, - device=fabric.device, + agent, encoder, decoder = build_agent( + fabric, + cfg, + observation_space, + envs.single_action_space, + state["agent"] if cfg.checkpoint.resume_from else None, + state["encoder"] if cfg.checkpoint.resume_from else None, + state["decoder"] if cfg.checkpoint.resume_from else None, ) # Optimizers @@ -312,18 +221,12 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): decoder_optimizer = hydra.utils.instantiate(cfg.algo.decoder.optimizer, params=decoder.parameters()) if cfg.checkpoint.resume_from: - agent.load_state_dict(state["agent"]) qf_optimizer.load_state_dict(state["qf_optimizer"]) actor_optimizer.load_state_dict(state["actor_optimizer"]) alpha_optimizer.load_state_dict(state["alpha_optimizer"]) encoder_optimizer.load_state_dict(state["encoder_optimizer"]) decoder_optimizer.load_state_dict(state["decoder_optimizer"]) - encoder = fabric.setup_module(encoder) - decoder = fabric.setup_module(decoder) - agent.actor = fabric.setup_module(agent.actor) - agent.critic = fabric.setup_module(agent.critic) - qf_optimizer, actor_optimizer, alpha_optimizer, encoder_optimizer, decoder_optimizer = fabric.setup_optimizers( qf_optimizer, actor_optimizer, alpha_optimizer, encoder_optimizer, decoder_optimizer ) diff --git a/sheeprl/configs/model_manager/default.yaml b/sheeprl/configs/model_manager/default.yaml index 7ab58825..e397b00a 100644 --- a/sheeprl/configs/model_manager/default.yaml +++ b/sheeprl/configs/model_manager/default.yaml @@ -1,2 +1,2 @@ -disabled: False +disabled: True models: {}