Skip to content

Commit

Permalink
fix log epoch (#1986)
Browse files Browse the repository at this point in the history
* fix logs

* fix tests

* remove dead loggers

* fix tests

* add prints

* fix logs
  • Loading branch information
mvpatel2000 authored Feb 22, 2023
1 parent b16fb66 commit ddf179c
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 7 deletions.
5 changes: 1 addition & 4 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1921,7 +1921,7 @@ def _train_loop(self) -> None:
try:
if int(self.state.timestamp.batch_in_epoch) == 0:
self.engine.run_event(Event.EPOCH_START)
self.logger.log_metrics({'epoch': int(self.state.timestamp.epoch)})
self.logger.log_metrics({'trainer/epoch': int(self.state.timestamp.epoch)})

dataloader = self.state.dataloader
if isinstance(dataloader, DataLoader) and isinstance(dataloader.sampler, DistributedSampler):
Expand Down Expand Up @@ -2760,9 +2760,6 @@ def _eval_loop(

self.engine.run_event(Event.EVAL_BATCH_END)

self.logger.log_metrics({'epoch': self.state.timestamp.epoch.value})
self.logger.log_metrics({'trainer/global_step': self.state.timestamp.batch.value})

self._compute_and_log_metrics(dataloader_label=dataloader_label, metrics=metrics)

self.engine.run_event(Event.EVAL_END)
Expand Down
3 changes: 2 additions & 1 deletion tests/loggers/test_cometml_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,8 @@ def test_comet_ml_logging_train_loop(monkeypatch, tmp_path):

# Check that basic metrics appear in the comet logs
assert len([
metric_msg for metric_msg in msg_type_to_msgs['metric_msg'] if metric_msg['metric']['metricName'] == 'epoch'
metric_msg for metric_msg in msg_type_to_msgs['metric_msg']
if metric_msg['metric']['metricName'] == 'trainer/epoch'
]) == 2

# Check that basic params appear in the comet logs
Expand Down
3 changes: 1 addition & 2 deletions tests/metrics/test_current_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,12 @@ def test_current_metrics(eval_interval: str,):
num_expected_calls += (train_subset_num_batches + 1) * num_epochs
# computed at eval end
if compute_val_metrics:
num_calls_per_eval = 3 # metrics + epoch + trainer/global_step
num_evals = 0
if eval_interval == '1ba':
num_evals += train_subset_num_batches * num_epochs
if eval_interval == '1ep':
num_evals += num_epochs
num_expected_calls += (num_calls_per_eval) * num_evals
num_expected_calls += num_evals

num_actual_calls = len(mock_logger_destination.log_metrics.mock_calls)

Expand Down

0 comments on commit ddf179c

Please sign in to comment.