-
Notifications
You must be signed in to change notification settings - Fork 19
/
batch_gen.py
125 lines (101 loc) · 4.97 KB
/
batch_gen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
'''
Adapted from https://github.com/yabufarha/ms-tcn
'''
import torch
import numpy as np
import random
from grid_sampler import GridSampler, TimeWarpLayer
class BatchGenerator(object):
def __init__(self, num_classes, actions_dict, gt_path, features_path, sample_rate):
self.index = 0
self.num_classes = num_classes
self.actions_dict = actions_dict
self.gt_path = gt_path
self.features_path = features_path
self.sample_rate = sample_rate
self.timewarp_layer = TimeWarpLayer()
def reset(self):
self.index = 0
self.my_shuffle()
def has_next(self):
if self.index < len(self.list_of_examples):
return True
return False
def read_data(self, vid_list_file):
file_ptr = open(vid_list_file, 'r')
self.list_of_examples = file_ptr.read().split('\n')[:-1]
file_ptr.close()
self.gts = [self.gt_path + vid for vid in self.list_of_examples]
self.features = [self.features_path + vid.split('.')[0] + '.npy' for vid in self.list_of_examples]
self.my_shuffle()
def my_shuffle(self):
# shuffle list_of_examples, gts, features with the same order
randnum = random.randint(0, 100)
random.seed(randnum)
random.shuffle(self.list_of_examples)
random.seed(randnum)
random.shuffle(self.gts)
random.seed(randnum)
random.shuffle(self.features)
def warp_video(self, batch_input_tensor, batch_target_tensor):
'''
:param batch_input_tensor: (bs, C_in, L_in)
:param batch_target_tensor: (bs, L_in)
:return: warped input and target
'''
bs, _, T = batch_input_tensor.shape
grid_sampler = GridSampler(T)
grid = grid_sampler.sample(bs)
grid = torch.from_numpy(grid).float()
warped_batch_input_tensor = self.timewarp_layer(batch_input_tensor, grid, mode='bilinear')
batch_target_tensor = batch_target_tensor.unsqueeze(1).float()
warped_batch_target_tensor = self.timewarp_layer(batch_target_tensor, grid, mode='nearest') # no bilinear for label!
warped_batch_target_tensor = warped_batch_target_tensor.squeeze(1).long() # obtain the same shape
return warped_batch_input_tensor, warped_batch_target_tensor
def merge(self, bg, suffix):
'''
merge two batch generator. I.E
BatchGenerator a;
BatchGenerator b;
a.merge(b, suffix='@1')
:param bg:
:param suffix: identify the video
:return:
'''
self.list_of_examples += [vid + suffix for vid in bg.list_of_examples]
self.gts += bg.gts
self.features += bg.features
print('Merge! Dataset length:{}'.format(len(self.list_of_examples)))
def next_batch(self, batch_size, if_warp=False): # if_warp=True is a strong data augmentation. See grid_sampler.py for details.
batch = self.list_of_examples[self.index:self.index + batch_size]
batch_gts = self.gts[self.index:self.index + batch_size]
batch_features = self.features[self.index:self.index + batch_size]
self.index += batch_size
batch_input = []
batch_target = []
for idx, vid in enumerate(batch):
features = np.load(batch_features[idx])
file_ptr = open(batch_gts[idx], 'r')
content = file_ptr.read().split('\n')[:-1]
classes = np.zeros(min(np.shape(features)[1], len(content)))
for i in range(len(classes)):
classes[i] = self.actions_dict[content[i]]
feature = features[:, ::self.sample_rate]
target = classes[::self.sample_rate]
batch_input.append(feature)
batch_target.append(target)
length_of_sequences = list(map(len, batch_target))
batch_input_tensor = torch.zeros(len(batch_input), np.shape(batch_input[0])[0], max(length_of_sequences), dtype=torch.float) # bs, C_in, L_in
batch_target_tensor = torch.ones(len(batch_input), max(length_of_sequences), dtype=torch.long) * (-100)
mask = torch.zeros(len(batch_input), self.num_classes, max(length_of_sequences), dtype=torch.float)
for i in range(len(batch_input)):
if if_warp:
warped_input, warped_target = self.warp_video(torch.from_numpy(batch_input[i]).unsqueeze(0), torch.from_numpy(batch_target[i]).unsqueeze(0))
batch_input_tensor[i, :, :np.shape(batch_input[i])[1]], batch_target_tensor[i, :np.shape(batch_target[i])[0]] = warped_input.squeeze(0), warped_target.squeeze(0)
else:
batch_input_tensor[i, :, :np.shape(batch_input[i])[1]] = torch.from_numpy(batch_input[i])
batch_target_tensor[i, :np.shape(batch_target[i])[0]] = torch.from_numpy(batch_target[i])
mask[i, :, :np.shape(batch_target[i])[0]] = torch.ones(self.num_classes, np.shape(batch_target[i])[0])
return batch_input_tensor, batch_target_tensor, mask, batch
if __name__ == '__main__':
pass