-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdrn_default.py
109 lines (86 loc) · 3.24 KB
/
drn_default.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import torch.nn as nn
from model import common
def make_model(opt):
return DRN(opt)
class DRN(nn.Module):
def __init__(self, opt, conv=common.default_conv):
super(DRN, self).__init__()
self.opt = opt
self.scale = opt.scale
self.phase = len(opt.scale)
n_blocks = opt.n_blocks
n_feats = opt.n_feats
kernel_size = 3
act = nn.ReLU(True)
self.upsample = nn.Upsample(scale_factor=max(opt.scale),
mode='bicubic', align_corners=False)
rgb_mean = (0.4488, 0.4371, 0.4040)
rgb_std = (1.0, 1.0, 1.0)
self.sub_mean = common.MeanShift(opt.rgb_range, rgb_mean, rgb_std)
self.head = conv(opt.n_colors, n_feats, kernel_size)
self.down = [
common.DownBlock(opt, 2, n_feats * pow(2, p), n_feats * pow(2, p), n_feats * pow(2, p + 1)
) for p in range(self.phase)
]
self.down = nn.ModuleList(self.down)
up_body_blocks = [[
common.RCAB(
conv, n_feats * pow(2, p), kernel_size, act=act
) for _ in range(n_blocks)
] for p in range(self.phase, 1, -1)
]
up_body_blocks.insert(0, [
common.RCAB(
conv, n_feats * pow(2, self.phase), kernel_size, act=act
) for _ in range(n_blocks)
])
# The fisrt upsample block
up = [[
common.Upsampler(conv, 2, n_feats * pow(2, self.phase), act=False),
conv(n_feats * pow(2, self.phase), n_feats * pow(2, self.phase - 1), kernel_size=1)
]]
# The rest upsample blocks
for p in range(self.phase - 1, 0, -1):
up.append([
common.Upsampler(conv, 2, 2 * n_feats * pow(2, p), act=False),
conv(2 * n_feats * pow(2, p), n_feats * pow(2, p - 1), kernel_size=1)
])
self.up_blocks = nn.ModuleList()
for idx in range(self.phase):
self.up_blocks.append(
nn.Sequential(*up_body_blocks[idx], *up[idx])
)
# tail conv that output sr imgs
tail = [conv(n_feats * pow(2, self.phase), opt.n_colors, kernel_size)]
for p in range(self.phase, 0, -1):
tail.append(
conv(n_feats * pow(2, p), opt.n_colors, kernel_size)
)
self.tail = nn.ModuleList(tail)
self.add_mean = common.MeanShift(opt.rgb_range, rgb_mean, rgb_std, 1)
def forward(self, x):
# upsample x to target sr size
x = self.upsample(x)
# preprocess
x = self.sub_mean(x)
x = self.head(x)
# down phases,
copies = []
for idx in range(self.phase):
copies.append(x)
x = self.down[idx](x)
# up phases
sr = self.tail[0](x)
sr = self.add_mean(sr)
results = [sr]
for idx in range(self.phase):
# upsample to SR features
x = self.up_blocks[idx](x)
# concat down features and upsample features
x = torch.cat((x, copies[self.phase - idx - 1]), 1)
# output sr imgs
sr = self.tail[idx + 1](x)
sr = self.add_mean(sr)
results.append(sr)
return results