diff --git a/baselines/STAEformer/arch/staeformer_arch.py b/baselines/STAEformer/arch/staeformer_arch.py index a2a0dfcb..aa6085c6 100644 --- a/baselines/STAEformer/arch/staeformer_arch.py +++ b/baselines/STAEformer/arch/staeformer_arch.py @@ -201,16 +201,16 @@ def forward(self, history_data: torch.Tensor, future_data: torch.Tensor, batch_s batch_size = x.shape[0] if self.tod_embedding_dim > 0: - tod = x[..., 1] + tod = x[..., 1] * self.steps_per_day if self.dow_embedding_dim > 0: - dow = x[..., 2] + dow = x[..., 2] * 7 x = x[..., : self.input_dim] x = self.input_proj(x) # (batch_size, in_steps, num_nodes, input_embedding_dim) features = [x] if self.tod_embedding_dim > 0: tod_emb = self.tod_embedding( - (tod * self.steps_per_day).long() + tod.long() ) # (batch_size, in_steps, num_nodes, tod_embedding_dim) features.append(tod_emb) if self.dow_embedding_dim > 0: