Skip to content

Commit

Permalink
feature(whl): add PC algorithm (#514)
Browse files Browse the repository at this point in the history
* feature(whl): add PC algorithm

* code style

* bug fix and hyperparameter tuning

* debug

* tuning pc

* polish pc

* PC update

* remove files

* remove files

* polish

* reformat

* add bc main entry

* debug

* debug

* polish pc policy

* polish pc entry

* polish
  • Loading branch information
kxzxvbk authored Mar 9, 2023
1 parent f798002 commit 7601e03
Show file tree
Hide file tree
Showing 19 changed files with 1,320 additions and 23 deletions.
1 change: 1 addition & 0 deletions ding/entry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@
from .application_entry_drex_collect_data import drex_collecting_data
from .serial_entry_mbrl import serial_pipeline_dyna, serial_pipeline_dream
from .serial_entry_bco import serial_pipeline_bco
from .serial_entry_pc import serial_pipeline_pc
108 changes: 108 additions & 0 deletions ding/entry/serial_entry_pc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from typing import Union, Optional, Tuple
import os
from functools import partial
from copy import deepcopy

import torch
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader

from ding.envs import get_vec_env_setting, create_env_manager
from ding.worker import BaseLearner, InteractionSerialEvaluator
from ding.config import read_config, compile_config
from ding.policy import create_policy
from ding.utils import set_pkg_seed
from ding.utils.data.dataset import load_bfs_datasets


def serial_pipeline_pc(
input_cfg: Union[str, Tuple[dict, dict]],
seed: int = 0,
model: Optional[torch.nn.Module] = None,
max_iter=int(1e6),
) -> Union['Policy', bool]: # noqa
r"""
Overview:
Serial pipeline entry of procedure cloning using BFS as expert policy.
Arguments:
- input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
``str`` type means config file path. \
``Tuple[dict, dict]`` type means [user_config, create_cfg].
- seed (:obj:`int`): Random seed.
- model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
- max_iter (:obj:`Optional[int]`): Max iteration for executing PC training.
Returns:
- policy (:obj:`Policy`): Converged policy.
- convergence (:obj:`bool`): whether the training is converged
"""
if isinstance(input_cfg, str):
cfg, create_cfg = read_config(input_cfg)
else:
cfg, create_cfg = deepcopy(input_cfg)
cfg = compile_config(cfg, seed=seed, auto=True, create_cfg=create_cfg)

# Env, Policy
env_fn, _, evaluator_env_cfg = get_vec_env_setting(cfg.env)
evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
# Random seed
evaluator_env.seed(cfg.seed, dynamic_seed=False)
set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'eval'])

# Main components
tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
train_data, test_data = load_bfs_datasets(train_seeds=cfg.train_seeds)
dataloader = DataLoader(train_data, batch_size=cfg.policy.learn.batch_size, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=cfg.policy.learn.batch_size, shuffle=True)
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
evaluator = InteractionSerialEvaluator(
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
)

# ==========
# Main loop
# ==========
learner.call_hook('before_run')
stop = False
iter_cnt = 0
for epoch in range(cfg.policy.learn.train_epoch):
# train
criterion = torch.nn.CrossEntropyLoss()
for i, train_data in enumerate(dataloader):
learner.train(train_data)
iter_cnt += 1
if iter_cnt >= max_iter:
stop = True
break
if epoch % 69 == 0:
policy._optimizer.param_groups[0]['lr'] /= 10
if stop:
break
losses = []
acces = []
# Evaluation
for _, test_data in enumerate(test_dataloader):
observations, bfs_input_maps, bfs_output_maps = test_data['obs'], test_data['bfs_in'].long(), \
test_data['bfs_out'].long()
states = observations
bfs_input_onehot = torch.nn.functional.one_hot(bfs_input_maps, 5).float()

bfs_states = torch.cat([
states,
bfs_input_onehot,
], dim=-1).cuda()
logits = policy._model(bfs_states)['logit']
logits = logits.flatten(0, -2)
labels = bfs_output_maps.flatten(0, -1).cuda()

loss = criterion(logits, labels).item()
preds = torch.argmax(logits, dim=-1)
acc = torch.sum((preds == labels)) / preds.shape[0]

losses.append(loss)
acces.append(acc)
print('Test Finished! Loss: {} acc: {}'.format(sum(losses) / len(losses), sum(acces) / len(acces)))
stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter)
learner.call_hook('after_run')
print('final reward is: {}'.format(reward))
return policy, stop
2 changes: 1 addition & 1 deletion ding/model/template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@
from .madqn import MADQN
from .vae import VanillaVAE
from .decision_transformer import DecisionTransformer
from .procedure_cloning import ProcedureCloning
from .procedure_cloning import ProcedureCloningMCTS, ProcedureCloningBFS
28 changes: 18 additions & 10 deletions ding/model/template/bc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
class DiscreteBC(nn.Module):

def __init__(
self,
obs_shape: Union[int, SequenceType],
action_shape: Union[int, SequenceType],
encoder_hidden_size_list: SequenceType = [128, 128, 64],
dueling: bool = True,
head_hidden_size: Optional[int] = None,
head_layer_num: int = 1,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None
self,
obs_shape: Union[int, SequenceType],
action_shape: Union[int, SequenceType],
encoder_hidden_size_list: SequenceType = [128, 128, 64],
dueling: bool = True,
head_hidden_size: Optional[int] = None,
head_layer_num: int = 1,
activation: Optional[nn.Module] = nn.ReLU(),
norm_type: Optional[str] = None,
strides: Optional[list] = None,
) -> None:
"""
Overview:
Expand Down Expand Up @@ -49,7 +50,14 @@ def __init__(
self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
# Conv Encoder
elif len(obs_shape) == 3:
self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
if not strides:
self.encoder = ConvEncoder(
obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type
)
else:
self.encoder = ConvEncoder(
obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type, stride=strides
)
else:
raise RuntimeError(
"not support obs_shape for pre-defined encoder: {}, please customize your own BC".format(obs_shape)
Expand Down
101 changes: 97 additions & 4 deletions ding/model/template/procedure_cloning.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Optional, Tuple
from typing import Optional, Tuple, Union, Dict

import torch
import torch.nn as nn

from ding.utils import MODEL_REGISTRY, SequenceType
from ding.torch_utils.network.transformer import Attention
from ding.torch_utils.network.nn_module import fc_block, build_normalization
Expand Down Expand Up @@ -42,8 +44,8 @@ def forward(self, x: torch.Tensor):
return x


@MODEL_REGISTRY.register('pc')
class ProcedureCloning(nn.Module):
@MODEL_REGISTRY.register('pc_mcts')
class ProcedureCloningMCTS(nn.Module):

def __init__(
self,
Expand All @@ -53,7 +55,7 @@ def __init__(
cnn_activation: Optional[nn.Module] = nn.ReLU(),
cnn_kernel_size: SequenceType = [3, 3, 3, 3, 3],
cnn_stride: SequenceType = [1, 1, 1, 1, 1],
cnn_padding: Optional[SequenceType] = ['same', 'same', 'same', 'same', 'same'],
cnn_padding: Optional[SequenceType] = [1, 1, 1, 1, 1],
mlp_hidden_list: SequenceType = [256, 256],
mlp_activation: Optional[nn.Module] = nn.ReLU(),
att_heads: int = 8,
Expand Down Expand Up @@ -117,3 +119,94 @@ def forward(self, states: torch.Tensor, goals: torch.Tensor,
action_preds = self.predict_action(h[:, 1:, :])

return goal_preds, action_preds


class BFSConvEncoder(nn.Module):
"""
Overview: The ``BFSConvolution Encoder`` used to encode raw 2-dim observations. And output a feature map with the
same height and width as input. Interfaces: ``__init__``, ``forward``.
"""

def __init__(
self,
obs_shape: SequenceType,
hidden_size_list: SequenceType = [32, 64, 64, 128],
activation: Optional[nn.Module] = nn.ReLU(),
kernel_size: SequenceType = [8, 4, 3],
stride: SequenceType = [4, 2, 1],
padding: Optional[SequenceType] = None,
) -> None:
"""
Overview:
Init the ``BFSConvolution Encoder`` according to the provided arguments.
Arguments:
- obs_shape (:obj:`SequenceType`): Sequence of ``in_channel``, plus one or more ``input size``.
- hidden_size_list (:obj:`SequenceType`): Sequence of ``hidden_size`` of subsequent conv layers \
and the final dense layer.
- activation (:obj:`nn.Module`): Type of activation to use in the conv ``layers`` and ``ResBlock``. \
Default is ``nn.ReLU()``.
- kernel_size (:obj:`SequenceType`): Sequence of ``kernel_size`` of subsequent conv layers.
- stride (:obj:`SequenceType`): Sequence of ``stride`` of subsequent conv layers.
- padding (:obj:`SequenceType`): Padding added to all four sides of the input for each conv layer. \
See ``nn.Conv2d`` for more details. Default is ``None``.
"""
super(BFSConvEncoder, self).__init__()
self.obs_shape = obs_shape
self.act = activation
self.hidden_size_list = hidden_size_list
if padding is None:
padding = [0 for _ in range(len(kernel_size))]

layers = []
input_size = obs_shape[0] # in_channel
for i in range(len(kernel_size)):
layers.append(nn.Conv2d(input_size, hidden_size_list[i], kernel_size[i], stride[i], padding[i]))
layers.append(self.act)
input_size = hidden_size_list[i]
layers = layers[:-1]
self.main = nn.Sequential(*layers)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Overview:
Return output tensor of the env observation.
Arguments:
- x (:obj:`torch.Tensor`): Env raw observation.
Returns:
- outputs (:obj:`torch.Tensor`): Output embedding tensor.
Shapes:
- outputs: :math:`(B, N, H, W)`, where ``N = hidden_size_list[-1]``.
"""
return self.main(x)


@MODEL_REGISTRY.register('pc_bfs')
class ProcedureCloningBFS(nn.Module):

def __init__(
self,
obs_shape: Union[int, SequenceType],
action_shape: Union[int, SequenceType],
encoder_hidden_size_list: SequenceType = [128, 128, 256, 256],
):
super().__init__()
num_layers = len(encoder_hidden_size_list)

kernel_sizes = (3, ) * (num_layers + 1)
stride_sizes = (1, ) * (num_layers + 1)
padding_sizes = (1, ) * (num_layers + 1)
# The output channel equals to action_shape + 1
encoder_hidden_size_list.append(action_shape + 1)

self._encoder = BFSConvEncoder(
obs_shape=obs_shape,
hidden_size_list=encoder_hidden_size_list,
kernel_size=kernel_sizes,
stride=stride_sizes,
padding=padding_sizes,
)

def forward(self, x: torch.Tensor) -> Dict:
x = x.permute(0, 3, 1, 2)
x = self._encoder(x)
return {'logit': x.permute(0, 2, 3, 1)}
20 changes: 13 additions & 7 deletions ding/model/template/tests/test_procedure_cloning.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import torch
import pytest
import numpy as np
from itertools import product

from ding.model.template import ProcedureCloning
from ding.torch_utils import is_differentiable
from ding.utils import squeeze
import torch

from ding.model.template import ProcedureCloningMCTS, ProcedureCloningBFS

B = 4
T = 15
Expand All @@ -19,16 +17,24 @@
@pytest.mark.parametrize('obs_shape, action_dim', args)
class TestProcedureCloning:

def test_procedure_cloning(self, obs_shape, action_dim):
def test_procedure_cloning_mcts(self, obs_shape, action_dim):
inputs = {
'states': torch.randn(B, *obs_shape),
'goals': torch.randn(B, *obs_shape),
'actions': torch.randn(B, T, action_dim)
}
model = ProcedureCloning(obs_shape=obs_shape, action_dim=action_dim)
model = ProcedureCloningMCTS(obs_shape=obs_shape, action_dim=action_dim)

print(model)

goal_preds, action_preds = model(inputs['states'], inputs['goals'], inputs['actions'])
assert goal_preds.shape == (B, obs_embeddings)
assert action_preds.shape == (B, T + 1, action_dim)

def test_procedure_cloning_bfs(self, obs_shape, action_dim):
o_shape = (obs_shape[2], obs_shape[0], obs_shape[1])
model = ProcedureCloningBFS(obs_shape=o_shape, action_shape=action_dim)

inputs = torch.randn(B, *obs_shape)
map_preds = model(inputs)
assert map_preds['logit'].shape == (B, obs_shape[0], obs_shape[1], action_dim + 1)
2 changes: 2 additions & 0 deletions ding/policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,7 @@
from .bc import BehaviourCloningPolicy
from .ibc import IBCPolicy

from .pc import ProcedureCloningBFSPolicy

# new-type policy
from .ppof import PPOFPolicy
Loading

0 comments on commit 7601e03

Please sign in to comment.