Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: added ensembles creation to build agent function #154

Merged
merged 1 commit into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 34 additions & 7 deletions sheeprl/algos/p2e_dv1/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch
from lightning.fabric import Fabric
from lightning.fabric.wrappers import _FabricModule
from lightning.pytorch.utilities.seed import isolate_rng
from torch import nn

from sheeprl.algos.dreamer_v1.agent import WorldModel
from sheeprl.algos.dreamer_v1.agent import build_agent as dv1_build_agent
Expand All @@ -27,11 +29,12 @@ def build_agent(
cfg: Dict[str, Any],
obs_space: gymnasium.spaces.Dict,
world_model_state: Optional[Dict[str, torch.Tensor]] = None,
ensembles_state: Optional[Dict[str, torch.Tensor]] = None,
actor_task_state: Optional[Dict[str, torch.Tensor]] = None,
critic_task_state: Optional[Dict[str, torch.Tensor]] = None,
actor_exploration_state: Optional[Dict[str, torch.Tensor]] = None,
critic_exploration_state: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[WorldModel, _FabricModule, _FabricModule, _FabricModule, _FabricModule]:
) -> Tuple[WorldModel, _FabricModule, _FabricModule, _FabricModule, _FabricModule, _FabricModule]:
"""Build the models and wrap them with Fabric.

Args:
Expand All @@ -42,6 +45,8 @@ def build_agent(
obs_space (Dict[str, Any]): the observation space.
world_model_state (Dict[str, Tensor], optional): the state of the world model.
Default to None.
ensembles_state (Dict[str, Tensor], optional): the state of the ensembles.
Default to None.
actor_task_state (Dict[str, Tensor], optional): the state of the actor_task.
Default to None.
critic_task_state (Dict[str, Tensor], optional): the state of the critic_task.
Expand All @@ -53,11 +58,12 @@ def build_agent(

Returns:
The world model (WorldModel): composed by the encoder, rssm, observation and
reward models and the continue model.
The actor_task (_FabricModule).
The critic_task (_FabricModule).
The actor_exploration (_FabricModule).
The critic_exploration (_FabricModule).
reward models and the continue model.
The ensembles (_FabricModule): for estimating the intrinsic reward.
The actor_task (_FabricModule): for learning the task.
The critic_task (_FabricModule): for predicting the values of the task.
The actor_exploration (_FabricModule): for exploring the environment.
The critic_exploration (_FabricModule): for predicting the values of the exploration.
"""
world_model_cfg = cfg.algo.world_model
actor_cfg = cfg.algo.actor
Expand Down Expand Up @@ -110,4 +116,25 @@ def build_agent(
actor_task = fabric.setup_module(actor_task)
critic_task = fabric.setup_module(critic_task)

return world_model, actor_task, critic_task, actor_exploration, critic_exploration
ens_list = []
with isolate_rng():
for i in range(cfg.algo.ensembles.n):
fabric.seed_everything(cfg.seed + i)
ens_list.append(
MLP(
input_dims=(
int(sum(actions_dim))
+ cfg.algo.world_model.recurrent_model.recurrent_state_size
+ cfg.algo.world_model.stochastic_size
),
output_dim=world_model.encoder.cnn_output_dim + world_model.encoder.mlp_output_dim,
hidden_sizes=[cfg.algo.ensembles.dense_units] * cfg.algo.ensembles.mlp_layers,
activation=eval(cfg.algo.ensembles.dense_act),
).apply(init_weights)
)
ensembles = nn.ModuleList(ens_list)
if ensembles_state:
ensembles.load_state_dict(ensembles_state)
fabric.setup_module(ensembles)

return world_model, ensembles, actor_task, critic_task, actor_exploration, critic_exploration
2 changes: 1 addition & 1 deletion sheeprl/algos/p2e_dv1/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]):
action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
)
# Create the actor and critic models
world_model, actor_task, _, _, _ = build_agent(
world_model, _, actor_task, _, _, _ = build_agent(
fabric,
actions_dim,
is_continuous,
Expand Down
29 changes: 3 additions & 26 deletions sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@
import torch.nn.functional as F
from lightning.fabric import Fabric
from lightning.fabric.wrappers import _FabricModule, _FabricOptimizer
from lightning.pytorch.utilities.seed import isolate_rng
from mlflow.models.model import ModelInfo
from tensordict import TensorDict
from tensordict.tensordict import TensorDictBase
from torch import nn
from torch.distributions import Bernoulli, Independent, Normal
from torch.utils.data import BatchSampler
from torchmetrics import SumMetric
Expand All @@ -28,13 +26,12 @@
from sheeprl.algos.dreamer_v2.utils import test
from sheeprl.algos.p2e_dv1.agent import build_agent
from sheeprl.data.buffers import AsyncReplayBuffer
from sheeprl.models.models import MLP
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
from sheeprl.utils.metric import MetricAggregator
from sheeprl.utils.registry import register_algorithm
from sheeprl.utils.timer import timer
from sheeprl.utils.utils import init_weights, polynomial_decay, register_model, unwrap_fabric
from sheeprl.utils.utils import polynomial_decay, register_model, unwrap_fabric

# Decomment the following line if you are using MineDojo on an headless machine
# os.environ["MINEDOJO_HEADLESS"] = "1"
Expand Down Expand Up @@ -488,40 +485,20 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder)
obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder

world_model, actor_task, critic_task, actor_exploration, critic_exploration = build_agent(
world_model, ensembles, actor_task, critic_task, actor_exploration, critic_exploration = build_agent(
fabric,
actions_dim,
is_continuous,
cfg,
observation_space,
state["world_model"] if cfg.checkpoint.resume_from else None,
state["ensembles"] if cfg.checkpoint.resume_from else None,
state["actor_task"] if cfg.checkpoint.resume_from else None,
state["critic_task"] if cfg.checkpoint.resume_from else None,
state["actor_exploration"] if cfg.checkpoint.resume_from else None,
state["critic_exploration"] if cfg.checkpoint.resume_from else None,
)

# initialize the ensembles with different seeds to be sure they have different weights
ens_list = []
with isolate_rng():
for i in range(cfg.algo.ensembles.n):
fabric.seed_everything(cfg.seed + i)
ens_list.append(
MLP(
input_dims=(
int(sum(actions_dim))
+ cfg.algo.world_model.recurrent_model.recurrent_state_size
+ cfg.algo.world_model.stochastic_size
),
output_dim=world_model.encoder.cnn_output_dim + world_model.encoder.mlp_output_dim,
hidden_sizes=[cfg.algo.ensembles.dense_units] * cfg.algo.ensembles.mlp_layers,
activation=eval(cfg.algo.ensembles.dense_act),
).apply(init_weights)
)
ensembles = nn.ModuleList(ens_list)
if cfg.checkpoint.resume_from:
ensembles.load_state_dict(state["ensembles"])
fabric.setup_module(ensembles)
player = PlayerDV1(
world_model.encoder.module,
world_model.rssm.recurrent_model.module,
Expand Down
3 changes: 2 additions & 1 deletion sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,13 +139,14 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder)
obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder

world_model, actor_task, critic_task, actor_exploration, _ = build_agent(
world_model, _, actor_task, critic_task, actor_exploration, _ = build_agent(
fabric,
actions_dim,
is_continuous,
cfg,
observation_space,
state["world_model"],
None,
state["actor_task"],
state["critic_task"],
state["actor_exploration"],
Expand Down
36 changes: 3 additions & 33 deletions sheeprl/algos/p2e_dv1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
import mlflow
from lightning import Fabric
from mlflow.models.model import ModelInfo
from torch import nn

from sheeprl.algos.dreamer_v1.utils import AGGREGATOR_KEYS as AGGREGATOR_KEYS_DV1
from sheeprl.algos.p2e_dv1.agent import build_agent
from sheeprl.models.models import MLP
from sheeprl.utils.utils import unwrap_fabric

AGGREGATOR_KEYS = {
Expand Down Expand Up @@ -62,6 +60,7 @@ def log_models_from_checkpoint(
)
(
world_model,
ensembles,
actor_task,
critic_task,
actor_exploration,
Expand All @@ -73,50 +72,21 @@ def log_models_from_checkpoint(
cfg,
env.observation_space,
state["world_model"],
state["ensembles"] if "exploration" in cfg.algo.name else None,
state["actor_task"],
state["critic_task"],
state["actor_exploration"] if "exploration" in cfg.algo.name else None,
state["critic_exploration"] if "exploration" in cfg.algo.name else None,
)

if "exploration" in cfg.algo.name:
ens_list = []
cfg_ensembles = cfg.algo.ensembles
for i in range(cfg_ensembles.n):
fabric.seed_everything(cfg.seed + i)
ens_list.append(
MLP(
input_dims=int(
sum(actions_dim)
+ cfg.algo.world_model.recurrent_model.recurrent_state_size
+ cfg.algo.world_model.stochastic_size * cfg.algo.world_model.discrete_size
),
output_dim=cfg.algo.world_model.stochastic_size * cfg.algo.world_model.discrete_size,
hidden_sizes=[cfg_ensembles.dense_units] * cfg_ensembles.mlp_layers,
activation=eval(cfg_ensembles.dense_act),
flatten_dim=None,
layer_args={"bias": not cfg.algo.ensembles.layer_norm},
norm_layer=(
[nn.LayerNorm for _ in range(cfg_ensembles.mlp_layers)] if cfg_ensembles.layer_norm else None
),
norm_args=(
[{"normalized_shape": cfg_ensembles.dense_units} for _ in range(cfg_ensembles.mlp_layers)]
if cfg_ensembles.layer_norm
else None
),
)
)
ensembles = nn.ModuleList(ens_list)
ensembles.load_state_dict(state["ensembles"])

# Log the model, create a new run if `cfg.run_id` is None.
model_info = {}
with mlflow.start_run(run_id=cfg.run_id, nested=True) as _:
model_info["world_model"] = mlflow.pytorch.log_model(unwrap_fabric(world_model), artifact_path="world_model")
model_info["actor_task"] = mlflow.pytorch.log_model(unwrap_fabric(actor_task), artifact_path="actor_task")
model_info["critic_task"] = mlflow.pytorch.log_model(unwrap_fabric(critic_task), artifact_path="critic_task")
if "exploration" in cfg.algo.name:
model_info["ensembles"] = mlflow.pytorch.log_model(ensembles, artifact_path="ensembles")
model_info["ensembles"] = mlflow.pytorch.log_model(unwrap_fabric(ensembles), artifact_path="ensembles")
model_info["actor_exploration"] = mlflow.pytorch.log_model(
unwrap_fabric(actor_exploration), artifact_path="actor_exploration"
)
Expand Down
58 changes: 50 additions & 8 deletions sheeprl/algos/p2e_dv2/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
from lightning.fabric import Fabric
from lightning.fabric.wrappers import _FabricModule
from lightning.pytorch.utilities.seed import isolate_rng
from torch import nn

from sheeprl.algos.dreamer_v2.agent import Actor as DV2Actor
Expand All @@ -29,13 +30,14 @@ def build_agent(
cfg: Dict[str, Any],
obs_space: gymnasium.spaces.Dict,
world_model_state: Optional[Dict[str, torch.Tensor]] = None,
ensembles_state: Optional[Dict[str, torch.Tensor]] = None,
actor_task_state: Optional[Dict[str, torch.Tensor]] = None,
critic_task_state: Optional[Dict[str, torch.Tensor]] = None,
target_critic_task_state: Optional[Dict[str, torch.Tensor]] = None,
actor_exploration_state: Optional[Dict[str, torch.Tensor]] = None,
critic_exploration_state: Optional[Dict[str, torch.Tensor]] = None,
target_critic_exploration_state: Optional[Dict[str, torch.Tensor]] = None,
) -> Tuple[WorldModel, _FabricModule, _FabricModule, nn.Module, _FabricModule, _FabricModule, nn.Module]:
) -> Tuple[WorldModel, _FabricModule, _FabricModule, _FabricModule, nn.Module, _FabricModule, _FabricModule, nn.Module]:
"""Build the models and wrap them with Fabric.

Args:
Expand All @@ -46,6 +48,8 @@ def build_agent(
obs_space (Dict[str, Any]): The observations space of the environment.
world_model_state (Dict[str, Tensor], optional): the state of the world model.
Default to None.
ensembles_state (Dict[str, Tensor], optional): the state of the ensembles.
Default to None.
actor_task_state (Dict[str, Tensor], optional): the state of the actor_task.
Default to None.
critic_task_state (Dict[str, Tensor], optional): the state of the critic_task.
Expand All @@ -61,13 +65,14 @@ def build_agent(

Returns:
The world model (WorldModel): composed by the encoder, rssm, observation and
reward models and the continue model.
The actor_task (_FabricModule).
The critic_task (_FabricModule).
The target_critic_task (nn.Module).
The actor_exploration (_FabricModule).
The critic_exploration (_FabricModule).
The target_critic_exploration (nn.Module).
reward models and the continue model.
The ensembles (_FabricModule): for estimating the intrinsic reward.
The actor_task (_FabricModule): for learning the task.
The critic_task (_FabricModule): for predicting the values of the task.
The target_critic_task (nn.Module): takes a EMA of the critic_task weights.
The actor_exploration (_FabricModule): for exploring the environment.
The critic_exploration (_FabricModule): for predicting the values of the exploration.
The target_critic_exploration (nn.Module): takes a EMA of the critic_exploration weights.
"""
world_model_cfg = cfg.algo.world_model
actor_cfg = cfg.algo.actor
Expand Down Expand Up @@ -131,8 +136,45 @@ def build_agent(
if target_critic_task_state:
target_critic_task.load_state_dict(target_critic_task_state)

# initialize the ensembles with different seeds to be sure they have different weights
ens_list = []
with isolate_rng():
for i in range(cfg.algo.ensembles.n):
fabric.seed_everything(cfg.seed + i)
ens_list.append(
MLP(
input_dims=int(
sum(actions_dim)
+ cfg.algo.world_model.recurrent_model.recurrent_state_size
+ cfg.algo.world_model.stochastic_size * cfg.algo.world_model.discrete_size
),
output_dim=cfg.algo.world_model.stochastic_size * cfg.algo.world_model.discrete_size,
hidden_sizes=[cfg.algo.ensembles.dense_units] * cfg.algo.ensembles.mlp_layers,
activation=eval(cfg.algo.ensembles.dense_act),
flatten_dim=None,
norm_layer=(
[nn.LayerNorm for _ in range(cfg.algo.ensembles.mlp_layers)]
if cfg.algo.ensembles.layer_norm
else None
),
norm_args=(
[
{"normalized_shape": cfg.algo.ensembles.dense_units}
for _ in range(cfg.algo.ensembles.mlp_layers)
]
if cfg.algo.ensembles.layer_norm
else None
),
).apply(init_weights)
)
ensembles = nn.ModuleList(ens_list)
if ensembles_state:
ensembles.load_state_dict(ensembles_state)
fabric.setup_module(ensembles)

return (
world_model,
ensembles,
actor_task,
critic_task,
target_critic_task,
Expand Down
3 changes: 2 additions & 1 deletion sheeprl/algos/p2e_dv2/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,14 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]):
action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
)
# Create the actor and critic models
world_model, actor_task, _, _, _, _, _ = build_agent(
world_model, _, actor_task, _, _, _, _, _ = build_agent(
fabric,
actions_dim,
is_continuous,
cfg,
observation_space,
state["world_model"],
None,
state["actor_task"],
)
player = PlayerDV2(
Expand Down
Loading