-
Notifications
You must be signed in to change notification settings - Fork 67
/
Copy pathtrainer.py
125 lines (92 loc) · 3.88 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import torch
from networks.discriminator import Discriminator
from networks.generator import Generator
import torch.nn.functional as F
from torch import nn, optim
import os
from vgg19 import VGGLoss
from torch.nn.parallel import DistributedDataParallel as DDP
def requires_grad(net, flag=True):
for p in net.parameters():
p.requires_grad = flag
class Trainer(nn.Module):
def __init__(self, args, device, rank):
super(Trainer, self).__init__()
self.args = args
self.batch_size = args.batch_size
self.gen = Generator(args.size, args.latent_dim_style, args.latent_dim_motion, args.channel_multiplier).to(
device)
self.dis = Discriminator(args.size, args.channel_multiplier).to(device)
# distributed computing
self.gen = DDP(self.gen, device_ids=[rank], find_unused_parameters=True)
self.dis = DDP(self.dis, device_ids=[rank], find_unused_parameters=True)
g_reg_ratio = args.g_reg_every / (args.g_reg_every + 1)
d_reg_ratio = args.d_reg_every / (args.d_reg_every + 1)
self.g_optim = optim.Adam(
self.gen.parameters(),
lr=args.lr * g_reg_ratio,
betas=(0 ** g_reg_ratio, 0.99 ** g_reg_ratio)
)
self.d_optim = optim.Adam(
self.dis.parameters(),
lr=args.lr * d_reg_ratio,
betas=(0 ** d_reg_ratio, 0.99 ** d_reg_ratio)
)
self.criterion_vgg = VGGLoss().to(rank)
def g_nonsaturating_loss(self, fake_pred):
return F.softplus(-fake_pred).mean()
def d_nonsaturating_loss(self, fake_pred, real_pred):
real_loss = F.softplus(-real_pred)
fake_loss = F.softplus(fake_pred)
return real_loss.mean() + fake_loss.mean()
def gen_update(self, img_source, img_target):
self.gen.train()
self.gen.zero_grad()
requires_grad(self.gen, True)
requires_grad(self.dis, False)
img_target_recon = self.gen(img_source, img_target)
img_recon_pred = self.dis(img_target_recon)
vgg_loss = self.criterion_vgg(img_target_recon, img_target).mean()
l1_loss = F.l1_loss(img_target_recon, img_target)
gan_g_loss = self.g_nonsaturating_loss(img_recon_pred)
g_loss = vgg_loss + l1_loss + gan_g_loss
g_loss.backward()
self.g_optim.step()
return vgg_loss, l1_loss, gan_g_loss, img_target_recon
def dis_update(self, img_real, img_recon):
self.dis.zero_grad()
requires_grad(self.gen, False)
requires_grad(self.dis, True)
real_img_pred = self.dis(img_real)
recon_img_pred = self.dis(img_recon.detach())
d_loss = self.d_nonsaturating_loss(recon_img_pred, real_img_pred)
d_loss.backward()
self.d_optim.step()
return d_loss
def sample(self, img_source, img_target):
with torch.no_grad():
self.gen.eval()
img_recon = self.gen(img_source, img_target)
img_source_ref = self.gen(img_source, None)
return img_recon, img_source_ref
def resume(self, resume_ckpt):
print("load model:", resume_ckpt)
ckpt = torch.load(resume_ckpt)
ckpt_name = os.path.basename(resume_ckpt)
start_iter = int(os.path.splitext(ckpt_name)[0])
self.gen.module.load_state_dict(ckpt["gen"])
self.dis.module.load_state_dict(ckpt["dis"])
self.g_optim.load_state_dict(ckpt["g_optim"])
self.d_optim.load_state_dict(ckpt["d_optim"])
return start_iter
def save(self, idx, checkpoint_path):
torch.save(
{
"gen": self.gen.module.state_dict(),
"dis": self.dis.module.state_dict(),
"g_optim": self.g_optim.state_dict(),
"d_optim": self.d_optim.state_dict(),
"args": self.args
},
f"{checkpoint_path}/{str(idx).zfill(6)}.pt"
)