Skip to content

Commit

Permalink
Use numpy for data sampler and address pull request comments
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani committed Jan 3, 2022
1 parent 3e745ba commit dfe9f80
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 41 deletions.
2 changes: 1 addition & 1 deletion cmake/onnxruntime_python.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ if (onnxruntime_ENABLE_TRAINING)
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_utils_data_srcs}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/ortmodule/torch_cpp_extensions/cuda/fused_ops/
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/training/utils/data/
)
endif()

Expand Down
93 changes: 54 additions & 39 deletions orttraining/orttraining/python/training/utils/data/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,25 @@
from torch.utils.data.sampler import Sampler
from torch.utils.data.dataset import Dataset
from typing import Optional, Iterator, Callable
from collections import OrderedDict
import numpy as np


# Implementation is heavily derived from bagua/load_balancing_data_loader.py
# Implementation is adapted from bagua/load_balancing_data_loader.py
# https://github.com/BaguaSys/bagua/blob/01874a7c3f90904c37c5612a9db866b5d4b8b5ed/bagua/torch_api/contrib/load_balancing_data_loader.py#L12
class LoadBalancingDistributedSampler:
r"""Sampler that balances the data load across workers based on the sample's complexity.
This sampler uses a :attr:`complexity_fn` to calculate each sample's computational
complexity and make each batch get similar computational complexity.
This is useful in scenarios like speech and NLP, where each batch has variable
length and distributed training suffers from straggler problem.
The usage is similar to `torch.utils.data.DistributedSampler <https://pytorch.org/docs/stable/data.html?highlight=distributedsampler#torch.utils.data.distributed.DistributedSampler>`_,
where each process loads a subset of the original dataset that is exclusive to it.
length and distributed training suffers from straggler problem. In such scenarios,
the complexity function could be defined to return the length of the input sample sequence.
The usage is similar to `torch.utils.data.DistributedSampler`, where each process loads a
subset of the original dataset that is exclusive to it.
The sampler sorts the dataset in increasing order of complexity. If the :attr:`group_size` is
provided, the sorting happens within dataset groups of size :attr:`group_size` before the
group order is shuffled followed by sharding of data across workers. If :attr:`group_size`
is not provided, the data is distributed across workers before the data indices for each worker
is shuffled deterministically.
.. note::
Dataset is assumed to be of constant size (map-style dataset).
Args:
Expand Down Expand Up @@ -52,21 +58,20 @@ class LoadBalancingDistributedSampler:
of load balance. 0 means the best load balance, while 1 means the opposite.
.. warning::
In distributed mode, calling the :meth:`set_epoch` method at
the beginning of each epoch **before** creating the `DataLoader <https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader>`_ iterator
the beginning of each epoch **before** creating the `torch.utils.data.DataLoader` iterator
is necessary to make shuffling work properly across multiple epochs. Otherwise,
the same ordering will be always used.
Example::
Define your :attr:`complexity_fn`, which accepts a dataset sample as its input and produces an integer
as the sample's computational complexity:
>>> dataset = torch.utils.data.TensorDataset(torch.randn(n, 2), torch.randperm(n))
>>> complexity_fn = lambda x: x[1]
>>> dataset = MyVariableSequenceLengthDataset(dataset_samples)
>>> complexity_fn = lambda x: len(x)
Below is the usage of :class:`LoadBalancingDistributedSampler`
and `DataLoader <https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader>`_:
>>> sampler = bagua.torch_api.contrib.LoadBalancingDistributedSampler(
and `torch.utils.data.DataLoader`:
>>> sampler = onnxruntime.training.utils.data.LoadBalancingDistributedSampler(
... dataset,
... complexity_fn=complexity_fn) if is_distributed else None
... complexity_fn=complexity_fn)
>>> loader = torch.utils.data.DataLoader(dataset,
... shuffle=(sampler is None),
... sampler=sampler)
>>>
>>> for epoch in range(start_epoch, n_epochs):
Expand Down Expand Up @@ -121,13 +126,8 @@ def __init__(
self.shuffle = shuffle
self.seed = seed

self.sample_complexity_list = [None]*dataset_len
for sample_index in range(dataset_len):
self.sample_complexity_list[sample_index] = \
[sample_index, complexity_fn(self.dataset[sample_index])]

max_complexity = max(self.sample_complexity_list, key=lambda t: t[1])[1]
min_complexity = min(self.sample_complexity_list, key=lambda t: t[1])[1]
self.complexity_fn = complexity_fn
self.sample_complexities = None

if random_level < 0.0 or random_level > 1.0:
raise ValueError(
Expand All @@ -136,7 +136,8 @@ def __init__(
)
)

self.random_number = int((max_complexity - min_complexity) * random_level + 1)
self.random_level = random_level
self.random_number = None

def _sort_shard_and_shuffle_dataset(self):
# This method returns a list of dataset sample indices after
Expand All @@ -147,17 +148,18 @@ def _sort_shard_and_shuffle_dataset(self):
# Shuffling is done either before sharding on the group indices (if group_size is provided)
# or on the dataset sample indices if the group_size is not provided.

def sort_in_groups(sample_complexity_list, group_size):
def sort_in_groups(sample_complexities, group_size):
"""Sort the dataset samples indices inside each group of size group_size."""
# If the group_size is None, the entire dataset is considered as a single group
if group_size is None:
group_size = len(sample_complexity_list)
group_size = len(sample_complexities)
# Sort the dataset samples inside each group of the dataset based on sample complexity.
for group_begin_index in range(0, len(sample_complexity_list), group_size):
group_end_index = min(group_begin_index + group_size, len(sample_complexity_list))
sample_complexity_list[group_begin_index:group_end_index] = \
sorted(sample_complexity_list[group_begin_index:group_end_index], key=lambda t: t[1])
return sample_complexity_list
for group_begin_index in range(0, len(sample_complexities), group_size):
group_end_index = min(group_begin_index + group_size, len(sample_complexities))
sorted_indices = \
group_begin_index + np.argsort(sample_complexities[group_begin_index:group_end_index, 1])
sample_complexities[group_begin_index:group_end_index, :] = sample_complexities[sorted_indices]
return sample_complexities

def chunks_wrap_padding(dataset_index_list, num_shards):
"""Yield successive num_shards-sized chunks from dataset_index_list."""
Expand All @@ -170,43 +172,56 @@ def chunks_wrap_padding(dataset_index_list, num_shards):
yield current_lst
current_lst = []

sample_complexity_list = self.sample_complexity_list.copy()
# Get the samples and their complexities from the complexity_fn
if not self.sample_complexities:
self.sample_complexities = np.empty((len(self.dataset), 2), dtype=np.int64)
for sample_index in range(len(self.dataset)):
self.sample_complexities[sample_index][0] = sample_index
self.sample_complexities[sample_index][1] = self.complexity_fn(self.dataset[sample_index])

if self.random_number is None:
max_complexity = max(self.sample_complexities, key=lambda t: t[1])[1]
min_complexity = min(self.sample_complexities, key=lambda t: t[1])[1]
self.random_number = int((max_complexity - min_complexity) * self.random_level + 1)

sample_complexities = self.sample_complexities.copy()

# Control the degree of load balancing by modifying the complexities of
# all samples using the random_number.
g = torch.Generator()
g.manual_seed(self.seed + self.epoch)
g = g.manual_seed(self.seed + self.epoch)

if self.random_number > 1:
complexity_random_ints = torch.randint(
self.random_number, (len(sample_complexity_list),), generator=g
self.random_number, (len(sample_complexities),), generator=g
).tolist()

for index, random_int in enumerate(complexity_random_ints):
sample_complexity_list[index][1] += random_int
sample_complexities[index][1] += random_int

# Sort the data based on the computed complexities and group sizes.
ordered_sample_complexity_list = sort_in_groups(sample_complexity_list, self.group_size)
ordered_sample_complexities = sort_in_groups(sample_complexities, self.group_size)

# If group_size is not None, shuffle the index of each group instead
# of shuffling the data indices.
if self.shuffle and self.group_size is not None:
num_groups = (len(self.sample_complexity_list) + self.group_size - 1) // self.group_size
num_groups = (len(self.sample_complexities) + self.group_size - 1) // self.group_size
group_order = torch.randperm(num_groups, generator=g).tolist()
end = 0
sample_complexity_list_copy = sample_complexity_list.copy()
sample_complexities_copy = ordered_sample_complexities.copy()
for group_index in group_order:
original_list_begin_index = self.group_size*group_index
original_list_end_index = min(original_list_begin_index+self.group_size, len(sample_complexity_list))
original_list_end_index = min(original_list_begin_index+self.group_size, len(sample_complexities))
begin = end
end = begin + (original_list_end_index - original_list_begin_index)
sample_complexity_list_copy[begin:end] = sample_complexity_list[original_list_begin_index:original_list_end_index]
ordered_sample_complexity_list = sample_complexity_list_copy
sample_complexities_copy[begin:end, :] = \
sample_complexities[original_list_begin_index:original_list_end_index, :]
ordered_sample_complexities = sample_complexities_copy

# Shard the data across the different workers.
index_chunks = list(
chunks_wrap_padding(
[index_complexity_tuple[0] for index_complexity_tuple in ordered_sample_complexity_list], self.world_size
[index_complexity_tuple[0] for index_complexity_tuple in ordered_sample_complexities], self.world_size
)
)

Expand Down Expand Up @@ -264,7 +279,7 @@ class LoadBalancingDistributedBatchSampler(Sampler):
:attr:`batch_fn` will have the signature of::
def batch_fn(indices: List[int]) -> List[List[int]]
Example::
>>> from bagua.torch_api.contrib import LoadBalancingDistributedSampler, \
>>> from onnxruntime.training.utils.data import LoadBalancingDistributedSampler, \
... LoadBalancingDistributedBatchSampler
>>>
>>> sampler = LoadBalancingDistributedSampler(dataset, complexity_fn=complexity_fn)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def complexity_fn(sample):
len(samples_and_complexities)+group_size-1) // group_size,
generator=torch.Generator().manual_seed(0)).tolist()
end = 0
for original_group_index, group_index in enumerate(shuffled_group_order):
for group_index in shuffled_group_order:
original_begin = group_index*group_size
original_end = min(original_begin+group_size, len(samples_and_complexities))
begin = end
Expand Down

0 comments on commit dfe9f80

Please sign in to comment.