Skip to content

Commit

Permalink
Update Transformer reader (PaddlePaddle#163)
Browse files Browse the repository at this point in the history
* try to fix sampler

* update
  • Loading branch information
FrostML authored Mar 22, 2021
1 parent b5cc5ac commit 3a70ec6
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 335 deletions.
207 changes: 34 additions & 173 deletions benchmark/transformer/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,25 +66,46 @@ def convert_samples(sample):

return source, target

def _max_token_fn(current_idx, current_batch_size, tokens_sofar,
data_source):
return max(tokens_sofar,
len(data_source[current_idx][0]) + 1,
len(data_source[current_idx][1]) + 1)

def _key(size_so_far, minibatch_len):
return size_so_far * minibatch_len

data_loaders = [(None)] * 2
for i, dataset in enumerate(datasets):
dataset = dataset.map(convert_samples, lazy=False).filter(
partial(
min_max_filer, max_len=args.max_length))
batch_sampler = TransformerBatchSampler(
dataset=dataset,

sampler = SamplerHelper(dataset)

if args.sort_type == SortType.GLOBAL:
src_key = (lambda x, data_source: len(data_source[x][0]) + 1)
trg_key = (lambda x, data_source: len(data_source[x][1]) + 1)
# Sort twice
sampler = sampler.sort(key=trg_key).sort(key=src_key)
else:
if args.shuffle:
sampler = sampler.shuffle(seed=args.shuffle_seed)
max_key = (lambda x, data_source: max(len(data_source[x][0]), len(data_source[x][1])) + 1)
if args.sort_type == SortType.POOL:
sampler = sampler.sort(key=max_key, buffer_size=args.pool_size)

batch_sampler = sampler.batch(
batch_size=args.batch_size,
pool_size=args.pool_size,
sort_type=args.sort_type,
shuffle=args.shuffle,
shuffle_batch=args.shuffle_batch,
use_token_batch=True,
max_length=args.max_length,
distribute_mode=True if i == 0 else False,
world_size=dist.get_world_size(),
rank=dist.get_rank(),
pad_seq=args.pad_seq,
bsz_multi=args.bsz_multi)
drop_last=False,
batch_size_fn=_max_token_fn,
key=_key)

if args.shuffle_batch:
batch_sampler = batch_sampler.shuffle(seed=args.shuffle_seed)

if i == 0:
batch_sampler = batch_sampler.shard()

data_loader = DataLoader(
dataset=dataset,
Expand Down Expand Up @@ -196,163 +217,3 @@ class SortType(object):
GLOBAL = 'global'
POOL = 'pool'
NONE = "none"


class SentenceBatchCreator(object):
def __init__(self, batch_size):
self.batch = []
self._batch_size = batch_size

def append(self, info):
self.batch.append(info)
if len(self.batch) == self._batch_size:
tmp = self.batch
self.batch = []
return tmp


class TokenBatchCreator(object):
def __init__(self, batch_size, bsz_multi=1):
self._batch = []
self.max_len = -1
self._batch_size = batch_size
self._bsz_multi = bsz_multi

def append(self, info):
cur_len = info.max_len
max_len = max(self.max_len, cur_len)
if max_len * (len(self._batch) + 1) > self._batch_size:
# Make sure the batch size won't be empty.
mode_len = max(
len(self._batch) // self._bsz_multi * self._bsz_multi,
len(self._batch) % self._bsz_multi)
result = self._batch[:mode_len]
self._batch = self._batch[mode_len:]
self._batch.append(info)
self.max_len = max([b.max_len for b in self._batch])
return result
else:
self.max_len = max_len
self._batch.append(info)

@property
def batch(self):
return self._batch


class SampleInfo(object):
def __init__(self, i, lens, pad_seq=1):
self.i = i
# Take bos and eos into account
self.min_len = min(lens[0], lens[1]) + 1
self.max_len = (max(lens[0], lens[1]) + pad_seq) // pad_seq * pad_seq
self.seq_max_len = max(lens[0], lens[1]) + 1
self.src_len = lens[0] + 1
self.trg_len = lens[1] + 1


class TransformerBatchSampler(BatchSampler):
def __init__(self,
dataset,
batch_size,
pool_size=10000,
sort_type=SortType.NONE,
min_length=0,
max_length=100,
shuffle=False,
shuffle_batch=False,
use_token_batch=False,
clip_last_batch=False,
distribute_mode=True,
seed=0,
world_size=1,
rank=0,
pad_seq=1,
bsz_multi=8):
for arg, value in locals().items():
if arg != "self":
setattr(self, "_" + arg, value)
self._random = np.random
self._random.seed(seed)
# for multi-devices
self._distribute_mode = distribute_mode
self._nranks = world_size
self._local_rank = rank
self._sample_infos = []
for i, data in enumerate(self._dataset):
lens = [len(data[0]), len(data[1])]
self._sample_infos.append(SampleInfo(i, lens, self._pad_seq))

def __iter__(self):
# global sort or global shuffle
if self._sort_type == SortType.GLOBAL:
infos = sorted(self._sample_infos, key=lambda x: x.trg_len)
infos = sorted(infos, key=lambda x: x.src_len)
else:
if self._shuffle:
infos = self._sample_infos
self._random.shuffle(infos)
else:
infos = self._sample_infos

if self._sort_type == SortType.POOL:
reverse = True
for i in range(0, len(infos), self._pool_size):
# To avoid placing short next to long sentences
reverse = not reverse
infos[i:i + self._pool_size] = sorted(
infos[i:i + self._pool_size],
key=lambda x: x.seq_max_len,
reverse=reverse)

batches = []
batch_creator = TokenBatchCreator(
self._batch_size,
self._bsz_multi) if self._use_token_batch else SentenceBatchCreator(
self._batch_size * self._nranks)

for info in infos:
batch = batch_creator.append(info)
if batch is not None:
batches.append(batch)

if not self._clip_last_batch and len(batch_creator.batch) != 0:
batches.append(batch_creator.batch)

if self._shuffle_batch:
self._random.shuffle(batches)

if not self._use_token_batch:
# When producing batches according to sequence number, to confirm
# neighbor batches which would be feed and run parallel have similar
# length (thus similar computational cost) after shuffle, we as take
# them as a whole when shuffling and split here
batches = [[
batch[self._batch_size * i:self._batch_size * (i + 1)]
for i in range(self._nranks)
] for batch in batches]
batches = list(itertools.chain.from_iterable(batches))
self.batch_number = (len(batches) + self._nranks - 1) // self._nranks

# for multi-device
for batch_id, batch in enumerate(batches):
if not self._distribute_mode or (
batch_id % self._nranks == self._local_rank):
batch_indices = [info.i for info in batch]
yield batch_indices
if self._distribute_mode and len(batches) % self._nranks != 0:
if self._local_rank >= len(batches) % self._nranks:
# use previous data to pad
yield batch_indices

def __len__(self):
if hasattr(self, "batch_number"): #
return self.batch_number
if not self._use_token_batch:
batch_number = (
len(self._dataset) + self._batch_size * self._nranks - 1) // (
self._batch_size * self._nranks)
else:
# For uncertain batch number, the actual value is self.batch_number
batch_number = sys.maxsize
return batch_number
Loading

0 comments on commit 3a70ec6

Please sign in to comment.