diff --git a/README.md b/README.md
index 4694f5a4..ae2731c0 100644
--- a/README.md
+++ b/README.md
@@ -9,8 +9,6 @@
|
|
-
-
|
|
@@ -58,7 +56,7 @@
DOA++(w/o optimizations)1 |
7M |
18d 22h |
- 2726/332832 |
+ 2726/33282 |
N.A. |
1-3080 |
diff --git a/sheeprl/algos/dreamer_v1/agent.py b/sheeprl/algos/dreamer_v1/agent.py
index db9f8db8..5aa56849 100644
--- a/sheeprl/algos/dreamer_v1/agent.py
+++ b/sheeprl/algos/dreamer_v1/agent.py
@@ -324,7 +324,7 @@ def get_greedy_action(
return actions
-def build_models(
+def build_agent(
fabric: Fabric,
actions_dim: Sequence[int],
is_continuous: bool,
diff --git a/sheeprl/algos/dreamer_v1/dreamer_v1.py b/sheeprl/algos/dreamer_v1/dreamer_v1.py
index 0a9e8d6a..48454131 100644
--- a/sheeprl/algos/dreamer_v1/dreamer_v1.py
+++ b/sheeprl/algos/dreamer_v1/dreamer_v1.py
@@ -20,7 +20,7 @@
from torch.utils.data import BatchSampler
from torchmetrics import SumMetric
-from sheeprl.algos.dreamer_v1.agent import PlayerDV1, WorldModel, build_models
+from sheeprl.algos.dreamer_v1.agent import PlayerDV1, WorldModel, build_agent
from sheeprl.algos.dreamer_v1.loss import actor_loss, critic_loss, reconstruction_loss
from sheeprl.algos.dreamer_v1.utils import compute_lambda_values
from sheeprl.algos.dreamer_v2.utils import test
@@ -444,7 +444,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
is_continuous = isinstance(action_space, gym.spaces.Box)
is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete)
- actions_dim = (
+ actions_dim = tuple(
action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
)
clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r
@@ -477,7 +477,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder)
obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder
- world_model, actor, critic = build_models(
+ world_model, actor, critic = build_agent(
fabric,
actions_dim,
is_continuous,
diff --git a/sheeprl/algos/dreamer_v1/evaluate.py b/sheeprl/algos/dreamer_v1/evaluate.py
index c2dfb450..5574cbec 100644
--- a/sheeprl/algos/dreamer_v1/evaluate.py
+++ b/sheeprl/algos/dreamer_v1/evaluate.py
@@ -5,7 +5,7 @@
import gymnasium as gym
from lightning import Fabric
-from sheeprl.algos.dreamer_v1.agent import PlayerDV1, build_models
+from sheeprl.algos.dreamer_v1.agent import PlayerDV1, build_agent
from sheeprl.algos.dreamer_v2.utils import test
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
@@ -43,11 +43,11 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]):
is_continuous = isinstance(action_space, gym.spaces.Box)
is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete)
- actions_dim = (
+ actions_dim = tuple(
action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
)
# Create the actor and critic models
- world_model, actor, _ = build_models(
+ world_model, actor, _ = build_agent(
fabric,
actions_dim,
is_continuous,
diff --git a/sheeprl/algos/dreamer_v2/agent.py b/sheeprl/algos/dreamer_v2/agent.py
index cb01db30..ededb565 100644
--- a/sheeprl/algos/dreamer_v2/agent.py
+++ b/sheeprl/algos/dreamer_v2/agent.py
@@ -862,7 +862,7 @@ def get_greedy_action(
return actions
-def build_models(
+def build_agent(
fabric: Fabric,
actions_dim: Sequence[int],
is_continuous: bool,
diff --git a/sheeprl/algos/dreamer_v2/dreamer_v2.py b/sheeprl/algos/dreamer_v2/dreamer_v2.py
index 8f6b7b54..850e9776 100644
--- a/sheeprl/algos/dreamer_v2/dreamer_v2.py
+++ b/sheeprl/algos/dreamer_v2/dreamer_v2.py
@@ -25,7 +25,7 @@
from torch.utils.data import BatchSampler
from torchmetrics import SumMetric
-from sheeprl.algos.dreamer_v2.agent import PlayerDV2, WorldModel, build_models
+from sheeprl.algos.dreamer_v2.agent import PlayerDV2, WorldModel, build_agent
from sheeprl.algos.dreamer_v2.loss import reconstruction_loss
from sheeprl.algos.dreamer_v2.utils import compute_lambda_values, test
from sheeprl.data.buffers import AsyncReplayBuffer, EpisodeBuffer
@@ -468,7 +468,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
is_continuous = isinstance(action_space, gym.spaces.Box)
is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete)
- actions_dim = (
+ actions_dim = tuple(
action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
)
clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r
@@ -501,7 +501,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder)
obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder
- world_model, actor, critic, target_critic = build_models(
+ world_model, actor, critic, target_critic = build_agent(
fabric,
actions_dim,
is_continuous,
diff --git a/sheeprl/algos/dreamer_v2/evaluate.py b/sheeprl/algos/dreamer_v2/evaluate.py
index d7d37a6a..ec2e0738 100644
--- a/sheeprl/algos/dreamer_v2/evaluate.py
+++ b/sheeprl/algos/dreamer_v2/evaluate.py
@@ -5,7 +5,7 @@
import gymnasium as gym
from lightning import Fabric
-from sheeprl.algos.dreamer_v2.agent import PlayerDV2, build_models
+from sheeprl.algos.dreamer_v2.agent import PlayerDV2, build_agent
from sheeprl.algos.dreamer_v2.utils import test
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
@@ -43,11 +43,11 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]):
is_continuous = isinstance(action_space, gym.spaces.Box)
is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete)
- actions_dim = (
+ actions_dim = tuple(
action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
)
# Create the actor and critic models
- world_model, actor, _, _ = build_models(
+ world_model, actor, _, _ = build_agent(
fabric,
actions_dim,
is_continuous,
diff --git a/sheeprl/algos/dreamer_v3/agent.py b/sheeprl/algos/dreamer_v3/agent.py
index 0d712b2e..1d7ac574 100644
--- a/sheeprl/algos/dreamer_v3/agent.py
+++ b/sheeprl/algos/dreamer_v3/agent.py
@@ -897,7 +897,7 @@ def add_exploration_noise(
return tuple(expl_actions)
-def build_models(
+def build_agent(
fabric: Fabric,
actions_dim: Sequence[int],
is_continuous: bool,
diff --git a/sheeprl/algos/dreamer_v3/dreamer_v3.py b/sheeprl/algos/dreamer_v3/dreamer_v3.py
index c3588d66..0b360e8f 100644
--- a/sheeprl/algos/dreamer_v3/dreamer_v3.py
+++ b/sheeprl/algos/dreamer_v3/dreamer_v3.py
@@ -25,7 +25,7 @@
from torch.utils.data import BatchSampler
from torchmetrics import SumMetric
-from sheeprl.algos.dreamer_v3.agent import PlayerDV3, WorldModel, build_models
+from sheeprl.algos.dreamer_v3.agent import PlayerDV3, WorldModel, build_agent
from sheeprl.algos.dreamer_v3.loss import reconstruction_loss
from sheeprl.algos.dreamer_v3.utils import Moments, compute_lambda_values, test
from sheeprl.data.buffers import AsyncReplayBuffer
@@ -402,7 +402,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
is_continuous = isinstance(action_space, gym.spaces.Box)
is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete)
- actions_dim = (
+ actions_dim = tuple(
action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
)
clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r
@@ -435,7 +435,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder)
obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder
- world_model, actor, critic, target_critic = build_models(
+ world_model, actor, critic, target_critic = build_agent(
fabric,
actions_dim,
is_continuous,
diff --git a/sheeprl/algos/dreamer_v3/evaluate.py b/sheeprl/algos/dreamer_v3/evaluate.py
index 5b1a0f38..ab5ab68f 100644
--- a/sheeprl/algos/dreamer_v3/evaluate.py
+++ b/sheeprl/algos/dreamer_v3/evaluate.py
@@ -5,7 +5,7 @@
import gymnasium as gym
from lightning import Fabric
-from sheeprl.algos.dreamer_v3.agent import PlayerDV3, build_models
+from sheeprl.algos.dreamer_v3.agent import PlayerDV3, build_agent
from sheeprl.algos.dreamer_v3.utils import test
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
@@ -43,11 +43,11 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]):
is_continuous = isinstance(action_space, gym.spaces.Box)
is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete)
- actions_dim = (
+ actions_dim = tuple(
action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
)
# Create the actor and critic models
- world_model, actor, _, _ = build_models(
+ world_model, actor, _, _ = build_agent(
fabric,
actions_dim,
is_continuous,
diff --git a/sheeprl/algos/droq/agent.py b/sheeprl/algos/droq/agent.py
index a3c88b12..cd206e22 100644
--- a/sheeprl/algos/droq/agent.py
+++ b/sheeprl/algos/droq/agent.py
@@ -1,8 +1,11 @@
import copy
-from typing import Sequence, Tuple, Union
+from math import prod
+from typing import Any, Dict, Optional, Sequence, Tuple, Union
+import gymnasium
import torch
import torch.nn as nn
+from lightning import Fabric
from lightning.fabric.wrappers import _FabricModule
from torch import Tensor
@@ -198,3 +201,41 @@ def qfs_target_ema(self, critic_idx: int) -> None:
self.qfs_unwrapped[critic_idx].parameters(), self.qfs_target[critic_idx].parameters()
):
target_param.data.copy_(self._tau * param.data + (1 - self._tau) * target_param.data)
+
+
+def build_agent(
+ fabric: Fabric,
+ cfg: Dict[str, Any],
+ obs_space: gymnasium.spaces.Dict,
+ action_space: gymnasium.spaces.Box,
+ agent_state: Optional[Dict[str, Tensor]] = None,
+) -> DROQAgent:
+ act_dim = prod(action_space.shape)
+ obs_dim = sum([prod(obs_space[k].shape) for k in cfg.mlp_keys.encoder])
+ actor = SACActor(
+ observation_dim=obs_dim,
+ action_dim=act_dim,
+ distribution_cfg=cfg.distribution,
+ hidden_size=cfg.algo.actor.hidden_size,
+ action_low=action_space.low,
+ action_high=action_space.high,
+ )
+ critics = [
+ DROQCritic(
+ observation_dim=obs_dim + act_dim,
+ hidden_size=cfg.algo.critic.hidden_size,
+ num_critics=1,
+ dropout=cfg.algo.critic.dropout,
+ )
+ for _ in range(cfg.algo.critic.n)
+ ]
+ target_entropy = -act_dim
+ agent = DROQAgent(
+ actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device
+ )
+ if agent_state:
+ agent.load_state_dict(agent_state)
+ agent.actor = fabric.setup_module(agent.actor)
+ agent.critics = [fabric.setup_module(critic) for critic in agent.critics]
+
+ return agent
diff --git a/sheeprl/algos/droq/droq.py b/sheeprl/algos/droq/droq.py
index 94d3a366..c0c5b71f 100644
--- a/sheeprl/algos/droq/droq.py
+++ b/sheeprl/algos/droq/droq.py
@@ -3,7 +3,6 @@
import copy
import os
import warnings
-from math import prod
from typing import Any, Dict
import gymnasium as gym
@@ -20,8 +19,7 @@
from torch.utils.data.sampler import BatchSampler
from torchmetrics import SumMetric
-from sheeprl.algos.droq.agent import DROQAgent, DROQCritic
-from sheeprl.algos.sac.agent import SACActor
+from sheeprl.algos.droq.agent import DROQAgent, build_agent
from sheeprl.algos.sac.loss import entropy_loss, policy_loss
from sheeprl.algos.sac.sac import test
from sheeprl.data.buffers import ReplayBuffer
@@ -196,33 +194,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder)
# Define the agent and the optimizer and setup them with Fabric
- act_dim = prod(action_space.shape)
- obs_dim = sum([prod(observation_space[k].shape) for k in cfg.mlp_keys.encoder])
- actor = SACActor(
- observation_dim=obs_dim,
- action_dim=act_dim,
- distribution_cfg=cfg.distribution,
- hidden_size=cfg.algo.actor.hidden_size,
- action_low=action_space.low,
- action_high=action_space.high,
+ agent = build_agent(
+ fabric, cfg, observation_space, action_space, state["agent"] if cfg.checkpoint.resume_from else None
)
- critics = [
- DROQCritic(
- observation_dim=obs_dim + act_dim,
- hidden_size=cfg.algo.critic.hidden_size,
- num_critics=1,
- dropout=cfg.algo.critic.dropout,
- )
- for _ in range(cfg.algo.critic.n)
- ]
- target_entropy = -act_dim
- agent = DROQAgent(
- actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device
- )
- if cfg.checkpoint.resume_from:
- agent.load_state_dict(state["agent"])
- agent.actor = fabric.setup_module(agent.actor)
- agent.critics = [fabric.setup_module(critic) for critic in agent.critics]
# Optimizers
qf_optimizer = hydra.utils.instantiate(cfg.algo.critic.optimizer, params=agent.qfs.parameters())
diff --git a/sheeprl/algos/droq/evaluate.py b/sheeprl/algos/droq/evaluate.py
index 0d082dbf..ce179264 100644
--- a/sheeprl/algos/droq/evaluate.py
+++ b/sheeprl/algos/droq/evaluate.py
@@ -1,13 +1,11 @@
from __future__ import annotations
-from math import prod
from typing import Any, Dict
import gymnasium as gym
from lightning import Fabric
-from sheeprl.algos.droq.agent import DROQAgent, DROQCritic
-from sheeprl.algos.sac.agent import SACActor
+from sheeprl.algos.droq.agent import build_agent
from sheeprl.algos.sac.utils import test
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
@@ -47,29 +45,5 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]):
if cfg.metric.log_level > 0:
fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder)
- act_dim = prod(action_space.shape)
- obs_dim = sum([prod(observation_space[k].shape) for k in cfg.mlp_keys.encoder])
- actor = SACActor(
- observation_dim=obs_dim,
- action_dim=act_dim,
- distribution_cfg=cfg.distribution,
- hidden_size=cfg.algo.actor.hidden_size,
- action_low=action_space.low,
- action_high=action_space.high,
- )
- critics = [
- DROQCritic(
- observation_dim=obs_dim + act_dim,
- hidden_size=cfg.algo.critic.hidden_size,
- num_critics=1,
- dropout=cfg.algo.critic.dropout,
- )
- for _ in range(cfg.algo.critic.n)
- ]
- target_entropy = -act_dim
- agent = DROQAgent(
- actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device
- )
- agent.load_state_dict(state["agent"])
- agent = fabric.setup_module(agent)
+ agent = build_agent(fabric, cfg, observation_space, action_space, state["agent"])
test(agent.actor, fabric, cfg, log_dir)
diff --git a/sheeprl/algos/p2e_dv1/agent.py b/sheeprl/algos/p2e_dv1/agent.py
index 9c844dd8..9fb17606 100644
--- a/sheeprl/algos/p2e_dv1/agent.py
+++ b/sheeprl/algos/p2e_dv1/agent.py
@@ -7,7 +7,7 @@
from lightning.fabric.wrappers import _FabricModule
from sheeprl.algos.dreamer_v1.agent import WorldModel
-from sheeprl.algos.dreamer_v1.agent import build_models as dv1_build_models
+from sheeprl.algos.dreamer_v1.agent import build_agent as dv1_build_agent
from sheeprl.algos.dreamer_v2.agent import Actor as DV2Actor
from sheeprl.algos.dreamer_v2.agent import MinedojoActor as DV2MinedojoActor
from sheeprl.models.models import MLP
@@ -20,7 +20,7 @@
MinedojoActor = DV2MinedojoActor
-def build_models(
+def build_agent(
fabric: Fabric,
actions_dim: Sequence[int],
is_continuous: bool,
@@ -67,7 +67,7 @@ def build_models(
latent_state_size = world_model_cfg.stochastic_size + world_model_cfg.recurrent_model.recurrent_state_size
# Create exploration models
- world_model, actor_exploration, critic_exploration = dv1_build_models(
+ world_model, actor_exploration, critic_exploration = dv1_build_agent(
fabric,
actions_dim=actions_dim,
is_continuous=is_continuous,
diff --git a/sheeprl/algos/p2e_dv1/evaluate.py b/sheeprl/algos/p2e_dv1/evaluate.py
index ddc7c294..349f4884 100644
--- a/sheeprl/algos/p2e_dv1/evaluate.py
+++ b/sheeprl/algos/p2e_dv1/evaluate.py
@@ -7,7 +7,7 @@
from sheeprl.algos.dreamer_v1.agent import PlayerDV1
from sheeprl.algos.dreamer_v2.utils import test
-from sheeprl.algos.p2e_dv1.agent import build_models
+from sheeprl.algos.p2e_dv1.agent import build_agent
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
from sheeprl.utils.registry import register_evaluation
@@ -44,11 +44,11 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]):
is_continuous = isinstance(action_space, gym.spaces.Box)
is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete)
- actions_dim = (
+ actions_dim = tuple(
action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
)
# Create the actor and critic models
- world_model, actor_task, _, _, _ = build_models(
+ world_model, actor_task, _, _, _ = build_agent(
fabric,
actions_dim,
is_continuous,
diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py
index 41f44078..c81f569e 100644
--- a/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py
+++ b/sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py
@@ -26,7 +26,7 @@
from sheeprl.algos.dreamer_v1.loss import actor_loss, critic_loss, reconstruction_loss
from sheeprl.algos.dreamer_v1.utils import compute_lambda_values
from sheeprl.algos.dreamer_v2.utils import test
-from sheeprl.algos.p2e_dv1.agent import build_models
+from sheeprl.algos.p2e_dv1.agent import build_agent
from sheeprl.data.buffers import AsyncReplayBuffer
from sheeprl.models.models import MLP
from sheeprl.utils.env import make_env
@@ -455,7 +455,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
is_continuous = isinstance(action_space, gym.spaces.Box)
is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete)
- actions_dim = (
+ actions_dim = tuple(
action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
)
clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r
@@ -488,7 +488,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder)
obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder
- world_model, actor_task, critic_task, actor_exploration, critic_exploration = build_models(
+ world_model, actor_task, critic_task, actor_exploration, critic_exploration = build_agent(
fabric,
actions_dim,
is_continuous,
diff --git a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py
index ecb370eb..d8d743e0 100644
--- a/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py
+++ b/sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py
@@ -20,7 +20,7 @@
from sheeprl.algos.dreamer_v1.agent import PlayerDV1
from sheeprl.algos.dreamer_v1.dreamer_v1 import train
from sheeprl.algos.dreamer_v2.utils import test
-from sheeprl.algos.p2e_dv1.agent import build_models
+from sheeprl.algos.p2e_dv1.agent import build_agent
from sheeprl.data.buffers import AsyncReplayBuffer
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
@@ -106,7 +106,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
is_continuous = isinstance(action_space, gym.spaces.Box)
is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete)
- actions_dim = (
+ actions_dim = tuple(
action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
)
clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r
@@ -139,7 +139,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder)
obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder
- world_model, actor_task, critic_task, actor_exploration, _ = build_models(
+ world_model, actor_task, critic_task, actor_exploration, _ = build_agent(
fabric,
actions_dim,
is_continuous,
diff --git a/sheeprl/algos/p2e_dv2/agent.py b/sheeprl/algos/p2e_dv2/agent.py
index b40973d6..3ffef72f 100644
--- a/sheeprl/algos/p2e_dv2/agent.py
+++ b/sheeprl/algos/p2e_dv2/agent.py
@@ -11,7 +11,7 @@
from sheeprl.algos.dreamer_v2.agent import Actor as DV2Actor
from sheeprl.algos.dreamer_v2.agent import MinedojoActor as DV2MinedojoActor
from sheeprl.algos.dreamer_v2.agent import WorldModel
-from sheeprl.algos.dreamer_v2.agent import build_models as dv2_build_models
+from sheeprl.algos.dreamer_v2.agent import build_agent as dv2_build_agent
from sheeprl.models.models import MLP
from sheeprl.utils.utils import init_weights
@@ -22,7 +22,7 @@
MinedojoActor = DV2MinedojoActor
-def build_models(
+def build_agent(
fabric: Fabric,
actions_dim: Sequence[int],
is_continuous: bool,
@@ -78,7 +78,7 @@ def build_models(
latent_state_size = stochastic_size + world_model_cfg.recurrent_model.recurrent_state_size
# Create exploration models
- world_model, actor_exploration, critic_exploration, target_critic_exploration = dv2_build_models(
+ world_model, actor_exploration, critic_exploration, target_critic_exploration = dv2_build_agent(
fabric,
actions_dim=actions_dim,
is_continuous=is_continuous,
diff --git a/sheeprl/algos/p2e_dv2/evaluate.py b/sheeprl/algos/p2e_dv2/evaluate.py
index f252f4e6..cb420460 100644
--- a/sheeprl/algos/p2e_dv2/evaluate.py
+++ b/sheeprl/algos/p2e_dv2/evaluate.py
@@ -7,7 +7,7 @@
from sheeprl.algos.dreamer_v2.agent import PlayerDV2
from sheeprl.algos.dreamer_v2.utils import test
-from sheeprl.algos.p2e_dv2.agent import build_models
+from sheeprl.algos.p2e_dv2.agent import build_agent
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
from sheeprl.utils.registry import register_evaluation
@@ -44,11 +44,11 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]):
is_continuous = isinstance(action_space, gym.spaces.Box)
is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete)
- actions_dim = (
+ actions_dim = tuple(
action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
)
# Create the actor and critic models
- world_model, actor_task, _, _, _, _, _ = build_models(
+ world_model, actor_task, _, _, _, _, _ = build_agent(
fabric,
actions_dim,
is_continuous,
diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py
index 7e4b1cee..da3a1be6 100644
--- a/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py
+++ b/sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py
@@ -25,7 +25,7 @@
from sheeprl.algos.dreamer_v2.agent import PlayerDV2, WorldModel
from sheeprl.algos.dreamer_v2.loss import reconstruction_loss
from sheeprl.algos.dreamer_v2.utils import compute_lambda_values, init_weights, test
-from sheeprl.algos.p2e_dv2.agent import build_models
+from sheeprl.algos.p2e_dv2.agent import build_agent
from sheeprl.data.buffers import AsyncReplayBuffer, EpisodeBuffer
from sheeprl.models.models import MLP
from sheeprl.utils.distribution import OneHotCategoricalValidateArgs
@@ -570,7 +570,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
is_continuous = isinstance(action_space, gym.spaces.Box)
is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete)
- actions_dim = (
+ actions_dim = tuple(
action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
)
clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r
@@ -611,7 +611,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
actor_exploration,
critic_exploration,
target_critic_exploration,
- ) = build_models(
+ ) = build_agent(
fabric,
actions_dim,
is_continuous,
diff --git a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py
index 878d7a77..adebc83b 100644
--- a/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py
+++ b/sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py
@@ -21,7 +21,7 @@
from sheeprl.algos.dreamer_v2.agent import PlayerDV2
from sheeprl.algos.dreamer_v2.dreamer_v2 import train
from sheeprl.algos.dreamer_v2.utils import test
-from sheeprl.algos.p2e_dv2.agent import build_models
+from sheeprl.algos.p2e_dv2.agent import build_agent
from sheeprl.data.buffers import AsyncReplayBuffer, EpisodeBuffer
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
@@ -110,7 +110,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
is_continuous = isinstance(action_space, gym.spaces.Box)
is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete)
- actions_dim = (
+ actions_dim = tuple(
action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
)
clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r
@@ -143,7 +143,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder)
obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder
- world_model, actor_task, critic_task, target_critic_task, actor_exploration, _, _ = build_models(
+ world_model, actor_task, critic_task, target_critic_task, actor_exploration, _, _ = build_agent(
fabric,
actions_dim,
is_continuous,
diff --git a/sheeprl/algos/p2e_dv3/agent.py b/sheeprl/algos/p2e_dv3/agent.py
index c4019b1c..586e3e1e 100644
--- a/sheeprl/algos/p2e_dv3/agent.py
+++ b/sheeprl/algos/p2e_dv3/agent.py
@@ -10,7 +10,7 @@
from sheeprl.algos.dreamer_v3.agent import Actor as DV3Actor
from sheeprl.algos.dreamer_v3.agent import MinedojoActor as DV3MinedojoActor
from sheeprl.algos.dreamer_v3.agent import WorldModel
-from sheeprl.algos.dreamer_v3.agent import build_models as dv3_build_models
+from sheeprl.algos.dreamer_v3.agent import build_agent as dv3_build_agent
from sheeprl.algos.dreamer_v3.utils import init_weights, uniform_init_weights
from sheeprl.models.models import MLP
@@ -21,7 +21,7 @@
MinedojoActor = DV3MinedojoActor
-def build_models(
+def build_agent(
fabric: Fabric,
actions_dim: Sequence[int],
is_continuous: bool,
@@ -73,7 +73,7 @@ def build_models(
latent_state_size = stochastic_size + world_model_cfg.recurrent_model.recurrent_state_size
# Create task models
- world_model, actor_task, critic_task, target_critic_task = dv3_build_models(
+ world_model, actor_task, critic_task, target_critic_task = dv3_build_agent(
fabric,
actions_dim=actions_dim,
is_continuous=is_continuous,
diff --git a/sheeprl/algos/p2e_dv3/evaluate.py b/sheeprl/algos/p2e_dv3/evaluate.py
index 97ceb94f..20ccd61d 100644
--- a/sheeprl/algos/p2e_dv3/evaluate.py
+++ b/sheeprl/algos/p2e_dv3/evaluate.py
@@ -7,7 +7,7 @@
from sheeprl.algos.dreamer_v3.agent import PlayerDV3
from sheeprl.algos.dreamer_v3.utils import test
-from sheeprl.algos.p2e_dv3.agent import build_models
+from sheeprl.algos.p2e_dv3.agent import build_agent
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
from sheeprl.utils.registry import register_evaluation
@@ -44,11 +44,11 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]):
is_continuous = isinstance(action_space, gym.spaces.Box)
is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete)
- actions_dim = (
+ actions_dim = tuple(
action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
)
# Create the actor and critic models
- world_model, actor, _, _, _, _ = build_models(
+ world_model, actor, _, _, _, _ = build_agent(
fabric,
actions_dim,
is_continuous,
diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py
index 86c9063f..776c5420 100644
--- a/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py
+++ b/sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py
@@ -24,7 +24,7 @@
from sheeprl.algos.dreamer_v3.agent import PlayerDV3, WorldModel
from sheeprl.algos.dreamer_v3.loss import reconstruction_loss
from sheeprl.algos.dreamer_v3.utils import Moments, compute_lambda_values, init_weights, test
-from sheeprl.algos.p2e_dv3.agent import build_models
+from sheeprl.algos.p2e_dv3.agent import build_agent
from sheeprl.data.buffers import AsyncReplayBuffer
from sheeprl.models.models import MLP
from sheeprl.utils.distribution import (
@@ -599,7 +599,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
is_continuous = isinstance(action_space, gym.spaces.Box)
is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete)
- actions_dim = (
+ actions_dim = tuple(
action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
)
clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r
@@ -639,7 +639,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
target_critic_task,
actor_exploration,
critics_exploration,
- ) = build_models(
+ ) = build_agent(
fabric,
actions_dim,
is_continuous,
diff --git a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py
index 321f5ea2..ea1eef9f 100644
--- a/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py
+++ b/sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py
@@ -19,7 +19,7 @@
from sheeprl.algos.dreamer_v3.agent import PlayerDV3
from sheeprl.algos.dreamer_v3.dreamer_v3 import train
from sheeprl.algos.dreamer_v3.utils import Moments, test
-from sheeprl.algos.p2e_dv3.agent import build_models
+from sheeprl.algos.p2e_dv3.agent import build_agent
from sheeprl.data.buffers import AsyncReplayBuffer
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
@@ -104,7 +104,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
is_continuous = isinstance(action_space, gym.spaces.Box)
is_multidiscrete = isinstance(action_space, gym.spaces.MultiDiscrete)
- actions_dim = (
+ actions_dim = tuple(
action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
)
clip_rewards_fn = lambda r: torch.tanh(r) if cfg.env.clip_rewards else r
@@ -144,7 +144,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
target_critic_task,
actor_exploration,
_,
- ) = build_models(
+ ) = build_agent(
fabric,
actions_dim,
is_continuous,
diff --git a/sheeprl/algos/ppo/agent.py b/sheeprl/algos/ppo/agent.py
index 4efcea6b..6bd6b632 100644
--- a/sheeprl/algos/ppo/agent.py
+++ b/sheeprl/algos/ppo/agent.py
@@ -4,6 +4,8 @@
import gymnasium
import torch
import torch.nn as nn
+from lightning import Fabric
+from lightning.fabric.wrappers import _FabricModule
from torch import Tensor
from torch.distributions import Distribution, Independent, Normal
@@ -62,7 +64,7 @@ def forward(self, obs: Dict[str, Tensor]) -> Tensor:
class PPOAgent(nn.Module):
def __init__(
self,
- actions_dim: List[int],
+ actions_dim: Sequence[int],
obs_space: gymnasium.spaces.Dict,
encoder_cfg: Dict[str, Any],
actor_cfg: Dict[str, Any],
@@ -194,3 +196,30 @@ def get_greedy_actions(self, obs: Dict[str, Tensor]) -> Sequence[Tensor]:
for logits in pre_dist
]
)
+
+
+def build_agent(
+ fabric: Fabric,
+ actions_dim: Sequence[int],
+ is_continuous: bool,
+ cfg: Dict[str, Any],
+ obs_space: gymnasium.spaces.Dict,
+ agent_state: Optional[Dict[str, Tensor]] = None,
+) -> _FabricModule:
+ agent = PPOAgent(
+ actions_dim=actions_dim,
+ obs_space=obs_space,
+ encoder_cfg=cfg.algo.encoder,
+ actor_cfg=cfg.algo.actor,
+ critic_cfg=cfg.algo.critic,
+ cnn_keys=cfg.cnn_keys.encoder,
+ mlp_keys=cfg.mlp_keys.encoder,
+ screen_size=cfg.env.screen_size,
+ distribution_cfg=cfg.distribution,
+ is_continuous=is_continuous,
+ )
+ if agent_state:
+ agent.load_state_dict(agent_state)
+ agent = fabric.setup_module(agent)
+
+ return agent
diff --git a/sheeprl/algos/ppo/evaluate.py b/sheeprl/algos/ppo/evaluate.py
index 82764685..56ca9e1f 100644
--- a/sheeprl/algos/ppo/evaluate.py
+++ b/sheeprl/algos/ppo/evaluate.py
@@ -5,7 +5,7 @@
import gymnasium as gym
from lightning import Fabric
-from sheeprl.algos.ppo.agent import PPOAgent
+from sheeprl.algos.ppo.agent import build_agent
from sheeprl.algos.ppo.utils import test
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
@@ -42,26 +42,13 @@ def evaluate_ppo(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]):
is_continuous = isinstance(env.action_space, gym.spaces.Box)
is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete)
- actions_dim = (
+ actions_dim = tuple(
env.action_space.shape
if is_continuous
else (env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n])
)
# Create the actor and critic models
- agent = PPOAgent(
- actions_dim=actions_dim,
- obs_space=observation_space,
- encoder_cfg=cfg.algo.encoder,
- actor_cfg=cfg.algo.actor,
- critic_cfg=cfg.algo.critic,
- cnn_keys=cfg.cnn_keys.encoder,
- mlp_keys=cfg.mlp_keys.encoder,
- screen_size=cfg.env.screen_size,
- distribution_cfg=cfg.distribution,
- is_continuous=is_continuous,
- )
- agent.load_state_dict(state["agent"])
- agent = fabric.setup_module(agent)
+ agent = build_agent(fabric, actions_dim, is_continuous, cfg, observation_space, state["agent"])
test(agent, fabric, cfg, log_dir)
diff --git a/sheeprl/algos/ppo/ppo.py b/sheeprl/algos/ppo/ppo.py
index 8be7f0ca..d8bd32db 100644
--- a/sheeprl/algos/ppo/ppo.py
+++ b/sheeprl/algos/ppo/ppo.py
@@ -19,7 +19,7 @@
from torch.utils.data import BatchSampler, DistributedSampler, RandomSampler
from torchmetrics import SumMetric
-from sheeprl.algos.ppo.agent import PPOAgent
+from sheeprl.algos.ppo.agent import build_agent
from sheeprl.algos.ppo.loss import entropy_loss, policy_loss, value_loss
from sheeprl.algos.ppo.utils import normalize_obs, test
from sheeprl.data import ReplayBuffer
@@ -168,23 +168,19 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
is_continuous = isinstance(envs.single_action_space, gym.spaces.Box)
is_multidiscrete = isinstance(envs.single_action_space, gym.spaces.MultiDiscrete)
- actions_dim = (
+ actions_dim = tuple(
envs.single_action_space.shape
if is_continuous
else (envs.single_action_space.nvec.tolist() if is_multidiscrete else [envs.single_action_space.n])
)
# Create the actor and critic models
- agent = PPOAgent(
- actions_dim=actions_dim,
- obs_space=observation_space,
- encoder_cfg=cfg.algo.encoder,
- actor_cfg=cfg.algo.actor,
- critic_cfg=cfg.algo.critic,
- cnn_keys=cfg.cnn_keys.encoder,
- mlp_keys=cfg.mlp_keys.encoder,
- screen_size=cfg.env.screen_size,
- distribution_cfg=cfg.distribution,
- is_continuous=is_continuous,
+ agent = build_agent(
+ fabric,
+ actions_dim,
+ is_continuous,
+ cfg,
+ observation_space,
+ state["agent"] if cfg.checkpoint.resume_from else None,
)
# Define the optimizer
@@ -194,11 +190,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
# Load the state from the checkpoint
if cfg.checkpoint.resume_from:
- agent.load_state_dict(state["agent"])
optimizer.load_state_dict(state["optimizer"])
# Setup agent and optimizer with Fabric
- agent = fabric.setup_module(agent)
optimizer = fabric.setup_optimizers(optimizer)
# Create a metric aggregator to log the metrics
diff --git a/sheeprl/algos/ppo/ppo_decoupled.py b/sheeprl/algos/ppo/ppo_decoupled.py
index 5ab593e6..494189b0 100644
--- a/sheeprl/algos/ppo/ppo_decoupled.py
+++ b/sheeprl/algos/ppo/ppo_decoupled.py
@@ -78,7 +78,7 @@ def player(
is_continuous = isinstance(envs.single_action_space, gym.spaces.Box)
is_multidiscrete = isinstance(envs.single_action_space, gym.spaces.MultiDiscrete)
- actions_dim = (
+ actions_dim = tuple(
envs.single_action_space.shape
if is_continuous
else (envs.single_action_space.nvec.tolist() if is_multidiscrete else [envs.single_action_space.n])
diff --git a/sheeprl/algos/ppo_recurrent/agent.py b/sheeprl/algos/ppo_recurrent/agent.py
index 90ae6ab3..7fd1878f 100644
--- a/sheeprl/algos/ppo_recurrent/agent.py
+++ b/sheeprl/algos/ppo_recurrent/agent.py
@@ -4,6 +4,7 @@
import gymnasium
import torch
import torch.nn as nn
+from lightning import Fabric
from torch import Tensor
from torch.distributions import Independent, Normal
@@ -286,3 +287,33 @@ def forward(
pre_dist = self.get_pre_dist(out)
actions, logprobs, entropies = self.get_sampled_actions(pre_dist, actions)
return actions, logprobs, entropies, values, states
+
+
+def build_agent(
+ fabric: Fabric,
+ actions_dim: Sequence[int],
+ is_continuous: bool,
+ cfg: Dict[str, Any],
+ obs_space: gymnasium.spaces.Dict,
+ agent_state: Optional[Dict[str, Tensor]] = None,
+) -> RecurrentPPOAgent:
+ agent = RecurrentPPOAgent(
+ actions_dim=actions_dim,
+ obs_space=obs_space,
+ encoder_cfg=cfg.algo.encoder,
+ rnn_cfg=cfg.algo.rnn,
+ actor_cfg=cfg.algo.actor,
+ critic_cfg=cfg.algo.critic,
+ cnn_keys=cfg.cnn_keys.encoder,
+ mlp_keys=cfg.mlp_keys.encoder,
+ is_continuous=is_continuous,
+ distribution_cfg=cfg.distribution,
+ num_envs=cfg.env.num_envs,
+ screen_size=cfg.env.screen_size,
+ device=fabric.device,
+ )
+ if agent_state:
+ agent.load_state_dict(agent_state)
+ agent = fabric.setup_module(agent)
+
+ return agent
diff --git a/sheeprl/algos/ppo_recurrent/evaluate.py b/sheeprl/algos/ppo_recurrent/evaluate.py
index b0940742..e7aed909 100644
--- a/sheeprl/algos/ppo_recurrent/evaluate.py
+++ b/sheeprl/algos/ppo_recurrent/evaluate.py
@@ -5,7 +5,7 @@
import gymnasium as gym
from lightning import Fabric
-from sheeprl.algos.ppo_recurrent.agent import RecurrentPPOAgent
+from sheeprl.algos.ppo_recurrent.agent import build_agent
from sheeprl.algos.ppo_recurrent.utils import test
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
@@ -42,27 +42,11 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]):
is_continuous = isinstance(env.action_space, gym.spaces.Box)
is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete)
- actions_dim = (
+ actions_dim = tuple(
env.action_space.shape
if is_continuous
else (env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n])
)
# Create the actor and critic models
- agent = RecurrentPPOAgent(
- actions_dim=actions_dim,
- obs_space=observation_space,
- encoder_cfg=cfg.algo.encoder,
- rnn_cfg=cfg.algo.rnn,
- actor_cfg=cfg.algo.actor,
- critic_cfg=cfg.algo.critic,
- cnn_keys=cfg.cnn_keys.encoder,
- mlp_keys=cfg.mlp_keys.encoder,
- is_continuous=is_continuous,
- distribution_cfg=cfg.distribution,
- num_envs=cfg.env.num_envs,
- screen_size=cfg.env.screen_size,
- device=fabric.device,
- )
- agent.load_state_dict(state["agent"])
- agent = fabric.setup_module(agent)
+ agent = build_agent(fabric, actions_dim, is_continuous, cfg, observation_space, state["agent"])
test(agent, fabric, cfg, log_dir)
diff --git a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py
index b067a7b9..ec4bad48 100644
--- a/sheeprl/algos/ppo_recurrent/ppo_recurrent.py
+++ b/sheeprl/algos/ppo_recurrent/ppo_recurrent.py
@@ -22,7 +22,7 @@
from sheeprl.algos.ppo.loss import entropy_loss, policy_loss, value_loss
from sheeprl.algos.ppo.utils import normalize_obs
-from sheeprl.algos.ppo_recurrent.agent import RecurrentPPOAgent
+from sheeprl.algos.ppo_recurrent.agent import RecurrentPPOAgent, build_agent
from sheeprl.algos.ppo_recurrent.utils import test
from sheeprl.data.buffers import ReplayBuffer
from sheeprl.utils.env import make_env
@@ -178,37 +178,27 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
is_continuous = isinstance(envs.single_action_space, gym.spaces.Box)
is_multidiscrete = isinstance(envs.single_action_space, gym.spaces.MultiDiscrete)
- actions_dim = (
+ actions_dim = tuple(
envs.single_action_space.shape
if is_continuous
else (envs.single_action_space.nvec.tolist() if is_multidiscrete else [envs.single_action_space.n])
)
# Define the agent and the optimizer
- agent = RecurrentPPOAgent(
- actions_dim=actions_dim,
- obs_space=observation_space,
- encoder_cfg=cfg.algo.encoder,
- rnn_cfg=cfg.algo.rnn,
- actor_cfg=cfg.algo.actor,
- critic_cfg=cfg.algo.critic,
- cnn_keys=cfg.cnn_keys.encoder,
- mlp_keys=cfg.mlp_keys.encoder,
- is_continuous=is_continuous,
- distribution_cfg=cfg.distribution,
- num_envs=cfg.env.num_envs,
- screen_size=cfg.env.screen_size,
- device=device,
+ agent = build_agent(
+ fabric,
+ actions_dim,
+ is_continuous,
+ cfg,
+ observation_space,
+ state["agent"] if cfg.checkpoint.resume_from else None,
)
optimizer = hydra.utils.instantiate(cfg.algo.optimizer, params=agent.parameters())
# Load the state from the checkpoint
if cfg.checkpoint.resume_from:
- agent.load_state_dict(state["agent"])
optimizer.load_state_dict(state["optimizer"])
-
# Setup agent and optimizer with Fabric
- agent = fabric.setup_module(agent)
optimizer = fabric.setup_optimizers(optimizer)
local_vars = locals()
diff --git a/sheeprl/algos/sac/agent.py b/sheeprl/algos/sac/agent.py
index 4e52a08f..215415a5 100644
--- a/sheeprl/algos/sac/agent.py
+++ b/sheeprl/algos/sac/agent.py
@@ -1,8 +1,11 @@
import copy
-from typing import Any, Dict, Sequence, SupportsFloat, Tuple, Union
+from math import prod
+from typing import Any, Dict, Optional, Sequence, SupportsFloat, Tuple, Union
+import gymnasium
import torch
import torch.nn as nn
+from lightning import Fabric
from lightning.fabric.wrappers import _FabricModule
from numpy.typing import NDArray
from torch import Tensor
@@ -273,3 +276,34 @@ def get_next_target_q_values(self, next_obs: Tensor, rewards: Tensor, dones: Ten
def qfs_target_ema(self) -> None:
for param, target_param in zip(self.qfs_unwrapped.parameters(), self.qfs_target.parameters()):
target_param.data.copy_(self._tau * param.data + (1 - self._tau) * target_param.data)
+
+
+def build_agent(
+ fabric: Fabric,
+ cfg: Dict[str, Any],
+ obs_space: gymnasium.spaces.Dict,
+ action_space: gymnasium.spaces.Box,
+ agent_state: Optional[Dict[str, Tensor]] = None,
+) -> SACAgent:
+ act_dim = prod(action_space.shape)
+ obs_dim = sum([prod(obs_space[k].shape) for k in cfg.mlp_keys.encoder])
+ actor = SACActor(
+ observation_dim=obs_dim,
+ action_dim=act_dim,
+ distribution_cfg=cfg.distribution,
+ hidden_size=cfg.algo.actor.hidden_size,
+ action_low=action_space.low,
+ action_high=action_space.high,
+ )
+ critics = [
+ SACCritic(observation_dim=obs_dim + act_dim, hidden_size=cfg.algo.critic.hidden_size, num_critics=1)
+ for _ in range(cfg.algo.critic.n)
+ ]
+ target_entropy = -act_dim
+ agent = SACAgent(actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device)
+ if agent_state:
+ agent.load_state_dict(agent_state)
+ agent.actor = fabric.setup_module(agent.actor)
+ agent.critics = [fabric.setup_module(critic) for critic in agent.critics]
+
+ return agent
diff --git a/sheeprl/algos/sac/evaluate.py b/sheeprl/algos/sac/evaluate.py
index 9d138562..89024046 100644
--- a/sheeprl/algos/sac/evaluate.py
+++ b/sheeprl/algos/sac/evaluate.py
@@ -1,12 +1,11 @@
from __future__ import annotations
-from math import prod
from typing import Any, Dict
import gymnasium as gym
from lightning import Fabric
-from sheeprl.algos.sac.agent import SACActor, SACAgent, SACCritic
+from sheeprl.algos.sac.agent import build_agent
from sheeprl.algos.sac.utils import test
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
@@ -45,22 +44,5 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]):
)
fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder)
- act_dim = prod(action_space.shape)
- obs_dim = sum([prod(observation_space[k].shape) for k in cfg.mlp_keys.encoder])
- actor = SACActor(
- observation_dim=obs_dim,
- action_dim=act_dim,
- distribution_cfg=cfg.distribution,
- hidden_size=cfg.algo.actor.hidden_size,
- action_low=action_space.low,
- action_high=action_space.high,
- )
- critics = [
- SACCritic(observation_dim=obs_dim + act_dim, hidden_size=cfg.algo.critic.hidden_size, num_critics=1)
- for _ in range(cfg.algo.critic.n)
- ]
- target_entropy = -act_dim
- agent = SACAgent(actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device)
- agent.load_state_dict(state["agent"])
- agent = fabric.setup_module(agent)
+ agent = build_agent(fabric, cfg, observation_space, action_space, state["agent"])
test(agent.actor, fabric, cfg, log_dir)
diff --git a/sheeprl/algos/sac/sac.py b/sheeprl/algos/sac/sac.py
index 58e80cc4..d0a9ec7e 100644
--- a/sheeprl/algos/sac/sac.py
+++ b/sheeprl/algos/sac/sac.py
@@ -3,7 +3,6 @@
import copy
import os
import warnings
-from math import prod
from typing import Any, Dict, Optional
import gymnasium as gym
@@ -21,7 +20,7 @@
from torch.utils.data.sampler import BatchSampler
from torchmetrics import SumMetric
-from sheeprl.algos.sac.agent import SACActor, SACAgent, SACCritic
+from sheeprl.algos.sac.agent import SACAgent, build_agent
from sheeprl.algos.sac.loss import critic_loss, entropy_loss, policy_loss
from sheeprl.algos.sac.utils import test
from sheeprl.data.buffers import ReplayBuffer
@@ -148,26 +147,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder)
# Define the agent and the optimizer and setup sthem with Fabric
- act_dim = prod(action_space.shape)
- obs_dim = sum([prod(observation_space[k].shape) for k in cfg.mlp_keys.encoder])
- actor = SACActor(
- observation_dim=obs_dim,
- action_dim=act_dim,
- distribution_cfg=cfg.distribution,
- hidden_size=cfg.algo.actor.hidden_size,
- action_low=action_space.low,
- action_high=action_space.high,
+ agent = build_agent(
+ fabric, cfg, observation_space, action_space, state["agent"] if cfg.checkpoint.resume_from else None
)
- critics = [
- SACCritic(observation_dim=obs_dim + act_dim, hidden_size=cfg.algo.critic.hidden_size, num_critics=1)
- for _ in range(cfg.algo.critic.n)
- ]
- target_entropy = -act_dim
- agent = SACAgent(actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device)
- if cfg.checkpoint.resume_from:
- agent.load_state_dict(state["agent"])
- agent.actor = fabric.setup_module(agent.actor)
- agent.critics = [fabric.setup_module(critic) for critic in agent.critics]
# Optimizers
qf_optimizer = hydra.utils.instantiate(cfg.algo.critic.optimizer, params=agent.qfs.parameters())
diff --git a/sheeprl/algos/sac/sac_decoupled.py b/sheeprl/algos/sac/sac_decoupled.py
index 70be5bad..8881ac2b 100644
--- a/sheeprl/algos/sac/sac_decoupled.py
+++ b/sheeprl/algos/sac/sac_decoupled.py
@@ -20,7 +20,7 @@
from torch.utils.data.sampler import BatchSampler
from torchmetrics import SumMetric
-from sheeprl.algos.sac.agent import SACActor, SACAgent, SACCritic
+from sheeprl.algos.sac.agent import SACActor, SACAgent, SACCritic, build_agent
from sheeprl.algos.sac.sac import train
from sheeprl.algos.sac.utils import test
from sheeprl.data.buffers import ReplayBuffer
@@ -369,27 +369,13 @@ def trainer(
assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"
# Define the agent and the optimizer and setup them with Fabric
- act_dim = prod(envs.single_action_space.shape)
- obs_dim = sum([prod(envs.single_observation_space[k].shape) for k in cfg.mlp_keys.encoder])
-
- actor = SACActor(
- observation_dim=obs_dim,
- action_dim=act_dim,
- distribution_cfg=cfg.distribution,
- hidden_size=cfg.algo.actor.hidden_size,
- action_low=envs.single_action_space.low,
- action_high=envs.single_action_space.high,
+ agent = build_agent(
+ fabric,
+ cfg,
+ envs.single_observation_space,
+ envs.single_action_space,
+ state["agent"] if cfg.checkpoint.resume_from else None,
)
- critics = [
- SACCritic(observation_dim=obs_dim + act_dim, hidden_size=cfg.algo.critic.hidden_size, num_critics=1)
- for _ in range(cfg.algo.critic.n)
- ]
- target_entropy = -act_dim
- agent = SACAgent(actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device)
- if cfg.checkpoint.resume_from:
- agent.load_state_dict(state["agent"])
- agent.actor = fabric.setup_module(agent.actor)
- agent.critics = [fabric.setup_module(critic) for critic in agent.critics]
# Optimizers
qf_optimizer = hydra.utils.instantiate(cfg.algo.critic.optimizer, params=agent.qfs.parameters())
diff --git a/sheeprl/algos/sac_ae/agent.py b/sheeprl/algos/sac_ae/agent.py
index 0265cbe4..4fbbbc57 100644
--- a/sheeprl/algos/sac_ae/agent.py
+++ b/sheeprl/algos/sac_ae/agent.py
@@ -1,16 +1,18 @@
import copy
from math import prod
-from typing import Any, Dict, List, Sequence, SupportsFloat, Tuple, Union
+from typing import Any, Dict, List, Optional, Sequence, SupportsFloat, Tuple, Union
+import gymnasium
import numpy as np
import torch
import torch.nn as nn
+from lightning import Fabric
from lightning.fabric.wrappers import _FabricModule
from numpy.typing import NDArray
from torch import Size, Tensor
from sheeprl.algos.sac_ae.utils import weight_init
-from sheeprl.models.models import CNN, MLP, DeCNN, MultiEncoder
+from sheeprl.models.models import CNN, MLP, DeCNN, MultiDecoder, MultiEncoder
LOG_STD_MAX = 2
LOG_STD_MIN = -10
@@ -448,3 +450,114 @@ def critic_encoder_target_ema(self) -> None:
self.critic_unwrapped.encoder.parameters(), self.critic_target.encoder.parameters()
):
target_param.data.copy_(self._encoder_tau * param.data + (1 - self._encoder_tau) * target_param.data)
+
+
+def build_agent(
+ fabric: Fabric,
+ cfg: Dict[str, Any],
+ obs_space: gymnasium.spaces.Dict,
+ action_space: gymnasium.spaces.Box,
+ agent_state: Optional[Dict[str, Tensor]] = None,
+ encoder_state: Optional[Dict[str, Tensor]] = None,
+ decoder_sate: Optional[Dict[str, Tensor]] = None,
+) -> Tuple[SACAEAgent, _FabricModule, _FabricModule]:
+ act_dim = prod(action_space.shape)
+ target_entropy = -act_dim
+
+ # Define the encoder and decoder and setup them with fabric.
+ # Then we will set the critic encoder and actor decoder as the unwrapped encoder module:
+ # we do not need it wrapped with the strategy inside actor and critic
+ cnn_channels = [prod(obs_space[k].shape[:-2]) for k in cfg.cnn_keys.encoder]
+ mlp_dims = [obs_space[k].shape[0] for k in cfg.mlp_keys.encoder]
+ cnn_encoder = (
+ CNNEncoder(
+ in_channels=sum(cnn_channels),
+ features_dim=cfg.algo.encoder.features_dim,
+ keys=cfg.cnn_keys.encoder,
+ screen_size=cfg.env.screen_size,
+ cnn_channels_multiplier=cfg.algo.encoder.cnn_channels_multiplier,
+ )
+ if cfg.cnn_keys.encoder is not None and len(cfg.cnn_keys.encoder) > 0
+ else None
+ )
+ mlp_encoder = (
+ MLPEncoder(
+ sum(mlp_dims),
+ cfg.mlp_keys.encoder,
+ cfg.algo.encoder.dense_units,
+ cfg.algo.encoder.mlp_layers,
+ eval(cfg.algo.encoder.dense_act),
+ cfg.algo.encoder.layer_norm,
+ )
+ if cfg.mlp_keys.encoder is not None and len(cfg.mlp_keys.encoder) > 0
+ else None
+ )
+ encoder = MultiEncoder(cnn_encoder, mlp_encoder)
+ cnn_decoder = (
+ CNNDecoder(
+ cnn_encoder.conv_output_shape,
+ features_dim=encoder.output_dim,
+ keys=cfg.cnn_keys.decoder,
+ channels=cnn_channels,
+ screen_size=cfg.env.screen_size,
+ cnn_channels_multiplier=cfg.algo.decoder.cnn_channels_multiplier,
+ )
+ if cfg.cnn_keys.decoder is not None and len(cfg.cnn_keys.decoder) > 0
+ else None
+ )
+ mlp_decoder = (
+ MLPDecoder(
+ encoder.output_dim,
+ mlp_dims,
+ cfg.mlp_keys.decoder,
+ cfg.algo.decoder.dense_units,
+ cfg.algo.decoder.mlp_layers,
+ eval(cfg.algo.decoder.dense_act),
+ cfg.algo.decoder.layer_norm,
+ )
+ if cfg.mlp_keys.decoder is not None and len(cfg.mlp_keys.decoder) > 0
+ else None
+ )
+ decoder = MultiDecoder(cnn_decoder, mlp_decoder)
+ if cfg.checkpoint.resume_from:
+ encoder.load_state_dict(encoder_state)
+ decoder.load_state_dict(decoder_sate)
+
+ # Setup actor and critic. Those will initialize with orthogonal weights
+ # both the actor and critic
+ actor = SACAEContinuousActor(
+ encoder=copy.deepcopy(encoder),
+ action_dim=act_dim,
+ distribution_cfg=cfg.distribution,
+ hidden_size=cfg.algo.actor.hidden_size,
+ action_low=action_space.low,
+ action_high=action_space.high,
+ )
+ qfs = [
+ SACAEQFunction(
+ input_dim=encoder.output_dim, action_dim=act_dim, hidden_size=cfg.algo.critic.hidden_size, output_dim=1
+ )
+ for _ in range(cfg.algo.critic.n)
+ ]
+ critic = SACAECritic(encoder=encoder, qfs=qfs)
+
+ # The agent will tied convolutional and linear weights between the encoder actor and critic
+ agent = SACAEAgent(
+ actor,
+ critic,
+ target_entropy,
+ alpha=cfg.algo.alpha.alpha,
+ tau=cfg.algo.tau,
+ encoder_tau=cfg.algo.encoder.tau,
+ device=fabric.device,
+ )
+
+ if agent_state:
+ agent.load_state_dict(agent_state)
+
+ encoder = fabric.setup_module(encoder)
+ decoder = fabric.setup_module(decoder)
+ agent.actor = fabric.setup_module(agent.actor)
+ agent.critic = fabric.setup_module(agent.critic)
+
+ return agent, encoder, decoder
diff --git a/sheeprl/algos/sac_ae/evaluate.py b/sheeprl/algos/sac_ae/evaluate.py
index 508d4fdf..27d1ed32 100644
--- a/sheeprl/algos/sac_ae/evaluate.py
+++ b/sheeprl/algos/sac_ae/evaluate.py
@@ -1,22 +1,12 @@
from __future__ import annotations
-import copy
-from math import prod
from typing import Any, Dict
import gymnasium as gym
from lightning import Fabric
-from sheeprl.algos.sac_ae.agent import (
- CNNEncoder,
- MLPEncoder,
- SACAEAgent,
- SACAEContinuousActor,
- SACAECritic,
- SACAEQFunction,
-)
+from sheeprl.algos.sac_ae.agent import SACAEAgent, build_agent
from sheeprl.algos.sac_ae.utils import test_sac_ae
-from sheeprl.models.models import MultiEncoder
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
from sheeprl.utils.registry import register_evaluation
@@ -51,68 +41,8 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]):
fabric.print("Encoder CNN keys:", cfg.cnn_keys.encoder)
fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder)
- act_dim = prod(action_space.shape)
- target_entropy = -act_dim
-
- # Define the encoder and decoder and setup them with fabric.
- # Then we will set the critic encoder and actor decoder as the unwrapped encoder module:
- # we do not need it wrapped with the strategy inside actor and critic
- cnn_channels = [prod(observation_space[k].shape[:-2]) for k in cfg.cnn_keys.encoder]
- mlp_dims = [observation_space[k].shape[0] for k in cfg.mlp_keys.encoder]
- cnn_encoder = (
- CNNEncoder(
- in_channels=sum(cnn_channels),
- features_dim=cfg.algo.encoder.features_dim,
- keys=cfg.cnn_keys.encoder,
- screen_size=cfg.env.screen_size,
- cnn_channels_multiplier=cfg.algo.encoder.cnn_channels_multiplier,
- )
- if cfg.cnn_keys.encoder is not None and len(cfg.cnn_keys.encoder) > 0
- else None
- )
- mlp_encoder = (
- MLPEncoder(
- sum(mlp_dims),
- cfg.mlp_keys.encoder,
- cfg.algo.encoder.dense_units,
- cfg.algo.encoder.mlp_layers,
- eval(cfg.algo.encoder.dense_act),
- cfg.algo.encoder.layer_norm,
- )
- if cfg.mlp_keys.encoder is not None and len(cfg.mlp_keys.encoder) > 0
- else None
- )
- encoder = MultiEncoder(cnn_encoder, mlp_encoder)
- encoder.load_state_dict(state["encoder"])
-
- # Setup actor and critic. Those will initialize with orthogonal weights
- # both the actor and critic
- actor = SACAEContinuousActor(
- encoder=copy.deepcopy(encoder),
- action_dim=act_dim,
- distribution_cfg=cfg.distribution,
- hidden_size=cfg.algo.actor.hidden_size,
- action_low=action_space.low,
- action_high=action_space.high,
- )
- qfs = [
- SACAEQFunction(
- input_dim=encoder.output_dim, action_dim=act_dim, hidden_size=cfg.algo.critic.hidden_size, output_dim=1
- )
- for _ in range(cfg.algo.critic.n)
- ]
- critic = SACAECritic(encoder=encoder, qfs=qfs)
-
- # The agent will tied convolutional and linear weights between the encoder actor and critic
- agent = SACAEAgent(
- actor,
- critic,
- target_entropy,
- alpha=cfg.algo.alpha.alpha,
- tau=cfg.algo.tau,
- encoder_tau=cfg.algo.encoder.tau,
- device=fabric.device,
+ agent: SACAEAgent
+ agent, _, _ = build_agent(
+ fabric, cfg, observation_space, action_space, state["agent"], state["encoder"], state["decoder"]
)
- agent.load_state_dict(state["agent"])
- agent.actor = fabric.setup_module(agent.actor)
test_sac_ae(agent.actor, fabric, cfg, log_dir)
diff --git a/sheeprl/algos/sac_ae/sac_ae.py b/sheeprl/algos/sac_ae/sac_ae.py
index 46761bfb..3696d95a 100644
--- a/sheeprl/algos/sac_ae/sac_ae.py
+++ b/sheeprl/algos/sac_ae/sac_ae.py
@@ -4,7 +4,6 @@
import os
import time
import warnings
-from math import prod
from typing import Any, Dict, Optional, Union
import gymnasium as gym
@@ -25,16 +24,7 @@
from torchmetrics import SumMetric
from sheeprl.algos.sac.loss import critic_loss, entropy_loss, policy_loss
-from sheeprl.algos.sac_ae.agent import (
- CNNDecoder,
- CNNEncoder,
- MLPDecoder,
- MLPEncoder,
- SACAEAgent,
- SACAEContinuousActor,
- SACAECritic,
- SACAEQFunction,
-)
+from sheeprl.algos.sac_ae.agent import SACAEAgent, build_agent
from sheeprl.algos.sac_ae.utils import preprocess_obs, test_sac_ae
from sheeprl.data.buffers import ReplayBuffer
from sheeprl.models.models import MultiDecoder, MultiEncoder
@@ -213,95 +203,14 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder)
# Define the agent and the optimizer and setup them with Fabric
- act_dim = prod(envs.single_action_space.shape)
- target_entropy = -act_dim
-
- # Define the encoder and decoder and setup them with fabric.
- # Then we will set the critic encoder and actor decoder as the unwrapped encoder module:
- # we do not need it wrapped with the strategy inside actor and critic
- cnn_channels = [prod(envs.single_observation_space[k].shape[:-2]) for k in cfg.cnn_keys.encoder]
- mlp_dims = [envs.single_observation_space[k].shape[0] for k in cfg.mlp_keys.encoder]
- cnn_encoder = (
- CNNEncoder(
- in_channels=sum(cnn_channels),
- features_dim=cfg.algo.encoder.features_dim,
- keys=cfg.cnn_keys.encoder,
- screen_size=cfg.env.screen_size,
- cnn_channels_multiplier=cfg.algo.encoder.cnn_channels_multiplier,
- )
- if cfg.cnn_keys.encoder is not None and len(cfg.cnn_keys.encoder) > 0
- else None
- )
- mlp_encoder = (
- MLPEncoder(
- sum(mlp_dims),
- cfg.mlp_keys.encoder,
- cfg.algo.encoder.dense_units,
- cfg.algo.encoder.mlp_layers,
- eval(cfg.algo.encoder.dense_act),
- cfg.algo.encoder.layer_norm,
- )
- if cfg.mlp_keys.encoder is not None and len(cfg.mlp_keys.encoder) > 0
- else None
- )
- encoder = MultiEncoder(cnn_encoder, mlp_encoder)
- cnn_decoder = (
- CNNDecoder(
- cnn_encoder.conv_output_shape,
- features_dim=encoder.output_dim,
- keys=cfg.cnn_keys.decoder,
- channels=cnn_channels,
- screen_size=cfg.env.screen_size,
- cnn_channels_multiplier=cfg.algo.decoder.cnn_channels_multiplier,
- )
- if cfg.cnn_keys.decoder is not None and len(cfg.cnn_keys.decoder) > 0
- else None
- )
- mlp_decoder = (
- MLPDecoder(
- encoder.output_dim,
- mlp_dims,
- cfg.mlp_keys.decoder,
- cfg.algo.decoder.dense_units,
- cfg.algo.decoder.mlp_layers,
- eval(cfg.algo.decoder.dense_act),
- cfg.algo.decoder.layer_norm,
- )
- if cfg.mlp_keys.decoder is not None and len(cfg.mlp_keys.decoder) > 0
- else None
- )
- decoder = MultiDecoder(cnn_decoder, mlp_decoder)
- if cfg.checkpoint.resume_from:
- encoder.load_state_dict(state["encoder"])
- decoder.load_state_dict(state["decoder"])
-
- # Setup actor and critic. Those will initialize with orthogonal weights
- # both the actor and critic
- actor = SACAEContinuousActor(
- encoder=copy.deepcopy(encoder),
- action_dim=act_dim,
- distribution_cfg=cfg.distribution,
- hidden_size=cfg.algo.actor.hidden_size,
- action_low=envs.single_action_space.low,
- action_high=envs.single_action_space.high,
- )
- qfs = [
- SACAEQFunction(
- input_dim=encoder.output_dim, action_dim=act_dim, hidden_size=cfg.algo.critic.hidden_size, output_dim=1
- )
- for _ in range(cfg.algo.critic.n)
- ]
- critic = SACAECritic(encoder=encoder, qfs=qfs)
-
- # The agent will tied convolutional and linear weights between the encoder actor and critic
- agent = SACAEAgent(
- actor,
- critic,
- target_entropy,
- alpha=cfg.algo.alpha.alpha,
- tau=cfg.algo.tau,
- encoder_tau=cfg.algo.encoder.tau,
- device=fabric.device,
+ agent, encoder, decoder = build_agent(
+ fabric,
+ cfg,
+ observation_space,
+ envs.single_action_space,
+ state["agent"] if cfg.checkpoint.resume_from else None,
+ state["encoder"] if cfg.checkpoint.resume_from else None,
+ state["decoder"] if cfg.checkpoint.resume_from else None,
)
# Optimizers
@@ -312,18 +221,12 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
decoder_optimizer = hydra.utils.instantiate(cfg.algo.decoder.optimizer, params=decoder.parameters())
if cfg.checkpoint.resume_from:
- agent.load_state_dict(state["agent"])
qf_optimizer.load_state_dict(state["qf_optimizer"])
actor_optimizer.load_state_dict(state["actor_optimizer"])
alpha_optimizer.load_state_dict(state["alpha_optimizer"])
encoder_optimizer.load_state_dict(state["encoder_optimizer"])
decoder_optimizer.load_state_dict(state["decoder_optimizer"])
- encoder = fabric.setup_module(encoder)
- decoder = fabric.setup_module(decoder)
- agent.actor = fabric.setup_module(agent.actor)
- agent.critic = fabric.setup_module(agent.critic)
-
qf_optimizer, actor_optimizer, alpha_optimizer, encoder_optimizer, decoder_optimizer = fabric.setup_optimizers(
qf_optimizer, actor_optimizer, alpha_optimizer, encoder_optimizer, decoder_optimizer
)
diff --git a/sheeprl/configs/model_manager/default.yaml b/sheeprl/configs/model_manager/default.yaml
index 7ab58825..e397b00a 100644
--- a/sheeprl/configs/model_manager/default.yaml
+++ b/sheeprl/configs/model_manager/default.yaml
@@ -1,2 +1,2 @@
-disabled: False
+disabled: True
models: {}