From 7601e03466dc0fcef47b4a30568c3bb7e608f086 Mon Sep 17 00:00:00 2001 From: Wang hl <59834623+kxzxvbk@users.noreply.github.com> Date: Thu, 9 Mar 2023 01:41:27 -0500 Subject: [PATCH] feature(whl): add PC algorithm (#514) * feature(whl): add PC algorithm * code style * bug fix and hyperparameter tuning * debug * tuning pc * polish pc * PC update * remove files * remove files * polish * reformat * add bc main entry * debug * debug * polish pc policy * polish pc entry * polish --- ding/entry/__init__.py | 1 + ding/entry/serial_entry_pc.py | 108 +++++ ding/model/template/__init__.py | 2 +- ding/model/template/bc.py | 28 +- ding/model/template/procedure_cloning.py | 101 ++++- .../template/tests/test_procedure_cloning.py | 20 +- ding/policy/__init__.py | 2 + ding/policy/pc.py | 187 +++++++++ ding/utils/__init__.py | 1 + ding/utils/bfs_helper.py | 59 +++ ding/utils/data/dataset.py | 77 +++- ding/utils/tests/test_bfs_helper.py | 26 ++ dizoo/maze/__init__.py | 3 + dizoo/maze/config/maze_bc_config.py | 56 +++ dizoo/maze/config/maze_pc_config.py | 57 +++ dizoo/maze/entry/maze_bc_main.py | 206 ++++++++++ dizoo/maze/envs/__init__.py | 1 + dizoo/maze/envs/maze_env.py | 380 ++++++++++++++++++ dizoo/maze/envs/test_maze_env.py | 28 ++ 19 files changed, 1320 insertions(+), 23 deletions(-) create mode 100644 ding/entry/serial_entry_pc.py create mode 100644 ding/policy/pc.py create mode 100644 ding/utils/bfs_helper.py create mode 100644 ding/utils/tests/test_bfs_helper.py create mode 100644 dizoo/maze/__init__.py create mode 100644 dizoo/maze/config/maze_bc_config.py create mode 100644 dizoo/maze/config/maze_pc_config.py create mode 100644 dizoo/maze/entry/maze_bc_main.py create mode 100644 dizoo/maze/envs/__init__.py create mode 100644 dizoo/maze/envs/maze_env.py create mode 100644 dizoo/maze/envs/test_maze_env.py diff --git a/ding/entry/__init__.py b/ding/entry/__init__.py index 1e90351ee4..e0501b12db 100644 --- a/ding/entry/__init__.py +++ b/ding/entry/__init__.py @@ -27,3 +27,4 @@ from .application_entry_drex_collect_data import drex_collecting_data from .serial_entry_mbrl import serial_pipeline_dyna, serial_pipeline_dream from .serial_entry_bco import serial_pipeline_bco +from .serial_entry_pc import serial_pipeline_pc diff --git a/ding/entry/serial_entry_pc.py b/ding/entry/serial_entry_pc.py new file mode 100644 index 0000000000..386d6f0ec9 --- /dev/null +++ b/ding/entry/serial_entry_pc.py @@ -0,0 +1,108 @@ +from typing import Union, Optional, Tuple +import os +from functools import partial +from copy import deepcopy + +import torch +from tensorboardX import SummaryWriter +from torch.utils.data import DataLoader + +from ding.envs import get_vec_env_setting, create_env_manager +from ding.worker import BaseLearner, InteractionSerialEvaluator +from ding.config import read_config, compile_config +from ding.policy import create_policy +from ding.utils import set_pkg_seed +from ding.utils.data.dataset import load_bfs_datasets + + +def serial_pipeline_pc( + input_cfg: Union[str, Tuple[dict, dict]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + max_iter=int(1e6), +) -> Union['Policy', bool]: # noqa + r""" + Overview: + Serial pipeline entry of procedure cloning using BFS as expert policy. + Arguments: + - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ + ``str`` type means config file path. \ + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + - max_iter (:obj:`Optional[int]`): Max iteration for executing PC training. + Returns: + - policy (:obj:`Policy`): Converged policy. + - convergence (:obj:`bool`): whether the training is converged + """ + if isinstance(input_cfg, str): + cfg, create_cfg = read_config(input_cfg) + else: + cfg, create_cfg = deepcopy(input_cfg) + cfg = compile_config(cfg, seed=seed, auto=True, create_cfg=create_cfg) + + # Env, Policy + env_fn, _, evaluator_env_cfg = get_vec_env_setting(cfg.env) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + # Random seed + evaluator_env.seed(cfg.seed, dynamic_seed=False) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'eval']) + + # Main components + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) + train_data, test_data = load_bfs_datasets(train_seeds=cfg.train_seeds) + dataloader = DataLoader(train_data, batch_size=cfg.policy.learn.batch_size, shuffle=True) + test_dataloader = DataLoader(test_data, batch_size=cfg.policy.learn.batch_size, shuffle=True) + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + evaluator = InteractionSerialEvaluator( + cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name + ) + + # ========== + # Main loop + # ========== + learner.call_hook('before_run') + stop = False + iter_cnt = 0 + for epoch in range(cfg.policy.learn.train_epoch): + # train + criterion = torch.nn.CrossEntropyLoss() + for i, train_data in enumerate(dataloader): + learner.train(train_data) + iter_cnt += 1 + if iter_cnt >= max_iter: + stop = True + break + if epoch % 69 == 0: + policy._optimizer.param_groups[0]['lr'] /= 10 + if stop: + break + losses = [] + acces = [] + # Evaluation + for _, test_data in enumerate(test_dataloader): + observations, bfs_input_maps, bfs_output_maps = test_data['obs'], test_data['bfs_in'].long(), \ + test_data['bfs_out'].long() + states = observations + bfs_input_onehot = torch.nn.functional.one_hot(bfs_input_maps, 5).float() + + bfs_states = torch.cat([ + states, + bfs_input_onehot, + ], dim=-1).cuda() + logits = policy._model(bfs_states)['logit'] + logits = logits.flatten(0, -2) + labels = bfs_output_maps.flatten(0, -1).cuda() + + loss = criterion(logits, labels).item() + preds = torch.argmax(logits, dim=-1) + acc = torch.sum((preds == labels)) / preds.shape[0] + + losses.append(loss) + acces.append(acc) + print('Test Finished! Loss: {} acc: {}'.format(sum(losses) / len(losses), sum(acces) / len(acces))) + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter) + learner.call_hook('after_run') + print('final reward is: {}'.format(reward)) + return policy, stop diff --git a/ding/model/template/__init__.py b/ding/model/template/__init__.py index 11f7aa35b5..e994286ac3 100644 --- a/ding/model/template/__init__.py +++ b/ding/model/template/__init__.py @@ -22,4 +22,4 @@ from .madqn import MADQN from .vae import VanillaVAE from .decision_transformer import DecisionTransformer -from .procedure_cloning import ProcedureCloning +from .procedure_cloning import ProcedureCloningMCTS, ProcedureCloningBFS diff --git a/ding/model/template/bc.py b/ding/model/template/bc.py index 0da1f3e171..b40ef4f118 100644 --- a/ding/model/template/bc.py +++ b/ding/model/template/bc.py @@ -13,15 +13,16 @@ class DiscreteBC(nn.Module): def __init__( - self, - obs_shape: Union[int, SequenceType], - action_shape: Union[int, SequenceType], - encoder_hidden_size_list: SequenceType = [128, 128, 64], - dueling: bool = True, - head_hidden_size: Optional[int] = None, - head_layer_num: int = 1, - activation: Optional[nn.Module] = nn.ReLU(), - norm_type: Optional[str] = None + self, + obs_shape: Union[int, SequenceType], + action_shape: Union[int, SequenceType], + encoder_hidden_size_list: SequenceType = [128, 128, 64], + dueling: bool = True, + head_hidden_size: Optional[int] = None, + head_layer_num: int = 1, + activation: Optional[nn.Module] = nn.ReLU(), + norm_type: Optional[str] = None, + strides: Optional[list] = None, ) -> None: """ Overview: @@ -49,7 +50,14 @@ def __init__( self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type) # Conv Encoder elif len(obs_shape) == 3: - self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type) + if not strides: + self.encoder = ConvEncoder( + obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type + ) + else: + self.encoder = ConvEncoder( + obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type, stride=strides + ) else: raise RuntimeError( "not support obs_shape for pre-defined encoder: {}, please customize your own BC".format(obs_shape) diff --git a/ding/model/template/procedure_cloning.py b/ding/model/template/procedure_cloning.py index a86e813933..d0d6ffcbd0 100644 --- a/ding/model/template/procedure_cloning.py +++ b/ding/model/template/procedure_cloning.py @@ -1,6 +1,8 @@ -from typing import Optional, Tuple +from typing import Optional, Tuple, Union, Dict + import torch import torch.nn as nn + from ding.utils import MODEL_REGISTRY, SequenceType from ding.torch_utils.network.transformer import Attention from ding.torch_utils.network.nn_module import fc_block, build_normalization @@ -42,8 +44,8 @@ def forward(self, x: torch.Tensor): return x -@MODEL_REGISTRY.register('pc') -class ProcedureCloning(nn.Module): +@MODEL_REGISTRY.register('pc_mcts') +class ProcedureCloningMCTS(nn.Module): def __init__( self, @@ -53,7 +55,7 @@ def __init__( cnn_activation: Optional[nn.Module] = nn.ReLU(), cnn_kernel_size: SequenceType = [3, 3, 3, 3, 3], cnn_stride: SequenceType = [1, 1, 1, 1, 1], - cnn_padding: Optional[SequenceType] = ['same', 'same', 'same', 'same', 'same'], + cnn_padding: Optional[SequenceType] = [1, 1, 1, 1, 1], mlp_hidden_list: SequenceType = [256, 256], mlp_activation: Optional[nn.Module] = nn.ReLU(), att_heads: int = 8, @@ -117,3 +119,94 @@ def forward(self, states: torch.Tensor, goals: torch.Tensor, action_preds = self.predict_action(h[:, 1:, :]) return goal_preds, action_preds + + +class BFSConvEncoder(nn.Module): + """ + Overview: The ``BFSConvolution Encoder`` used to encode raw 2-dim observations. And output a feature map with the + same height and width as input. Interfaces: ``__init__``, ``forward``. + """ + + def __init__( + self, + obs_shape: SequenceType, + hidden_size_list: SequenceType = [32, 64, 64, 128], + activation: Optional[nn.Module] = nn.ReLU(), + kernel_size: SequenceType = [8, 4, 3], + stride: SequenceType = [4, 2, 1], + padding: Optional[SequenceType] = None, + ) -> None: + """ + Overview: + Init the ``BFSConvolution Encoder`` according to the provided arguments. + Arguments: + - obs_shape (:obj:`SequenceType`): Sequence of ``in_channel``, plus one or more ``input size``. + - hidden_size_list (:obj:`SequenceType`): Sequence of ``hidden_size`` of subsequent conv layers \ + and the final dense layer. + - activation (:obj:`nn.Module`): Type of activation to use in the conv ``layers`` and ``ResBlock``. \ + Default is ``nn.ReLU()``. + - kernel_size (:obj:`SequenceType`): Sequence of ``kernel_size`` of subsequent conv layers. + - stride (:obj:`SequenceType`): Sequence of ``stride`` of subsequent conv layers. + - padding (:obj:`SequenceType`): Padding added to all four sides of the input for each conv layer. \ + See ``nn.Conv2d`` for more details. Default is ``None``. + """ + super(BFSConvEncoder, self).__init__() + self.obs_shape = obs_shape + self.act = activation + self.hidden_size_list = hidden_size_list + if padding is None: + padding = [0 for _ in range(len(kernel_size))] + + layers = [] + input_size = obs_shape[0] # in_channel + for i in range(len(kernel_size)): + layers.append(nn.Conv2d(input_size, hidden_size_list[i], kernel_size[i], stride[i], padding[i])) + layers.append(self.act) + input_size = hidden_size_list[i] + layers = layers[:-1] + self.main = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Overview: + Return output tensor of the env observation. + Arguments: + - x (:obj:`torch.Tensor`): Env raw observation. + Returns: + - outputs (:obj:`torch.Tensor`): Output embedding tensor. + Shapes: + - outputs: :math:`(B, N, H, W)`, where ``N = hidden_size_list[-1]``. + """ + return self.main(x) + + +@MODEL_REGISTRY.register('pc_bfs') +class ProcedureCloningBFS(nn.Module): + + def __init__( + self, + obs_shape: Union[int, SequenceType], + action_shape: Union[int, SequenceType], + encoder_hidden_size_list: SequenceType = [128, 128, 256, 256], + ): + super().__init__() + num_layers = len(encoder_hidden_size_list) + + kernel_sizes = (3, ) * (num_layers + 1) + stride_sizes = (1, ) * (num_layers + 1) + padding_sizes = (1, ) * (num_layers + 1) + # The output channel equals to action_shape + 1 + encoder_hidden_size_list.append(action_shape + 1) + + self._encoder = BFSConvEncoder( + obs_shape=obs_shape, + hidden_size_list=encoder_hidden_size_list, + kernel_size=kernel_sizes, + stride=stride_sizes, + padding=padding_sizes, + ) + + def forward(self, x: torch.Tensor) -> Dict: + x = x.permute(0, 3, 1, 2) + x = self._encoder(x) + return {'logit': x.permute(0, 2, 3, 1)} diff --git a/ding/model/template/tests/test_procedure_cloning.py b/ding/model/template/tests/test_procedure_cloning.py index e169ec2cee..5a52542879 100644 --- a/ding/model/template/tests/test_procedure_cloning.py +++ b/ding/model/template/tests/test_procedure_cloning.py @@ -1,11 +1,9 @@ -import torch import pytest -import numpy as np from itertools import product -from ding.model.template import ProcedureCloning -from ding.torch_utils import is_differentiable -from ding.utils import squeeze +import torch + +from ding.model.template import ProcedureCloningMCTS, ProcedureCloningBFS B = 4 T = 15 @@ -19,16 +17,24 @@ @pytest.mark.parametrize('obs_shape, action_dim', args) class TestProcedureCloning: - def test_procedure_cloning(self, obs_shape, action_dim): + def test_procedure_cloning_mcts(self, obs_shape, action_dim): inputs = { 'states': torch.randn(B, *obs_shape), 'goals': torch.randn(B, *obs_shape), 'actions': torch.randn(B, T, action_dim) } - model = ProcedureCloning(obs_shape=obs_shape, action_dim=action_dim) + model = ProcedureCloningMCTS(obs_shape=obs_shape, action_dim=action_dim) print(model) goal_preds, action_preds = model(inputs['states'], inputs['goals'], inputs['actions']) assert goal_preds.shape == (B, obs_embeddings) assert action_preds.shape == (B, T + 1, action_dim) + + def test_procedure_cloning_bfs(self, obs_shape, action_dim): + o_shape = (obs_shape[2], obs_shape[0], obs_shape[1]) + model = ProcedureCloningBFS(obs_shape=o_shape, action_shape=action_dim) + + inputs = torch.randn(B, *obs_shape) + map_preds = model(inputs) + assert map_preds['logit'].shape == (B, obs_shape[0], obs_shape[1], action_dim + 1) diff --git a/ding/policy/__init__.py b/ding/policy/__init__.py index a636a91191..15575c7d30 100644 --- a/ding/policy/__init__.py +++ b/ding/policy/__init__.py @@ -46,5 +46,7 @@ from .bc import BehaviourCloningPolicy from .ibc import IBCPolicy +from .pc import ProcedureCloningBFSPolicy + # new-type policy from .ppof import PPOFPolicy diff --git a/ding/policy/pc.py b/ding/policy/pc.py new file mode 100644 index 0000000000..3788c70012 --- /dev/null +++ b/ding/policy/pc.py @@ -0,0 +1,187 @@ +import math +from typing import List, Dict, Any, Tuple +from collections import namedtuple + +import torch +import torch.nn as nn +from torch.optim import Adam, SGD, AdamW +from torch.optim.lr_scheduler import LambdaLR + +from ding.policy import Policy +from ding.model import model_wrap +from ding.torch_utils import to_device +from ding.utils import EasyTimer +from ding.utils import POLICY_REGISTRY + + +@POLICY_REGISTRY.register('pc_bfs') +class ProcedureCloningBFSPolicy(Policy): + + def default_model(self) -> Tuple[str, List[str]]: + return 'pc_bfs', ['ding.model.template.procedure_cloning'] + + config = dict( + type='pc', + cuda=False, + on_policy=False, + continuous=False, + max_bfs_steps=100, + learn=dict( + multi_gpu=False, + update_per_collect=1, + batch_size=32, + learning_rate=1e-5, + lr_decay=False, + decay_epoch=30, + decay_rate=0.1, + warmup_lr=1e-4, + warmup_epoch=3, + optimizer='SGD', + momentum=0.9, + weight_decay=1e-4, + ), + collect=dict( + unroll_len=1, + noise=False, + noise_sigma=0.2, + noise_range=dict( + min=-0.5, + max=0.5, + ), + ), + eval=dict(), + other=dict(replay_buffer=dict(replay_buffer_size=10000)), + ) + + def _init_learn(self): + assert self._cfg.learn.optimizer in ['SGD', 'Adam'] + if self._cfg.learn.optimizer == 'SGD': + self._optimizer = SGD( + self._model.parameters(), + lr=self._cfg.learn.learning_rate, + weight_decay=self._cfg.learn.weight_decay, + momentum=self._cfg.learn.momentum + ) + elif self._cfg.learn.optimizer == 'Adam': + if self._cfg.learn.weight_decay is None: + self._optimizer = Adam( + self._model.parameters(), + lr=self._cfg.learn.learning_rate, + ) + else: + self._optimizer = AdamW( + self._model.parameters(), + lr=self._cfg.learn.learning_rate, + weight_decay=self._cfg.learn.weight_decay + ) + if self._cfg.learn.lr_decay: + + def lr_scheduler_fn(epoch): + if epoch <= self._cfg.learn.warmup_epoch: + return self._cfg.learn.warmup_lr / self._cfg.learn.learning_rate + else: + ratio = (epoch - self._cfg.learn.warmup_epoch) // self._cfg.learn.decay_epoch + return math.pow(self._cfg.learn.decay_rate, ratio) + + self._lr_scheduler = LambdaLR(self._optimizer, lr_scheduler_fn) + self._timer = EasyTimer(cuda=True) + self._learn_model = model_wrap(self._model, 'base') + self._learn_model.reset() + self._max_bfs_steps = self._cfg.max_bfs_steps + self._maze_size = self._cfg.maze_size + self._num_actions = self._cfg.num_actions + + self._loss = nn.CrossEntropyLoss() + + def process_states(self, observations, maze_maps): + """Returns [B, W, W, 3] binary values. Channels are (wall; goal; obs)""" + loc = torch.nn.functional.one_hot( + (observations[:, 0] * self._maze_size + observations[:, 1]).long(), + self._maze_size * self._maze_size, + ).long() + loc = torch.reshape(loc, [observations.shape[0], self._maze_size, self._maze_size]) + states = torch.cat([maze_maps, loc], dim=-1).long() + return states + + def _forward_learn(self, data): + if self._cuda: + collated_data = to_device(data, self._device) + else: + collated_data = data + observations = collated_data['obs'], + bfs_input_maps, bfs_output_maps = collated_data['bfs_in'].long(), collated_data['bfs_out'].long() + states = observations + bfs_input_onehot = torch.nn.functional.one_hot(bfs_input_maps, self._num_actions + 1).float() + + bfs_states = torch.cat([ + states, + bfs_input_onehot, + ], dim=-1) + logits = self._model(bfs_states)['logit'] + logits = logits.flatten(0, -2) + labels = bfs_output_maps.flatten(0, -1) + + loss = self._loss(logits, labels) + preds = torch.argmax(logits, dim=-1) + acc = torch.sum((preds == labels)) / preds.shape[0] + + self._optimizer.zero_grad() + loss.backward() + self._optimizer.step() + pred_loss = loss.item() + + cur_lr = [param_group['lr'] for param_group in self._optimizer.param_groups] + cur_lr = sum(cur_lr) / len(cur_lr) + return {'cur_lr': cur_lr, 'total_loss': pred_loss, 'acc': acc} + + def _monitor_vars_learn(self): + return ['cur_lr', 'total_loss', 'acc'] + + def _init_eval(self): + self._eval_model = model_wrap(self._model, wrapper_name='base') + self._eval_model.reset() + + def _forward_eval(self, data): + if self._cuda: + data = to_device(data, self._device) + max_len = self._max_bfs_steps + data_id = list(data.keys()) + output = {} + + for ii in data_id: + states = data[ii].unsqueeze(0) + bfs_input_maps = self._num_actions * torch.ones([1, self._maze_size, self._maze_size]).long() + if self._cuda: + bfs_input_maps = to_device(bfs_input_maps, self._device) + xy = torch.where(states[:, :, :, -1] == 1) + observation = (xy[1][0].item(), xy[2][0].item()) + + i = 0 + while bfs_input_maps[0, observation[0], observation[1]].item() == self._num_actions and i < max_len: + bfs_input_onehot = torch.nn.functional.one_hot(bfs_input_maps, self._num_actions + 1).long() + + bfs_states = torch.cat([ + states, + bfs_input_onehot, + ], dim=-1) + logits = self._model(bfs_states)['logit'] + bfs_input_maps = torch.argmax(logits, dim=-1) + i += 1 + output[ii] = bfs_input_maps[0, observation[0], observation[1]] + if self._cuda: + output[ii] = {'action': to_device(output[ii], 'cpu'), 'info': {}} + if output[ii]['action'].item() == self._num_actions: + output[ii]['action'] = torch.randint(low=0, high=self._num_actions, size=[1])[0] + return output + + def _init_collect(self) -> None: + raise NotImplementedError + + def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]: + raise NotImplementedError + + def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict: + raise NotImplementedError + + def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + raise NotImplementedError diff --git a/ding/utils/__init__.py b/ding/utils/__init__.py index b900e0e26d..2980747f6a 100644 --- a/ding/utils/__init__.py +++ b/ding/utils/__init__.py @@ -29,6 +29,7 @@ from .type_helper import SequenceType from .render_helper import render, fps from .fast_copy import fastcopy +from .bfs_helper import get_vi_sequence if ding.enable_linklink: from .linklink_dist_helper import get_rank, get_world_size, dist_mode, dist_init, dist_finalize, \ diff --git a/ding/utils/bfs_helper.py b/ding/utils/bfs_helper.py new file mode 100644 index 0000000000..2bef851ccf --- /dev/null +++ b/ding/utils/bfs_helper.py @@ -0,0 +1,59 @@ +import numpy as np +import torch + + +# BFS algorithm +def get_vi_sequence(env, observation): + """Returns [L, W, W] optimal actions.""" + xy = np.where(observation[Ellipsis, -1] == 1) + start_x, start_y = xy[0][0], xy[1][0] + target_location = env.target_location + nav_map = env.nav_map + current_points = [target_location] + chosen_actions = {target_location: 0} + visited_points = {target_location: True} + vi_sequence = [] + + vi_map = np.full((env.size, env.size), fill_value=env.n_action, dtype=np.int32) + + found_start = False + while current_points and not found_start: + next_points = [] + for point_x, point_y in current_points: + for (action, (next_point_x, next_point_y)) in [(0, (point_x - 1, point_y)), (1, (point_x, point_y - 1)), + (2, (point_x + 1, point_y)), (3, (point_x, point_y + 1))]: + + if (next_point_x, next_point_y) in visited_points: + continue + + if not (0 <= next_point_x < len(nav_map) and 0 <= next_point_y < len(nav_map[next_point_x])): + continue + + if nav_map[next_point_x][next_point_y] == 'x': + continue + + next_points.append((next_point_x, next_point_y)) + visited_points[(next_point_x, next_point_y)] = True + chosen_actions[(next_point_x, next_point_y)] = action + vi_map[next_point_x, next_point_y] = action + + if next_point_x == start_x and next_point_y == start_y: + found_start = True + vi_sequence.append(vi_map.copy()) + current_points = next_points + track_back = [] + if found_start: + cur_x, cur_y = start_x, start_y + while cur_x != target_location[0] or cur_y != target_location[1]: + act = vi_sequence[-1][cur_x, cur_y] + track_back.append((torch.FloatTensor(env.process_states([cur_x, cur_y], env.get_maze_map())), act)) + if act == 0: + cur_x += 1 + elif act == 1: + cur_y += 1 + elif act == 2: + cur_x -= 1 + elif act == 3: + cur_y -= 1 + + return np.array(vi_sequence), track_back diff --git a/ding/utils/data/dataset.py b/ding/utils/data/dataset.py index a2dec4c2ea..ba577ba4a2 100644 --- a/ding/utils/data/dataset.py +++ b/ding/utils/data/dataset.py @@ -1,13 +1,16 @@ from typing import List, Dict, Tuple import pickle + +import easydict import torch import numpy as np from ditk import logging from copy import deepcopy +from torch.utils.data import Dataset from dataclasses import dataclass from easydict import EasyDict -from torch.utils.data import Dataset +from ding.utils.bfs_helper import get_vi_sequence from ding.utils import DATASET_REGISTRY, import_module from ding.rl_utils import discount_cumsum @@ -423,6 +426,78 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tenso return timesteps, states, actions, returns_to_go, traj_mask +class PCDataset(Dataset): + + def __init__(self, all_data): + self._data = all_data + + def __getitem__(self, item): + return {'obs': self._data[0][item], 'bfs_in': self._data[1][item], 'bfs_out': self._data[2][item]} + + def __len__(self): + return self._data[0].shape[0] + + +def load_bfs_datasets(train_seeds=1, test_seeds=5): + from dizoo.maze.envs import Maze + + def load_env(seed): + ccc = easydict.EasyDict({'size': 16}) + e = Maze(ccc) + e.seed(seed) + e.reset() + return e + + envs = [load_env(i) for i in range(train_seeds + test_seeds)] + + observations_train = [] + observations_test = [] + bfs_input_maps_train = [] + bfs_input_maps_test = [] + bfs_output_maps_train = [] + bfs_output_maps_test = [] + for idx, env in enumerate(envs): + if idx < train_seeds: + observations = observations_train + bfs_input_maps = bfs_input_maps_train + bfs_output_maps = bfs_output_maps_train + else: + observations = observations_test + bfs_input_maps = bfs_input_maps_test + bfs_output_maps = bfs_output_maps_test + + start_obs = env.process_states(env._get_obs(), env.get_maze_map()) + _, track_back = get_vi_sequence(env, start_obs) + env_observations = torch.stack([track_back[i][0] for i in range(len(track_back))], dim=0) + + for i in range(env_observations.shape[0]): + bfs_sequence, _ = get_vi_sequence(env, env_observations[i].numpy().astype(np.int32)) # [L, W, W] + bfs_input_map = env.n_action * np.ones([env.size, env.size], dtype=np.long) + + for j in range(bfs_sequence.shape[0]): + bfs_input_maps.append(torch.from_numpy(bfs_input_map)) + bfs_output_maps.append(torch.from_numpy(bfs_sequence[j])) + observations.append(env_observations[i]) + bfs_input_map = bfs_sequence[j] + + train_data = PCDataset( + ( + torch.stack(observations_train, dim=0), + torch.stack(bfs_input_maps_train, dim=0), + torch.stack(bfs_output_maps_train, dim=0), + ) + ) + test_data = PCDataset( + ( + torch.stack(observations_test, dim=0), + torch.stack(bfs_input_maps_test, dim=0), + torch.stack(bfs_output_maps_test, dim=0), + ) + ) + + return train_data, test_data + + @DATASET_REGISTRY.register('bco') class BCODataset(Dataset): diff --git a/ding/utils/tests/test_bfs_helper.py b/ding/utils/tests/test_bfs_helper.py new file mode 100644 index 0000000000..7f095907a0 --- /dev/null +++ b/ding/utils/tests/test_bfs_helper.py @@ -0,0 +1,26 @@ +import easydict +import numpy +import pytest + +from ding.utils import get_vi_sequence +from dizoo.maze.envs.maze_env import Maze + + +@pytest.mark.unittest +class TestBFSHelper: + + def test_bfs(self): + + def load_env(seed): + ccc = easydict.EasyDict({'size': 16}) + e = Maze(ccc) + e.seed(seed) + e.reset() + return e + + env = load_env(314) + start_obs = env.process_states(env._get_obs(), env.get_maze_map()) + vi_sequence, track_back = get_vi_sequence(env, start_obs) + assert vi_sequence.shape[1:] == (16, 16) + assert track_back[0][0].shape == (16, 16, 3) + assert isinstance(track_back[0][1], numpy.int32) diff --git a/dizoo/maze/__init__.py b/dizoo/maze/__init__.py new file mode 100644 index 0000000000..3bfa255bd5 --- /dev/null +++ b/dizoo/maze/__init__.py @@ -0,0 +1,3 @@ +from gym.envs.registration import register + +register(id='Maze', entry_point='dizoo.maze.envs:Maze') diff --git a/dizoo/maze/config/maze_bc_config.py b/dizoo/maze/config/maze_bc_config.py new file mode 100644 index 0000000000..18c9d6ade8 --- /dev/null +++ b/dizoo/maze/config/maze_bc_config.py @@ -0,0 +1,56 @@ +from easydict import EasyDict + +maze_size = 16 +num_actions = 4 +maze_pc_config = dict( + exp_name="maze_bc_seed0", + env=dict( + collector_env_num=1, + evaluator_env_num=5, + n_evaluator_episode=5, + env_id='Maze', + size=maze_size, + wall_type='tunnel', + stop_value=1 + ), + policy=dict( + cuda=True, + maze_size=maze_size, + num_actions=num_actions, + max_bfs_steps=100, + model=dict( + obs_shape=[3, maze_size, maze_size], + action_shape=num_actions, + encoder_hidden_size_list=[ + 128, + 256, + 512, + 1024, + ], + strides=[1, 1, 1, 1] + ), + learn=dict( + # update_per_collect=4, + batch_size=256, + learning_rate=0.005, + train_epoch=5000, + optimizer='SGD', + ), + eval=dict(evaluator=dict(n_episode=5)), + collect=dict(), + ), +) +maze_pc_config = EasyDict(maze_pc_config) +main_config = maze_pc_config +maze_pc_create_config = dict( + env=dict( + type='maze', + import_names=['dizoo.maze.envs.maze_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='bc'), +) +maze_pc_create_config = EasyDict(maze_pc_create_config) +create_config = maze_pc_create_config + +# You can run `dizoo/maze/entry/maze_bc_main.py` to run this config. diff --git a/dizoo/maze/config/maze_pc_config.py b/dizoo/maze/config/maze_pc_config.py new file mode 100644 index 0000000000..2a5f40b278 --- /dev/null +++ b/dizoo/maze/config/maze_pc_config.py @@ -0,0 +1,57 @@ +from easydict import EasyDict + +maze_size = 16 +num_actions = 4 +maze_pc_config = dict( + exp_name="maze_pc_seed0", + train_seeds=5, + env=dict( + collector_env_num=8, + evaluator_env_num=5, + n_evaluator_episode=5, + env_id='Maze', + size=maze_size, + wall_type='tunnel', + stop_value=1, + ), + policy=dict( + cuda=True, + maze_size=maze_size, + num_actions=num_actions, + max_bfs_steps=100, + model=dict( + obs_shape=[8, maze_size, maze_size], + action_shape=num_actions, + encoder_hidden_size_list=[ + 128, + 256, + 512, + 1024, + ], + ), + learn=dict( + batch_size=32, + learning_rate=0.0005, + train_epoch=100, + optimizer='Adam', + ), + eval=dict(evaluator=dict(n_episode=5)), + collect=dict(), + ), +) +maze_pc_config = EasyDict(maze_pc_config) +main_config = maze_pc_config +maze_pc_create_config = dict( + env=dict( + type='maze', + import_names=['dizoo.maze.envs.maze_env'], + ), + env_manager=dict(type='subprocess'), + policy=dict(type='pc_bfs'), +) +maze_pc_create_config = EasyDict(maze_pc_create_config) +create_config = maze_pc_create_config + +if __name__ == '__main__': + from ding.entry import serial_pipeline_pc + serial_pipeline_pc([maze_pc_config, maze_pc_create_config], seed=0) diff --git a/dizoo/maze/entry/maze_bc_main.py b/dizoo/maze/entry/maze_bc_main.py new file mode 100644 index 0000000000..efd9b6d2a8 --- /dev/null +++ b/dizoo/maze/entry/maze_bc_main.py @@ -0,0 +1,206 @@ +from typing import Union, Optional, Tuple +import os +from functools import partial +from copy import deepcopy + +import easydict +import torch +import numpy as np +from tensorboardX import SummaryWriter +from torch.utils.data import DataLoader, Dataset + +from ding.envs import get_vec_env_setting, create_env_manager +from ding.worker import BaseLearner, InteractionSerialEvaluator +from ding.config import read_config, compile_config +from ding.policy import create_policy +from ding.utils import set_pkg_seed +from dizoo.maze.envs.maze_env import Maze + + +# BFS algorithm +def get_vi_sequence(env, observation): + """Returns [L, W, W] optimal actions.""" + xy = np.where(observation[Ellipsis, -1] == 1) + start_x, start_y = xy[0][0], xy[1][0] + target_location = env.target_location + nav_map = env.nav_map + current_points = [target_location] + chosen_actions = {target_location: 0} + visited_points = {target_location: True} + vi_sequence = [] + + vi_map = np.full((env.size, env.size), fill_value=env.n_action, dtype=np.int32) + + found_start = False + while current_points and not found_start: + next_points = [] + for point_x, point_y in current_points: + for (action, (next_point_x, next_point_y)) in [(0, (point_x - 1, point_y)), (1, (point_x, point_y - 1)), + (2, (point_x + 1, point_y)), (3, (point_x, point_y + 1))]: + + if (next_point_x, next_point_y) in visited_points: + continue + + if not (0 <= next_point_x < len(nav_map) and 0 <= next_point_y < len(nav_map[next_point_x])): + continue + + if nav_map[next_point_x][next_point_y] == 'x': + continue + + next_points.append((next_point_x, next_point_y)) + visited_points[(next_point_x, next_point_y)] = True + chosen_actions[(next_point_x, next_point_y)] = action + vi_map[next_point_x, next_point_y] = action + + if next_point_x == start_x and next_point_y == start_y: + found_start = True + vi_sequence.append(vi_map.copy()) + current_points = next_points + track_back = [] + if found_start: + cur_x, cur_y = start_x, start_y + while cur_x != target_location[0] or cur_y != target_location[1]: + act = vi_sequence[-1][cur_x, cur_y] + track_back.append(( + torch.FloatTensor(env.process_states([cur_x, cur_y], env.get_maze_map())), + act)) + if act == 0: + cur_x += 1 + elif act == 1: + cur_y += 1 + elif act == 2: + cur_x -= 1 + elif act == 3: + cur_y -= 1 + + return np.array(vi_sequence), track_back + + +class BCDataset(Dataset): + + def __init__(self, all_data): + self._data = all_data + + def __getitem__(self, item): + return {'obs': self._data[item][0], 'action': self._data[item][1]} + + def __len__(self): + return len(self._data) + + +def load_bc_dataset(train_seeds=1, test_seeds=1, batch_size=32): + def load_env(seed): + ccc = easydict.EasyDict({'size': 16}) + e = Maze(ccc) + e.seed(seed) + e.reset() + return e + + envs = [load_env(i) for i in range(train_seeds + test_seeds)] + data_train = [] + data_test = [] + + for idx, env in enumerate(envs): + if idx < train_seeds: + data = data_train + else: + data = data_test + + start_obs = env.process_states(env._get_obs(), env.get_maze_map()) + _, track_back = get_vi_sequence(env, start_obs) + + data += track_back + + + train_data = BCDataset( + data_train + ) + test_data = BCDataset( + data_test + ) + + train_dataset = DataLoader(train_data, batch_size=batch_size, shuffle=True) + test_dataset = DataLoader(test_data, batch_size=batch_size, shuffle=True) + return train_dataset, test_dataset + + +def serial_pipeline_bc( + input_cfg: Union[str, Tuple[dict, dict]], + seed: int = 0, + model: Optional[torch.nn.Module] = None, + max_iter=int(1e6), +) -> Union['Policy', bool]: # noqa + r""" + Overview: + Serial pipeline entry of imitation learning. + Arguments: + - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \ + ``str`` type means config file path. \ + ``Tuple[dict, dict]`` type means [user_config, create_cfg]. + - seed (:obj:`int`): Random seed. + - data_path (:obj:`str`): Path of training data. + - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module. + Returns: + - policy (:obj:`Policy`): Converged policy. + - convergence (:obj:`bool`): whether il training is converged + """ + if isinstance(input_cfg, str): + cfg, create_cfg = read_config(input_cfg) + else: + cfg, create_cfg = deepcopy(input_cfg) + cfg = compile_config(cfg, seed=seed, auto=True, create_cfg=create_cfg) + + # Env, Policy + env_fn, _, evaluator_env_cfg = get_vec_env_setting(cfg.env) + evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg]) + # Random seed + evaluator_env.seed(cfg.seed, dynamic_seed=False) + set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda) + policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'eval']) + + # Main components + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) + dataloader, test_dataloader = load_bc_dataset() + learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) + evaluator = InteractionSerialEvaluator( + cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name + ) + + # ========== + # Main loop + # ========== + learner.call_hook('before_run') + stop = False + iter_cnt = 0 + for epoch in range(cfg.policy.learn.train_epoch): + # Evaluate policy performance + loss_list = [] + for _, bat in enumerate(test_dataloader): + bat['action'] = bat['action'].long() + res = policy._forward_eval(bat['obs']) + res = torch.argmax(res['logit'], dim=1) + loss_list.append(torch.sum(res == bat['action'].squeeze(-1)).item() / bat['action'].shape[0]) + label = 'validation_acc' + tb_logger.add_scalar(label, sum(loss_list) / len(loss_list), iter_cnt) + for i, train_data in enumerate(dataloader): + if evaluator.should_eval(learner.train_iter): + stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter) + if stop: + break + train_data['action'] = train_data['action'].long() + learner.train(train_data) + iter_cnt += 1 + if iter_cnt >= max_iter: + stop = True + break + if stop: + break + + learner.call_hook('after_run') + print('final reward is: {}'.format(reward)) + return policy, stop + + +if __name__ == '__main__': + from dizoo.maze.config.maze_bc_config import main_config, create_config + serial_pipeline_bc([main_config, create_config], seed=0) diff --git a/dizoo/maze/envs/__init__.py b/dizoo/maze/envs/__init__.py new file mode 100644 index 0000000000..ab42c5b39d --- /dev/null +++ b/dizoo/maze/envs/__init__.py @@ -0,0 +1 @@ +from .maze_env import Maze diff --git a/dizoo/maze/envs/maze_env.py b/dizoo/maze/envs/maze_env.py new file mode 100644 index 0000000000..67cb1469d4 --- /dev/null +++ b/dizoo/maze/envs/maze_env.py @@ -0,0 +1,380 @@ +from typing import List + +import copy +import numpy as np +import gym +from gym import spaces +from gym.utils import seeding + +from ding.envs import BaseEnvTimestep +from ding.utils import ENV_REGISTRY + + +@ENV_REGISTRY.register('maze') +class Maze(gym.Env): + """ + Environment with random maze layouts. The ASCII representation of the mazes include the following objects: + - ``: empty + - `x`: wall + - `S`: the start location (optional) + - `T`: the target location. + """ + KEY_EMPTY = 0 + KEY_WALL = 1 + KEY_TARGET = 2 + KEY_START = 3 + ASCII_MAP = { + KEY_EMPTY: ' ', + KEY_WALL: 'x', + KEY_TARGET: 'T', + KEY_START: 'S', + } + + def __init__( + self, + cfg, + ): + self._size = cfg.size + self._init_flag = False + self._random_start = True + self._seed = None + self._step = 0 + + def reset(self): + self.active_init() + obs = self._get_obs() + self._step = 0 + return self.process_states(obs, self.get_maze_map()) + + def seed(self, seed: int, dynamic_seed: bool = True) -> None: + self._seed = seed + self._dynamic_seed = dynamic_seed + np.random.seed(self._seed) + + def active_init( + self, + tabular_obs=False, + reward_fn=lambda x, y, tx, ty: 1 if (x == tx and y == ty) else 0, + done_fn=lambda x, y, tx, ty: x == tx and y == ty + ): + self._maze = self.generate_maze(self.size, self._seed, 'tunnel') + self._num_maze_keys = len(Maze.ASCII_MAP.keys()) + nav_map = self.maze_to_ascii(self._maze) + self._map = nav_map + self._tabular_obs = tabular_obs + self._reward_fn = reward_fn + self._done_fn = done_fn + if self._reward_fn is None: + self._reward_fn = lambda x, y, tx, ty: float(x == tx and y == ty) + if self._done_fn is None: + self._done_fn = lambda x, y, tx, ty: False + + self._max_x = len(self._map) + if not self._max_x: + raise ValueError('Invalid map.') + self._max_y = len(self._map[0]) + if not all(len(m) == self._max_y for m in self._map): + raise ValueError('Invalid map.') + self._start_x, self._start_y = self._find_initial_point() + self._target_x, self._target_y = self._find_target_point() + self._x, self._y = self._start_x, self._start_y + + self._n_state = self._max_x * self._max_y + self._n_action = 4 + + if self._tabular_obs: + self.observation_space = spaces.Discrete(self._n_state) + else: + self.observation_space = spaces.Box(low=0.0, high=np.inf, shape=(16, 16, 3)) + + self.action_space = spaces.Discrete(self._n_action) + self.reward_space = spaces.Box(low=0, high=1, shape=(1, ), dtype=np.float32) + + def random_start(self): + init_x, init_y = self._x, self._y + while True: # Find empty grid cell. + self._x = self.np_random.integers(self._max_x) + self._y = self.np_random.integers(self._max_y) + if self._map[self._x][self._y] != 'x': + break + ret = copy.deepcopy(self.process_states(self._get_obs(), self.get_maze_map())) + self._x, self._y = init_x, init_y + return ret + + def close(self) -> None: + if self._init_flag: + self._env.close() + self._init_flag = False + + @property + def num_maze_keys(self): + return self._num_maze_keys + + @property + def size(self): + return self._size + + def process_states(self, observations, maze_maps): + """Returns [B, W, W, 3] binary values. Channels are (wall; goal; obs)""" + loc = np.eye(self._size * self._size, dtype=np.long)[observations[0] * self._size + observations[1]] + loc = np.reshape(loc, [self._size, self._size]) + maze_maps = maze_maps.astype(np.long) + + states = np.concatenate([maze_maps, loc[Ellipsis, None]], axis=-1, dtype=np.long) + return states + + def get_maze_map(self, stacked=True): + if not stacked: + return self._maze.copy() + wall = self._maze.copy() + target_x, target_y = self.target_location + assert wall[target_x][target_y] == Maze.KEY_TARGET + wall[target_x][target_y] = 0 + target = np.zeros((self._size, self._size)) + target[target_x][target_y] = 1 + assert wall[self._start_x][self._start_y] == Maze.KEY_START + wall[self._start_x][self._start_y] = 0 + return np.stack([wall, target], axis=-1) + + def generate_maze(self, size, seed, wall_type): + rng, _ = seeding.np_random(seed) + maze = np.full((size, size), fill_value=Maze.KEY_EMPTY, dtype=int) + + if wall_type == 'none': + maze[[0, -1], :] = Maze.KEY_WALL + maze[:, [0, -1]] = Maze.KEY_WALL + elif wall_type == 'tunnel': + self.sample_wall(maze, rng) + elif wall_type.startswith('blocks:'): + maze[[0, -1], :] = Maze.KEY_WALL + maze[:, [0, -1]] = Maze.KEY_WALL + self.sample_blocks(maze, rng, int(wall_type.split(':')[-1])) + else: + raise ValueError('Unknown wall type: %s' % wall_type) + + loc_target = self.sample_location(maze, rng) + maze[loc_target] = Maze.KEY_TARGET + + loc_start = self.sample_location(maze, rng) + maze[loc_start] = Maze.KEY_START + self._start_x, self._start_y = loc_start + + return maze + + def sample_blocks(self, maze, rng, num_blocks): + """Sample single-block 'wall' or 'obstacles'.""" + for _ in range(num_blocks): + loc = self.sample_location(maze, rng) + maze[loc] = Maze.KEY_WALL + + def sample_wall( + self, maze, rng, shortcut_prob=0.1, inner_wall_thickness=1, outer_wall_thickness=1, corridor_thickness=2 + ): + room = maze + + # step 1: fill everything as wall + room[:] = Maze.KEY_WALL + + # step 2: prepare + # we move two pixels at a time, because the walls are also occupying pixels + delta = inner_wall_thickness + corridor_thickness + dx = [delta, -delta, 0, 0] + dy = [0, 0, delta, -delta] + + def get_loc_type(y, x): + # remember there is a outside wall of 1 pixel surrounding the room + if (y < outer_wall_thickness or y + corridor_thickness - 1 >= room.shape[0] - outer_wall_thickness): + return 'invalid' + if (x < outer_wall_thickness or x + corridor_thickness - 1 >= room.shape[1] - outer_wall_thickness): + return 'invalid' + # already visited + if room[y, x] == Maze.KEY_EMPTY: + return 'occupied' + return 'valid' + + def connect_pixel(y, x, ny, nx): + pixel = Maze.KEY_EMPTY + if ny == y: + room[y:y + corridor_thickness, min(x, nx):max(x, nx) + corridor_thickness] = pixel + else: + room[min(y, ny):max(y, ny) + corridor_thickness, x:x + corridor_thickness] = pixel + + def carve_passage_from(y, x): + room[y, x] = Maze.KEY_EMPTY + for direction in rng.permutation(len(dx)): + ny = y + dy[direction] + nx = x + dx[direction] + + loc_type = get_loc_type(ny, nx) + if loc_type == 'invalid': + continue + elif loc_type == 'valid': + connect_pixel(y, x, ny, nx) + # recursion + carve_passage_from(ny, nx) + else: + # occupied + # we create shortcut with some probability, this is because + # we do not want to restrict to only one feasible path. + if rng.random() < shortcut_prob: + connect_pixel(y, x, ny, nx) + + carve_passage_from(outer_wall_thickness, outer_wall_thickness) + + def sample_location(self, maze, rng): + for _ in range(1000): + x, y = rng.integers(low=1, high=self._size, size=2) + if maze[x, y] == Maze.KEY_EMPTY: + return x, y + raise ValueError('Cannot sample empty location, make maze bigger?') + + @staticmethod + def key_to_ascii(key): + if key in Maze.ASCII_MAP: + return Maze.ASCII_MAP[key] + assert (Maze.KEY_OBJ <= key < Maze.KEY_OBJ + Maze.MAX_OBJ_TYPES) + return chr(ord('1') + key - Maze.KEY_OBJ) + + def maze_to_ascii(self, maze): + return [[Maze.key_to_ascii(x) for x in row] for row in maze] + + def tabular_obs_action(self, status_obs, action, include_maze_layout=False): + tabular_obs = self.get_tabular_obs(status_obs) + multiplier = self._n_action + if include_maze_layout: + multiplier += self._num_maze_keys + return multiplier * tabular_obs + action + + @staticmethod + def create_collector_env_cfg(cfg: dict) -> List[dict]: + collector_env_num = cfg.pop('collector_env_num') + cfg = copy.deepcopy(cfg) + cfg.is_train = True + return [cfg for _ in range(collector_env_num)] + + @staticmethod + def create_evaluator_env_cfg(cfg: dict) -> List[dict]: + evaluator_env_num = cfg.pop('evaluator_env_num') + cfg = copy.deepcopy(cfg) + cfg.is_train = False + return [cfg for _ in range(evaluator_env_num)] + + @property + def nav_map(self): + return self._map + + @property + def n_state(self): + return self._n_state + + @property + def n_action(self): + return self._n_action + + @property + def target_location(self): + return self._target_x, self._target_y + + @property + def tabular_obs(self): + return self._tabular_obs + + def _find_initial_point(self): + for x in range(self._max_x): + for y in range(self._max_y): + if self._map[x][y] == 'S': + break + if self._map[x][y] == 'S': + break + else: + return None, None + + return x, y + + def _find_target_point(self): + for x in range(self._max_x): + for y in range(self._max_y): + if self._map[x][y] == 'T': + break + if self._map[x][y] == 'T': + break + else: + raise ValueError('Target point not found in map.') + + return x, y + + def _get_obs(self): + if self._tabular_obs: + return self._x * self._max_y + self._y + else: + return np.array([self._x, self._y]) + + def get_tabular_obs(self, status_obs): + return self._max_y * status_obs[..., 0] + status_obs[..., 1] + + def get_xy(self, state): + x = state / self._max_y + y = state % self._max_y + return x, y + + def step(self, action): + last_x, last_y = self._x, self._y + if action == 0: + if self._x < self._max_x - 1: + self._x += 1 + elif action == 1: + if self._y < self._max_y - 1: + self._y += 1 + elif action == 2: + if self._x > 0: + self._x -= 1 + elif action == 3: + if self._y > 0: + self._y -= 1 + + if self._map[self._x][self._y] == 'x': + self._x, self._y = last_x, last_y + self._step += 1 + reward = self._reward_fn(self._x, self._y, self._target_x, self._target_y) + done = self._done_fn(self._x, self._y, self._target_x, self._target_y) + info = {} + if self._step > 100: + done = True + if done: + info['final_eval_reward'] = reward + info['eval_episode_return'] = reward + return BaseEnvTimestep(self.process_states(self._get_obs(), self.get_maze_map()), reward, done, info) + + +def get_value_map(env): + """Returns [W, W, A] one-hot VI actions.""" + target_location = env.target_location + nav_map = env.nav_map + current_points = [target_location] + chosen_actions = {target_location: 0} + visited_points = {target_location: True} + + while current_points: + next_points = [] + for point_x, point_y in current_points: + for (action, (next_point_x, next_point_y)) in [(0, (point_x - 1, point_y)), (1, (point_x, point_y - 1)), + (2, (point_x + 1, point_y)), (3, (point_x, point_y + 1))]: + + if (next_point_x, next_point_y) in visited_points: + continue + + if not (0 <= next_point_x < len(nav_map) and 0 <= next_point_y < len(nav_map[next_point_x])): + continue + + if nav_map[next_point_x][next_point_y] == 'x': + continue + + next_points.append((next_point_x, next_point_y)) + visited_points[(next_point_x, next_point_y)] = True + chosen_actions[(next_point_x, next_point_y)] = action + current_points = next_points + + value_map = np.zeros([env.size, env.size, env.n_action]) + for (x, y), action in chosen_actions.items(): + value_map[x][y][action] = 1 + return value_map diff --git a/dizoo/maze/envs/test_maze_env.py b/dizoo/maze/envs/test_maze_env.py new file mode 100644 index 0000000000..b8350d46d3 --- /dev/null +++ b/dizoo/maze/envs/test_maze_env.py @@ -0,0 +1,28 @@ +import pytest +import os +import numpy as np +from dizoo.maze.envs.maze_env import Maze +from easydict import EasyDict +import copy + + +@pytest.mark.envtest +class TestMazeEnv: + + def test_maze(self): + env = Maze(EasyDict({'size': 16})) + env.seed(314) + assert env._seed == 314 + obs = env.reset() + assert obs.shape == (16, 16, 3) + min_val, max_val = 0, 3 + for i in range(100): + random_action = np.random.randint(min_val, max_val, size=(1, )) + timestep = env.step(random_action) + print(timestep) + print(timestep.obs.max()) + assert isinstance(timestep.obs, np.ndarray) + assert isinstance(timestep.done, bool) + if timestep.done: + env.reset() + env.close()