diff --git a/examples/nas/multi-trial/nasbench201/base_ops.py b/examples/nas/multi-trial/nasbench201/base_ops.py new file mode 100644 index 0000000000..31427fa5f8 --- /dev/null +++ b/examples/nas/multi-trial/nasbench201/base_ops.py @@ -0,0 +1,138 @@ +import torch +import torch.nn as nn + + +OPS_WITH_STRIDE = { + 'none': lambda C_in, C_out, stride: Zero(C_in, C_out, stride), + 'avg_pool_3x3': lambda C_in, C_out, stride: Pooling(C_in, C_out, stride, 'avg'), + 'max_pool_3x3': lambda C_in, C_out, stride: Pooling(C_in, C_out, stride, 'max'), + 'conv_3x3': lambda C_in, C_out, stride: ReLUConvBN(C_in, C_out, (3, 3), (stride, stride), (1, 1), (1, 1)), + 'conv_1x1': lambda C_in, C_out, stride: ReLUConvBN(C_in, C_out, (1, 1), (stride, stride), (0, 0), (1, 1)), + 'skip_connect': lambda C_in, C_out, stride: nn.Identity() if stride == 1 and C_in == C_out + else FactorizedReduce(C_in, C_out, stride), +} + +PRIMITIVES = ['none', 'skip_connect', 'conv_1x1', 'conv_3x3', 'avg_pool_3x3'] + + +class ReLUConvBN(nn.Module): + def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation): + super(ReLUConvBN, self).__init__() + self.op = nn.Sequential( + nn.ReLU(inplace=False), + nn.Conv2d(C_in, C_out, kernel_size, stride=stride, + padding=padding, dilation=dilation, bias=False), + nn.BatchNorm2d(C_out) + ) + + def forward(self, x): + return self.op(x) + + +class SepConv(nn.Module): + def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation): + super(SepConv, self).__init__() + self.op = nn.Sequential( + nn.ReLU(inplace=False), + nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, + padding=padding, dilation=dilation, groups=C_in, bias=False), + nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), + nn.BatchNorm2d(C_out), + ) + + def forward(self, x): + return self.op(x) + + +class Pooling(nn.Module): + def __init__(self, C_in, C_out, stride, mode): + super(Pooling, self).__init__() + if C_in == C_out: + self.preprocess = None + else: + self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, 1) + if mode == 'avg': + self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False) + elif mode == 'max': + self.op = nn.MaxPool2d(3, stride=stride, padding=1) + else: + raise ValueError('Invalid mode={:} in Pooling'.format(mode)) + + def forward(self, x): + if self.preprocess: + x = self.preprocess(x) + return self.op(x) + + +class Zero(nn.Module): + def __init__(self, C_in, C_out, stride): + super(Zero, self).__init__() + self.C_in = C_in + self.C_out = C_out + self.stride = stride + self.is_zero = True + + def forward(self, x): + if self.C_in == self.C_out: + if self.stride == 1: + return x.mul(0.) + else: + return x[:, :, ::self.stride, ::self.stride].mul(0.) + else: + shape = list(x.shape) + shape[1] = self.C_out + zeros = x.new_zeros(shape, dtype=x.dtype, device=x.device) + return zeros + + +class FactorizedReduce(nn.Module): + def __init__(self, C_in, C_out, stride): + super(FactorizedReduce, self).__init__() + self.stride = stride + self.C_in = C_in + self.C_out = C_out + self.relu = nn.ReLU(inplace=False) + if stride == 2: + C_outs = [C_out // 2, C_out - C_out // 2] + self.convs = nn.ModuleList() + for i in range(2): + self.convs.append(nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=False)) + self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0) + else: + raise ValueError('Invalid stride : {:}'.format(stride)) + self.bn = nn.BatchNorm2d(C_out) + + def forward(self, x): + x = self.relu(x) + y = self.pad(x) + out = torch.cat([self.convs[0](x), self.convs[1](y[:, :, 1:, 1:])], dim=1) + out = self.bn(out) + return out + + +class ResNetBasicblock(nn.Module): + def __init__(self, inplanes, planes, stride): + super(ResNetBasicblock, self).__init__() + assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride) + self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1) + self.conv_b = ReLUConvBN(planes, planes, 3, 1, 1, 1) + if stride == 2: + self.downsample = nn.Sequential( + nn.AvgPool2d(kernel_size=2, stride=2, padding=0), + nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False)) + elif inplanes != planes: + self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1) + else: + self.downsample = None + self.in_dim = inplanes + self.out_dim = planes + self.stride = stride + self.num_conv = 2 + + def forward(self, inputs): + basicblock = self.conv_a(inputs) + basicblock = self.conv_b(basicblock) + + if self.downsample is not None: + inputs = self.downsample(inputs) # residual + return inputs + basicblock diff --git a/examples/nas/multi-trial/nasbench201/network.py b/examples/nas/multi-trial/nasbench201/network.py new file mode 100644 index 0000000000..8631967867 --- /dev/null +++ b/examples/nas/multi-trial/nasbench201/network.py @@ -0,0 +1,162 @@ +import click +import nni +import nni.retiarii.evaluator.pytorch.lightning as pl +import torch.nn as nn +import torchmetrics +from nni.retiarii import model_wrapper, serialize, serialize_cls +from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig +from nni.retiarii.nn.pytorch import NasBench201Cell +from nni.retiarii.strategy import Random +from pytorch_lightning.callbacks import LearningRateMonitor +from timm.optim import RMSpropTF +from torch.optim.lr_scheduler import CosineAnnealingLR +from torchvision import transforms +from torchvision.datasets import CIFAR100 + +from base_ops import ResNetBasicblock, PRIMITIVES, OPS_WITH_STRIDE + + +@model_wrapper +class NasBench201(nn.Module): + def __init__(self, + stem_out_channels: int = 16, + num_modules_per_stack: int = 5, + num_labels: int = 100): + super().__init__() + self.channels = C = stem_out_channels + self.num_modules = N = num_modules_per_stack + self.num_labels = num_labels + + self.stem = nn.Sequential( + nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(C) + ) + + layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N + layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N + + C_prev = C + self.cells = nn.ModuleList() + for C_curr, reduction in zip(layer_channels, layer_reductions): + if reduction: + cell = ResNetBasicblock(C_prev, C_curr, 2) + else: + cell = NasBench201Cell({prim: lambda C_in, C_out: OPS_WITH_STRIDE[prim](C_in, C_out, 1) for prim in PRIMITIVES}, + C_prev, C_curr, label='cell') + self.cells.append(cell) + C_prev = C_curr + + self.lastact = nn.Sequential( + nn.BatchNorm2d(C_prev), + nn.ReLU(inplace=True) + ) + self.global_pooling = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(C_prev, self.num_labels) + + def forward(self, inputs): + feature = self.stem(inputs) + for cell in self.cells: + feature = cell(feature) + + out = self.lastact(feature) + out = self.global_pooling(out) + out = out.view(out.size(0), -1) + logits = self.classifier(out) + + return logits + + +class AccuracyWithLogits(torchmetrics.Accuracy): + def update(self, pred, target): + return super().update(nn.functional.softmax(pred), target) + + +@serialize_cls +class NasBench201TrainingModule(pl.LightningModule): + def __init__(self, max_epochs=200, learning_rate=0.1, weight_decay=5e-4): + super().__init__() + self.save_hyperparameters('learning_rate', 'weight_decay', 'max_epochs') + self.criterion = nn.CrossEntropyLoss() + self.accuracy = AccuracyWithLogits() + + def forward(self, x): + y_hat = self.model(x) + return y_hat + + def training_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + loss = self.criterion(y_hat, y) + self.log('train_loss', loss, prog_bar=True) + self.log('train_accuracy', self.accuracy(y_hat, y), prog_bar=True) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + y_hat = self(x) + self.log('val_loss', self.criterion(y_hat, y), prog_bar=True) + self.log('val_accuracy', self.accuracy(y_hat, y), prog_bar=True) + + def configure_optimizers(self): + optimizer = RMSpropTF(self.parameters(), lr=self.hparams.learning_rate, + weight_decay=self.hparams.weight_decay, + momentum=0.9, alpha=0.9, eps=1.0) + return { + 'optimizer': optimizer, + 'scheduler': CosineAnnealingLR(optimizer, self.hparams.max_epochs) + } + + def on_validation_epoch_end(self): + nni.report_intermediate_result(self.trainer.callback_metrics['val_accuracy'].item()) + + def teardown(self, stage): + if stage == 'fit': + nni.report_final_result(self.trainer.callback_metrics['val_accuracy'].item()) + + +@click.command() +@click.option('--epochs', default=12, help='Training length.') +@click.option('--batch_size', default=256, help='Batch size.') +@click.option('--port', default=8081, help='On which port the experiment is run.') +def _multi_trial_test(epochs, batch_size, port): + # initalize dataset. Note that 50k+10k is used. It's a little different from paper + transf = [ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip() + ] + normalize = [ + transforms.ToTensor(), + transforms.Normalize([x / 255 for x in [129.3, 124.1, 112.4]], [x / 255 for x in [68.2, 65.4, 70.4]]) + ] + train_dataset = serialize(CIFAR100, 'data', train=True, download=True, transform=transforms.Compose(transf + normalize)) + test_dataset = serialize(CIFAR100, 'data', train=False, transform=transforms.Compose(normalize)) + + # specify training hyper-parameters + training_module = NasBench201TrainingModule(max_epochs=epochs) + # FIXME: need to fix a bug in serializer for this to work + # lr_monitor = serialize(LearningRateMonitor, logging_interval='step') + trainer = pl.Trainer(max_epochs=epochs, gpus=1) + lightning = pl.Lightning( + lightning_module=training_module, + trainer=trainer, + train_dataloader=pl.DataLoader(train_dataset, batch_size=batch_size, shuffle=True), + val_dataloaders=pl.DataLoader(test_dataset, batch_size=batch_size), + ) + + strategy = Random() + + model = NasBench201() + + exp = RetiariiExperiment(model, lightning, [], strategy) + + exp_config = RetiariiExeConfig('local') + exp_config.trial_concurrency = 2 + exp_config.max_trial_number = 20 + exp_config.trial_gpu_number = 1 + exp_config.training_service.use_active_gpu = False + + exp.run(exp_config, port) + + +if __name__ == '__main__': + _multi_trial_test() diff --git a/nni/retiarii/nn/pytorch/component.py b/nni/retiarii/nn/pytorch/component.py index 8a1470b3b6..383e21e7b5 100644 --- a/nni/retiarii/nn/pytorch/component.py +++ b/nni/retiarii/nn/pytorch/component.py @@ -1,4 +1,5 @@ import copy +from collections import OrderedDict from typing import Callable, List, Union, Tuple, Optional import torch @@ -12,7 +13,7 @@ from ...utils import NoContextError -__all__ = ['Repeat', 'Cell', 'NasBench101Cell', 'NasBench101Mutator'] +__all__ = ['Repeat', 'Cell', 'NasBench101Cell', 'NasBench101Mutator', 'NasBench201Cell'] class Repeat(nn.Module): @@ -147,3 +148,77 @@ def forward(self, x: List[torch.Tensor]): current_state = torch.sum(torch.stack(current_state), 0) states.append(current_state) return torch.cat(states[self.num_predecessors:], 1) + + +class NasBench201Cell(nn.Module): + """ + Cell structure that is proposed in NAS-Bench-201 [nasbench201]_ . + + This cell is a densely connected DAG with ``num_tensors`` nodes, where each node is tensor. + For every i < j, there is an edge from i-th node to j-th node. + Each edge in this DAG is associated with an operation transforming the hidden state from the source node + to the target node. All possible operations are selected from a predefined operation set, defined in ``op_candidates``. + Each of the ``op_candidates`` should be a callable that accepts input dimension and output dimension, + and returns a ``Module``. + + Input of this cell should be of shape :math:`[N, C_{in}, *]`, while output should be :math:`[N, C_{out}, *]`. For example, + + The space size of this cell would be :math:`|op|^{N(N-1)/2}`, where :math:`|op|` is the number of operation candidates, + and :math:`N` is defined by ``num_tensors``. + + Parameters + ---------- + op_candidates : list of callable + Operation candidates. Each should be a function accepts input feature and output feature, returning nn.Module. + in_features : int + Input dimension of cell. + out_features : int + Output dimension of cell. + num_tensors : int + Number of tensors in the cell (input included). Default: 4 + label : str + Identifier of the cell. Cell sharing the same label will semantically share the same choice. + + References + ---------- + .. [nasbench201] Dong, X. and Yang, Y., 2020. Nas-bench-201: Extending the scope of reproducible neural architecture search. + arXiv preprint arXiv:2001.00326. + """ + + @staticmethod + def _make_dict(x): + if isinstance(x, list): + return OrderedDict([(str(i), t) for i, t in enumerate(x)]) + return OrderedDict(x) + + def __init__(self, op_candidates: List[Callable[[int, int], nn.Module]], + in_features: int, out_features: int, num_tensors: int = 4, + label: Optional[str] = None): + super().__init__() + self._label = generate_new_label(label) + + self.layers = nn.ModuleList() + self.in_features = in_features + self.out_features = out_features + self.num_tensors = num_tensors + + op_candidates = self._make_dict(op_candidates) + + for tid in range(1, num_tensors): + node_ops = nn.ModuleList() + for j in range(tid): + inp = in_features if j == 0 else out_features + op_choices = OrderedDict([(key, cls(inp, out_features)) + for key, cls in op_candidates.items()]) + node_ops.append(LayerChoice(op_choices, label=f'{self._label}__{j}_{tid}')) + self.layers.append(node_ops) + + def forward(self, inputs): + tensors = [inputs] + for layer in self.layers: + current_tensor = [] + for i, op in enumerate(layer): + current_tensor.append(op(tensors[i])) + current_tensor = torch.sum(torch.stack(current_tensor), 0) + tensors.append(current_tensor) + return tensors[-1] diff --git a/test/ut/retiarii/test_highlevel_apis.py b/test/ut/retiarii/test_highlevel_apis.py index db24f7ec3c..49d4944a6b 100644 --- a/test/ut/retiarii/test_highlevel_apis.py +++ b/test/ut/retiarii/test_highlevel_apis.py @@ -493,6 +493,27 @@ def forward(self, x): model = mutator.bind_sampler(sampler).apply(model) self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(1, 16)).size() == torch.Size([1, 64])) + def test_nasbench201_cell(self): + @self.get_serializer() + class Net(nn.Module): + def __init__(self): + super().__init__() + self.cell = nn.NasBench201Cell([ + lambda x, y: nn.Linear(x, y), + lambda x, y: nn.Linear(x, y, bias=False) + ], 10, 16) + + def forward(self, x): + return self.cell(x) + + raw_model, mutators = self._get_model_with_mutators(Net()) + for _ in range(10): + sampler = EnumerateSampler() + model = raw_model + for mutator in mutators: + model = mutator.bind_sampler(sampler).apply(model) + self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(2, 10)).size() == torch.Size([2, 16])) + class Python(GraphIR): def _get_converted_pytorch_model(self, model_ir):