Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Executable bit #486

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='tensor2tensor',
version='1.4.0',
version='1.4.1',
description='Tensor2Tensor',
author='Google Inc.',
author_email='[email protected]',
Expand Down
Empty file modified tensor2tensor/bin/t2t-datagen
100644 → 100755
Empty file.
Empty file modified tensor2tensor/bin/t2t-decoder
100644 → 100755
Empty file.
Empty file modified tensor2tensor/bin/t2t-make-tf-configs
100644 → 100755
Empty file.
11 changes: 6 additions & 5 deletions tensor2tensor/bin/t2t-trainer
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ try:
flags.DEFINE_string("output_dir", "", "Base output directory for run.")
flags.DEFINE_string("schedule", "continuous_train_and_eval",
"Method of Experiment to run.")
flags.DEFINE_integer("eval_steps", 200, "Number of steps in evaluation.")
flags.DEFINE_integer("eval_steps", 10000,
"Number of steps in evaluation. By default, eval will "
"stop after eval_steps or when it runs through the eval "
"dataset once in full, whichever comes first, so this "
"can be a very large number.")
except: # pylint: disable=bare-except
pass

Expand All @@ -77,9 +81,6 @@ def create_hparams():


def create_experiment_fn():
use_validation_monitor = (FLAGS.schedule in
["train_and_evaluate", "continuous_train_and_eval"]
and FLAGS.local_eval_frequency)
return tpu_trainer_lib.create_experiment_fn(
model_name=FLAGS.model,
problem_name=get_problem_name(),
Expand All @@ -92,9 +93,9 @@ def create_experiment_fn():
decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams),
use_tfdbg=FLAGS.tfdbg,
use_dbgprofile=FLAGS.dbgprofile,
use_validation_monitor=use_validation_monitor,
eval_early_stopping_steps=FLAGS.eval_early_stopping_steps,
eval_early_stopping_metric=FLAGS.eval_early_stopping_metric,
eval_early_stopping_metric_delta=FLAGS.eval_early_stopping_metric_delta,
eval_early_stopping_metric_minimize=FLAGS.
eval_early_stopping_metric_minimize,
use_tpu=FLAGS.use_tpu)
Expand Down
11 changes: 6 additions & 5 deletions tensor2tensor/bin/t2t_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@
flags.DEFINE_string("output_dir", "", "Base output directory for run.")
flags.DEFINE_string("schedule", "continuous_train_and_eval",
"Method of Experiment to run.")
flags.DEFINE_integer("eval_steps", 200, "Number of steps in evaluation.")
flags.DEFINE_integer("eval_steps", 10000,
"Number of steps in evaluation. By default, eval will "
"stop after eval_steps or when it runs through the eval "
"dataset once in full, whichever comes first, so this "
"can be a very large number.")
except: # pylint: disable=bare-except
pass

Expand All @@ -76,9 +80,6 @@ def create_hparams():


def create_experiment_fn():
use_validation_monitor = (FLAGS.schedule in
["train_and_evaluate", "continuous_train_and_eval"]
and FLAGS.local_eval_frequency)
return tpu_trainer_lib.create_experiment_fn(
model_name=FLAGS.model,
problem_name=get_problem_name(),
Expand All @@ -91,9 +92,9 @@ def create_experiment_fn():
decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams),
use_tfdbg=FLAGS.tfdbg,
use_dbgprofile=FLAGS.dbgprofile,
use_validation_monitor=use_validation_monitor,
eval_early_stopping_steps=FLAGS.eval_early_stopping_steps,
eval_early_stopping_metric=FLAGS.eval_early_stopping_metric,
eval_early_stopping_metric_delta=FLAGS.eval_early_stopping_metric_delta,
eval_early_stopping_metric_minimize=FLAGS.
eval_early_stopping_metric_minimize,
use_tpu=FLAGS.use_tpu)
Expand Down
33 changes: 33 additions & 0 deletions tensor2tensor/data_generators/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,19 @@ def define_shapes(example):
batching_scheme["boundaries"],
batching_scheme["batch_sizes"])

if not is_training:
def _pad_batch(features):
if not config or config.data_parallelism.n <= 1:
return features
tf.logging.warn(
"Padding the batch to ensure that remainder eval batches have "
"a batch size divisible by the number of data shards. This may "
"lead to incorrect metrics for non-zero-padded features, e.g. "
"images. Use a single datashard (i.e. 1 GPU) in that case.")
return pad_batch(features, config.data_parallelism.n)

dataset = dataset.map(_pad_batch, num_parallel_calls=num_threads)

dataset = dataset.map(define_shapes, num_parallel_calls=num_threads)
dataset = dataset.prefetch(1)
features = dataset.make_one_shot_iterator().get_next()
Expand Down Expand Up @@ -930,3 +943,23 @@ def standardize_shapes(features, batch_size=None):
t.get_shape().assert_is_fully_defined()

return features


def pad_batch(features, batch_multiple):
"""Pad batch dim of features to nearest multiple of batch_multiple."""
feature = features.items()[0][1]
batch_size = tf.shape(feature)[0]
mod = batch_size % batch_multiple
has_mod = tf.cast(tf.cast(mod, tf.bool), tf.int32)
batch_padding = batch_multiple * has_mod - mod

padded_features = {}
for k, feature in features.items():
rank = len(feature.shape)
paddings = []
for _ in range(rank):
paddings.append([0, 0])
paddings[0][1] = batch_padding
padded_feature = tf.pad(feature, paddings)
padded_features[k] = padded_feature
return padded_features
52 changes: 2 additions & 50 deletions tensor2tensor/notebooks/hello_t2t.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@
"import os\n",
"import collections\n",
"\n",
"from tensor2tensor import models\n",
"from tensor2tensor import problems\n",
"from tensor2tensor.layers import common_layers\n",
"from tensor2tensor.tpu import tpu_trainer_lib\n",
Expand Down Expand Up @@ -1540,55 +1541,6 @@
}
]
},
{
"metadata": {
"id": "a2cL8UwLaSYG",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"source": [
"# This will eventually be available at\n",
"# tensor2tensor.metrics.create_eager_metrics\n",
"def create_eager_metrics(metric_names):\n",
" \"\"\"Create metrics accumulators and averager for Eager mode.\n",
"\n",
" Args:\n",
" metric_names: list<str> from tensor2tensor.metrics.Metrics\n",
"\n",
" Returns:\n",
" (accum_fn(predictions, targets) => None,\n",
" result_fn() => dict<str metric_name, float avg_val>\n",
" \"\"\"\n",
" metric_fns = dict(\n",
" [(name, metrics.METRICS_FNS[name]) for name in metric_names])\n",
" tfe_metrics = dict()\n",
"\n",
" for name in metric_names:\n",
" tfe_metrics[name] = tfe.metrics.Mean(name=name)\n",
"\n",
" def metric_accum(predictions, targets):\n",
" for name, metric_fn in metric_fns.items():\n",
" val, weight = metric_fn(predictions, targets,\n",
" weights_fn=common_layers.weights_all)\n",
" tfe_metrics[name](np.squeeze(val), np.squeeze(weight))\n",
"\n",
" def metric_means():\n",
" avgs = {}\n",
" for name in metric_names:\n",
" avgs[name] = tfe_metrics[name].result().numpy()\n",
" return avgs\n",
"\n",
" return metric_accum, metric_means"
],
"cell_type": "code",
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "CIFlkiVOd8jO",
Expand Down Expand Up @@ -1625,7 +1577,7 @@
"\n",
"# Create eval metric accumulators for accuracy (ACC) and accuracy in\n",
"# top 5 (ACC_TOP5)\n",
"metrics_accum, metrics_result = create_eager_metrics(\n",
"metrics_accum, metrics_result = metrics.create_eager_metrics(\n",
" [metrics.Metrics.ACC, metrics.Metrics.ACC_TOP5])\n",
"\n",
"for count, example in enumerate(tfe.Iterator(mnist_eval_dataset)):\n",
Expand Down
11 changes: 6 additions & 5 deletions tensor2tensor/tpu/tpu_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,11 @@
flags.DEFINE_string("output_dir", "", "Base output directory for run.")
flags.DEFINE_string("schedule", "continuous_train_and_eval",
"Method of Experiment to run.")
flags.DEFINE_integer("eval_steps", 200, "Number of steps in evaluation.")
flags.DEFINE_integer("eval_steps", 10000,
"Number of steps in evaluation. By default, eval will "
"stop after eval_steps or when it runs through the eval "
"dataset once in full, whichever comes first, so this "
"can be a very large number.")
except: # pylint: disable=bare-except
pass

Expand All @@ -76,9 +80,6 @@ def create_hparams():


def create_experiment_fn():
use_validation_monitor = (FLAGS.schedule in
["train_and_evaluate", "continuous_train_and_eval"]
and FLAGS.local_eval_frequency)
return tpu_trainer_lib.create_experiment_fn(
model_name=FLAGS.model,
problem_name=get_problem_name(),
Expand All @@ -91,9 +92,9 @@ def create_experiment_fn():
decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams),
use_tfdbg=FLAGS.tfdbg,
use_dbgprofile=FLAGS.dbgprofile,
use_validation_monitor=use_validation_monitor,
eval_early_stopping_steps=FLAGS.eval_early_stopping_steps,
eval_early_stopping_metric=FLAGS.eval_early_stopping_metric,
eval_early_stopping_metric_delta=FLAGS.eval_early_stopping_metric_delta,
eval_early_stopping_metric_minimize=FLAGS.
eval_early_stopping_metric_minimize,
use_tpu=FLAGS.use_tpu)
Expand Down
33 changes: 30 additions & 3 deletions tensor2tensor/tpu/tpu_trainer_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,13 @@
from __future__ import division
from __future__ import print_function

import os

# Dependency imports

from tensor2tensor.utils import devices
from tensor2tensor.utils import expert_utils
from tensor2tensor.utils import metrics_hook
from tensor2tensor.utils import registry
from tensor2tensor.utils import t2t_model

Expand Down Expand Up @@ -186,7 +189,8 @@ def create_estimator(model_name,


def create_hooks(use_tfdbg=False, use_dbgprofile=False, dbgprofile_kwargs=None,
use_validation_monitor=False, validation_monitor_kwargs=None):
use_validation_monitor=False, validation_monitor_kwargs=None,
use_early_stopping=False, early_stopping_kwargs=None):
"""Create train and eval hooks for Experiment."""
train_monitors = []
eval_hooks = []
Expand All @@ -208,6 +212,12 @@ def create_hooks(use_tfdbg=False, use_dbgprofile=False, dbgprofile_kwargs=None,
tf.contrib.learn.monitors.ValidationMonitor(
hooks=eval_hooks, **validation_monitor_kwargs))

if use_early_stopping:
hook = metrics_hook.EarlyStoppingHook(**early_stopping_kwargs)
# Adding to both training and eval so that eval aborts as well
train_monitors.append(hook)
eval_hooks.append(hook)

return train_monitors, eval_hooks


Expand All @@ -224,9 +234,9 @@ def create_experiment(run_config,
decode_hparams=None,
use_tfdbg=False,
use_dbgprofile=False,
use_validation_monitor=False,
eval_early_stopping_steps=None,
eval_early_stopping_metric=None,
eval_early_stopping_metric_delta=None,
eval_early_stopping_metric_minimize=True,
use_tpu=False):
"""Create Experiment."""
Expand Down Expand Up @@ -264,12 +274,29 @@ def create_experiment(run_config,
early_stopping_rounds=eval_early_stopping_steps,
early_stopping_metric=eval_early_stopping_metric,
early_stopping_metric_minimize=eval_early_stopping_metric_minimize)
early_stopping_kwargs = dict(
events_dir=os.path.join(run_config.model_dir, "eval_continuous"),
tag=eval_early_stopping_metric,
num_plateau_steps=eval_early_stopping_steps,
plateau_decrease=eval_early_stopping_metric_minimize,
plateau_delta=eval_early_stopping_metric_delta,
every_n_steps=min_eval_frequency)

# In-process eval (and possible early stopping)
local_schedules = ["train_and_evaluate", "continuous_train_and_eval"]
use_validation_monitor = (
schedule in local_schedules and min_eval_frequency)
# Distributed early stopping
use_early_stopping = (
schedule not in local_schedules and eval_early_stopping_steps)
train_monitors, eval_hooks = create_hooks(
use_tfdbg=use_tfdbg,
use_dbgprofile=use_dbgprofile,
dbgprofile_kwargs=dbgprofile_kwargs,
use_validation_monitor=use_validation_monitor,
validation_monitor_kwargs=validation_monitor_kwargs)
use_early_stopping=use_early_stopping,
validation_monitor_kwargs=validation_monitor_kwargs,
early_stopping_kwargs=early_stopping_kwargs)
hooks_kwargs = {"train_monitors": train_monitors, "eval_hooks": eval_hooks}

# Experiment
Expand Down
3 changes: 2 additions & 1 deletion tensor2tensor/tpu/tpu_trainer_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ def testExperiment(self):
eval_steps=1,
min_eval_frequency=1,
use_tpu=False)
run_config = tpu_trainer_lib.create_run_config(num_gpus=0, use_tpu=False)
run_config = tpu_trainer_lib.create_run_config(
model_dir=self.data_dir, num_gpus=0, use_tpu=False)
hparams = registry.hparams("transformer_tiny_tpu")()
exp = exp_fn(run_config, hparams)
exp.test()
Expand Down
12 changes: 6 additions & 6 deletions tensor2tensor/utils/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,14 @@
flags.DEFINE_integer("train_steps", 250000,
"The number of steps to run training for.")
flags.DEFINE_string("eval_early_stopping_metric", "loss",
"If --schedule=train_and_evaluate and "
"--eval_early_stopping_steps is not None, then stop when "
"--eval_early_stopping_metric has not decreased for "
"If --eval_early_stopping_steps is not None, then stop "
"when --eval_early_stopping_metric has not decreased for "
"--eval_early_stopping_steps")
flags.DEFINE_float("eval_early_stopping_metric_delta", 0.1,
"Delta determining whether metric has plateaued.")
flags.DEFINE_integer("eval_early_stopping_steps", None,
"If --schedule=train_and_evaluate and "
"--eval_early_stopping_steps is not None, then stop when "
"--eval_early_stopping_metric has not decreased for "
"If --eval_early_stopping_steps is not None, then stop "
"when --eval_early_stopping_metric has not decreased for "
"--eval_early_stopping_steps")
flags.DEFINE_bool("eval_early_stopping_metric_minimize", True,
"Whether to check for the early stopping metric going down "
Expand Down
Loading