forked from proteus1991/GridDehazeNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathresidual_dense_block.py
49 lines (41 loc) · 1.46 KB
/
residual_dense_block.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
"""
paper: GridDehazeNet: Attention-Based Multi-Scale Network for Image Dehazing
file: residual_dense_block.py
about: build the Residual Dense Block
author: Xiaohong Liu
date: 01/08/19
"""
# --- Imports --- #
import torch
import torch.nn as nn
import torch.nn.functional as F
# --- Build dense --- #
class MakeDense(nn.Module):
def __init__(self, in_channels, growth_rate, kernel_size=3):
super(MakeDense, self).__init__()
self.conv = nn.Conv2d(in_channels, growth_rate, kernel_size=kernel_size, padding=(kernel_size-1)//2)
def forward(self, x):
out = F.relu(self.conv(x))
out = torch.cat((x, out), 1)
return out
# --- Build the Residual Dense Block --- #
class RDB(nn.Module):
def __init__(self, in_channels, num_dense_layer, growth_rate):
"""
:param in_channels: input channel size
:param num_dense_layer: the number of RDB layers
:param growth_rate: growth_rate
"""
super(RDB, self).__init__()
_in_channels = in_channels
modules = []
for i in range(num_dense_layer):
modules.append(MakeDense(_in_channels, growth_rate))
_in_channels += growth_rate
self.residual_dense_layers = nn.Sequential(*modules)
self.conv_1x1 = nn.Conv2d(_in_channels, in_channels, kernel_size=1, padding=0)
def forward(self, x):
out = self.residual_dense_layers(x)
out = self.conv_1x1(out)
out = out + x
return out