Skip to content

Commit

Permalink
fix: 🐛 a bug in STAEformer dow embeddings (isse #219)
Browse files Browse the repository at this point in the history
  • Loading branch information
zezhishao committed Jan 1, 2025
1 parent d91d272 commit 4b5f0ad
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions baselines/STAEformer/arch/staeformer_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,16 +201,16 @@ def forward(self, history_data: torch.Tensor, future_data: torch.Tensor, batch_s
batch_size = x.shape[0]

if self.tod_embedding_dim > 0:
tod = x[..., 1]
tod = x[..., 1] * self.steps_per_day
if self.dow_embedding_dim > 0:
dow = x[..., 2]
dow = x[..., 2] * 7
x = x[..., : self.input_dim]

x = self.input_proj(x) # (batch_size, in_steps, num_nodes, input_embedding_dim)
features = [x]
if self.tod_embedding_dim > 0:
tod_emb = self.tod_embedding(
(tod * self.steps_per_day).long()
tod.long()
) # (batch_size, in_steps, num_nodes, tod_embedding_dim)
features.append(tod_emb)
if self.dow_embedding_dim > 0:
Expand Down

0 comments on commit 4b5f0ad

Please sign in to comment.