-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
Nonstationary_Transformer.py
218 lines (188 loc) · 9.49 KB
/
Nonstationary_Transformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import torch
import torch.nn as nn
from layers.Transformer_EncDec import Decoder, DecoderLayer, Encoder, EncoderLayer
from layers.SelfAttention_Family import DSAttention, AttentionLayer
from layers.Embed import DataEmbedding
import torch.nn.functional as F
class Projector(nn.Module):
'''
MLP to learn the De-stationary factors
Paper link: https://openreview.net/pdf?id=ucNDIDRNjjv
'''
def __init__(self, enc_in, seq_len, hidden_dims, hidden_layers, output_dim, kernel_size=3):
super(Projector, self).__init__()
padding = 1 if torch.__version__ >= '1.5.0' else 2
self.series_conv = nn.Conv1d(in_channels=seq_len, out_channels=1, kernel_size=kernel_size, padding=padding,
padding_mode='circular', bias=False)
layers = [nn.Linear(2 * enc_in, hidden_dims[0]), nn.ReLU()]
for i in range(hidden_layers - 1):
layers += [nn.Linear(hidden_dims[i], hidden_dims[i + 1]), nn.ReLU()]
layers += [nn.Linear(hidden_dims[-1], output_dim, bias=False)]
self.backbone = nn.Sequential(*layers)
def forward(self, x, stats):
# x: B x S x E
# stats: B x 1 x E
# y: B x O
batch_size = x.shape[0]
x = self.series_conv(x) # B x 1 x E
x = torch.cat([x, stats], dim=1) # B x 2 x E
x = x.view(batch_size, -1) # B x 2E
y = self.backbone(x) # B x O
return y
class Model(nn.Module):
"""
Paper link: https://openreview.net/pdf?id=ucNDIDRNjjv
"""
def __init__(self, configs):
super(Model, self).__init__()
self.task_name = configs.task_name
self.pred_len = configs.pred_len
self.seq_len = configs.seq_len
self.label_len = configs.label_len
# Embedding
self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
configs.dropout)
# Encoder
self.encoder = Encoder(
[
EncoderLayer(
AttentionLayer(
DSAttention(False, configs.factor, attention_dropout=configs.dropout,
output_attention=False), configs.d_model, configs.n_heads),
configs.d_model,
configs.d_ff,
dropout=configs.dropout,
activation=configs.activation
) for l in range(configs.e_layers)
],
norm_layer=torch.nn.LayerNorm(configs.d_model)
)
# Decoder
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
self.dec_embedding = DataEmbedding(configs.dec_in, configs.d_model, configs.embed, configs.freq,
configs.dropout)
self.decoder = Decoder(
[
DecoderLayer(
AttentionLayer(
DSAttention(True, configs.factor, attention_dropout=configs.dropout,
output_attention=False),
configs.d_model, configs.n_heads),
AttentionLayer(
DSAttention(False, configs.factor, attention_dropout=configs.dropout,
output_attention=False),
configs.d_model, configs.n_heads),
configs.d_model,
configs.d_ff,
dropout=configs.dropout,
activation=configs.activation,
)
for l in range(configs.d_layers)
],
norm_layer=torch.nn.LayerNorm(configs.d_model),
projection=nn.Linear(configs.d_model, configs.c_out, bias=True)
)
if self.task_name == 'imputation':
self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True)
if self.task_name == 'anomaly_detection':
self.projection = nn.Linear(configs.d_model, configs.c_out, bias=True)
if self.task_name == 'classification':
self.act = F.gelu
self.dropout = nn.Dropout(configs.dropout)
self.projection = nn.Linear(configs.d_model * configs.seq_len, configs.num_class)
self.tau_learner = Projector(enc_in=configs.enc_in, seq_len=configs.seq_len, hidden_dims=configs.p_hidden_dims,
hidden_layers=configs.p_hidden_layers, output_dim=1)
self.delta_learner = Projector(enc_in=configs.enc_in, seq_len=configs.seq_len,
hidden_dims=configs.p_hidden_dims, hidden_layers=configs.p_hidden_layers,
output_dim=configs.seq_len)
def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
x_raw = x_enc.clone().detach()
# Normalization
mean_enc = x_enc.mean(1, keepdim=True).detach() # B x 1 x E
x_enc = x_enc - mean_enc
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() # B x 1 x E
x_enc = x_enc / std_enc
# B x S x E, B x 1 x E -> B x 1, positive scalar
tau = self.tau_learner(x_raw, std_enc).exp()
# B x S x E, B x 1 x E -> B x S
delta = self.delta_learner(x_raw, mean_enc)
x_dec_new = torch.cat([x_enc[:, -self.label_len:, :], torch.zeros_like(x_dec[:, -self.pred_len:, :])],
dim=1).to(x_enc.device).clone()
enc_out = self.enc_embedding(x_enc, x_mark_enc)
enc_out, attns = self.encoder(enc_out, attn_mask=None, tau=tau, delta=delta)
dec_out = self.dec_embedding(x_dec_new, x_mark_dec)
dec_out = self.decoder(dec_out, enc_out, x_mask=None, cross_mask=None, tau=tau, delta=delta)
dec_out = dec_out * std_enc + mean_enc
return dec_out
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
x_raw = x_enc.clone().detach()
# Normalization
mean_enc = torch.sum(x_enc, dim=1) / torch.sum(mask == 1, dim=1)
mean_enc = mean_enc.unsqueeze(1).detach()
x_enc = x_enc - mean_enc
x_enc = x_enc.masked_fill(mask == 0, 0)
std_enc = torch.sqrt(torch.sum(x_enc * x_enc, dim=1) / torch.sum(mask == 1, dim=1) + 1e-5)
std_enc = std_enc.unsqueeze(1).detach()
x_enc /= std_enc
# B x S x E, B x 1 x E -> B x 1, positive scalar
tau = self.tau_learner(x_raw, std_enc).exp()
# B x S x E, B x 1 x E -> B x S
delta = self.delta_learner(x_raw, mean_enc)
enc_out = self.enc_embedding(x_enc, x_mark_enc)
enc_out, attns = self.encoder(enc_out, attn_mask=None, tau=tau, delta=delta)
dec_out = self.projection(enc_out)
dec_out = dec_out * std_enc + mean_enc
return dec_out
def anomaly_detection(self, x_enc):
x_raw = x_enc.clone().detach()
# Normalization
mean_enc = x_enc.mean(1, keepdim=True).detach() # B x 1 x E
x_enc = x_enc - mean_enc
std_enc = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() # B x 1 x E
x_enc = x_enc / std_enc
# B x S x E, B x 1 x E -> B x 1, positive scalar
tau = self.tau_learner(x_raw, std_enc).exp()
# B x S x E, B x 1 x E -> B x S
delta = self.delta_learner(x_raw, mean_enc)
# embedding
enc_out = self.enc_embedding(x_enc, None)
enc_out, attns = self.encoder(enc_out, attn_mask=None, tau=tau, delta=delta)
dec_out = self.projection(enc_out)
dec_out = dec_out * std_enc + mean_enc
return dec_out
def classification(self, x_enc, x_mark_enc):
x_raw = x_enc.clone().detach()
# Normalization
mean_enc = x_enc.mean(1, keepdim=True).detach() # B x 1 x E
std_enc = torch.sqrt(
torch.var(x_enc - mean_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach() # B x 1 x E
# B x S x E, B x 1 x E -> B x 1, positive scalar
tau = self.tau_learner(x_raw, std_enc).exp()
# B x S x E, B x 1 x E -> B x S
delta = self.delta_learner(x_raw, mean_enc)
# embedding
enc_out = self.enc_embedding(x_enc, None)
enc_out, attns = self.encoder(enc_out, attn_mask=None, tau=tau, delta=delta)
# Output
output = self.act(enc_out) # the output transformer encoder/decoder embeddings don't include non-linearity
output = self.dropout(output)
output = output * x_mark_enc.unsqueeze(-1) # zero-out padding embeddings
# (batch_size, seq_length * d_model)
output = output.reshape(output.shape[0], -1)
# (batch_size, num_classes)
output = self.projection(output)
return output
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out[:, -self.pred_len:, :] # [B, L, D]
if self.task_name == 'imputation':
dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
return dec_out # [B, L, D]
if self.task_name == 'anomaly_detection':
dec_out = self.anomaly_detection(x_enc)
return dec_out # [B, L, D]
if self.task_name == 'classification':
dec_out = self.classification(x_enc, x_mark_enc)
return dec_out # [B, L, D]
return None