Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Seq2seq example #63

Merged
merged 62 commits into from
Aug 31, 2017
Merged
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
81413a1
minor fix
keisukefukuda Apr 13, 2017
7cadc00
Merged https://github.com/pfnet/chainer/pull/2555
keisukefukuda Apr 17, 2017
553e121
Cache pre-processed input data
keisukefukuda Apr 17, 2017
6025ec0
Changed report trigger to 'epoch'
keisukefukuda Apr 17, 2017
74fc0ca
merge master
keisukefukuda Apr 19, 2017
11af9b8
recovered seq2seq example after merging master
keisukefukuda Apr 19, 2017
9dc41f6
Merge branch 'master' into seq2seq
keisukefukuda Apr 28, 2017
b935372
Merge branch 'master' into seq2seq
keisukefukuda May 2, 2017
3f115ce
passes flake8 and autopep8 check
keisukefukuda May 2, 2017
9cfac7e
Merge branch 'master' into seq2seq
keisukefukuda May 8, 2017
5e90bcc
Merge branch 'master' into seq2seq
keisukefukuda May 9, 2017
5a298d3
Merge branch 'master' into seq2seq
keisukefukuda May 17, 2017
147f35c
fix to work with Chainer2
keisukefukuda May 17, 2017
0759bc1
added BleuEvaluator
keisukefukuda May 17, 2017
648d124
fix flake8
keisukefukuda May 17, 2017
2e4c4d8
added updated seq2seq.py and europal.py from the latest chainer repos…
keisukefukuda May 18, 2017
dc38939
reflected the latest commits to Chainer's main repository
keisukefukuda May 18, 2017
7dc4e81
Merge branch 'master' into seq2seq
keisukefukuda May 23, 2017
a37114a
edit code so make it easier to compare with the original seq2seq.py
keisukefukuda May 23, 2017
6050513
Renamed variables to be similar to the original seq2seq example
keisukefukuda May 23, 2017
0006530
Removed get_epoch_trigger
keisukefukuda May 23, 2017
c8080b0
fix flake8
keisukefukuda May 23, 2017
0d179f7
multi-node evaluator
keisukefukuda May 23, 2017
d8b4942
fix flake8/pep8
keisukefukuda May 23, 2017
72f513e
minor fix
keisukefukuda May 24, 2017
badeaa4
Merge branch 'master' into seq2seq
keisukefukuda Jun 8, 2017
3f6749e
WIP
keisukefukuda Jun 14, 2017
bb155d9
Merge branch 'master' into seq2seq
keisukefukuda Jun 16, 2017
5d72622
some minor fixes
keisukefukuda Jul 10, 2017
278d43d
minor fix
keisukefukuda Aug 14, 2017
2fc93d6
added BLEU evaluator & minor options
keisukefukuda Aug 14, 2017
db773b0
Merge branch 'seq2seq' of github.com:pfnet/chainermn into seq2seq
keisukefukuda Aug 14, 2017
df9c099
removed 'train=False' argument
keisukefukuda Aug 14, 2017
b9ca386
minor fix
keisukefukuda Aug 16, 2017
61fe839
Merge branch 'master' into seq2seq
keisukefukuda Aug 16, 2017
876637c
Merge branch 'master' into seq2seq
keisukefukuda Aug 17, 2017
2550e58
minor fixes
keisukefukuda Aug 17, 2017
a8e64a6
Merge branch 'master' into seq2seq
keisukefukuda Aug 18, 2017
55f6130
added README.md in examples/seq2seq directory
keisukefukuda Aug 18, 2017
dac5bc7
Merge branch 'seq2seq' of github.com:pfnet/chainermn into seq2seq
keisukefukuda Aug 18, 2017
6212243
added create_optimizer() function
keisukefukuda Aug 23, 2017
7476c1d
Merge branch 'better-err-msg-in-large-scatter' into seq2seq
keisukefukuda Aug 24, 2017
5f9c40e
added a support DataSizeError in scatter_dataset (based on PR#111)
keisukefukuda Aug 24, 2017
f243bf9
Merge branch 'master' into seq2seq
keisukefukuda Aug 28, 2017
944d507
Merge branch 'master' into seq2seq
keisukefukuda Aug 28, 2017
c6e4076
minor fix
keisukefukuda Aug 30, 2017
f2ee572
removed redundant import
keisukefukuda Aug 30, 2017
3fa6d29
added comments, removed debug print and fix pep8
keisukefukuda Aug 30, 2017
51fefe1
added README
keisukefukuda Aug 30, 2017
ad0106b
removed unused code
keisukefukuda Aug 30, 2017
695477d
minor fix
keisukefukuda Aug 30, 2017
72cf8a1
Merge branch 'better-err-msg-in-large-scatter' into seq2seq
keisukefukuda Aug 30, 2017
9d44a1e
fix pep8
keisukefukuda Aug 30, 2017
a45f8bc
Merge branch 'better-err-msg-in-large-scatter' into seq2seq
keisukefukuda Aug 31, 2017
dc56485
Removed unnecessary file
keisukefukuda Aug 31, 2017
4b924d6
added _get_num_split() and _slices() to calculate best partitioning o…
keisukefukuda Aug 31, 2017
d5774c1
Merge branch 'better-err-msg-in-large-scatter' into seq2seq
keisukefukuda Aug 31, 2017
0566d77
fix pep8
keisukefukuda Aug 31, 2017
fa74e3e
added a docstring
keisukefukuda Aug 31, 2017
db86083
bugfix
keisukefukuda Aug 31, 2017
1fd9fb4
Merge branch 'better-err-msg-in-large-scatter' into seq2seq
keisukefukuda Aug 31, 2017
fa5bd26
Removed unnecessary file
keisukefukuda Aug 31, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions chainermn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from chainermn.communicators import create_communicator # NOQA
from chainermn.dataset import scatter_dataset # NOQA
from chainermn.dataset import DataSizeError # NOQA
from chainermn.link import MultiNodeChainList # NOQA
from chainermn.multi_node_evaluator import create_multi_node_evaluator # NOQA
from chainermn.multi_node_optimizer import create_multi_node_optimizer # NOQA
Expand Down
61 changes: 58 additions & 3 deletions chainermn/dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,53 @@
import chainer.datasets
import math
import numpy
import re
import warnings

import chainer.datasets


class DataSizeError(RuntimeError):
def __init__(self, ds_size, pickled_size):
msg = """The dataset was too large to be scattered using MPI.

The length of the dataset is {} and it's size after being pickled
was {}. In the current MPI specification, the size cannot exceed
{}, which is so called 'INT_MAX'.

To solve this problem, please split the dataset into multiple
peaces and send/recv them separately.

Recommended sizes are indicated by ``slices()`` method.
"""

INT_MAX = 2147483647
msg = msg.format(ds_size, pickled_size, INT_MAX)
super(DataSizeError, self).__init__(self, msg)

self.pickled_size = pickled_size
self.max_size = INT_MAX
self.dataset_len = ds_size

def num_split(self):
ps = self.pickled_size
mx = self.max_size
return (ps + mx - 1) // mx

def slices(self):
ds = self.dataset_len
nsplit = self.num_split()
size = math.ceil(ds / nsplit)

return [(b, min(e, ds)) for b, e in
((i * size, (i + 1) * size) for i in range(0, nsplit))]


def _parse_overflow_error(err):
msg = str(err)
m = re.search(r'integer (\d+) does not fit in', msg)
assert m is not None, "'{}' must include size of the message".format(msg)
return int(m.group(1))


def scatter_dataset(dataset, comm, root=0, shuffle=False, seed=None):
"""Scatter the given dataset to the workers in the communicator.
Expand Down Expand Up @@ -53,10 +99,19 @@ def scatter_dataset(dataset, comm, root=0, shuffle=False, seed=None):
if i == root:
mine = subds
else:
comm.send(subds, dest=i)
try:
comm.send(subds, dest=i)
except OverflowError as e:
pickled_size = _parse_overflow_error(e)
raise DataSizeError(len(dataset), pickled_size)

return mine
else:
return comm.recv(source=root)
try:
return comm.recv(source=0)
except OverflowError as e:
pickled_size = _parse_overflow_error(e)
raise DataSizeError(len(dataset), pickled_size)


def get_n_iterations_for_one_epoch(dataset, local_batch_size, comm):
Expand Down
32 changes: 32 additions & 0 deletions examples/seq2seq/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# ChainerMN seq2seq example

An sample implementation of seq2seq model.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An sample -> A sample


## Data download and setup

First, go to http://www.statmt.org/wmt15/translation-task.html#download and donwload necessary dataset.
Let's assume you are in a working directory called `$WMT_DIR`.

```
$ cd $WMT_DIR
$ wget http://www.statmt.org/wmt10/training-giga-fren.tar
$ wget http://www.statmt.org/wmt15/dev-v2.tgz
$ tar -xf training-giga-fren.tar
$ tar -xf dev-v2.tgz
$ ls
dev/ dev-v2.tgz giga-fren.release2.fixed.en.gz giga-fren.release2.fixed.fr.gz training-giga-fren.tar

```

Next, you need to install required packages.

```
$ pip install nltk progressbar2

## Run

```bash
$ cd $CHAINERMN
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need a little more lines to run the script 😉

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed!

```


82 changes: 82 additions & 0 deletions examples/seq2seq/europal.orig.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from __future__ import unicode_literals

import collections
import gzip
import io
import os
import re

import numpy
import progressbar


split_pattern = re.compile(r'([.,!?"\':;)(])')
digit_pattern = re.compile(r'\d')


def split_sentence(s):
s = s.lower()
s = s.replace('\u2019', "'")
s = digit_pattern.sub('0', s)
words = []
for word in s.strip().split():
words.extend(split_pattern.split(word))
words = [w for w in words if w]
return words


def open_file(path):
if path.endswith('.gz'):
return gzip.open(path, 'rt', 'utf-8')
else:
# Find gzipped version of the file
gz = path + '.gz'
if os.path.exists(gz):
return open_file(gz)
else:
return io.open(path, encoding='utf-8', errors='ignore')


def count_lines(path):
with open_file(path) as f:
return sum([1 for _ in f])


def read_file(path):
n_lines = count_lines(path)
bar = progressbar.ProgressBar()
with open_file(path) as f:
for line in bar(f, max_value=n_lines):
words = split_sentence(line)
yield words


def count_words(path):
counts = collections.Counter()
for words in read_file(path):
for word in words:
counts[word] += 1

vocab = [word for (word, _) in counts.most_common(40000)]
return vocab


def make_dataset(path, vocab):
word_id = {word: index for index, word in enumerate(vocab)}
dataset = []
token_count = 0
unknown_count = 0
for words in read_file(path):
array = make_array(word_id, words)
dataset.append(array)
token_count += array.size
unknown_count += (array == 1).sum()
print('# of tokens: %d' % token_count)
print('# of unknown: %d (%.2f %%)'
% (unknown_count, 100. * unknown_count / token_count))
return dataset


def make_array(word_id, words):
ids = [word_id.get(word, 1) for word in words]
return numpy.array(ids, 'i')
88 changes: 88 additions & 0 deletions examples/seq2seq/europal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from __future__ import unicode_literals

import collections
import gzip
import io
import os
import re

import numpy
import progressbar


split_pattern = re.compile(r'([.,!?"\':;)(])')
digit_pattern = re.compile(r'\d')


def split_sentence(s):
s = s.lower()
s = s.replace('\u2019', "'")
s = digit_pattern.sub('0', s)
words = []
for word in s.strip().split():
words.extend(split_pattern.split(word))
words = [w for w in words if w]
return words


def open_file(path):
if path.endswith('.gz'):
return gzip.open(path, 'rt', encoding='utf-8')
else:
# Find gzipped version of the file
gz = path + '.gz'
if os.path.exists(gz):
return open_file(gz)
else:
return io.open(path, encoding='utf-8', errors='ignore')


def count_lines(path):
print(path)
with open_file(path) as f:
return sum([1 for _ in f])


def read_file(path):
n_lines = count_lines(path)
bar = progressbar.ProgressBar()
with open_file(path) as f:
for line in bar(f, max_value=n_lines):
words = split_sentence(line)
yield words


def count_words(path):
counts = collections.Counter()
for words in read_file(path):
for word in words:
counts[word] += 1

vocab = [word for (word, _) in counts.most_common(40000)]
return vocab


def make_dataset(path, vocab):
word_id = {word: index for index, word in enumerate(vocab)}
dataset = []
token_count = 0
unknown_count = 0
for words in read_file(path):
array = make_array(word_id, words)
dataset.append(array)
token_count += array.size
unknown_count += (array == 1).sum()
print('# of tokens: %d' % token_count)
print('# of unknown: %d (%.2f %%)'
% (unknown_count, 100. * unknown_count / token_count))
return dataset


def make_array(word_id, words):
ids = [word_id.get(word, 1) for word in words]
return numpy.array(ids, 'i')


if __name__ == '__main__':
vocab = count_words('wmt/giga-fren.release2.fixed.en')
make_dataset('wmt/giga-fren.release2.fixed.en', vocab)
Loading