-
Notifications
You must be signed in to change notification settings - Fork 2
/
data_utils.py
101 lines (79 loc) · 3.05 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# @Time : Jan. 10, 2019 15:26
# @Author : Veritas YIN
# @FileName : data_utils.py
# @Version : 1.0
# @IDE : PyCharm
# @Github : https://github.com/VeritasYin/Project_Orion
from utils.math_utils import z_score
import numpy as np
import pandas as pd
class Dataset(object):
def __init__(self, data, stats):
self.__data = data
self.mean = stats['mean']
self.std = stats['std']
def get_data(self, type):
return self.__data[type]
def get_stats(self):
return {'mean': self.mean, 'std': self.std}
def get_len(self, type):
return len(self.__data[type])
def z_inverse(self, type):
return self.__data[type] * self.std + self.mean
def seq_gen(len_seq, data_seq, offset, n_frame, n_route, day_slot, C_0=1):
'''
序列生成
'''
n_slot = day_slot - n_frame + 1
tmp_seq = np.zeros((len_seq * n_slot, n_frame, n_route, C_0))
for i in range(len_seq):
for j in range(n_slot):
sta = (i + offset) * day_slot + j
end = sta + n_frame
tmp_seq[i * n_slot + j, :, :, :] = np.reshape(data_seq[sta:end, :], [n_frame, n_route, C_0])
return tmp_seq
def data_gen(file_path, data_config, n_route, is_csv, n_frame=21, day_slot=288):
'''
数据序列生成
'''
n_train, n_val, n_test = data_config
# generate training, validation and test data
try:
if is_csv:
data_seq = pd.read_csv(file_path, header=None).values
else:
data_seq = pd.read_hdf(file_path).values[:51840]
except FileNotFoundError:
print(f'ERROR: input file was not found in {file_path}.')
seq_train = seq_gen(n_train, data_seq, 0, n_frame, n_route, day_slot)
seq_val = seq_gen(n_val, data_seq, n_train, n_frame, n_route, day_slot)
seq_test = seq_gen(n_test, data_seq, n_train + n_val, n_frame, n_route, day_slot)
# x_stats: dict, the stats for the train dataset, including the value of mean and standard deviation.
x_stats = {'mean': np.mean(seq_train), 'std': np.std(seq_train)}
# x_train, x_val, x_test: np.array, [sample_size, n_frame, n_route, channel_size].
x_train = z_score(seq_train, x_stats['mean'], x_stats['std'])
x_val = z_score(seq_val, x_stats['mean'], x_stats['std'])
x_test = z_score(seq_test, x_stats['mean'], x_stats['std'])
x_data = {'train': x_train, 'val': x_val, 'test': x_test}
dataset = Dataset(x_data, x_stats)
return dataset
def gen_batch(inputs, batch_size, dynamic_batch=False, shuffle=False):
'''
数据批次生成器(返回迭代器)
'''
len_inputs = len(inputs)
if shuffle:
idx = np.arange(len_inputs)
np.random.shuffle(idx)
for start_idx in range(0, len_inputs, batch_size):
end_idx = start_idx + batch_size
if end_idx > len_inputs:
if dynamic_batch:
end_idx = len_inputs
else:
break
if shuffle:
slide = idx[start_idx:end_idx]
else:
slide = slice(start_idx, end_idx)
yield inputs[slide]