forked from lightaime/deep_gcns_torch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patharchitecture.py
96 lines (81 loc) · 4.65 KB
/
architecture.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch
from gcn_lib.dense import BasicConv, GraphConv2d, PlainDynBlock2d, ResDynBlock2d, DenseDynBlock2d, DenseDilatedKnnGraph
from torch.nn import Sequential as Seq
class DenseDeepGCN(torch.nn.Module):
def __init__(self, opt):
super(DenseDeepGCN, self).__init__()
channels = opt.n_filters
k = opt.k
act = opt.act
norm = opt.norm
bias = opt.bias
epsilon = opt.epsilon
stochastic = opt.stochastic
conv = opt.conv
c_growth = channels
self.n_blocks = opt.n_blocks
self.knn = DenseDilatedKnnGraph(k, 1, stochastic, epsilon)
self.head = GraphConv2d(opt.in_channels, channels, conv, act, norm, bias)
if opt.block.lower() == 'res':
self.backbone = Seq(*[ResDynBlock2d(channels, k, 1+i, conv, act, norm, bias, stochastic, epsilon)
for i in range(self.n_blocks-1)])
fusion_dims = int(channels + c_growth * (self.n_blocks - 1))
elif opt.block.lower() == 'dense':
self.backbone = Seq(*[DenseDynBlock2d(channels+c_growth*i, c_growth, k, 1+i, conv, act,
norm, bias, stochastic, epsilon)
for i in range(self.n_blocks-1)])
fusion_dims = int(
(channels + channels + c_growth * (self.n_blocks - 1)) * self.n_blocks // 2)
else:
stochastic = False
self.backbone = Seq(*[PlainDynBlock2d(channels, k, 1, conv, act, norm,
bias, stochastic, epsilon)
for i in range(self.n_blocks - 1)])
fusion_dims = int(channels + c_growth * (self.n_blocks - 1))
self.fusion_block = BasicConv([fusion_dims, 1024], act, norm, bias)
self.prediction = Seq(*[BasicConv([fusion_dims+1024, 512], act, norm, bias),
BasicConv([512, 256], act, norm, bias),
torch.nn.Dropout(p=opt.dropout),
BasicConv([256, opt.n_classes], None, None, bias)])
def forward(self, inputs):
feats = [self.head(inputs, self.knn(inputs[:, 0:3]))]
for i in range(self.n_blocks-1):
feats.append(self.backbone[i](feats[-1]))
feats = torch.cat(feats, dim=1)
fusion = torch.max_pool2d(self.fusion_block(feats), kernel_size=[feats.shape[2], feats.shape[3]])
fusion = torch.repeat_interleave(fusion, repeats=feats.shape[2], dim=2)
return self.prediction(torch.cat((fusion, feats), dim=1)).squeeze(-1)
if __name__ == "__main__":
import random, numpy as np, argparse
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
batch_size = 2
N = 1024
device = 'cuda'
parser = argparse.ArgumentParser(description='PyTorch implementation of Deep GCN For semantic segmentation')
parser.add_argument('--in_channels', default=9, type=int, help='input channels (default:9)')
parser.add_argument('--n_classes', default=13, type=int, help='num of segmentation classes (default:13)')
parser.add_argument('--k', default=20, type=int, help='neighbor num (default:16)')
parser.add_argument('--block', default='res', type=str, help='graph backbone block type {plain, res, dense}')
parser.add_argument('--conv', default='edge', type=str, help='graph conv layer {edge, mr}')
parser.add_argument('--act', default='relu', type=str, help='activation layer {relu, prelu, leakyrelu}')
parser.add_argument('--norm', default='batch', type=str, help='{batch, instance} normalization')
parser.add_argument('--bias', default=True, type=bool, help='bias of conv layer True or False')
parser.add_argument('--n_filters', default=64, type=int, help='number of channels of deep features')
parser.add_argument('--n_blocks', default=7, type=int, help='number of basic blocks')
parser.add_argument('--dropout', default=0.5, type=float, help='ratio of dropout')
parser.add_argument('--epsilon', default=0.2, type=float, help='stochastic epsilon for gcn')
parser.add_argument('--stochastic', default=False, type=bool, help='stochastic for gcn, True or False')
args = parser.parse_args()
pos = torch.rand((batch_size, N, 3), dtype=torch.float).to(device)
x = torch.rand((batch_size, N, 6), dtype=torch.float).to(device)
inputs = torch.cat((pos, x), 2).transpose(1, 2).unsqueeze(-1)
# net = DGCNNSegDense().to(device)
net = DenseDeepGCN(args).to(device)
print(net)
out = net(inputs)
print(out.shape)