Skip to content

Commit

Permalink
Fix/player build agent (#258)
Browse files Browse the repository at this point in the history
* Decoupled RSSM for DV3 agent

* Initialize posterior with prior if is_first is True

* Fix PlayerDV3 creation in evaluation

* Fix representation_model

* Fix compute first prior state with a zero posterior

* DV3 replay ratio conversion

* Removed expl parameters dependent on old per_Rank_gradient_steps

* feat: update repeats computation

* feat: update learning starts in config

* fix: remove files

* feat: update repeats

* Let Dv3 compute bootstrap correctly

* feat: added replay ratio and update exploration

* Fix exploration actions computation on DV1

* Fix naming

* Add replay-ratio to SAC

* feat: added replay ratio to p2e algos

* feat: update configs and utils of p2e algos

* Add replay-ratio to SAC-AE

* Add DrOQ replay ratio

* Fix tests

* Fix mispelled

* Fix wrong attribute accesing

* FIx naming and configs

* feat: add terminated and truncated to dreamer, p2e and ppo algos

* fix: dmc wrapper

* feat: update algos to split terminated from truncated

* fix: crafter and diambra wrappers

* feat: replace done with truncated key in when the buffer is added to the checkpoint

* Set Distribution.validate_args once at the beginning

* Move validate_args in run method

* Defined PPOPlayer

* Defined RecurrentPPOPlayer

* feat: added truncated/terminated to minedojo environment

* Add SACPlayer

* FIx evaluate.py for PPO and RecurrentPPO

* FIx DrOQ build_agent

* Delete unused training agent during evaluation

* Fix SACPlayer

* Fix DrOQ build_agent

* Fix PPO decoupled creating single-device fabric

* Adapt SAC decoupled to new build_agent

* Fix typings + add get_actions method

* Add SACAEPlayer

* Add A2CPlayer from PPOPlayer

* Fix get_single_device_fabric

* Fix calling get_values on player instead of agent

* Setup PLayerDV1 in build_agent

* Fix typings

* DV2 player

* Remove one weight tie

* Fix DV3 player in build_agent

* Fix return player from build_agent

* Set actor_type during evaluation

* Build player in build_agent for P2E-DV3

* Update comments

* Update dreamer-v3 cfg

* Learnable initial recurrent state choice

* Fix DecoupledRSSM to accept the learnable_initial_recurrent_state flag

* Preserve input dtype after LayerNorm (pytorch/pytorch#66707 (comment))

* Fix imports

* Move hyperparams to rightful key inside world_model

* sample_actions to greedy

* From sample_actions to greedy

* From sample_actions to greedy

* unwrap_fabric before test in p2e

* Fix player in notebook

* Update how-tos

---------

Co-authored-by: Michele Milesi <[email protected]>
  • Loading branch information
belerico and michele-milesi authored Apr 8, 2024
1 parent 5e75246 commit 32a3736
Show file tree
Hide file tree
Showing 53 changed files with 1,574 additions and 623 deletions.
334 changes: 330 additions & 4 deletions howto/register_external_algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,22 @@ my_awesome_algo

## The agent

The agent is the core of the algorithm and it is defined in the `agent.py` file. It must contain at least single function called `build_agent` that returns a `torch.nn.Module` wrapped with Fabric:
The agent is the core of the algorithm and it is defined in the `agent.py` file. It must contain at least single function called `build_agent` that returns at least a tuple composed of two `torch.nn.Module` wrapped with Fabric; the first one is the agent used during the training phase, while the other one is the one used during the environment interaction:

```python
from __future__ import annotations

from typing import Any, Dict, Sequence
import copy
from typing import Any, Dict, Sequence, Tuple

import gymnasium
from lightning import Fabric
from lightning.fabric.wrappers import _FabricModule
import torch
from torch import Tensor

from sheeprl.utils.fabric import get_single_device_fabric


class SOTAAgent(torch.nn.Module):
def __init__(self, ...):
Expand All @@ -38,6 +41,16 @@ class SOTAAgent(torch.nn.Module):
def forward(self, obs: Dict[str, torch.Tensor]) -> Tensor:
...

class SOTAAgentPlayer(torch.nn.Module):
def __init__(self, ...):
...

def forward(self, obs: Dict[str, torch.Tensor]) -> Tensor:
...

def get_actions(self, obs: Dict[str, torch.Tensor], greedy: bool = False) -> Tensor:
...


def build_agent(
fabric: Fabric,
Expand All @@ -46,7 +59,7 @@ def build_agent(
cfg: Dict[str, Any],
observation_space: gymnasium.spaces.Dict,
state: Dict[str, Any] | None = None,
) -> _FabricModule:
) -> Tuple[_FabricModule, _FabricModule]:

# Define the agent here
agent = SOTAAgent(...)
Expand All @@ -55,12 +68,325 @@ def build_agent(
if state:
agent.load_state_dict(state)

# Setup player agent
player = copy.deepcopy(agent)

# Setup the agent with Fabric
agent = fabric.setup_model(agent)

return agent
# Setup the player agent with a single-device Fabric
fabric_player = get_single_device_fabric(fabric)
player = fabric_player.setup_module(player)

# Tie weights between the agent and the player
for agent_p, player_p in zip(agent.parameters(), player.parameters()):
player_p.data = agent_p.data
return agent, player
```

The player agent is wrapped with a **single-device Fabric**, in this way we maintain the same precision and device of the main Fabric object, but with the player agent being able to interct with the environment skipping possible distributed synchronization points.

If the agent is composed of multiple models, each one with its own forward method, it is advisable to wrap each one of them with the main Fabric object; the same happens for the player agent, where each of the models has to be consequently wrapped with the single-device Fabric obejct. Here we have the example of the **PPOAgent**:

```python
from __future__ import annotations

import copy
from math import prod
from typing import Any, Dict, List, Optional, Sequence, Tuple

import gymnasium
import torch
import torch.nn as nn
from lightning import Fabric
from torch import Tensor
from torch.distributions import Distribution, Independent, Normal, OneHotCategorical

from sheeprl.models.models import MLP, MultiEncoder, NatureCNN
from sheeprl.utils.fabric import get_single_device_fabric


class CNNEncoder(nn.Module):
def __init__(
self,
in_channels: int,
features_dim: int,
screen_size: int,
keys: Sequence[str],
) -> None:
super().__init__()
self.keys = keys
self.input_dim = (in_channels, screen_size, screen_size)
self.output_dim = features_dim
self.model = NatureCNN(in_channels=in_channels, features_dim=features_dim, screen_size=screen_size)

def forward(self, obs: Dict[str, Tensor]) -> Tensor:
x = torch.cat([obs[k] for k in self.keys], dim=-3)
return self.model(x)


class MLPEncoder(nn.Module):
def __init__(
self,
input_dim: int,
features_dim: int | None,
keys: Sequence[str],
dense_units: int = 64,
mlp_layers: int = 2,
dense_act: nn.Module = nn.ReLU,
layer_norm: bool = False,
) -> None:
super().__init__()
self.keys = keys
self.input_dim = input_dim
self.output_dim = features_dim if features_dim else dense_units
self.model = MLP(
input_dim,
features_dim,
[dense_units] * mlp_layers,
activation=dense_act,
norm_layer=[nn.LayerNorm for _ in range(mlp_layers)] if layer_norm else None,
norm_args=[{"normalized_shape": dense_units} for _ in range(mlp_layers)] if layer_norm else None,
)

def forward(self, obs: Dict[str, Tensor]) -> Tensor:
x = torch.cat([obs[k] for k in self.keys], dim=-1)
return self.model(x)


class PPOActor(nn.Module):
def __init__(self, actor_backbone: torch.nn.Module, actor_heads: torch.nn.ModuleList, is_continuous: bool) -> None:
super().__init__()
self.actor_backbone = actor_backbone
self.actor_heads = actor_heads
self.is_continuous = is_continuous

def forward(self, x: Tensor) -> List[Tensor]:
x = self.actor_backbone(x)
return [head(x) for head in self.actor_heads]


class PPOAgent(nn.Module):
def __init__(
self,
actions_dim: Sequence[int],
obs_space: gymnasium.spaces.Dict,
encoder_cfg: Dict[str, Any],
actor_cfg: Dict[str, Any],
critic_cfg: Dict[str, Any],
cnn_keys: Sequence[str],
mlp_keys: Sequence[str],
screen_size: int,
distribution_cfg: Dict[str, Any],
is_continuous: bool = False,
):
super().__init__()
self.is_continuous = is_continuous
self.distribution_cfg = distribution_cfg
self.actions_dim = actions_dim
in_channels = sum([prod(obs_space[k].shape[:-2]) for k in cnn_keys])
mlp_input_dim = sum([obs_space[k].shape[0] for k in mlp_keys])
cnn_encoder = (
CNNEncoder(in_channels, encoder_cfg.cnn_features_dim, screen_size, cnn_keys)
if cnn_keys is not None and len(cnn_keys) > 0
else None
)
mlp_encoder = (
MLPEncoder(
mlp_input_dim,
encoder_cfg.mlp_features_dim,
mlp_keys,
encoder_cfg.dense_units,
encoder_cfg.mlp_layers,
eval(encoder_cfg.dense_act),
encoder_cfg.layer_norm,
)
if mlp_keys is not None and len(mlp_keys) > 0
else None
)
self.feature_extractor = MultiEncoder(cnn_encoder, mlp_encoder)
features_dim = self.feature_extractor.output_dim
self.critic = MLP(
input_dims=features_dim,
output_dim=1,
hidden_sizes=[critic_cfg.dense_units] * critic_cfg.mlp_layers,
activation=eval(critic_cfg.dense_act),
norm_layer=[nn.LayerNorm for _ in range(critic_cfg.mlp_layers)] if critic_cfg.layer_norm else None,
norm_args=(
[{"normalized_shape": critic_cfg.dense_units} for _ in range(critic_cfg.mlp_layers)]
if critic_cfg.layer_norm
else None
),
)
actor_backbone = (
MLP(
input_dims=features_dim,
output_dim=None,
hidden_sizes=[actor_cfg.dense_units] * actor_cfg.mlp_layers,
activation=eval(actor_cfg.dense_act),
flatten_dim=None,
norm_layer=[nn.LayerNorm] * actor_cfg.mlp_layers if actor_cfg.layer_norm else None,
norm_args=(
[{"normalized_shape": actor_cfg.dense_units} for _ in range(actor_cfg.mlp_layers)]
if actor_cfg.layer_norm
else None
),
)
if actor_cfg.mlp_layers > 0
else nn.Identity()
)
if is_continuous:
actor_heads = nn.ModuleList([nn.Linear(actor_cfg.dense_units, sum(actions_dim) * 2)])
else:
actor_heads = nn.ModuleList([nn.Linear(actor_cfg.dense_units, action_dim) for action_dim in actions_dim])
self.actor = PPOActor(actor_backbone, actor_heads, is_continuous)

def forward(
self, obs: Dict[str, Tensor], actions: Optional[List[Tensor]] = None
) -> Tuple[Sequence[Tensor], Tensor, Tensor, Tensor]:
feat = self.feature_extractor(obs)
actor_out: List[Tensor] = self.actor(feat)
values = self.critic(feat)
if self.is_continuous:
mean, log_std = torch.chunk(actor_out[0], chunks=2, dim=-1)
std = log_std.exp()
normal = Independent(Normal(mean, std), 1)
if actions is None:
actions = normal.sample()
else:
# always composed by a tuple of one element containing all the
# continuous actions
actions = actions[0]
log_prob = normal.log_prob(actions)
return tuple([actions]), log_prob.unsqueeze(dim=-1), normal.entropy().unsqueeze(dim=-1), values
else:
should_append = False
actions_logprobs: List[Tensor] = []
actions_entropies: List[Tensor] = []
actions_dist: List[Distribution] = []
if actions is None:
should_append = True
actions: List[Tensor] = []
for i, logits in enumerate(actor_out):
actions_dist.append(OneHotCategorical(logits=logits))
actions_entropies.append(actions_dist[-1].entropy())
if should_append:
actions.append(actions_dist[-1].sample())
actions_logprobs.append(actions_dist[-1].log_prob(actions[i]))
return (
tuple(actions),
torch.stack(actions_logprobs, dim=-1).sum(dim=-1, keepdim=True),
torch.stack(actions_entropies, dim=-1).sum(dim=-1, keepdim=True),
values,
)


class PPOPlayer(nn.Module):
def __init__(self, feature_extractor: MultiEncoder, actor: PPOActor, critic: nn.Module) -> None:
super().__init__()
self.feature_extractor = feature_extractor
self.critic = critic
self.actor = actor

def forward(self, obs: Dict[str, Tensor]) -> Tuple[Sequence[Tensor], Tensor, Tensor]:
feat = self.feature_extractor(obs)
values = self.critic(feat)
actor_out: List[Tensor] = self.actor(feat)
if self.actor.is_continuous:
mean, log_std = torch.chunk(actor_out[0], chunks=2, dim=-1)
std = log_std.exp()
normal = Independent(Normal(mean, std), 1)
actions = normal.sample()
log_prob = normal.log_prob(actions)
return tuple([actions]), log_prob.unsqueeze(dim=-1), values
else:
actions_dist: List[Distribution] = []
actions_logprobs: List[Tensor] = []
actions: List[Tensor] = []
for i, logits in enumerate(actor_out):
actions_dist.append(OneHotCategorical(logits=logits))
actions.append(actions_dist[-1].sample())
actions_logprobs.append(actions_dist[-1].log_prob(actions[i]))
return (
tuple(actions),
torch.stack(actions_logprobs, dim=-1).sum(dim=-1, keepdim=True),
values,
)

def get_values(self, obs: Dict[str, Tensor]) -> Tensor:
feat = self.feature_extractor(obs)
return self.critic(feat)

def get_actions(self, obs: Dict[str, Tensor], greedy: bool = False) -> Sequence[Tensor]:
feat = self.feature_extractor(obs)
actor_out: List[Tensor] = self.actor(feat)
if self.actor.is_continuous:
mean, log_std = torch.chunk(actor_out[0], chunks=2, dim=-1)
if greedy:
actions = mean
else:
std = log_std.exp()
normal = Independent(Normal(mean, std), 1)
actions = normal.sample()
return tuple([actions])
else:
actions: List[Tensor] = []
actions_dist: List[Distribution] = []
for logits in actor_out:
actions_dist.append(OneHotCategorical(logits=logits))
if greedy:
actions.append(actions_dist[-1].mode)
else:
actions.append(actions_dist[-1].sample())
return tuple(actions)


def build_agent(
fabric: Fabric,
actions_dim: Sequence[int],
is_continuous: bool,
cfg: Dict[str, Any],
obs_space: gymnasium.spaces.Dict,
agent_state: Optional[Dict[str, Tensor]] = None,
) -> Tuple[PPOAgent, PPOPlayer]:
agent = PPOAgent(
actions_dim=actions_dim,
obs_space=obs_space,
encoder_cfg=cfg.algo.encoder,
actor_cfg=cfg.algo.actor,
critic_cfg=cfg.algo.critic,
cnn_keys=cfg.algo.cnn_keys.encoder,
mlp_keys=cfg.algo.mlp_keys.encoder,
screen_size=cfg.env.screen_size,
distribution_cfg=cfg.distribution,
is_continuous=is_continuous,
)
if agent_state:
agent.load_state_dict(agent_state)

# Setup player agent
player = PPOPlayer(copy.deepcopy(agent.feature_extractor), copy.deepcopy(agent.actor), copy.deepcopy(agent.critic))

# Setup training agent
agent.feature_extractor = fabric.setup_module(agent.feature_extractor)
agent.critic = fabric.setup_module(agent.critic)
agent.actor = fabric.setup_module(agent.actor)

# Setup player agent
fabric_player = get_single_device_fabric(fabric)
player.feature_extractor = fabric_player.setup_module(player.feature_extractor)
player.critic = fabric_player.setup_module(player.critic)
player.actor = fabric_player.setup_module(player.actor)

# Tie weights between the agent and the player
for agent_p, player_p in zip(agent.feature_extractor.parameters(), player.feature_extractor.parameters()):
player_p.data = agent_p.data
for agent_p, player_p in zip(agent.actor.parameters(), player.actor.parameters()):
player_p.data = agent_p.data
for agent_p, player_p in zip(agent.critic.parameters(), player.critic.parameters()):
player_p.data = agent_p.data
return agent, player

## Loss functions

All the loss functions to be optimized by the agent during the training should be defined under the `loss.py` file, even though is not strictly necessary:
Expand Down
Loading

0 comments on commit 32a3736

Please sign in to comment.