-
Notifications
You must be signed in to change notification settings - Fork 0
/
convert_model.py
executable file
·88 lines (64 loc) · 2.96 KB
/
convert_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#!/usr/bin/env python
"""
Copyright 2017-2018 Fizyr (https://fizyr.com)
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import argparse
import os
import sys
import keras
import tensorflow as tf
# Allow relative imports when being executed as script.
if __name__ == "__main__" and __package__ is None:
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', '..'))
import keras_retinanet.bin # noqa: F401
__package__ = "keras_retinanet.bin"
# Change these to absolute imports if you copy this script outside the keras_retinanet package.
from .. import models
from ..utils.config import read_config_file, parse_anchor_parameters
def get_session():
""" Construct a modified tf session.
"""
config = tf.ConfigProto()
os.environ["CUDA_VISIBLE_DEVICES"] = ""
return tf.Session(config=config)
def parse_args(args):
parser = argparse.ArgumentParser(description='Script for converting a training model to an inference model.')
parser.add_argument('model_in', help='The model to convert.')
parser.add_argument('model_out', help='Path to save the converted model to.')
parser.add_argument('--backbone', help='The backbone of the model to convert.', default='resnet50')
parser.add_argument('--no-nms', help='Disables non maximum suppression.', dest='nms', action='store_false')
parser.add_argument('--no-class-specific-filter', help='Disables class specific filtering.', dest='class_specific_filter', action='store_false')
parser.add_argument('--config', help='Path to a configuration parameters .ini file.')
return parser.parse_args(args)
def main(args=None):
# parse arguments
if args is None:
args = sys.argv[1:]
args = parse_args(args)
# Set modified tf session to avoid using the GPUs
keras.backend.tensorflow_backend.set_session(get_session())
# optionally load config parameters
anchor_parameters = None
if args.config:
args.config = read_config_file(args.config)
if 'anchor_parameters' in args.config:
anchor_parameters = parse_anchor_parameters(args.config)
# load the model
model = models.load_model(args.model_in, backbone_name=args.backbone)
# check if this is indeed a training model
models.check_training_model(model)
# convert the model
model = models.convert_model(model, nms=args.nms, class_specific_filter=args.class_specific_filter, anchor_params=anchor_parameters)
# save model
model.save(args.model_out)
if __name__ == '__main__':
main()