Skip to content

Commit

Permalink
Call to unwrap_fabric before p2e test
Browse files Browse the repository at this point in the history
  • Loading branch information
belerico committed Apr 8, 2024
1 parent 9736d3a commit ac74c1d
Show file tree
Hide file tree
Showing 5 changed files with 203 additions and 8 deletions.
4 changes: 2 additions & 2 deletions sheeprl/algos/p2e_dv1/p2e_dv1_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from sheeprl.utils.metric import MetricAggregator
from sheeprl.utils.registry import register_algorithm
from sheeprl.utils.timer import timer
from sheeprl.utils.utils import Ratio, save_configs
from sheeprl.utils.utils import Ratio, save_configs, unwrap_fabric

# Decomment the following line if you are using MineDojo on an headless machine
# os.environ["MINEDOJO_HEADLESS"] = "1"
Expand Down Expand Up @@ -786,7 +786,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
if fabric.is_global_zero and cfg.algo.run_test:
player.actor_type = "task"
fabric_player = get_single_device_fabric(fabric)
player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task))
player.actor = fabric_player.setup_module(unwrap_fabric(actor_task))
test(player, fabric, cfg, log_dir, "zero-shot")

if not cfg.model_manager.disabled and fabric.is_global_zero:
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/p2e_dv2/p2e_dv2_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from sheeprl.utils.metric import MetricAggregator
from sheeprl.utils.registry import register_algorithm
from sheeprl.utils.timer import timer
from sheeprl.utils.utils import Ratio, save_configs
from sheeprl.utils.utils import Ratio, save_configs, unwrap_fabric

# Decomment the following line if you are using MineDojo on an headless machine
# os.environ["MINEDOJO_HEADLESS"] = "1"
Expand Down Expand Up @@ -937,7 +937,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
if fabric.is_global_zero and cfg.algo.run_test:
player.actor_type = "task"
fabric_player = get_single_device_fabric(fabric)
player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task))
player.actor = fabric_player.setup_module(unwrap_fabric(actor_task))
test(player, fabric, cfg, log_dir, "zero-shot")

if not cfg.model_manager.disabled and fabric.is_global_zero:
Expand Down
4 changes: 2 additions & 2 deletions sheeprl/algos/p2e_dv3/p2e_dv3_exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from sheeprl.utils.metric import MetricAggregator
from sheeprl.utils.registry import register_algorithm
from sheeprl.utils.timer import timer
from sheeprl.utils.utils import Ratio, save_configs
from sheeprl.utils.utils import Ratio, save_configs, unwrap_fabric

# Decomment the following line if you are using MineDojo on an headless machine
# os.environ["MINEDOJO_HEADLESS"] = "1"
Expand Down Expand Up @@ -1033,7 +1033,7 @@ def main(fabric: Fabric, cfg: Dict[str, Any]):
if fabric.is_global_zero and cfg.algo.run_test:
player.actor_type = "task"
fabric_player = get_single_device_fabric(fabric)
player.actor = fabric_player.setup_module(getattr(actor_task, "module", actor_task))
player.actor = fabric_player.setup_module(unwrap_fabric(actor_task))
test(player, fabric, cfg, log_dir, "zero-shot", greedy=False)

if not cfg.model_manager.disabled and fabric.is_global_zero:
Expand Down
2 changes: 0 additions & 2 deletions sheeprl/algos/p2e_dv3/p2e_dv3_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,8 +463,6 @@ def main(fabric: Fabric, cfg: Dict[str, Any], exploration_cfg: Dict[str, Any]):
if fabric.is_global_zero and cfg.algo.run_test:
player.actor_type = "task"
player.actor = fabric_player.setup_module(unwrap_fabric(actor_task))
for agent_p, p in zip(actor_task.parameters(), player.actor.parameters()):
p.data = agent_p.data
test(player, fabric, cfg, log_dir, "few-shot", greedy=False)

if not cfg.model_manager.disabled and fabric.is_global_zero:
Expand Down
197 changes: 197 additions & 0 deletions sheeprl/models/impala.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import math
from copy import deepcopy
from typing import Dict, Optional, Sequence, Tuple

import torch as th
import torch.nn as nn
import torch.nn.functional as F


class FanInInitReLULayer(nn.Module):
def __init__(
self,
inchan: int,
outchan: int,
layer_type: str = "conv",
init_scale: float = 1.0,
batch_norm: bool = False,
batch_norm_kwargs: Dict = {},
group_norm_groups: Optional[int] = None,
layer_norm: bool = False,
use_activation: bool = True,
**layer_kwargs,
):
super().__init__()

# Normalization
self.norm = None
if batch_norm:
self.norm = nn.BatchNorm2d(inchan, **batch_norm_kwargs)
elif group_norm_groups is not None:
self.norm = nn.GroupNorm(group_norm_groups, inchan)
elif layer_norm:
self.norm = nn.LayerNorm(inchan)

# Layer
layer = dict(conv=nn.Conv2d, conv3d=nn.Conv3d, linear=nn.Linear)[layer_type]
self.layer = layer(inchan, outchan, bias=self.norm is None, **layer_kwargs)
self.use_activation = use_activation

# Initialization
self.layer.weight.data *= init_scale / self.layer.weight.norm(
dim=tuple(range(1, self.layer.weight.data.ndim)), p=2, keepdim=True
)
if self.layer.bias is not None:
self.layer.bias.data *= 0

def forward(self, x: th.Tensor):
if self.norm is not None:
x = self.norm(x)
x = self.layer(x)
if self.use_activation:
x = F.relu(x, inplace=True)
return x


class CnnBasicBlock(nn.Module):
def __init__(
self,
inchan: int,
init_scale: float = 1.0,
init_norm_kwargs: Dict = {},
):
super().__init__()

# Layers
s = math.sqrt(init_scale)
self.conv0 = FanInInitReLULayer(
inchan,
inchan,
kernel_size=3,
padding=1,
init_scale=s,
**init_norm_kwargs,
)
self.conv1 = FanInInitReLULayer(
inchan,
inchan,
kernel_size=3,
padding=1,
init_scale=s,
**init_norm_kwargs,
)

def forward(self, x: th.Tensor) -> th.Tensor:
x = x + self.conv1(self.conv0(x))
return x


class CnnDownStack(nn.Module):
def __init__(
self,
inchan: int,
nblock: int,
outchan: int,
init_scale: float = 1.0,
pool: bool = True,
post_pool_groups: Optional[int] = None,
init_norm_kwargs: Dict = {},
first_conv_norm: bool = False,
**kwargs,
):
super().__init__()

# Params
self.inchan = inchan
self.outchan = outchan
self.pool = pool

# Layers
first_conv_init_kwargs = deepcopy(init_norm_kwargs)
if not first_conv_norm:
first_conv_init_kwargs["group_norm_groups"] = None
first_conv_init_kwargs["batch_norm"] = False
self.firstconv = FanInInitReLULayer(
inchan,
outchan,
kernel_size=3,
padding=1,
**first_conv_init_kwargs,
)
self.post_pool_groups = post_pool_groups
if post_pool_groups is not None:
self.n = nn.GroupNorm(post_pool_groups, outchan)
self.blocks = nn.ModuleList(
[
CnnBasicBlock(
outchan,
init_scale=init_scale / math.sqrt(nblock),
init_norm_kwargs=init_norm_kwargs,
**kwargs,
)
for _ in range(nblock)
]
)

def forward(self, x: th.Tensor) -> th.Tensor:
x = self.firstconv(x)
if self.pool:
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
if self.post_pool_groups is not None:
x = self.n(x)
for block in self.blocks:
x = block(x)
return x

def output_shape(self, inshape: Sequence[int]) -> Tuple[int, int, int]:
c, h, w = inshape
assert c == self.inchan
if self.pool:
return (self.outchan, (h + 1) // 2, (w + 1) // 2)
else:
return (self.outchan, h, w)


class ImpalaCNN(nn.Module):
def __init__(
self,
inshape: Sequence[int],
chans: Sequence[int],
outsize: int,
nblock: int,
init_norm_kwargs: Dict = {},
dense_init_norm_kwargs: Dict = {},
first_conv_norm: bool = False,
**kwargs,
):
super().__init__()

# Layers
curshape = inshape
self.stacks = nn.ModuleList()
for i, outchan in enumerate(chans):
stack = CnnDownStack(
curshape[0],
nblock=nblock,
outchan=outchan,
init_scale=1.0 / math.sqrt(len(chans)),
init_norm_kwargs=init_norm_kwargs,
first_conv_norm=first_conv_norm if i == 0 else True,
**kwargs,
)
self.stacks.append(stack)
curshape = stack.output_shape(curshape)
self.dense = FanInInitReLULayer(
math.prod(curshape),
outsize,
layer_type="linear",
init_scale=1.4,
**dense_init_norm_kwargs,
)

def forward(self, x: th.Tensor) -> th.Tensor:
for stack in self.stacks:
x = stack(x)
x = x.reshape(x.size(0), -1)
x = self.dense(x)
return x

0 comments on commit ac74c1d

Please sign in to comment.