-
Notifications
You must be signed in to change notification settings - Fork 11
/
convert_lama.py
49 lines (41 loc) · 1.37 KB
/
convert_lama.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import coremltools as ct
import torch
from iopaint.model.lama import LaMa
from CoreMLaMa import CoreMLaMa
model_manager = LaMa("cpu")
# Fixed image/mask size
# Flexible input shapes are not (currently) supported, for various reasons
size = (800, 800) # pixel width x height
# Image/mask shapes in PyTorch format
image_shape=(1, 3, size[1], size[0])
mask_shape=(1, 1, size[1], size[0])
lama_inpaint_model = model_manager.model
model = CoreMLaMa(lama_inpaint_model).eval()
print("Scripting CoreMLaMa")
jit_model = torch.jit.script(model)
print("Converting model")
# Note that ct.ImageType assumes an 8 bpp image, while LaMa
# uses 32-bit FP math internally. Creating a Core ML model
# that can work with 32-bit FP image inputs is on the "To Do"
# list
coreml_model = ct.convert(
jit_model,
convert_to="mlprogram",
compute_precision=ct.precision.FLOAT32,
compute_units=ct.ComputeUnit.CPU_AND_GPU,
inputs=[
ct.ImageType(name="image",
shape=image_shape,
scale=1/255.0),
ct.ImageType(
name="mask",
shape=mask_shape,
color_layout=ct.colorlayout.GRAYSCALE)
],
outputs=[ct.ImageType(name="output")],
skip_model_load=True
)
coreml_model_file_name = "LaMa.mlpackage"
print(f"Saving model to {coreml_model_file_name}")
coreml_model.save(coreml_model_file_name)
print("Done!")