From 93ffecf0ac4d55ccd2537b4c0fb4ec68e4d4454b Mon Sep 17 00:00:00 2001 From: zhoujiaming Date: Mon, 12 Oct 2020 20:21:53 +0800 Subject: [PATCH 1/4] add votenet model convert script --- tools/convert_votenet_checkpoints.py | 108 +++++++++++++++++++++++++++ 1 file changed, 108 insertions(+) create mode 100644 tools/convert_votenet_checkpoints.py diff --git a/tools/convert_votenet_checkpoints.py b/tools/convert_votenet_checkpoints.py new file mode 100644 index 0000000000..276bfce606 --- /dev/null +++ b/tools/convert_votenet_checkpoints.py @@ -0,0 +1,108 @@ +import argparse +import torch +from mmcv import Config +from mmcv.runner import load_state_dict + +from mmdet3d.models import build_detector + + +def parse_args(): + parser = argparse.ArgumentParser( + description='MMDet test (and eval) a model') + parser.add_argument('config', help='test config file path') + parser.add_argument('checkpoint', help='checkpoint file') + parser.add_argument('--out', help='path of the output checkpoint file') + parser.add_argument( + '--model', + choices=['sunrgbd', 'scannet'], + default='sunrgbd', + help='type of the model') + args = parser.parse_args() + return args + + +def main(): + args = parse_args() + cfg = Config.fromfile(args.config) + + if args.model == 'scannet': + NUM_CLASSES = 18 + else: + NUM_CLASSES = 10 + + EXTRACT_KEYS = { + 'bbox_head.conv_pred.conv_cls.weight': + ('bbox_head.conv_pred.conv_out.weight', [(0, 2), (-NUM_CLASSES, -1)]), + 'bbox_head.conv_pred.conv_cls.bias': + ('bbox_head.conv_pred.conv_out.bias', [(0, 2), (-NUM_CLASSES, -1)]), + 'bbox_head.conv_pred.conv_reg.weight': + ('bbox_head.conv_pred.conv_out.weight', [(2, -NUM_CLASSES)]), + 'bbox_head.conv_pred.conv_reg.bias': + ('bbox_head.conv_pred.conv_out.bias', [(2, -NUM_CLASSES)]) + } + + RENAME_KEYS = { + 'bbox_head.conv_pred.shared_convs.layer0.conv.weight': + 'bbox_head.conv_pred.0.conv.weight', + 'bbox_head.conv_pred.shared_convs.layer0.conv.bias': + 'bbox_head.conv_pred.0.conv.bias', + 'bbox_head.conv_pred.shared_convs.layer0.bn.weight': + 'bbox_head.conv_pred.0.bn.weight', + 'bbox_head.conv_pred.shared_convs.layer0.bn.bias': + 'bbox_head.conv_pred.0.bn.bias', + 'bbox_head.conv_pred.shared_convs.layer0.bn.running_mean': + 'bbox_head.conv_pred.0.bn.running_mean', + 'bbox_head.conv_pred.shared_convs.layer0.bn.running_var': + 'bbox_head.conv_pred.0.bn.running_var', + 'bbox_head.conv_pred.shared_convs.layer1.conv.weight': + 'bbox_head.conv_pred.1.conv.weight', + 'bbox_head.conv_pred.shared_convs.layer1.conv.bias': + 'bbox_head.conv_pred.1.conv.bias', + 'bbox_head.conv_pred.shared_convs.layer1.bn.weight': + 'bbox_head.conv_pred.1.bn.weight', + 'bbox_head.conv_pred.shared_convs.layer1.bn.bias': + 'bbox_head.conv_pred.1.bn.bias', + 'bbox_head.conv_pred.shared_convs.layer1.bn.running_mean': + 'bbox_head.conv_pred.1.bn.running_mean', + 'bbox_head.conv_pred.shared_convs.layer1.bn.running_var': + 'bbox_head.conv_pred.1.bn.running_var' + } + + DEL_KEYS = [ + 'bbox_head.conv_pred.0.bn.num_batches_tracked', + 'bbox_head.conv_pred.1.bn.num_batches_tracked' + ] + + # build the model and load checkpoint + model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) + checkpoint = torch.load(args.checkpoint) + orig_ckpt = checkpoint['state_dict'] + converted_ckpt = orig_ckpt.copy() + + for new_key, old_key in RENAME_KEYS.items(): + converted_ckpt[new_key] = converted_ckpt.pop(old_key) + + for new_key, (old_key, indices) in EXTRACT_KEYS.items(): + cur_layers = orig_ckpt[old_key] + converted_layers = [] + for (start, end) in indices: + if end != -1: + converted_layers.append(cur_layers[start:end]) + else: + converted_layers.append(cur_layers[start:]) + converted_layers = torch.cat(converted_layers, 0) + converted_ckpt[new_key] = converted_layers + if old_key in converted_ckpt.keys(): + converted_ckpt.pop(old_key) + + for key in DEL_KEYS: + converted_ckpt.pop(key) + + # check the converted checkpoint by loading to the model + load_state_dict(model, converted_ckpt, strict=True) + checkpoint['state_dict'] = converted_ckpt + torch.save(checkpoint, args.out) + + +if __name__ == '__main__': + main() From 184f7f13e9a3bfb627ae7a9962e32c9581173cff Mon Sep 17 00:00:00 2001 From: zhoujiaming Date: Mon, 12 Oct 2020 20:29:20 +0800 Subject: [PATCH 2/4] fix bugs of votenet config --- configs/_base_/models/votenet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/configs/_base_/models/votenet.py b/configs/_base_/models/votenet.py index 4764aa2a22..3bddf30ace 100644 --- a/configs/_base_/models/votenet.py +++ b/configs/_base_/models/votenet.py @@ -38,7 +38,8 @@ mlp_channels=[256, 128, 128, 128], use_xyz=True, normalize_xyz=True), - pred_layer_cfg=dict(in_channels=128, shared_conv_channels=(128, 128)), + pred_layer_cfg=dict( + in_channels=128, shared_conv_channels=(128, 128), bias=True), conv_cfg=dict(type='Conv1d'), norm_cfg=dict(type='BN1d'), objectness_loss=dict( From f76c77c3f0cb7e2e1be8ce31b92f4363678ec6ab Mon Sep 17 00:00:00 2001 From: zhoujiaming Date: Wed, 21 Oct 2020 15:26:56 +0800 Subject: [PATCH 3/4] modify scripts --- tools/convert_votenet_checkpoints.py | 136 +++++++++++++++++---------- 1 file changed, 85 insertions(+), 51 deletions(-) diff --git a/tools/convert_votenet_checkpoints.py b/tools/convert_votenet_checkpoints.py index 276bfce606..eef02f38c4 100644 --- a/tools/convert_votenet_checkpoints.py +++ b/tools/convert_votenet_checkpoints.py @@ -1,4 +1,5 @@ import argparse +import tempfile import torch from mmcv import Config from mmcv.runner import load_state_dict @@ -8,27 +9,88 @@ def parse_args(): parser = argparse.ArgumentParser( - description='MMDet test (and eval) a model') - parser.add_argument('config', help='test config file path') + description='MMDet3D upgrade model version(before v0.6.0) of VoteNet') parser.add_argument('checkpoint', help='checkpoint file') parser.add_argument('--out', help='path of the output checkpoint file') - parser.add_argument( - '--model', - choices=['sunrgbd', 'scannet'], - default='sunrgbd', - help='type of the model') args = parser.parse_args() return args +def parse_config(config_strings): + """Parse config from strings. + + Args: + config_strings (string): strings of model config. + + Returns: + Config: model config + """ + temp_file = tempfile.NamedTemporaryFile() + config_path = f'{temp_file.name}.py' + with open(config_path, 'w') as f: + f.write(config_strings) + + config = Config.fromfile(config_path) + + # Update backbone config + if 'pool_mod' in config.model.backbone: + config.model.backbone.pop('pool_mod') + + if 'sa_cfg' not in config.model.backbone: + config.model.backbone['sa_cfg'] = dict( + type='PointSAModule', + pool_mod='max', + use_xyz=True, + normalize_xyz=True) + + if 'type' not in config.model.bbox_head.vote_aggregation_cfg: + config.model.bbox_head.vote_aggregation_cfg['type'] = 'PointSAModule' + + # Update bbox_head config + if 'pred_layer_cfg' not in config.model.bbox_head: + config.model.bbox_head['pred_layer_cfg'] = dict( + in_channels=128, shared_conv_channels=(128, 128), bias=True) + + if 'feat_channels' in config.model.bbox_head: + config.model.bbox_head.pop('feat_channels') + + if 'vote_moudule_cfg' in config.model.bbox_head: + config.model.bbox_head['vote_module_cfg'] = config.model.bbox_head.pop( + 'vote_moudule_cfg') + + if config.model.bbox_head.vote_aggregation_cfg.use_xyz: + config.model.bbox_head.vote_aggregation_cfg.mlp_channels[0] -= 3 + + temp_file.close() + + return config + + def main(): args = parse_args() - cfg = Config.fromfile(args.config) + checkpoint = torch.load(args.checkpoint) + cfg = parse_config(checkpoint['meta']['config']) + # Build the model and load checkpoint + model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) + orig_ckpt = checkpoint['state_dict'] + converted_ckpt = orig_ckpt.copy() - if args.model == 'scannet': + if cfg['dataset_type'] == 'ScanNetDataset': NUM_CLASSES = 18 - else: + elif cfg['dataset_type'] == 'SUNRGBDDataset': NUM_CLASSES = 10 + else: + raise NotImplementedError + + RENAME_PREFIX = { + 'bbox_head.conv_pred.0': 'bbox_head.conv_pred.shared_convs.layer0', + 'bbox_head.conv_pred.1': 'bbox_head.conv_pred.shared_convs.layer1' + } + + DEL_KEYS = [ + 'bbox_head.conv_pred.0.bn.num_batches_tracked', + 'bbox_head.conv_pred.1.bn.num_batches_tracked' + ] EXTRACT_KEYS = { 'bbox_head.conv_pred.conv_cls.weight': @@ -41,47 +103,22 @@ def main(): ('bbox_head.conv_pred.conv_out.bias', [(2, -NUM_CLASSES)]) } - RENAME_KEYS = { - 'bbox_head.conv_pred.shared_convs.layer0.conv.weight': - 'bbox_head.conv_pred.0.conv.weight', - 'bbox_head.conv_pred.shared_convs.layer0.conv.bias': - 'bbox_head.conv_pred.0.conv.bias', - 'bbox_head.conv_pred.shared_convs.layer0.bn.weight': - 'bbox_head.conv_pred.0.bn.weight', - 'bbox_head.conv_pred.shared_convs.layer0.bn.bias': - 'bbox_head.conv_pred.0.bn.bias', - 'bbox_head.conv_pred.shared_convs.layer0.bn.running_mean': - 'bbox_head.conv_pred.0.bn.running_mean', - 'bbox_head.conv_pred.shared_convs.layer0.bn.running_var': - 'bbox_head.conv_pred.0.bn.running_var', - 'bbox_head.conv_pred.shared_convs.layer1.conv.weight': - 'bbox_head.conv_pred.1.conv.weight', - 'bbox_head.conv_pred.shared_convs.layer1.conv.bias': - 'bbox_head.conv_pred.1.conv.bias', - 'bbox_head.conv_pred.shared_convs.layer1.bn.weight': - 'bbox_head.conv_pred.1.bn.weight', - 'bbox_head.conv_pred.shared_convs.layer1.bn.bias': - 'bbox_head.conv_pred.1.bn.bias', - 'bbox_head.conv_pred.shared_convs.layer1.bn.running_mean': - 'bbox_head.conv_pred.1.bn.running_mean', - 'bbox_head.conv_pred.shared_convs.layer1.bn.running_var': - 'bbox_head.conv_pred.1.bn.running_var' - } - - DEL_KEYS = [ - 'bbox_head.conv_pred.0.bn.num_batches_tracked', - 'bbox_head.conv_pred.1.bn.num_batches_tracked' - ] - - # build the model and load checkpoint - model = build_detector(cfg.model, train_cfg=None, test_cfg=cfg.test_cfg) - checkpoint = torch.load(args.checkpoint) - orig_ckpt = checkpoint['state_dict'] - converted_ckpt = orig_ckpt.copy() + # Delete some useless keys + for key in DEL_KEYS: + converted_ckpt.pop(key) + # Rename keys with specific prefix + RENAME_KEYS = dict() + for old_key in converted_ckpt.keys(): + for rename_prefix in RENAME_PREFIX.keys(): + if rename_prefix in old_key: + new_key = old_key.replace(rename_prefix, + RENAME_PREFIX[rename_prefix]) + RENAME_KEYS[new_key] = old_key for new_key, old_key in RENAME_KEYS.items(): converted_ckpt[new_key] = converted_ckpt.pop(old_key) + # Extract weights and rename the keys for new_key, (old_key, indices) in EXTRACT_KEYS.items(): cur_layers = orig_ckpt[old_key] converted_layers = [] @@ -95,10 +132,7 @@ def main(): if old_key in converted_ckpt.keys(): converted_ckpt.pop(old_key) - for key in DEL_KEYS: - converted_ckpt.pop(key) - - # check the converted checkpoint by loading to the model + # Check the converted checkpoint by loading to the model load_state_dict(model, converted_ckpt, strict=True) checkpoint['state_dict'] = converted_ckpt torch.save(checkpoint, args.out) From 23458a79162d58c66fc469cecb36ad623a001565 Mon Sep 17 00:00:00 2001 From: zhoujiaming Date: Wed, 21 Oct 2020 15:31:02 +0800 Subject: [PATCH 4/4] add docstring --- tools/convert_votenet_checkpoints.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tools/convert_votenet_checkpoints.py b/tools/convert_votenet_checkpoints.py index eef02f38c4..5996ec5af8 100644 --- a/tools/convert_votenet_checkpoints.py +++ b/tools/convert_votenet_checkpoints.py @@ -67,6 +67,12 @@ def parse_config(config_strings): def main(): + """Convert keys in checkpoints for VoteNet. + + There can be some breaking changes during the development of mmdetection3d, + and this tool is used for upgrading checkpoints trained with old versions + (before v0.6.0) to the latest one. + """ args = parse_args() checkpoint = torch.load(args.checkpoint) cfg = parse_config(checkpoint['meta']['config'])