-
Notifications
You must be signed in to change notification settings - Fork 977
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow DataLoaderAdapter subclasses to be pickled by implementing __reduce__
#3074
Conversation
… cloud workstation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for working on this issue so quickly. I agree that pickling the data loader is (even if indirectly, via accelerator) is not generally a good idea. I'll add a note to the skorch docs that this should be avoided. Still, I think it's better to fix this, as we don't know what other code may depend on it, even if only indirectly (e.g. pickle is often being used under the hood to pass around data between processes).
As to the solution, I made some comments which I think will make the code more robust. Regarding the use of @property
for __class__
, I really feel like we're dancing on a razor's edge here (the limits of the Python data model), though this is probably no worse than what we had before.
src/accelerate/data_loader.py
Outdated
( | ||
self.base_dataloader.dataset, | ||
self.device, | ||
self.rng_types, | ||
self.synchronized_generator, | ||
self.skip_batches, | ||
self.use_stateful_dataloader, | ||
self._drop_last, | ||
self._non_blocking, | ||
), | ||
self.__dict__, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of hard-coding these arguments, how about calling super().__reduce__()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know if that works. IIUC, __reduce__
is called to specify how we want to pickle an object for reconstruction. In this context, if I'm understanding right, the current code states "We want to pickle a DataLoaderShard by constructing a DataLoaderShard with these arguments and filling it's __dict__
with my elements."
If we replace this with return super().__reduce__()
, then doesn't this change this to "We want to pickle a DataLoaderShard by turning it into a DataLoader." In that case, the following would happen:
## Behavior of code using super().__reduce__():
# accelerator prepares a dataloader shard from a dataloader
dl_shard = original_accel.prepare(dl)
assert isinstance(dl_shard, DataLoaderShard)
assert isinstance(original_accel._dataloaders[0], DataLoaderShard)
# Try to pickle and unpickle the dataloader
loaded_accel = pickle.loads(pickle.dumps(original_accel)
# The unpickled accelerator's underlying dataloaders are no longer sharded
assert isinstance(loaded_accel._dataloaders[0], DataLoader)
Is this the intended behavior?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, what I meant is something like:
args = super().__reduce__()
return (DataLoaderShard, *args[1:])
That way, the 2nd and 3rd argument (which are roughly the arguments needed to construct the instance) don't need to be hardcoded, only the constructor class. If we hard-code the arguments, we need to remember to change them each time they're changed on the class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done. The above block doesn't include the object __dict__
, but I'm not sure if we need to add that or not. It passes my existing tests as written
src/accelerate/data_loader.py
Outdated
( | ||
self.base_dataloader.dataset, | ||
self.split_batches, | ||
self.skip_batches, | ||
self.use_stateful_dataloader, | ||
self._drop_last, | ||
self._non_blocking, | ||
self.slice_fn, | ||
), | ||
self.__dict__, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as above.
src/accelerate/data_loader.py
Outdated
(self.base_dataloader.dataset, self.skip_batches, self.use_stateful_dataloader), | ||
self.__dict__, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as above.
We really are. This idea comes up a few times if you look for it online, and unfortunately given the nature of the problem, this is the absolute last resort: https://stackoverflow.com/a/52172876. The problem is that the more sensible solutions such as using a metaclass or tampering with I will reiterate, the only reason for this magic code to exist is to make sure that |
Yeah, I know there is really no good way to get to the goal, so we need to pick our poison. I just needed to express my desperation at the situation ;-) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this! Looks like some leftover print statements, I'll also run it locally to make sure there's nothing else lingering/unexpected too.
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for quickly taking up this task and finding a solution. 🤞 that this will avoid further trouble down the line.
A nit based on my personal preference, I like to add a comment to non-obvious code like the __reduce__
method to indicate why it's needed. This way, the reader doesn't have to git blame and search through github to understand the code.
@byi8220 can you do a quick |
Added docstrings for |
…educe__` (#3074) * initial fix for breaking accelerator pickling * cleanup * skip_first_batches should be used on raw dls * multigpu sanity test * bugs * does this work with iterable dsets? * fix typo * ignore these commits, i'm just syncing the origin so i can test on my cloud workstation * comment out failing tests, unsure if those are existing bugs or a recent regression * torch 2.4.0? * pickling generator issues * test_pickle_accelerator * test_pickle_accelerator should work now) * base.__len__() -> len(base) * undo reduce * undo super().__reduce__() again * pass args through superclass * remove prints * doc changes + make style && make quality
What does this PR do?
Fixes #3070
This PR contains the following:
Cleanup the dynamic class resolution of DataLoaderAdapter a bit. Instead of performing dynamic class overriding in
__init__
, we instead override the property to pass through the base dataloader's class.Note: This means
isinstance(obj, DataLoaderAdapter)
breaks. However, sinceDataLoaderAdapter
was introduced very recently, it's very unlikely this will break anything, since there shouldn't be any reason why existing downstream code would be checking this.Implement
__reduce__
in DataLoaderAdapter subclasses to allow pickling of these classes, despite the__class__
override. This implementation just repasses the object's init arguments and__dict__
back into the constructor.Some cleanup in DataLoaderDispatcher (use base_dataloader instead of
super()
), and some cleanup in SkipDataLoader (implement__len__
)Some caveats:
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@muellerzr
@BenjaminBossan