Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Update LoggingHandler to support logging per interval #16922

Merged
merged 10 commits into from
Dec 7, 2019
29 changes: 23 additions & 6 deletions python/mxnet/gluon/contrib/estimator/event_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,20 +231,25 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat
Limit the granularity of metrics displayed during training process.
verbose=LOG_PER_EPOCH: display metrics every epoch
verbose=LOG_PER_BATCH: display metrics every batch
verbose=LOG_PER_INTERVAL: display metrics every interval of batches
train_metrics : list of EvalMetrics
Training metrics to be logged, logged at batch end, epoch end, train end.
val_metrics : list of EvalMetrics
Validation metrics to be logged, logged at epoch end, train end.
log_interval: int, default 1
Logging interval during training. 1 is equivalent to LOG_PER_BATCH.
"""

LOG_PER_EPOCH = 1
LOG_PER_BATCH = 2
LOG_PER_INTERVAL = 3

def __init__(self, verbose=LOG_PER_EPOCH,
train_metrics=None,
val_metrics=None):
val_metrics=None,
log_interval=1):
liuzh47 marked this conversation as resolved.
Show resolved Hide resolved
super(LoggingHandler, self).__init__()
if verbose not in [self.LOG_PER_EPOCH, self.LOG_PER_BATCH]:
if verbose not in [self.LOG_PER_EPOCH, self.LOG_PER_BATCH, self.LOG_PER_INTERVAL]:
raise ValueError("verbose level must be either LOG_PER_EPOCH or "
"LOG_PER_BATCH, received %s. "
"E.g: LoggingHandler(verbose=LoggingHandler.LOG_PER_EPOCH)"
Expand All @@ -258,6 +263,7 @@ def __init__(self, verbose=LOG_PER_EPOCH,
# logging handler need to be called at last to make sure all states are updated
# it will also shut down logging at train end
self.priority = np.Inf
self.log_interval = log_interval

def train_begin(self, estimator, *args, **kwargs):
self.train_start = time.time()
Expand All @@ -275,6 +281,7 @@ def train_begin(self, estimator, *args, **kwargs):
self.current_epoch = 0
self.batch_index = 0
self.processed_samples = 0
self.log_interval_time = 0

def train_end(self, estimator, *args, **kwargs):
train_time = time.time() - self.train_start
Expand All @@ -286,31 +293,41 @@ def train_end(self, estimator, *args, **kwargs):
estimator.logger.info(msg.rstrip(', '))

def batch_begin(self, estimator, *args, **kwargs):
if self.verbose == self.LOG_PER_BATCH:
if self.verbose in [self.LOG_PER_BATCH, self.LOG_PER_INTERVAL]:
self.batch_start = time.time()

def batch_end(self, estimator, *args, **kwargs):
if self.verbose == self.LOG_PER_BATCH:
if self.verbose in [self.LOG_PER_BATCH, self.LOG_PER_INTERVAL]:
batch_time = time.time() - self.batch_start
msg = '[Epoch %d][Batch %d]' % (self.current_epoch, self.batch_index)
self.processed_samples += kwargs['batch'][0].shape[0]
msg += '[Samples %s] ' % (self.processed_samples)
self.log_interval_time += batch_time

if self.verbose == self.LOG_PER_BATCH:
msg += 'time/batch: %.3fs ' % batch_time
for metric in self.train_metrics:
# only log current training loss & metric after each batch
name, value = metric.get()
msg += '%s: %.4f, ' % (name, value)
estimator.logger.info(msg.rstrip(', '))
elif self.verbose == self.LOG_PER_INTERVAL and self.batch_index % self.log_interval == 0:
msg += 'time/interval: %.3fs ' % self.log_interval_time
self.log_interval_time = 0
for monitor in self.train_metrics + self.val_metrics:
name, value = monitor.get()
msg += '%s: %.4f, ' % (name, value)
estimator.logger.info(msg.rstrip(', '))
self.batch_index += 1

def epoch_begin(self, estimator, *args, **kwargs):
if self.verbose >= self.LOG_PER_EPOCH:
if self.verbose in [self.LOG_PER_EPOCH, self.LOG_PER_BATCH, self.LOG_PER_INTERVAL]:
self.epoch_start = time.time()
estimator.logger.info("[Epoch %d] Begin, current learning rate: %.4f",
self.current_epoch, estimator.trainer.learning_rate)

def epoch_end(self, estimator, *args, **kwargs):
if self.verbose >= self.LOG_PER_EPOCH:
if self.verbose in [self.LOG_PER_EPOCH, self.LOG_PER_BATCH, self.LOG_PER_INTERVAL]:
epoch_time = time.time() - self.epoch_start
msg = '[Epoch %d] Finished in %.3fs, ' % (self.current_epoch, epoch_time)
for monitor in self.train_metrics + self.val_metrics:
Expand Down