-
Notifications
You must be signed in to change notification settings - Fork 966
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pytorch->Caffe #241
Comments
Hi @TonyTangYu , the tutorial is available at https://github.com/Microsoft/MMdnn#support-frameworks. Here comes the approach for your problem. You have to make some modification in order to reach such effect. First, git clone the
Commit It might look like this. def __init__(self, model_file_name, input_shape):
super(PytorchParser, self).__init__()
# if not os.path.exists(model_file_name):
# print("Pytorch model file [{}] is not found.".format(model_file_name))
# assert False
# # test
# # cpu: https://github.com/pytorch/pytorch/issues/5286
# try:
# model = torch.load(model_file_name)
# except:
# model = torch.load(model_file_name, map_location='cpu')
model = model_file_name
self.weight_loaded = True
# Build network graph
self.pytorch_graph = PytorchGraph(model)
self.input_shape = tuple([1] + input_shape)
self.pytorch_graph.build(self.input_shape)
self.state_dict = self.pytorch_graph.state_dict
self.shape_dict = self.pytorch_graph.shape_dict Then, you install this local pip install -e . -U Secondly, you get the 'FD-mobile' repo.
Also, you commit evaluate.py from Line45 to Line 56, and add this below. size = 224
from mmdnn.conversion.pytorch.pytorch_parser import PytorchParser
pytorchparser = PytorchParser(model, [3, size, size])
IR_file = 'FD_mobile'
pytorchparser.run(IR_file) The def main():
global args, best_prec1, last_epoch
args = parser.parse_args()
with open(args.data_config, 'r') as json_file:
data_config = json.load(json_file)
with open(args.model_config, 'r') as json_file:
model_config = json.load(json_file)
if not os.path.exists(args.checkpoint):
raise RuntimeError('checkpoint `{}` does not exist.'.format(args.checkpoint))
# create model
print('==> Creating model `{}`...'.format(model_config['name']))
model = models.get_model(data_config['name'], model_config)
checkpoint = torch.load(args.checkpoint, map_location='cpu')
print('==> Checkpoint name is `{}`.'.format(checkpoint['name']))
model.load_state_dict(checkpoint['state_dict'])
size = 224
from mmdnn.conversion.pytorch.pytorch_parser import PytorchParser
pytorchparser = PytorchParser(model, [3, size, size])
IR_file = 'FD_mobile'
pytorchparser.run(IR_file) You run the script, and then you will get the IR structure file. python evaluate.py --data config/imagenet/data-config/imagenet-test.json --model config/imagenet/model-config/fd-mobilenet/1x-FDMobileNet-224.json --checkpoint saved_models/1x-FDMobileNet-224.pth.tar The result is IR network structure is saved as [FD_mobile.json].
IR network structure is saved as [FD_mobile.pb].
IR weights are saved as [FD_mobile.npy]. After that, you can use this line to convert IR to Caffe Code. mmtocode -f caffe -n FD_mobile.pb -w FD_mobile.npy -d caffe_converted.py -dw caffe_converted.npy Parse file [FD_mobile.pb] with binary format successfully.
Target network code snippet is saved as [caffe_converted.py].
Target weights are saved as [caffe_converted.npy]. Finally, you can also get the caffe model like this. mmtomodel -f caffe -in caffe_converted.py -iw caffe_converted.npy -o caffe_target You will get this result. Caffe model files are saved as [caffe_target.prototxt] and [caffe_target.caffemodel], generated by [caffe_converted.py] and [caffe_converted.npy]. |
Platform:ubuntu 16.04
Python version:python 3.5.2
Source framework with version (like Tensorflow 1.4.1 with GPU):
pytorch0.3.1 with GPU
Destination framework with version (like CNTK 2.3 with GPU):
caffe
Pre-trained model path (webpath or webdisk path):
https://github.com/clavichord93/FD-MobileNet/tree/master/saved_models
Running scripts:
In fact, FD-MobileNet is a new kind of deep learning model on the basis of MobileNet. I want to convert this kind of model of pytorch into that of caffe. But I don't know how to use MMdnn. The example esage doesn't give me a clear clue. I wonder how to do this and is there a doc about MMdnn showing all command of MMdnn?
The text was updated successfully, but these errors were encountered: