diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index ae67807da3..90cd25c41b 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -1921,7 +1921,7 @@ def _train_loop(self) -> None: try: if int(self.state.timestamp.batch_in_epoch) == 0: self.engine.run_event(Event.EPOCH_START) - self.logger.log_metrics({'epoch': int(self.state.timestamp.epoch)}) + self.logger.log_metrics({'trainer/epoch': int(self.state.timestamp.epoch)}) dataloader = self.state.dataloader if isinstance(dataloader, DataLoader) and isinstance(dataloader.sampler, DistributedSampler): @@ -2760,9 +2760,6 @@ def _eval_loop( self.engine.run_event(Event.EVAL_BATCH_END) - self.logger.log_metrics({'epoch': self.state.timestamp.epoch.value}) - self.logger.log_metrics({'trainer/global_step': self.state.timestamp.batch.value}) - self._compute_and_log_metrics(dataloader_label=dataloader_label, metrics=metrics) self.engine.run_event(Event.EVAL_END) diff --git a/tests/loggers/test_cometml_logger.py b/tests/loggers/test_cometml_logger.py index b57a9d0d8f..797a71b3ef 100644 --- a/tests/loggers/test_cometml_logger.py +++ b/tests/loggers/test_cometml_logger.py @@ -194,7 +194,8 @@ def test_comet_ml_logging_train_loop(monkeypatch, tmp_path): # Check that basic metrics appear in the comet logs assert len([ - metric_msg for metric_msg in msg_type_to_msgs['metric_msg'] if metric_msg['metric']['metricName'] == 'epoch' + metric_msg for metric_msg in msg_type_to_msgs['metric_msg'] + if metric_msg['metric']['metricName'] == 'trainer/epoch' ]) == 2 # Check that basic params appear in the comet logs diff --git a/tests/metrics/test_current_metrics.py b/tests/metrics/test_current_metrics.py index 1887b706a4..4ca4344f61 100644 --- a/tests/metrics/test_current_metrics.py +++ b/tests/metrics/test_current_metrics.py @@ -110,13 +110,12 @@ def test_current_metrics(eval_interval: str,): num_expected_calls += (train_subset_num_batches + 1) * num_epochs # computed at eval end if compute_val_metrics: - num_calls_per_eval = 3 # metrics + epoch + trainer/global_step num_evals = 0 if eval_interval == '1ba': num_evals += train_subset_num_batches * num_epochs if eval_interval == '1ep': num_evals += num_epochs - num_expected_calls += (num_calls_per_eval) * num_evals + num_expected_calls += num_evals num_actual_calls = len(mock_logger_destination.log_metrics.mock_calls)