Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/add build agents #153

Merged
merged 6 commits into from
Nov 20, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
<tr>
<td><img src="https://github.com/Eclectic-Sheep/sheeprl/assets/18405289/6efd09f0-df91-4da0-971d-92e0213b8835" width="200px"></td>
<td><img src="https://github.com/Eclectic-Sheep/sheeprl/assets/18405289/dbba57db-6ef5-4db4-9c53-d7b5f303033a" width="200px"></td>
</tr>
<tr>
<td><img src="https://github.com/Eclectic-Sheep/sheeprl/assets/18405289/3f38e5eb-aadd-4402-a698-695d1f99c048" width="200px"></td>
<td><img src="https://github.com/Eclectic-Sheep/sheeprl/assets/18405289/93749119-fe61-44f1-94bb-fdb89c1869b5" width="200px"></td>
</tr>
Expand Down Expand Up @@ -58,7 +56,7 @@
<td>DOA++(w/o optimizations)<sup>1</sup></td>
<td>7M</td>
<td>18d 22h</td>
<td>2726/33283<sup>2</sup></td>
<td>2726/3328<sup>2</sup></td>
<td>N.A.</td>
<td>1-3080</td>
</tr>
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/dreamer_v1/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def get_greedy_action(
return actions


def build_models(
def build_agent(
fabric: Fabric,
actions_dim: Sequence[int],
is_continuous: bool,
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from torch.utils.data import BatchSampler
from torchmetrics import SumMetric

from sheeprl.algos.dreamer_v1.agent import PlayerDV1, WorldModel, build_models
from sheeprl.algos.dreamer_v1.agent import PlayerDV1, WorldModel, build_agent
from sheeprl.algos.dreamer_v1.loss import actor_loss, critic_loss, reconstruction_loss
from sheeprl.algos.dreamer_v1.utils import compute_lambda_values
from sheeprl.algos.dreamer_v2.utils import test
Expand Down Expand Up @@ -477,7 +477,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder)
obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder

world_model, actor, critic = build_models(
world_model, actor, critic = build_agent(
fabric,
actions_dim,
is_continuous,
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/dreamer_v1/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import gymnasium as gym
from lightning import Fabric

from sheeprl.algos.dreamer_v1.agent import PlayerDV1, build_models
from sheeprl.algos.dreamer_v1.agent import PlayerDV1, build_agent
from sheeprl.algos.dreamer_v2.utils import test
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
Expand Down Expand Up @@ -47,7 +47,7 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]):
action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
)
# Create the actor and critic models
world_model, actor, _ = build_models(
world_model, actor, _ = build_agent(
fabric,
actions_dim,
is_continuous,
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/dreamer_v2/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,7 @@ def get_greedy_action(
return actions


def build_models(
def build_agent(
fabric: Fabric,
actions_dim: Sequence[int],
is_continuous: bool,
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torch.utils.data import BatchSampler
from torchmetrics import SumMetric

from sheeprl.algos.dreamer_v2.agent import PlayerDV2, WorldModel, build_models
from sheeprl.algos.dreamer_v2.agent import PlayerDV2, WorldModel, build_agent
from sheeprl.algos.dreamer_v2.loss import reconstruction_loss
from sheeprl.algos.dreamer_v2.utils import compute_lambda_values, test
from sheeprl.data.buffers import AsyncReplayBuffer, EpisodeBuffer
Expand Down Expand Up @@ -501,7 +501,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder)
obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder

world_model, actor, critic, target_critic = build_models(
world_model, actor, critic, target_critic = build_agent(
fabric,
actions_dim,
is_continuous,
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/dreamer_v2/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import gymnasium as gym
from lightning import Fabric

from sheeprl.algos.dreamer_v2.agent import PlayerDV2, build_models
from sheeprl.algos.dreamer_v2.agent import PlayerDV2, build_agent
from sheeprl.algos.dreamer_v2.utils import test
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
Expand Down Expand Up @@ -47,7 +47,7 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]):
action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
)
# Create the actor and critic models
world_model, actor, _, _ = build_models(
world_model, actor, _, _ = build_agent(
fabric,
actions_dim,
is_continuous,
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/dreamer_v3/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,7 +897,7 @@ def add_exploration_noise(
return tuple(expl_actions)


def build_models(
def build_agent(
fabric: Fabric,
actions_dim: Sequence[int],
is_continuous: bool,
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/dreamer_v3/dreamer_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from torch.utils.data import BatchSampler
from torchmetrics import SumMetric

from sheeprl.algos.dreamer_v3.agent import PlayerDV3, WorldModel, build_models
from sheeprl.algos.dreamer_v3.agent import PlayerDV3, WorldModel, build_agent
from sheeprl.algos.dreamer_v3.loss import reconstruction_loss
from sheeprl.algos.dreamer_v3.utils import Moments, compute_lambda_values, test
from sheeprl.data.buffers import AsyncReplayBuffer
Expand Down Expand Up @@ -435,7 +435,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder)
obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder

world_model, actor, critic, target_critic = build_models(
world_model, actor, critic, target_critic = build_agent(
fabric,
actions_dim,
is_continuous,
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/dreamer_v3/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import gymnasium as gym
from lightning import Fabric

from sheeprl.algos.dreamer_v3.agent import PlayerDV3, build_models
from sheeprl.algos.dreamer_v3.agent import PlayerDV3, build_agent
from sheeprl.algos.dreamer_v3.utils import test
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
Expand Down Expand Up @@ -47,7 +47,7 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]):
action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
)
# Create the actor and critic models
world_model, actor, _, _ = build_models(
world_model, actor, _, _ = build_agent(
fabric,
actions_dim,
is_continuous,
Expand Down
43 changes: 42 additions & 1 deletion sheeprl/algos/droq/agent.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import copy
from typing import Sequence, Tuple, Union
from math import prod
from typing import Any, Dict, Optional, Sequence, Tuple, Union

import gymnasium
import torch
import torch.nn as nn
from lightning import Fabric
from lightning.fabric.wrappers import _FabricModule
from torch import Tensor

Expand Down Expand Up @@ -198,3 +201,41 @@ def qfs_target_ema(self, critic_idx: int) -> None:
self.qfs_unwrapped[critic_idx].parameters(), self.qfs_target[critic_idx].parameters()
):
target_param.data.copy_(self._tau * param.data + (1 - self._tau) * target_param.data)


def build_agent(
fabric: Fabric,
cfg: Dict[str, Any],
obs_space: gymnasium.spaces.Dict,
action_space: gymnasium.spaces.Box,
agent_state: Optional[Dict[str, Tensor]] = None,
) -> DROQAgent:
act_dim = prod(action_space.shape)
obs_dim = sum([prod(obs_space[k].shape) for k in cfg.mlp_keys.encoder])
actor = SACActor(
observation_dim=obs_dim,
action_dim=act_dim,
distribution_cfg=cfg.distribution,
hidden_size=cfg.algo.actor.hidden_size,
action_low=action_space.low,
action_high=action_space.high,
)
critics = [
DROQCritic(
observation_dim=obs_dim + act_dim,
hidden_size=cfg.algo.critic.hidden_size,
num_critics=1,
dropout=cfg.algo.critic.dropout,
)
for _ in range(cfg.algo.critic.n)
]
target_entropy = -act_dim
agent = DROQAgent(
actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device
)
if agent_state:
agent.load_state_dict(agent_state)
agent.actor = fabric.setup_module(agent.actor)
agent.critics = [fabric.setup_module(critic) for critic in agent.critics]

return agent
32 changes: 3 additions & 29 deletions sheeprl/algos/droq/droq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import copy
import os
import warnings
from math import prod
from typing import Any, Dict

import gymnasium as gym
Expand All @@ -20,8 +19,7 @@
from torch.utils.data.sampler import BatchSampler
from torchmetrics import SumMetric

from sheeprl.algos.droq.agent import DROQAgent, DROQCritic
from sheeprl.algos.sac.agent import SACActor
from sheeprl.algos.droq.agent import DROQAgent, build_agent
from sheeprl.algos.sac.loss import entropy_loss, policy_loss
from sheeprl.algos.sac.sac import test
from sheeprl.data.buffers import ReplayBuffer
Expand Down Expand Up @@ -196,33 +194,9 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
fabric.print("Encoder MLP keys:", cfg.mlp_keys.encoder)

# Define the agent and the optimizer and setup them with Fabric
act_dim = prod(action_space.shape)
obs_dim = sum([prod(observation_space[k].shape) for k in cfg.mlp_keys.encoder])
actor = SACActor(
observation_dim=obs_dim,
action_dim=act_dim,
distribution_cfg=cfg.distribution,
hidden_size=cfg.algo.actor.hidden_size,
action_low=action_space.low,
action_high=action_space.high,
agent = build_agent(
fabric, cfg, observation_space, action_space, state["agent"] if cfg.checkpoint.resume_from else None
)
critics = [
DROQCritic(
observation_dim=obs_dim + act_dim,
hidden_size=cfg.algo.critic.hidden_size,
num_critics=1,
dropout=cfg.algo.critic.dropout,
)
for _ in range(cfg.algo.critic.n)
]
target_entropy = -act_dim
agent = DROQAgent(
actor, critics, target_entropy, alpha=cfg.algo.alpha.alpha, tau=cfg.algo.tau, device=fabric.device
)
if cfg.checkpoint.resume_from:
agent.load_state_dict(state["agent"])
agent.actor = fabric.setup_module(agent.actor)
agent.critics = [fabric.setup_module(critic) for critic in agent.critics]

# Optimizers
qf_optimizer = hydra.utils.instantiate(cfg.algo.critic.optimizer, params=agent.qfs.parameters())
Expand Down
6 changes: 3 additions & 3 deletions sheeprl/algos/p2e_dv1/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from lightning.fabric.wrappers import _FabricModule

from sheeprl.algos.dreamer_v1.agent import WorldModel
from sheeprl.algos.dreamer_v1.agent import build_models as dv1_build_models
from sheeprl.algos.dreamer_v1.agent import build_agent as dv1_build_agent
from sheeprl.algos.dreamer_v2.agent import Actor as DV2Actor
from sheeprl.algos.dreamer_v2.agent import MinedojoActor as DV2MinedojoActor
from sheeprl.models.models import MLP
Expand All @@ -20,7 +20,7 @@
MinedojoActor = DV2MinedojoActor


def build_models(
def build_agent(
fabric: Fabric,
actions_dim: Sequence[int],
is_continuous: bool,
Expand Down Expand Up @@ -67,7 +67,7 @@ def build_models(
latent_state_size = world_model_cfg.stochastic_size + world_model_cfg.recurrent_model.recurrent_state_size

# Create exploration models
world_model, actor_exploration, critic_exploration = dv1_build_models(
world_model, actor_exploration, critic_exploration = dv1_build_agent(
fabric,
actions_dim=actions_dim,
is_continuous=is_continuous,
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/p2e_dv1/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from sheeprl.algos.dreamer_v1.agent import PlayerDV1
from sheeprl.algos.dreamer_v2.utils import test
from sheeprl.algos.p2e_dv1.agent import build_models
from sheeprl.algos.p2e_dv1.agent import build_agent
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
from sheeprl.utils.registry import register_evaluation
Expand Down Expand Up @@ -48,7 +48,7 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]):
action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
)
# Create the actor and critic models
world_model, actor_task, _, _, _ = build_models(
world_model, actor_task, _, _, _ = build_agent(
fabric,
actions_dim,
is_continuous,
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from sheeprl.algos.dreamer_v1.loss import actor_loss, critic_loss, reconstruction_loss
from sheeprl.algos.dreamer_v1.utils import compute_lambda_values
from sheeprl.algos.dreamer_v2.utils import test
from sheeprl.algos.p2e_dv1.agent import build_models
from sheeprl.algos.p2e_dv1.agent import build_agent
from sheeprl.data.buffers import AsyncReplayBuffer
from sheeprl.models.models import MLP
from sheeprl.utils.env import make_env
Expand Down Expand Up @@ -488,7 +488,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder)
obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder

world_model, actor_task, critic_task, actor_exploration, critic_exploration = build_models(
world_model, actor_task, critic_task, actor_exploration, critic_exploration = build_agent(
fabric,
actions_dim,
is_continuous,
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from sheeprl.algos.dreamer_v1.agent import PlayerDV1
from sheeprl.algos.dreamer_v1.dreamer_v1 import train
from sheeprl.algos.dreamer_v2.utils import test
from sheeprl.algos.p2e_dv1.agent import build_models
from sheeprl.algos.p2e_dv1.agent import build_agent
from sheeprl.data.buffers import AsyncReplayBuffer
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
Expand Down Expand Up @@ -139,7 +139,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
fabric.print("Decoder MLP keys:", cfg.mlp_keys.decoder)
obs_keys = cfg.cnn_keys.encoder + cfg.mlp_keys.encoder

world_model, actor_task, critic_task, actor_exploration, _ = build_models(
world_model, actor_task, critic_task, actor_exploration, _ = build_agent(
fabric,
actions_dim,
is_continuous,
Expand Down
6 changes: 3 additions & 3 deletions sheeprl/algos/p2e_dv2/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from sheeprl.algos.dreamer_v2.agent import Actor as DV2Actor
from sheeprl.algos.dreamer_v2.agent import MinedojoActor as DV2MinedojoActor
from sheeprl.algos.dreamer_v2.agent import WorldModel
from sheeprl.algos.dreamer_v2.agent import build_models as dv2_build_models
from sheeprl.algos.dreamer_v2.agent import build_agent as dv2_build_agent
from sheeprl.models.models import MLP
from sheeprl.utils.utils import init_weights

Expand All @@ -22,7 +22,7 @@
MinedojoActor = DV2MinedojoActor


def build_models(
def build_agent(
fabric: Fabric,
actions_dim: Sequence[int],
is_continuous: bool,
Expand Down Expand Up @@ -78,7 +78,7 @@ def build_models(
latent_state_size = stochastic_size + world_model_cfg.recurrent_model.recurrent_state_size

# Create exploration models
world_model, actor_exploration, critic_exploration, target_critic_exploration = dv2_build_models(
world_model, actor_exploration, critic_exploration, target_critic_exploration = dv2_build_agent(
fabric,
actions_dim=actions_dim,
is_continuous=is_continuous,
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/p2e_dv2/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from sheeprl.algos.dreamer_v2.agent import PlayerDV2
from sheeprl.algos.dreamer_v2.utils import test
from sheeprl.algos.p2e_dv2.agent import build_models
from sheeprl.algos.p2e_dv2.agent import build_agent
from sheeprl.utils.env import make_env
from sheeprl.utils.logger import get_log_dir, get_logger
from sheeprl.utils.registry import register_evaluation
Expand Down Expand Up @@ -48,7 +48,7 @@ def evaluate(fabric: Fabric, cfg: Dict[str, Any], state: Dict[str, Any]):
action_space.shape if is_continuous else (action_space.nvec.tolist() if is_multidiscrete else [action_space.n])
)
# Create the actor and critic models
world_model, actor_task, _, _, _, _, _ = build_models(
world_model, actor_task, _, _, _, _, _ = build_agent(
fabric,
actions_dim,
is_continuous,
Expand Down
Loading