diff --git a/matsciml/models/base.py b/matsciml/models/base.py index e6d215a0..d6309089 100644 --- a/matsciml/models/base.py +++ b/matsciml/models/base.py @@ -894,12 +894,26 @@ def _compute_losses( targets = self._get_targets(batch) predictions = self(batch) losses = {} + used_norm = [] for key in self.task_keys: target_val = targets[key] if self.uses_normalizers: target_val = self.normalizers[key].norm(target_val) + used_norm.append(key) losses[key] = self.loss_func(predictions[key], target_val) total_loss: torch.Tensor = sum(losses.values()) + # trigger warning for when we infer normalization intent but + # not actually executed + if len(used_norm) == 0 and self.uses_normalizers: + raise RuntimeError( + "Target normalization was intended but not used." + f"Please check your config - expected: {self.task_keys}" + ) + if len(used_norm) != len(self.normalizers): + raise RuntimeError( + "Normalization was performed, but number of keys do not match." + f"Expected {len(self.normalizers)} keys, but only used {len(used_norm)}." + ) return {"loss": total_loss, "log": losses} def configure_optimizers(self) -> torch.optim.AdamW: