-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
make the data loader interface more general.
Signed-off-by: Peng Zhang <[email protected]>
- Loading branch information
Showing
11 changed files
with
194 additions
and
163 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
from queue import Queue, Empty | ||
from threading import Thread, Event | ||
|
||
|
||
class BaseDataLoader(object): | ||
def __len__(self): | ||
""" | ||
lenght of the batches to be loaded. | ||
""" | ||
# If we cannot infer the number of iteratios we return 0 | ||
return 0 | ||
|
||
def _process_batch(self, batch): | ||
""" | ||
Hook before output a batch for custom needs. | ||
""" | ||
return batch | ||
|
||
def __iter__(self): | ||
""" | ||
Starting iteration and get batchs | ||
""" | ||
for b in self._iterate(): | ||
yield self._process_batch(b) | ||
|
||
def _iterate(self): | ||
""" | ||
interface for the implimentation of iterate batches | ||
""" | ||
raise NotImplementedError() | ||
|
||
|
||
class BaseAsyncDataLoader(BaseDataLoader): | ||
def __init__(self, maxsize=5): | ||
self.maxsize = maxsize | ||
if self.maxsize > 0: | ||
self.finished_event = Event() | ||
self.q = Queue(self.maxsize) | ||
self.t = Thread(target=self._worker) | ||
self.t.daemon = True | ||
self.started = False | ||
|
||
def __del__(self): | ||
self.close() | ||
|
||
def close(self): | ||
if self.maxsize > 0 and self.started: | ||
self.finished_event.set() | ||
try: | ||
# Free buffer to allow worker to retry | ||
self.q.get_nowait() | ||
except Empty: | ||
pass | ||
self.t.join() | ||
|
||
def _worker(self): | ||
try: | ||
while not self.finished_event.is_set(): | ||
for b in self._iterate(): | ||
if self.finished_event.is_set(): | ||
break | ||
self.q.put(b) | ||
self.q.put(None) | ||
except Exception as ex: | ||
self.q.put(ex) | ||
self.q.put(None) | ||
finally: | ||
self.q.put(None) | ||
|
||
def __iter__(self): | ||
if self.maxsize > 0: | ||
if not self.started: | ||
self.started = True | ||
self.t.start() | ||
while True: | ||
b = self.q.get() | ||
if b is None: | ||
break | ||
if isinstance(b, Exception): | ||
raise b | ||
yield self._process_batch(b) | ||
else: | ||
for b in self._iterate(): | ||
yield self._process_batch(b) | ||
|
||
def _iterate(self): | ||
""" | ||
interface for the implimentation of iterate batches | ||
""" | ||
raise NotImplementedError() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from petastorm.pytorch import BatchedDataLoader | ||
from .base_data_loader import BaseDataLoader, BaseAsyncDataLoader | ||
|
||
|
||
class PytorchDataLoader(BaseDataLoader): | ||
def __init__(self, reader, batch_size, shuffling_queue_capacity): | ||
self.reader = reader | ||
self.batch_size = batch_size | ||
self.shuffling_queue_capacity = shuffling_queue_capacity | ||
print(f"Initializing petastorm dataloader with batch_size {batch_size}" | ||
f" and shuffling_queue_capacity {shuffling_queue_capacity}") | ||
|
||
def __len__(self): | ||
return len(self.reader) | ||
|
||
def _iterate(self): | ||
if self.reader.last_row_consumed: | ||
print(f"Resetting Petastorm reader for {self.reader.dataset.paths}") | ||
self.reader.reset() | ||
|
||
# Re-create the data loader for each iterate. There maybe some left over data | ||
# from last epoch which will cause petastorm's BatchedDataLoader fail to reset. | ||
data_loader = BatchedDataLoader( | ||
self.reader, | ||
batch_size=self.batch_size, | ||
shuffling_queue_capacity=self.shuffling_queue_capacity, | ||
) | ||
|
||
for batch in data_loader: | ||
yield batch | ||
|
||
|
||
class PytorchAsyncDataLoader(BaseAsyncDataLoader): | ||
def __init__(self, reader, batch_size, shuffling_queue_capacity, q_size=64): | ||
super().__init__(q_size) | ||
|
||
self.reader = reader | ||
self.batch_size = batch_size | ||
self.shuffling_queue_capacity = shuffling_queue_capacity | ||
|
||
def __len__(self): | ||
return len(self.reader) | ||
|
||
def _iterate(self): | ||
if self.reader.last_row_consumed: | ||
print(f"Resetting Petastorm reader for {self.reader.dataset.paths}") | ||
self.reader.reset() | ||
|
||
# Re-create the data loader for each iterate. There maybe some left over data | ||
# from last epoch which will cause petastorm's BatchedDataLoader fail to reset. | ||
data_loader = BatchedDataLoader( | ||
self.reader, | ||
batch_size=self.batch_size, | ||
shuffling_queue_capacity=self.shuffling_queue_capacity, | ||
) | ||
|
||
for batch in data_loader: | ||
yield batch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.