-
Notifications
You must be signed in to change notification settings - Fork 150
/
ts2vec.py
319 lines (264 loc) · 14 KB
/
ts2vec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
import torch
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import numpy as np
from models import TSEncoder
from models.losses import hierarchical_contrastive_loss
from utils import take_per_row, split_with_nan, centerize_vary_length_series, torch_pad_nan
import math
class TS2Vec:
'''The TS2Vec model'''
def __init__(
self,
input_dims,
output_dims=320,
hidden_dims=64,
depth=10,
device='cuda',
lr=0.001,
batch_size=16,
max_train_length=None,
temporal_unit=0,
after_iter_callback=None,
after_epoch_callback=None
):
''' Initialize a TS2Vec model.
Args:
input_dims (int): The input dimension. For a univariate time series, this should be set to 1.
output_dims (int): The representation dimension.
hidden_dims (int): The hidden dimension of the encoder.
depth (int): The number of hidden residual blocks in the encoder.
device (int): The gpu used for training and inference.
lr (int): The learning rate.
batch_size (int): The batch size.
max_train_length (Union[int, NoneType]): The maximum allowed sequence length for training. For sequence with a length greater than <max_train_length>, it would be cropped into some sequences, each of which has a length less than <max_train_length>.
temporal_unit (int): The minimum unit to perform temporal contrast. When training on a very long sequence, this param helps to reduce the cost of time and memory.
after_iter_callback (Union[Callable, NoneType]): A callback function that would be called after each iteration.
after_epoch_callback (Union[Callable, NoneType]): A callback function that would be called after each epoch.
'''
super().__init__()
self.device = device
self.lr = lr
self.batch_size = batch_size
self.max_train_length = max_train_length
self.temporal_unit = temporal_unit
self._net = TSEncoder(input_dims=input_dims, output_dims=output_dims, hidden_dims=hidden_dims, depth=depth).to(self.device)
self.net = torch.optim.swa_utils.AveragedModel(self._net)
self.net.update_parameters(self._net)
self.after_iter_callback = after_iter_callback
self.after_epoch_callback = after_epoch_callback
self.n_epochs = 0
self.n_iters = 0
def fit(self, train_data, n_epochs=None, n_iters=None, verbose=False):
''' Training the TS2Vec model.
Args:
train_data (numpy.ndarray): The training data. It should have a shape of (n_instance, n_timestamps, n_features). All missing data should be set to NaN.
n_epochs (Union[int, NoneType]): The number of epochs. When this reaches, the training stops.
n_iters (Union[int, NoneType]): The number of iterations. When this reaches, the training stops. If both n_epochs and n_iters are not specified, a default setting would be used that sets n_iters to 200 for a dataset with size <= 100000, 600 otherwise.
verbose (bool): Whether to print the training loss after each epoch.
Returns:
loss_log: a list containing the training losses on each epoch.
'''
assert train_data.ndim == 3
if n_iters is None and n_epochs is None:
n_iters = 200 if train_data.size <= 100000 else 600 # default param for n_iters
if self.max_train_length is not None:
sections = train_data.shape[1] // self.max_train_length
if sections >= 2:
train_data = np.concatenate(split_with_nan(train_data, sections, axis=1), axis=0)
temporal_missing = np.isnan(train_data).all(axis=-1).any(axis=0)
if temporal_missing[0] or temporal_missing[-1]:
train_data = centerize_vary_length_series(train_data)
train_data = train_data[~np.isnan(train_data).all(axis=2).all(axis=1)]
train_dataset = TensorDataset(torch.from_numpy(train_data).to(torch.float))
train_loader = DataLoader(train_dataset, batch_size=min(self.batch_size, len(train_dataset)), shuffle=True, drop_last=True)
optimizer = torch.optim.AdamW(self._net.parameters(), lr=self.lr)
loss_log = []
while True:
if n_epochs is not None and self.n_epochs >= n_epochs:
break
cum_loss = 0
n_epoch_iters = 0
interrupted = False
for batch in train_loader:
if n_iters is not None and self.n_iters >= n_iters:
interrupted = True
break
x = batch[0]
if self.max_train_length is not None and x.size(1) > self.max_train_length:
window_offset = np.random.randint(x.size(1) - self.max_train_length + 1)
x = x[:, window_offset : window_offset + self.max_train_length]
x = x.to(self.device)
ts_l = x.size(1)
crop_l = np.random.randint(low=2 ** (self.temporal_unit + 1), high=ts_l+1)
crop_left = np.random.randint(ts_l - crop_l + 1)
crop_right = crop_left + crop_l
crop_eleft = np.random.randint(crop_left + 1)
crop_eright = np.random.randint(low=crop_right, high=ts_l + 1)
crop_offset = np.random.randint(low=-crop_eleft, high=ts_l - crop_eright + 1, size=x.size(0))
optimizer.zero_grad()
out1 = self._net(take_per_row(x, crop_offset + crop_eleft, crop_right - crop_eleft))
out1 = out1[:, -crop_l:]
out2 = self._net(take_per_row(x, crop_offset + crop_left, crop_eright - crop_left))
out2 = out2[:, :crop_l]
loss = hierarchical_contrastive_loss(
out1,
out2,
temporal_unit=self.temporal_unit
)
loss.backward()
optimizer.step()
self.net.update_parameters(self._net)
cum_loss += loss.item()
n_epoch_iters += 1
self.n_iters += 1
if self.after_iter_callback is not None:
self.after_iter_callback(self, loss.item())
if interrupted:
break
cum_loss /= n_epoch_iters
loss_log.append(cum_loss)
if verbose:
print(f"Epoch #{self.n_epochs}: loss={cum_loss}")
self.n_epochs += 1
if self.after_epoch_callback is not None:
self.after_epoch_callback(self, cum_loss)
return loss_log
def _eval_with_pooling(self, x, mask=None, slicing=None, encoding_window=None):
out = self.net(x.to(self.device, non_blocking=True), mask)
if encoding_window == 'full_series':
if slicing is not None:
out = out[:, slicing]
out = F.max_pool1d(
out.transpose(1, 2),
kernel_size = out.size(1),
).transpose(1, 2)
elif isinstance(encoding_window, int):
out = F.max_pool1d(
out.transpose(1, 2),
kernel_size = encoding_window,
stride = 1,
padding = encoding_window // 2
).transpose(1, 2)
if encoding_window % 2 == 0:
out = out[:, :-1]
if slicing is not None:
out = out[:, slicing]
elif encoding_window == 'multiscale':
p = 0
reprs = []
while (1 << p) + 1 < out.size(1):
t_out = F.max_pool1d(
out.transpose(1, 2),
kernel_size = (1 << (p + 1)) + 1,
stride = 1,
padding = 1 << p
).transpose(1, 2)
if slicing is not None:
t_out = t_out[:, slicing]
reprs.append(t_out)
p += 1
out = torch.cat(reprs, dim=-1)
else:
if slicing is not None:
out = out[:, slicing]
return out.cpu()
def encode(self, data, mask=None, encoding_window=None, causal=False, sliding_length=None, sliding_padding=0, batch_size=None):
''' Compute representations using the model.
Args:
data (numpy.ndarray): This should have a shape of (n_instance, n_timestamps, n_features). All missing data should be set to NaN.
mask (str): The mask used by encoder can be specified with this parameter. This can be set to 'binomial', 'continuous', 'all_true', 'all_false' or 'mask_last'.
encoding_window (Union[str, int]): When this param is specified, the computed representation would the max pooling over this window. This can be set to 'full_series', 'multiscale' or an integer specifying the pooling kernel size.
causal (bool): When this param is set to True, the future informations would not be encoded into representation of each timestamp.
sliding_length (Union[int, NoneType]): The length of sliding window. When this param is specified, a sliding inference would be applied on the time series.
sliding_padding (int): This param specifies the contextual data length used for inference every sliding windows.
batch_size (Union[int, NoneType]): The batch size used for inference. If not specified, this would be the same batch size as training.
Returns:
repr: The representations for data.
'''
assert self.net is not None, 'please train or load a net first'
assert data.ndim == 3
if batch_size is None:
batch_size = self.batch_size
n_samples, ts_l, _ = data.shape
org_training = self.net.training
self.net.eval()
dataset = TensorDataset(torch.from_numpy(data).to(torch.float))
loader = DataLoader(dataset, batch_size=batch_size)
with torch.no_grad():
output = []
for batch in loader:
x = batch[0]
if sliding_length is not None:
reprs = []
if n_samples < batch_size:
calc_buffer = []
calc_buffer_l = 0
for i in range(0, ts_l, sliding_length):
l = i - sliding_padding
r = i + sliding_length + (sliding_padding if not causal else 0)
x_sliding = torch_pad_nan(
x[:, max(l, 0) : min(r, ts_l)],
left=-l if l<0 else 0,
right=r-ts_l if r>ts_l else 0,
dim=1
)
if n_samples < batch_size:
if calc_buffer_l + n_samples > batch_size:
out = self._eval_with_pooling(
torch.cat(calc_buffer, dim=0),
mask,
slicing=slice(sliding_padding, sliding_padding+sliding_length),
encoding_window=encoding_window
)
reprs += torch.split(out, n_samples)
calc_buffer = []
calc_buffer_l = 0
calc_buffer.append(x_sliding)
calc_buffer_l += n_samples
else:
out = self._eval_with_pooling(
x_sliding,
mask,
slicing=slice(sliding_padding, sliding_padding+sliding_length),
encoding_window=encoding_window
)
reprs.append(out)
if n_samples < batch_size:
if calc_buffer_l > 0:
out = self._eval_with_pooling(
torch.cat(calc_buffer, dim=0),
mask,
slicing=slice(sliding_padding, sliding_padding+sliding_length),
encoding_window=encoding_window
)
reprs += torch.split(out, n_samples)
calc_buffer = []
calc_buffer_l = 0
out = torch.cat(reprs, dim=1)
if encoding_window == 'full_series':
out = F.max_pool1d(
out.transpose(1, 2).contiguous(),
kernel_size = out.size(1),
).squeeze(1)
else:
out = self._eval_with_pooling(x, mask, encoding_window=encoding_window)
if encoding_window == 'full_series':
out = out.squeeze(1)
output.append(out)
output = torch.cat(output, dim=0)
self.net.train(org_training)
return output.numpy()
def save(self, fn):
''' Save the model to a file.
Args:
fn (str): filename.
'''
torch.save(self.net.state_dict(), fn)
def load(self, fn):
''' Load the model from a file.
Args:
fn (str): filename.
'''
state_dict = torch.load(fn, map_location=self.device)
self.net.load_state_dict(state_dict)