Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
Generating raw data files, completed pipeline for rouge
Browse files Browse the repository at this point in the history
  • Loading branch information
urvashik committed Nov 14, 2017
1 parent f711de9 commit a82231e
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 10 deletions.
29 changes: 24 additions & 5 deletions tensor2tensor/data_generators/cnn_dailymail.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import division
from __future__ import print_function

import io
import os
import tarfile
import hashlib
Expand All @@ -45,7 +46,7 @@
# Train/Dev/Test Splits for summarization data
_TRAIN_URLS = "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_train.txt"
_DEV_URLS = "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_val.txt"
_TEST_URLS = "https://github.com/abisee/cnn-dailymail/blob/master/url_lists/all_test.txt"
_TEST_URLS = "https://raw.githubusercontent.com/abisee/cnn-dailymail/master/url_lists/all_test.txt"

# End-of-sentence marker.
EOS = text_encoder.EOS_ID
Expand Down Expand Up @@ -117,14 +118,13 @@ def generate_hash(inp):

return filelist

def example_generator(tmp_dir, is_training, sum_token):
def example_generator(all_files, urls_path, sum_token):
def fix_run_on_sents(line):
if u"@highlight" in line: return line
if line=="": return line
if line[-1] in END_TOKENS: return line
return line + u"."

all_files, urls_path = _maybe_download_corpora(tmp_dir, is_training)
filelist = example_splits(urls_path, all_files)
story_summary_split_token = u" <summary> " if sum_token else " "

Expand Down Expand Up @@ -156,6 +156,23 @@ def _story_summary_split(story):
split_pos = story.find(split_str)
return story[:split_pos], story[split_pos+split_str_len:] # story, summary

def write_raw_text_to_files(all_files, urls_path, data_dir, tmp_dir, is_training):
def write_to_file(all_files, urls_path, data_dir, filename):
with io.open(os.path.join(data_dir, filename+".source"), "w") as fstory, io.open(os.path.join(data_dir, filename+".target"), "w") as fsummary:
for example in example_generator(all_files, urls_path, sum_token=True):
story, summary = _story_summary_split(example)
fstory.write(story+"\n")
fsummary.write(summary+"\n")

filename = "cnndm.train" if is_training else "cnndm.dev"
tf.logging.info("Writing %s" % filename)
write_to_file(all_files, urls_path, data_dir, filename)

if not is_training:
test_urls_path = generator_utils.maybe_download(tmp_dir, "all_test.txt", _TEST_URLS)
filename = "cnndm.test"
tf.logging.info("Writing %s" % filename)
write_to_file(all_files, test_urls_path, data_dir, filename)

@registry.register_problem
class SummarizeCnnDailymail32k(problem.Text2TextProblem):
Expand Down Expand Up @@ -198,10 +215,12 @@ def use_train_shards_for_dev(self):
return False

def generator(self, data_dir, tmp_dir, is_training):
all_files, urls_path = _maybe_download_corpora(tmp_dir, is_training)
encoder = generator_utils.get_or_generate_vocab_inner(
data_dir, self.vocab_file, self.targeted_vocab_size,
example_generator(tmp_dir, is_training, sum_token=False))
for example in example_generator(tmp_dir, is_training, sum_token=True):
example_generator(all_files, urls_path, sum_token=False))
write_raw_text_to_files(all_files, urls_path, data_dir, tmp_dir, is_training)
for example in example_generator(all_files, urls_path, sum_token=True):
story, summary = _story_summary_split(example)
encoded_summary = encoder.encode(summary) + [EOS]
encoded_story = encoder.encode(story) + [EOS]
Expand Down
3 changes: 3 additions & 0 deletions tensor2tensor/utils/get_cnndm_rouge.sh
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#!/bin/bash

# Path to moses dir
mosesdecoder=$1

# Path to file containing gold summaries, one per line
targets_file=$2
# Path to file containing model generated summaries, one per line
decodes_file=$3

# Tokenize.
Expand Down
7 changes: 2 additions & 5 deletions tensor2tensor/utils/get_rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,7 @@
tf.flags.DEFINE_string("targets_filename", None, "File containing model target summaries tokenized")

def write_to_file(filename, data):
# TODO: ensure the output format (chars split by spaces) was as intended
data = ".\n".join(data.split(". "))
if len(data.strip()) == 0:
print(data, filename)
with open(filename, "w") as fp:
fp.write(data)

Expand All @@ -63,9 +60,9 @@ def main(_):

tmpdir = mkdtemp()
tf.logging.info("tmpdir: %s" % tmpdir)
# system = decodes
# system = decodes/predictions
system_dir = os.path.join(tmpdir, 'system')
# model = gold
# model = targets/gold
model_dir = os.path.join(tmpdir, 'model')
os.mkdir(system_dir)
os.mkdir(model_dir)
Expand Down

0 comments on commit a82231e

Please sign in to comment.