-
Notifications
You must be signed in to change notification settings - Fork 61
/
export_coreml.py
43 lines (33 loc) · 1.2 KB
/
export_coreml.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
from timm import create_model
import model
import utils
import torch
import torchvision
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('--model', default='repvit_m1_1', type=str)
parser.add_argument('--resolution', default=224, type=int)
parser.add_argument('--ckpt', default=None, type=str)
if __name__ == "__main__":
# Load a pre-trained version of MobileNetV2
args = parser.parse_args()
model = create_model(args.model, distillation=True)
if args.ckpt:
model.load_state_dict(torch.load(args.ckpt)['model'])
utils.replace_batchnorm(model)
model.eval()
# Trace the model with random data.
resolution = args.resolution
example_input = torch.rand(1, 3, resolution, resolution)
traced_model = torch.jit.trace(model, example_input)
out = traced_model(example_input)
import coremltools as ct
# Using image_input in the inputs parameter:
# Convert to Core ML neural network using the Unified Conversion API.
model = ct.convert(
traced_model,
inputs=[ct.ImageType(shape=example_input.shape)]
)
# Save the converted model.
model.save(f"coreml/{args.model}_{resolution}.mlmodel")