-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathbuilder.py
100 lines (84 loc) · 2.91 KB
/
builder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# --------------------------------------------------------
# OctFormer: Octree-based Transformers for 3D Point Clouds
# Copyright (c) 2023 Peng-Shuai Wang <[email protected]>
# Licensed under The MIT License [see LICENSE for details]
# Written by Peng-Shuai Wang
# --------------------------------------------------------
import ocnn
import torch
import datasets
import models
def octsegformer_large(in_channels, out_channels, **kwargs):
return models.OctFormerSeg(
in_channels, out_channels,
channels=[192, 384, 768, 768],
num_blocks=[2, 2, 18, 2],
num_heads=[12, 24, 48, 48],
patch_size=32, dilation=4,
drop_path=0.5, nempty=True,
stem_down=2, head_up=2,
fpn_channel=168,
head_drop=[0.5, 0.5])
def octsegformer(in_channels, out_channels, **kwargs):
return models.OctFormerSeg(
in_channels, out_channels,
channels=[96, 192, 384, 384],
num_blocks=[2, 2, 18, 2],
num_heads=[6, 12, 24, 24],
patch_size=32, dilation=4,
drop_path=0.5, nempty=True,
stem_down=2, head_up=2,
fpn_channel=168,
head_drop=[0.5, 0.5])
def octsegformer_small(in_channels, out_channels, **kwargs):
return models.OctFormerSeg(
in_channels, out_channels,
channels=[96, 192, 384, 384],
num_blocks=[2, 2, 6, 2],
num_heads=[6, 12, 24, 24],
patch_size=32, dilation=4,
drop_path=0.5, nempty=True,
stem_down=2, head_up=2,
fpn_channel=168,
head_drop=[0.5, 0.5])
def octsegformer_cls(in_channels, out_channels, nemtpy, **kwargs):
return models.OctFormerCls(
in_channels, out_channels,
channels=[96, 192],
num_blocks=[6, 6],
num_heads=[6, 12],
patch_size=32, dilation=2,
drop_path=0.5, nempty=nemtpy,
stem_down=2, head_drop=0.5)
def get_segmentation_model(flags):
params = {
'in_channels': flags.channel, 'out_channels': flags.nout,
'interp': flags.interp, 'nempty': flags.nempty,
}
networks = {
'octsegformer': octsegformer,
'octsegformer_large': octsegformer_large,
'octsegformer_small': octsegformer_small,
}
return networks[flags.name.lower()](**params)
def get_classification_model(flags):
if flags.name.lower() == 'lenet':
model = ocnn.models.LeNet(
flags.channel, flags.nout, flags.stages, flags.nempty)
elif flags.name.lower() == 'hrnet':
model = ocnn.models.HRNet(
flags.channel, flags.nout, flags.stages, nempty=flags.nempty)
elif flags.name.lower() == 'octformercls':
model = octsegformer_cls(flags.channel, flags.nout, flags.nempty)
else:
raise ValueError
return model
def get_segmentation_dataset(flags):
if flags.name.lower() == 'shapenet':
return datasets.get_shapenet_seg_dataset(flags)
elif flags.name.lower() == 'scannet':
return datasets.get_scannet_dataset(flags)
elif flags.name.lower() == 'kitti':
return datasets.get_kitti_dataset(flags)
else:
raise ValueError