Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

v1.4.1 #484

Merged
merged 8 commits into from
Dec 22, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ env:
- T2T_DATA_DIR=/tmp/t2t-data
- T2T_TRAIN_DIR=/tmp/t2t-train
script:
- pytest --ignore=tensor2tensor/utils/registry_test.py --ignore=tensor2tensor/problems_test.py --ignore=tensor2tensor/tpu/tpu_trainer_lib_test.py
- pytest --ignore=tensor2tensor/utils/registry_test.py --ignore=tensor2tensor/problems_test.py --ignore=tensor2tensor/tpu/tpu_trainer_lib_test.py --ignore=tensor2tensor/data_generators/algorithmic_math_test.py
- pytest tensor2tensor/utils/registry_test.py
- pytest tensor2tensor/tpu/tpu_trainer_lib_test.py
- t2t-datagen 2>&1 | grep translate && echo passed
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name='tensor2tensor',
version='1.4.0',
version='1.4.1',
description='Tensor2Tensor',
author='Google Inc.',
author_email='[email protected]',
Expand All @@ -30,6 +30,7 @@
'gym',
'numpy',
'requests',
'scipy',
'sympy',
'six',
],
Expand Down
14 changes: 8 additions & 6 deletions tensor2tensor/bin/t2t-trainer
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ flags.DEFINE_string("t2t_usr_dir", "",
"The imported files should contain registrations, "
"e.g. @registry.register_model calls, that will then be "
"available to the t2t-trainer.")
flags.DEFINE_integer("random_seed", 1234, "Random seed.")
flags.DEFINE_integer("tpu_num_shards", 8, "Number of tpu shards.")
flags.DEFINE_integer("iterations_per_loop", 1000,
"Number of iterations in a TPU training loop.")
Expand All @@ -61,7 +62,11 @@ try:
flags.DEFINE_string("output_dir", "", "Base output directory for run.")
flags.DEFINE_string("schedule", "continuous_train_and_eval",
"Method of Experiment to run.")
flags.DEFINE_integer("eval_steps", 200, "Number of steps in evaluation.")
flags.DEFINE_integer("eval_steps", 10000,
"Number of steps in evaluation. By default, eval will "
"stop after eval_steps or when it runs through the eval "
"dataset once in full, whichever comes first, so this "
"can be a very large number.")
except: # pylint: disable=bare-except
pass

Expand All @@ -77,9 +82,6 @@ def create_hparams():


def create_experiment_fn():
use_validation_monitor = (FLAGS.schedule in
["train_and_evaluate", "continuous_train_and_eval"]
and FLAGS.local_eval_frequency)
return tpu_trainer_lib.create_experiment_fn(
model_name=FLAGS.model,
problem_name=get_problem_name(),
Expand All @@ -92,9 +94,9 @@ def create_experiment_fn():
decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams),
use_tfdbg=FLAGS.tfdbg,
use_dbgprofile=FLAGS.dbgprofile,
use_validation_monitor=use_validation_monitor,
eval_early_stopping_steps=FLAGS.eval_early_stopping_steps,
eval_early_stopping_metric=FLAGS.eval_early_stopping_metric,
eval_early_stopping_metric_delta=FLAGS.eval_early_stopping_metric_delta,
eval_early_stopping_metric_minimize=FLAGS.
eval_early_stopping_metric_minimize,
use_tpu=FLAGS.use_tpu)
Expand Down Expand Up @@ -170,7 +172,7 @@ def execute_schedule(exp):

def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
tf.set_random_seed(123)
tpu_trainer_lib.set_random_seed(FLAGS.random_seed)
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
log_registry()

Expand Down
14 changes: 8 additions & 6 deletions tensor2tensor/bin/t2t_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"The imported files should contain registrations, "
"e.g. @registry.register_model calls, that will then be "
"available to the t2t-trainer.")
flags.DEFINE_integer("random_seed", 1234, "Random seed.")
flags.DEFINE_integer("tpu_num_shards", 8, "Number of tpu shards.")
flags.DEFINE_integer("iterations_per_loop", 1000,
"Number of iterations in a TPU training loop.")
Expand All @@ -60,7 +61,11 @@
flags.DEFINE_string("output_dir", "", "Base output directory for run.")
flags.DEFINE_string("schedule", "continuous_train_and_eval",
"Method of Experiment to run.")
flags.DEFINE_integer("eval_steps", 200, "Number of steps in evaluation.")
flags.DEFINE_integer("eval_steps", 10000,
"Number of steps in evaluation. By default, eval will "
"stop after eval_steps or when it runs through the eval "
"dataset once in full, whichever comes first, so this "
"can be a very large number.")
except: # pylint: disable=bare-except
pass

Expand All @@ -76,9 +81,6 @@ def create_hparams():


def create_experiment_fn():
use_validation_monitor = (FLAGS.schedule in
["train_and_evaluate", "continuous_train_and_eval"]
and FLAGS.local_eval_frequency)
return tpu_trainer_lib.create_experiment_fn(
model_name=FLAGS.model,
problem_name=get_problem_name(),
Expand All @@ -91,9 +93,9 @@ def create_experiment_fn():
decode_hparams=decoding.decode_hparams(FLAGS.decode_hparams),
use_tfdbg=FLAGS.tfdbg,
use_dbgprofile=FLAGS.dbgprofile,
use_validation_monitor=use_validation_monitor,
eval_early_stopping_steps=FLAGS.eval_early_stopping_steps,
eval_early_stopping_metric=FLAGS.eval_early_stopping_metric,
eval_early_stopping_metric_delta=FLAGS.eval_early_stopping_metric_delta,
eval_early_stopping_metric_minimize=FLAGS.
eval_early_stopping_metric_minimize,
use_tpu=FLAGS.use_tpu)
Expand Down Expand Up @@ -169,7 +171,7 @@ def execute_schedule(exp):

def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
tf.set_random_seed(123)
tpu_trainer_lib.set_random_seed(FLAGS.random_seed)
usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)
log_registry()

Expand Down
1 change: 1 addition & 0 deletions tensor2tensor/data_generators/algorithmic_math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

"""Tests for tensor2tensor.data_generators.algorithmic_math."""
# TODO(rsepassi): This test is flaky. Disable, remove, or update.

from __future__ import absolute_import
from __future__ import division
Expand Down
203 changes: 34 additions & 169 deletions tensor2tensor/data_generators/librispeech.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,14 @@
"""Librispeech dataset."""

import os
from subprocess import call
import tarfile
import wave

# Dependency imports

import numpy as np

from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_encoder
from tensor2tensor.layers import common_layers
from tensor2tensor.utils import modality
from tensor2tensor.data_generators import speech_recognition
from tensor2tensor.utils import registry

import tensorflow as tf


_LIBRISPEECH_TRAIN_DATASETS = [
[
Expand Down Expand Up @@ -86,130 +77,13 @@ def _collect_data(directory, input_ext, transcription_ext):
return data_files


def _get_audio_data(filepath):
# Construct a true .wav file.
out_filepath = filepath.strip(".flac") + ".wav"
# Assumes sox is installed on system. Sox converts from FLAC to WAV.
call(["sox", filepath, out_filepath])
wav_file = wave.open(open(out_filepath))
frame_count = wav_file.getnframes()
byte_array = wav_file.readframes(frame_count)

data = np.fromstring(byte_array, np.uint8).tolist()
return data, frame_count, wav_file.getsampwidth(), wav_file.getnchannels()


class LibrispeechTextEncoder(text_encoder.TextEncoder):

def encode(self, s):
return [self._num_reserved_ids + ord(c) for c in s]

def decode(self, ids):
"""Transform a sequence of int ids into a human-readable string.

EOS is not expected in ids.

Args:
ids: list of integers to be converted.
Returns:
s: human-readable string.
"""
decoded_ids = []
for id_ in ids:
if 0 <= id_ < self._num_reserved_ids:
decoded_ids.append(text_encoder.RESERVED_TOKENS[int(id_)])
else:
decoded_ids.append(id_ - self._num_reserved_ids)
return "".join([chr(d) for d in decoded_ids])


@registry.register_audio_modality
class LibrispeechModality(modality.Modality):
"""Performs strided conv compressions for audio spectral data."""

def bottom(self, inputs):
"""Transform input from data space to model space.

Args:
inputs: A Tensor with shape [batch, ...]
Returns:
body_input: A Tensor with shape [batch, ?, ?, body_input_depth].
"""
with tf.variable_scope(self.name):
# TODO(aidangomez): Will need to sort out a better audio pipeline
def xnet_resblock(x, filters, res_relu, name):
with tf.variable_scope(name):
# We only stride along the length dimension to preserve the spectral
# bins (which are tiny in dimensionality relative to length)
y = common_layers.separable_conv_block(
x,
filters, [((1, 1), (3, 3)), ((1, 1), (3, 3))],
first_relu=True,
padding="SAME",
force2d=True,
name="sep_conv_block")
y = common_layers.pool(y, (3, 3), "MAX", "SAME", strides=(2, 1))
return y + common_layers.conv_block(
x,
filters, [((1, 1), (1, 1))],
padding="SAME",
strides=(2, 1),
first_relu=res_relu,
force2d=True,
name="res_conv0")

# Rescale from UINT8 to floats in [-1,-1]
signals = (tf.to_float(inputs)-127)/128.
signals = tf.squeeze(signals, [2, 3])

# `stfts` is a complex64 Tensor representing the short-time Fourier
# Transform of each signal in `signals`. Its shape is
# [batch_size, ?, fft_unique_bins]
# where fft_unique_bins = fft_length // 2 + 1 = 513.
stfts = tf.contrib.signal.stft(signals, frame_length=1024, frame_step=512,
fft_length=1024)

# An energy spectrogram is the magnitude of the complex-valued STFT.
# A float32 Tensor of shape [batch_size, ?, 513].
magnitude_spectrograms = tf.abs(stfts)

# Warp the linear-scale, magnitude spectrograms into the mel-scale.
num_spectrogram_bins = magnitude_spectrograms.shape[-1].value
lower_edge_hertz, upper_edge_hertz, num_mel_bins = 80.0, 7600.0, 64
sample_rate = 16000
linear_to_mel_weight_matrix = (
tf.contrib.signal.linear_to_mel_weight_matrix(
num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz,
upper_edge_hertz))
mel_spectrograms = tf.tensordot(
magnitude_spectrograms, linear_to_mel_weight_matrix, 1)
# Note: Shape inference for tensordot does not currently handle this case.
mel_spectrograms.set_shape(magnitude_spectrograms.shape[:-1].concatenate(
linear_to_mel_weight_matrix.shape[-1:]))

x = tf.expand_dims(mel_spectrograms, 2)
x.set_shape([None, None, None, num_mel_bins])
for i in xrange(self._model_hparams.audio_compression):
x = xnet_resblock(x, 2**(i + 1), True, "compress_block_%d" % i)
return xnet_resblock(x, self._body_input_depth, False,
"compress_block_final")


@registry.register_problem()
class Librispeech(problem.Problem):
"""Problem spec for English word to dictionary definition."""
class Librispeech(speech_recognition.SpeechRecognitionProblem):
"""Problem spec for Librispeech using clean and noisy data."""

@property
def is_character_level(self):
return True

@property
def input_space_id(self):
return problem.SpaceID.AUDIO_SPECTRAL

@property
def target_space_id(self):
return problem.SpaceID.EN_CHR
# Select only the clean data
TRAIN_DATASETS = _LIBRISPEECH_TRAIN_DATASETS
DEV_DATASETS = _LIBRISPEECH_TEST_DATASETS

@property
def num_shards(self):
Expand All @@ -228,26 +102,8 @@ def use_train_shards_for_dev(self):
"""If true, we only generate training data and hold out shards for dev."""
return False

def feature_encoders(self, _):
return {
"inputs": text_encoder.TextEncoder(),
"targets": LibrispeechTextEncoder(),
}

def example_reading_spec(self):
data_fields = {
"inputs": tf.VarLenFeature(tf.int64),
"targets": tf.VarLenFeature(tf.int64),
}
data_items_to_decoders = None
return (data_fields, data_items_to_decoders)

def generator(self, data_dir, tmp_dir, training,
def generator(self, data_dir, tmp_dir, datasets,
eos_list=None, start_from=0, how_many=0):
eos_list = [1] if eos_list is None else eos_list
datasets = (_LIBRISPEECH_TRAIN_DATASETS if training
else _LIBRISPEECH_TEST_DATASETS)
num_reserved_ids = self.feature_encoders(None)["targets"].num_reserved_ids
i = 0
for url, subdir in datasets:
filename = os.path.basename(url)
Expand All @@ -267,44 +123,53 @@ def generator(self, data_dir, tmp_dir, training,
data_dir = os.path.join(tmp_dir, "LibriSpeech", subdir)
data_files = _collect_data(data_dir, "flac", "txt")
data_pairs = data_files.values()

encoders = self.feature_encoders(None)
audio_encoder = encoders["waveforms"]
text_encoder = encoders["targets"]

for media_file, text_data in sorted(data_pairs)[start_from:]:
if how_many > 0 and i == how_many:
return
i += 1
audio_data, sample_count, sample_width, num_channels = _get_audio_data(
media_file)
label = [num_reserved_ids + ord(c) for c in text_data] + eos_list
yield {
"inputs": audio_data,
"audio/channel_count": [num_channels],
"audio/sample_count": [sample_count],
"audio/sample_width": [sample_width],
"targets": label
"waveforms": audio_encoder.encode(media_file),
"targets": text_encoder.encode(text_data)
}

def generate_data(self, data_dir, tmp_dir, task_id=-1):
train_paths = self.training_filepaths(
data_dir, self.num_shards, shuffled=False)
dev_paths = self.dev_filepaths(
data_dir, self.num_dev_shards, shuffled=False)

if self.use_train_shards_for_dev:
all_paths = train_paths + dev_paths
generator_utils.generate_files(
self.generator(data_dir, tmp_dir, True), all_paths)
self.generator(data_dir, tmp_dir, self.TRAIN_DATASETS), all_paths)
generator_utils.shuffle_dataset(all_paths)
else:
generator_utils.generate_dataset_and_shuffle(
self.generator(data_dir, tmp_dir, True), train_paths,
self.generator(data_dir, tmp_dir, False), dev_paths)
self.generator(data_dir, tmp_dir, self.TRAIN_DATASETS), train_paths,
self.generator(data_dir, tmp_dir, self.DEV_DATASETS), dev_paths)

def hparams(self, defaults, unused_model_hparams):
p = defaults
p.stop_at_eos = int(False)
p.input_modality = {"inputs": ("audio:librispeech_modality", None)}
p.target_modality = (registry.Modalities.SYMBOL, 256)

def preprocess_example(self, example, mode, hparams):
return example
@registry.register_problem()
class LibrispeechCleanSmall(Librispeech):
"""Problem spec for Librispeech using 100h clean train data."""

# Select only the clean data
TRAIN_DATASETS = _LIBRISPEECH_TRAIN_DATASETS[:1]
DEV_DATASETS = _LIBRISPEECH_TEST_DATASETS[:1]


@registry.register_problem()
class LibrispeechClean(Librispeech):
"""Problem spec for Librispeech using 460h clean train data."""

# Select only the clean data
TRAIN_DATASETS = _LIBRISPEECH_TRAIN_DATASETS[:2]
DEV_DATASETS = _LIBRISPEECH_TEST_DATASETS[:1]


# TODO(lukaszkaiser): clean up hparams or remove from here.
Expand Down
Loading