From 97a1c0f875dc6d974f3578be0da8ff29ad39e3e2 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 26 Jul 2024 13:17:51 -0700 Subject: [PATCH] [NeMo-UX] log val loss (#9814) (#9831) Signed-off-by: ashors1 Co-authored-by: Anna Shors <71393111+ashors1@users.noreply.github.com> --- nemo/lightning/pytorch/strategies.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/nemo/lightning/pytorch/strategies.py b/nemo/lightning/pytorch/strategies.py index 57cd33a612ae4..a17bdd60c77cc 100644 --- a/nemo/lightning/pytorch/strategies.py +++ b/nemo/lightning/pytorch/strategies.py @@ -441,7 +441,9 @@ def validation_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OU kwargs = self._update_step_kwargs(dataloader_iter, kwargs, "validation") with self.precision_plugin.val_step_context(): # TODO: Do we need this? - return self.model(dataloader_iter, forward_only=True, *args, **kwargs) + out = self.model(dataloader_iter, forward_only=True, *args, **kwargs) + self.lightning_module.log('val_loss', out, rank_zero_only=True, batch_size=1) + return out @override def test_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTPUT: