-
Notifications
You must be signed in to change notification settings - Fork 11
/
training_data_loader.py
88 lines (68 loc) · 2.69 KB
/
training_data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# Copyright (c) Meta, Inc. and its affiliates.
# Copyright (c) Stanford University
import numpy as np
import time
import random
import torch
from torch.utils.data import Dataset
class TrainSubDataset(Dataset):
"""
Randomly Downsample the whole Training (mainly just AMASS) dataset (with down_sample_rates) to fit in memory.
This is because each data point for Transformer is a time window which is too large if we don't downsample.
After each epoch (downsampled data exhausted), you should get a new dataset instance
therefore re-downsampling the dataset
"""
def __init__(
self,
seq_length,
info_path,
imu_combine_path,
s_combine_path,
with_acc_sum=True,
):
start_time = time.time()
IMU_c = np.load(imu_combine_path)
S_c = np.load(s_combine_path)
infos = list(np.load(info_path))
if with_acc_sum:
IMU_sum_c = np.load(imu_combine_path.replace("imu", "sum_imu"))
else:
IMU_sum_c = None
IMU = []
S = []
IMU_sum = []
for info in infos:
# each info is [start_t, end_t, down sample]
start_t, end_t, down_sample_rate = tuple(info)
time_range = range(start_t + seq_length, end_t - 1)
if len(time_range) == 0:
continue
num_samples = np.maximum(round(len(time_range) / down_sample_rate), 1)
# note, set random seed outside
for t in random.sample(time_range, k=num_samples):
# IMU dim (num_samples, T, num_feat)
IMU.append(IMU_c[(t - seq_length):t, :])
S.append(S_c[(t - seq_length):(t + 1), :])
if with_acc_sum:
IMU_sum.append(IMU_sum_c[(t - seq_length):t, :])
self.IMU = torch.from_numpy(np.array(IMU))
self.IMU_sum = torch.from_numpy(np.array(IMU_sum))
self.S = torch.from_numpy(np.array(S))
self.size = (self.IMU.shape[0], self.IMU.shape[1])
self.seq_length = seq_length
self.with_acc_sum = with_acc_sum
print("load time", time.time() - start_time)
print("IMU shape", self.IMU.size())
print("IMU sum shape", self.IMU_sum.size())
print("S shape", self.S.size())
def __getitem__(self, index):
x_imu = self.IMU[index]
if self.with_acc_sum:
x_imu_acc_sum = self.IMU_sum[index]
x_s = self.S[index, :-1, :]
y_s_n = self.S[index, 1:, :]
if self.with_acc_sum:
x_imu = torch.cat((x_imu, x_imu_acc_sum), dim=1)
return x_imu, x_s, y_s_n
def __len__(self):
return self.size[0]