From 09d1db8995b362c3ace7c44a0496aecd476dd09e Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Wed, 27 Nov 2019 07:59:11 +0000 Subject: [PATCH 01/10] Update LoggingHandler to support logging per interval --- .../gluon/contrib/estimator/event_handler.py | 25 ++++++++++++++++--- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index 3cdc407407c1..bff3cdf387cd 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -231,20 +231,25 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat Limit the granularity of metrics displayed during training process. verbose=LOG_PER_EPOCH: display metrics every epoch verbose=LOG_PER_BATCH: display metrics every batch + verbose=LOG_PER_INTERVAL: display metrics every interval of batches train_metrics : list of EvalMetrics Training metrics to be logged, logged at batch end, epoch end, train end. val_metrics : list of EvalMetrics Validation metrics to be logged, logged at epoch end, train end. + log_interval: int, default 1 + Logging interval during training. 1 is equivalent to LOG_PER_BATCH. """ LOG_PER_EPOCH = 1 LOG_PER_BATCH = 2 + LOG_PER_INTERVAL = 3 def __init__(self, verbose=LOG_PER_EPOCH, train_metrics=None, - val_metrics=None): + val_metrics=None, + log_interval=1): super(LoggingHandler, self).__init__() - if verbose not in [self.LOG_PER_EPOCH, self.LOG_PER_BATCH]: + if verbose not in [self.LOG_PER_EPOCH, self.LOG_PER_BATCH, self.LOG_PER_INTERVAL]: raise ValueError("verbose level must be either LOG_PER_EPOCH or " "LOG_PER_BATCH, received %s. " "E.g: LoggingHandler(verbose=LoggingHandler.LOG_PER_EPOCH)" @@ -258,6 +263,7 @@ def __init__(self, verbose=LOG_PER_EPOCH, # logging handler need to be called at last to make sure all states are updated # it will also shut down logging at train end self.priority = np.Inf + self.log_interval = log_interval def train_begin(self, estimator, *args, **kwargs): self.train_start = time.time() @@ -275,6 +281,7 @@ def train_begin(self, estimator, *args, **kwargs): self.current_epoch = 0 self.batch_index = 0 self.processed_samples = 0 + self.log_interval_time = 0 def train_end(self, estimator, *args, **kwargs): train_time = time.time() - self.train_start @@ -286,21 +293,31 @@ def train_end(self, estimator, *args, **kwargs): estimator.logger.info(msg.rstrip(', ')) def batch_begin(self, estimator, *args, **kwargs): - if self.verbose == self.LOG_PER_BATCH: + if self.verbose >= self.LOG_PER_BATCH: self.batch_start = time.time() def batch_end(self, estimator, *args, **kwargs): - if self.verbose == self.LOG_PER_BATCH: + if self.verbose >= self.LOG_PER_BATCH: batch_time = time.time() - self.batch_start msg = '[Epoch %d][Batch %d]' % (self.current_epoch, self.batch_index) self.processed_samples += kwargs['batch'][0].shape[0] msg += '[Samples %s] ' % (self.processed_samples) + self.log_interval_time += batch_time + + if self.verbose == self.LOG_PER_BATCH: msg += 'time/batch: %.3fs ' % batch_time for metric in self.train_metrics: # only log current training loss & metric after each batch name, value = metric.get() msg += '%s: %.4f, ' % (name, value) estimator.logger.info(msg.rstrip(', ')) + elif self.verbose == self.LOG_PER_INTERVAL and self.batch_index % self.log_interval == 0: + msg += 'time/interval: %.3fs ' % self.log_interval_time + self.log_interval_time = 0 + for monitor in self.train_metrics + self.val_metrics: + name, value = monitor.get() + msg += '%s: %.4f, ' % (name, value) + estimator.logger.info(msg.rstrip(', ')) self.batch_index += 1 def epoch_begin(self, estimator, *args, **kwargs): From fea0bc6c6ad97bcd812ab5e56afe3fedf6825bf0 Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Mon, 2 Dec 2019 03:08:23 +0000 Subject: [PATCH 02/10] Fix the constant variable issue in the logging handler --- python/mxnet/gluon/contrib/estimator/event_handler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index bff3cdf387cd..849265f08fc0 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -293,11 +293,11 @@ def train_end(self, estimator, *args, **kwargs): estimator.logger.info(msg.rstrip(', ')) def batch_begin(self, estimator, *args, **kwargs): - if self.verbose >= self.LOG_PER_BATCH: + if self.verbose in [self.LOG_PER_BATCH, self.LOG_PER_INTERVAL]: self.batch_start = time.time() def batch_end(self, estimator, *args, **kwargs): - if self.verbose >= self.LOG_PER_BATCH: + if self.verbose in [self.LOG_PER_BATCH, self.LOG_PER_INTERVAL]: batch_time = time.time() - self.batch_start msg = '[Epoch %d][Batch %d]' % (self.current_epoch, self.batch_index) self.processed_samples += kwargs['batch'][0].shape[0] From 1d588dfbf8bd6b57157da7026e95333fe027a13d Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Mon, 2 Dec 2019 03:24:46 +0000 Subject: [PATCH 03/10] Remove the constant variable hack in the logging handler. --- python/mxnet/gluon/contrib/estimator/event_handler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index 849265f08fc0..109edabb7349 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -321,13 +321,13 @@ def batch_end(self, estimator, *args, **kwargs): self.batch_index += 1 def epoch_begin(self, estimator, *args, **kwargs): - if self.verbose >= self.LOG_PER_EPOCH: + if self.verbose in [self.LOG_PER_EPOCH, self.LOG_PER_BATCH, self.LOG_PER_INTERVAL]: self.epoch_start = time.time() estimator.logger.info("[Epoch %d] Begin, current learning rate: %.4f", self.current_epoch, estimator.trainer.learning_rate) def epoch_end(self, estimator, *args, **kwargs): - if self.verbose >= self.LOG_PER_EPOCH: + if self.verbose in [self.LOG_PER_EPOCH, self.LOG_PER_BATCH, self.LOG_PER_INTERVAL]: epoch_time = time.time() - self.epoch_start msg = '[Epoch %d] Finished in %.3fs, ' % (self.current_epoch, epoch_time) for monitor in self.train_metrics + self.val_metrics: From 20e44a0f46e7bca28e8a905640ddf42e8f7fdb3b Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Tue, 3 Dec 2019 10:35:51 +0000 Subject: [PATCH 04/10] 1) replace LOG_PER_BATCH with LOG_PER_INTERVAL 2) add test case --- .../gluon/contrib/estimator/event_handler.py | 38 ++++------ tests/python/unittest/test_gluon_estimator.py | 75 ++++++++++++++++++- 2 files changed, 86 insertions(+), 27 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index 109edabb7349..93ea328d4dfe 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -241,17 +241,16 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat """ LOG_PER_EPOCH = 1 - LOG_PER_BATCH = 2 - LOG_PER_INTERVAL = 3 + LOG_PER_INTERVAL = 2 def __init__(self, verbose=LOG_PER_EPOCH, train_metrics=None, val_metrics=None, log_interval=1): super(LoggingHandler, self).__init__() - if verbose not in [self.LOG_PER_EPOCH, self.LOG_PER_BATCH, self.LOG_PER_INTERVAL]: + if verbose not in [self.LOG_PER_EPOCH, self.LOG_PER_INTERVAL]: raise ValueError("verbose level must be either LOG_PER_EPOCH or " - "LOG_PER_BATCH, received %s. " + "LOG_PER_INTERVAL, received %s. " "E.g: LoggingHandler(verbose=LoggingHandler.LOG_PER_EPOCH)" % verbose) self.verbose = verbose @@ -293,41 +292,34 @@ def train_end(self, estimator, *args, **kwargs): estimator.logger.info(msg.rstrip(', ')) def batch_begin(self, estimator, *args, **kwargs): - if self.verbose in [self.LOG_PER_BATCH, self.LOG_PER_INTERVAL]: + if self.verbose == self.LOG_PER_INTERVAL: self.batch_start = time.time() def batch_end(self, estimator, *args, **kwargs): - if self.verbose in [self.LOG_PER_BATCH, self.LOG_PER_INTERVAL]: + if self.verbose == self.LOG_PER_INTERVAL: batch_time = time.time() - self.batch_start msg = '[Epoch %d][Batch %d]' % (self.current_epoch, self.batch_index) self.processed_samples += kwargs['batch'][0].shape[0] msg += '[Samples %s] ' % (self.processed_samples) self.log_interval_time += batch_time - - if self.verbose == self.LOG_PER_BATCH: - msg += 'time/batch: %.3fs ' % batch_time - for metric in self.train_metrics: - # only log current training loss & metric after each batch - name, value = metric.get() - msg += '%s: %.4f, ' % (name, value) - estimator.logger.info(msg.rstrip(', ')) - elif self.verbose == self.LOG_PER_INTERVAL and self.batch_index % self.log_interval == 0: - msg += 'time/interval: %.3fs ' % self.log_interval_time - self.log_interval_time = 0 - for monitor in self.train_metrics + self.val_metrics: - name, value = monitor.get() - msg += '%s: %.4f, ' % (name, value) - estimator.logger.info(msg.rstrip(', ')) + if self.batch_index % self.log_interval == 0: + msg += 'time/interval: %.3fs ' % self.log_interval_time + self.log_interval_time = 0 + for metric in self.train_metrics: + # only log current training loss & metric after each interval + name, value = metric.get() + msg += '%s: %.4f, ' % (name, value) + estimator.logger.info(msg.rstrip(', ')) self.batch_index += 1 def epoch_begin(self, estimator, *args, **kwargs): - if self.verbose in [self.LOG_PER_EPOCH, self.LOG_PER_BATCH, self.LOG_PER_INTERVAL]: + if self.verbose in [self.LOG_PER_EPOCH, self.LOG_PER_INTERVAL]: self.epoch_start = time.time() estimator.logger.info("[Epoch %d] Begin, current learning rate: %.4f", self.current_epoch, estimator.trainer.learning_rate) def epoch_end(self, estimator, *args, **kwargs): - if self.verbose in [self.LOG_PER_EPOCH, self.LOG_PER_BATCH, self.LOG_PER_INTERVAL]: + if self.verbose in [self.LOG_PER_EPOCH, self.LOG_PER_INTERVAL]: epoch_time = time.time() - self.epoch_start msg = '[Epoch %d] Finished in %.3fs, ' % (self.current_epoch, epoch_time) for monitor in self.train_metrics + self.val_metrics: diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py index cf913a6161c0..f28fec6fd49d 100644 --- a/tests/python/unittest/test_gluon_estimator.py +++ b/tests/python/unittest/test_gluon_estimator.py @@ -20,6 +20,7 @@ import sys import unittest import warnings +import re import mxnet as mx from mxnet import gluon @@ -27,7 +28,7 @@ from mxnet.gluon.contrib.estimator import * from mxnet.gluon.contrib.estimator.event_handler import * from nose.tools import assert_raises - +from io import StringIO def _get_test_network(): net = nn.Sequential() @@ -35,10 +36,10 @@ def _get_test_network(): return net -def _get_test_data(): +def _get_test_data(in_size=10): batch_size = 4 - in_data = mx.nd.random.uniform(shape=(10, 3)) - out_data = mx.nd.random.uniform(shape=(10, 4)) + in_data = mx.nd.random.uniform(shape=(in_size, 3)) + out_data = mx.nd.random.uniform(shape=(in_size, 4)) # Input dataloader dataset = gluon.data.dataset.ArrayDataset(in_data, out_data) dataloader = gluon.data.DataLoader(dataset, batch_size=batch_size) @@ -370,3 +371,69 @@ def test_default_handlers(): assert len(handlers) == 4 assert isinstance(handlers[0], MetricHandler) assert isinstance(handlers[3], LoggingHandler) + +def test_logging_interval(): + ''' test different options for logging handler ''' + ''' test case #1: log interval is 1 ''' + batch_size = 4 + data_size = 50 + old_stdout = sys.stdout + sys.stdout = mystdout = StringIO() + log_interval = 1 + net = _get_test_network() + dataloader, dataiter = _get_test_data(in_size=data_size) + num_epochs = 1 + ctx = mx.cpu() + loss = gluon.loss.L2Loss() + acc = mx.metric.Accuracy() + net.initialize(ctx=ctx) + trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) + logging = LoggingHandler(verbose=LoggingHandler.LOG_PER_INTERVAL, + train_metrics=[acc], log_interval=log_interval) + est = Estimator(net=net, + loss=loss, + metrics=acc, + trainer=trainer, + context=ctx) + + est.fit(train_data=dataloader, + epochs=num_epochs, + event_handlers=[logging]) + + sys.stdout = old_stdout + log_info_list = mystdout.getvalue().splitlines() + info_len = 0 + for info in log_info_list: + match = re.match( + '(\[Epoch \d+\]\[Batch \d+\]\[Samples \d+\] time\/interval: \d+.\d+s' + + ' training accuracy: \d+.\d+)', info) + if match: + info_len += 1 + + assert(info_len == int(data_size/batch_size/log_interval) + 1) + ''' test case #2: log interval is 5 ''' + old_stdout = sys.stdout + sys.stdout = mystdout = StringIO() + acc = mx.metric.Accuracy() + log_interval = 5 + trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) + logging = LoggingHandler(verbose=LoggingHandler.LOG_PER_INTERVAL, + train_metrics=[acc], log_interval=log_interval) + est = Estimator(net=net, + loss=loss, + metrics=acc, + trainer=trainer, + context=ctx) + est.fit(train_data=dataloader, + epochs=num_epochs, + event_handlers=[logging]) + sys.stdout = old_stdout + log_info_list = mystdout.getvalue().splitlines() + info_len = 0 + for info in log_info_list: + match = re.match( + '(\[Epoch \d+\]\[Batch \d+\]\[Samples \d+\] time\/interval: \d+.\d+s' + + ' training accuracy: \d+.\d+)', info) + if match: + info_len += 1 + assert(info_len == int(data_size/batch_size/log_interval) + 1) From 05018ed4121dde2a696c9b6adab578e586813424 Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Thu, 5 Dec 2019 04:26:10 +0000 Subject: [PATCH 05/10] Improve the test script for LoggingHandler --- tests/python/unittest/test_gluon_estimator.py | 74 +---------------- .../unittest/test_gluon_event_handler.py | 82 ++++++++++++++++++- 2 files changed, 82 insertions(+), 74 deletions(-) diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py index f28fec6fd49d..518c0a0077ed 100644 --- a/tests/python/unittest/test_gluon_estimator.py +++ b/tests/python/unittest/test_gluon_estimator.py @@ -20,7 +20,6 @@ import sys import unittest import warnings -import re import mxnet as mx from mxnet import gluon @@ -28,7 +27,7 @@ from mxnet.gluon.contrib.estimator import * from mxnet.gluon.contrib.estimator.event_handler import * from nose.tools import assert_raises -from io import StringIO + def _get_test_network(): net = nn.Sequential() @@ -36,10 +35,10 @@ def _get_test_network(): return net -def _get_test_data(in_size=10): +def _get_test_data(): batch_size = 4 - in_data = mx.nd.random.uniform(shape=(in_size, 3)) - out_data = mx.nd.random.uniform(shape=(in_size, 4)) + in_data = mx.nd.random.uniform(shape=(10, 3)) + out_data = mx.nd.random.uniform(shape=(10, 4)) # Input dataloader dataset = gluon.data.dataset.ArrayDataset(in_data, out_data) dataloader = gluon.data.DataLoader(dataset, batch_size=batch_size) @@ -372,68 +371,3 @@ def test_default_handlers(): assert isinstance(handlers[0], MetricHandler) assert isinstance(handlers[3], LoggingHandler) -def test_logging_interval(): - ''' test different options for logging handler ''' - ''' test case #1: log interval is 1 ''' - batch_size = 4 - data_size = 50 - old_stdout = sys.stdout - sys.stdout = mystdout = StringIO() - log_interval = 1 - net = _get_test_network() - dataloader, dataiter = _get_test_data(in_size=data_size) - num_epochs = 1 - ctx = mx.cpu() - loss = gluon.loss.L2Loss() - acc = mx.metric.Accuracy() - net.initialize(ctx=ctx) - trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) - logging = LoggingHandler(verbose=LoggingHandler.LOG_PER_INTERVAL, - train_metrics=[acc], log_interval=log_interval) - est = Estimator(net=net, - loss=loss, - metrics=acc, - trainer=trainer, - context=ctx) - - est.fit(train_data=dataloader, - epochs=num_epochs, - event_handlers=[logging]) - - sys.stdout = old_stdout - log_info_list = mystdout.getvalue().splitlines() - info_len = 0 - for info in log_info_list: - match = re.match( - '(\[Epoch \d+\]\[Batch \d+\]\[Samples \d+\] time\/interval: \d+.\d+s' + - ' training accuracy: \d+.\d+)', info) - if match: - info_len += 1 - - assert(info_len == int(data_size/batch_size/log_interval) + 1) - ''' test case #2: log interval is 5 ''' - old_stdout = sys.stdout - sys.stdout = mystdout = StringIO() - acc = mx.metric.Accuracy() - log_interval = 5 - trainer = gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) - logging = LoggingHandler(verbose=LoggingHandler.LOG_PER_INTERVAL, - train_metrics=[acc], log_interval=log_interval) - est = Estimator(net=net, - loss=loss, - metrics=acc, - trainer=trainer, - context=ctx) - est.fit(train_data=dataloader, - epochs=num_epochs, - event_handlers=[logging]) - sys.stdout = old_stdout - log_info_list = mystdout.getvalue().splitlines() - info_len = 0 - for info in log_info_list: - match = re.match( - '(\[Epoch \d+\]\[Batch \d+\]\[Samples \d+\] time\/interval: \d+.\d+s' + - ' training accuracy: \d+.\d+)', info) - if match: - info_len += 1 - assert(info_len == int(data_size/batch_size/log_interval) + 1) diff --git a/tests/python/unittest/test_gluon_event_handler.py b/tests/python/unittest/test_gluon_event_handler.py index 17c75813d516..2ef8d3b5f603 100644 --- a/tests/python/unittest/test_gluon_event_handler.py +++ b/tests/python/unittest/test_gluon_event_handler.py @@ -17,13 +17,19 @@ import os import logging +import sys +import re import mxnet as mx from common import TemporaryDirectory from mxnet import nd from mxnet.gluon import nn, loss from mxnet.gluon.contrib.estimator import estimator, event_handler - +from mxnet.gluon.contrib.estimator.event_handler import LoggingHandler +try: + from StringIO import StringIO +except ImportError: + from io import StringIO def _get_test_network(net=nn.Sequential()): net.add(nn.Dense(128, activation='relu', flatten=False), @@ -32,9 +38,9 @@ def _get_test_network(net=nn.Sequential()): return net -def _get_test_data(): - data = nd.ones((32, 100)) - label = nd.zeros((32, 1)) +def _get_test_data(in_size=32): + data = nd.ones((in_size, 100)) + label = nd.zeros((in_size, 1)) data_arr = mx.gluon.data.dataset.ArrayDataset(data, label) return mx.gluon.data.DataLoader(data_arr, batch_size=8) @@ -200,3 +206,71 @@ def epoch_end(self, estimator, *args, **kwargs): est.fit(test_data, event_handlers=[custom_handler], epochs=10) assert custom_handler.num_batch == 5 * 4 assert custom_handler.num_epoch == 5 + +def test_logging_interval(): + ''' test different options for logging handler ''' + ''' test case #1: log interval is 1 ''' + batch_size = 8 + data_size = 100 + old_stdout = sys.stdout + sys.stdout = mystdout = StringIO() + log_interval = 1 + net = _get_test_network() + dataloader = _get_test_data(in_size=data_size) + num_epochs = 1 + ctx = mx.cpu() + ce_loss = loss.SoftmaxCrossEntropyLoss() + acc = mx.metric.Accuracy() + net.initialize(ctx=ctx) + trainer = mx.gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) + logging = LoggingHandler(verbose=LoggingHandler.LOG_PER_INTERVAL, + train_metrics=[acc], log_interval=log_interval) + est = estimator.Estimator(net=net, + loss=ce_loss, + metrics=acc, + trainer=trainer, + context=ctx) + + est.fit(train_data=dataloader, + epochs=num_epochs, + event_handlers=[logging]) + + sys.stdout = old_stdout + log_info_list = mystdout.getvalue().splitlines() + info_len = 0 + for info in log_info_list: + match = re.match( + '(\[Epoch \d+\]\[Batch \d+\]\[Samples \d+\] time\/interval: \d+.\d+s' + + ' training accuracy: \d+.\d+)', info) + if match: + info_len += 1 + + assert(info_len == int(data_size/batch_size/log_interval) + 1) + ''' test case #2: log interval is 5 ''' + old_stdout = sys.stdout + sys.stdout = mystdout = StringIO() + acc = mx.metric.Accuracy() + log_interval = 5 + trainer = mx.gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) + logging = LoggingHandler(verbose=LoggingHandler.LOG_PER_INTERVAL, + train_metrics=[acc], log_interval=log_interval) + est = estimator.Estimator(net=net, + loss=ce_loss, + metrics=acc, + trainer=trainer, + context=ctx) + est.fit(train_data=dataloader, + epochs=num_epochs, + event_handlers=[logging]) + sys.stdout = old_stdout + log_info_list = mystdout.getvalue().splitlines() + info_len = 0 + for info in log_info_list: + match = re.match( + '(\[Epoch \d+\]\[Batch \d+\]\[Samples \d+\] time\/interval: \d+.\d+s' + + ' training accuracy: \d+.\d+)', info) + if match: + info_len += 1 + + assert(info_len == int(data_size/batch_size/log_interval) + 1) + From 7820ea97213682647649b111c46959f80e30204a Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Thu, 5 Dec 2019 04:28:00 +0000 Subject: [PATCH 06/10] small fix on the test script --- tests/python/unittest/test_gluon_estimator.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/python/unittest/test_gluon_estimator.py b/tests/python/unittest/test_gluon_estimator.py index 518c0a0077ed..cf913a6161c0 100644 --- a/tests/python/unittest/test_gluon_estimator.py +++ b/tests/python/unittest/test_gluon_estimator.py @@ -370,4 +370,3 @@ def test_default_handlers(): assert len(handlers) == 4 assert isinstance(handlers[0], MetricHandler) assert isinstance(handlers[3], LoggingHandler) - From 0a3532b3bca191df41d1879071c498888405a20b Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Thu, 5 Dec 2019 06:38:30 +0000 Subject: [PATCH 07/10] logging handler test case bug fix --- tests/python/unittest/test_gluon_event_handler.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/python/unittest/test_gluon_event_handler.py b/tests/python/unittest/test_gluon_event_handler.py index 2ef8d3b5f603..2719c9293add 100644 --- a/tests/python/unittest/test_gluon_event_handler.py +++ b/tests/python/unittest/test_gluon_event_handler.py @@ -222,13 +222,11 @@ def test_logging_interval(): ce_loss = loss.SoftmaxCrossEntropyLoss() acc = mx.metric.Accuracy() net.initialize(ctx=ctx) - trainer = mx.gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) logging = LoggingHandler(verbose=LoggingHandler.LOG_PER_INTERVAL, train_metrics=[acc], log_interval=log_interval) est = estimator.Estimator(net=net, loss=ce_loss, metrics=acc, - trainer=trainer, context=ctx) est.fit(train_data=dataloader, @@ -251,13 +249,11 @@ def test_logging_interval(): sys.stdout = mystdout = StringIO() acc = mx.metric.Accuracy() log_interval = 5 - trainer = mx.gluon.Trainer(net.collect_params(), 'sgd', {'learning_rate': 0.001}) logging = LoggingHandler(verbose=LoggingHandler.LOG_PER_INTERVAL, train_metrics=[acc], log_interval=log_interval) est = estimator.Estimator(net=net, loss=ce_loss, metrics=acc, - trainer=trainer, context=ctx) est.fit(train_data=dataloader, epochs=num_epochs, From 458d22b551aadcfbf9599fd90f2e493687cd5cac Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Thu, 5 Dec 2019 07:42:10 +0000 Subject: [PATCH 08/10] remove parameter verbose from LoggingHandler --- .../gluon/contrib/estimator/event_handler.py | 35 +++++++------------ .../unittest/test_gluon_event_handler.py | 6 ++-- 2 files changed, 14 insertions(+), 27 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index 93ea328d4dfe..0eae0a6fa143 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -227,33 +227,22 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat Parameters ---------- - verbose : int, default LOG_PER_EPOCH - Limit the granularity of metrics displayed during training process. - verbose=LOG_PER_EPOCH: display metrics every epoch - verbose=LOG_PER_BATCH: display metrics every batch - verbose=LOG_PER_INTERVAL: display metrics every interval of batches train_metrics : list of EvalMetrics Training metrics to be logged, logged at batch end, epoch end, train end. val_metrics : list of EvalMetrics Validation metrics to be logged, logged at epoch end, train end. - log_interval: int, default 1 - Logging interval during training. 1 is equivalent to LOG_PER_BATCH. + log_interval: int or str, default 'epoch' + Logging interval during training. + log_interval='epoch': display metrics every epoch + log_interval=integer k: display metrics every interval of k batches """ - LOG_PER_EPOCH = 1 - LOG_PER_INTERVAL = 2 - - def __init__(self, verbose=LOG_PER_EPOCH, - train_metrics=None, + def __init__(self, train_metrics=None, val_metrics=None, - log_interval=1): + log_interval='epoch'): super(LoggingHandler, self).__init__() - if verbose not in [self.LOG_PER_EPOCH, self.LOG_PER_INTERVAL]: - raise ValueError("verbose level must be either LOG_PER_EPOCH or " - "LOG_PER_INTERVAL, received %s. " - "E.g: LoggingHandler(verbose=LoggingHandler.LOG_PER_EPOCH)" - % verbose) - self.verbose = verbose + if not isinstance(log_interval, int) and log_interval != 'epoch': + raise ValueError("log_interval must be either an integer or string 'epoch'") self.train_metrics = _check_metrics(train_metrics) self.val_metrics = _check_metrics(val_metrics) self.batch_index = 0 @@ -292,11 +281,11 @@ def train_end(self, estimator, *args, **kwargs): estimator.logger.info(msg.rstrip(', ')) def batch_begin(self, estimator, *args, **kwargs): - if self.verbose == self.LOG_PER_INTERVAL: + if isinstance(self.log_interval, int): self.batch_start = time.time() def batch_end(self, estimator, *args, **kwargs): - if self.verbose == self.LOG_PER_INTERVAL: + if isinstance(self.log_interval, int): batch_time = time.time() - self.batch_start msg = '[Epoch %d][Batch %d]' % (self.current_epoch, self.batch_index) self.processed_samples += kwargs['batch'][0].shape[0] @@ -313,13 +302,13 @@ def batch_end(self, estimator, *args, **kwargs): self.batch_index += 1 def epoch_begin(self, estimator, *args, **kwargs): - if self.verbose in [self.LOG_PER_EPOCH, self.LOG_PER_INTERVAL]: + if isinstance(self.log_interval, int) or self.log_interval == 'epoch': self.epoch_start = time.time() estimator.logger.info("[Epoch %d] Begin, current learning rate: %.4f", self.current_epoch, estimator.trainer.learning_rate) def epoch_end(self, estimator, *args, **kwargs): - if self.verbose in [self.LOG_PER_EPOCH, self.LOG_PER_INTERVAL]: + if isinstance(self.log_interval, int) or self.log_interval == 'epoch': epoch_time = time.time() - self.epoch_start msg = '[Epoch %d] Finished in %.3fs, ' % (self.current_epoch, epoch_time) for monitor in self.train_metrics + self.val_metrics: diff --git a/tests/python/unittest/test_gluon_event_handler.py b/tests/python/unittest/test_gluon_event_handler.py index 2719c9293add..fc069317dd13 100644 --- a/tests/python/unittest/test_gluon_event_handler.py +++ b/tests/python/unittest/test_gluon_event_handler.py @@ -222,8 +222,7 @@ def test_logging_interval(): ce_loss = loss.SoftmaxCrossEntropyLoss() acc = mx.metric.Accuracy() net.initialize(ctx=ctx) - logging = LoggingHandler(verbose=LoggingHandler.LOG_PER_INTERVAL, - train_metrics=[acc], log_interval=log_interval) + logging = LoggingHandler(train_metrics=[acc], log_interval=log_interval) est = estimator.Estimator(net=net, loss=ce_loss, metrics=acc, @@ -249,8 +248,7 @@ def test_logging_interval(): sys.stdout = mystdout = StringIO() acc = mx.metric.Accuracy() log_interval = 5 - logging = LoggingHandler(verbose=LoggingHandler.LOG_PER_INTERVAL, - train_metrics=[acc], log_interval=log_interval) + logging = LoggingHandler(train_metrics=[acc], log_interval=log_interval) est = estimator.Estimator(net=net, loss=ce_loss, metrics=acc, From 59d7915069700bea54879963024a27b18f28f68e Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Thu, 5 Dec 2019 07:57:56 +0000 Subject: [PATCH 09/10] move log_interval to the first argument --- .../mxnet/gluon/contrib/estimator/event_handler.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/python/mxnet/gluon/contrib/estimator/event_handler.py b/python/mxnet/gluon/contrib/estimator/event_handler.py index 0eae0a6fa143..53ba07dc836a 100644 --- a/python/mxnet/gluon/contrib/estimator/event_handler.py +++ b/python/mxnet/gluon/contrib/estimator/event_handler.py @@ -227,19 +227,19 @@ class LoggingHandler(TrainBegin, TrainEnd, EpochBegin, EpochEnd, BatchBegin, Bat Parameters ---------- - train_metrics : list of EvalMetrics - Training metrics to be logged, logged at batch end, epoch end, train end. - val_metrics : list of EvalMetrics - Validation metrics to be logged, logged at epoch end, train end. log_interval: int or str, default 'epoch' Logging interval during training. log_interval='epoch': display metrics every epoch log_interval=integer k: display metrics every interval of k batches + train_metrics : list of EvalMetrics + Training metrics to be logged, logged at batch end, epoch end, train end. + val_metrics : list of EvalMetrics + Validation metrics to be logged, logged at epoch end, train end. """ - def __init__(self, train_metrics=None, - val_metrics=None, - log_interval='epoch'): + def __init__(self, log_interval='epoch', + train_metrics=None, + val_metrics=None): super(LoggingHandler, self).__init__() if not isinstance(log_interval, int) and log_interval != 'epoch': raise ValueError("log_interval must be either an integer or string 'epoch'") From 777addd0f95014692dacb3b7ae15d45af7724210 Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Fri, 6 Dec 2019 08:43:35 +0000 Subject: [PATCH 10/10] resolve unittest mistakes --- tests/python/unittest/test_gluon_event_handler.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/tests/python/unittest/test_gluon_event_handler.py b/tests/python/unittest/test_gluon_event_handler.py index fc069317dd13..658fb88f47e5 100644 --- a/tests/python/unittest/test_gluon_event_handler.py +++ b/tests/python/unittest/test_gluon_event_handler.py @@ -218,15 +218,12 @@ def test_logging_interval(): net = _get_test_network() dataloader = _get_test_data(in_size=data_size) num_epochs = 1 - ctx = mx.cpu() ce_loss = loss.SoftmaxCrossEntropyLoss() acc = mx.metric.Accuracy() - net.initialize(ctx=ctx) logging = LoggingHandler(train_metrics=[acc], log_interval=log_interval) est = estimator.Estimator(net=net, loss=ce_loss, - metrics=acc, - context=ctx) + metrics=acc) est.fit(train_data=dataloader, epochs=num_epochs, @@ -251,8 +248,7 @@ def test_logging_interval(): logging = LoggingHandler(train_metrics=[acc], log_interval=log_interval) est = estimator.Estimator(net=net, loss=ce_loss, - metrics=acc, - context=ctx) + metrics=acc) est.fit(train_data=dataloader, epochs=num_epochs, event_handlers=[logging])