-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathshapenet_part_loader.py
115 lines (96 loc) · 4.38 KB
/
shapenet_part_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import torch.utils.data as data
import os
import os.path
import torch
import json
import numpy as np
import sys
def pc_normalize(pc):
centroid = np.mean(pc, axis=0)
pc = pc - centroid
m = np.max(np.sqrt(np.sum(pc ** 2, axis=1)))
pc = pc / m
return pc
class PartDataset(data.Dataset):
def __init__(self, root='', npoints=2500, classification=False, class_choice=None, split='train', normalize=True):
assert len(root) > 0, "No data root specified"
self.npoints = npoints
self.root = root
self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')
self.cat = {}
self.classification = classification
self.normalize = normalize
with open(self.catfile, 'r') as f:
for line in f:
ls = line.strip().split()
self.cat[ls[0]] = ls[1]
if class_choice is not None:
self.cat = {k: v for k, v in self.cat.items() if k in class_choice}
print("Chosen Categories: {}".format(str(self.cat)))
self.meta = {}
with open(os.path.join(self.root, 'train_test_split', 'shuffled_train_file_list.json'), 'r') as f:
train_ids = set([str(d.split('/')[2]) for d in json.load(f)])
with open(os.path.join(self.root, 'train_test_split', 'shuffled_val_file_list.json'), 'r') as f:
val_ids = set([str(d.split('/')[2]) for d in json.load(f)])
with open(os.path.join(self.root, 'train_test_split', 'shuffled_test_file_list.json'), 'r') as f:
test_ids = set([str(d.split('/')[2]) for d in json.load(f)])
for item in self.cat:
self.meta[item] = []
dir_point = os.path.join(self.root, self.cat[item], 'points')
dir_seg = os.path.join(self.root, self.cat[item], 'points_label')
fns = sorted(os.listdir(dir_point))
if split == 'trainval':
fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))]
elif split == 'train':
fns = [fn for fn in fns if fn[0:-4] in train_ids]
elif split == 'val':
fns = [fn for fn in fns if fn[0:-4] in val_ids]
elif split == 'test':
fns = [fn for fn in fns if fn[0:-4] in test_ids]
else:
print('Fail: Unknown Shapenet split: %s' % split)
sys.exit(-1)
for fn in fns:
token = (os.path.splitext(os.path.basename(fn))[0])
self.meta[item].append((os.path.join(dir_point, token + '.pts'), os.path.join(dir_seg, token + '.seg'),
self.cat[item], token))
self.datapath = []
for item in self.cat:
for fn in self.meta[item]:
self.datapath.append((item, fn[0], fn[1], fn[2], fn[3]))
self.classes = dict(zip(sorted(self.cat), range(len(self.cat))))
print(self.classes)
self.num_seg_classes = 0
if not self.classification:
for i in range(len(self.datapath) // 50):
l = len(np.unique(np.loadtxt(self.datapath[i][2]).astype(np.uint8)))
if l > self.num_seg_classes:
self.num_seg_classes = l
self.cache = {} # from index to (point_set, cls, seg) tuple
self.cache_size = 18000
def __getitem__(self, index):
if index in self.cache:
point_set, seg, cls, foldername, filename = self.cache[index]
else:
fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]]
point_set = np.loadtxt(fn[1]).astype(np.float32)
if self.normalize:
point_set = pc_normalize(point_set)
seg = np.loadtxt(fn[2]).astype(np.int64) - 1
foldername = fn[3]
filename = fn[4]
if len(self.cache) < self.cache_size:
self.cache[index] = (point_set, seg, cls, foldername, filename)
choice = np.random.choice(len(seg), self.npoints, replace=True)
point_set = point_set[choice, :]
seg = seg[choice]
point_set = torch.from_numpy(point_set)
seg = torch.from_numpy(seg)
cls = torch.from_numpy(np.array([cls]).astype(np.int64))
if self.classification:
return point_set, cls
else:
return point_set, seg, cls
def __len__(self):
return len(self.datapath)