-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
61 lines (45 loc) · 1.92 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import torch
from torch import nn
from functools import reduce
from pathlib import Path
from imagen_pytorch.configs import ImagenConfig, ElucidatedImagenConfig
from ema_pytorch import EMA
def exists(val):
return val is not None
def safeget(dictionary, keys, default = None):
return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split('.'), dictionary)
def load_imagen_from_checkpoint(
checkpoint_path,
load_weights = True,
load_ema_if_available = False
):
model_path = Path(checkpoint_path)
full_model_path = str(model_path.resolve())
assert model_path.exists(), f'checkpoint not found at {full_model_path}'
loaded = torch.load(str(model_path), map_location='cpu')
imagen_params = safeget(loaded, 'imagen_params')
imagen_type = safeget(loaded, 'imagen_type')
if imagen_type == 'original':
imagen_klass = ImagenConfig
elif imagen_type == 'elucidated':
imagen_klass = ElucidatedImagenConfig
else:
raise ValueError(f'unknown imagen type {imagen_type} - you need to instantiate your Imagen with configurations, using classes ImagenConfig or ElucidatedImagenConfig')
assert exists(imagen_params) and exists(imagen_type), 'imagen type and configuration not saved in this checkpoint'
imagen = imagen_klass(**imagen_params).create()
if not load_weights:
return imagen
has_ema = 'ema' in loaded
should_load_ema = has_ema and load_ema_if_available
imagen.load_state_dict(loaded['model'])
if not should_load_ema:
print('loading non-EMA version of unets')
return imagen
ema_unets = nn.ModuleList([])
for unet in imagen.unets:
ema_unets.append(EMA(unet))
ema_unets.load_state_dict(loaded['ema'])
for unet, ema_unet in zip(imagen.unets, ema_unets):
unet.load_state_dict(ema_unet.ema_model.state_dict())
print('loaded EMA version of unets')
return imagen