diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index d9b2e0e8980..9275f315322 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -416,25 +416,6 @@ def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, * else: self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs) - # Dynamically mixin the parent class. See https://stackoverflow.com/a/31075641 - # In C++ terms, this is analogous to creating `DataLoaderAdapter : T`, where T is a DataLoader or - # StatefulDataLoader - # - # The same functionality could be achieved by directly creating the required subclasses for both {DataLoader, - # StatefulDataLoader}, however that could lead to much messier code, with duplicated classes and conditional - # dispatching scattered throughout various functions and files. - # - # This code is incredibly awkward but it's the only way to make `isinstance(obj, StatefulDataLoader)` work - # transparently. - # - # A more robust solution is for DataLoaderAdapter to not inherit from DataLoader (compose rather than inherit), - # but this would not be backwards compatible with existing code which assumes - # DataLoaderShard/DataLoaderDispatcher are DataLoaders. - base_cls = self.__class__ - base_cls_name = self.__class__.__name__ - parent_cls_name = self.base_dataloader.__class__ - self.__class__ = type(base_cls_name, (base_cls, parent_cls_name), {}) - if hasattr(self.base_dataloader, "state_dict"): self.dl_state_dict = self.base_dataloader.state_dict() @@ -451,6 +432,18 @@ def state_dict(self): def load_state_dict(self, state_dict): self.base_dataloader.load_state_dict(state_dict) + @property + def __class__(self): + """ + In order to maintain backwards compatability with other code, we need to ensure `isinstance(obj, DataLoader)` + returs true. This is because some downstream code assumes that the `DataLoader` is the base class of the + object. + """ + return self.base_dataloader.__class__ + + def __len__(self): + return len(self.base_dataloader) + def adjust_state_dict_for_prefetch(self): """ Adjusts the state dict for prefetching. Natively, this will adjust all of the iters yielded keys in @@ -580,6 +573,15 @@ def __iter__(self): self.iteration += 1 self.end() + def __reduce__(self): + """ + Define the `__reduce__` method to ensure a `DataLoaderShard` can be pickled and unpickled. This needs to be + explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its + `__class__` member. + """ + args = super().__reduce__() + return (DataLoaderShard, *args[1:]) + def set_epoch(self, epoch: int): # In case it is manually passed in, the user can set it to what they like if self.iteration != epoch: @@ -865,7 +867,7 @@ def set_epoch(self, epoch: int): self.dataset.set_epoch(epoch) def __len__(self): - whole_length = super().__len__() + whole_length = len(self.base_dataloader) if self.split_batches: return whole_length elif self._drop_last: @@ -873,6 +875,15 @@ def __len__(self): else: return math.ceil(whole_length / self.state.num_processes) + def __reduce__(self): + """ + Define the `__reduce__` method to ensure a `DataLoaderDispatcher` can be pickled and unpickled. This needs to + be explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its + `__class__` member. + """ + args = super().__reduce__() + return (DataLoaderDispatcher, *args[1:]) + @property def total_batch_size(self): return ( @@ -1211,6 +1222,18 @@ def __iter__(self): yield batch self.end() + def __len__(self): + return len(self.base_dataloader) - self.skip_batches + + def __reduce__(self): + """ + Define the `__reduce__` method to ensure a `SkipDataLoader` can be pickled and unpickled. This needs to be + explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its + `__class__` member. + """ + args = super().__reduce__() + return (SkipDataLoader, *args[1:]) + def skip_first_batches(dataloader, num_batches=0): """ diff --git a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py index 5c59a6d7326..899dc6e3f87 100644 --- a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py +++ b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pickle import tempfile import warnings from typing import List @@ -247,6 +248,16 @@ def test_join_raises_warning_for_iterable_when_overriding_even_batches(): assert "only supported for map-style datasets" in str(w[-1].message) +def test_pickle_accelerator(): + accelerator = create_accelerator() + data_loader = create_dataloader(accelerator, dataset_size=32, batch_size=4) + _ = accelerator.prepare(data_loader) + pickled_accelerator = pickle.dumps(accelerator) + unpickled_accelerator = pickle.loads(pickled_accelerator) + # TODO: Maybe this should be implemented as __eq__ for AcceleratorState? + assert accelerator.state.__dict__ == unpickled_accelerator.state.__dict__ + + def test_data_loader(data_loader, accelerator): # Prepare the DataLoader data_loader = accelerator.prepare(data_loader) @@ -368,6 +379,9 @@ def main(): test_join_raises_warning_for_non_ddp_distributed(accelerator) accelerator.state.distributed_type = original_state + accelerator.print("Test pickling an accelerator") + test_pickle_accelerator() + dataset = DummyDataset() # Conventional Dataloader with shuffle=False loader = DataLoader(dataset, shuffle=False, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS) diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index 8865e4f0009..00c18506ced 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -27,6 +27,7 @@ from accelerate import DistributedType, infer_auto_device_map, init_empty_weights, load_checkpoint_and_dispatch from accelerate.accelerator import Accelerator +from accelerate.data_loader import DataLoaderDispatcher, DataLoaderShard, skip_first_batches from accelerate.state import GradientState, PartialState from accelerate.test_utils import ( require_bnb, @@ -647,6 +648,52 @@ def test_can_unwrap_model(self): model_loaded = pickle.loads(pickle.dumps(model)) model_loaded(inputs) + @parameterized.expand([True, False]) + def test_can_pickle_dataloader(self, dispatch_batches): + """ + Test that pickling a prepared dataloader works. + """ + data = torch.arange(10).to(torch_device) + ds = torch.utils.data.TensorDataset(data) + dl = torch.utils.data.DataLoader(ds) + skip_dl = skip_first_batches(dl, 2) + + # Currently, StatefulDataLoader doesn't seem to support pickling, so we aren't testing that functionality + # TODO: Add support for pickling StatefulDataLoader + dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, use_stateful_dataloader=False) + accelerator = Accelerator(dataloader_config=dataloader_config) + + original_dl, _ = accelerator.prepare(dl, skip_dl) + if dispatch_batches: + assert isinstance(original_dl, DataLoaderDispatcher) + else: + assert isinstance(original_dl, DataLoaderShard) + + prepared_model_dumps = pickle.dumps(accelerator) + + model_loaded = pickle.loads(prepared_model_dumps) + assert len(model_loaded._dataloaders) == 2 + + # Assert equality of recovered and original dataloader + loaded_dl = model_loaded._dataloaders[0] + assert isinstance(loaded_dl, DataLoader) + if dispatch_batches: + assert isinstance(loaded_dl, DataLoaderDispatcher) + else: + assert isinstance(loaded_dl, DataLoaderShard) + assert len(loaded_dl) == len(original_dl) + assert [i for i in loaded_dl] == [i for i in original_dl] + + # Test skip dataloader works as expected as well + loaded_skip_dl = model_loaded._dataloaders[1] + assert isinstance(loaded_skip_dl, DataLoader) + if dispatch_batches: + assert isinstance(loaded_dl, DataLoaderDispatcher) + else: + assert isinstance(loaded_dl, DataLoaderShard) + assert len(loaded_skip_dl) == len(original_dl) - 2 + assert [i for i in loaded_skip_dl] == [i for i in original_dl][2:] + # Ideally would be a parameterized test which works with either stateful or non-stateful dataloaders, but dependencies are a bit awkward. @require_torchdata_stateful_dataloader def test_prepared_objects_are_referenced_with_stateful_dataloader(self): diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index 1fe34bdc6ed..4ffbf29c134 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -420,6 +420,14 @@ def test_dataloader_inheritance(self): skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2) dl_shard = DataLoaderShard(range(16), batch_size=4) dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4) + + # Test dataloaders are instances of instantiated classes + # These asserts look redundant, but it's worth checking since we are doing magic tricks such as dynamically overriding __class__ + assert isinstance(skip_dl, SkipDataLoader) + assert isinstance(dl_shard, DataLoaderShard) + assert isinstance(dl_dispatcher, DataLoaderDispatcher) + + # Test dataloaders are instances of base classes assert isinstance(skip_dl, DataLoader) assert isinstance(dl_shard, DataLoader) assert isinstance(dl_dispatcher, DataLoader) @@ -556,6 +564,13 @@ def test_dataloader_inheritance(self): skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2, use_stateful_dataloader=True) dl_shard = DataLoaderShard(range(16), batch_size=4, use_stateful_dataloader=True) dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True) + + # Test dataloaders are instances of instantiated classes + # These asserts look redundant, but it's worth checking since we are doing magic tricks such as dynamically overriding __class__ + assert isinstance(skip_dl, SkipDataLoader) + assert isinstance(dl_shard, DataLoaderShard) + assert isinstance(dl_dispatcher, DataLoaderDispatcher) + assert isinstance(skip_dl, StatefulDataLoader) assert isinstance(dl_shard, StatefulDataLoader) assert isinstance(dl_dispatcher, StatefulDataLoader)