Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix/dreamer v3 #253

Merged
merged 62 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from 59 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
3711d05
Decoupled RSSM for DV3 agent
belerico Feb 8, 2024
e80e9d5
Initialize posterior with prior if is_first is True
belerico Feb 8, 2024
b23112a
Merge branch 'main' of https://github.com/Eclectic-Sheep/sheeprl into…
belerico Feb 12, 2024
f47b8f9
Fix PlayerDV3 creation in evaluation
belerico Feb 12, 2024
e42c83d
Merge branch 'main' of https://github.com/Eclectic-Sheep/sheeprl into…
belerico Feb 26, 2024
2ec4fbb
Fix representation_model
belerico Feb 26, 2024
3a5380b
Fix compute first prior state with a zero posterior
belerico Feb 27, 2024
42d9433
DV3 replay ratio conversion
belerico Feb 29, 2024
750f671
Merge branch 'main' of https://github.com/Eclectic-Sheep/sheeprl into…
belerico Feb 29, 2024
b06433b
Removed expl parameters dependent on old per_Rank_gradient_steps
belerico Feb 29, 2024
20cc43e
Merge branch 'main' of https://github.com/Eclectic-Sheep/sheeprl into…
belerico Mar 4, 2024
37d0e86
Merge branch 'main' of github.com:Eclectic-Sheep/sheeprl into feature…
michele-milesi Mar 18, 2024
704b0ce
feat: update repeats computation
michele-milesi Mar 18, 2024
20905f0
Merge branch 'main' of github.com:Eclectic-Sheep/sheeprl into feature…
michele-milesi Mar 28, 2024
e1290ee
feat: update learning starts in config
michele-milesi Mar 28, 2024
1f0c0ef
fix: remove files
michele-milesi Mar 28, 2024
cd4a4c4
feat: update repeats
michele-milesi Mar 28, 2024
b17d451
Let Dv3 compute bootstrap correctly
belerico Mar 28, 2024
e8c9049
feat: added replay ratio and update exploration
michele-milesi Mar 28, 2024
88c6968
Fix exploration actions computation on DV1
belerico Mar 28, 2024
a5c957c
Fix naming
belerico Mar 28, 2024
c36577d
Add replay-ratio to SAC
belerico Mar 28, 2024
0bc9f07
feat: added replay ratio to p2e algos
michele-milesi Mar 28, 2024
b5fbe5d
feat: update configs and utils of p2e algos
michele-milesi Mar 28, 2024
24c9352
Add replay-ratio to SAC-AE
belerico Mar 28, 2024
a11b558
Merge branch 'feature/replay-ratio' of https://github.com/Eclectic-Sh…
belerico Mar 28, 2024
32b89b4
Add DrOQ replay ratio
belerico Mar 29, 2024
d057886
Fix tests
belerico Mar 29, 2024
b9044a3
Fix mispelled
belerico Mar 29, 2024
5bd7d75
Fix wrong attribute accesing
belerico Mar 29, 2024
8d94f68
FIx naming and configs
belerico Mar 29, 2024
cae85a3
Merge branch 'fix/dv3-continue-on-terminated' of github.com:Eclectic-…
michele-milesi Mar 29, 2024
e5dd8fd
feat: add terminated and truncated to dreamer, p2e and ppo algos
michele-milesi Mar 29, 2024
fdd4579
fix: dmc wrapper
michele-milesi Mar 29, 2024
a2a2690
feat: update algos to split terminated from truncated
michele-milesi Mar 29, 2024
74bfb6b
fix: crafter and diambra wrappers
michele-milesi Mar 29, 2024
3d1f2c9
Merge branch 'main' of github.com:Eclectic-Sheep/sheeprl into fix/ter…
michele-milesi Mar 30, 2024
05e4370
feat: replace done with truncated key in when the buffer is added to …
michele-milesi Mar 30, 2024
87c9098
feat: added truncated/terminated to minedojo environment
michele-milesi Mar 30, 2024
e137a38
feat: added terminated/truncated to minerl and super mario bros envs
michele-milesi Apr 2, 2024
b557835
Merge branch 'main' of github.com:Eclectic-Sheep/sheeprl into fix/ter…
michele-milesi Apr 2, 2024
64d3c81
docs: update howto
michele-milesi Apr 2, 2024
2e156f3
fix: minedojo wrapper
michele-milesi Apr 2, 2024
0167fd5
docs: update
michele-milesi Apr 2, 2024
09e051e
fix: minedojo
michele-milesi Apr 2, 2024
dacd425
update dependencies
michele-milesi Apr 2, 2024
f2557a3
fix: minedojo
michele-milesi Apr 2, 2024
5bf50dd
fix: dv3 small configs
michele-milesi Apr 2, 2024
f58a3c2
fix: episode buffer and tests
michele-milesi Apr 2, 2024
d19a8ba
feat: added possibility to choose layernorm and kwargs
michele-milesi Apr 2, 2024
3c9e1f6
Merge branch 'main' of github.com:Eclectic-Sheep/sheeprl into fix/dre…
michele-milesi Apr 2, 2024
4f15952
feat: added first recurrent state learnable in dv3
michele-milesi Apr 2, 2024
66d4d92
feat: update dv3 ww configs
michele-milesi Apr 2, 2024
d536264
feat: learned initial recurrent state when resetting the player states)
michele-milesi Apr 2, 2024
6108afb
fix: env interaction
michele-milesi Apr 2, 2024
01484e9
fix: avoid to rewrite with layer_norm kwargs
michele-milesi Apr 3, 2024
fd62f69
fix: avoid to rewrite with layer_norm kwargs
michele-milesi Apr 3, 2024
2df458a
fix: dv2 LayerNormGruCell creation
michele-milesi Apr 3, 2024
fe4807a
fix: dv2 LayerNormGruCell creation
michele-milesi Apr 3, 2024
6a8c6cb
fix: update p2e dv3 + fix tests
michele-milesi Apr 3, 2024
1e6f309
fix: tests
michele-milesi Apr 3, 2024
e77c52d
fix: tests
michele-milesi Apr 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion sheeprl/algos/dreamer_v1/dreamer_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")}
if len(mask) == 0:
mask = None
real_actions = actions = player.get_exploration_actions(normalized_obs, mask, step=policy_step)
real_actions = actions = player.get_exploration_actions(normalized_obs, mask=mask, step=policy_step)
actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
Expand Down
4 changes: 3 additions & 1 deletion sheeprl/algos/dreamer_v2/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,9 @@ def __init__(
norm_layer=[nn.LayerNorm] if layer_norm else None,
norm_args=[{"normalized_shape": dense_units}] if layer_norm else None,
)
self.rnn = LayerNormGRUCell(dense_units, recurrent_state_size, bias=True, batch_first=False, layer_norm=True)
self.rnn = LayerNormGRUCell(
belerico marked this conversation as resolved.
Show resolved Hide resolved
dense_units, recurrent_state_size, bias=True, batch_first=False, layer_norm_cls=nn.LayerNorm
)

def forward(self, input: Tensor, recurrent_state: Tensor) -> Tensor:
"""
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/dreamer_v2/dreamer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")}
if len(mask) == 0:
mask = None
real_actions = actions = player.get_actions(normalized_obs, mask)
real_actions = actions = player.get_actions(normalized_obs, mask=mask)
actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
Expand Down
296 changes: 161 additions & 135 deletions sheeprl/algos/dreamer_v3/agent.py

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion sheeprl/algos/dreamer_v3/dreamer_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")}
if len(mask) == 0:
mask = None
real_actions = actions = player.get_actions(preprocessed_obs, mask)
real_actions = actions = player.get_actions(preprocessed_obs, mask=mask)
actions = torch.cat(actions, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, dim=-1).cpu().numpy()
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/dreamer_v3/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(
self.register_buffer("high", torch.zeros((), dtype=torch.float32))

def forward(self, x: Tensor, fabric: Fabric) -> Any:
gathered_x = fabric.all_gather(x).detach()
gathered_x = fabric.all_gather(x).float().detach()
low = torch.quantile(gathered_x, self._percentile_low)
high = torch.quantile(gathered_x, self._percentile_high)
self.low = self._decay * self.low + (1 - self._decay) * low
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")}
if len(mask) == 0:
mask = None
real_actions = actions = player.get_exploration_actions(normalized_obs, mask, step=policy_step)
real_actions = actions = player.get_exploration_actions(normalized_obs, mask=mask, step=policy_step)
actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/p2e_dv1/p2e_dv1_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")}
if len(mask) == 0:
mask = None
real_actions = actions = player.get_exploration_actions(normalized_obs, mask, step=policy_step)
real_actions = actions = player.get_exploration_actions(normalized_obs, mask=mask, step=policy_step)
actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")}
if len(mask) == 0:
mask = None
real_actions = actions = player.get_actions(normalized_obs, mask)
real_actions = actions = player.get_actions(normalized_obs, mask=mask)
actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/p2e_dv2/p2e_dv2_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
mask = {k: v for k, v in normalized_obs.items() if k.startswith("mask")}
if len(mask) == 0:
mask = None
real_actions = actions = player.get_actions(normalized_obs, mask)
real_actions = actions = player.get_actions(normalized_obs, mask=mask)
actions = torch.cat(actions, -1).view(cfg.env.num_envs, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, -1).cpu().numpy()
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")}
if len(mask) == 0:
mask = None
real_actions = actions = player.get_actions(preprocessed_obs, mask)
real_actions = actions = player.get_actions(preprocessed_obs, mask=mask)
actions = torch.cat(actions, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, dim=-1).cpu().numpy()
Expand Down
2 changes: 1 addition & 1 deletion sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
mask = {k: v for k, v in preprocessed_obs.items() if k.startswith("mask")}
if len(mask) == 0:
mask = None
real_actions = actions = player.get_actions(preprocessed_obs, mask)
real_actions = actions = player.get_actions(preprocessed_obs, mask=mask)
actions = torch.cat(actions, -1).cpu().numpy()
if is_continuous:
real_actions = torch.cat(real_actions, dim=-1).cpu().numpy()
Expand Down
29 changes: 19 additions & 10 deletions sheeprl/configs/algo/dreamer_v3.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,14 @@ mlp_keys:
decoder: ${algo.mlp_keys.encoder}

# Model related parameters
layer_norm: True
cnn_layer_norm:
cls: sheeprl.utils.model.LayerNormChannelLastFP32
kw:
eps: 1e-3
mlp_layer_norm:
cls: sheeprl.utils.model.LayerNormFP32
kw:
eps: 1e-3
dense_units: 1024
mlp_layers: 5
dense_act: torch.nn.SiLU
Expand All @@ -51,41 +58,43 @@ world_model:
cnn_act: ${algo.cnn_act}
dense_act: ${algo.dense_act}
mlp_layers: ${algo.mlp_layers}
layer_norm: ${algo.layer_norm}
cnn_layer_norm: ${algo.cnn_layer_norm}
mlp_layer_norm: ${algo.mlp_layer_norm}
dense_units: ${algo.dense_units}

# Recurrent model
recurrent_model:
recurrent_state_size: 4096
layer_norm: True
layer_norm: ${algo.mlp_layer_norm}
dense_units: ${algo.dense_units}

# Prior
transition_model:
hidden_size: 1024
dense_act: ${algo.dense_act}
layer_norm: ${algo.layer_norm}
layer_norm: ${algo.mlp_layer_norm}

# Posterior
representation_model:
hidden_size: 1024
dense_act: ${algo.dense_act}
layer_norm: ${algo.layer_norm}
layer_norm: ${algo.mlp_layer_norm}

# Decoder
observation_model:
cnn_channels_multiplier: ${algo.world_model.encoder.cnn_channels_multiplier}
cnn_act: ${algo.cnn_act}
dense_act: ${algo.dense_act}
mlp_layers: ${algo.mlp_layers}
layer_norm: ${algo.layer_norm}
cnn_layer_norm: ${algo.cnn_layer_norm}
mlp_layer_norm: ${algo.mlp_layer_norm}
dense_units: ${algo.dense_units}

# Reward model
reward_model:
dense_act: ${algo.dense_act}
mlp_layers: ${algo.mlp_layers}
layer_norm: ${algo.layer_norm}
layer_norm: ${algo.mlp_layer_norm}
dense_units: ${algo.dense_units}
bins: 255

Expand All @@ -94,7 +103,7 @@ world_model:
learnable: True
dense_act: ${algo.dense_act}
mlp_layers: ${algo.mlp_layers}
layer_norm: ${algo.layer_norm}
layer_norm: ${algo.mlp_layer_norm}
dense_units: ${algo.dense_units}

# World model optimizer
Expand All @@ -112,7 +121,7 @@ actor:
init_std: 2.0
dense_act: ${algo.dense_act}
mlp_layers: ${algo.mlp_layers}
layer_norm: ${algo.layer_norm}
layer_norm: ${algo.mlp_layer_norm}
dense_units: ${algo.dense_units}
clip_gradients: 100.0
unimix: ${algo.unimix}
Expand All @@ -136,7 +145,7 @@ actor:
critic:
dense_act: ${algo.dense_act}
mlp_layers: ${algo.mlp_layers}
layer_norm: ${algo.layer_norm}
layer_norm: ${algo.mlp_layer_norm}
dense_units: ${algo.dense_units}
per_rank_target_network_update_freq: 1
tau: 0.02
Expand Down
15 changes: 14 additions & 1 deletion sheeprl/configs/exp/dreamer_v3_dmc_walker_walk.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,22 @@ algo:
- rgb
mlp_keys:
encoder: []
learning_starts: 1024
learning_starts: 1300
replay_ratio: 0.5

# Metric
metric:
log_every: 5000

fabric:
accelerator: cuda
precision: bf16-mixed
# precision: None
# plugins:
# - _target_: lightning.fabric.plugins.precision.MixedPrecision
# precision: 16-mixed
# device: cuda
# scaler:
# _target_: torch.cuda.amp.GradScaler
# init_scale: 1e4
# growth_interval: 1000
23 changes: 15 additions & 8 deletions sheeprl/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import warnings
from math import prod
from typing import Dict, Optional, Sequence, Union, no_type_check
from typing import Any, Callable, Dict, Optional, Sequence, Union, no_type_check

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -342,23 +342,30 @@ class LayerNormGRUCell(nn.Module):
Defaults to True.
batch_first (bool, optional): whether the first dimension represent the batch dimension or not.
Defaults to False.
layer_norm (bool, optional): whether to apply a LayerNorm after the input projection.
Defaults to False.
layer_norm_cls (Callable[..., nn.Module]): the layer norm to apply after the input projection.
Defaults to nn.Identiy.
layer_norm_kw (Dict[str, Any]): the kwargs of the layer norm.
Default to {}.
"""

def __init__(
self, input_size: int, hidden_size: int, bias: bool = True, batch_first: bool = False, layer_norm: bool = False
self,
input_size: int,
hidden_size: int,
bias: bool = True,
batch_first: bool = False,
layer_norm_cls: Callable[..., nn.Module] = nn.Identity,
layer_norm_kw: Dict[str, Any] = {},
) -> None:
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.bias = bias
self.batch_first = batch_first
self.linear = nn.Linear(input_size + hidden_size, 3 * hidden_size, bias=self.bias)
if layer_norm:
self.layer_norm = torch.nn.LayerNorm(3 * hidden_size)
else:
self.layer_norm = nn.Identity()
# Avoid multiple values for the `normalized_shape` argument
layer_norm_kw.pop("normalized_shape", None)
self.layer_norm = layer_norm_cls(3 * hidden_size, **layer_norm_kw)

def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
is_3d = input.dim() == 3
Expand Down
17 changes: 17 additions & 0 deletions sheeprl/utils/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Adapted from: https://github.com/thu-ml/tianshou/blob/master/tianshou/utils/net/common.py
"""

from typing import Any, Dict, List, Optional, Tuple, Type, Union

import torch
Expand Down Expand Up @@ -233,3 +234,19 @@ def forward(self, x: Tensor) -> Tensor:
x = super().forward(x)
x = x.permute(0, 3, 1, 2)
return x


class LayerNormChannelLastFP32(LayerNormChannelLast):
def forward(self, x: Tensor) -> Tensor:
input_dtype = x.dtype
x = x.to(torch.float32)
out = super().forward(x)
return out.to(input_dtype)


class LayerNormFP32(nn.LayerNorm):
def forward(self, x: Tensor) -> Tensor:
input_dtype = x.dtype
x = x.to(torch.float32)
out = super().forward(x)
return out.to(input_dtype)
Loading