Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
More tiny sketch fixes
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 175621997
  • Loading branch information
Ryan Sepassi committed Nov 29, 2017
1 parent 92983ea commit 6cf47f9
Show file tree
Hide file tree
Showing 10 changed files with 28 additions and 468 deletions.
2 changes: 1 addition & 1 deletion tensor2tensor/bin/t2t-datagen
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ _SUPPORTED_PROBLEM_GENERATORS = {
vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15),
lambda: audio.timit_generator(
FLAGS.data_dir, FLAGS.tmp_dir, False, 626,
vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15)),
vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15)),
}

# pylint: enable=g-long-lambda
Expand Down
5 changes: 3 additions & 2 deletions tensor2tensor/bin/t2t-decoder
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ flags = tf.flags
FLAGS = flags.FLAGS

flags.DEFINE_string("output_dir", "", "Training directory to load from.")
flags.DEFINE_string("decode_from_file", None, "Path to the source file for decoding")
flags.DEFINE_string("decode_to_file", None, "Path to the decoded (output) file")
flags.DEFINE_string("decode_from_file", None, "Path to decode file")
flags.DEFINE_string("decode_to_file", None,
"Path prefix to inference output file")
flags.DEFINE_bool("decode_interactive", False,
"Interactive local inference mode.")
flags.DEFINE_integer("decode_shards", 1, "Number of decoding replicas.")
Expand Down
17 changes: 2 additions & 15 deletions tensor2tensor/bin/t2t-trainer
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ flags.DEFINE_string("output_dir", "", "Base output directory for run.")
flags.DEFINE_string("master", "", "Address of TensorFlow master.")
flags.DEFINE_string("schedule", "train_and_evaluate",
"Method of tf.contrib.learn.Experiment to run.")
flags.DEFINE_bool("profile", False, "Profile performance?")


def main(_):
tf.logging.set_verbosity(tf.logging.INFO)
Expand All @@ -83,26 +83,13 @@ def main(_):
problem.generate_data(data_dir, tmp_dir)

# Run the trainer.
def run_experiment():
trainer_utils.run(
trainer_utils.run(
data_dir=data_dir,
model=FLAGS.model,
output_dir=output_dir,
train_steps=FLAGS.train_steps,
eval_steps=FLAGS.eval_steps,
schedule=FLAGS.schedule)

if FLAGS.profile:
with tf.contrib.tfprof.ProfileContext('t2tprof',
trace_steps=range(100),
dump_steps=range(100)) as pctx:
opts = tf.profiler.ProfileOptionBuilder.time_and_memory()
pctx.add_auto_profiling('op', opts, range(100))

run_experiment()

else:
run_experiment()


if __name__ == "__main__":
Expand Down
1 change: 0 additions & 1 deletion tensor2tensor/data_generators/all_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
from tensor2tensor.data_generators import ice_parsing
from tensor2tensor.data_generators import image
from tensor2tensor.data_generators import imdb
from tensor2tensor.data_generators import librispeech
from tensor2tensor.data_generators import lm1b
from tensor2tensor.data_generators import multinli
from tensor2tensor.data_generators import problem_hparams
Expand Down
34 changes: 9 additions & 25 deletions tensor2tensor/data_generators/cnn_dailymail.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@
from __future__ import division
from __future__ import print_function

import io
import hashlib
import os
import tarfile
import hashlib

# Dependency imports

Expand All @@ -47,7 +46,7 @@
# Train/Dev/Test Splits for summarization data
_TRAIN_URLS = "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt"
_DEV_URLS = "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_val.txt"
_TEST_URLS = "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_test.txt"
_TEST_URLS = "https://github.com/abisee/cnn-dailymail/blob/master/url_lists/all_test.txt"


# End-of-sentence marker.
Expand Down Expand Up @@ -129,7 +128,9 @@ def generate_hash(inp):

return filelist

def example_generator(all_files, urls_path, sum_token):

def example_generator(tmp_dir, is_training, sum_token):
"""Generate examples."""
def fix_run_on_sents(line):
if u"@highlight" in line:
return line
Expand All @@ -139,6 +140,7 @@ def fix_run_on_sents(line):
return line
return line + u"."

all_files, urls_path = _maybe_download_corpora(tmp_dir, is_training)
filelist = example_splits(urls_path, all_files)
story_summary_split_token = u" <summary> " if sum_token else " "

Expand Down Expand Up @@ -168,29 +170,13 @@ def fix_run_on_sents(line):

yield " ".join(story) + story_summary_split_token + " ".join(summary)


def _story_summary_split(story):
split_str = u" <summary> "
split_str_len = len(split_str)
split_pos = story.find(split_str)
return story[:split_pos], story[split_pos+split_str_len:] # story, summary

def write_raw_text_to_files(all_files, urls_path, data_dir, tmp_dir, is_training):
def write_to_file(all_files, urls_path, data_dir, filename):
with io.open(os.path.join(data_dir, filename+".source"), "w") as fstory, io.open(os.path.join(data_dir, filename+".target"), "w") as fsummary:
for example in example_generator(all_files, urls_path, sum_token=True):
story, summary = _story_summary_split(example)
fstory.write(story+"\n")
fsummary.write(summary+"\n")

filename = "cnndm.train" if is_training else "cnndm.dev"
tf.logging.info("Writing %s" % filename)
write_to_file(all_files, urls_path, data_dir, filename)

if not is_training:
test_urls_path = generator_utils.maybe_download(tmp_dir, "all_test.txt", _TEST_URLS)
filename = "cnndm.test"
tf.logging.info("Writing %s" % filename)
write_to_file(all_files, test_urls_path, data_dir, filename)

@registry.register_problem
class SummarizeCnnDailymail32k(problem.Text2TextProblem):
Expand Down Expand Up @@ -233,12 +219,10 @@ def use_train_shards_for_dev(self):
return False

def generator(self, data_dir, tmp_dir, is_training):
all_files, urls_path = _maybe_download_corpora(tmp_dir, is_training)
encoder = generator_utils.get_or_generate_vocab_inner(
data_dir, self.vocab_file, self.targeted_vocab_size,
example_generator(all_files, urls_path, sum_token=False))
write_raw_text_to_files(all_files, urls_path, data_dir, tmp_dir, is_training)
for example in example_generator(all_files, urls_path, sum_token=True):
example_generator(tmp_dir, is_training, sum_token=False))
for example in example_generator(tmp_dir, is_training, sum_token=True):
story, summary = _story_summary_split(example)
encoded_summary = encoder.encode(summary) + [EOS]
encoded_story = encoder.encode(story) + [EOS]
Expand Down
Loading

0 comments on commit 6cf47f9

Please sign in to comment.