From 29fee9988636bf3bbd80c220fa75fa2377546e6b Mon Sep 17 00:00:00 2001 From: Michele Milesi Date: Tue, 21 Nov 2023 10:59:26 +0100 Subject: [PATCH] feat: added ensembles creation to build agent function --- sheeprl/algos/p2e_dv1/agent.py | 41 +++++++++++--- sheeprl/algos/p2e_dv1/evaluate.py | 2 +- sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py | 29 +--------- sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py | 3 +- sheeprl/algos/p2e_dv1/utils.py | 36 +----------- sheeprl/algos/p2e_dv2/agent.py | 58 +++++++++++++++++--- sheeprl/algos/p2e_dv2/evaluate.py | 3 +- sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py | 41 +------------- sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py | 3 +- sheeprl/algos/p2e_dv2/utils.py | 36 +----------- sheeprl/algos/p2e_dv3/agent.py | 54 ++++++++++++++++-- sheeprl/algos/p2e_dv3/evaluate.py | 3 +- sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py | 38 +------------ sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py | 2 + sheeprl/algos/p2e_dv3/utils.py | 44 ++++----------- 15 files changed, 169 insertions(+), 224 deletions(-) diff --git a/sheeprl/algos/p2e_dv1/agent.py b/sheeprl/algos/p2e_dv1/agent.py index 9fb17606..53e09231 100644 --- a/sheeprl/algos/p2e_dv1/agent.py +++ b/sheeprl/algos/p2e_dv1/agent.py @@ -5,6 +5,8 @@ import torch from lightning.fabric import Fabric from lightning.fabric.wrappers import _FabricModule +from lightning.pytorch.utilities.seed import isolate_rng +from torch import nn from sheeprl.algos.dreamer_v1.agent import WorldModel from sheeprl.algos.dreamer_v1.agent import build_agent as dv1_build_agent @@ -27,11 +29,12 @@ def build_agent( cfg: Dict[str, Any], obs_space: gymnasium.spaces.Dict, world_model_state: Optional[Dict[str, torch.Tensor]] = None, + ensembles_state: Optional[Dict[str, torch.Tensor]] = None, actor_task_state: Optional[Dict[str, torch.Tensor]] = None, critic_task_state: Optional[Dict[str, torch.Tensor]] = None, actor_exploration_state: Optional[Dict[str, torch.Tensor]] = None, critic_exploration_state: Optional[Dict[str, torch.Tensor]] = None, -) -> Tuple[WorldModel, _FabricModule, _FabricModule, _FabricModule, _FabricModule]: +) -> Tuple[WorldModel, _FabricModule, _FabricModule, _FabricModule, _FabricModule, _FabricModule]: """Build the models and wrap them with Fabric. Args: @@ -42,6 +45,8 @@ def build_agent( obs_space (Dict[str, Any]): the observation space. world_model_state (Dict[str, Tensor], optional): the state of the world model. Default to None. + ensembles_state (Dict[str, Tensor], optional): the state of the ensembles. + Default to None. actor_task_state (Dict[str, Tensor], optional): the state of the actor_task. Default to None. critic_task_state (Dict[str, Tensor], optional): the state of the critic_task. @@ -53,11 +58,12 @@ def build_agent( Returns: The world model (WorldModel): composed by the encoder, rssm, observation and - reward models and the continue model. - The actor_task (_FabricModule). - The critic_task (_FabricModule). - The actor_exploration (_FabricModule). - The critic_exploration (_FabricModule). + reward models and the continue model. + The ensembles (_FabricModule): for estimating the intrinsic reward. + The actor_task (_FabricModule): for learning the task. + The critic_task (_FabricModule): for predicting the values of the task. + The actor_exploration (_FabricModule): for exploring the environment. + The critic_exploration (_FabricModule): for predicting the values of the exploration. """ world_model_cfg = cfg.algo.world_model actor_cfg = cfg.algo.actor @@ -110,4 +116,25 @@ def build_agent( actor_task = fabric.setup_module(actor_task) critic_task = fabric.setup_module(critic_task) - return world_model, actor_task, critic_task, actor_exploration, critic_exploration + ens_list = [] + with isolate_rng(): + for i in range(cfg.algo.ensembles.n): + fabric.seed_everything(cfg.seed + i) + ens_list.append( + MLP( + input_dims=( + int(sum(actions_dim)) + + cfg.algo.world_model.recurrent_model.recurrent_state_size + + cfg.algo.world_model.stochastic_size + ), + output_dim=world_model.encoder.cnn_output_dim + world_model.encoder.mlp_output_dim, + hidden_sizes=[cfg.algo.ensembles.dense_units] * cfg.algo.ensembles.mlp_layers, + activation=eval(cfg.algo.ensembles.dense_act), + ).apply(init_weights) + ) + ensembles = nn.ModuleList(ens_list) + if ensembles_state: + ensembles.load_state_dict(ensembles_state) + fabric.setup_module(ensembles) + + return world_model, ensembles, actor_task, critic_task, actor_exploration, critic_exploration diff --git a/sheeprl/algos/p2e_dv1/evaluate.py b/sheeprl/algos/p2e_dv1/evaluate.py index 349f4884..76f19156 100644 --- a/sheeprl/algos/p2e_dv1/evaluate.py +++ b/sheeprl/algos/p2e_dv1/evaluate.py @@ -48,7 +48,7 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) # Create the actor and critic models - world_model, actor_task, _, _, _ = build_agent( + world_model, _, actor_task, _, _, _ = build_agent( fabric, actions_dim, is_continuous, diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py index c81f569e..8e89bc71 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py @@ -13,11 +13,9 @@ import torch.nn.functional as F from lightning.fabric import Fabric from lightning.fabric.wrappers import _FabricModule, _FabricOptimizer -from lightning.pytorch.utilities.seed import isolate_rng from mlflow.models.model import ModelInfo from tensordict import TensorDict from tensordict.tensordict import TensorDictBase -from torch import nn from torch.distributions import Bernoulli, Independent, Normal from torch.utils.data import BatchSampler from torchmetrics import SumMetric @@ -28,13 +26,12 @@ from sheeprl.algos.dreamer_v2.utils import test from sheeprl.algos.p2e_dv1.agent import build_agent from sheeprl.data.buffers import AsyncReplayBuffer -from sheeprl.models.models import MLP from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger from sheeprl.utils.metric import MetricAggregator from sheeprl.utils.registry import register_algorithm from sheeprl.utils.timer import timer -from sheeprl.utils.utils import init_weights, polynomial_decay, register_model, unwrap_fabric +from sheeprl.utils.utils import polynomial_decay, register_model, unwrap_fabric # Decomment the following line if you are using MineDojo on an headless machine # os.environ["MINEDOJO_HEADLESS"] = "1" @@ -488,40 +485,20 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder) obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder - world_model, actor_task, critic_task, actor_exploration, critic_exploration = build_agent( + world_model, ensembles, actor_task, critic_task, actor_exploration, critic_exploration = build_agent( fabric, actions_dim, is_continuous, cfg, observation_space, state["world_model"] if cfg.checkpoint.resume_from else None, + state["ensembles"] if cfg.checkpoint.resume_from else None, state["actor_task"] if cfg.checkpoint.resume_from else None, state["critic_task"] if cfg.checkpoint.resume_from else None, state["actor_exploration"] if cfg.checkpoint.resume_from else None, state["critic_exploration"] if cfg.checkpoint.resume_from else None, ) - # initialize the ensembles with different seeds to be sure they have different weights - ens_list = [] - with isolate_rng(): - for i in range(cfg.algo.ensembles.n): - fabric.seed_everything(cfg.seed + i) - ens_list.append( - MLP( - input_dims=( - int(sum(actions_dim)) - + cfg.algo.world_model.recurrent_model.recurrent_state_size - + cfg.algo.world_model.stochastic_size - ), - output_dim=world_model.encoder.cnn_output_dim + world_model.encoder.mlp_output_dim, - hidden_sizes=[cfg.algo.ensembles.dense_units] * cfg.algo.ensembles.mlp_layers, - activation=eval(cfg.algo.ensembles.dense_act), - ).apply(init_weights) - ) - ensembles = nn.ModuleList(ens_list) - if cfg.checkpoint.resume_from: - ensembles.load_state_dict(state["ensembles"]) - fabric.setup_module(ensembles) player = PlayerDV1( world_model.encoder.module, world_model.rssm.recurrent_model.module, diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py index d8d743e0..4d841065 100644 --- a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py +++ b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py @@ -139,13 +139,14 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder) obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder - world_model, actor_task, critic_task, actor_exploration, _ = build_agent( + world_model, _, actor_task, critic_task, actor_exploration, _ = build_agent( fabric, actions_dim, is_continuous, cfg, observation_space, state["world_model"], + None, state["actor_task"], state["critic_task"], state["actor_exploration"], diff --git a/sheeprl/algos/p2e_dv1/utils.py b/sheeprl/algos/p2e_dv1/utils.py index fa462c13..81cf5f34 100644 --- a/sheeprl/algos/p2e_dv1/utils.py +++ b/sheeprl/algos/p2e_dv1/utils.py @@ -4,11 +4,9 @@ import mlflow from lightning import Fabric from mlflow.models.model import ModelInfo -from torch import nn from sheeprl.algos.dreamer_v1.utils import AGGREGATOR_KEYS as AGGREGATOR_KEYS_DV1 from sheeprl.algos.p2e_dv1.agent import build_agent -from sheeprl.models.models import MLP from sheeprl.utils.utils import unwrap_fabric AGGREGATOR_KEYS = { @@ -62,6 +60,7 @@ def log_models_from_checkpoint( ) ( world_model, + ensembles, actor_task, critic_task, actor_exploration, @@ -73,42 +72,13 @@ def log_models_from_checkpoint( cfg, env.observation_space, state["world_model"], + state["ensembles"] if "exploration" in cfg.algo.name else None, state["actor_task"], state["critic_task"], state["actor_exploration"] if "exploration" in cfg.algo.name else None, state["critic_exploration"] if "exploration" in cfg.algo.name else None, ) - if "exploration" in cfg.algo.name: - ens_list = [] - cfg_ensembles = cfg.algo.ensembles - for i in range(cfg_ensembles.n): - fabric.seed_everything(cfg.seed + i) - ens_list.append( - MLP( - input_dims=int( - sum(actions_dim) - + cfg.algo.world_model.recurrent_model.recurrent_state_size - + cfg.algo.world_model.stochastic_size * cfg.algo.world_model.discrete_size - ), - output_dim=cfg.algo.world_model.stochastic_size * cfg.algo.world_model.discrete_size, - hidden_sizes=[cfg_ensembles.dense_units] * cfg_ensembles.mlp_layers, - activation=eval(cfg_ensembles.dense_act), - flatten_dim=None, - layer_args={"bias": not cfg.algo.ensembles.layer_norm}, - norm_layer=( - [nn.LayerNorm for _ in range(cfg_ensembles.mlp_layers)] if cfg_ensembles.layer_norm else None - ), - norm_args=( - [{"normalized_shape": cfg_ensembles.dense_units} for _ in range(cfg_ensembles.mlp_layers)] - if cfg_ensembles.layer_norm - else None - ), - ) - ) - ensembles = nn.ModuleList(ens_list) - ensembles.load_state_dict(state["ensembles"]) - # Log the model, create a new run if `cfg.run_id` is None. model_info = {} with mlflow.start_run(run_id=cfg.run_id, nested=True) as _: @@ -116,7 +86,7 @@ def log_models_from_checkpoint( model_info["actor_task"] = mlflow.pytorch.log_model(unwrap_fabric(actor_task), artifact_path="actor_task") model_info["critic_task"] = mlflow.pytorch.log_model(unwrap_fabric(critic_task), artifact_path="critic_task") if "exploration" in cfg.algo.name: - model_info["ensembles"] = mlflow.pytorch.log_model(ensembles, artifact_path="ensembles") + model_info["ensembles"] = mlflow.pytorch.log_model(unwrap_fabric(ensembles), artifact_path="ensembles") model_info["actor_exploration"] = mlflow.pytorch.log_model( unwrap_fabric(actor_exploration), artifact_path="actor_exploration" ) diff --git a/sheeprl/algos/p2e_dv2/agent.py b/sheeprl/algos/p2e_dv2/agent.py index 3ffef72f..fd12963d 100644 --- a/sheeprl/algos/p2e_dv2/agent.py +++ b/sheeprl/algos/p2e_dv2/agent.py @@ -6,6 +6,7 @@ import torch from lightning.fabric import Fabric from lightning.fabric.wrappers import _FabricModule +from lightning.pytorch.utilities.seed import isolate_rng from torch import nn from sheeprl.algos.dreamer_v2.agent import Actor as DV2Actor @@ -29,13 +30,14 @@ def build_agent( cfg: Dict[str, Any], obs_space: gymnasium.spaces.Dict, world_model_state: Optional[Dict[str, torch.Tensor]] = None, + ensembles_state: Optional[Dict[str, torch.Tensor]] = None, actor_task_state: Optional[Dict[str, torch.Tensor]] = None, critic_task_state: Optional[Dict[str, torch.Tensor]] = None, target_critic_task_state: Optional[Dict[str, torch.Tensor]] = None, actor_exploration_state: Optional[Dict[str, torch.Tensor]] = None, critic_exploration_state: Optional[Dict[str, torch.Tensor]] = None, target_critic_exploration_state: Optional[Dict[str, torch.Tensor]] = None, -) -> Tuple[WorldModel, _FabricModule, _FabricModule, nn.Module, _FabricModule, _FabricModule, nn.Module]: +) -> Tuple[WorldModel, _FabricModule, _FabricModule, _FabricModule, nn.Module, _FabricModule, _FabricModule, nn.Module]: """Build the models and wrap them with Fabric. Args: @@ -46,6 +48,8 @@ def build_agent( obs_space (Dict[str, Any]): The observations space of the environment. world_model_state (Dict[str, Tensor], optional): the state of the world model. Default to None. + ensembles_state (Dict[str, Tensor], optional): the state of the ensembles. + Default to None. actor_task_state (Dict[str, Tensor], optional): the state of the actor_task. Default to None. critic_task_state (Dict[str, Tensor], optional): the state of the critic_task. @@ -61,13 +65,14 @@ def build_agent( Returns: The world model (WorldModel): composed by the encoder, rssm, observation and - reward models and the continue model. - The actor_task (_FabricModule). - The critic_task (_FabricModule). - The target_critic_task (nn.Module). - The actor_exploration (_FabricModule). - The critic_exploration (_FabricModule). - The target_critic_exploration (nn.Module). + reward models and the continue model. + The ensembles (_FabricModule): for estimating the intrinsic reward. + The actor_task (_FabricModule): for learning the task. + The critic_task (_FabricModule): for predicting the values of the task. + The target_critic_task (nn.Module): takes a EMA of the critic_task weights. + The actor_exploration (_FabricModule): for exploring the environment. + The critic_exploration (_FabricModule): for predicting the values of the exploration. + The target_critic_exploration (nn.Module): takes a EMA of the critic_exploration weights. """ world_model_cfg = cfg.algo.world_model actor_cfg = cfg.algo.actor @@ -131,8 +136,45 @@ def build_agent( if target_critic_task_state: target_critic_task.load_state_dict(target_critic_task_state) + # initialize the ensembles with different seeds to be sure they have different weights + ens_list = [] + with isolate_rng(): + for i in range(cfg.algo.ensembles.n): + fabric.seed_everything(cfg.seed + i) + ens_list.append( + MLP( + input_dims=int( + sum(actions_dim) + + cfg.algo.world_model.recurrent_model.recurrent_state_size + + cfg.algo.world_model.stochastic_size * cfg.algo.world_model.discrete_size + ), + output_dim=cfg.algo.world_model.stochastic_size * cfg.algo.world_model.discrete_size, + hidden_sizes=[cfg.algo.ensembles.dense_units] * cfg.algo.ensembles.mlp_layers, + activation=eval(cfg.algo.ensembles.dense_act), + flatten_dim=None, + norm_layer=( + [nn.LayerNorm for _ in range(cfg.algo.ensembles.mlp_layers)] + if cfg.algo.ensembles.layer_norm + else None + ), + norm_args=( + [ + {"normalized_shape": cfg.algo.ensembles.dense_units} + for _ in range(cfg.algo.ensembles.mlp_layers) + ] + if cfg.algo.ensembles.layer_norm + else None + ), + ).apply(init_weights) + ) + ensembles = nn.ModuleList(ens_list) + if ensembles_state: + ensembles.load_state_dict(ensembles_state) + fabric.setup_module(ensembles) + return ( world_model, + ensembles, actor_task, critic_task, target_critic_task, diff --git a/sheeprl/algos/p2e_dv2/evaluate.py b/sheeprl/algos/p2e_dv2/evaluate.py index cb420460..0cc459f4 100644 --- a/sheeprl/algos/p2e_dv2/evaluate.py +++ b/sheeprl/algos/p2e_dv2/evaluate.py @@ -48,13 +48,14 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) # Create the actor and critic models - world_model, actor_task, _, _, _, _, _ = build_agent( + world_model, _, actor_task, _, _, _, _, _ = build_agent( fabric, actions_dim, is_continuous, cfg, observation_space, state["world_model"], + None, state["actor_task"], ) player = PlayerDV2( diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py index da3a1be6..28a59676 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py @@ -13,7 +13,6 @@ import torch.nn.functional as F from lightning.fabric import Fabric from lightning.fabric.wrappers import _FabricModule, _FabricOptimizer -from lightning.pytorch.utilities.seed import isolate_rng from mlflow.models.model import ModelInfo from tensordict import TensorDict from tensordict.tensordict import TensorDictBase @@ -24,10 +23,9 @@ from sheeprl.algos.dreamer_v2.agent import PlayerDV2, WorldModel from sheeprl.algos.dreamer_v2.loss import reconstruction_loss -from sheeprl.algos.dreamer_v2.utils import compute_lambda_values, init_weights, test +from sheeprl.algos.dreamer_v2.utils import compute_lambda_values, test from sheeprl.algos.p2e_dv2.agent import build_agent from sheeprl.data.buffers import AsyncReplayBuffer, EpisodeBuffer -from sheeprl.models.models import MLP from sheeprl.utils.distribution import OneHotCategoricalValidateArgs from sheeprl.utils.env import make_env from sheeprl.utils.logger import get_log_dir, get_logger @@ -605,6 +603,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ( world_model, + ensembles, actor_task, critic_task, target_critic_task, @@ -618,6 +617,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): cfg, observation_space, state["world_model"] if cfg.checkpoint.resume_from else None, + state["ensembles"] if cfg.checkpoint.resume_from else None, state["actor_task"] if cfg.checkpoint.resume_from else None, state["critic_task"] if cfg.checkpoint.resume_from else None, state["target_critic_task"] if cfg.checkpoint.resume_from else None, @@ -626,41 +626,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): state["target_critic_exploration"] if cfg.checkpoint.resume_from else None, ) - # initialize the ensembles with different seeds to be sure they have different weights - ens_list = [] - with isolate_rng(): - for i in range(cfg.algo.ensembles.n): - fabric.seed_everything(cfg.seed + i) - ens_list.append( - MLP( - input_dims=int( - sum(actions_dim) - + cfg.algo.world_model.recurrent_model.recurrent_state_size - + cfg.algo.world_model.stochastic_size * cfg.algo.world_model.discrete_size - ), - output_dim=cfg.algo.world_model.stochastic_size * cfg.algo.world_model.discrete_size, - hidden_sizes=[cfg.algo.ensembles.dense_units] * cfg.algo.ensembles.mlp_layers, - activation=eval(cfg.algo.ensembles.dense_act), - flatten_dim=None, - norm_layer=( - [nn.LayerNorm for _ in range(cfg.algo.ensembles.mlp_layers)] - if cfg.algo.ensembles.layer_norm - else None - ), - norm_args=( - [ - {"normalized_shape": cfg.algo.ensembles.dense_units} - for _ in range(cfg.algo.ensembles.mlp_layers) - ] - if cfg.algo.ensembles.layer_norm - else None - ), - ).apply(init_weights) - ) - ensembles = nn.ModuleList(ens_list) - if cfg.checkpoint.resume_from: - ensembles.load_state_dict(state["ensembles"]) - fabric.setup_module(ensembles) player = PlayerDV2( world_model.encoder.module, world_model.rssm.recurrent_model.module, diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py index adebc83b..e9600d06 100644 --- a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py +++ b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py @@ -143,13 +143,14 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder) obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder - world_model, actor_task, critic_task, target_critic_task, actor_exploration, _, _ = build_agent( + world_model, _, actor_task, critic_task, target_critic_task, actor_exploration, _, _ = build_agent( fabric, actions_dim, is_continuous, cfg, observation_space, state["world_model"], + None, state["actor_task"], state["critic_task"], state["target_critic_task"], diff --git a/sheeprl/algos/p2e_dv2/utils.py b/sheeprl/algos/p2e_dv2/utils.py index 92e90af9..4388b067 100644 --- a/sheeprl/algos/p2e_dv2/utils.py +++ b/sheeprl/algos/p2e_dv2/utils.py @@ -4,11 +4,9 @@ import mlflow from lightning import Fabric from mlflow.models.model import ModelInfo -from torch import nn from sheeprl.algos.dreamer_v2.utils import AGGREGATOR_KEYS as AGGREGATOR_KEYS_DV2 from sheeprl.algos.p2e_dv2.agent import build_agent -from sheeprl.models.models import MLP from sheeprl.utils.utils import unwrap_fabric AGGREGATOR_KEYS = { @@ -64,6 +62,7 @@ def log_models_from_checkpoint( ) ( world_model, + ensembles, actor_task, critic_task, target_critic_task, @@ -77,6 +76,7 @@ def log_models_from_checkpoint( cfg, env.observation_space, state["world_model"], + state["ensembles"] if "exploration" in cfg.algo.name else None, state["actor_task"], state["critic_task"], state["target_critic_task"], @@ -85,36 +85,6 @@ def log_models_from_checkpoint( state["target_critic_exploration"] if "exploration" in cfg.algo.name else None, ) - if "exploration" in cfg.algo.name: - ens_list = [] - cfg_ensembles = cfg.algo.ensembles - for i in range(cfg_ensembles.n): - fabric.seed_everything(cfg.seed + i) - ens_list.append( - MLP( - input_dims=int( - sum(actions_dim) - + cfg.algo.world_model.recurrent_model.recurrent_state_size - + cfg.algo.world_model.stochastic_size * cfg.algo.world_model.discrete_size - ), - output_dim=cfg.algo.world_model.stochastic_size * cfg.algo.world_model.discrete_size, - hidden_sizes=[cfg_ensembles.dense_units] * cfg_ensembles.mlp_layers, - activation=eval(cfg_ensembles.dense_act), - flatten_dim=None, - layer_args={"bias": not cfg.algo.ensembles.layer_norm}, - norm_layer=( - [nn.LayerNorm for _ in range(cfg_ensembles.mlp_layers)] if cfg_ensembles.layer_norm else None - ), - norm_args=( - [{"normalized_shape": cfg_ensembles.dense_units} for _ in range(cfg_ensembles.mlp_layers)] - if cfg_ensembles.layer_norm - else None - ), - ) - ) - ensembles = nn.ModuleList(ens_list) - ensembles.load_state_dict(state["ensembles"]) - # Log the model, create a new run if `cfg.run_id` is None. model_info = {} with mlflow.start_run(run_id=cfg.run_id, nested=True) as _: @@ -125,7 +95,7 @@ def log_models_from_checkpoint( target_critic_task, artifact_path="target_critic_task" ) if "exploration" in cfg.algo.name: - model_info["ensembles"] = mlflow.pytorch.log_model(ensembles, artifact_path="ensembles") + model_info["ensembles"] = mlflow.pytorch.log_model(unwrap_fabric(ensembles), artifact_path="ensembles") model_info["actor_exploration"] = mlflow.pytorch.log_model( unwrap_fabric(actor_exploration), artifact_path="actor_exploration" ) diff --git a/sheeprl/algos/p2e_dv3/agent.py b/sheeprl/algos/p2e_dv3/agent.py index 586e3e1e..aacfa47c 100644 --- a/sheeprl/algos/p2e_dv3/agent.py +++ b/sheeprl/algos/p2e_dv3/agent.py @@ -5,6 +5,7 @@ import torch from lightning.fabric import Fabric from lightning.fabric.wrappers import _FabricModule +from lightning.pytorch.utilities.seed import isolate_rng from torch import nn from sheeprl.algos.dreamer_v3.agent import Actor as DV3Actor @@ -28,12 +29,13 @@ def build_agent( cfg: Dict[str, Any], obs_space: Dict[str, Any], world_model_state: Optional[Dict[str, torch.Tensor]] = None, + ensembles_state: Optional[Dict[str, torch.Tensor]] = None, actor_task_state: Optional[Dict[str, torch.Tensor]] = None, critic_task_state: Optional[Dict[str, torch.Tensor]] = None, target_critic_task_state: Optional[Dict[str, torch.Tensor]] = None, actor_exploration_state: Optional[Dict[str, torch.Tensor]] = None, critics_exploration_state: Optional[Dict[str, Dict[str, Any]]] = None, -) -> Tuple[WorldModel, _FabricModule, _FabricModule, nn.Module, _FabricModule, Dict[str, Any]]: +) -> Tuple[WorldModel, _FabricModule, _FabricModule, _FabricModule, nn.Module, _FabricModule, Dict[str, Any]]: """Build the models and wrap them with Fabric. Args: @@ -44,6 +46,8 @@ def build_agent( obs_space (Dict[str, Any]): The observations space of the environment. world_model_state (Dict[str, Tensor], optional): the state of the world model. Default to None. + ensembles_state (Dict[str, Tensor], optional): the state of the ensembles. + Default to None. actor_task_state (Dict[str, Tensor], optional): the state of the actor_task. Default to None. critic_task_state (Dict[str, Tensor], optional): the state of the critic_task. @@ -58,11 +62,15 @@ def build_agent( Returns: The world model (WorldModel): composed by the encoder, rssm, observation and reward models and the continue model. - The actor_task (_FabricModule). - The critic_task (_FabricModule). - The target_critic_task (nn.Module). - The actor_exploration (_FabricModule). - The critics_exploration (Dict[str, Dict[str, Any]]). + + The ensembles (_FabricModule): for estimating the intrinsic reward. + The actor_task (_FabricModule): for learning the task. + The critic_task (_FabricModule): for predicting the values of the task. + The target_critic_task (nn.Module): takes a EMA of the critic_task weights. + The actor_exploration (_FabricModule): for exploring the environment. + The critics_exploration (_FabricModule): for predicting the values of the exploration. + The critics_exploration (Dict[str, Dict[str, Any]]): python dictionary containing all the exploration critics. + The critic is under the 'module' key, whereas, the target critic is under the 'target_critic' key. """ world_model_cfg = cfg.algo.world_model actor_cfg = cfg.algo.actor @@ -152,8 +160,42 @@ def build_agent( for c in critics_exploration.values(): c["target_module"].requires_grad_(False) + # initialize the ensembles with different seeds to be sure they have different weights + ens_list = [] + cfg_ensembles = cfg.algo.ensembles + with isolate_rng(): + for i in range(cfg_ensembles.n): + fabric.seed_everything(cfg.seed + i) + ens_list.append( + MLP( + input_dims=int( + sum(actions_dim) + + cfg.algo.world_model.recurrent_model.recurrent_state_size + + cfg.algo.world_model.stochastic_size * cfg.algo.world_model.discrete_size + ), + output_dim=cfg.algo.world_model.stochastic_size * cfg.algo.world_model.discrete_size, + hidden_sizes=[cfg_ensembles.dense_units] * cfg_ensembles.mlp_layers, + activation=eval(cfg_ensembles.dense_act), + flatten_dim=None, + layer_args={"bias": not cfg.algo.ensembles.layer_norm}, + norm_layer=( + [nn.LayerNorm for _ in range(cfg_ensembles.mlp_layers)] if cfg_ensembles.layer_norm else None + ), + norm_args=( + [{"normalized_shape": cfg_ensembles.dense_units} for _ in range(cfg_ensembles.mlp_layers)] + if cfg_ensembles.layer_norm + else None + ), + ).apply(init_weights) + ) + ensembles = nn.ModuleList(ens_list) + if ensembles_state: + ensembles.load_state_dict(ensembles_state) + fabric.setup_module(ensembles) + return ( world_model, + ensembles, actor_task, critic_task, target_critic_task, diff --git a/sheeprl/algos/p2e_dv3/evaluate.py b/sheeprl/algos/p2e_dv3/evaluate.py index 20ccd61d..8e79c87c 100644 --- a/sheeprl/algos/p2e_dv3/evaluate.py +++ b/sheeprl/algos/p2e_dv3/evaluate.py @@ -48,13 +48,14 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]): action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n]) ) # Create the actor and critic models - world_model, actor, _, _, _, _ = build_agent( + world_model, _, actor, _, _, _, _ = build_agent( fabric, actions_dim, is_continuous, cfg, observation_space, state["world_model"], + None, state["actor_task"], ) player = PlayerDV3( diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py index 776c5420..321b4de1 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py @@ -11,7 +11,6 @@ import torch.nn.functional as F from lightning.fabric import Fabric from lightning.fabric.wrappers import _FabricModule, _FabricOptimizer -from lightning.pytorch.utilities.seed import isolate_rng from mlflow.models.model import ModelInfo from omegaconf import DictConfig from tensordict import TensorDict @@ -23,10 +22,9 @@ from sheeprl.algos.dreamer_v3.agent import PlayerDV3, WorldModel from sheeprl.algos.dreamer_v3.loss import reconstruction_loss -from sheeprl.algos.dreamer_v3.utils import Moments, compute_lambda_values, init_weights, test +from sheeprl.algos.dreamer_v3.utils import Moments, compute_lambda_values, test from sheeprl.algos.p2e_dv3.agent import build_agent from sheeprl.data.buffers import AsyncReplayBuffer -from sheeprl.models.models import MLP from sheeprl.utils.distribution import ( MSEDistribution, OneHotCategoricalValidateArgs, @@ -634,6 +632,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): ( world_model, + ensembles, actor_task, critic_task, target_critic_task, @@ -646,6 +645,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): cfg, observation_space, state["world_model"] if cfg.checkpoint.resume_from else None, + state["ensembles"] if cfg.checkpoint.resume_from else None, state["actor_task"] if cfg.checkpoint.resume_from else None, state["critic_task"] if cfg.checkpoint.resume_from else None, state["target_critic_task"] if cfg.checkpoint.resume_from else None, @@ -653,38 +653,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any]): state["critics_exploration"] if cfg.checkpoint.resume_from else None, ) - # initialize the ensembles with different seeds to be sure they have different weights - ens_list = [] - cfg_ensembles = cfg.algo.ensembles - with isolate_rng(): - for i in range(cfg_ensembles.n): - fabric.seed_everything(cfg.seed + i) - ens_list.append( - MLP( - input_dims=int( - sum(actions_dim) - + cfg.algo.world_model.recurrent_model.recurrent_state_size - + cfg.algo.world_model.stochastic_size * cfg.algo.world_model.discrete_size - ), - output_dim=cfg.algo.world_model.stochastic_size * cfg.algo.world_model.discrete_size, - hidden_sizes=[cfg_ensembles.dense_units] * cfg_ensembles.mlp_layers, - activation=eval(cfg_ensembles.dense_act), - flatten_dim=None, - layer_args={"bias": not cfg.algo.ensembles.layer_norm}, - norm_layer=( - [nn.LayerNorm for _ in range(cfg_ensembles.mlp_layers)] if cfg_ensembles.layer_norm else None - ), - norm_args=( - [{"normalized_shape": cfg_ensembles.dense_units} for _ in range(cfg_ensembles.mlp_layers)] - if cfg_ensembles.layer_norm - else None - ), - ).apply(init_weights) - ) - ensembles = nn.ModuleList(ens_list) - if cfg.checkpoint.resume_from: - ensembles.load_state_dict(state["ensembles"]) - fabric.setup_module(ensembles) player = PlayerDV3( world_model.encoder.module, world_model.rssm, diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py index ea1eef9f..fdba7836 100644 --- a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py +++ b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py @@ -139,6 +139,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): ( world_model, + _, actor_task, critic_task, target_critic_task, @@ -151,6 +152,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]): cfg, observation_space, state["world_model"], + None, state["actor_task"], state["critic_task"], state["target_critic_task"], diff --git a/sheeprl/algos/p2e_dv3/utils.py b/sheeprl/algos/p2e_dv3/utils.py index 1407cc01..754c766f 100644 --- a/sheeprl/algos/p2e_dv3/utils.py +++ b/sheeprl/algos/p2e_dv3/utils.py @@ -4,12 +4,10 @@ import mlflow from lightning import Fabric from mlflow.models.model import ModelInfo -from torch import nn from sheeprl.algos.dreamer_v3.utils import AGGREGATOR_KEYS as AGGREGATOR_KEYS_DV3 from sheeprl.algos.dreamer_v3.utils import Moments from sheeprl.algos.p2e_dv3.agent import build_agent -from sheeprl.models.models import MLP from sheeprl.utils.utils import unwrap_fabric AGGREGATOR_KEYS = { @@ -69,13 +67,22 @@ def log_models_from_checkpoint( if is_continuous else (env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]) ) - world_model, actor_task, critic_task, target_critic_task, actor_exploration, critics_exploration = build_agent( + ( + world_model, + ensembles, + actor_task, + critic_task, + target_critic_task, + actor_exploration, + critics_exploration, + ) = build_agent( fabric, actions_dim, is_continuous, cfg, env.observation_space, state["world_model"], + state["ensembles"] if "exploration" in cfg.algo.name else None, state["actor_task"], state["critic_task"], state["target_critic_task"], @@ -92,35 +99,6 @@ def log_models_from_checkpoint( moments_task.load_state_dict(state["moments_task"]) if "exploration" in cfg.algo.name: - ens_list = [] - cfg_ensembles = cfg.algo.ensembles - for i in range(cfg_ensembles.n): - fabric.seed_everything(cfg.seed + i) - ens_list.append( - MLP( - input_dims=int( - sum(actions_dim) - + cfg.algo.world_model.recurrent_model.recurrent_state_size - + cfg.algo.world_model.stochastic_size * cfg.algo.world_model.discrete_size - ), - output_dim=cfg.algo.world_model.stochastic_size * cfg.algo.world_model.discrete_size, - hidden_sizes=[cfg_ensembles.dense_units] * cfg_ensembles.mlp_layers, - activation=eval(cfg_ensembles.dense_act), - flatten_dim=None, - layer_args={"bias": not cfg.algo.ensembles.layer_norm}, - norm_layer=( - [nn.LayerNorm for _ in range(cfg_ensembles.mlp_layers)] if cfg_ensembles.layer_norm else None - ), - norm_args=( - [{"normalized_shape": cfg_ensembles.dense_units} for _ in range(cfg_ensembles.mlp_layers)] - if cfg_ensembles.layer_norm - else None - ), - ) - ) - ensembles = nn.ModuleList(ens_list) - ensembles.load_state_dict(state["ensembles"]) - moments_exploration = { k: Moments( fabric, @@ -145,7 +123,7 @@ def log_models_from_checkpoint( ) model_info["moments_task"] = mlflow.pytorch.log_model(moments_task, artifact_path="moments_task") if "exploration" in cfg.algo.name: - model_info["ensembles"] = mlflow.pytorch.log_model(ensembles, artifact_path="ensembles") + model_info["ensembles"] = mlflow.pytorch.log_model(unwrap_fabric(ensembles), artifact_path="ensembles") model_info["actor_exploration"] = mlflow.pytorch.log_model( unwrap_fabric(actor_exploration), artifact_path="actor_exploration" )