Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow DataLoaderAdapter subclasses to be pickled by implementing __reduce__ #3074

Merged
merged 20 commits into from
Sep 5, 2024

Conversation

byi8220
Copy link
Contributor

@byi8220 byi8220 commented Sep 3, 2024

What does this PR do?

Fixes #3070

This PR contains the following:

  1. Cleanup the dynamic class resolution of DataLoaderAdapter a bit. Instead of performing dynamic class overriding in __init__, we instead override the property to pass through the base dataloader's class.

    Note: This means isinstance(obj, DataLoaderAdapter) breaks. However, since DataLoaderAdapter was introduced very recently, it's very unlikely this will break anything, since there shouldn't be any reason why existing downstream code would be checking this.

  2. Implement __reduce__ in DataLoaderAdapter subclasses to allow pickling of these classes, despite the __class__ override. This implementation just repasses the object's init arguments and __dict__ back into the constructor.

  3. Some cleanup in DataLoaderDispatcher (use base_dataloader instead of super()), and some cleanup in SkipDataLoader (implement __len__)

Some caveats:

  1. StatefulDataLoader's state_dict might not be pickleable.
  2. Pickling in general seems dubious and fragile.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@muellerzr
@BenjaminBossan

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this issue so quickly. I agree that pickling the data loader is (even if indirectly, via accelerator) is not generally a good idea. I'll add a note to the skorch docs that this should be avoided. Still, I think it's better to fix this, as we don't know what other code may depend on it, even if only indirectly (e.g. pickle is often being used under the hood to pass around data between processes).

As to the solution, I made some comments which I think will make the code more robust. Regarding the use of @property for __class__, I really feel like we're dancing on a razor's edge here (the limits of the Python data model), though this is probably no worse than what we had before.

Comment on lines 577 to 587
(
self.base_dataloader.dataset,
self.device,
self.rng_types,
self.synchronized_generator,
self.skip_batches,
self.use_stateful_dataloader,
self._drop_last,
self._non_blocking,
),
self.__dict__,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of hard-coding these arguments, how about calling super().__reduce__()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if that works. IIUC, __reduce__ is called to specify how we want to pickle an object for reconstruction. In this context, if I'm understanding right, the current code states "We want to pickle a DataLoaderShard by constructing a DataLoaderShard with these arguments and filling it's __dict__ with my elements."

If we replace this with return super().__reduce__(), then doesn't this change this to "We want to pickle a DataLoaderShard by turning it into a DataLoader." In that case, the following would happen:

## Behavior of code using super().__reduce__():

# accelerator prepares a dataloader shard from a dataloader
dl_shard = original_accel.prepare(dl)
assert isinstance(dl_shard, DataLoaderShard)
assert isinstance(original_accel._dataloaders[0], DataLoaderShard)

# Try to pickle and unpickle the dataloader
loaded_accel = pickle.loads(pickle.dumps(original_accel)

# The unpickled accelerator's underlying dataloaders are no longer sharded
assert isinstance(loaded_accel._dataloaders[0], DataLoader)

Is this the intended behavior?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, what I meant is something like:

args = super().__reduce__()
return (DataLoaderShard, *args[1:])

That way, the 2nd and 3rd argument (which are roughly the arguments needed to construct the instance) don't need to be hardcoded, only the constructor class. If we hard-code the arguments, we need to remember to change them each time they're changed on the class.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. The above block doesn't include the object __dict__, but I'm not sure if we need to add that or not. It passes my existing tests as written

src/accelerate/data_loader.py Outdated Show resolved Hide resolved
Comment on lines 886 to 895
(
self.base_dataloader.dataset,
self.split_batches,
self.skip_batches,
self.use_stateful_dataloader,
self._drop_last,
self._non_blocking,
self.slice_fn,
),
self.__dict__,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above.

Comment on lines 1243 to 1244
(self.base_dataloader.dataset, self.skip_batches, self.use_stateful_dataloader),
self.__dict__,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above.

@byi8220
Copy link
Contributor Author

byi8220 commented Sep 4, 2024

Regarding the use of @Property for class, I really feel like we're dancing on a razor's edge here (the limits of the Python data model),

We really are. This idea comes up a few times if you look for it online, and unfortunately given the nature of the problem, this is the absolute last resort: https://stackoverflow.com/a/52172876. The problem is that the more sensible solutions such as using a metaclass or tampering with __instancecheck__ requires us to have control over the base class.

I will reiterate, the only reason for this magic code to exist is to make sure that isinstance(dl_shard, DataLoader) works. If this requirement is lifted, we don't need to mess with the class structure at all.

@BenjaminBossan
Copy link
Member

Yeah, I know there is really no good way to get to the goal, so we need to pick our poison. I just needed to express my desperation at the situation ;-)

Copy link
Collaborator

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for fixing this! Looks like some leftover print statements, I'll also run it locally to make sure there's nothing else lingering/unexpected too.

src/accelerate/data_loader.py Outdated Show resolved Hide resolved
tests/test_accelerator.py Outdated Show resolved Hide resolved
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@byi8220 byi8220 marked this pull request as ready for review September 4, 2024 20:34
Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for quickly taking up this task and finding a solution. 🤞 that this will avoid further trouble down the line.

A nit based on my personal preference, I like to add a comment to non-obvious code like the __reduce__ method to indicate why it's needed. This way, the reader doesn't have to git blame and search through github to understand the code.

@muellerzr
Copy link
Collaborator

@byi8220 can you do a quick make style; make quality then we can include it in the patchfix coming out soon :)

@byi8220
Copy link
Contributor Author

byi8220 commented Sep 5, 2024

Added docstrings for __reduce__ methods and ran make style && make quality

@muellerzr muellerzr merged commit f1ca8ac into huggingface:main Sep 5, 2024
25 checks passed
muellerzr pushed a commit that referenced this pull request Sep 5, 2024
…educe__` (#3074)

* initial fix for breaking accelerator pickling

* cleanup

* skip_first_batches should be used on raw dls

* multigpu sanity test

* bugs

* does this work with iterable dsets?

* fix typo

* ignore these commits, i'm just syncing the origin so i can test on my cloud workstation

* comment out failing tests, unsure if those are existing bugs or a recent regression

* torch 2.4.0?

* pickling generator issues

* test_pickle_accelerator

* test_pickle_accelerator should work now)

* base.__len__() -> len(base)

* undo reduce

* undo super().__reduce__() again

* pass args through superclass

* remove prints

* doc changes + make style && make quality
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Pickling the accelerator after preparing data loader no longer possible
4 participants