Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved fine-tuning, ConvNeXt support, improved training speed of GHNs #7

Merged
merged 10 commits into from
Jul 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ authors: [Boris Knyazev](http://bknyaz.github.io/), [Michal Drozdzal](https://sc


**Updates**
- [Jul 21, 2022] Fine-tuning of predicted parameters is improved and parameter prediction for [ConvNeXt](https://arxiv.org/abs/2201.03545) is added (see [report](https://arxiv.org/abs/2207.10049) and respective code changes in [PR#7](https://github.com/facebookresearch/ppuda/pull/7))
- [Jul 21, 2022] Training speed of GHNs is further improved (see [PR#7](https://github.com/facebookresearch/ppuda/pull/7) for details).
- [Jan 12, 2022] Training speed of GHNs is improved significantly in some cases (see [PR#2](https://github.com/facebookresearch/ppuda/pull/2) for details).
- [Nov 24, 2021] Video of Yannic Kilcher reviewing our paper together with Boris Knyazev is available on [YouTube](https://youtu.be/3HUK2UWzlFA)

Expand Down Expand Up @@ -222,6 +224,20 @@ where `$split` is one from `val, test, wide, deep, dense, bnfree, predefined`, `
The parameters predicted by GHN-2 trained on ImageNet can be fine-tuned on any vision dataset, such as CIFAR-10.


**[Update Jul 21, 2022]**

According to the report ([Pretraining a Neural Network before Knowing Its Architecture](https://arxiv.org/abs/2207.10049)) showing improved fine-tuning results, the following arguments are added to the code: `--opt`, `--init`, `--imsize`, `--beta`, `--layer`.

- For example, to obtain fine-tuning results of `GHN-orth` for **ResNet-50**:
`python experiments/sgd/train_net.py --val --split predefined --arch 0 --epochs 300 -d cifar10 --n_shots 100 --lr 0.01 --wd 0.01 --ckpt ./checkpoints/ghn2_imagenet.pt --opt sgd --init orth --imsize 32 --beta 3e-5 --layer 37`

- For **[ConvNeXt-Base](https://arxiv.org/abs/2201.03545)**:
`python experiments/sgd/train_net.py --val --arch convnext_base -b 48 --epochs 300 -d cifar10 --n_shots 100 --lr 0.001 --wd 0.1 --ckpt ./checkpoints/ghn2_imagenet.pt --opt adamw --init orth --imsize 32 --beta 3e-5 --layer 94`.
Multiple warnings will be printed that some layers of ConvNeXt are not supported by GHNs, which is intended. Note that in the report, layer 100 is mistakenly specified as the best value, however 94 should be used for better performance.


Below are the commands to reproduce the original (NeurIPS 2021) results.

### 100-shot CIFAR-10

- Fine-tune **ResNet-50** initialized with the parameters predicted by **GHN-1-ImageNet**: `python experiments/sgd/train_net.py --split predefined --arch 0 --epochs 50 -d cifar10 --n_shots 100 --wd 1e-3 --ckpt ./checkpoints/ghn1_imagenet.pt`
Expand Down
4 changes: 3 additions & 1 deletion examples/torch_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@

Example:

python examples/torch_models.py imagenet resnet50
1. python examples/torch_models.py imagenet resnet50

2. python examples/torch_models.py cifar10 convnext_base

"""

Expand Down
81 changes: 57 additions & 24 deletions experiments/sgd/train_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@
import torch
import torchvision
import torch.utils
import time
import os
from ppuda.config import init_config
from ppuda.vision.loader import image_loader
from ppuda.deepnets1m.net import Network
from ppuda.deepnets1m.loader import DeepNets1M
from ppuda.deepnets1m.genotypes import ViT, DARTS
import ppuda.deepnets1m.genotypes as genotypes
from ppuda.utils import capacity, adjust_net, infer, pretrained_model, Trainer
from ppuda.utils import capacity, adjust_net, infer, pretrained_model, Trainer, init
from ppuda.ghn.nn import GHN


Expand All @@ -45,7 +46,8 @@ def main():
is_imagenet = args.dataset == 'imagenet'
train_queue, valid_queue, num_classes = image_loader(dataset=args.dataset,
data_dir=args.data_dir,
test=True,
test=not args.val,
im_size=args.imsize,
load_train_anyway=True,
batch_size=args.batch_size,
test_batch_size=args.test_batch_size,
Expand All @@ -56,48 +58,74 @@ def main():
noise=args.noise,
n_shots=args.n_shots)

if args.val:
test_queue = image_loader(dataset=args.dataset,
data_dir=args.data_dir,
test=True,
im_size=args.imsize,
test_batch_size=args.test_batch_size)[1]

assert args.arch is not None, 'architecture genotype/index must be specified'

try:
genotype = eval('genotypes.%s' % args.arch)
arch = genotype = eval('genotypes.%s' % args.arch)
net_args = {'C': args.init_channels,
'genotype': genotype,
'n_cells': args.layers,
'C_mult': int(genotype != ViT) + 1, # assume either ViT or DARTS-style architecture
'preproc': genotype != ViT,
'stem_type': 1} # assume that the ImageNet-style stem is used by default
except:
deepnets = DeepNets1M(split=args.split,
nets_dir=args.data_dir,
large_images=is_imagenet,
arch=args.arch)
assert len(deepnets) == 1, 'one architecture must be chosen to train'
graph = deepnets[0]
net_args, idx = graph.net_args, graph.net_idx
if 'norm' in net_args and net_args['norm'] == 'bn':
net_args['norm'] = 'bn-track'
if isinstance(net_args['genotype'], str):
model = adjust_net(eval('torchvision.models.%s(pretrained=%d)' % (net_args['genotype'], args.pretrained)), is_imagenet)
except (SyntaxError, AttributeError):
try:
arch = args.arch
if args.arch is not None:
arch = int(args.arch)
deepnets = DeepNets1M(split=args.split,
nets_dir=args.data_dir,
large_images=is_imagenet,
arch=arch)
assert len(deepnets) == 1, 'one architecture must be chosen to train'
graph = deepnets[0]
net_args, idx = graph.net_args, graph.net_idx
if 'norm' in net_args and net_args['norm'] == 'bn':
net_args['norm'] = 'bn-track'
arch = net_args['genotype']
except ValueError:
arch = args.arch.lower()

if isinstance(arch, str):
model = adjust_net(eval('torchvision.models.%s(pretrained=%d,num_classes=%d)' %
(arch, args.pretrained, 1000 if args.pretrained else num_classes)),
is_imagenet or args.imsize > 32)
else:
model = Network(num_classes=num_classes,
is_imagenet_input=is_imagenet,
is_imagenet_input=is_imagenet or args.imsize > 32,
auxiliary=args.auxiliary,
**net_args)

if args.ckpt is not None or isinstance(model, torchvision.models.ResNet):
if (args.ckpt or (args.pretrained and model.__module__.startswith('torchvision.models'))):
assert bool(args.ckpt is not None) != args.pretrained, 'ckpt and pretrained are mutually exclusive'
model.expected_input_sz = args.imsize
model = pretrained_model(model, args.ckpt, num_classes, args.debug, GHN)

model = init(model,
orth=args.init.lower() == 'orth',
beta=args.beta,
layer=args.layer,
verbose=args.debug > 1)

model = model.train().to(args.device)

print('\nTraining arch={} with {} parameters'.format(args.arch, capacity(model)[1]))
print('Training arch={} with {} parameters'.format(args.arch, capacity(model)[1]))

optimizer = torch.optim.SGD(
model.parameters(),
args.lr,
momentum=args.momentum,
weight_decay=args.wd
)
if args.opt == 'sgd':
optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=args.momentum, weight_decay=args.wd)
elif args.opt == 'adam':
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
elif args.opt == 'adamw':
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd)
else:
raise NotImplementedError(args.opt)

if is_imagenet:
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1, 0.97)
Expand Down Expand Up @@ -137,5 +165,10 @@ def main():

scheduler.step()

if args.val:
infer(model.eval(), test_queue, verbose=True)

print('\ndone at {}!'.format(time.strftime('%Y%m%d-%H%M%S')))

if __name__ == '__main__':
main()
4 changes: 2 additions & 2 deletions experiments/train_ghn.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def main():


ghn = GHN(**config,
debug_level=args.debug).to(args.device)
debug_level=0).to(args.device)

if state_dict is not None:
ghn.load_state_dict(state_dict['state_dict'])
Expand Down Expand Up @@ -125,7 +125,7 @@ def main():
for nets_args in graphs.net_args:
net = Network(is_imagenet_input=is_imagenet,
num_classes=num_classes,
compress_params=True,
light=True,
**nets_args)
nets_torch.append(net)

Expand Down
68 changes: 68 additions & 0 deletions experiments/train_ghn_stable.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

"""
A wrapper to automatically restart training GHNs from the last saved checkpoint,
e.g. in the case of CUDA OOM or nan loss that can frequently occur due to sampling training architectures.

Example:

# To train GHN-2 on CIFAR-10:
python experiments/train_ghn_stable.py experiments/train_ghn.py -m 8 -n -v 50 --ln --name ghn2-cifar10

"""


import os
import sys
from subprocess import PIPE, run

args = sys.argv[1:]
ckpt_args = None
attempts = 0

while attempts < 100: # let's allow for resuming this job 100 times

attempts += 1
print('\nrunning the script time #%d with args:' % attempts, args, ckpt_args, '\n')

result = run(['python'] + args + ([] if ckpt_args is None else ckpt_args),
stderr=PIPE, text=True)

print('script returned:', result)
print('\nreturned code:', result.returncode)


if result.returncode != 0:

print('Script failed!')

print('\nERROR:', result.stderr)

if result.returncode == 2 and result.stderr.find('[Errno 2] No such file or directory') >= 0:
print('\nRun this script as `python experiments/train_ghn_stable.py experiments/train_ghn.py [args]`\n')
break

elif result.stderr.find('RuntimeError') < 0:
print('\nPlease fix the above printed error and restart the script\n')
break

print('restarting the script')
n1 = result.stderr.find('use this ckpt for resuming the script:')
if n1 >= 0:
n1 = result.stderr[n1:].find(':') + n1
n2 = result.stderr[n1:].find('\n') + n1
ckpt = result.stderr[n1 + 2 : n2]
print('parsed path:', ckpt, 'exists:', os.path.exists(ckpt))
if os.path.exists(ckpt):
ckpt_args = ['--ckpt', ckpt]
else:
print('saved checkpoint file is missing, will be starting from scratch')
else:
print('no saved checkpoint found')
continue
else:
break
20 changes: 17 additions & 3 deletions ppuda/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import time
import os
import torch
import torchvision
import torch.backends.cudnn as cudnn
from .utils import set_seed, default_device

Expand All @@ -34,8 +35,12 @@
env['git commit'] = 'no git'
env['hostname'] = platform.node()
env['torch'] = torch.__version__
if env['torch'][0] in ['0', '1'] and not env['torch'].startswith('1.9'):
print('WARNING: pytorch >= 1.9 is strongly recommended for this repo!')
env['torchvision'] = torchvision.__version__
try:
assert list(map(lambda x: float(x), env['torch'].split('.')[:2])) >= [1, 9]
except:
print('WARNING: PyTorch version {} is used, but version >= 1.9 is strongly recommended for this repo!'.format(env['torch']))


env['cuda available'] = torch.cuda.is_available()
env['cudnn enabled'] = cudnn.enabled
Expand Down Expand Up @@ -65,7 +70,7 @@ def init_config(mode='eval'):
help='number of cpu processes to use')
parser.add_argument('--device', type=str, default=default_device(), help='device: cpu or cuda')
parser.add_argument('--debug', type=int, default=1, help='the level of details printed out, 0 is the minimal level.')
parser.add_argument('--ckpt', type=str, default=None,
parser.add_argument('-C', '--ckpt', type=str, default=None,
help='path to load the network/GHN parameters from')

is_train_ghn = mode == 'train_ghn'
Expand Down Expand Up @@ -109,6 +114,8 @@ def init_config(mode='eval'):

parser.add_argument('-b', '--batch_size', type=int, default=batch_size, help='image batch size for training')
parser.add_argument('-e', '--epochs', type=int, default=epochs, help='number of epochs to train')
parser.add_argument('--opt', type=str, default='sgd' if is_train_net else 'adam',
choices=['sgd', 'adam', 'adamw'], help='optimizer')
parser.add_argument('--lr', type=float, default=lr, help='initial learning rate')
parser.add_argument('--grad_clip', type=float, default=5, help='grad clip')
parser.add_argument('-l', '--log_interval', type=int, default=10 if is_detection else 100,
Expand Down Expand Up @@ -152,6 +159,13 @@ def init_config(mode='eval'):
parser.add_argument('--drop_path_prob', type=float, default=0, help='drop path probability')
parser.add_argument('--n_shots', type=int, default=None, help='number of training images per class for fine-tuning experiments')

parser.add_argument('--init', type=str, default='rand', choices=['rand', 'orth'], help='init method')
parser.add_argument('--layer', type=int, default=0, help='layer after each to add noise')
parser.add_argument('--beta', type=float, default=0,
help='standard deviation of the Gaussian noise added to parameters')
parser.add_argument('--imsize', type=int, default=224 if is_imagenet else 32,
choices=[32, 224], help='image size used to train and eval models')
parser.add_argument('--val', action='store_true', default=False, help='evaluate on the validation set')

args = parser.parse_args()

Expand Down
45 changes: 38 additions & 7 deletions ppuda/deepnets1m/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,11 @@ def __init__(self, model=None, node_feat=None, node_info=None, A=None, edges=Non
self.nx_graph = None # NetworkX DiGraph instance

if model is not None:

if isinstance(model, torchvision.models.vision_transformer.VisionTransformer):
raise NotImplementedError('Official PyTorch VisionTransformer module is not supported in the graph construction process. '
'Use the deepnets1m.net.Network class and deepnets1m.genotypes.ViT to construct it.')

sz = model.expected_input_sz if hasattr(model, 'expected_input_sz') else 224 # assume ImageNet image width/heigh by default
self.expected_input_sz = sz if isinstance(sz, (tuple, list)) else (3, sz, sz) # assume images by default
self.n_cells = self.model._n_cells if hasattr(self.model, '_n_cells') else 1
Expand Down Expand Up @@ -305,6 +310,18 @@ def traverse_graph(fn):
if hasattr(u[0], 'variable'):
var = u[0].variable
name, module = param_map[id(var)]

if type(module) not in MODULES:
print('WARNING: unrecognized layer {}, params = {}, type = {}'.format(name,
sum([p.numel() for
n, p in
module.named_parameters()
if (n.find('layer_scale') >= 0 or
n.find('weight') >= 0 or
n.find('bias') >= 0)]),
type(module)))
continue

if type(module) in NormLayers and name.find('.bias') >= 0:
continue # do not add biases of NormLayers as nodes
leaf_nodes.append({'id': u[0],
Expand Down Expand Up @@ -585,15 +602,20 @@ def _named_modules(self):
Helper function to automatically build the graphs.
:return:
"""

modules = {}
for n, m in self.model.named_modules():
is_w = hasattr(m, 'weight') and m.weight is not None
is_b = hasattr(m, 'bias') and m.bias is not None
if is_w:
modules[n + '.weight'] = (m.weight, m)
if is_b:
modules[n + '.bias'] = (m.bias, m)
for np, p in m.named_parameters(recurse=False):
if p is None:
continue
key = n + '.' + np
if key in modules:
assert id(p) == id(modules[key][0]), (n, np, p.shape, modules[key][0].shape)
continue
modules[key] = (p, m)

n_params = len(list(self.model.named_parameters()))
assert len(modules) == n_params, (len(modules), n_params)

return modules


Expand Down Expand Up @@ -758,3 +780,12 @@ def get_conv_name(module, param_name):
'Add': 'sum',
'Cat': 'concat',
}

try:
import torchvision
MODULES[torchvision.models.convnext.LayerNorm2d] = MODULES[nn.LayerNorm]
# MODULES[torchvision.models.convnext.CNBlock] = MODULES[nn.LayerNorm]
# we can pretend that layer_scale in CNBlock is the same as LayerNorm and this way predict all ConvNeXt parameters,
# but this does not have benefits in fine-tuning
except Exception as e:
print(e, 'convnext requires torchvision >= 0.12, current version is ', torchvision.__version__)
Loading