Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Move Estimator input_fn and model_fn construction into Problem and T2…
Browse files Browse the repository at this point in the history
…TModel, respectively, which allows subclassing

PiperOrigin-RevId: 177229237
  • Loading branch information
Ryan Sepassi committed Nov 29, 2017
1 parent 5adacd0 commit 398e85b
Show file tree
Hide file tree
Showing 5 changed files with 467 additions and 373 deletions.
117 changes: 117 additions & 0 deletions tensor2tensor/data_generators/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import six
from tensor2tensor.data_generators import generator_utils
from tensor2tensor.data_generators import text_encoder
from tensor2tensor.utils import data_reader
from tensor2tensor.utils import metrics
from tensor2tensor.utils import registry
import tensorflow as tf
Expand Down Expand Up @@ -457,6 +458,90 @@ def feature_info(self):
self._feature_info = features
return features

def make_estimator_input_fn(self, mode, hparams):

def estimator_input_fn(params, config):
return self.input_pipeline(mode, hparams, params=params, config=config)

return estimator_input_fn

def input_pipeline(self, mode, hparams, params=None, config=None):
"""Builds input pipeline for problem.
Args:
mode: tf.estimator.ModeKeys
hparams: HParams, model hparams
params: dict, may include "batch_size"
config: RunConfig; if passed, should include t2t_device_info dict
Returns:
(features_dict<str name, Tensor feature>, Tensor targets)
"""
tf.logging.warning("Problem.input_pipeline implements a subset of "
"input_fn_builder.build_input_fn and is currently only "
"used in tpu_trainer.")
is_training = mode == tf.estimator.ModeKeys.TRAIN
num_threads = 4 if is_training else 1
batch_size = _get_batch_size(params, hparams, config)

def valid_size(example):
return data_reader.example_valid_size(example, hparams.min_length,
hparams.max_length)

def define_shapes(example):
"""Set the right shapes for the features."""
inputs = example["inputs"]
targets = example["targets"]

# Ensure inputs and targets are proper rank.
while len(inputs.get_shape()) < 4:
inputs = tf.expand_dims(inputs, axis=-1)
while len(targets.get_shape()) < 4:
targets = tf.expand_dims(targets, axis=-1)

example["inputs"] = inputs
example["targets"] = targets

# Ensure batch size is set on all features
for _, t in six.iteritems(example):
shape = t.get_shape().as_list()
shape[0] = batch_size
t.set_shape(t.get_shape().merge_with(shape))
# Assert shapes are fully known
t.get_shape().assert_is_fully_defined()

return example

# Read and preprocess
data_dir = hparams.data_dir
dataset = self.dataset(
mode=mode, data_dir=data_dir, num_threads=num_threads, hparams=hparams)
dataset = dataset.map(
data_reader.cast_int64_to_int32, num_threads=num_threads)
if is_training:
dataset = dataset.repeat(None)

# Batch (and pad)
# TODO(rsepassi): Add support for bucketing by length
if _are_shapes_fully_defined(dataset.output_shapes):
dataset = dataset.apply(
tf.contrib.data.batch_and_drop_remainder(batch_size))
else:
# If shapes are not fully defined, filter out long ones and pad to
# hparams.max_length
dataset = dataset.filter(valid_size)
padded_shapes = _fill_shape_nones(
dataset.output_shapes, none_filler=hparams.max_length)
dataset = dataset.apply(
tf.contrib.data.padded_batch_and_drop_remainder(batch_size,
padded_shapes))

dataset = dataset.map(define_shapes, num_parallel_calls=num_threads)
dataset = dataset.prefetch(1)
features = dataset.make_one_shot_iterator().get_next()

return features, features["targets"]


class FeatureInfo(object):

Expand Down Expand Up @@ -693,3 +778,35 @@ def eval_metrics(self):
metrics.Metrics.APPROX_BLEU, metrics.Metrics.ROUGE_2_F,
metrics.Metrics.ROUGE_L_F
]


def _are_shapes_fully_defined(shapes_dict):
for shape in shapes_dict.values():
if not shape.is_fully_defined():
return False
return True


def _get_batch_size(params, hparams, config):
"""Batch size determined by params dict, HParams, and RunConfig."""
# If params specifies batch size, use that. TPUEstimator passes batch size in
# params.
batch_size = params and params.get("batch_size")

# If not set, then we're running on CPU/GPU, so use the batch size from the
# hparams, and multiply by the number of data shards.
if not batch_size:
batch_size = hparams.tpu_batch_size_per_shard
if config:
batch_size *= config.t2t_device_info["num_shards"]

return batch_size


def _fill_shape_nones(shapes_dict, none_filler=None):
padded_shapes = {}
for key, shape in six.iteritems(shapes_dict):
padded_shapes[key] = [
(dim if dim is not None else none_filler) for dim in shape.as_list()
]
return padded_shapes
2 changes: 1 addition & 1 deletion tensor2tensor/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _fast_decode(self,
inputs = features["inputs"]
batch_size = common_layers.shape_list(inputs)[0]
target_modality = self._problem_hparams.target_modality
if t2t_model.is_class_modality(target_modality):
if target_modality.is_class_modality:
decode_length = 1
else:
decode_length = common_layers.shape_list(inputs)[1] + decode_length
Expand Down
Loading

0 comments on commit 398e85b

Please sign in to comment.