Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Raw text data generation for train, dev and test; script to compute r…
Browse files Browse the repository at this point in the history
…ouge using the official pyrouge dist; a bash script to tokenize targets and predictions before computing rouge
  • Loading branch information
urvashik committed Nov 14, 2017
2 parents 5233253 + a82231e commit 3f52fb9
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 11 deletions.
34 changes: 25 additions & 9 deletions tensor2tensor/data_generators/cnn_dailymail.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@
from __future__ import division
from __future__ import print_function

import hashlib
import io
import os
import tarfile
import hashlib

# Dependency imports

Expand All @@ -46,7 +47,7 @@
# Train/Dev/Test Splits for summarization data
_TRAIN_URLS = "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt"
_DEV_URLS = "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_val.txt"
_TEST_URLS = "https://github.com/abisee/cnn-dailymail/blob/master/url_lists/all_test.txt"
_TEST_URLS = "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_test.txt"


# End-of-sentence marker.
Expand Down Expand Up @@ -128,9 +129,7 @@ def generate_hash(inp):

return filelist


def example_generator(tmp_dir, is_training, sum_token):
"""Generate examples."""
def example_generator(all_files, urls_path, sum_token):
def fix_run_on_sents(line):
if u"@highlight" in line:
return line
Expand All @@ -140,7 +139,6 @@ def fix_run_on_sents(line):
return line
return line + u"."

all_files, urls_path = _maybe_download_corpora(tmp_dir, is_training)
filelist = example_splits(urls_path, all_files)
story_summary_split_token = u" <summary> " if sum_token else " "

Expand Down Expand Up @@ -170,13 +168,29 @@ def fix_run_on_sents(line):

yield " ".join(story) + story_summary_split_token + " ".join(summary)


def _story_summary_split(story):
split_str = u" <summary> "
split_str_len = len(split_str)
split_pos = story.find(split_str)
return story[:split_pos], story[split_pos+split_str_len:] # story, summary

def write_raw_text_to_files(all_files, urls_path, data_dir, tmp_dir, is_training):
def write_to_file(all_files, urls_path, data_dir, filename):
with io.open(os.path.join(data_dir, filename+".source"), "w") as fstory, io.open(os.path.join(data_dir, filename+".target"), "w") as fsummary:
for example in example_generator(all_files, urls_path, sum_token=True):
story, summary = _story_summary_split(example)
fstory.write(story+"\n")
fsummary.write(summary+"\n")

filename = "cnndm.train" if is_training else "cnndm.dev"
tf.logging.info("Writing %s" % filename)
write_to_file(all_files, urls_path, data_dir, filename)

if not is_training:
test_urls_path = generator_utils.maybe_download(tmp_dir, "all_test.txt", _TEST_URLS)
filename = "cnndm.test"
tf.logging.info("Writing %s" % filename)
write_to_file(all_files, test_urls_path, data_dir, filename)

@registry.register_problem
class SummarizeCnnDailymail32k(problem.Text2TextProblem):
Expand Down Expand Up @@ -219,10 +233,12 @@ def use_train_shards_for_dev(self):
return False

def generator(self, data_dir, tmp_dir, is_training):
all_files, urls_path = _maybe_download_corpora(tmp_dir, is_training)
encoder = generator_utils.get_or_generate_vocab_inner(
data_dir, self.vocab_file, self.targeted_vocab_size,
example_generator(tmp_dir, is_training, sum_token=False))
for example in example_generator(tmp_dir, is_training, sum_token=True):
example_generator(all_files, urls_path, sum_token=False))
write_raw_text_to_files(all_files, urls_path, data_dir, tmp_dir, is_training)
for example in example_generator(all_files, urls_path, sum_token=True):
story, summary = _story_summary_split(example)
encoded_summary = encoder.encode(summary) + [EOS]
encoded_story = encoder.encode(story) + [EOS]
Expand Down
4 changes: 2 additions & 2 deletions tensor2tensor/utils/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def log_decode_results(inputs,

decoded_targets = None
if identity_output:
decoded_outputs = " ".join(map(str, outputs.flatten()))
decoded_outputs = "".join(map(str, outputs.flatten()))
if targets is not None:
decoded_targets = " ".join(map(str, targets.flatten()))
decoded_targets = "".join(map(str, targets.flatten()))
else:
decoded_outputs = targets_vocab.decode(_save_until_eos(outputs, is_image))
if targets is not None:
Expand Down
16 changes: 16 additions & 0 deletions tensor2tensor/utils/get_cnndm_rouge.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/bin/bash

# Path to moses dir
mosesdecoder=$1

# Path to file containing gold summaries, one per line
targets_file=$2
# Path to file containing model generated summaries, one per line
decodes_file=$3

# Tokenize.
perl $mosesdecoder/scripts/tokenizer/tokenizer.perl -l en < $targets_file > $targets_file.tok
perl $mosesdecoder/scripts/tokenizer/tokenizer.perl -l en < $decodes_file > $decodes_file.tok

# Get rouge scores
python get_rouge.py --decodes_filename $decodes_file.tok --targets_filename $targets_file.tok
88 changes: 88 additions & 0 deletions tensor2tensor/utils/get_rouge.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# coding=utf-8
# Copyright 2017 The Tensor2Tensor Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Computing rouge scores using pyrouge."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import logging
import shutil
from tempfile import mkdtemp
from pprint import pprint

# Dependency imports
from pyrouge import Rouge155

import numpy as np
import tensorflow as tf

FLAGS = tf.flags.FLAGS

tf.flags.DEFINE_string("decodes_filename", None, "File containing model generated summaries tokenized")
tf.flags.DEFINE_string("targets_filename", None, "File containing model target summaries tokenized")

def write_to_file(filename, data):
data = ".\n".join(data.split(". "))
with open(filename, "w") as fp:
fp.write(data)

def prep_data(decode_dir, target_dir):
with open(FLAGS.decodes_filename, "rb") as fdecodes, open(FLAGS.targets_filename, "rb") as ftargets:
for i, (d, t) in enumerate(zip(fdecodes, ftargets)):
write_to_file(os.path.join(decode_dir, "rouge.%06d.txt" % (i+1)), d)
write_to_file(os.path.join(target_dir, "rouge.A.%06d.txt" % (i+1)), t)

if (i+1 % 1000) == 0:
tf.logging.into("Written %d examples to file" % i)

def main(_):
rouge = Rouge155()
rouge.log.setLevel(logging.ERROR)
rouge.system_filename_pattern = "rouge.(\d+).txt"
rouge.model_filename_pattern = "rouge.[A-Z].#ID#.txt"

tf.logging.set_verbosity(tf.logging.INFO)

tmpdir = mkdtemp()
tf.logging.info("tmpdir: %s" % tmpdir)
# system = decodes/predictions
system_dir = os.path.join(tmpdir, 'system')
# model = targets/gold
model_dir = os.path.join(tmpdir, 'model')
os.mkdir(system_dir)
os.mkdir(model_dir)

rouge.system_dir = system_dir
rouge.model_dir = model_dir

prep_data(rouge.system_dir, rouge.model_dir)

rouge_scores = rouge.convert_and_evaluate()
rouge_scores = rouge.output_to_dict(rouge_scores)
for prefix in ["rouge_1", "rouge_2", "rouge_l"]:
for suffix in ["f_score", "precision", "recall"]:
key = "_".join([prefix, suffix])
tf.logging.info("%s: %.4f" % (key, rouge_scores[key]))

# clean up after pyrouge
shutil.rmtree(tmpdir)
shutil.rmtree(rouge._config_dir)
shutil.rmtree(os.path.split(rouge._system_dir)[0])

if __name__=='__main__':
tf.app.run()

0 comments on commit 3f52fb9

Please sign in to comment.