Skip to content

Commit

Permalink
Update stagcn.py
Browse files Browse the repository at this point in the history
  • Loading branch information
QiweiMa-LL authored Oct 14, 2022
1 parent b6cf102 commit d93c55a
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions stagcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def forward(self, x, La):
Ls = []
L1 = La
L0 = torch.eye(nNode).repeat(nSample, 1, 1).cuda() # 单位矩阵
# torch.eye 为了生成nNode个对角线全1,其余部分全0的二维数组
# .repeat()把原始torch位置的数据与repeat对应位置相乘,多出来的维度写在前面
Ls.append(L0)
Ls.append(L1)
Expand Down Expand Up @@ -119,7 +118,7 @@ def dgconstruct(self, time_embedding, source_embedding, target_embedding, core_e
adp = torch.einsum('ai, ijk->ajk', time_embedding, core_embedding)
adp = torch.einsum('bj, ajk->abk', source_embedding, adp)
adp = torch.einsum('ck, abk->abc', target_embedding, adp)
adp = F.softmax(F.relu(adp), dim=-2) # 我改了F.relu(adp)
adp = F.softmax(F.relu(adp), dim=-2)
return adp

def forward(self, x):
Expand Down Expand Up @@ -232,7 +231,7 @@ def __init__(self, ks, kt, bs, T, n, p, out, num_features, adj, n_days):
def forward(self, x):
# attentional mechanisms
x_1 = self.TATT_1(x[:, [0]])
Las = self.adaptivegcn(x[:, [1]]) # dynamic Laplacian matrix multiplication
Las = self.adaptivegcn(x[:, [1]])
x_st1 = self.st_conv1(x_1, Las)
x_st2 = self.st_conv2(x_st1, Las)
return self.output(x_st2)

0 comments on commit d93c55a

Please sign in to comment.