diff --git a/tensor2tensor/bin/t2t-datagen b/tensor2tensor/bin/t2t-datagen old mode 100755 new mode 100644 index 67890371b..2ac0f0db2 --- a/tensor2tensor/bin/t2t-datagen +++ b/tensor2tensor/bin/t2t-datagen @@ -112,7 +112,7 @@ _SUPPORTED_PROBLEM_GENERATORS = { vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15), lambda: audio.timit_generator( FLAGS.data_dir, FLAGS.tmp_dir, False, 626, - vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15)), + vocab_filename="vocab.endefr.%d" % 2**15, vocab_size=2**15)), } # pylint: enable=g-long-lambda diff --git a/tensor2tensor/bin/t2t-decoder b/tensor2tensor/bin/t2t-decoder index 5f05f5bcb..c2bf97f94 100644 --- a/tensor2tensor/bin/t2t-decoder +++ b/tensor2tensor/bin/t2t-decoder @@ -47,8 +47,9 @@ flags = tf.flags FLAGS = flags.FLAGS flags.DEFINE_string("output_dir", "", "Training directory to load from.") -flags.DEFINE_string("decode_from_file", None, "Path to the source file for decoding") -flags.DEFINE_string("decode_to_file", None, "Path to the decoded (output) file") +flags.DEFINE_string("decode_from_file", None, "Path to decode file") +flags.DEFINE_string("decode_to_file", None, + "Path prefix to inference output file") flags.DEFINE_bool("decode_interactive", False, "Interactive local inference mode.") flags.DEFINE_integer("decode_shards", 1, "Number of decoding replicas.") diff --git a/tensor2tensor/bin/t2t-trainer b/tensor2tensor/bin/t2t-trainer index fc37f27ab..5a2866da6 100644 --- a/tensor2tensor/bin/t2t-trainer +++ b/tensor2tensor/bin/t2t-trainer @@ -59,7 +59,7 @@ flags.DEFINE_string("output_dir", "", "Base output directory for run.") flags.DEFINE_string("master", "", "Address of TensorFlow master.") flags.DEFINE_string("schedule", "train_and_evaluate", "Method of tf.contrib.learn.Experiment to run.") -flags.DEFINE_bool("profile", False, "Profile performance?") + def main(_): tf.logging.set_verbosity(tf.logging.INFO) @@ -83,26 +83,13 @@ def main(_): problem.generate_data(data_dir, tmp_dir) # Run the trainer. - def run_experiment(): - trainer_utils.run( + trainer_utils.run( data_dir=data_dir, model=FLAGS.model, output_dir=output_dir, train_steps=FLAGS.train_steps, eval_steps=FLAGS.eval_steps, schedule=FLAGS.schedule) - - if FLAGS.profile: - with tf.contrib.tfprof.ProfileContext('t2tprof', - trace_steps=range(100), - dump_steps=range(100)) as pctx: - opts = tf.profiler.ProfileOptionBuilder.time_and_memory() - pctx.add_auto_profiling('op', opts, range(100)) - - run_experiment() - - else: - run_experiment() if __name__ == "__main__": diff --git a/tensor2tensor/data_generators/all_problems.py b/tensor2tensor/data_generators/all_problems.py index 2aca3d377..c7f364cf1 100644 --- a/tensor2tensor/data_generators/all_problems.py +++ b/tensor2tensor/data_generators/all_problems.py @@ -28,7 +28,6 @@ from tensor2tensor.data_generators import ice_parsing from tensor2tensor.data_generators import image from tensor2tensor.data_generators import imdb -from tensor2tensor.data_generators import librispeech from tensor2tensor.data_generators import lm1b from tensor2tensor.data_generators import multinli from tensor2tensor.data_generators import problem_hparams diff --git a/tensor2tensor/data_generators/cnn_dailymail.py b/tensor2tensor/data_generators/cnn_dailymail.py index 05b2a1f37..239d1af99 100644 --- a/tensor2tensor/data_generators/cnn_dailymail.py +++ b/tensor2tensor/data_generators/cnn_dailymail.py @@ -19,10 +19,9 @@ from __future__ import division from __future__ import print_function -import io +import hashlib import os import tarfile -import hashlib # Dependency imports @@ -47,7 +46,7 @@ # Train/Dev/Test Splits for summarization data _TRAIN_URLS = "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt" _DEV_URLS = "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_val.txt" -_TEST_URLS = "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_test.txt" +_TEST_URLS = "https://github.com/abisee/cnn-dailymail/blob/master/url_lists/all_test.txt" # End-of-sentence marker. @@ -129,7 +128,9 @@ def generate_hash(inp): return filelist -def example_generator(all_files, urls_path, sum_token): + +def example_generator(tmp_dir, is_training, sum_token): + """Generate examples.""" def fix_run_on_sents(line): if u"@highlight" in line: return line @@ -139,6 +140,7 @@ def fix_run_on_sents(line): return line return line + u"." + all_files, urls_path = _maybe_download_corpora(tmp_dir, is_training) filelist = example_splits(urls_path, all_files) story_summary_split_token = u" " if sum_token else " " @@ -168,29 +170,13 @@ def fix_run_on_sents(line): yield " ".join(story) + story_summary_split_token + " ".join(summary) + def _story_summary_split(story): split_str = u" " split_str_len = len(split_str) split_pos = story.find(split_str) return story[:split_pos], story[split_pos+split_str_len:] # story, summary -def write_raw_text_to_files(all_files, urls_path, data_dir, tmp_dir, is_training): - def write_to_file(all_files, urls_path, data_dir, filename): - with io.open(os.path.join(data_dir, filename+".source"), "w") as fstory, io.open(os.path.join(data_dir, filename+".target"), "w") as fsummary: - for example in example_generator(all_files, urls_path, sum_token=True): - story, summary = _story_summary_split(example) - fstory.write(story+"\n") - fsummary.write(summary+"\n") - - filename = "cnndm.train" if is_training else "cnndm.dev" - tf.logging.info("Writing %s" % filename) - write_to_file(all_files, urls_path, data_dir, filename) - - if not is_training: - test_urls_path = generator_utils.maybe_download(tmp_dir, "all_test.txt", _TEST_URLS) - filename = "cnndm.test" - tf.logging.info("Writing %s" % filename) - write_to_file(all_files, test_urls_path, data_dir, filename) @registry.register_problem class SummarizeCnnDailymail32k(problem.Text2TextProblem): @@ -233,12 +219,10 @@ def use_train_shards_for_dev(self): return False def generator(self, data_dir, tmp_dir, is_training): - all_files, urls_path = _maybe_download_corpora(tmp_dir, is_training) encoder = generator_utils.get_or_generate_vocab_inner( data_dir, self.vocab_file, self.targeted_vocab_size, - example_generator(all_files, urls_path, sum_token=False)) - write_raw_text_to_files(all_files, urls_path, data_dir, tmp_dir, is_training) - for example in example_generator(all_files, urls_path, sum_token=True): + example_generator(tmp_dir, is_training, sum_token=False)) + for example in example_generator(tmp_dir, is_training, sum_token=True): story, summary = _story_summary_split(example) encoded_summary = encoder.encode(summary) + [EOS] encoded_story = encoder.encode(story) + [EOS] diff --git a/tensor2tensor/data_generators/librispeech.py b/tensor2tensor/data_generators/librispeech.py deleted file mode 100644 index de7ed94cc..000000000 --- a/tensor2tensor/data_generators/librispeech.py +++ /dev/null @@ -1,310 +0,0 @@ -from tensor2tensor.data_generators import problem -from tensor2tensor.utils import registry -from tensor2tensor.models import transformer -from tensor2tensor.utils import modality -from tensor2tensor.layers import common_layers -from tensor2tensor.data_generators import text_encoder -import random -import tensorflow as tf -import numpy as np -from tensor2tensor.data_generators import generator_utils -import os -from subprocess import call -import tarfile -import wave - - -_LIBRISPEECH_TRAIN_DATASETS = [ - [ - "http://www.openslr.org/resources/12/train-clean-100.tar.gz", # pylint: disable=line-too-long - "train-clean-100" - ], - [ - "http://www.openslr.org/resources/12/train-clean-360.tar.gz", - "train-clean-360" - ], - [ - "http://www.openslr.org/resources/12/train-other-500.tar.gz", - "train-other-500" - ], -] -_LIBRISPEECH_TEST_DATASETS = [ - [ - "http://www.openslr.org/resources/12/dev-clean.tar.gz", - "dev-clean" - ], - [ - "http://www.openslr.org/resources/12/dev-other.tar.gz", - "dev-other" - ], -] - - -def _collect_data(directory, input_ext, transcription_ext): - """Traverses directory collecting input and target files.""" - # Directory from string to tuple pair of strings - # key: the filepath to a datafile including the datafile's basename. Example, - # if the datafile was "/path/to/datafile.wav" then the key would be - # "/path/to/datafile" - # value: a pair of strings (media_filepath, label) - data_files = dict() - for root, _, filenames in os.walk(directory): - transcripts = [filename for filename in filenames if transcription_ext in filename] - for transcript in transcripts: - basename = transcript.strip(transcription_ext) - transcript_path = os.path.join(root, transcript) - with open(transcript_path, 'r') as transcript_file: - for transcript_line in transcript_file: - line_contents = transcript_line.split(" ", 1) - assert len(line_contents) == 2 - media_base, label = line_contents - key = os.path.join(root, media_base) - assert key not in data_files - media_name = "%s.%s"%(media_base, input_ext) - media_path = os.path.join(root, media_name) - data_files[key] = (media_path, label) - return data_files - - -def _get_audio_data(filepath): - # Construct a true .wav file. - out_filepath = filepath.strip(".flac") + ".wav" - # Assumes sox is installed on system. Sox converts from FLAC to WAV. - call(["sox", filepath, out_filepath]) - wav_file = wave.open(open(out_filepath)) - frame_count = wav_file.getnframes() - byte_array = wav_file.readframes(frame_count) - - data = np.fromstring(byte_array, np.uint8).tolist() - return data, frame_count, wav_file.getsampwidth(), wav_file.getnchannels() - - -class LibrispeechTextEncoder(text_encoder.TextEncoder): - - def encode(self, s): - return [self._num_reserved_ids + ord(c) for c in s] - - def decode(self, ids): - """Transform a sequence of int ids into a human-readable string. - EOS is not expected in ids. - Args: - ids: list of integers to be converted. - Returns: - s: human-readable string. - """ - decoded_ids = [] - for id_ in ids: - if 0 <= id_ < self._num_reserved_ids: - decoded_ids.append(RESERVED_TOKENS[int(id_)]) - else: - decoded_ids.append(id_ - self._num_reserved_ids) - return "".join([chr(d) for d in decoded_ids]) - - - -@registry.register_audio_modality -class LibrispeechModality(modality.Modality): - """Performs strided conv compressions for audio spectral data.""" - - def bottom(self, inputs): - """Transform input from data space to model space. - Args: - inputs: A Tensor with shape [batch, ...] - Returns: - body_input: A Tensor with shape [batch, ?, ?, body_input_depth]. - """ - with tf.variable_scope(self.name): - # TODO(aidangomez): Will need to sort out a better audio pipeline - def xnet_resblock(x, filters, res_relu, name): - with tf.variable_scope(name): - # We only stride along the length dimension to preserve the spectral - # bins (which are tiny in dimensionality relative to length) - y = common_layers.separable_conv_block( - x, - filters, [((1, 1), (3, 3)), ((1, 1), (3, 3))], - first_relu=True, - padding="SAME", - force2d=True, - name="sep_conv_block") - y = common_layers.pool(y, (3, 3), "MAX", "SAME", strides=(2, 1)) - return y + common_layers.conv_block( - x, - filters, [((1, 1), (1, 1))], - padding="SAME", - strides=(2, 1), - first_relu=res_relu, - force2d=True, - name="res_conv0") - - # Rescale from UINT8 to floats in [-1,-1] - signals = (tf.to_float(inputs)-127)/128. - #signals = tf.contrib.framework.nest.flatten(signals) - signals = tf.squeeze(signals, [2, 3]) - - # `stfts` is a complex64 Tensor representing the Short-time Fourier Transform of - # each signal in `signals`. Its shape is [batch_size, ?, fft_unique_bins] - # where fft_unique_bins = fft_length // 2 + 1 = 513. - stfts = tf.contrib.signal.stft(signals, frame_length=1024, frame_step=512, - fft_length=1024) - - # An energy spectrogram is the magnitude of the complex-valued STFT. - # A float32 Tensor of shape [batch_size, ?, 513]. - magnitude_spectrograms = tf.abs(stfts) - - log_offset = 1e-6 - log_magnitude_spectrograms = tf.log(magnitude_spectrograms + log_offset) - - # Warp the linear-scale, magnitude spectrograms into the mel-scale. - num_spectrogram_bins = magnitude_spectrograms.shape[-1].value - lower_edge_hertz, upper_edge_hertz, num_mel_bins = 80.0, 7600.0, 64 - sample_rate = 16000 - linear_to_mel_weight_matrix = tf.contrib.signal.linear_to_mel_weight_matrix( - num_mel_bins, num_spectrogram_bins, sample_rate, lower_edge_hertz, - upper_edge_hertz) - mel_spectrograms = tf.tensordot( - magnitude_spectrograms, linear_to_mel_weight_matrix, 1) - # Note: Shape inference for `tf.tensordot` does not currently handle this case. - mel_spectrograms.set_shape(magnitude_spectrograms.shape[:-1].concatenate( - linear_to_mel_weight_matrix.shape[-1:])) - - # Try without the conversion to MFCCs, first. - '''num_mfccs = 13 - # Keep the first `num_mfccs` MFCCs. - mfccs = tf.contrib.signal.mfccs_from_log_mel_spectrograms( - log_mel_spectrograms)[..., :num_mfccs]''' - - x = tf.expand_dims(mel_spectrograms, 2) - x.set_shape([None, None, None, num_mel_bins]) - for i in xrange(self._model_hparams.audio_compression): - x = xnet_resblock(x, 2**(i + 1), True, "compress_block_%d" % i) - return xnet_resblock(x, self._body_input_depth, False, - "compress_block_final") - - -@registry.register_problem() -class Librispeech(problem.Problem): - """Problem spec for English word to dictionary definition.""" - - @property - def is_character_level(self): - return True - - @property - def input_space_id(self): - return problem.SpaceID.AUDIO_SPECTRAL - - @property - def target_space_id(self): - return problem.SpaceID.EN_CHR - - @property - def num_shards(self): - return 100 - - @property - def use_subword_tokenizer(self): - return False - - @property - def num_dev_shards(self): - return 1 - - @property - def use_train_shards_for_dev(self): - """If true, we only generate training data and hold out shards for dev.""" - return False - - def feature_encoders(self, _): - return { - "inputs": text_encoder.TextEncoder(), - "targets": LibrispeechTextEncoder(), - } - - def example_reading_spec(self): - data_fields = { - "inputs": tf.VarLenFeature(tf.int64), - #"audio/channel_count": tf.FixedLenFeature([], tf.int64), - #"audio/sample_count": tf.FixedLenFeature([], tf.int64), - #"audio/sample_width": tf.FixedLenFeature([], tf.int64), - "targets": tf.VarLenFeature(tf.int64), - } - data_items_to_decoders = None - return (data_fields, data_items_to_decoders) - - - def generator(self, data_dir, tmp_dir, training, eos_list=None, start_from=0, how_many=0): - eos_list = [1] if eos_list is None else eos_list - datasets = (_LIBRISPEECH_TRAIN_DATASETS if training else _LIBRISPEECH_TEST_DATASETS) - num_reserved_ids = self.feature_encoders(None)["targets"].num_reserved_ids - i = 0 - for url, subdir in datasets: - filename = os.path.basename(url) - compressed_file = generator_utils.maybe_download(tmp_dir, filename, url) - - read_type = "r:gz" if filename.endswith("tgz") else "r" - with tarfile.open(compressed_file, read_type) as corpus_tar: - # Create a subset of files that don't already exist. - # tarfile.extractall errors when encountering an existing file - # and tarfile.extract is extremely slow - members = [] - for f in corpus_tar: - if not os.path.isfile(os.path.join(tmp_dir, f.name)): - members.append(f) - corpus_tar.extractall(tmp_dir, members=members) - - data_dir = os.path.join(tmp_dir, "LibriSpeech", subdir) - data_files = _collect_data(data_dir, "flac", "txt") - data_pairs = data_files.values() - for media_file, text_data in sorted(data_pairs)[start_from:]: - if how_many > 0 and i == how_many: - return - i += 1 - audio_data, sample_count, sample_width, num_channels = _get_audio_data( - media_file) - label = [num_reserved_ids + ord(c) for c in text_data] + eos_list - yield { - "inputs": audio_data, - "audio/channel_count": [num_channels], - "audio/sample_count": [sample_count], - "audio/sample_width": [sample_width], - "targets": label - } - - - def generate_data(self, data_dir, tmp_dir, task_id=-1): - train_paths = self.training_filepaths(data_dir, self.num_shards, shuffled=False) - dev_paths = self.dev_filepaths(data_dir, self.num_dev_shards, shuffled=False) - if self.use_train_shards_for_dev: - all_paths = train_paths + dev_paths - generator_utils.generate_files(self.generator(data_dir, tmp_dir, True), all_paths) - generator_utils.shuffle_dataset(all_paths) - else: - generator_utils.generate_dataset_and_shuffle( - self.generator(data_dir, tmp_dir, True), train_paths, - self.generator(data_dir, tmp_dir, False), dev_paths) - - - def hparams(self, defaults, unused_model_hparams): - p = defaults - p.stop_at_eos = int(False) - p.input_modality = { "inputs": ("audio:librispeech_modality", None) } - p.target_modality = (registry.Modalities.SYMBOL, 256) - - def preprocess_example(self, example, mode, hparams): - return example - -# TODO: clean up hparams -@registry.register_hparams -def librispeech_hparams(): - hparams = transformer.transformer_base_single_gpu() # Or whatever you'd like to build off. - hparams.batch_size = 36 - hparams.audio_compression = 8 - hparams.hidden_size = 2048 - hparams.max_input_seq_length = 600000 - hparams.max_target_seq_length = 350 - hparams.max_length = hparams.max_input_seq_length - hparams.min_length_bucket = hparams.max_input_seq_length // 2 - hparams.learning_rate = 0.05 - hparams.train_steps = 5000000 - hparams.num_hidden_layers = 4 - return hparams diff --git a/tensor2tensor/models/transformer_sketch.py b/tensor2tensor/models/transformer_sketch.py index 45384f065..b6bbb7708 100644 --- a/tensor2tensor/models/transformer_sketch.py +++ b/tensor2tensor/models/transformer_sketch.py @@ -66,7 +66,7 @@ def transformer_sketch(): hparams.learning_rate = 0.2 hparams.learning_rate_warmup_steps = 10000 hparams.num_hidden_layers = 6 - hparams.initializer = "orthogonal" + # hparams.initializer = "orthogonal" hparams.sampling_method = "random" return hparams diff --git a/tensor2tensor/utils/decoding.py b/tensor2tensor/utils/decoding.py index d825df6f2..629b2ed26 100644 --- a/tensor2tensor/utils/decoding.py +++ b/tensor2tensor/utils/decoding.py @@ -83,9 +83,9 @@ def log_decode_results(inputs, decoded_targets = None if identity_output: - decoded_outputs = "".join(map(str, outputs.flatten())) + decoded_outputs = " ".join(map(str, outputs.flatten())) if targets is not None: - decoded_targets = "".join(map(str, targets.flatten())) + decoded_targets = " ".join(map(str, targets.flatten())) else: decoded_outputs = targets_vocab.decode(_save_until_eos(outputs, is_image)) if targets is not None: @@ -252,14 +252,17 @@ def input_fn(): # _decode_batch_input_fn sorted_inputs.reverse() decodes.reverse() - # If decode_to_file was provided use it as the output filename without any change - # (except for adding shard_id if using more shards for decoding). - # Otherwise, use the input filename plus model, hp, problem, beam, alpha. - decode_filename = decode_to_file if decode_to_file else filename + # Dumping inputs and outputs to file filename.decodes in + # format result\tinput in the same order as original inputs + if decode_to_file: + output_filename = decode_to_file + else: + output_filename = filename if decode_hp.shards > 1: - decode_filename = decode_filename + ("%.2d" % decode_hp.shard_id) - if not decode_to_file: - decode_filename = _decode_filename(decode_filename, problem_name, decode_hp) + base_filename = output_filename + ("%.2d" % decode_hp.shard_id) + else: + base_filename = output_filename + decode_filename = _decode_filename(base_filename, problem_name, decode_hp) tf.logging.info("Writing decodes into %s" % decode_filename) outfile = tf.gfile.Open(decode_filename, "w") for index in range(len(sorted_inputs)): diff --git a/tensor2tensor/utils/get_cnndm_rouge.sh b/tensor2tensor/utils/get_cnndm_rouge.sh deleted file mode 100644 index 0f52bb56c..000000000 --- a/tensor2tensor/utils/get_cnndm_rouge.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/bash - -# Path to moses dir -mosesdecoder=$1 - -# Path to file containing gold summaries, one per line -targets_file=$2 -# Path to file containing model generated summaries, one per line -decodes_file=$3 - -# Tokenize. -perl $mosesdecoder/scripts/tokenizer/tokenizer.perl -l en < $targets_file > $targets_file.tok -perl $mosesdecoder/scripts/tokenizer/tokenizer.perl -l en < $decodes_file > $decodes_file.tok - -# Get rouge scores -python get_rouge.py --decodes_filename $decodes_file.tok --targets_filename $targets_file.tok diff --git a/tensor2tensor/utils/get_rouge.py b/tensor2tensor/utils/get_rouge.py deleted file mode 100644 index c15545cfd..000000000 --- a/tensor2tensor/utils/get_rouge.py +++ /dev/null @@ -1,88 +0,0 @@ -# coding=utf-8 -# Copyright 2017 The Tensor2Tensor Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Computing rouge scores using pyrouge.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import os -import logging -import shutil -from tempfile import mkdtemp -from pprint import pprint - -# Dependency imports -from pyrouge import Rouge155 - -import numpy as np -import tensorflow as tf - -FLAGS = tf.flags.FLAGS - -tf.flags.DEFINE_string("decodes_filename", None, "File containing model generated summaries tokenized") -tf.flags.DEFINE_string("targets_filename", None, "File containing model target summaries tokenized") - -def write_to_file(filename, data): - data = ".\n".join(data.split(". ")) - with open(filename, "w") as fp: - fp.write(data) - -def prep_data(decode_dir, target_dir): - with open(FLAGS.decodes_filename, "rb") as fdecodes, open(FLAGS.targets_filename, "rb") as ftargets: - for i, (d, t) in enumerate(zip(fdecodes, ftargets)): - write_to_file(os.path.join(decode_dir, "rouge.%06d.txt" % (i+1)), d) - write_to_file(os.path.join(target_dir, "rouge.A.%06d.txt" % (i+1)), t) - - if (i+1 % 1000) == 0: - tf.logging.into("Written %d examples to file" % i) - -def main(_): - rouge = Rouge155() - rouge.log.setLevel(logging.ERROR) - rouge.system_filename_pattern = "rouge.(\d+).txt" - rouge.model_filename_pattern = "rouge.[A-Z].#ID#.txt" - - tf.logging.set_verbosity(tf.logging.INFO) - - tmpdir = mkdtemp() - tf.logging.info("tmpdir: %s" % tmpdir) - # system = decodes/predictions - system_dir = os.path.join(tmpdir, 'system') - # model = targets/gold - model_dir = os.path.join(tmpdir, 'model') - os.mkdir(system_dir) - os.mkdir(model_dir) - - rouge.system_dir = system_dir - rouge.model_dir = model_dir - - prep_data(rouge.system_dir, rouge.model_dir) - - rouge_scores = rouge.convert_and_evaluate() - rouge_scores = rouge.output_to_dict(rouge_scores) - for prefix in ["rouge_1", "rouge_2", "rouge_l"]: - for suffix in ["f_score", "precision", "recall"]: - key = "_".join([prefix, suffix]) - tf.logging.info("%s: %.4f" % (key, rouge_scores[key])) - - # clean up after pyrouge - shutil.rmtree(tmpdir) - shutil.rmtree(rouge._config_dir) - shutil.rmtree(os.path.split(rouge._system_dir)[0]) - -if __name__=='__main__': - tf.app.run()