Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

[Retiarii] NAS-Bench-201 #3920

Merged
merged 6 commits into from
Jul 14, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions examples/nas/multi-trial/nasbench201/base_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import torch
import torch.nn as nn


OPS_WITH_STRIDE = {
'none': lambda C_in, C_out, stride: Zero(C_in, C_out, stride),
'avg_pool_3x3': lambda C_in, C_out, stride: Pooling(C_in, C_out, stride, 'avg'),
'max_pool_3x3': lambda C_in, C_out, stride: Pooling(C_in, C_out, stride, 'max'),
'conv_3x3': lambda C_in, C_out, stride: ReLUConvBN(C_in, C_out, (3, 3), (stride, stride), (1, 1), (1, 1)),
'conv_1x1': lambda C_in, C_out, stride: ReLUConvBN(C_in, C_out, (1, 1), (stride, stride), (0, 0), (1, 1)),
'skip_connect': lambda C_in, C_out, stride: nn.Identity() if stride == 1 and C_in == C_out
else FactorizedReduce(C_in, C_out, stride),
}

PRIMITIVES = ['none', 'skip_connect', 'conv_1x1', 'conv_3x3', 'avg_pool_3x3']


class ReLUConvBN(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation):
super(ReLUConvBN, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_out, kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=False),
nn.BatchNorm2d(C_out)
)

def forward(self, x):
return self.op(x)


class SepConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation):
super(SepConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=C_in, bias=False),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out),
)

def forward(self, x):
return self.op(x)


class Pooling(nn.Module):
def __init__(self, C_in, C_out, stride, mode):
super(Pooling, self).__init__()
if C_in == C_out:
self.preprocess = None
else:
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, 1)
if mode == 'avg':
self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
elif mode == 'max':
self.op = nn.MaxPool2d(3, stride=stride, padding=1)
else:
raise ValueError('Invalid mode={:} in Pooling'.format(mode))

def forward(self, x):
if self.preprocess:
x = self.preprocess(x)
return self.op(x)


class Zero(nn.Module):
def __init__(self, C_in, C_out, stride):
super(Zero, self).__init__()
self.C_in = C_in
self.C_out = C_out
self.stride = stride
self.is_zero = True

def forward(self, x):
if self.C_in == self.C_out:
if self.stride == 1:
return x.mul(0.)
else:
return x[:, :, ::self.stride, ::self.stride].mul(0.)
else:
shape = list(x.shape)
shape[1] = self.C_out
zeros = x.new_zeros(shape, dtype=x.dtype, device=x.device)
return zeros


class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, stride):
super(FactorizedReduce, self).__init__()
self.stride = stride
self.C_in = C_in
self.C_out = C_out
self.relu = nn.ReLU(inplace=False)
if stride == 2:
C_outs = [C_out // 2, C_out - C_out // 2]
self.convs = nn.ModuleList()
for i in range(2):
self.convs.append(nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=False))
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
else:
raise ValueError('Invalid stride : {:}'.format(stride))
self.bn = nn.BatchNorm2d(C_out)

def forward(self, x):
x = self.relu(x)
y = self.pad(x)
out = torch.cat([self.convs[0](x), self.convs[1](y[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
return out


class ResNetBasicblock(nn.Module):
def __init__(self, inplanes, planes, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1)
self.conv_b = ReLUConvBN(planes, planes, 3, 1, 1, 1)
if stride == 2:
self.downsample = nn.Sequential(
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False))
elif inplanes != planes:
self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1)
else:
self.downsample = None
self.in_dim = inplanes
self.out_dim = planes
self.stride = stride
self.num_conv = 2

def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)

if self.downsample is not None:
inputs = self.downsample(inputs) # residual
return inputs + basicblock
162 changes: 162 additions & 0 deletions examples/nas/multi-trial/nasbench201/network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import click
import nni
import nni.retiarii.evaluator.pytorch.lightning as pl
import torch.nn as nn
import torchmetrics
from nni.retiarii import model_wrapper, serialize, serialize_cls
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.nn.pytorch import NasBench201Cell
from nni.retiarii.strategy import Random
from pytorch_lightning.callbacks import LearningRateMonitor
from timm.optim import RMSpropTF
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import transforms
from torchvision.datasets import CIFAR100

from base_ops import ResNetBasicblock, PRIMITIVES, OPS_WITH_STRIDE


@model_wrapper
class NasBench201(nn.Module):
def __init__(self,
stem_out_channels: int = 16,
num_modules_per_stack: int = 5,
num_labels: int = 100):
super().__init__()
self.channels = C = stem_out_channels
self.num_modules = N = num_modules_per_stack
self.num_labels = num_labels

self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C)
)

layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N

C_prev = C
self.cells = nn.ModuleList()
for C_curr, reduction in zip(layer_channels, layer_reductions):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = NasBench201Cell({prim: lambda C_in, C_out: OPS_WITH_STRIDE[prim](C_in, C_out, 1) for prim in PRIMITIVES},
C_prev, C_curr, label='cell')
self.cells.append(cell)
C_prev = C_curr

self.lastact = nn.Sequential(
nn.BatchNorm2d(C_prev),
nn.ReLU(inplace=True)
)
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, self.num_labels)

def forward(self, inputs):
feature = self.stem(inputs)
for cell in self.cells:
feature = cell(feature)

out = self.lastact(feature)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)

return logits


class AccuracyWithLogits(torchmetrics.Accuracy):
def update(self, pred, target):
return super().update(nn.functional.softmax(pred), target)


@serialize_cls
class NasBench201TrainingModule(pl.LightningModule):
def __init__(self, max_epochs=200, learning_rate=0.1, weight_decay=5e-4):
super().__init__()
self.save_hyperparameters('learning_rate', 'weight_decay', 'max_epochs')
self.criterion = nn.CrossEntropyLoss()
self.accuracy = AccuracyWithLogits()

def forward(self, x):
y_hat = self.model(x)
return y_hat

def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.criterion(y_hat, y)
self.log('train_loss', loss, prog_bar=True)
self.log('train_accuracy', self.accuracy(y_hat, y), prog_bar=True)
return loss

def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
self.log('val_loss', self.criterion(y_hat, y), prog_bar=True)
self.log('val_accuracy', self.accuracy(y_hat, y), prog_bar=True)

def configure_optimizers(self):
optimizer = RMSpropTF(self.parameters(), lr=self.hparams.learning_rate,
weight_decay=self.hparams.weight_decay,
momentum=0.9, alpha=0.9, eps=1.0)
return {
'optimizer': optimizer,
'scheduler': CosineAnnealingLR(optimizer, self.hparams.max_epochs)
}

def on_validation_epoch_end(self):
nni.report_intermediate_result(self.trainer.callback_metrics['val_accuracy'].item())

def teardown(self, stage):
if stage == 'fit':
nni.report_final_result(self.trainer.callback_metrics['val_accuracy'].item())


@click.command()
@click.option('--epochs', default=12, help='Training length.')
@click.option('--batch_size', default=256, help='Batch size.')
@click.option('--port', default=8081, help='On which port the experiment is run.')
def _multi_trial_test(epochs, batch_size, port):
# initalize dataset. Note that 50k+10k is used. It's a little different from paper
transf = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip()
]
normalize = [
transforms.ToTensor(),
transforms.Normalize([x / 255 for x in [129.3, 124.1, 112.4]], [x / 255 for x in [68.2, 65.4, 70.4]])
]
train_dataset = serialize(CIFAR100, 'data', train=True, download=True, transform=transforms.Compose(transf + normalize))
test_dataset = serialize(CIFAR100, 'data', train=False, transform=transforms.Compose(normalize))

# specify training hyper-parameters
training_module = NasBench201TrainingModule(max_epochs=epochs)
# FIXME: need to fix a bug in serializer for this to work
# lr_monitor = serialize(LearningRateMonitor, logging_interval='step')
trainer = pl.Trainer(max_epochs=epochs, gpus=1)
lightning = pl.Lightning(
lightning_module=training_module,
trainer=trainer,
train_dataloader=pl.DataLoader(train_dataset, batch_size=batch_size, shuffle=True),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=batch_size),
)

strategy = Random()

model = NasBench201()

exp = RetiariiExperiment(model, lightning, [], strategy)

exp_config = RetiariiExeConfig('local')
exp_config.trial_concurrency = 2
exp_config.max_trial_number = 20
exp_config.trial_gpu_number = 1
exp_config.training_service.use_active_gpu = False

exp.run(exp_config, port)


if __name__ == '__main__':
_multi_trial_test()
77 changes: 76 additions & 1 deletion nni/retiarii/nn/pytorch/component.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
from collections import OrderedDict
from typing import Callable, List, Union, Tuple, Optional

import torch
Expand All @@ -12,7 +13,7 @@
from ...utils import NoContextError


__all__ = ['Repeat', 'Cell', 'NasBench101Cell', 'NasBench101Mutator']
__all__ = ['Repeat', 'Cell', 'NasBench101Cell', 'NasBench101Mutator', 'NasBench201Cell']


class Repeat(nn.Module):
Expand Down Expand Up @@ -147,3 +148,77 @@ def forward(self, x: List[torch.Tensor]):
current_state = torch.sum(torch.stack(current_state), 0)
states.append(current_state)
return torch.cat(states[self.num_predecessors:], 1)


class NasBench201Cell(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is supposed to be supported by base execution engine. have you tested it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There has been unittests on both execution engines.

"""
Cell structure that is proposed in NAS-Bench-201 [nasbench201]_ .

This cell is a densely connected DAG with ``num_tensors`` nodes, where each node is tensor.
For every i < j, there is an edge from i-th node to j-th node.
Each edge in this DAG is associated with an operation transforming the hidden state from the source node
to the target node. All possible operations are selected from a predefined operation set, defined in ``op_candidates``.
Each of the ``op_candidates`` should be a callable that accepts input dimension and output dimension,
and returns a ``Module``.

Input of this cell should be of shape :math:`[N, C_{in}, *]`, while output should be :math:`[N, C_{out}, *]`. For example,

The space size of this cell would be :math:`|op|^{N(N-1)/2}`, where :math:`|op|` is the number of operation candidates,
and :math:`N` is defined by ``num_tensors``.

Parameters
----------
op_candidates : list of callable
Operation candidates. Each should be a function accepts input feature and output feature, returning nn.Module.
in_features : int
Input dimension of cell.
out_features : int
Output dimension of cell.
num_tensors : int
Number of tensors in the cell (input included). Default: 4
label : str
Identifier of the cell. Cell sharing the same label will semantically share the same choice.

References
----------
.. [nasbench201] Dong, X. and Yang, Y., 2020. Nas-bench-201: Extending the scope of reproducible neural architecture search.
arXiv preprint arXiv:2001.00326.
"""

@staticmethod
def _make_dict(x):
if isinstance(x, list):
return OrderedDict([(str(i), t) for i, t in enumerate(x)])
return OrderedDict(x)

def __init__(self, op_candidates: List[Callable[[int, int], nn.Module]],
in_features: int, out_features: int, num_tensors: int = 4,
label: Optional[str] = None):
super().__init__()
self._label = generate_new_label(label)

self.layers = nn.ModuleList()
self.in_features = in_features
self.out_features = out_features
self.num_tensors = num_tensors

op_candidates = self._make_dict(op_candidates)

for tid in range(1, num_tensors):
node_ops = nn.ModuleList()
for j in range(tid):
inp = in_features if j == 0 else out_features
op_choices = OrderedDict([(key, cls(inp, out_features))
for key, cls in op_candidates.items()])
node_ops.append(LayerChoice(op_choices, label=f'{self._label}__{j}_{tid}'))
self.layers.append(node_ops)

def forward(self, inputs):
tensors = [inputs]
for layer in self.layers:
current_tensor = []
for i, op in enumerate(layer):
current_tensor.append(op(tensors[i]))
current_tensor = torch.sum(torch.stack(current_tensor), 0)
tensors.append(current_tensor)
return tensors[-1]
Loading