diff --git a/tensor2tensor/data_generators/cnn_dailymail.py b/tensor2tensor/data_generators/cnn_dailymail.py index 239d1af99..05b2a1f37 100644 --- a/tensor2tensor/data_generators/cnn_dailymail.py +++ b/tensor2tensor/data_generators/cnn_dailymail.py @@ -19,9 +19,10 @@ from __future__ import division from __future__ import print_function -import hashlib +import io import os import tarfile +import hashlib # Dependency imports @@ -46,7 +47,7 @@ # Train/Dev/Test Splits for summarization data _TRAIN_URLS = "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt" _DEV_URLS = "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_val.txt" -_TEST_URLS = "https://github.com/abisee/cnn-dailymail/blob/master/url_lists/all_test.txt" +_TEST_URLS = "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_test.txt" # End-of-sentence marker. @@ -128,9 +129,7 @@ def generate_hash(inp): return filelist - -def example_generator(tmp_dir, is_training, sum_token): - """Generate examples.""" +def example_generator(all_files, urls_path, sum_token): def fix_run_on_sents(line): if u"@highlight" in line: return line @@ -140,7 +139,6 @@ def fix_run_on_sents(line): return line return line + u"." - all_files, urls_path = _maybe_download_corpora(tmp_dir, is_training) filelist = example_splits(urls_path, all_files) story_summary_split_token = u" " if sum_token else " " @@ -170,13 +168,29 @@ def fix_run_on_sents(line): yield " ".join(story) + story_summary_split_token + " ".join(summary) - def _story_summary_split(story): split_str = u" " split_str_len = len(split_str) split_pos = story.find(split_str) return story[:split_pos], story[split_pos+split_str_len:] # story, summary +def write_raw_text_to_files(all_files, urls_path, data_dir, tmp_dir, is_training): + def write_to_file(all_files, urls_path, data_dir, filename): + with io.open(os.path.join(data_dir, filename+".source"), "w") as fstory, io.open(os.path.join(data_dir, filename+".target"), "w") as fsummary: + for example in example_generator(all_files, urls_path, sum_token=True): + story, summary = _story_summary_split(example) + fstory.write(story+"\n") + fsummary.write(summary+"\n") + + filename = "cnndm.train" if is_training else "cnndm.dev" + tf.logging.info("Writing %s" % filename) + write_to_file(all_files, urls_path, data_dir, filename) + + if not is_training: + test_urls_path = generator_utils.maybe_download(tmp_dir, "all_test.txt", _TEST_URLS) + filename = "cnndm.test" + tf.logging.info("Writing %s" % filename) + write_to_file(all_files, test_urls_path, data_dir, filename) @registry.register_problem class SummarizeCnnDailymail32k(problem.Text2TextProblem): @@ -219,10 +233,12 @@ def use_train_shards_for_dev(self): return False def generator(self, data_dir, tmp_dir, is_training): + all_files, urls_path = _maybe_download_corpora(tmp_dir, is_training) encoder = generator_utils.get_or_generate_vocab_inner( data_dir, self.vocab_file, self.targeted_vocab_size, - example_generator(tmp_dir, is_training, sum_token=False)) - for example in example_generator(tmp_dir, is_training, sum_token=True): + example_generator(all_files, urls_path, sum_token=False)) + write_raw_text_to_files(all_files, urls_path, data_dir, tmp_dir, is_training) + for example in example_generator(all_files, urls_path, sum_token=True): story, summary = _story_summary_split(example) encoded_summary = encoder.encode(summary) + [EOS] encoded_story = encoder.encode(story) + [EOS] diff --git a/tensor2tensor/utils/decoding.py b/tensor2tensor/utils/decoding.py index d6dc5f1db..d825df6f2 100644 --- a/tensor2tensor/utils/decoding.py +++ b/tensor2tensor/utils/decoding.py @@ -83,9 +83,9 @@ def log_decode_results(inputs, decoded_targets = None if identity_output: - decoded_outputs = " ".join(map(str, outputs.flatten())) + decoded_outputs = "".join(map(str, outputs.flatten())) if targets is not None: - decoded_targets = " ".join(map(str, targets.flatten())) + decoded_targets = "".join(map(str, targets.flatten())) else: decoded_outputs = targets_vocab.decode(_save_until_eos(outputs, is_image)) if targets is not None: diff --git a/tensor2tensor/utils/get_cnndm_rouge.sh b/tensor2tensor/utils/get_cnndm_rouge.sh new file mode 100644 index 000000000..0f52bb56c --- /dev/null +++ b/tensor2tensor/utils/get_cnndm_rouge.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +# Path to moses dir +mosesdecoder=$1 + +# Path to file containing gold summaries, one per line +targets_file=$2 +# Path to file containing model generated summaries, one per line +decodes_file=$3 + +# Tokenize. +perl $mosesdecoder/scripts/tokenizer/tokenizer.perl -l en < $targets_file > $targets_file.tok +perl $mosesdecoder/scripts/tokenizer/tokenizer.perl -l en < $decodes_file > $decodes_file.tok + +# Get rouge scores +python get_rouge.py --decodes_filename $decodes_file.tok --targets_filename $targets_file.tok diff --git a/tensor2tensor/utils/get_rouge.py b/tensor2tensor/utils/get_rouge.py new file mode 100644 index 000000000..c15545cfd --- /dev/null +++ b/tensor2tensor/utils/get_rouge.py @@ -0,0 +1,88 @@ +# coding=utf-8 +# Copyright 2017 The Tensor2Tensor Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Computing rouge scores using pyrouge.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import logging +import shutil +from tempfile import mkdtemp +from pprint import pprint + +# Dependency imports +from pyrouge import Rouge155 + +import numpy as np +import tensorflow as tf + +FLAGS = tf.flags.FLAGS + +tf.flags.DEFINE_string("decodes_filename", None, "File containing model generated summaries tokenized") +tf.flags.DEFINE_string("targets_filename", None, "File containing model target summaries tokenized") + +def write_to_file(filename, data): + data = ".\n".join(data.split(". ")) + with open(filename, "w") as fp: + fp.write(data) + +def prep_data(decode_dir, target_dir): + with open(FLAGS.decodes_filename, "rb") as fdecodes, open(FLAGS.targets_filename, "rb") as ftargets: + for i, (d, t) in enumerate(zip(fdecodes, ftargets)): + write_to_file(os.path.join(decode_dir, "rouge.%06d.txt" % (i+1)), d) + write_to_file(os.path.join(target_dir, "rouge.A.%06d.txt" % (i+1)), t) + + if (i+1 % 1000) == 0: + tf.logging.into("Written %d examples to file" % i) + +def main(_): + rouge = Rouge155() + rouge.log.setLevel(logging.ERROR) + rouge.system_filename_pattern = "rouge.(\d+).txt" + rouge.model_filename_pattern = "rouge.[A-Z].#ID#.txt" + + tf.logging.set_verbosity(tf.logging.INFO) + + tmpdir = mkdtemp() + tf.logging.info("tmpdir: %s" % tmpdir) + # system = decodes/predictions + system_dir = os.path.join(tmpdir, 'system') + # model = targets/gold + model_dir = os.path.join(tmpdir, 'model') + os.mkdir(system_dir) + os.mkdir(model_dir) + + rouge.system_dir = system_dir + rouge.model_dir = model_dir + + prep_data(rouge.system_dir, rouge.model_dir) + + rouge_scores = rouge.convert_and_evaluate() + rouge_scores = rouge.output_to_dict(rouge_scores) + for prefix in ["rouge_1", "rouge_2", "rouge_l"]: + for suffix in ["f_score", "precision", "recall"]: + key = "_".join([prefix, suffix]) + tf.logging.info("%s: %.4f" % (key, rouge_scores[key])) + + # clean up after pyrouge + shutil.rmtree(tmpdir) + shutil.rmtree(rouge._config_dir) + shutil.rmtree(os.path.split(rouge._system_dir)[0]) + +if __name__=='__main__': + tf.app.run()