-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
DLinear.py
110 lines (96 loc) · 4.4 KB
/
DLinear.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
import torch.nn as nn
import torch.nn.functional as F
from layers.Autoformer_EncDec import series_decomp
class Model(nn.Module):
"""
Paper link: https://arxiv.org/pdf/2205.13504.pdf
"""
def __init__(self, configs, individual=False):
"""
individual: Bool, whether shared model among different variates.
"""
super(Model, self).__init__()
self.task_name = configs.task_name
self.seq_len = configs.seq_len
if self.task_name == 'classification' or self.task_name == 'anomaly_detection' or self.task_name == 'imputation':
self.pred_len = configs.seq_len
else:
self.pred_len = configs.pred_len
# Series decomposition block from Autoformer
self.decompsition = series_decomp(configs.moving_avg)
self.individual = individual
self.channels = configs.enc_in
if self.individual:
self.Linear_Seasonal = nn.ModuleList()
self.Linear_Trend = nn.ModuleList()
for i in range(self.channels):
self.Linear_Seasonal.append(
nn.Linear(self.seq_len, self.pred_len))
self.Linear_Trend.append(
nn.Linear(self.seq_len, self.pred_len))
self.Linear_Seasonal[i].weight = nn.Parameter(
(1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len]))
self.Linear_Trend[i].weight = nn.Parameter(
(1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len]))
else:
self.Linear_Seasonal = nn.Linear(self.seq_len, self.pred_len)
self.Linear_Trend = nn.Linear(self.seq_len, self.pred_len)
self.Linear_Seasonal.weight = nn.Parameter(
(1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len]))
self.Linear_Trend.weight = nn.Parameter(
(1 / self.seq_len) * torch.ones([self.pred_len, self.seq_len]))
if self.task_name == 'classification':
self.projection = nn.Linear(
configs.enc_in * configs.seq_len, configs.num_class)
def encoder(self, x):
seasonal_init, trend_init = self.decompsition(x)
seasonal_init, trend_init = seasonal_init.permute(
0, 2, 1), trend_init.permute(0, 2, 1)
if self.individual:
seasonal_output = torch.zeros([seasonal_init.size(0), seasonal_init.size(1), self.pred_len],
dtype=seasonal_init.dtype).to(seasonal_init.device)
trend_output = torch.zeros([trend_init.size(0), trend_init.size(1), self.pred_len],
dtype=trend_init.dtype).to(trend_init.device)
for i in range(self.channels):
seasonal_output[:, i, :] = self.Linear_Seasonal[i](
seasonal_init[:, i, :])
trend_output[:, i, :] = self.Linear_Trend[i](
trend_init[:, i, :])
else:
seasonal_output = self.Linear_Seasonal(seasonal_init)
trend_output = self.Linear_Trend(trend_init)
x = seasonal_output + trend_output
return x.permute(0, 2, 1)
def forecast(self, x_enc):
# Encoder
return self.encoder(x_enc)
def imputation(self, x_enc):
# Encoder
return self.encoder(x_enc)
def anomaly_detection(self, x_enc):
# Encoder
return self.encoder(x_enc)
def classification(self, x_enc):
# Encoder
enc_out = self.encoder(x_enc)
# Output
# (batch_size, seq_length * d_model)
output = enc_out.reshape(enc_out.shape[0], -1)
# (batch_size, num_classes)
output = self.projection(output)
return output
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
dec_out = self.forecast(x_enc)
return dec_out[:, -self.pred_len:, :] # [B, L, D]
if self.task_name == 'imputation':
dec_out = self.imputation(x_enc)
return dec_out # [B, L, D]
if self.task_name == 'anomaly_detection':
dec_out = self.anomaly_detection(x_enc)
return dec_out # [B, L, D]
if self.task_name == 'classification':
dec_out = self.classification(x_enc)
return dec_out # [B, N]
return None