From 87bfac5c9773a119390a7971025e699674bb6df9 Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Thu, 21 Dec 2017 14:40:33 -0800 Subject: [PATCH 1/6] Add EarlyStoppingHook, PlateauOpHook, and MetricsBasedHook base class PiperOrigin-RevId: 179860572 --- tensor2tensor/bin/t2t-trainer | 5 +- tensor2tensor/bin/t2t_trainer.py | 5 +- tensor2tensor/tpu/tpu_trainer.py | 5 +- tensor2tensor/tpu/tpu_trainer_lib.py | 33 ++- tensor2tensor/tpu/tpu_trainer_lib_test.py | 3 +- tensor2tensor/utils/flags.py | 12 +- tensor2tensor/utils/metrics_hook.py | 291 ++++++++++++++++++++++ tensor2tensor/utils/metrics_hook_test.py | 198 +++++++++++++++ 8 files changed, 530 insertions(+), 22 deletions(-) create mode 100644 tensor2tensor/utils/metrics_hook.py create mode 100644 tensor2tensor/utils/metrics_hook_test.py diff --git a/tensor2tensor/bin/t2t-trainer b/tensor2tensor/bin/t2t-trainer index 7992e9ba9..ed89949ab 100644 --- a/tensor2tensor/bin/t2t-trainer +++ b/tensor2tensor/bin/t2t-trainer @@ -77,9 +77,6 @@ def create_hparams(): def create_experiment_fn(): - use_validation_monitor = (FLAGS.schedule in - ["train_and_evaluate", "continuous_train_and_eval"] - and FLAGS.local_eval_frequency) return tpu_trainer_lib.create_experiment_fn( model_name=FLAGS.model, problem_name=get_problem_name(), @@ -92,9 +89,9 @@ def create_experiment_fn(): decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams), use_tfdbg=FLAGS.tfdbg, use_dbgprofile=FLAGS.dbgprofile, - use_validation_monitor=use_validation_monitor, eval_early_stopping_steps=FLAGS.eval_early_stopping_steps, eval_early_stopping_metric=FLAGS.eval_early_stopping_metric, + eval_early_stopping_metric_delta=FLAGS.eval_early_stopping_metric_delta, eval_early_stopping_metric_minimize=FLAGS. eval_early_stopping_metric_minimize, use_tpu=FLAGS.use_tpu) diff --git a/tensor2tensor/bin/t2t_trainer.py b/tensor2tensor/bin/t2t_trainer.py index d17ff85ea..990035ed0 100644 --- a/tensor2tensor/bin/t2t_trainer.py +++ b/tensor2tensor/bin/t2t_trainer.py @@ -76,9 +76,6 @@ def create_hparams(): def create_experiment_fn(): - use_validation_monitor = (FLAGS.schedule in - ["train_and_evaluate", "continuous_train_and_eval"] - and FLAGS.local_eval_frequency) return tpu_trainer_lib.create_experiment_fn( model_name=FLAGS.model, problem_name=get_problem_name(), @@ -91,9 +88,9 @@ def create_experiment_fn(): decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams), use_tfdbg=FLAGS.tfdbg, use_dbgprofile=FLAGS.dbgprofile, - use_validation_monitor=use_validation_monitor, eval_early_stopping_steps=FLAGS.eval_early_stopping_steps, eval_early_stopping_metric=FLAGS.eval_early_stopping_metric, + eval_early_stopping_metric_delta=FLAGS.eval_early_stopping_metric_delta, eval_early_stopping_metric_minimize=FLAGS. eval_early_stopping_metric_minimize, use_tpu=FLAGS.use_tpu) diff --git a/tensor2tensor/tpu/tpu_trainer.py b/tensor2tensor/tpu/tpu_trainer.py index d17ff85ea..990035ed0 100644 --- a/tensor2tensor/tpu/tpu_trainer.py +++ b/tensor2tensor/tpu/tpu_trainer.py @@ -76,9 +76,6 @@ def create_hparams(): def create_experiment_fn(): - use_validation_monitor = (FLAGS.schedule in - ["train_and_evaluate", "continuous_train_and_eval"] - and FLAGS.local_eval_frequency) return tpu_trainer_lib.create_experiment_fn( model_name=FLAGS.model, problem_name=get_problem_name(), @@ -91,9 +88,9 @@ def create_experiment_fn(): decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams), use_tfdbg=FLAGS.tfdbg, use_dbgprofile=FLAGS.dbgprofile, - use_validation_monitor=use_validation_monitor, eval_early_stopping_steps=FLAGS.eval_early_stopping_steps, eval_early_stopping_metric=FLAGS.eval_early_stopping_metric, + eval_early_stopping_metric_delta=FLAGS.eval_early_stopping_metric_delta, eval_early_stopping_metric_minimize=FLAGS. eval_early_stopping_metric_minimize, use_tpu=FLAGS.use_tpu) diff --git a/tensor2tensor/tpu/tpu_trainer_lib.py b/tensor2tensor/tpu/tpu_trainer_lib.py index 475d0f1be..be7f00351 100644 --- a/tensor2tensor/tpu/tpu_trainer_lib.py +++ b/tensor2tensor/tpu/tpu_trainer_lib.py @@ -19,10 +19,13 @@ from __future__ import division from __future__ import print_function +import os + # Dependency imports from tensor2tensor.utils import devices from tensor2tensor.utils import expert_utils +from tensor2tensor.utils import metrics_hook from tensor2tensor.utils import registry from tensor2tensor.utils import t2t_model @@ -186,7 +189,8 @@ def create_estimator(model_name, def create_hooks(use_tfdbg=False, use_dbgprofile=False, dbgprofile_kwargs=None, - use_validation_monitor=False, validation_monitor_kwargs=None): + use_validation_monitor=False, validation_monitor_kwargs=None, + use_early_stopping=False, early_stopping_kwargs=None): """Create train and eval hooks for Experiment.""" train_monitors = [] eval_hooks = [] @@ -208,6 +212,12 @@ def create_hooks(use_tfdbg=False, use_dbgprofile=False, dbgprofile_kwargs=None, tf.contrib.learn.monitors.ValidationMonitor( hooks=eval_hooks, **validation_monitor_kwargs)) + if use_early_stopping: + hook = metrics_hook.EarlyStoppingHook(**early_stopping_kwargs) + # Adding to both training and eval so that eval aborts as well + train_monitors.append(hook) + eval_hooks.append(hook) + return train_monitors, eval_hooks @@ -224,9 +234,9 @@ def create_experiment(run_config, decode_hparams=None, use_tfdbg=False, use_dbgprofile=False, - use_validation_monitor=False, eval_early_stopping_steps=None, eval_early_stopping_metric=None, + eval_early_stopping_metric_delta=None, eval_early_stopping_metric_minimize=True, use_tpu=False): """Create Experiment.""" @@ -264,12 +274,29 @@ def create_experiment(run_config, early_stopping_rounds=eval_early_stopping_steps, early_stopping_metric=eval_early_stopping_metric, early_stopping_metric_minimize=eval_early_stopping_metric_minimize) + early_stopping_kwargs = dict( + events_dir=os.path.join(run_config.model_dir, "eval_continuous"), + tag=eval_early_stopping_metric, + num_plateau_steps=eval_early_stopping_steps, + plateau_decrease=eval_early_stopping_metric_minimize, + plateau_delta=eval_early_stopping_metric_delta, + every_n_steps=min_eval_frequency) + + # In-process eval (and possible early stopping) + local_schedules = ["train_and_evaluate", "continuous_train_and_eval"] + use_validation_monitor = ( + schedule in local_schedules and min_eval_frequency) + # Distributed early stopping + use_early_stopping = ( + schedule not in local_schedules and eval_early_stopping_steps) train_monitors, eval_hooks = create_hooks( use_tfdbg=use_tfdbg, use_dbgprofile=use_dbgprofile, dbgprofile_kwargs=dbgprofile_kwargs, use_validation_monitor=use_validation_monitor, - validation_monitor_kwargs=validation_monitor_kwargs) + use_early_stopping=use_early_stopping, + validation_monitor_kwargs=validation_monitor_kwargs, + early_stopping_kwargs=early_stopping_kwargs) hooks_kwargs = {"train_monitors": train_monitors, "eval_hooks": eval_hooks} # Experiment diff --git a/tensor2tensor/tpu/tpu_trainer_lib_test.py b/tensor2tensor/tpu/tpu_trainer_lib_test.py index e8c1689c7..2a2148afd 100644 --- a/tensor2tensor/tpu/tpu_trainer_lib_test.py +++ b/tensor2tensor/tpu/tpu_trainer_lib_test.py @@ -68,7 +68,8 @@ def testExperiment(self): eval_steps=1, min_eval_frequency=1, use_tpu=False) - run_config = tpu_trainer_lib.create_run_config(num_gpus=0, use_tpu=False) + run_config = tpu_trainer_lib.create_run_config( + model_dir=self.data_dir, num_gpus=0, use_tpu=False) hparams = registry.hparams("transformer_tiny_tpu")() exp = exp_fn(run_config, hparams) exp.test() diff --git a/tensor2tensor/utils/flags.py b/tensor2tensor/utils/flags.py index f4e93a68f..410dccfe1 100644 --- a/tensor2tensor/utils/flags.py +++ b/tensor2tensor/utils/flags.py @@ -55,14 +55,14 @@ flags.DEFINE_integer("train_steps", 250000, "The number of steps to run training for.") flags.DEFINE_string("eval_early_stopping_metric", "loss", - "If --schedule=train_and_evaluate and " - "--eval_early_stopping_steps is not None, then stop when " - "--eval_early_stopping_metric has not decreased for " + "If --eval_early_stopping_steps is not None, then stop " + "when --eval_early_stopping_metric has not decreased for " "--eval_early_stopping_steps") +flags.DEFINE_float("eval_early_stopping_metric_delta", 0.1, + "Delta determining whether metric has plateaued.") flags.DEFINE_integer("eval_early_stopping_steps", None, - "If --schedule=train_and_evaluate and " - "--eval_early_stopping_steps is not None, then stop when " - "--eval_early_stopping_metric has not decreased for " + "If --eval_early_stopping_steps is not None, then stop " + "when --eval_early_stopping_metric has not decreased for " "--eval_early_stopping_steps") flags.DEFINE_bool("eval_early_stopping_metric_minimize", True, "Whether to check for the early stopping metric going down " diff --git a/tensor2tensor/utils/metrics_hook.py b/tensor2tensor/utils/metrics_hook.py new file mode 100644 index 000000000..e5cde12cc --- /dev/null +++ b/tensor2tensor/utils/metrics_hook.py @@ -0,0 +1,291 @@ +# coding=utf-8 +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Summary-based SessionRunHooks.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +# Dependency imports + +import tensorflow as tf + +from tensorboard.backend.event_processing import event_accumulator +from tensorboard.backend.event_processing import event_multiplexer + + +class MetricsBasedHook(tf.train.SessionRunHook): + """Base class for hooks based on summary metrics. + + Subclasses should override _process_metrics. + + If _process_metrics returns True, calls run_context.request_stop(). + + This can be used to something like "Stop after the loss has stopped decreasing + for 5000 steps. + """ + _RUN_NAME = "run%d" + + def __init__(self, events_dir, subdirs=None, tags=None, every_n_steps=1000): + """Construct MetricsBasedHook. + + Args: + events_dir: str, top-level directory containing events files. + subdirs: list, subdirectories of events_dir that also contain + events files. Use "" to specify the top-level directory. Defaults to + [""]. + tags: list, names of metrics to collect. Default will collect all + metrics. + every_n_steps: int, collect metrics every n steps. + """ + self._events_dir = events_dir + self._subdirs = subdirs or [""] + self._tags = tags + self._every_n_steps = every_n_steps + self._start_step = None + self._event_multiplexer = self._init_multiplexer() + + def _init_multiplexer(self): + dirs = [os.path.join(self._events_dir, subdir) for subdir in self._subdirs] + run_path_map = dict([(self._RUN_NAME % i, d) for i, d in enumerate(dirs)]) + return event_multiplexer.EventMultiplexer(run_path_map) + + def begin(self): + self._global_step_tensor = tf.train.get_global_step() + if self._global_step_tensor is None: + raise RuntimeError("Global step must be created to use MetricsBasedHook.") + + def after_create_session(self, session, coord): + del coord + if self._start_step is None: + self._start_step = session.run(self._global_step_tensor) + + def before_run(self, run_context): + del run_context + return tf.train.SessionRunArgs([self._global_step_tensor]) + + def after_run(self, run_context, run_values): + global_step = run_values.results[0] + if (global_step - self._start_step) % self._every_n_steps != 0: + return + metrics = self._collect_metrics() + self._after_run(run_context, run_values, global_step, metrics) + + def _after_run(self, run_context, run_values, global_step, metrics): + if self._process_metrics(global_step, metrics): + run_context.request_stop() + + def _collect_metrics(self): + self._event_multiplexer.Reload() + subdir_data = {} + for i, subdir in enumerate(self._subdirs): + subdir_metrics = {} + + accum = self._event_multiplexer.GetAccumulator(self._RUN_NAME % i) + for tag in accum.Tags()[event_accumulator.SCALARS]: + steps, vals = zip(*[ + (event.step, event.value) for event in accum.Scalars(tag)]) + subdir_metrics[tag] = (steps, vals) + + subdir_data[subdir] = subdir_metrics + return subdir_data + + def _process_metrics(self, global_step, metrics): + """Process the collected metrics. + + Args: + global_step: int, the current global step value. + metrics: dict. The collected + metrics. subdir_metrics is a dict from tag name to tuple of lists. The + lists are a list of global steps and a list of values. + i.e. subdir_metrics: + `dict global steps, list values>>>` + + Returns: + should_stop: bool. If True, will request that the session stops. + """ + return False + + +class EarlyStoppingHook(MetricsBasedHook): + """EarlyStoppingHook will stop training when a given metric has plateaued.""" + + def __init__(self, + events_dir, + tag, + num_plateau_steps=1000, + plateau_delta=0.1, + plateau_decrease=True, + every_n_steps=1000): + """Create an EarlyStoppingHook. + + This hook will stop training when the metric identified by tag has + plateaued. Plateaued is defined by the metric having stopped + increasing/decreasing (based on plateau_decrease) by plateau_delta for + num_plateau_steps. + + Args: + events_dir: Directory with events files. + tag: Name of metric in TensorBoard. + num_plateau_steps: Number of steps over which to check the plateau. + plateau_delta: delta to define a "plateau". + plateau_decrease: whether to check decrease or increase in the metric. + every_n_steps: how often to run this hook. + + Returns: + An instance of EarlyStoppingHook. + """ + super(EarlyStoppingHook, self).__init__( + events_dir=events_dir, tags=[tag], every_n_steps=every_n_steps) + self._num_plateau_steps = num_plateau_steps + self._plateau_delta = plateau_delta + self._plateau_decrease = plateau_decrease + + def _process_metrics(self, global_step, metrics): + if not metrics: + return + + if not metrics.values()[0]: + return + + # Metrics should have just a single subdir and a single tag + steps, vals = metrics.values()[0][self._tags[0]] + return has_metric_plateaued( + steps, + vals, + num_steps=self._num_plateau_steps, + delta=self._plateau_delta, + decrease=self._plateau_decrease) + + +class PlateauOpHook(MetricsBasedHook): + """Runs an op when a metric has plateaued.""" + + def __init__(self, + events_dir, + tag, + plateau_op, + num_plateau_steps=1000, + plateau_delta=0.1, + plateau_decrease=True, + every_n_steps=1000, + only_once=False): + """See EarlyStoppingHook for args. Runs plateau_op if plateaued.""" + super(PlateauOpHook, self).__init__( + events_dir=events_dir, tags=[tag], every_n_steps=every_n_steps) + self._num_plateau_steps = num_plateau_steps + self._plateau_delta = plateau_delta + self._plateau_decrease = plateau_decrease + self._plateau_op = plateau_op + self._only_once = only_once + self._should_run_op = False + self._ever_ran = False + self._last_metric_step_seen = 0 + + @property + def keep_alive(self): + if self._only_once and self._ever_ran: + return False + return True + + def before_run(self, run_context): + del run_context + + fetches = [self._global_step_tensor] + if self._should_run_op and self.keep_alive: + fetches.append(self._plateau_op) + self._should_run_op = False + self._ever_ran = True + + return tf.train.SessionRunArgs(fetches) + + def _after_run(self, run_context, run_values, global_step, metrics): + del run_context + del run_values + del global_step + + if not self.keep_alive: + return + + if not metrics: + return + + if not metrics.values()[0]: + return + + # There should be only a single subdir and a single tag + steps, vals = metrics.values()[0][self._tags[0]] + + if not steps: + return + + last_step = steps[-1] + if last_step == self._last_metric_step_seen: + return + self._last_metric_step_seen = last_step + + if has_metric_plateaued( + steps, + vals, + num_steps=self._num_plateau_steps, + delta=self._plateau_delta, + decrease=self._plateau_decrease): + self._should_run_op = True + + +def has_metric_plateaued(steps, values, num_steps=100, delta=0.1, + decrease=True): + """Check if metric has plateaued. + + A metric has plateaued if the value has not increased/decreased (depending on + `decrease`) by `delta` for at least `num_steps`. + + Args: + steps: list list of global steps for values. + values: list list of metric values. + num_steps: int, number of steps the metric has to have been plateaued for. + delta: float, how much the metric should have changed by over num_steps. + decrease: bool, whether to check if the metric has decreased by delta or + increased by delta. + + Returns: + bool, whether the metric has plateaued. + """ + assert num_steps > 0 + if len(steps) < 2: + return False + + steps_at_least_num_steps_ago = [ + s for s in steps if s <= (steps[-1] - num_steps) + ] + if not steps_at_least_num_steps_ago: + # Not enough steps yet + return False + delta_step_idx = len(steps_at_least_num_steps_ago) - 1 + + start_val = values[delta_step_idx] + values_to_check = values[delta_step_idx:] + observed_deltas = [] + for val in values_to_check: + if decrease: + observed_delta = start_val - val + else: + observed_delta = val - start_val + observed_deltas.append(observed_delta) + + within_range = [obs < delta for obs in observed_deltas] + return all(within_range) diff --git a/tensor2tensor/utils/metrics_hook_test.py b/tensor2tensor/utils/metrics_hook_test.py new file mode 100644 index 000000000..dc4468cc4 --- /dev/null +++ b/tensor2tensor/utils/metrics_hook_test.py @@ -0,0 +1,198 @@ +# coding=utf-8 +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for metrics_hook.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import os +import shutil + +# Dependency imports + +from tensor2tensor.utils import metrics_hook + +import tensorflow as tf + + +class DummyHook(metrics_hook.MetricsBasedHook): + + def _process_metrics(self, global_step, metrics): + if metrics: + assert "" in metrics + assert isinstance(metrics[""], dict) + if metrics[""]: + assert "global_step_1" in metrics[""] + self.test_metrics = metrics + if global_step >= 40: + return True + + +class MetricsHookTest(tf.test.TestCase): + + @classmethod + def setUpClass(cls): + cls.base_checkpoint_dir = tf.test.get_temp_dir() + shutil.rmtree(cls.base_checkpoint_dir, ignore_errors=True) + + def ckpt_dir(self, name): + return os.path.join(self.base_checkpoint_dir, name) + + @contextlib.contextmanager + def sess(self, hook, ckpt_dir): + with tf.train.MonitoredTrainingSession( + checkpoint_dir=ckpt_dir, + save_checkpoint_secs=0, + save_summaries_steps=10, + hooks=[hook]) as sess: + self._sess = sess + yield sess + + def flush(self): + self._sess._hooks[1]._summary_writer.flush() + + def testStop(self): + global_step = tf.train.create_global_step() + tf.summary.scalar("global_step", global_step) + incr_global_step = tf.assign_add(global_step, 1) + + ckpt_dir = self.ckpt_dir("stop") + dummy = DummyHook(ckpt_dir, every_n_steps=10) + with self.sess(dummy, ckpt_dir) as sess: + for _ in xrange(20): + sess.run(incr_global_step) + + # Summary files should now have 2 global step values in them + self.flush() + + # Run for 10 more so that the hook gets triggered again + for _ in xrange(10): + sess.run(incr_global_step) + + # Check that the metrics have actually been collected. + self.assertTrue("" in dummy.test_metrics) + metrics = dummy.test_metrics[""] + self.assertTrue("global_step_1" in metrics) + steps, vals = metrics["global_step_1"] + self.assertTrue(len(steps) == len(vals)) + self.assertTrue(len(steps) >= 2) + + # Run for 10 more so that the hook triggers stoppage + for _ in xrange(10): + sess.run(incr_global_step) + + with self.assertRaisesRegexp(RuntimeError, "after should_stop requested"): + sess.run(incr_global_step) + + def testEarlyStoppingHook(self): + global_step = tf.train.create_global_step() + counter = tf.get_variable("count", initializer=0, dtype=tf.int32) + tf.summary.scalar("count", counter) + incr_global_step = tf.assign_add(global_step, 1) + incr_counter = tf.assign_add(counter, 1) + + # Stop if the global step has not gone up by more than 1 in 20 steps. + + ckpt_dir = self.ckpt_dir("early") + stop_hook = metrics_hook.EarlyStoppingHook( + ckpt_dir, + "count_1", + num_plateau_steps=20, + plateau_delta=1., + plateau_decrease=False, + every_n_steps=10) + with self.sess(stop_hook, ckpt_dir) as sess: + for _ in xrange(20): + sess.run((incr_global_step, incr_counter)) + + # Summary files should now have 2 values in them + self.flush() + + # Run for more steps so that the hook gets triggered and we verify that we + # don't stop. + for _ in xrange(30): + sess.run((incr_global_step, incr_counter)) + + self.flush() + + # Run without incrementing the counter + for _ in xrange(40): + sess.run(incr_global_step) + + # Metrics should be written such that now the counter has gone >20 steps + # without being incremented. + self.flush() + + # Check that we ask for stop + with self.assertRaisesRegexp(RuntimeError, "after should_stop requested"): + for _ in xrange(30): + sess.run(incr_global_step) + + def testPlateauOpHook(self): + global_step = tf.train.create_global_step() + counter = tf.get_variable("count", initializer=0, dtype=tf.int32) + indicator = tf.get_variable("indicator", initializer=0, dtype=tf.int32) + tf.summary.scalar("count", counter) + incr_global_step = tf.assign_add(global_step, 1) + incr_counter = tf.assign_add(counter, 1) + incr_indicator = tf.assign_add(indicator, 1) + + # Stop if the global step has not gone up by more than 1 in 20 steps. + + ckpt_dir = self.ckpt_dir("plateauop") + stop_hook = metrics_hook.PlateauOpHook( + ckpt_dir, + "count_1", + incr_indicator, + num_plateau_steps=20, + plateau_delta=1., + plateau_decrease=False, + every_n_steps=10) + with self.sess(stop_hook, ckpt_dir) as sess: + for _ in xrange(20): + sess.run((incr_global_step, incr_counter)) + + # Summary files should now have 2 values in them + self.flush() + + # Run for more steps so that the hook gets triggered and we verify that we + # don't stop. + for _ in xrange(30): + sess.run((incr_global_step, incr_counter)) + + self.flush() + + # Run without incrementing the counter + for _ in xrange(30): + sess.run(incr_global_step) + self.flush() + + self.assertTrue(sess.run(indicator) < 1) + + # Metrics should be written such that now the counter has gone >20 steps + # without being incremented. + # Check that we run the incr_indicator op several times + for _ in xrange(3): + for _ in xrange(10): + sess.run(incr_global_step) + self.flush() + + self.assertTrue(sess.run(indicator) > 1) + +if __name__ == "__main__": + tf.test.main() From 45a4b88bdab90574929d25ef0a8bd0dda3481eb2 Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Thu, 21 Dec 2017 16:12:54 -0800 Subject: [PATCH 2/6] Fix colab notebook PiperOrigin-RevId: 179871302 --- tensor2tensor/notebooks/hello_t2t.ipynb | 52 +------------------------ 1 file changed, 2 insertions(+), 50 deletions(-) diff --git a/tensor2tensor/notebooks/hello_t2t.ipynb b/tensor2tensor/notebooks/hello_t2t.ipynb index 1ff6b1d2b..5b58b042b 100644 --- a/tensor2tensor/notebooks/hello_t2t.ipynb +++ b/tensor2tensor/notebooks/hello_t2t.ipynb @@ -85,6 +85,7 @@ "import os\n", "import collections\n", "\n", + "from tensor2tensor import models\n", "from tensor2tensor import problems\n", "from tensor2tensor.layers import common_layers\n", "from tensor2tensor.tpu import tpu_trainer_lib\n", @@ -1540,55 +1541,6 @@ } ] }, - { - "metadata": { - "id": "a2cL8UwLaSYG", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "source": [ - "# This will eventually be available at\n", - "# tensor2tensor.metrics.create_eager_metrics\n", - "def create_eager_metrics(metric_names):\n", - " \"\"\"Create metrics accumulators and averager for Eager mode.\n", - "\n", - " Args:\n", - " metric_names: list from tensor2tensor.metrics.Metrics\n", - "\n", - " Returns:\n", - " (accum_fn(predictions, targets) => None,\n", - " result_fn() => dict\n", - " \"\"\"\n", - " metric_fns = dict(\n", - " [(name, metrics.METRICS_FNS[name]) for name in metric_names])\n", - " tfe_metrics = dict()\n", - "\n", - " for name in metric_names:\n", - " tfe_metrics[name] = tfe.metrics.Mean(name=name)\n", - "\n", - " def metric_accum(predictions, targets):\n", - " for name, metric_fn in metric_fns.items():\n", - " val, weight = metric_fn(predictions, targets,\n", - " weights_fn=common_layers.weights_all)\n", - " tfe_metrics[name](np.squeeze(val), np.squeeze(weight))\n", - "\n", - " def metric_means():\n", - " avgs = {}\n", - " for name in metric_names:\n", - " avgs[name] = tfe_metrics[name].result().numpy()\n", - " return avgs\n", - "\n", - " return metric_accum, metric_means" - ], - "cell_type": "code", - "execution_count": 0, - "outputs": [] - }, { "metadata": { "id": "CIFlkiVOd8jO", @@ -1625,7 +1577,7 @@ "\n", "# Create eval metric accumulators for accuracy (ACC) and accuracy in\n", "# top 5 (ACC_TOP5)\n", - "metrics_accum, metrics_result = create_eager_metrics(\n", + "metrics_accum, metrics_result = metrics.create_eager_metrics(\n", " [metrics.Metrics.ACC, metrics.Metrics.ACC_TOP5])\n", "\n", "for count, example in enumerate(tfe.Iterator(mnist_eval_dataset)):\n", From b10286edfd366e68b12dac8eaf1a7e26305a683e Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Thu, 21 Dec 2017 18:13:56 -0800 Subject: [PATCH 3/6] Pad eval batch to enable multi-device eval; skip T2TModel.top if T2TModel.body returns training loss PiperOrigin-RevId: 179882031 --- setup.py | 2 +- tensor2tensor/bin/t2t-trainer | 6 ++++- tensor2tensor/bin/t2t_trainer.py | 6 ++++- tensor2tensor/data_generators/problem.py | 33 ++++++++++++++++++++++++ tensor2tensor/tpu/tpu_trainer.py | 6 ++++- tensor2tensor/utils/t2t_model.py | 8 ++++-- 6 files changed, 55 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index 01ef5e550..0ae11d780 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='tensor2tensor', - version='1.4.0', + version='1.4.1', description='Tensor2Tensor', author='Google Inc.', author_email='no-reply@google.com', diff --git a/tensor2tensor/bin/t2t-trainer b/tensor2tensor/bin/t2t-trainer index ed89949ab..9e2ca39b9 100644 --- a/tensor2tensor/bin/t2t-trainer +++ b/tensor2tensor/bin/t2t-trainer @@ -61,7 +61,11 @@ try: flags.DEFINE_string("output_dir", "", "Base output directory for run.") flags.DEFINE_string("schedule", "continuous_train_and_eval", "Method of Experiment to run.") - flags.DEFINE_integer("eval_steps", 200, "Number of steps in evaluation.") + flags.DEFINE_integer("eval_steps", 10000, + "Number of steps in evaluation. By default, eval will " + "stop after eval_steps or when it runs through the eval " + "dataset once in full, whichever comes first, so this " + "can be a very large number.") except: # pylint: disable=bare-except pass diff --git a/tensor2tensor/bin/t2t_trainer.py b/tensor2tensor/bin/t2t_trainer.py index 990035ed0..792403062 100644 --- a/tensor2tensor/bin/t2t_trainer.py +++ b/tensor2tensor/bin/t2t_trainer.py @@ -60,7 +60,11 @@ flags.DEFINE_string("output_dir", "", "Base output directory for run.") flags.DEFINE_string("schedule", "continuous_train_and_eval", "Method of Experiment to run.") - flags.DEFINE_integer("eval_steps", 200, "Number of steps in evaluation.") + flags.DEFINE_integer("eval_steps", 10000, + "Number of steps in evaluation. By default, eval will " + "stop after eval_steps or when it runs through the eval " + "dataset once in full, whichever comes first, so this " + "can be a very large number.") except: # pylint: disable=bare-except pass diff --git a/tensor2tensor/data_generators/problem.py b/tensor2tensor/data_generators/problem.py index e944f15ab..aa1c894db 100644 --- a/tensor2tensor/data_generators/problem.py +++ b/tensor2tensor/data_generators/problem.py @@ -576,6 +576,19 @@ def define_shapes(example): batching_scheme["boundaries"], batching_scheme["batch_sizes"]) + if not is_training: + def _pad_batch(features): + if not config or config.data_parallelism.n <= 1: + return features + tf.logging.warn( + "Padding the batch to ensure that remainder eval batches have " + "a batch size divisible by the number of data shards. This may " + "lead to incorrect metrics for non-zero-padded features, e.g. " + "images. Use a single datashard (i.e. 1 GPU) in that case.") + return pad_batch(features, config.data_parallelism.n) + + dataset = dataset.map(_pad_batch, num_parallel_calls=num_threads) + dataset = dataset.map(define_shapes, num_parallel_calls=num_threads) dataset = dataset.prefetch(1) features = dataset.make_one_shot_iterator().get_next() @@ -930,3 +943,23 @@ def standardize_shapes(features, batch_size=None): t.get_shape().assert_is_fully_defined() return features + + +def pad_batch(features, batch_multiple): + """Pad batch dim of features to nearest multiple of batch_multiple.""" + feature = features.items()[0][1] + batch_size = tf.shape(feature)[0] + mod = batch_size % batch_multiple + has_mod = tf.cast(tf.cast(mod, tf.bool), tf.int32) + batch_padding = batch_multiple * has_mod - mod + + padded_features = {} + for k, feature in features.items(): + rank = len(feature.shape) + paddings = [] + for _ in range(rank): + paddings.append([0, 0]) + paddings[0][1] = batch_padding + padded_feature = tf.pad(feature, paddings) + padded_features[k] = padded_feature + return padded_features diff --git a/tensor2tensor/tpu/tpu_trainer.py b/tensor2tensor/tpu/tpu_trainer.py index 990035ed0..792403062 100644 --- a/tensor2tensor/tpu/tpu_trainer.py +++ b/tensor2tensor/tpu/tpu_trainer.py @@ -60,7 +60,11 @@ flags.DEFINE_string("output_dir", "", "Base output directory for run.") flags.DEFINE_string("schedule", "continuous_train_and_eval", "Method of Experiment to run.") - flags.DEFINE_integer("eval_steps", 200, "Number of steps in evaluation.") + flags.DEFINE_integer("eval_steps", 10000, + "Number of steps in evaluation. By default, eval will " + "stop after eval_steps or when it runs through the eval " + "dataset once in full, whichever comes first, so this " + "can be a very large number.") except: # pylint: disable=bare-except pass diff --git a/tensor2tensor/utils/t2t_model.py b/tensor2tensor/utils/t2t_model.py index 26854de13..630011541 100644 --- a/tensor2tensor/utils/t2t_model.py +++ b/tensor2tensor/utils/t2t_model.py @@ -139,13 +139,15 @@ def model_fn_sharded(self, sharded_features): body_out = self.body_sharded( self._to_single_features_dict(transformed_features)) body_out, losses = self._normalize_body_output(body_out) - sharded_logits = dp(self.top, body_out, datashard_to_features) if "training" not in losses: + sharded_logits = dp(self.top, body_out, datashard_to_features) sharded_losses = dp(self.loss, sharded_logits, datashard_to_features) training_loss_dict = average_sharded_losses([{ "training": loss } for loss in sharded_losses]) losses.update(training_loss_dict) + else: + sharded_logits = body_out else: sharded_logits, sharded_losses = dp(self.model_fn, datashard_to_features) losses = average_sharded_losses(sharded_losses) @@ -172,9 +174,11 @@ def model_fn(self, features): body_out = self.body(transformed_features) output, losses = self._normalize_body_output(body_out) - logits = self.top(output, features) if "training" not in losses: + logits = self.top(output, features) losses["training"] = self.loss(logits, features) + else: + logits = output return logits, losses def bottom(self, features): From 83e5949a6c9502623a9ab35c4cb62ad681e23e7f Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Thu, 21 Dec 2017 18:28:29 -0800 Subject: [PATCH 4/6] Rm xrange usage from metrics_hook_test PiperOrigin-RevId: 179882966 --- tensor2tensor/utils/metrics_hook_test.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/tensor2tensor/utils/metrics_hook_test.py b/tensor2tensor/utils/metrics_hook_test.py index dc4468cc4..67c78eb2d 100644 --- a/tensor2tensor/utils/metrics_hook_test.py +++ b/tensor2tensor/utils/metrics_hook_test.py @@ -74,14 +74,14 @@ def testStop(self): ckpt_dir = self.ckpt_dir("stop") dummy = DummyHook(ckpt_dir, every_n_steps=10) with self.sess(dummy, ckpt_dir) as sess: - for _ in xrange(20): + for _ in range(20): sess.run(incr_global_step) # Summary files should now have 2 global step values in them self.flush() # Run for 10 more so that the hook gets triggered again - for _ in xrange(10): + for _ in range(10): sess.run(incr_global_step) # Check that the metrics have actually been collected. @@ -93,7 +93,7 @@ def testStop(self): self.assertTrue(len(steps) >= 2) # Run for 10 more so that the hook triggers stoppage - for _ in xrange(10): + for _ in range(10): sess.run(incr_global_step) with self.assertRaisesRegexp(RuntimeError, "after should_stop requested"): @@ -117,7 +117,7 @@ def testEarlyStoppingHook(self): plateau_decrease=False, every_n_steps=10) with self.sess(stop_hook, ckpt_dir) as sess: - for _ in xrange(20): + for _ in range(20): sess.run((incr_global_step, incr_counter)) # Summary files should now have 2 values in them @@ -125,13 +125,13 @@ def testEarlyStoppingHook(self): # Run for more steps so that the hook gets triggered and we verify that we # don't stop. - for _ in xrange(30): + for _ in range(30): sess.run((incr_global_step, incr_counter)) self.flush() # Run without incrementing the counter - for _ in xrange(40): + for _ in range(40): sess.run(incr_global_step) # Metrics should be written such that now the counter has gone >20 steps @@ -140,7 +140,7 @@ def testEarlyStoppingHook(self): # Check that we ask for stop with self.assertRaisesRegexp(RuntimeError, "after should_stop requested"): - for _ in xrange(30): + for _ in range(30): sess.run(incr_global_step) def testPlateauOpHook(self): @@ -164,7 +164,7 @@ def testPlateauOpHook(self): plateau_decrease=False, every_n_steps=10) with self.sess(stop_hook, ckpt_dir) as sess: - for _ in xrange(20): + for _ in range(20): sess.run((incr_global_step, incr_counter)) # Summary files should now have 2 values in them @@ -172,13 +172,13 @@ def testPlateauOpHook(self): # Run for more steps so that the hook gets triggered and we verify that we # don't stop. - for _ in xrange(30): + for _ in range(30): sess.run((incr_global_step, incr_counter)) self.flush() # Run without incrementing the counter - for _ in xrange(30): + for _ in range(30): sess.run(incr_global_step) self.flush() @@ -187,8 +187,8 @@ def testPlateauOpHook(self): # Metrics should be written such that now the counter has gone >20 steps # without being incremented. # Check that we run the incr_indicator op several times - for _ in xrange(3): - for _ in xrange(10): + for _ in range(3): + for _ in range(10): sess.run(incr_global_step) self.flush() From f2b620f7bd3266e911b75690e504c4146b2d2fdf Mon Sep 17 00:00:00 2001 From: Ryan Sepassi Date: Thu, 21 Dec 2017 19:38:07 -0800 Subject: [PATCH 5/6] python3 fix to metrics_hook_test PiperOrigin-RevId: 179886783 --- tensor2tensor/utils/metrics_hook.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensor2tensor/utils/metrics_hook.py b/tensor2tensor/utils/metrics_hook.py index e5cde12cc..964139a42 100644 --- a/tensor2tensor/utils/metrics_hook.py +++ b/tensor2tensor/utils/metrics_hook.py @@ -159,11 +159,11 @@ def _process_metrics(self, global_step, metrics): if not metrics: return - if not metrics.values()[0]: + if not list(metrics.values())[0]: return # Metrics should have just a single subdir and a single tag - steps, vals = metrics.values()[0][self._tags[0]] + steps, vals = list(metrics.values())[0][self._tags[0]] return has_metric_plateaued( steps, vals, @@ -224,11 +224,11 @@ def _after_run(self, run_context, run_values, global_step, metrics): if not metrics: return - if not metrics.values()[0]: + if not list(metrics.values())[0]: return # There should be only a single subdir and a single tag - steps, vals = metrics.values()[0][self._tags[0]] + steps, vals = list(metrics.values())[0][self._tags[0]] if not steps: return From 211eb69bbb3ab0f586e0eb70509ec934b7c7c791 Mon Sep 17 00:00:00 2001 From: Martin Popel Date: Fri, 22 Dec 2017 17:15:07 +0100 Subject: [PATCH 6/6] add executable bit to the scripts in bin/ --- tensor2tensor/bin/t2t-datagen | 0 tensor2tensor/bin/t2t-decoder | 0 tensor2tensor/bin/t2t-make-tf-configs | 0 tensor2tensor/bin/t2t-trainer | 0 4 files changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 tensor2tensor/bin/t2t-datagen mode change 100644 => 100755 tensor2tensor/bin/t2t-decoder mode change 100644 => 100755 tensor2tensor/bin/t2t-make-tf-configs mode change 100644 => 100755 tensor2tensor/bin/t2t-trainer diff --git a/tensor2tensor/bin/t2t-datagen b/tensor2tensor/bin/t2t-datagen old mode 100644 new mode 100755 diff --git a/tensor2tensor/bin/t2t-decoder b/tensor2tensor/bin/t2t-decoder old mode 100644 new mode 100755 diff --git a/tensor2tensor/bin/t2t-make-tf-configs b/tensor2tensor/bin/t2t-make-tf-configs old mode 100644 new mode 100755 diff --git a/tensor2tensor/bin/t2t-trainer b/tensor2tensor/bin/t2t-trainer old mode 100644 new mode 100755