Skip to content

Commit

Permalink
fix: 🐛 a bug in DGCRN runner (found by @wengwenchao123)
Browse files Browse the repository at this point in the history
  • Loading branch information
zezhishao committed Nov 8, 2024
1 parent e57e96e commit ef7d579
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions baselines/DGCRN/runner/dgcrn_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def forward(self, data: tuple, epoch: int = None, iter_num: int = None, train: b
dict: keys that must be included: inputs, prediction, target
"""

data = self.preprocessing(data)
# preprocess
future_data, history_data = data['target'], data['inputs']
history_data = self.to_running_device(history_data) # B, L, N, C
Expand All @@ -42,4 +43,5 @@ def forward(self, data: tuple, epoch: int = None, iter_num: int = None, train: b
model_return["target"] = self.select_target_features(future_data)
assert list(model_return["prediction"].shape)[:3] == [batch_size, length, num_nodes], \
"error shape of the output, edit the forward function to reshape it to [B, L, N, C]"
model_return = self.postprocessing(model_return)
return model_return

0 comments on commit ef7d579

Please sign in to comment.