Skip to content

Commit

Permalink
make the data loader interface more general.
Browse files Browse the repository at this point in the history
Signed-off-by: Peng Zhang <[email protected]>
  • Loading branch information
irasit committed May 26, 2021
1 parent 36350e2 commit a24d6fa
Show file tree
Hide file tree
Showing 11 changed files with 194 additions and 163 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).
## [Unreleased] - YYYY-MM-DD

### Added
Custom spark data loader interface. ([#2938](https://github.com/horovod/horovod/issues/2938))

### Changed

Expand Down
2 changes: 2 additions & 0 deletions docs/pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,5 @@ Start the training job and specify the number of workers on the command line as
You can find an example of use pytorch lightning trainer with horovod backend in `pytorch_lightning_mnist.py script <../examples/pytorch/pytorch_lightning_mnist.py>`__

See the PyTorch Lightning `docs <https://pytorch-lightning.readthedocs.io/en/stable/multi_gpu.html#horovod>`_ for more details.

A pytorch-lightning based spark estimator trainer is also added example is in `pytorch_lightning_spark_mnist.py <../examples/spark/pytorch/pytorch_lightning_spark_mnist.py>`__
2 changes: 2 additions & 0 deletions docs/spark.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ logging (for Tensorboard) using the Estimator ``Store`` abstraction. Stores are
artifacts including intermediate representations of the training data. Horovod natively supports stores for HDFS
and local filesystems.

Petastorm based data loader is used by default, but user can define a custom data loader by override the `base_data_loader` interface.

End-to-end example
------------------
`keras_spark_rossmann_estimator.py script <../examples/spark/keras/keras_spark_rossmann_estimator.py>`__ provides
Expand Down
1 change: 0 additions & 1 deletion examples/spark/pytorch/pytorch_lightning_spark_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ def __init__(self):
self.fc2 = nn.Linear(50, 10)

def forward(self, x):
# raise RuntimeError("x shape is {}".format(x.shape))
x = x.float().reshape((-1, 1, 28, 28))
x = F.relu(F.max_pool2d(self.conv1(x), 2))
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
Expand Down
151 changes: 0 additions & 151 deletions horovod/spark/common/data_loader.py

This file was deleted.

Empty file.
90 changes: 90 additions & 0 deletions horovod/spark/data_loaders/base_data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from queue import Queue, Empty
from threading import Thread, Event


class BaseDataLoader(object):
def __len__(self):
"""
lenght of the batches to be loaded.
"""
# If we cannot infer the number of iteratios we return 0
return 0

def _process_batch(self, batch):
"""
Hook before output a batch for custom needs.
"""
return batch

def __iter__(self):
"""
Starting iteration and get batchs
"""
for b in self._iterate():
yield self._process_batch(b)

def _iterate(self):
"""
interface for the implimentation of iterate batches
"""
raise NotImplementedError()


class BaseAsyncDataLoader(BaseDataLoader):
def __init__(self, maxsize=5):
self.maxsize = maxsize
if self.maxsize > 0:
self.finished_event = Event()
self.q = Queue(self.maxsize)
self.t = Thread(target=self._worker)
self.t.daemon = True
self.started = False

def __del__(self):
self.close()

def close(self):
if self.maxsize > 0 and self.started:
self.finished_event.set()
try:
# Free buffer to allow worker to retry
self.q.get_nowait()
except Empty:
pass
self.t.join()

def _worker(self):
try:
while not self.finished_event.is_set():
for b in self._iterate():
if self.finished_event.is_set():
break
self.q.put(b)
self.q.put(None)
except Exception as ex:
self.q.put(ex)
self.q.put(None)
finally:
self.q.put(None)

def __iter__(self):
if self.maxsize > 0:
if not self.started:
self.started = True
self.t.start()
while True:
b = self.q.get()
if b is None:
break
if isinstance(b, Exception):
raise b
yield self._process_batch(b)
else:
for b in self._iterate():
yield self._process_batch(b)

def _iterate(self):
"""
interface for the implimentation of iterate batches
"""
raise NotImplementedError()
58 changes: 58 additions & 0 deletions horovod/spark/data_loaders/pytorch_data_loaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from petastorm.pytorch import BatchedDataLoader
from .base_data_loader import BaseDataLoader, BaseAsyncDataLoader


class PytorchDataLoader(BaseDataLoader):
def __init__(self, reader, batch_size, shuffling_queue_capacity):
self.reader = reader
self.batch_size = batch_size
self.shuffling_queue_capacity = shuffling_queue_capacity
print(f"Initializing petastorm dataloader with batch_size {batch_size}"
f" and shuffling_queue_capacity {shuffling_queue_capacity}")

def __len__(self):
return len(self.reader)

def _iterate(self):
if self.reader.last_row_consumed:
print(f"Resetting Petastorm reader for {self.reader.dataset.paths}")
self.reader.reset()

# Re-create the data loader for each iterate. There maybe some left over data
# from last epoch which will cause petastorm's BatchedDataLoader fail to reset.
data_loader = BatchedDataLoader(
self.reader,
batch_size=self.batch_size,
shuffling_queue_capacity=self.shuffling_queue_capacity,
)

for batch in data_loader:
yield batch


class PytorchAsyncDataLoader(BaseAsyncDataLoader):
def __init__(self, reader, batch_size, shuffling_queue_capacity, q_size=64):
super().__init__(q_size)

self.reader = reader
self.batch_size = batch_size
self.shuffling_queue_capacity = shuffling_queue_capacity

def __len__(self):
return len(self.reader)

def _iterate(self):
if self.reader.last_row_consumed:
print(f"Resetting Petastorm reader for {self.reader.dataset.paths}")
self.reader.reset()

# Re-create the data loader for each iterate. There maybe some left over data
# from last epoch which will cause petastorm's BatchedDataLoader fail to reset.
data_loader = BatchedDataLoader(
self.reader,
batch_size=self.batch_size,
shuffling_queue_capacity=self.shuffling_queue_capacity,
)

for batch in data_loader:
yield batch
2 changes: 1 addition & 1 deletion horovod/spark/lightning/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ class TorchEstimator(HorovodEstimator, TorchEstimatorParamsWritable,
'Name of the dataloader class.')

loader_num_epochs = Param(Params._dummy(), 'loader_num_epochs',
'Number of epochs whcih data loader reads in each iteration. If set to None, reader will be in infinite loop mode.')
'An epoch is a single pass over all rows in the dataset. If set to None, reader will be in infinite loop mode, and generate unlimite data as needed. ')

@keyword_only
def __init__(self,
Expand Down
10 changes: 4 additions & 6 deletions horovod/spark/lightning/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,6 @@ def train(serialized_model):

model = deserialize(serialized_model)

# _train_steps_per_epoch = train_steps_per_epoch if train_steps_per_epoch else 1.0
# _val_steps_per_epoch = val_steps_per_epoch if val_steps_per_epoch else 1.0
_train_steps_per_epoch = train_steps_per_epoch
if _train_steps_per_epoch is None:
_train_steps_per_epoch = int(math.floor(float(train_rows) / batch_size / hvd.size()))
Expand Down Expand Up @@ -247,7 +245,7 @@ def make_petastorm_reader(model, data_path, dataloader_attr, reader_worker_count
**reader_factory_kwargs) as reader:
def dataloader_fn():
return data_loader_cls(reader, batch_size=batch_size,
shuffling_queue_capacity=calculate_shuffle_buffer_size())
shuffling_queue_capacity=calculate_shuffle_buffer_size())
try:
setattr(model, dataloader_attr, dataloader_fn)
yield
Expand Down Expand Up @@ -303,9 +301,9 @@ def calculate_shuffle_buffer_size():

def _create_dataloader(feature_columns, input_shapes, metadata, data_loader_cls=None):
if data_loader_cls is None:
# set PetastormAsyncDataLoader as default
from horovod.spark.common.data_loader import PetastormAsyncDataLoader
data_loader_cls = PetastormAsyncDataLoader
# set PytorchAsyncDataLoader as default
from horovod.spark.data_loaders.pytorch_data_loaders import PytorchAsyncDataLoader
data_loader_cls = PytorchAsyncDataLoader

print(f"Using dataloader: {data_loader_cls}")

Expand Down
Loading

0 comments on commit a24d6fa

Please sign in to comment.