Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow DataLoaderAdapter subclasses to be pickled by implementing __reduce__ #3074

Merged
merged 20 commits into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 30 additions & 21 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,25 +416,6 @@ def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, *
else:
self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs)

# Dynamically mixin the parent class. See https://stackoverflow.com/a/31075641
# In C++ terms, this is analogous to creating `DataLoaderAdapter<T> : T`, where T is a DataLoader or
# StatefulDataLoader
#
# The same functionality could be achieved by directly creating the required subclasses for both {DataLoader,
# StatefulDataLoader}, however that could lead to much messier code, with duplicated classes and conditional
# dispatching scattered throughout various functions and files.
#
# This code is incredibly awkward but it's the only way to make `isinstance(obj, StatefulDataLoader)` work
# transparently.
#
# A more robust solution is for DataLoaderAdapter to not inherit from DataLoader (compose rather than inherit),
# but this would not be backwards compatible with existing code which assumes
# DataLoaderShard/DataLoaderDispatcher are DataLoaders.
base_cls = self.__class__
base_cls_name = self.__class__.__name__
parent_cls_name = self.base_dataloader.__class__
self.__class__ = type(base_cls_name, (base_cls, parent_cls_name), {})

if hasattr(self.base_dataloader, "state_dict"):
self.dl_state_dict = self.base_dataloader.state_dict()

Expand All @@ -451,6 +432,17 @@ def state_dict(self):
def load_state_dict(self, state_dict):
self.base_dataloader.load_state_dict(state_dict)

@property
def __class__(self):
"""
In order to maintain backwards compatability with other code, we need to ensure `isinstance(obj, DataLoader)` returs true.
This is because some downstream code assumes that the `DataLoader` is the base class of the object.
"""
return self.base_dataloader.__class__

def __len__(self):
return len(self.base_dataloader)

def adjust_state_dict_for_prefetch(self):
"""
Adjusts the state dict for prefetching. Natively, this will adjust all of the iters yielded keys in
Expand Down Expand Up @@ -487,7 +479,6 @@ def _update_state_dict(self):
# Then tag if we are at the end of the dataloader
self.dl_state_dict["_iterator_finished"] = self.end_of_dataloader


class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin):
"""
Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup.
Expand Down Expand Up @@ -580,6 +571,11 @@ def __iter__(self):
self.iteration += 1
self.end()

def __reduce__(self):
args = super().__reduce__()
return (DataLoaderShard, *args[1:])


def set_epoch(self, epoch: int):
# In case it is manually passed in, the user can set it to what they like
if self.iteration != epoch:
Expand Down Expand Up @@ -865,14 +861,18 @@ def set_epoch(self, epoch: int):
self.dataset.set_epoch(epoch)

def __len__(self):
whole_length = super().__len__()
whole_length = len(self.base_dataloader)
if self.split_batches:
return whole_length
elif self._drop_last:
return whole_length // self.state.num_processes
else:
return math.ceil(whole_length / self.state.num_processes)

def __reduce__(self):
args = super().__reduce__()
return (DataLoaderDispatcher, *args[1:])

@property
def total_batch_size(self):
return (
Expand Down Expand Up @@ -1211,6 +1211,15 @@ def __iter__(self):
yield batch
self.end()

def __len__(self):
print("len called")
byi8220 marked this conversation as resolved.
Show resolved Hide resolved
return len(self.base_dataloader) - self.skip_batches

def __reduce__(self):
args = super().__reduce__()
return (SkipDataLoader, *args[1:])



def skip_first_batches(dataloader, num_batches=0):
"""
Expand Down
14 changes: 13 additions & 1 deletion src/accelerate/test_utils/scripts/test_distributed_data_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import pickle
import tempfile
import warnings
from typing import List
Expand Down Expand Up @@ -246,6 +247,15 @@ def test_join_raises_warning_for_iterable_when_overriding_even_batches():
assert issubclass(w[-1].category, UserWarning)
assert "only supported for map-style datasets" in str(w[-1].message)

def test_pickle_accelerator():
accelerator = create_accelerator()
data_loader = create_dataloader(accelerator, dataset_size=32, batch_size=4)
_ = accelerator.prepare(data_loader)
pickled_accelerator = pickle.dumps(accelerator)
unpickled_accelerator = pickle.loads(pickled_accelerator)
# TODO: Maybe this should be implemented as __eq__ for AcceleratorState?
assert accelerator.state.__dict__ == unpickled_accelerator.state.__dict__


def test_data_loader(data_loader, accelerator):
# Prepare the DataLoader
Expand Down Expand Up @@ -339,7 +349,6 @@ def test_stateful_dataloader_save_state(accelerator):
finally:
accelerator.dataloader_config = old_dataloader_config


def main():
accelerator = create_accelerator()
torch.manual_seed(accelerator.process_index)
Expand Down Expand Up @@ -368,6 +377,9 @@ def main():
test_join_raises_warning_for_non_ddp_distributed(accelerator)
accelerator.state.distributed_type = original_state

accelerator.print("Test pickling an accelerator")
test_pickle_accelerator()

dataset = DummyDataset()
# Conventional Dataloader with shuffle=False
loader = DataLoader(dataset, shuffle=False, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)
Expand Down
48 changes: 48 additions & 0 deletions tests/test_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from accelerate import DistributedType, infer_auto_device_map, init_empty_weights, load_checkpoint_and_dispatch
from accelerate.accelerator import Accelerator
from accelerate.data_loader import DataLoaderDispatcher, DataLoaderShard, SkipDataLoader, skip_first_batches
from accelerate.state import GradientState, PartialState
from accelerate.test_utils import (
require_bnb,
Expand Down Expand Up @@ -647,6 +648,53 @@ def test_can_unwrap_model(self):
model_loaded = pickle.loads(pickle.dumps(model))
model_loaded(inputs)

@parameterized.expand([True, False])
def test_can_pickle_dataloader(self, dispatch_batches):
"""
Test that pickling a prepared dataloader works.
"""
data = torch.arange(10).to(torch_device)
ds = torch.utils.data.TensorDataset(data)
dl = torch.utils.data.DataLoader(ds)
skip_dl = skip_first_batches(dl, 2)

# Currently, StatefulDataLoader doesn't seem to support pickling, so we aren't testing that functionality
# TODO: Add support for pickling StatefulDataLoader
dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, use_stateful_dataloader=False)
accelerator = Accelerator(dataloader_config=dataloader_config)

original_dl, prepared_skip_dl = accelerator.prepare(dl, skip_dl)
if dispatch_batches:
assert isinstance(original_dl, DataLoaderDispatcher)
else:
assert isinstance(original_dl, DataLoaderShard)

prepared_model_dumps = pickle.dumps(accelerator)

model_loaded = pickle.loads(prepared_model_dumps)
assert len(model_loaded._dataloaders) == 2

# Assert equality of recovered and original dataloader
loaded_dl = model_loaded._dataloaders[0]
assert isinstance(loaded_dl, DataLoader)
if dispatch_batches:
assert isinstance(loaded_dl, DataLoaderDispatcher)
else:
assert isinstance(loaded_dl, DataLoaderShard)
assert len(loaded_dl) == len(original_dl)
assert [i for i in loaded_dl] == [i for i in original_dl]

# Test skip dataloader works as expected as well
loaded_skip_dl = model_loaded._dataloaders[1]
print(model_loaded._dataloaders)
byi8220 marked this conversation as resolved.
Show resolved Hide resolved
assert isinstance(loaded_skip_dl, DataLoader)
if dispatch_batches:
assert isinstance(loaded_dl, DataLoaderDispatcher)
else:
assert isinstance(loaded_dl, DataLoaderShard)
assert len(loaded_skip_dl) == len(original_dl) - 2
assert [i for i in loaded_skip_dl] == [i for i in original_dl][2:]

# Ideally would be a parameterized test which works with either stateful or non-stateful dataloaders, but dependencies are a bit awkward.
@require_torchdata_stateful_dataloader
def test_prepared_objects_are_referenced_with_stateful_dataloader(self):
Expand Down
15 changes: 15 additions & 0 deletions tests/test_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,14 @@ def test_dataloader_inheritance(self):
skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2)
dl_shard = DataLoaderShard(range(16), batch_size=4)
dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4)

# Test dataloaders are instances of instantiated classes
# These asserts look redundant, but it's worth checking since we are doing magic tricks such as dynamically overriding __class__
assert isinstance(skip_dl, SkipDataLoader)
assert isinstance(dl_shard, DataLoaderShard)
assert isinstance(dl_dispatcher, DataLoaderDispatcher)

# Test dataloaders are instances of base classes
assert isinstance(skip_dl, DataLoader)
assert isinstance(dl_shard, DataLoader)
assert isinstance(dl_dispatcher, DataLoader)
Expand Down Expand Up @@ -556,6 +564,13 @@ def test_dataloader_inheritance(self):
skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2, use_stateful_dataloader=True)
dl_shard = DataLoaderShard(range(16), batch_size=4, use_stateful_dataloader=True)
dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True)

# Test dataloaders are instances of instantiated classes
# These asserts look redundant, but it's worth checking since we are doing magic tricks such as dynamically overriding __class__
assert isinstance(skip_dl, SkipDataLoader)
assert isinstance(dl_shard, DataLoaderShard)
assert isinstance(dl_dispatcher, DataLoaderDispatcher)

assert isinstance(skip_dl, StatefulDataLoader)
assert isinstance(dl_shard, StatefulDataLoader)
assert isinstance(dl_dispatcher, StatefulDataLoader)
Expand Down
Loading