diff --git a/.travis.yml b/.travis.yml index b67c74b1d..7841b0b7e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -14,7 +14,7 @@ env: - T2T_DATA_DIR=/tmp/t2t-data - T2T_TRAIN_DIR=/tmp/t2t-train script: - - pytest --ignore=tensor2tensor/utils/registry_test.py --ignore=tensor2tensor/problems_test.py --ignore=tensor2tensor/tpu/tpu_trainer_lib_test.py + - pytest --ignore=tensor2tensor/utils/registry_test.py --ignore=tensor2tensor/problems_test.py --ignore=tensor2tensor/tpu/tpu_trainer_lib_test.py --ignore=tensor2tensor/data_generators/algorithmic_math_test.py - pytest tensor2tensor/utils/registry_test.py - pytest tensor2tensor/tpu/tpu_trainer_lib_test.py - t2t-datagen 2>&1 | grep translate && echo passed diff --git a/setup.py b/setup.py index 01ef5e550..fb2b6492d 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name='tensor2tensor', - version='1.4.0', + version='1.4.1', description='Tensor2Tensor', author='Google Inc.', author_email='no-reply@google.com', @@ -30,6 +30,7 @@ 'gym', 'numpy', 'requests', + 'scipy', 'sympy', 'six', ], diff --git a/tensor2tensor/bin/t2t-trainer b/tensor2tensor/bin/t2t-trainer index 7992e9ba9..70435094a 100644 --- a/tensor2tensor/bin/t2t-trainer +++ b/tensor2tensor/bin/t2t-trainer @@ -45,6 +45,7 @@ flags.DEFINE_string("t2t_usr_dir", "", "The imported files should contain registrations, " "e.g. @registry.register_model calls, that will then be " "available to the t2t-trainer.") +flags.DEFINE_integer("random_seed", 1234, "Random seed.") flags.DEFINE_integer("tpu_num_shards", 8, "Number of tpu shards.") flags.DEFINE_integer("iterations_per_loop", 1000, "Number of iterations in a TPU training loop.") @@ -61,7 +62,11 @@ try: flags.DEFINE_string("output_dir", "", "Base output directory for run.") flags.DEFINE_string("schedule", "continuous_train_and_eval", "Method of Experiment to run.") - flags.DEFINE_integer("eval_steps", 200, "Number of steps in evaluation.") + flags.DEFINE_integer("eval_steps", 10000, + "Number of steps in evaluation. By default, eval will " + "stop after eval_steps or when it runs through the eval " + "dataset once in full, whichever comes first, so this " + "can be a very large number.") except: # pylint: disable=bare-except pass @@ -77,9 +82,6 @@ def create_hparams(): def create_experiment_fn(): - use_validation_monitor = (FLAGS.schedule in - ["train_and_evaluate", "continuous_train_and_eval"] - and FLAGS.local_eval_frequency) return tpu_trainer_lib.create_experiment_fn( model_name=FLAGS.model, problem_name=get_problem_name(), @@ -92,9 +94,9 @@ def create_experiment_fn(): decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams), use_tfdbg=FLAGS.tfdbg, use_dbgprofile=FLAGS.dbgprofile, - use_validation_monitor=use_validation_monitor, eval_early_stopping_steps=FLAGS.eval_early_stopping_steps, eval_early_stopping_metric=FLAGS.eval_early_stopping_metric, + eval_early_stopping_metric_delta=FLAGS.eval_early_stopping_metric_delta, eval_early_stopping_metric_minimize=FLAGS. eval_early_stopping_metric_minimize, use_tpu=FLAGS.use_tpu) @@ -170,7 +172,7 @@ def execute_schedule(exp): def main(_): tf.logging.set_verbosity(tf.logging.INFO) - tf.set_random_seed(123) + tpu_trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) log_registry() diff --git a/tensor2tensor/bin/t2t_trainer.py b/tensor2tensor/bin/t2t_trainer.py index d17ff85ea..571a21839 100644 --- a/tensor2tensor/bin/t2t_trainer.py +++ b/tensor2tensor/bin/t2t_trainer.py @@ -44,6 +44,7 @@ "The imported files should contain registrations, " "e.g. @registry.register_model calls, that will then be " "available to the t2t-trainer.") +flags.DEFINE_integer("random_seed", 1234, "Random seed.") flags.DEFINE_integer("tpu_num_shards", 8, "Number of tpu shards.") flags.DEFINE_integer("iterations_per_loop", 1000, "Number of iterations in a TPU training loop.") @@ -60,7 +61,11 @@ flags.DEFINE_string("output_dir", "", "Base output directory for run.") flags.DEFINE_string("schedule", "continuous_train_and_eval", "Method of Experiment to run.") - flags.DEFINE_integer("eval_steps", 200, "Number of steps in evaluation.") + flags.DEFINE_integer("eval_steps", 10000, + "Number of steps in evaluation. By default, eval will " + "stop after eval_steps or when it runs through the eval " + "dataset once in full, whichever comes first, so this " + "can be a very large number.") except: # pylint: disable=bare-except pass @@ -76,9 +81,6 @@ def create_hparams(): def create_experiment_fn(): - use_validation_monitor = (FLAGS.schedule in - ["train_and_evaluate", "continuous_train_and_eval"] - and FLAGS.local_eval_frequency) return tpu_trainer_lib.create_experiment_fn( model_name=FLAGS.model, problem_name=get_problem_name(), @@ -91,9 +93,9 @@ def create_experiment_fn(): decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams), use_tfdbg=FLAGS.tfdbg, use_dbgprofile=FLAGS.dbgprofile, - use_validation_monitor=use_validation_monitor, eval_early_stopping_steps=FLAGS.eval_early_stopping_steps, eval_early_stopping_metric=FLAGS.eval_early_stopping_metric, + eval_early_stopping_metric_delta=FLAGS.eval_early_stopping_metric_delta, eval_early_stopping_metric_minimize=FLAGS. eval_early_stopping_metric_minimize, use_tpu=FLAGS.use_tpu) @@ -169,7 +171,7 @@ def execute_schedule(exp): def main(_): tf.logging.set_verbosity(tf.logging.INFO) - tf.set_random_seed(123) + tpu_trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) log_registry() diff --git a/tensor2tensor/data_generators/algorithmic_math_test.py b/tensor2tensor/data_generators/algorithmic_math_test.py index 7cd67a83c..c7fdfa156 100644 --- a/tensor2tensor/data_generators/algorithmic_math_test.py +++ b/tensor2tensor/data_generators/algorithmic_math_test.py @@ -14,6 +14,7 @@ # limitations under the License. """Tests for tensor2tensor.data_generators.algorithmic_math.""" +# TODO(rsepassi): This test is flaky. Disable, remove, or update. from __future__ import absolute_import from __future__ import division diff --git a/tensor2tensor/data_generators/librispeech.py b/tensor2tensor/data_generators/librispeech.py index d6a07a391..ad8e931d8 100644 --- a/tensor2tensor/data_generators/librispeech.py +++ b/tensor2tensor/data_generators/librispeech.py @@ -16,23 +16,14 @@ """Librispeech dataset.""" import os -from subprocess import call import tarfile -import wave # Dependency imports -import numpy as np - from tensor2tensor.data_generators import generator_utils -from tensor2tensor.data_generators import problem -from tensor2tensor.data_generators import text_encoder -from tensor2tensor.layers import common_layers -from tensor2tensor.utils import modality +from tensor2tensor.data_generators import speech_recognition from tensor2tensor.utils import registry -import tensorflow as tf - _LIBRISPEECH_TRAIN_DATASETS = [ [ @@ -86,130 +77,13 @@ def _collect_data(directory, input_ext, transcription_ext): return data_files -def _get_audio_data(filepath): - # Construct a true .wav file. - out_filepath = filepath.strip(".flac") + ".wav" - # Assumes sox is installed on system. Sox converts from FLAC to WAV. - call(["sox", filepath, out_filepath]) - wav_file = wave.open(open(out_filepath)) - frame_count = wav_file.getnframes() - byte_array = wav_file.readframes(frame_count) - - data = np.fromstring(byte_array, np.uint8).tolist() - return data, frame_count, wav_file.getsampwidth(), wav_file.getnchannels() - - -class LibrispeechTextEncoder(text_encoder.TextEncoder): - - def encode(self, s): - return [self._num_reserved_ids + ord(c) for c in s] - - def decode(self, ids): - """Transform a sequence of int ids into a human-readable string. - - EOS is not expected in ids. - - Args: - ids: list of integers to be converted. - Returns: - s: human-readable string. - """ - decoded_ids = [] - for id_ in ids: - if 0 <= id_ < self._num_reserved_ids: - decoded_ids.append(text_encoder.RESERVED_TOKENS[int(id_)]) - else: - decoded_ids.append(id_ - self._num_reserved_ids) - return "".join([chr(d) for d in decoded_ids]) - - -@registry.register_audio_modality -class LibrispeechModality(modality.Modality): - """Performs strided conv compressions for audio spectral data.""" - - def bottom(self, inputs): - """Transform input from data space to model space. - - Args: - inputs: A Tensor with shape [batch, ...] - Returns: - body_input: A Tensor with shape [batch, ?, ?, body_input_depth]. - """ - with tf.variable_scope(self.name): - # TODO(aidangomez): Will need to sort out a better audio pipeline - def xnet_resblock(x, filters, res_relu, name): - with tf.variable_scope(name): - # We only stride along the length dimension to preserve the spectral - # bins (which are tiny in dimensionality relative to length) - y = common_layers.separable_conv_block( - x, - filters, [((1, 1), (3, 3)), ((1, 1), (3, 3))], - first_relu=True, - padding="SAME", - force2d=True, - name="sep_conv_block") - y = common_layers.pool(y, (3, 3), "MAX", "SAME", strides=(2, 1)) - return y + common_layers.conv_block( - x, - filters, [((1, 1), (1, 1))], - padding="SAME", - strides=(2, 1), - first_relu=res_relu, - force2d=True, - name="res_conv0") - - # Rescale from UINT8 to floats in [-1,-1] - signals = (tf.to_float(inputs)-127)/128. - signals = tf.squeeze(signals, [2, 3]) - - # `stfts` is a complex64 Tensor representing the short-time Fourier - # Transform of each signal in `signals`. Its shape is - # [batch_size, ?, fft_unique_bins] - # where fft_unique_bins = fft_length // 2 + 1 = 513. - stfts = tf.contrib.signal.stft(signals, frame_length=1024, frame_step=512, - fft_length=1024) - - # An energy spectrogram is the magnitude of the complex-valued STFT. - # A float32 Tensor of shape [batch_size, ?, 513]. - magnitude_spectrograms = tf.abs(stfts) - - # Warp the linear-scale, magnitude spectrograms into the mel-scale. - num_spectrogram_bins = magnitude_spectrograms.shape[-1].value - lower_edge_hertz, upper_edge_hertz, num_mel_bins = 80.0, 7600.0, 64 - sample_rate = 16000 - linear_to_mel_weight_matrix = ( - tf.contrib.signal.linear_to_mel_weight_matrix( - num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz, - upper_edge_hertz)) - mel_spectrograms = tf.tensordot( - magnitude_spectrograms, linear_to_mel_weight_matrix, 1) - # Note: Shape inference for tensordot does not currently handle this case. - mel_spectrograms.set_shape(magnitude_spectrograms.shape[:-1].concatenate( - linear_to_mel_weight_matrix.shape[-1:])) - - x = tf.expand_dims(mel_spectrograms, 2) - x.set_shape([None, None, None, num_mel_bins]) - for i in xrange(self._model_hparams.audio_compression): - x = xnet_resblock(x, 2**(i + 1), True, "compress_block_%d" % i) - return xnet_resblock(x, self._body_input_depth, False, - "compress_block_final") - - @registry.register_problem() -class Librispeech(problem.Problem): - """Problem spec for English word to dictionary definition.""" +class Librispeech(speech_recognition.SpeechRecognitionProblem): + """Problem spec for Librispeech using clean and noisy data.""" - @property - def is_character_level(self): - return True - - @property - def input_space_id(self): - return problem.SpaceID.AUDIO_SPECTRAL - - @property - def target_space_id(self): - return problem.SpaceID.EN_CHR + # Select only the clean data + TRAIN_DATASETS = _LIBRISPEECH_TRAIN_DATASETS + DEV_DATASETS = _LIBRISPEECH_TEST_DATASETS @property def num_shards(self): @@ -228,26 +102,8 @@ def use_train_shards_for_dev(self): """If true, we only generate training data and hold out shards for dev.""" return False - def feature_encoders(self, _): - return { - "inputs": text_encoder.TextEncoder(), - "targets": LibrispeechTextEncoder(), - } - - def example_reading_spec(self): - data_fields = { - "inputs": tf.VarLenFeature(tf.int64), - "targets": tf.VarLenFeature(tf.int64), - } - data_items_to_decoders = None - return (data_fields, data_items_to_decoders) - - def generator(self, data_dir, tmp_dir, training, + def generator(self, data_dir, tmp_dir, datasets, eos_list=None, start_from=0, how_many=0): - eos_list = [1] if eos_list is None else eos_list - datasets = (_LIBRISPEECH_TRAIN_DATASETS if training - else _LIBRISPEECH_TEST_DATASETS) - num_reserved_ids = self.feature_encoders(None)["targets"].num_reserved_ids i = 0 for url, subdir in datasets: filename = os.path.basename(url) @@ -267,19 +123,18 @@ def generator(self, data_dir, tmp_dir, training, data_dir = os.path.join(tmp_dir, "LibriSpeech", subdir) data_files = _collect_data(data_dir, "flac", "txt") data_pairs = data_files.values() + + encoders = self.feature_encoders(None) + audio_encoder = encoders["waveforms"] + text_encoder = encoders["targets"] + for media_file, text_data in sorted(data_pairs)[start_from:]: if how_many > 0 and i == how_many: return i += 1 - audio_data, sample_count, sample_width, num_channels = _get_audio_data( - media_file) - label = [num_reserved_ids + ord(c) for c in text_data] + eos_list yield { - "inputs": audio_data, - "audio/channel_count": [num_channels], - "audio/sample_count": [sample_count], - "audio/sample_width": [sample_width], - "targets": label + "waveforms": audio_encoder.encode(media_file), + "targets": text_encoder.encode(text_data) } def generate_data(self, data_dir, tmp_dir, task_id=-1): @@ -287,24 +142,34 @@ def generate_data(self, data_dir, tmp_dir, task_id=-1): data_dir, self.num_shards, shuffled=False) dev_paths = self.dev_filepaths( data_dir, self.num_dev_shards, shuffled=False) + if self.use_train_shards_for_dev: all_paths = train_paths + dev_paths generator_utils.generate_files( - self.generator(data_dir, tmp_dir, True), all_paths) + self.generator(data_dir, tmp_dir, self.TRAIN_DATASETS), all_paths) generator_utils.shuffle_dataset(all_paths) else: generator_utils.generate_dataset_and_shuffle( - self.generator(data_dir, tmp_dir, True), train_paths, - self.generator(data_dir, tmp_dir, False), dev_paths) + self.generator(data_dir, tmp_dir, self.TRAIN_DATASETS), train_paths, + self.generator(data_dir, tmp_dir, self.DEV_DATASETS), dev_paths) - def hparams(self, defaults, unused_model_hparams): - p = defaults - p.stop_at_eos = int(False) - p.input_modality = {"inputs": ("audio:librispeech_modality", None)} - p.target_modality = (registry.Modalities.SYMBOL, 256) - def preprocess_example(self, example, mode, hparams): - return example +@registry.register_problem() +class LibrispeechCleanSmall(Librispeech): + """Problem spec for Librispeech using 100h clean train data.""" + + # Select only the clean data + TRAIN_DATASETS = _LIBRISPEECH_TRAIN_DATASETS[:1] + DEV_DATASETS = _LIBRISPEECH_TEST_DATASETS[:1] + + +@registry.register_problem() +class LibrispeechClean(Librispeech): + """Problem spec for Librispeech using 460h clean train data.""" + + # Select only the clean data + TRAIN_DATASETS = _LIBRISPEECH_TRAIN_DATASETS[:2] + DEV_DATASETS = _LIBRISPEECH_TEST_DATASETS[:1] # TODO(lukaszkaiser): clean up hparams or remove from here. diff --git a/tensor2tensor/data_generators/problem.py b/tensor2tensor/data_generators/problem.py index e944f15ab..52d7bdab2 100644 --- a/tensor2tensor/data_generators/problem.py +++ b/tensor2tensor/data_generators/problem.py @@ -576,6 +576,19 @@ def define_shapes(example): batching_scheme["boundaries"], batching_scheme["batch_sizes"]) + if not is_training: + def _pad_batch(features): + if not config or config.data_parallelism.n <= 1: + return features + tf.logging.warn( + "Padding the batch to ensure that remainder eval batches have " + "a batch size divisible by the number of data shards. This may " + "lead to incorrect metrics for non-zero-padded features, e.g. " + "images. Use a single datashard (i.e. 1 GPU) in that case.") + return pad_batch(features, config.data_parallelism.n) + + dataset = dataset.map(_pad_batch, num_parallel_calls=num_threads) + dataset = dataset.map(define_shapes, num_parallel_calls=num_threads) dataset = dataset.prefetch(1) features = dataset.make_one_shot_iterator().get_next() @@ -930,3 +943,23 @@ def standardize_shapes(features, batch_size=None): t.get_shape().assert_is_fully_defined() return features + + +def pad_batch(features, batch_multiple): + """Pad batch dim of features to nearest multiple of batch_multiple.""" + feature = list(features.items())[0][1] + batch_size = tf.shape(feature)[0] + mod = batch_size % batch_multiple + has_mod = tf.cast(tf.cast(mod, tf.bool), tf.int32) + batch_padding = batch_multiple * has_mod - mod + + padded_features = {} + for k, feature in features.items(): + rank = len(feature.shape) + paddings = [] + for _ in range(rank): + paddings.append([0, 0]) + paddings[0][1] = batch_padding + padded_feature = tf.pad(feature, paddings) + padded_features[k] = padded_feature + return padded_features diff --git a/tensor2tensor/data_generators/speech_recognition.py b/tensor2tensor/data_generators/speech_recognition.py new file mode 100644 index 000000000..c54878045 --- /dev/null +++ b/tensor2tensor/data_generators/speech_recognition.py @@ -0,0 +1,332 @@ +# coding=utf-8 +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Common classes for automatic speech recogntion (ASR) datasets. + +The audio import uses sox to generate normalized waveforms, please install +it as appropriate (e.g. using apt-get or yum). +""" + +import functools +import os +from subprocess import call +import tempfile + +# Dependency imports + +import numpy as np +from scipy.io import wavfile +import scipy.signal + +from tensor2tensor.data_generators import problem +from tensor2tensor.data_generators import text_encoder +from tensor2tensor.layers import common_layers +from tensor2tensor.utils import modality +from tensor2tensor.utils import registry + +import tensorflow as tf + + +# +# ASR Feature pipeline in TF. +# +def add_delta_deltas(filterbanks, name=None): + """Compute time first and second-order derivative channels. + + Args: + filterbanks: float32 tensor with shape [batch_size, len, num_bins, 1] + name: scope name + + Returns: + float32 tensor with shape [batch_size, len, num_bins, 3] + """ + delta_filter = np.array([2, 1, 0, -1, -2]) + delta_delta_filter = scipy.signal.convolve(delta_filter, delta_filter, "full") + + delta_filter_stack = np.array( + [[0] * 4 + [1] + [0] * 4, [0] * 2 + list(delta_filter) + [0] * 2, + list(delta_delta_filter)], + dtype=np.float32).T[:, None, None, :] + + delta_filter_stack /= np.sqrt( + np.sum(delta_filter_stack**2, axis=0, keepdims=True)) + + filterbanks = tf.nn.conv2d( + filterbanks, delta_filter_stack, [1, 1, 1, 1], "SAME", data_format="NHWC", + name=name) + return filterbanks + + +def compute_mel_filterbank_features( + waveforms, + sample_rate=16000, dither=1.0 / np.iinfo(np.int16).max, preemphasis=0.97, + frame_length=25, frame_step=10, fft_length=None, + window_fn=functools.partial(tf.contrib.signal.hann_window, periodic=True), + lower_edge_hertz=80.0, upper_edge_hertz=7600.0, num_mel_bins=80, + log_noise_floor=1e-3): + """Implement mel-filterbank extraction using tf ops. + + Args: + waveforms: float32 tensor with shape [batch_size, max_len] + sample_rate: sampling rate of the waveform + dither: stddev of Gaussian noise added to waveform to prevent quantization + artefacts + preemphasis: waveform high-pass filtering costant + frame_length: frame length in ms + frame_step: frame_Step in ms + fft_length: number of fft bins + window_fn: windowing function + lower_edge_hertz: lowest frequency of the filterbank + upper_edge_hertz: highest frequency of the filterbank + num_mel_bins: filterbank size + log_noise_floor: clip small values to prevent numeric overflow in log + Returns: + tuple of (filterbanks, filterbank_lens) where: + filterbanks are float32 tensor with shape [batch_size, len, num_bins, 1] + filterbank_lens are int64 tensor with shape [batch_size] + """ + # `stfts` is a complex64 Tensor representing the short-time Fourier + # Transform of each signal in `signals`. Its shape is + # [batch_size, ?, fft_unique_bins] + # where fft_unique_bins = fft_length // 2 + 1 + if dither > 0: + waveforms += tf.random_normal(tf.shape(waveforms), stddev=dither) + if preemphasis > 0: + waveforms = waveforms[:, 1:] - preemphasis * waveforms[:, :-1] + frame_length = int(frame_length * sample_rate / 1e3) + frame_step = int(frame_step * sample_rate / 1e3) + if fft_length is None: + fft_length = int(2**(np.ceil(np.log2(frame_length)))) + stfts = tf.contrib.signal.stft( + waveforms, + frame_length=frame_length, + frame_step=frame_step, + fft_length=fft_length, + window_fn=window_fn, + pad_end=True) + + # An energy spectrogram is the magnitude of the complex-valued STFT. + # A float32 Tensor of shape [batch_size, ?, 257]. + magnitude_spectrograms = tf.abs(stfts) + + # Warp the linear-scale, magnitude spectrograms into the mel-scale. + num_spectrogram_bins = magnitude_spectrograms.shape[-1].value + linear_to_mel_weight_matrix = ( + tf.contrib.signal.linear_to_mel_weight_matrix( + num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz, + upper_edge_hertz)) + mel_spectrograms = tf.tensordot( + magnitude_spectrograms, linear_to_mel_weight_matrix, 1) + # Note: Shape inference for tensordot does not currently handle this case. + mel_spectrograms.set_shape(magnitude_spectrograms.shape[:-1].concatenate( + linear_to_mel_weight_matrix.shape[-1:])) + + log_mel_sgram = tf.log(tf.maximum(log_noise_floor, mel_spectrograms)) + + return tf.expand_dims(log_mel_sgram, -1) + + +# +# Audio problem definition +# +class AudioEncoder(object): + """Encoder class for saving and loading waveforms.""" + + def __init__(self, num_reserved_ids=0, sample_rate=16000): + assert num_reserved_ids == 0 + self._sample_rate = sample_rate + + @property + def num_reserved_ids(self): + return 0 + + def encode(self, s): + """Transform a string with a filename into a list of float32. + + Args: + s: path to the file with a waveform. + + Returns: + samples: list of int16s + """ + # Make sure that the data is a single channel, 16bit, 16kHz wave. + # TODO(chorowski): the directory may not be writable, this should fallback + # to a temp path, and provide instructions for instaling sox. + if not s.endswith(".wav"): + out_filepath = s + ".wav" + if not os.path.exists(out_filepath): + call(["sox", "-r", "16k", "-b", "16", "-c", "1", s, out_filepath]) + s = out_filepath + rate, data = wavfile.read(s) + assert rate == self._sample_rate + assert len(data.shape) == 1 + if data.dtype not in [np.float32, np.float64]: + data = data.astype(np.float32) / np.iinfo(data.dtype).max + return data.tolist() + + def decode(self, ids): + """Transform a sequence of float32 into a waveform. + + Args: + ids: list of integers to be converted. + + Returns: + Path to the temporary file where the waveform was saved. + + Raises: + ValueError: if the ids are not of the appropriate size. + """ + _, tmp_file_path = tempfile.mkstemp() + wavfile.write(tmp_file_path, self._sample_rate, np.asarray(ids)) + return tmp_file_path + + def decode_list(self, ids): + """Transform a sequence of int ids into an image file. + + Args: + ids: list of integers to be converted. + + Returns: + Singleton list: path to the temporary file where the wavfile was saved. + """ + return [self.decode(ids)] + + @property + def vocab_size(self): + return 256 + + +class SpeechRecognitionProblem(problem.Problem): + """Base class for speech recognition problems.""" + + def hparams(self, defaults, model_hparams): + p = model_hparams + # Filterbank extraction + p.add_hparam("audio_sample_rate", 16000) + p.add_hparam("audio_preemphasis", 0.97) + p.add_hparam("audio_dither", 1.0 / np.iinfo(np.int16).max) + p.add_hparam("audio_frame_length", 25.0) + p.add_hparam("audio_frame_step", 10.0) + p.add_hparam("audio_lower_edge_hertz", 20.0) + p.add_hparam("audio_upper_edge_hertz", 8000.0) + p.add_hparam("audio_num_mel_bins", 80) + p.add_hparam("audio_add_delta_deltas", True) + + p = defaults + # p.stop_at_eos = int(False) + p.input_modality = {"inputs": ("audio:speech_recognition_modality", None)} + p.target_modality = (registry.Modalities.SYMBOL, 256) + + @property + def is_character_level(self): + return True + + @property + def input_space_id(self): + return problem.SpaceID.AUDIO_SPECTRAL + + @property + def target_space_id(self): + return problem.SpaceID.EN_CHR + + def feature_encoders(self, _): + return { + "waveforms": AudioEncoder(), + "targets": text_encoder.ByteTextEncoder(), + } + + def example_reading_spec(self): + data_fields = { + "waveforms": tf.VarLenFeature(tf.float32), + "targets": tf.VarLenFeature(tf.int64), + } + + data_items_to_decoders = None + + return data_fields, data_items_to_decoders + + def preprocess_example(self, example, mode, hparams): + p = hparams + waveforms = tf.expand_dims(example["waveforms"], 0) + mel_fbanks = compute_mel_filterbank_features( + waveforms, + sample_rate=p.audio_sample_rate, + dither=p.audio_dither, + preemphasis=p.audio_preemphasis, + frame_length=p.audio_frame_length, + frame_step=p.audio_frame_step, + lower_edge_hertz=p.audio_lower_edge_hertz, + upper_edge_hertz=p.audio_upper_edge_hertz, + num_mel_bins=p.audio_num_mel_bins) + if p.audio_add_delta_deltas: + mel_fbanks = add_delta_deltas(mel_fbanks) + fbank_size = common_layers.shape_list(mel_fbanks) + assert fbank_size[0] == 1 + # Later models like to flatten the two spatial dims. Instead, we add a unit + # spatial dim and flatten the frequencies and channels. + example["inputs"] = tf.reshape( + mel_fbanks, [fbank_size[1], 1, fbank_size[2] * fbank_size[3]]) + return super(SpeechRecognitionProblem, self + ).preprocess_example(example, mode, hparams) + + +@registry.register_audio_modality +class SpeechRecognitionModality(modality.Modality): + """Common ASR filterbank processing.""" + + def bottom(self, inputs): + """Use batchnorm instead of CMVN and shorten the stft with strided convs. + + Args: + inputs: float32 tensor with shape [batch_size, len, 1, freqs * channels] + + Returns: + float32 tensor with shape [batch_size, shorter_len, 1, hidden_size] + """ + p = self._model_hparams + training = p.mode == tf.estimator.ModeKeys.TRAIN + + with tf.variable_scope(self.name): + x = inputs + num_mel_bins = p.audio_num_mel_bins + num_channels = 3 if p.audio_add_delta_deltas else 1 + # The convention is that the models are flattened along the spatial, + # dimensions, thus the speech preprocessor treats frequencies and channels + # as image colors (last axis) + x.set_shape([None, None, 1, num_mel_bins * num_channels]) + + # This replaces CMVN estimation on data + x = tf.layers.batch_normalization( + x, axis=3, center=False, scale=False, training=training) + + xshape = common_layers.shape_list(x) + # restore batch_size x time x frequency x channel layout + x = tf.reshape(x, [xshape[0], xshape[1], num_mel_bins, num_channels]) + + # TODO(chorowski): how to specify bottom's hparams and avoid hardcoding? + for _ in range(2): + x = tf.layers.conv2d( + x, 128, (3, 3), (2, 2), use_bias=False) + x = tf.layers.batch_normalization(x, axis=3, training=training) + x = tf.nn.relu(x) + + xshape = common_layers.shape_list(x) + # apply a conv that will remove all frequencies and at the same time + # project the output into desired hidden_size + x = tf.layers.conv2d(x, p.hidden_size, (3, xshape[2]), use_bias=False) + assert common_layers.shape_list(x)[2] == 1 + x = tf.layers.batch_normalization(x, axis=3, training=training) + x = tf.nn.relu(x) + return x diff --git a/tensor2tensor/models/transformer.py b/tensor2tensor/models/transformer.py index e9c272d7c..de812b64b 100644 --- a/tensor2tensor/models/transformer.py +++ b/tensor2tensor/models/transformer.py @@ -53,7 +53,8 @@ def encode(self, inputs, target_space, hparams, features=None): """Encode transformer inputs. Args: - inputs: Transformer inputs [batch_size, input_length, hidden_dim] + inputs: Transformer inputs [batch_size, input_length, input_height, + hidden_dim] which will be flattened along the two spatial dimensions. target_space: scalar, target space ID. hparams: hyperparmeters for model. features: optionally pass the entire features dictionary as well. diff --git a/tensor2tensor/notebooks/hello_t2t.ipynb b/tensor2tensor/notebooks/hello_t2t.ipynb index 1ff6b1d2b..5b58b042b 100644 --- a/tensor2tensor/notebooks/hello_t2t.ipynb +++ b/tensor2tensor/notebooks/hello_t2t.ipynb @@ -85,6 +85,7 @@ "import os\n", "import collections\n", "\n", + "from tensor2tensor import models\n", "from tensor2tensor import problems\n", "from tensor2tensor.layers import common_layers\n", "from tensor2tensor.tpu import tpu_trainer_lib\n", @@ -1540,55 +1541,6 @@ } ] }, - { - "metadata": { - "id": "a2cL8UwLaSYG", - "colab_type": "code", - "colab": { - "autoexec": { - "startup": false, - "wait_interval": 0 - } - } - }, - "source": [ - "# This will eventually be available at\n", - "# tensor2tensor.metrics.create_eager_metrics\n", - "def create_eager_metrics(metric_names):\n", - " \"\"\"Create metrics accumulators and averager for Eager mode.\n", - "\n", - " Args:\n", - " metric_names: list from tensor2tensor.metrics.Metrics\n", - "\n", - " Returns:\n", - " (accum_fn(predictions, targets) => None,\n", - " result_fn() => dict\n", - " \"\"\"\n", - " metric_fns = dict(\n", - " [(name, metrics.METRICS_FNS[name]) for name in metric_names])\n", - " tfe_metrics = dict()\n", - "\n", - " for name in metric_names:\n", - " tfe_metrics[name] = tfe.metrics.Mean(name=name)\n", - "\n", - " def metric_accum(predictions, targets):\n", - " for name, metric_fn in metric_fns.items():\n", - " val, weight = metric_fn(predictions, targets,\n", - " weights_fn=common_layers.weights_all)\n", - " tfe_metrics[name](np.squeeze(val), np.squeeze(weight))\n", - "\n", - " def metric_means():\n", - " avgs = {}\n", - " for name in metric_names:\n", - " avgs[name] = tfe_metrics[name].result().numpy()\n", - " return avgs\n", - "\n", - " return metric_accum, metric_means" - ], - "cell_type": "code", - "execution_count": 0, - "outputs": [] - }, { "metadata": { "id": "CIFlkiVOd8jO", @@ -1625,7 +1577,7 @@ "\n", "# Create eval metric accumulators for accuracy (ACC) and accuracy in\n", "# top 5 (ACC_TOP5)\n", - "metrics_accum, metrics_result = create_eager_metrics(\n", + "metrics_accum, metrics_result = metrics.create_eager_metrics(\n", " [metrics.Metrics.ACC, metrics.Metrics.ACC_TOP5])\n", "\n", "for count, example in enumerate(tfe.Iterator(mnist_eval_dataset)):\n", diff --git a/tensor2tensor/tpu/tpu_trainer.py b/tensor2tensor/tpu/tpu_trainer.py index d17ff85ea..571a21839 100644 --- a/tensor2tensor/tpu/tpu_trainer.py +++ b/tensor2tensor/tpu/tpu_trainer.py @@ -44,6 +44,7 @@ "The imported files should contain registrations, " "e.g. @registry.register_model calls, that will then be " "available to the t2t-trainer.") +flags.DEFINE_integer("random_seed", 1234, "Random seed.") flags.DEFINE_integer("tpu_num_shards", 8, "Number of tpu shards.") flags.DEFINE_integer("iterations_per_loop", 1000, "Number of iterations in a TPU training loop.") @@ -60,7 +61,11 @@ flags.DEFINE_string("output_dir", "", "Base output directory for run.") flags.DEFINE_string("schedule", "continuous_train_and_eval", "Method of Experiment to run.") - flags.DEFINE_integer("eval_steps", 200, "Number of steps in evaluation.") + flags.DEFINE_integer("eval_steps", 10000, + "Number of steps in evaluation. By default, eval will " + "stop after eval_steps or when it runs through the eval " + "dataset once in full, whichever comes first, so this " + "can be a very large number.") except: # pylint: disable=bare-except pass @@ -76,9 +81,6 @@ def create_hparams(): def create_experiment_fn(): - use_validation_monitor = (FLAGS.schedule in - ["train_and_evaluate", "continuous_train_and_eval"] - and FLAGS.local_eval_frequency) return tpu_trainer_lib.create_experiment_fn( model_name=FLAGS.model, problem_name=get_problem_name(), @@ -91,9 +93,9 @@ def create_experiment_fn(): decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams), use_tfdbg=FLAGS.tfdbg, use_dbgprofile=FLAGS.dbgprofile, - use_validation_monitor=use_validation_monitor, eval_early_stopping_steps=FLAGS.eval_early_stopping_steps, eval_early_stopping_metric=FLAGS.eval_early_stopping_metric, + eval_early_stopping_metric_delta=FLAGS.eval_early_stopping_metric_delta, eval_early_stopping_metric_minimize=FLAGS. eval_early_stopping_metric_minimize, use_tpu=FLAGS.use_tpu) @@ -169,7 +171,7 @@ def execute_schedule(exp): def main(_): tf.logging.set_verbosity(tf.logging.INFO) - tf.set_random_seed(123) + tpu_trainer_lib.set_random_seed(FLAGS.random_seed) usr_dir.import_usr_dir(FLAGS.t2t_usr_dir) log_registry() diff --git a/tensor2tensor/tpu/tpu_trainer_lib.py b/tensor2tensor/tpu/tpu_trainer_lib.py index 475d0f1be..bde85e4db 100644 --- a/tensor2tensor/tpu/tpu_trainer_lib.py +++ b/tensor2tensor/tpu/tpu_trainer_lib.py @@ -19,10 +19,16 @@ from __future__ import division from __future__ import print_function +import os +import random + # Dependency imports +import numpy as np + from tensor2tensor.utils import devices from tensor2tensor.utils import expert_utils +from tensor2tensor.utils import metrics_hook from tensor2tensor.utils import registry from tensor2tensor.utils import t2t_model @@ -186,7 +192,8 @@ def create_estimator(model_name, def create_hooks(use_tfdbg=False, use_dbgprofile=False, dbgprofile_kwargs=None, - use_validation_monitor=False, validation_monitor_kwargs=None): + use_validation_monitor=False, validation_monitor_kwargs=None, + use_early_stopping=False, early_stopping_kwargs=None): """Create train and eval hooks for Experiment.""" train_monitors = [] eval_hooks = [] @@ -208,6 +215,12 @@ def create_hooks(use_tfdbg=False, use_dbgprofile=False, dbgprofile_kwargs=None, tf.contrib.learn.monitors.ValidationMonitor( hooks=eval_hooks, **validation_monitor_kwargs)) + if use_early_stopping: + hook = metrics_hook.EarlyStoppingHook(**early_stopping_kwargs) + # Adding to both training and eval so that eval aborts as well + train_monitors.append(hook) + eval_hooks.append(hook) + return train_monitors, eval_hooks @@ -224,9 +237,9 @@ def create_experiment(run_config, decode_hparams=None, use_tfdbg=False, use_dbgprofile=False, - use_validation_monitor=False, eval_early_stopping_steps=None, eval_early_stopping_metric=None, + eval_early_stopping_metric_delta=None, eval_early_stopping_metric_minimize=True, use_tpu=False): """Create Experiment.""" @@ -264,12 +277,29 @@ def create_experiment(run_config, early_stopping_rounds=eval_early_stopping_steps, early_stopping_metric=eval_early_stopping_metric, early_stopping_metric_minimize=eval_early_stopping_metric_minimize) + early_stopping_kwargs = dict( + events_dir=os.path.join(run_config.model_dir, "eval_continuous"), + tag=eval_early_stopping_metric, + num_plateau_steps=eval_early_stopping_steps, + plateau_decrease=eval_early_stopping_metric_minimize, + plateau_delta=eval_early_stopping_metric_delta, + every_n_steps=min_eval_frequency) + + # In-process eval (and possible early stopping) + local_schedules = ["train_and_evaluate", "continuous_train_and_eval"] + use_validation_monitor = ( + schedule in local_schedules and min_eval_frequency) + # Distributed early stopping + use_early_stopping = ( + schedule not in local_schedules and eval_early_stopping_steps) train_monitors, eval_hooks = create_hooks( use_tfdbg=use_tfdbg, use_dbgprofile=use_dbgprofile, dbgprofile_kwargs=dbgprofile_kwargs, use_validation_monitor=use_validation_monitor, - validation_monitor_kwargs=validation_monitor_kwargs) + use_early_stopping=use_early_stopping, + validation_monitor_kwargs=validation_monitor_kwargs, + early_stopping_kwargs=early_stopping_kwargs) hooks_kwargs = {"train_monitors": train_monitors, "eval_hooks": eval_hooks} # Experiment @@ -309,3 +339,9 @@ def add_problem_hparams(hparams, problems): hparams.problem_instances.append(problem) hparams.problems.append(p_hparams) + + +def set_random_seed(seed): + tf.set_random_seed(seed) + random.seed(seed) + np.random.seed(seed) diff --git a/tensor2tensor/tpu/tpu_trainer_lib_test.py b/tensor2tensor/tpu/tpu_trainer_lib_test.py index e8c1689c7..2a2148afd 100644 --- a/tensor2tensor/tpu/tpu_trainer_lib_test.py +++ b/tensor2tensor/tpu/tpu_trainer_lib_test.py @@ -68,7 +68,8 @@ def testExperiment(self): eval_steps=1, min_eval_frequency=1, use_tpu=False) - run_config = tpu_trainer_lib.create_run_config(num_gpus=0, use_tpu=False) + run_config = tpu_trainer_lib.create_run_config( + model_dir=self.data_dir, num_gpus=0, use_tpu=False) hparams = registry.hparams("transformer_tiny_tpu")() exp = exp_fn(run_config, hparams) exp.test() diff --git a/tensor2tensor/utils/flags.py b/tensor2tensor/utils/flags.py index f4e93a68f..410dccfe1 100644 --- a/tensor2tensor/utils/flags.py +++ b/tensor2tensor/utils/flags.py @@ -55,14 +55,14 @@ flags.DEFINE_integer("train_steps", 250000, "The number of steps to run training for.") flags.DEFINE_string("eval_early_stopping_metric", "loss", - "If --schedule=train_and_evaluate and " - "--eval_early_stopping_steps is not None, then stop when " - "--eval_early_stopping_metric has not decreased for " + "If --eval_early_stopping_steps is not None, then stop " + "when --eval_early_stopping_metric has not decreased for " "--eval_early_stopping_steps") +flags.DEFINE_float("eval_early_stopping_metric_delta", 0.1, + "Delta determining whether metric has plateaued.") flags.DEFINE_integer("eval_early_stopping_steps", None, - "If --schedule=train_and_evaluate and " - "--eval_early_stopping_steps is not None, then stop when " - "--eval_early_stopping_metric has not decreased for " + "If --eval_early_stopping_steps is not None, then stop " + "when --eval_early_stopping_metric has not decreased for " "--eval_early_stopping_steps") flags.DEFINE_bool("eval_early_stopping_metric_minimize", True, "Whether to check for the early stopping metric going down " diff --git a/tensor2tensor/utils/metrics_hook.py b/tensor2tensor/utils/metrics_hook.py new file mode 100644 index 000000000..964139a42 --- /dev/null +++ b/tensor2tensor/utils/metrics_hook.py @@ -0,0 +1,291 @@ +# coding=utf-8 +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Summary-based SessionRunHooks.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +# Dependency imports + +import tensorflow as tf + +from tensorboard.backend.event_processing import event_accumulator +from tensorboard.backend.event_processing import event_multiplexer + + +class MetricsBasedHook(tf.train.SessionRunHook): + """Base class for hooks based on summary metrics. + + Subclasses should override _process_metrics. + + If _process_metrics returns True, calls run_context.request_stop(). + + This can be used to something like "Stop after the loss has stopped decreasing + for 5000 steps. + """ + _RUN_NAME = "run%d" + + def __init__(self, events_dir, subdirs=None, tags=None, every_n_steps=1000): + """Construct MetricsBasedHook. + + Args: + events_dir: str, top-level directory containing events files. + subdirs: list, subdirectories of events_dir that also contain + events files. Use "" to specify the top-level directory. Defaults to + [""]. + tags: list, names of metrics to collect. Default will collect all + metrics. + every_n_steps: int, collect metrics every n steps. + """ + self._events_dir = events_dir + self._subdirs = subdirs or [""] + self._tags = tags + self._every_n_steps = every_n_steps + self._start_step = None + self._event_multiplexer = self._init_multiplexer() + + def _init_multiplexer(self): + dirs = [os.path.join(self._events_dir, subdir) for subdir in self._subdirs] + run_path_map = dict([(self._RUN_NAME % i, d) for i, d in enumerate(dirs)]) + return event_multiplexer.EventMultiplexer(run_path_map) + + def begin(self): + self._global_step_tensor = tf.train.get_global_step() + if self._global_step_tensor is None: + raise RuntimeError("Global step must be created to use MetricsBasedHook.") + + def after_create_session(self, session, coord): + del coord + if self._start_step is None: + self._start_step = session.run(self._global_step_tensor) + + def before_run(self, run_context): + del run_context + return tf.train.SessionRunArgs([self._global_step_tensor]) + + def after_run(self, run_context, run_values): + global_step = run_values.results[0] + if (global_step - self._start_step) % self._every_n_steps != 0: + return + metrics = self._collect_metrics() + self._after_run(run_context, run_values, global_step, metrics) + + def _after_run(self, run_context, run_values, global_step, metrics): + if self._process_metrics(global_step, metrics): + run_context.request_stop() + + def _collect_metrics(self): + self._event_multiplexer.Reload() + subdir_data = {} + for i, subdir in enumerate(self._subdirs): + subdir_metrics = {} + + accum = self._event_multiplexer.GetAccumulator(self._RUN_NAME % i) + for tag in accum.Tags()[event_accumulator.SCALARS]: + steps, vals = zip(*[ + (event.step, event.value) for event in accum.Scalars(tag)]) + subdir_metrics[tag] = (steps, vals) + + subdir_data[subdir] = subdir_metrics + return subdir_data + + def _process_metrics(self, global_step, metrics): + """Process the collected metrics. + + Args: + global_step: int, the current global step value. + metrics: dict. The collected + metrics. subdir_metrics is a dict from tag name to tuple of lists. The + lists are a list of global steps and a list of values. + i.e. subdir_metrics: + `dict global steps, list values>>>` + + Returns: + should_stop: bool. If True, will request that the session stops. + """ + return False + + +class EarlyStoppingHook(MetricsBasedHook): + """EarlyStoppingHook will stop training when a given metric has plateaued.""" + + def __init__(self, + events_dir, + tag, + num_plateau_steps=1000, + plateau_delta=0.1, + plateau_decrease=True, + every_n_steps=1000): + """Create an EarlyStoppingHook. + + This hook will stop training when the metric identified by tag has + plateaued. Plateaued is defined by the metric having stopped + increasing/decreasing (based on plateau_decrease) by plateau_delta for + num_plateau_steps. + + Args: + events_dir: Directory with events files. + tag: Name of metric in TensorBoard. + num_plateau_steps: Number of steps over which to check the plateau. + plateau_delta: delta to define a "plateau". + plateau_decrease: whether to check decrease or increase in the metric. + every_n_steps: how often to run this hook. + + Returns: + An instance of EarlyStoppingHook. + """ + super(EarlyStoppingHook, self).__init__( + events_dir=events_dir, tags=[tag], every_n_steps=every_n_steps) + self._num_plateau_steps = num_plateau_steps + self._plateau_delta = plateau_delta + self._plateau_decrease = plateau_decrease + + def _process_metrics(self, global_step, metrics): + if not metrics: + return + + if not list(metrics.values())[0]: + return + + # Metrics should have just a single subdir and a single tag + steps, vals = list(metrics.values())[0][self._tags[0]] + return has_metric_plateaued( + steps, + vals, + num_steps=self._num_plateau_steps, + delta=self._plateau_delta, + decrease=self._plateau_decrease) + + +class PlateauOpHook(MetricsBasedHook): + """Runs an op when a metric has plateaued.""" + + def __init__(self, + events_dir, + tag, + plateau_op, + num_plateau_steps=1000, + plateau_delta=0.1, + plateau_decrease=True, + every_n_steps=1000, + only_once=False): + """See EarlyStoppingHook for args. Runs plateau_op if plateaued.""" + super(PlateauOpHook, self).__init__( + events_dir=events_dir, tags=[tag], every_n_steps=every_n_steps) + self._num_plateau_steps = num_plateau_steps + self._plateau_delta = plateau_delta + self._plateau_decrease = plateau_decrease + self._plateau_op = plateau_op + self._only_once = only_once + self._should_run_op = False + self._ever_ran = False + self._last_metric_step_seen = 0 + + @property + def keep_alive(self): + if self._only_once and self._ever_ran: + return False + return True + + def before_run(self, run_context): + del run_context + + fetches = [self._global_step_tensor] + if self._should_run_op and self.keep_alive: + fetches.append(self._plateau_op) + self._should_run_op = False + self._ever_ran = True + + return tf.train.SessionRunArgs(fetches) + + def _after_run(self, run_context, run_values, global_step, metrics): + del run_context + del run_values + del global_step + + if not self.keep_alive: + return + + if not metrics: + return + + if not list(metrics.values())[0]: + return + + # There should be only a single subdir and a single tag + steps, vals = list(metrics.values())[0][self._tags[0]] + + if not steps: + return + + last_step = steps[-1] + if last_step == self._last_metric_step_seen: + return + self._last_metric_step_seen = last_step + + if has_metric_plateaued( + steps, + vals, + num_steps=self._num_plateau_steps, + delta=self._plateau_delta, + decrease=self._plateau_decrease): + self._should_run_op = True + + +def has_metric_plateaued(steps, values, num_steps=100, delta=0.1, + decrease=True): + """Check if metric has plateaued. + + A metric has plateaued if the value has not increased/decreased (depending on + `decrease`) by `delta` for at least `num_steps`. + + Args: + steps: list list of global steps for values. + values: list list of metric values. + num_steps: int, number of steps the metric has to have been plateaued for. + delta: float, how much the metric should have changed by over num_steps. + decrease: bool, whether to check if the metric has decreased by delta or + increased by delta. + + Returns: + bool, whether the metric has plateaued. + """ + assert num_steps > 0 + if len(steps) < 2: + return False + + steps_at_least_num_steps_ago = [ + s for s in steps if s <= (steps[-1] - num_steps) + ] + if not steps_at_least_num_steps_ago: + # Not enough steps yet + return False + delta_step_idx = len(steps_at_least_num_steps_ago) - 1 + + start_val = values[delta_step_idx] + values_to_check = values[delta_step_idx:] + observed_deltas = [] + for val in values_to_check: + if decrease: + observed_delta = start_val - val + else: + observed_delta = val - start_val + observed_deltas.append(observed_delta) + + within_range = [obs < delta for obs in observed_deltas] + return all(within_range) diff --git a/tensor2tensor/utils/metrics_hook_test.py b/tensor2tensor/utils/metrics_hook_test.py new file mode 100644 index 000000000..67c78eb2d --- /dev/null +++ b/tensor2tensor/utils/metrics_hook_test.py @@ -0,0 +1,198 @@ +# coding=utf-8 +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for metrics_hook.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import contextlib +import os +import shutil + +# Dependency imports + +from tensor2tensor.utils import metrics_hook + +import tensorflow as tf + + +class DummyHook(metrics_hook.MetricsBasedHook): + + def _process_metrics(self, global_step, metrics): + if metrics: + assert "" in metrics + assert isinstance(metrics[""], dict) + if metrics[""]: + assert "global_step_1" in metrics[""] + self.test_metrics = metrics + if global_step >= 40: + return True + + +class MetricsHookTest(tf.test.TestCase): + + @classmethod + def setUpClass(cls): + cls.base_checkpoint_dir = tf.test.get_temp_dir() + shutil.rmtree(cls.base_checkpoint_dir, ignore_errors=True) + + def ckpt_dir(self, name): + return os.path.join(self.base_checkpoint_dir, name) + + @contextlib.contextmanager + def sess(self, hook, ckpt_dir): + with tf.train.MonitoredTrainingSession( + checkpoint_dir=ckpt_dir, + save_checkpoint_secs=0, + save_summaries_steps=10, + hooks=[hook]) as sess: + self._sess = sess + yield sess + + def flush(self): + self._sess._hooks[1]._summary_writer.flush() + + def testStop(self): + global_step = tf.train.create_global_step() + tf.summary.scalar("global_step", global_step) + incr_global_step = tf.assign_add(global_step, 1) + + ckpt_dir = self.ckpt_dir("stop") + dummy = DummyHook(ckpt_dir, every_n_steps=10) + with self.sess(dummy, ckpt_dir) as sess: + for _ in range(20): + sess.run(incr_global_step) + + # Summary files should now have 2 global step values in them + self.flush() + + # Run for 10 more so that the hook gets triggered again + for _ in range(10): + sess.run(incr_global_step) + + # Check that the metrics have actually been collected. + self.assertTrue("" in dummy.test_metrics) + metrics = dummy.test_metrics[""] + self.assertTrue("global_step_1" in metrics) + steps, vals = metrics["global_step_1"] + self.assertTrue(len(steps) == len(vals)) + self.assertTrue(len(steps) >= 2) + + # Run for 10 more so that the hook triggers stoppage + for _ in range(10): + sess.run(incr_global_step) + + with self.assertRaisesRegexp(RuntimeError, "after should_stop requested"): + sess.run(incr_global_step) + + def testEarlyStoppingHook(self): + global_step = tf.train.create_global_step() + counter = tf.get_variable("count", initializer=0, dtype=tf.int32) + tf.summary.scalar("count", counter) + incr_global_step = tf.assign_add(global_step, 1) + incr_counter = tf.assign_add(counter, 1) + + # Stop if the global step has not gone up by more than 1 in 20 steps. + + ckpt_dir = self.ckpt_dir("early") + stop_hook = metrics_hook.EarlyStoppingHook( + ckpt_dir, + "count_1", + num_plateau_steps=20, + plateau_delta=1., + plateau_decrease=False, + every_n_steps=10) + with self.sess(stop_hook, ckpt_dir) as sess: + for _ in range(20): + sess.run((incr_global_step, incr_counter)) + + # Summary files should now have 2 values in them + self.flush() + + # Run for more steps so that the hook gets triggered and we verify that we + # don't stop. + for _ in range(30): + sess.run((incr_global_step, incr_counter)) + + self.flush() + + # Run without incrementing the counter + for _ in range(40): + sess.run(incr_global_step) + + # Metrics should be written such that now the counter has gone >20 steps + # without being incremented. + self.flush() + + # Check that we ask for stop + with self.assertRaisesRegexp(RuntimeError, "after should_stop requested"): + for _ in range(30): + sess.run(incr_global_step) + + def testPlateauOpHook(self): + global_step = tf.train.create_global_step() + counter = tf.get_variable("count", initializer=0, dtype=tf.int32) + indicator = tf.get_variable("indicator", initializer=0, dtype=tf.int32) + tf.summary.scalar("count", counter) + incr_global_step = tf.assign_add(global_step, 1) + incr_counter = tf.assign_add(counter, 1) + incr_indicator = tf.assign_add(indicator, 1) + + # Stop if the global step has not gone up by more than 1 in 20 steps. + + ckpt_dir = self.ckpt_dir("plateauop") + stop_hook = metrics_hook.PlateauOpHook( + ckpt_dir, + "count_1", + incr_indicator, + num_plateau_steps=20, + plateau_delta=1., + plateau_decrease=False, + every_n_steps=10) + with self.sess(stop_hook, ckpt_dir) as sess: + for _ in range(20): + sess.run((incr_global_step, incr_counter)) + + # Summary files should now have 2 values in them + self.flush() + + # Run for more steps so that the hook gets triggered and we verify that we + # don't stop. + for _ in range(30): + sess.run((incr_global_step, incr_counter)) + + self.flush() + + # Run without incrementing the counter + for _ in range(30): + sess.run(incr_global_step) + self.flush() + + self.assertTrue(sess.run(indicator) < 1) + + # Metrics should be written such that now the counter has gone >20 steps + # without being incremented. + # Check that we run the incr_indicator op several times + for _ in range(3): + for _ in range(10): + sess.run(incr_global_step) + self.flush() + + self.assertTrue(sess.run(indicator) > 1) + +if __name__ == "__main__": + tf.test.main() diff --git a/tensor2tensor/utils/t2t_model.py b/tensor2tensor/utils/t2t_model.py index 26854de13..630011541 100644 --- a/tensor2tensor/utils/t2t_model.py +++ b/tensor2tensor/utils/t2t_model.py @@ -139,13 +139,15 @@ def model_fn_sharded(self, sharded_features): body_out = self.body_sharded( self._to_single_features_dict(transformed_features)) body_out, losses = self._normalize_body_output(body_out) - sharded_logits = dp(self.top, body_out, datashard_to_features) if "training" not in losses: + sharded_logits = dp(self.top, body_out, datashard_to_features) sharded_losses = dp(self.loss, sharded_logits, datashard_to_features) training_loss_dict = average_sharded_losses([{ "training": loss } for loss in sharded_losses]) losses.update(training_loss_dict) + else: + sharded_logits = body_out else: sharded_logits, sharded_losses = dp(self.model_fn, datashard_to_features) losses = average_sharded_losses(sharded_losses) @@ -172,9 +174,11 @@ def model_fn(self, features): body_out = self.body(transformed_features) output, losses = self._normalize_body_output(body_out) - logits = self.top(output, features) if "training" not in losses: + logits = self.top(output, features) losses["training"] = self.loss(logits, features) + else: + logits = output return logits, losses def bottom(self, features):