From 42a253b90cdc4c936887a2ee7e680d4b63f24384 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Tue, 3 Sep 2024 16:40:14 -0400 Subject: [PATCH 01/19] initial fix for breaking accelerator pickling --- src/accelerate/data_loader.py | 85 ++++++++++++++++++++++++++--------- tests/test_accelerator.py | 34 ++++++++++++++ tests/test_data_loader.py | 15 +++++++ 3 files changed, 114 insertions(+), 20 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index d9b2e0e8980..7260fc57d1b 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -416,25 +416,6 @@ def __init__(self, dataset, use_stateful_dataloader=False, batch_sampler=None, * else: self.base_dataloader = DataLoader(dataset, batch_sampler=batch_sampler, **kwargs) - # Dynamically mixin the parent class. See https://stackoverflow.com/a/31075641 - # In C++ terms, this is analogous to creating `DataLoaderAdapter : T`, where T is a DataLoader or - # StatefulDataLoader - # - # The same functionality could be achieved by directly creating the required subclasses for both {DataLoader, - # StatefulDataLoader}, however that could lead to much messier code, with duplicated classes and conditional - # dispatching scattered throughout various functions and files. - # - # This code is incredibly awkward but it's the only way to make `isinstance(obj, StatefulDataLoader)` work - # transparently. - # - # A more robust solution is for DataLoaderAdapter to not inherit from DataLoader (compose rather than inherit), - # but this would not be backwards compatible with existing code which assumes - # DataLoaderShard/DataLoaderDispatcher are DataLoaders. - base_cls = self.__class__ - base_cls_name = self.__class__.__name__ - parent_cls_name = self.base_dataloader.__class__ - self.__class__ = type(base_cls_name, (base_cls, parent_cls_name), {}) - if hasattr(self.base_dataloader, "state_dict"): self.dl_state_dict = self.base_dataloader.state_dict() @@ -451,6 +432,18 @@ def state_dict(self): def load_state_dict(self, state_dict): self.base_dataloader.load_state_dict(state_dict) + @property + def __class__(self): + """ + In order to maintain backwards compatability with other code, we need to ensure `isinstance(obj, DataLoader)` + returs true. This is because some downstream code assumes that the `DataLoader` is the base class of the + object. + """ + return self.base_dataloader.__class__ + + def __len__(self): + return len(self.base_dataloader) + def adjust_state_dict_for_prefetch(self): """ Adjusts the state dict for prefetching. Natively, this will adjust all of the iters yielded keys in @@ -488,6 +481,17 @@ def _update_state_dict(self): self.dl_state_dict["_iterator_finished"] = self.end_of_dataloader +class DataLoaderAdapterImpl(DataLoaderAdapter, DataLoader): + pass + + +if is_torchdata_stateful_dataloader_available(): + from torchdata.stateful_dataloader import StatefulDataLoader + + class StatefulDataLoaderAdapterImpl(DataLoaderAdapter, StatefulDataLoader): + pass + + class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin): """ Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup. @@ -580,6 +584,22 @@ def __iter__(self): self.iteration += 1 self.end() + def __reduce__(self): + return ( + DataLoaderShard, + ( + self.base_dataloader.dataset, + self.device, + self.rng_types, + self.synchronized_generator, + self.skip_batches, + self.use_stateful_dataloader, + self._drop_last, + self._non_blocking, + ), + self.__dict__, + ) + def set_epoch(self, epoch: int): # In case it is manually passed in, the user can set it to what they like if self.iteration != epoch: @@ -865,7 +885,7 @@ def set_epoch(self, epoch: int): self.dataset.set_epoch(epoch) def __len__(self): - whole_length = super().__len__() + whole_length = self.base_dataloader.__len__() if self.split_batches: return whole_length elif self._drop_last: @@ -873,6 +893,21 @@ def __len__(self): else: return math.ceil(whole_length / self.state.num_processes) + def __reduce__(self): + return ( + DataLoaderDispatcher, + ( + self.base_dataloader.dataset, + self.split_batches, + self.skip_batches, + self.use_stateful_dataloader, + self._drop_last, + self._non_blocking, + self.slice_fn, + ), + self.__dict__, + ) + @property def total_batch_size(self): return ( @@ -1211,6 +1246,16 @@ def __iter__(self): yield batch self.end() + def __len__(self): + return len(self.base_dataloader) - self.skip_batches + + def __reduce__(self): + return ( + SkipDataLoader, + (self.base_dataloader.dataset, self.skip_batches, self.use_stateful_dataloader), + self.__dict__, + ) + def skip_first_batches(dataloader, num_batches=0): """ diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index 8865e4f0009..a2e84387ebc 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -27,6 +27,7 @@ from accelerate import DistributedType, infer_auto_device_map, init_empty_weights, load_checkpoint_and_dispatch from accelerate.accelerator import Accelerator +from accelerate.data_loader import DataLoaderDispatcher, DataLoaderShard, skip_first_batches from accelerate.state import GradientState, PartialState from accelerate.test_utils import ( require_bnb, @@ -647,6 +648,39 @@ def test_can_unwrap_model(self): model_loaded = pickle.loads(pickle.dumps(model)) model_loaded(inputs) + @parameterized.expand([True, False]) + def test_can_pickle_dataloader(self, dispatch_batches): + """ + Test that pickling a prepared dataloader works. + """ + data = torch.arange(10) + ds = torch.utils.data.TensorDataset(data) + dl = torch.utils.data.DataLoader(ds) + # Currently, StatefulDataLoader doesn't seem to support pickling, so we aren't testing that functionality + # TODO: Add support for pickling StatefulDataLoader + dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, use_stateful_dataloader=False) + accelerator = Accelerator(dataloader_config=dataloader_config) + original_dl = accelerator.prepare(dl) + prepared_model_dumps = pickle.dumps(accelerator) + + model_loaded = pickle.loads(prepared_model_dumps) + # Assert equality of recovered and original dataloader + assert isinstance(model_loaded._dataloaders[0], DataLoader) + if dispatch_batches: + assert isinstance(model_loaded._dataloaders[0], DataLoaderDispatcher) + else: + assert isinstance(model_loaded._dataloaders[0], DataLoaderShard) + assert len(model_loaded._dataloaders[0]) == len(original_dl) + assert [i for i in model_loaded._dataloaders[0]] == [i for i in original_dl] + + # Test skip dataloader works as expected as well + skip_dl = skip_first_batches(original_dl, 2) + assert isinstance(skip_dl, torch.utils.data.DataLoader) + assert len(skip_dl) == len(original_dl) - 2 + orig_items = [i for i in original_dl] + skip_dl_items = [i for i in skip_dl] + assert orig_items[2:] == skip_dl_items + # Ideally would be a parameterized test which works with either stateful or non-stateful dataloaders, but dependencies are a bit awkward. @require_torchdata_stateful_dataloader def test_prepared_objects_are_referenced_with_stateful_dataloader(self): diff --git a/tests/test_data_loader.py b/tests/test_data_loader.py index 1fe34bdc6ed..4ffbf29c134 100644 --- a/tests/test_data_loader.py +++ b/tests/test_data_loader.py @@ -420,6 +420,14 @@ def test_dataloader_inheritance(self): skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2) dl_shard = DataLoaderShard(range(16), batch_size=4) dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4) + + # Test dataloaders are instances of instantiated classes + # These asserts look redundant, but it's worth checking since we are doing magic tricks such as dynamically overriding __class__ + assert isinstance(skip_dl, SkipDataLoader) + assert isinstance(dl_shard, DataLoaderShard) + assert isinstance(dl_dispatcher, DataLoaderDispatcher) + + # Test dataloaders are instances of base classes assert isinstance(skip_dl, DataLoader) assert isinstance(dl_shard, DataLoader) assert isinstance(dl_dispatcher, DataLoader) @@ -556,6 +564,13 @@ def test_dataloader_inheritance(self): skip_dl = SkipDataLoader(range(16), batch_size=4, skip_batches=2, use_stateful_dataloader=True) dl_shard = DataLoaderShard(range(16), batch_size=4, use_stateful_dataloader=True) dl_dispatcher = DataLoaderDispatcher(range(16), batch_size=4, use_stateful_dataloader=True) + + # Test dataloaders are instances of instantiated classes + # These asserts look redundant, but it's worth checking since we are doing magic tricks such as dynamically overriding __class__ + assert isinstance(skip_dl, SkipDataLoader) + assert isinstance(dl_shard, DataLoaderShard) + assert isinstance(dl_dispatcher, DataLoaderDispatcher) + assert isinstance(skip_dl, StatefulDataLoader) assert isinstance(dl_shard, StatefulDataLoader) assert isinstance(dl_dispatcher, StatefulDataLoader) From 364a7fdb5a878e1bbf50f4251eb7743528bc9510 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Tue, 3 Sep 2024 16:42:32 -0400 Subject: [PATCH 02/19] cleanup --- src/accelerate/data_loader.py | 17 ++--------------- 1 file changed, 2 insertions(+), 15 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 7260fc57d1b..7c8ab2dc79c 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -435,9 +435,8 @@ def load_state_dict(self, state_dict): @property def __class__(self): """ - In order to maintain backwards compatability with other code, we need to ensure `isinstance(obj, DataLoader)` - returs true. This is because some downstream code assumes that the `DataLoader` is the base class of the - object. + In order to maintain backwards compatability with other code, we need to ensure `isinstance(obj, DataLoader)` returs true. + This is because some downstream code assumes that the `DataLoader` is the base class of the object. """ return self.base_dataloader.__class__ @@ -480,18 +479,6 @@ def _update_state_dict(self): # Then tag if we are at the end of the dataloader self.dl_state_dict["_iterator_finished"] = self.end_of_dataloader - -class DataLoaderAdapterImpl(DataLoaderAdapter, DataLoader): - pass - - -if is_torchdata_stateful_dataloader_available(): - from torchdata.stateful_dataloader import StatefulDataLoader - - class StatefulDataLoaderAdapterImpl(DataLoaderAdapter, StatefulDataLoader): - pass - - class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin): """ Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup. From 8c4f15aed8ad60648d548bbecacbe74cc58eb98b Mon Sep 17 00:00:00 2001 From: byi8220 Date: Tue, 3 Sep 2024 17:01:54 -0400 Subject: [PATCH 03/19] skip_first_batches should be used on raw dls --- tests/test_accelerator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index a2e84387ebc..17114a9fd00 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -653,7 +653,7 @@ def test_can_pickle_dataloader(self, dispatch_batches): """ Test that pickling a prepared dataloader works. """ - data = torch.arange(10) + data = torch.arange(10).to(torch_device) ds = torch.utils.data.TensorDataset(data) dl = torch.utils.data.DataLoader(ds) # Currently, StatefulDataLoader doesn't seem to support pickling, so we aren't testing that functionality @@ -674,8 +674,8 @@ def test_can_pickle_dataloader(self, dispatch_batches): assert [i for i in model_loaded._dataloaders[0]] == [i for i in original_dl] # Test skip dataloader works as expected as well - skip_dl = skip_first_batches(original_dl, 2) - assert isinstance(skip_dl, torch.utils.data.DataLoader) + skip_dl = skip_first_batches(dl, 2) + assert isinstance(skip_dl, DataLoader) assert len(skip_dl) == len(original_dl) - 2 orig_items = [i for i in original_dl] skip_dl_items = [i for i in skip_dl] From 7ce63ef70eb9b377610243d9f60ee90c58498060 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Tue, 3 Sep 2024 17:50:29 -0400 Subject: [PATCH 04/19] multigpu sanity test --- src/accelerate/data_loader.py | 1 + .../scripts/test_distributed_data_loop.py | 24 +++++++++++++ tests/test_accelerator.py | 34 ++++++++++++------- 3 files changed, 46 insertions(+), 13 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 7c8ab2dc79c..3854dbf75bc 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -1234,6 +1234,7 @@ def __iter__(self): self.end() def __len__(self): + print("len called") return len(self.base_dataloader) - self.skip_batches def __reduce__(self): diff --git a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py index 5c59a6d7326..c7823b1c5d2 100644 --- a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py +++ b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pickle import tempfile import warnings from typing import List @@ -339,6 +340,26 @@ def test_stateful_dataloader_save_state(accelerator): finally: accelerator.dataloader_config = old_dataloader_config +def test_pickled_dataloader(accelerator): + # Prepare the DataLoader + data_loader = accelerator.prepare(data_loader) + # Pickle then reload the dataloader + prepared_model_dumps = pickle.dumps(accelerator) + loaded_accelerator = pickle.loads(prepared_model_dumps) + assert len(loaded_accelerator._dataloaders) == 1 + loaded_dataloader = loaded_accelerator._dataloaders[0] + all_examples = [] + for i, batch in enumerate(loaded_dataloader): + index, _ = accelerator.gather_for_metrics((batch["index"], batch["label"])) + all_examples.extend(index.detach().cpu().numpy().tolist()) + + # Sort the examples + sorted_all_examples = sorted(all_examples) + + # Check if all elements are present in the sorted list of iterated samples + assert ( + len(set(sorted_all_examples)) == NUM_ELEMENTS + ), "Not all the dataset elements have been iterated in an epoch due to duplication of samples across processes." def main(): accelerator = create_accelerator() @@ -389,6 +410,9 @@ def main(): test_stateful_dataloader(accelerator) test_stateful_dataloader_save_state(accelerator) + # Dataloader after pickling + loader = DataLoader(dataset, shuffle=False, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS) + test_pickled_dataloader(accelerator) accelerator.end_training() diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index 17114a9fd00..cc987644193 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -27,7 +27,7 @@ from accelerate import DistributedType, infer_auto_device_map, init_empty_weights, load_checkpoint_and_dispatch from accelerate.accelerator import Accelerator -from accelerate.data_loader import DataLoaderDispatcher, DataLoaderShard, skip_first_batches +from accelerate.data_loader import DataLoaderDispatcher, DataLoaderShard, SkipDataLoader, skip_first_batches from accelerate.state import GradientState, PartialState from accelerate.test_utils import ( require_bnb, @@ -656,30 +656,38 @@ def test_can_pickle_dataloader(self, dispatch_batches): data = torch.arange(10).to(torch_device) ds = torch.utils.data.TensorDataset(data) dl = torch.utils.data.DataLoader(ds) + skip_dl = skip_first_batches(dl, 2) + # Currently, StatefulDataLoader doesn't seem to support pickling, so we aren't testing that functionality # TODO: Add support for pickling StatefulDataLoader dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, use_stateful_dataloader=False) accelerator = Accelerator(dataloader_config=dataloader_config) - original_dl = accelerator.prepare(dl) + original_dl, prepared_skip_dl = accelerator.prepare(dl, skip_dl) prepared_model_dumps = pickle.dumps(accelerator) model_loaded = pickle.loads(prepared_model_dumps) + assert len(model_loaded._dataloaders) == 2 + # Assert equality of recovered and original dataloader - assert isinstance(model_loaded._dataloaders[0], DataLoader) + loaded_dl = model_loaded._dataloaders[0] + assert isinstance(loaded_dl, DataLoader) if dispatch_batches: - assert isinstance(model_loaded._dataloaders[0], DataLoaderDispatcher) + assert isinstance(loaded_dl, DataLoaderDispatcher) else: - assert isinstance(model_loaded._dataloaders[0], DataLoaderShard) - assert len(model_loaded._dataloaders[0]) == len(original_dl) - assert [i for i in model_loaded._dataloaders[0]] == [i for i in original_dl] + assert isinstance(loaded_dl, DataLoaderShard) + assert len(loaded_dl) == len(original_dl) + assert [i for i in loaded_dl] == [i for i in original_dl] # Test skip dataloader works as expected as well - skip_dl = skip_first_batches(dl, 2) - assert isinstance(skip_dl, DataLoader) - assert len(skip_dl) == len(original_dl) - 2 - orig_items = [i for i in original_dl] - skip_dl_items = [i for i in skip_dl] - assert orig_items[2:] == skip_dl_items + loaded_skip_dl = model_loaded._dataloaders[1] + print(model_loaded._dataloaders) + assert isinstance(loaded_skip_dl, DataLoader) + if dispatch_batches: + assert isinstance(loaded_dl, DataLoaderDispatcher) + else: + assert isinstance(loaded_dl, DataLoaderShard) + assert len(loaded_skip_dl) == len(original_dl) - 2 + assert [i for i in loaded_skip_dl] == [i for i in original_dl][2:] # Ideally would be a parameterized test which works with either stateful or non-stateful dataloaders, but dependencies are a bit awkward. @require_torchdata_stateful_dataloader From 99f2afe456a8f707c7fa7ab236731f02c00ca5da Mon Sep 17 00:00:00 2001 From: byi8220 Date: Tue, 3 Sep 2024 17:52:29 -0400 Subject: [PATCH 05/19] bugs --- .../test_utils/scripts/test_distributed_data_loop.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py index c7823b1c5d2..1660adf13cd 100644 --- a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py +++ b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py @@ -340,7 +340,7 @@ def test_stateful_dataloader_save_state(accelerator): finally: accelerator.dataloader_config = old_dataloader_config -def test_pickled_dataloader(accelerator): +def test_pickled_dataloader(data_loader, accelerator): # Prepare the DataLoader data_loader = accelerator.prepare(data_loader) # Pickle then reload the dataloader @@ -412,7 +412,8 @@ def main(): # Dataloader after pickling loader = DataLoader(dataset, shuffle=False, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS) - test_pickled_dataloader(accelerator) + test_pickled_dataloader(loader, accelerator) + accelerator.end_training() From 3235e994097d6d57a4f9a00271ef5a3aed24d16b Mon Sep 17 00:00:00 2001 From: byi8220 Date: Tue, 3 Sep 2024 18:01:45 -0400 Subject: [PATCH 06/19] does this work with iterable dsets? --- src/accelerate/test_utils/scripts/test_distributed_data_loop.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py index 1660adf13cd..b32eb52b435 100644 --- a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py +++ b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py @@ -411,7 +411,7 @@ def main(): test_stateful_dataloader_save_state(accelerator) # Dataloader after pickling - loader = DataLoader(dataset, shuffle=False, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS) + loader = DataLoader(DummyIterableDataset(), shuffle=False, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS) test_pickled_dataloader(loader, accelerator) accelerator.end_training() From dff2666282ff28c9de8ae5a97f1fe5aae2eb324e Mon Sep 17 00:00:00 2001 From: byi8220 Date: Tue, 3 Sep 2024 18:13:04 -0400 Subject: [PATCH 07/19] fix typo --- .../test_utils/scripts/test_distributed_data_loop.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py index b32eb52b435..8b306617468 100644 --- a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py +++ b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py @@ -411,7 +411,8 @@ def main(): test_stateful_dataloader_save_state(accelerator) # Dataloader after pickling - loader = DataLoader(DummyIterableDataset(), shuffle=False, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS) + iterable_loader = create_dataloader(accelerator, dataset_size=NUM_ELEMENTS, batch_size=BATCH_SIZE, iterable=True) + loader = DataLoader(DummyIterableDataset(iterable_loader), shuffle=False, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS) test_pickled_dataloader(loader, accelerator) accelerator.end_training() From 70929814749c85cf88dceb77ffa3b0ca9d4c8759 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Tue, 3 Sep 2024 18:14:24 -0400 Subject: [PATCH 08/19] ignore these commits, i'm just syncing the origin so i can test on my cloud workstation --- .../test_utils/scripts/test_distributed_data_loop.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py index 8b306617468..aacc21a5dc5 100644 --- a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py +++ b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py @@ -412,8 +412,7 @@ def main(): # Dataloader after pickling iterable_loader = create_dataloader(accelerator, dataset_size=NUM_ELEMENTS, batch_size=BATCH_SIZE, iterable=True) - loader = DataLoader(DummyIterableDataset(iterable_loader), shuffle=False, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS) - test_pickled_dataloader(loader, accelerator) + test_pickled_dataloader(iterable_loader, accelerator) accelerator.end_training() From 3b0702b3ce5b552b9c7d77b8c44b8ec1e15ce611 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Tue, 3 Sep 2024 18:48:24 -0400 Subject: [PATCH 09/19] comment out failing tests, unsure if those are existing bugs or a recent regression --- .../test_utils/scripts/test_distributed_data_loop.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py index aacc21a5dc5..19aa310c82d 100644 --- a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py +++ b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py @@ -411,8 +411,10 @@ def main(): test_stateful_dataloader_save_state(accelerator) # Dataloader after pickling - iterable_loader = create_dataloader(accelerator, dataset_size=NUM_ELEMENTS, batch_size=BATCH_SIZE, iterable=True) - test_pickled_dataloader(iterable_loader, accelerator) + # This test case currently fails. + + # iterable_loader = create_dataloader(accelerator, dataset_size=NUM_ELEMENTS, batch_size=BATCH_SIZE, iterable=True) + # test_pickled_dataloader(iterable_loader, accelerator) accelerator.end_training() From b43c85d61ed7d04c316750c29c5f61939ef80516 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Tue, 3 Sep 2024 19:57:31 -0400 Subject: [PATCH 10/19] torch 2.4.0? --- .../scripts/test_distributed_data_loop.py | 24 ++++--------------- tests/test_accelerator.py | 1 + 2 files changed, 5 insertions(+), 20 deletions(-) diff --git a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py index 19aa310c82d..52be07fedbc 100644 --- a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py +++ b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py @@ -74,7 +74,7 @@ def __iter__(self): def create_accelerator(even_batches=True): dataloader_config = DataLoaderConfiguration(even_batches=even_batches) accelerator = Accelerator(dataloader_config=dataloader_config) - assert accelerator.num_processes == 2, "this script expects that two GPUs are available" + # assert accelerator.num_processes == 2, "this script expects that two GPUs are available" return accelerator @@ -346,20 +346,6 @@ def test_pickled_dataloader(data_loader, accelerator): # Pickle then reload the dataloader prepared_model_dumps = pickle.dumps(accelerator) loaded_accelerator = pickle.loads(prepared_model_dumps) - assert len(loaded_accelerator._dataloaders) == 1 - loaded_dataloader = loaded_accelerator._dataloaders[0] - all_examples = [] - for i, batch in enumerate(loaded_dataloader): - index, _ = accelerator.gather_for_metrics((batch["index"], batch["label"])) - all_examples.extend(index.detach().cpu().numpy().tolist()) - - # Sort the examples - sorted_all_examples = sorted(all_examples) - - # Check if all elements are present in the sorted list of iterated samples - assert ( - len(set(sorted_all_examples)) == NUM_ELEMENTS - ), "Not all the dataset elements have been iterated in an epoch due to duplication of samples across processes." def main(): accelerator = create_accelerator() @@ -410,11 +396,9 @@ def main(): test_stateful_dataloader(accelerator) test_stateful_dataloader_save_state(accelerator) - # Dataloader after pickling - # This test case currently fails. - - # iterable_loader = create_dataloader(accelerator, dataset_size=NUM_ELEMENTS, batch_size=BATCH_SIZE, iterable=True) - # test_pickled_dataloader(iterable_loader, accelerator) + # Test pickling an accelerator works + iterable_loader = create_dataloader(accelerator, dataset_size=NUM_ELEMENTS, batch_size=BATCH_SIZE, iterable=True) + test_pickled_dataloader(iterable_loader, accelerator) accelerator.end_training() diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index cc987644193..38d7bdc6058 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -662,6 +662,7 @@ def test_can_pickle_dataloader(self, dispatch_batches): # TODO: Add support for pickling StatefulDataLoader dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, use_stateful_dataloader=False) accelerator = Accelerator(dataloader_config=dataloader_config) + torch.manual_seed(accelerator.process_index) original_dl, prepared_skip_dl = accelerator.prepare(dl, skip_dl) prepared_model_dumps = pickle.dumps(accelerator) From 40ec962a62da8fa2d82bc20a1979f24c202b5895 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Tue, 3 Sep 2024 20:35:43 -0400 Subject: [PATCH 11/19] pickling generator issues --- .../test_utils/scripts/test_distributed_data_loop.py | 4 +--- tests/test_accelerator.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py index 52be07fedbc..31338561249 100644 --- a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py +++ b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py @@ -74,7 +74,7 @@ def __iter__(self): def create_accelerator(even_batches=True): dataloader_config = DataLoaderConfiguration(even_batches=even_batches) accelerator = Accelerator(dataloader_config=dataloader_config) - # assert accelerator.num_processes == 2, "this script expects that two GPUs are available" + assert accelerator.num_processes == 2, "this script expects that two GPUs are available" return accelerator @@ -341,8 +341,6 @@ def test_stateful_dataloader_save_state(accelerator): accelerator.dataloader_config = old_dataloader_config def test_pickled_dataloader(data_loader, accelerator): - # Prepare the DataLoader - data_loader = accelerator.prepare(data_loader) # Pickle then reload the dataloader prepared_model_dumps = pickle.dumps(accelerator) loaded_accelerator = pickle.loads(prepared_model_dumps) diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index 38d7bdc6058..85dc0a7fa51 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -662,7 +662,7 @@ def test_can_pickle_dataloader(self, dispatch_batches): # TODO: Add support for pickling StatefulDataLoader dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, use_stateful_dataloader=False) accelerator = Accelerator(dataloader_config=dataloader_config) - torch.manual_seed(accelerator.process_index) + original_dl, prepared_skip_dl = accelerator.prepare(dl, skip_dl) prepared_model_dumps = pickle.dumps(accelerator) From 1964011ee4723e459f7c5d3fc125967db1d27b68 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Tue, 3 Sep 2024 20:51:31 -0400 Subject: [PATCH 12/19] test_pickle_accelerator --- .../scripts/test_distributed_data_loop.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py index 31338561249..b6aefe0b4fd 100644 --- a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py +++ b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py @@ -247,6 +247,14 @@ def test_join_raises_warning_for_iterable_when_overriding_even_batches(): assert issubclass(w[-1].category, UserWarning) assert "only supported for map-style datasets" in str(w[-1].message) +def test_pickle_accelerator(): + accelerator = create_accelerator() + data_loader = create_dataloader(accelerator, dataset_size=32, batch_size=4) + _ = accelerator.prepare(data_loader) + pickled_accelerator = pickle.dumps(accelerator) + unpickled_accelerator = pickle.loads(pickled_accelerator) + assert accelerator.state == unpickled_accelerator.state + def test_data_loader(data_loader, accelerator): # Prepare the DataLoader @@ -340,11 +348,6 @@ def test_stateful_dataloader_save_state(accelerator): finally: accelerator.dataloader_config = old_dataloader_config -def test_pickled_dataloader(data_loader, accelerator): - # Pickle then reload the dataloader - prepared_model_dumps = pickle.dumps(accelerator) - loaded_accelerator = pickle.loads(prepared_model_dumps) - def main(): accelerator = create_accelerator() torch.manual_seed(accelerator.process_index) @@ -373,6 +376,9 @@ def main(): test_join_raises_warning_for_non_ddp_distributed(accelerator) accelerator.state.distributed_type = original_state + accelerator.print("Test pickling an accelerator") + test_pickle_accelerator() + dataset = DummyDataset() # Conventional Dataloader with shuffle=False loader = DataLoader(dataset, shuffle=False, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS) @@ -394,10 +400,6 @@ def main(): test_stateful_dataloader(accelerator) test_stateful_dataloader_save_state(accelerator) - # Test pickling an accelerator works - iterable_loader = create_dataloader(accelerator, dataset_size=NUM_ELEMENTS, batch_size=BATCH_SIZE, iterable=True) - test_pickled_dataloader(iterable_loader, accelerator) - accelerator.end_training() From c6a68c3da052c3d022bcc754902f27086e11e4c1 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Tue, 3 Sep 2024 20:59:25 -0400 Subject: [PATCH 13/19] test_pickle_accelerator should work now) --- .../test_utils/scripts/test_distributed_data_loop.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py index b6aefe0b4fd..dd49457be89 100644 --- a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py +++ b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py @@ -253,7 +253,8 @@ def test_pickle_accelerator(): _ = accelerator.prepare(data_loader) pickled_accelerator = pickle.dumps(accelerator) unpickled_accelerator = pickle.loads(pickled_accelerator) - assert accelerator.state == unpickled_accelerator.state + # TODO: Maybe this should be implemented as __eq__ for AcceleratorState? + assert accelerator.state.__dict__ == unpickled_accelerator.state.__dict__ def test_data_loader(data_loader, accelerator): From 0f0a1f135649c73f79618ca2790ef49de411ba57 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Wed, 4 Sep 2024 10:19:54 -0400 Subject: [PATCH 14/19] base.__len__() -> len(base) --- src/accelerate/data_loader.py | 31 +++---------------------------- tests/test_accelerator.py | 5 +++++ 2 files changed, 8 insertions(+), 28 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 3854dbf75bc..88e2b3d3620 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -572,20 +572,7 @@ def __iter__(self): self.end() def __reduce__(self): - return ( - DataLoaderShard, - ( - self.base_dataloader.dataset, - self.device, - self.rng_types, - self.synchronized_generator, - self.skip_batches, - self.use_stateful_dataloader, - self._drop_last, - self._non_blocking, - ), - self.__dict__, - ) + return super().__reduce__() def set_epoch(self, epoch: int): # In case it is manually passed in, the user can set it to what they like @@ -872,7 +859,7 @@ def set_epoch(self, epoch: int): self.dataset.set_epoch(epoch) def __len__(self): - whole_length = self.base_dataloader.__len__() + whole_length = len(self.base_dataloader) if self.split_batches: return whole_length elif self._drop_last: @@ -881,19 +868,7 @@ def __len__(self): return math.ceil(whole_length / self.state.num_processes) def __reduce__(self): - return ( - DataLoaderDispatcher, - ( - self.base_dataloader.dataset, - self.split_batches, - self.skip_batches, - self.use_stateful_dataloader, - self._drop_last, - self._non_blocking, - self.slice_fn, - ), - self.__dict__, - ) + return super().__reduce__() @property def total_batch_size(self): diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index 85dc0a7fa51..97a8b52039c 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -664,6 +664,11 @@ def test_can_pickle_dataloader(self, dispatch_batches): accelerator = Accelerator(dataloader_config=dataloader_config) original_dl, prepared_skip_dl = accelerator.prepare(dl, skip_dl) + if dispatch_batches: + assert isinstance(original_dl, DataLoaderDispatcher) + else: + assert isinstance(original_dl, DataLoaderShard) + prepared_model_dumps = pickle.dumps(accelerator) model_loaded = pickle.loads(prepared_model_dumps) From 17713e7faa7de576c61943d3755dfa33f367221e Mon Sep 17 00:00:00 2001 From: byi8220 Date: Wed, 4 Sep 2024 10:21:45 -0400 Subject: [PATCH 15/19] undo reduce --- src/accelerate/data_loader.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 88e2b3d3620..843d594242d 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -572,7 +572,20 @@ def __iter__(self): self.end() def __reduce__(self): - return super().__reduce__() + return ( + DataLoaderShard, + ( + self.base_dataloader.dataset, + self.device, + self.rng_types, + self.synchronized_generator, + self.skip_batches, + self.use_stateful_dataloader, + self._drop_last, + self._non_blocking, + ), + self.__dict__, + ) def set_epoch(self, epoch: int): # In case it is manually passed in, the user can set it to what they like From 073d7b340d62f94eafaf50a7e98a7d793ab3eb02 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Wed, 4 Sep 2024 10:31:32 -0400 Subject: [PATCH 16/19] undo super().__reduce__() again --- src/accelerate/data_loader.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 843d594242d..df5ee2b0875 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -881,7 +881,19 @@ def __len__(self): return math.ceil(whole_length / self.state.num_processes) def __reduce__(self): - return super().__reduce__() + return ( + DataLoaderDispatcher, + ( + self.base_dataloader.dataset, + self.split_batches, + self.skip_batches, + self.use_stateful_dataloader, + self._drop_last, + self._non_blocking, + self.slice_fn, + ), + self.__dict__, + ) @property def total_batch_size(self): From ee681ed23cae9c7a2db07c9e9b0f40b89d7ed972 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Wed, 4 Sep 2024 11:39:49 -0400 Subject: [PATCH 17/19] pass args through superclass --- src/accelerate/data_loader.py | 40 +++++++---------------------------- 1 file changed, 8 insertions(+), 32 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index df5ee2b0875..6e33f074aa6 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -572,20 +572,9 @@ def __iter__(self): self.end() def __reduce__(self): - return ( - DataLoaderShard, - ( - self.base_dataloader.dataset, - self.device, - self.rng_types, - self.synchronized_generator, - self.skip_batches, - self.use_stateful_dataloader, - self._drop_last, - self._non_blocking, - ), - self.__dict__, - ) + args = super().__reduce__() + return (DataLoaderShard, *args[1:]) + def set_epoch(self, epoch: int): # In case it is manually passed in, the user can set it to what they like @@ -881,19 +870,8 @@ def __len__(self): return math.ceil(whole_length / self.state.num_processes) def __reduce__(self): - return ( - DataLoaderDispatcher, - ( - self.base_dataloader.dataset, - self.split_batches, - self.skip_batches, - self.use_stateful_dataloader, - self._drop_last, - self._non_blocking, - self.slice_fn, - ), - self.__dict__, - ) + args = super().__reduce__() + return (DataLoaderDispatcher, *args[1:]) @property def total_batch_size(self): @@ -1238,11 +1216,9 @@ def __len__(self): return len(self.base_dataloader) - self.skip_batches def __reduce__(self): - return ( - SkipDataLoader, - (self.base_dataloader.dataset, self.skip_batches, self.use_stateful_dataloader), - self.__dict__, - ) + args = super().__reduce__() + return (SkipDataLoader, *args[1:]) + def skip_first_batches(dataloader, num_batches=0): From 47557b7b56ce3c859f0b933ddb2cebcb73727bdb Mon Sep 17 00:00:00 2001 From: byi8220 Date: Wed, 4 Sep 2024 16:34:08 -0400 Subject: [PATCH 18/19] remove prints --- src/accelerate/data_loader.py | 1 - tests/test_accelerator.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 6e33f074aa6..97a99d54ea9 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -1212,7 +1212,6 @@ def __iter__(self): self.end() def __len__(self): - print("len called") return len(self.base_dataloader) - self.skip_batches def __reduce__(self): diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index 97a8b52039c..f83bfeafa17 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -686,7 +686,6 @@ def test_can_pickle_dataloader(self, dispatch_batches): # Test skip dataloader works as expected as well loaded_skip_dl = model_loaded._dataloaders[1] - print(model_loaded._dataloaders) assert isinstance(loaded_skip_dl, DataLoader) if dispatch_batches: assert isinstance(loaded_dl, DataLoaderDispatcher) From 9bba8f2e7fb99788e747cce922ecd2808f7cfef0 Mon Sep 17 00:00:00 2001 From: byi8220 Date: Thu, 5 Sep 2024 11:06:49 -0400 Subject: [PATCH 19/19] doc changes + make style && make quality --- src/accelerate/data_loader.py | 23 +++++++++++++++---- .../scripts/test_distributed_data_loop.py | 2 ++ tests/test_accelerator.py | 4 ++-- 3 files changed, 23 insertions(+), 6 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 97a99d54ea9..9275f315322 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -435,8 +435,9 @@ def load_state_dict(self, state_dict): @property def __class__(self): """ - In order to maintain backwards compatability with other code, we need to ensure `isinstance(obj, DataLoader)` returs true. - This is because some downstream code assumes that the `DataLoader` is the base class of the object. + In order to maintain backwards compatability with other code, we need to ensure `isinstance(obj, DataLoader)` + returs true. This is because some downstream code assumes that the `DataLoader` is the base class of the + object. """ return self.base_dataloader.__class__ @@ -479,6 +480,7 @@ def _update_state_dict(self): # Then tag if we are at the end of the dataloader self.dl_state_dict["_iterator_finished"] = self.end_of_dataloader + class DataLoaderShard(DataLoaderAdapter, DataLoaderStateMixin): """ Subclass of `DataLoaderAdapter` that will deal with device placement and current distributed setup. @@ -572,10 +574,14 @@ def __iter__(self): self.end() def __reduce__(self): + """ + Define the `__reduce__` method to ensure a `DataLoaderShard` can be pickled and unpickled. This needs to be + explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its + `__class__` member. + """ args = super().__reduce__() return (DataLoaderShard, *args[1:]) - def set_epoch(self, epoch: int): # In case it is manually passed in, the user can set it to what they like if self.iteration != epoch: @@ -870,6 +876,11 @@ def __len__(self): return math.ceil(whole_length / self.state.num_processes) def __reduce__(self): + """ + Define the `__reduce__` method to ensure a `DataLoaderDispatcher` can be pickled and unpickled. This needs to + be explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its + `__class__` member. + """ args = super().__reduce__() return (DataLoaderDispatcher, *args[1:]) @@ -1215,11 +1226,15 @@ def __len__(self): return len(self.base_dataloader) - self.skip_batches def __reduce__(self): + """ + Define the `__reduce__` method to ensure a `SkipDataLoader` can be pickled and unpickled. This needs to be + explicitly defined since default pickling behavior is broken by `DataLoaderAdapter` messing with its + `__class__` member. + """ args = super().__reduce__() return (SkipDataLoader, *args[1:]) - def skip_first_batches(dataloader, num_batches=0): """ Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. Should not be used if diff --git a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py index dd49457be89..899dc6e3f87 100644 --- a/src/accelerate/test_utils/scripts/test_distributed_data_loop.py +++ b/src/accelerate/test_utils/scripts/test_distributed_data_loop.py @@ -247,6 +247,7 @@ def test_join_raises_warning_for_iterable_when_overriding_even_batches(): assert issubclass(w[-1].category, UserWarning) assert "only supported for map-style datasets" in str(w[-1].message) + def test_pickle_accelerator(): accelerator = create_accelerator() data_loader = create_dataloader(accelerator, dataset_size=32, batch_size=4) @@ -349,6 +350,7 @@ def test_stateful_dataloader_save_state(accelerator): finally: accelerator.dataloader_config = old_dataloader_config + def main(): accelerator = create_accelerator() torch.manual_seed(accelerator.process_index) diff --git a/tests/test_accelerator.py b/tests/test_accelerator.py index f83bfeafa17..00c18506ced 100644 --- a/tests/test_accelerator.py +++ b/tests/test_accelerator.py @@ -27,7 +27,7 @@ from accelerate import DistributedType, infer_auto_device_map, init_empty_weights, load_checkpoint_and_dispatch from accelerate.accelerator import Accelerator -from accelerate.data_loader import DataLoaderDispatcher, DataLoaderShard, SkipDataLoader, skip_first_batches +from accelerate.data_loader import DataLoaderDispatcher, DataLoaderShard, skip_first_batches from accelerate.state import GradientState, PartialState from accelerate.test_utils import ( require_bnb, @@ -663,7 +663,7 @@ def test_can_pickle_dataloader(self, dispatch_batches): dataloader_config = DataLoaderConfiguration(dispatch_batches=dispatch_batches, use_stateful_dataloader=False) accelerator = Accelerator(dataloader_config=dataloader_config) - original_dl, prepared_skip_dl = accelerator.prepare(dl, skip_dl) + original_dl, _ = accelerator.prepare(dl, skip_dl) if dispatch_batches: assert isinstance(original_dl, DataLoaderDispatcher) else: