Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add safe_globals to resume training on PyTorch 2.6 #34632

Merged
merged 1 commit into from
Nov 25, 2024

Conversation

dvrogozh
Copy link
Contributor

@dvrogozh dvrogozh commented Nov 6, 2024

Starting from version 2.4 PyTorch introduces a stricter check for the objects which can be loaded with torch.load(). Starting from version 2.6 loading with weights_only=True requires allowlisting of such objects.

This commit adds allowlist of some numpy objects used to load model checkpoints. Usage is restricted by context manager. User can still call torch.serialization.add_safe_globals() to add other objects into the safe globals list.

Accelerate library also stepped into same problem and addressed it with PR-3036.

Fixes: #34631
See: pytorch/pytorch#137602
See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
See: huggingface/accelerate#3036

CC: @muellerzr @SunMarc

src/transformers/trainer.py Outdated Show resolved Hide resolved
@dvrogozh
Copy link
Contributor Author

@muellerzr, @SunMarc, @ArthurZucker : can you, please, help comment on this PR? see issue #34631 on details.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice ! Thanks for adding this ! Left a comment

src/transformers/trainer.py Outdated Show resolved Hide resolved
@ydshieh
Copy link
Collaborator

ydshieh commented Nov 15, 2024

I am getting

FAILED tests/trainer/test_trainer.py::TrainerIntegrationTest::test_can_resume_training - AttributeError: module 'numpy' has no attribute 'dtypes'. Did you mean: 'dtype'?

when running

python3 -m pytest tests/trainer/test_trainer.py::TrainerIntegrationTest::test_can_resume_training

against this PR.

@dvrogozh
Copy link
Contributor Author

@ydshieh : this might be due to numpy version. dtypes was added in 1.25 according to https://numpy.org/doc/2.1/reference/routines.dtypes.html#module-numpy.dtypes. Locally I have 1.26.4. Which version do you have?

I will work on using context manager since there is an alignment on that and also tune a list per versioning of numpy.

@ydshieh
Copy link
Collaborator

ydshieh commented Nov 15, 2024

On our CI runner , I get numpy=1.24.3

@mikaylagawarecki
Copy link

The numpy GLOBALs for dtypes that need to be allowlisted might need an if statement depending on whether version < 1.25 or not, there's some documentation on this here https://pytorch.org/docs/main/notes/serialization.html#troubleshooting-weights-only

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @muellerzr if you can have a look as well!

# See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
if version.parse(torch.__version__) >= version.parse("2.4.0"):
torch.serialization.add_safe_globals(
[np.core.multiarray._reconstruct, np.ndarray, np.dtype, np.dtypes.UInt32DType]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could have a SAFE_TRANSFORMERS_GLOBAL with these no? this way people can easily update them?
TBH I prefer the context manager but want to have the least duplication as possible!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found that calling torch.serialization.add_safe_globals() still works to add additional safe global staff. SAFE_TRANSFORMERS_GLOBAL can also be considered. Let me know if you see the need.

allowlist = [np.core.multiarray._reconstruct, np.ndarray, np.dtype]
# numpy >1.25 defines numpy.dtypes.UInt32DType, but below works for
# all versions of numpy
allowlist += [type(np.dtype(np.uint32))]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I add any other numpy dtypes in the list? As of now I spotted only np.unit32 in the Transformers list as the one needed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only one I don't see from accelerate is encode, however if things pass here without it it's accelerate specific and we don't need to worry about it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Transformer tests did pass on my side without adding encode. This indeed seems accelerate specific.

Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Just a documentation suggestion but this all looks correct

src/transformers/trainer.py Show resolved Hide resolved
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@dvrogozh
Copy link
Contributor Author

Thanks! Just a documentation suggestion but this all looks correct

@muellerz : done, added a link to Accelerate PR.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM ! Just a nit

src/transformers/trainer.py Outdated Show resolved Hide resolved
@dvrogozh
Copy link
Contributor Author

LGTM ! Just a nit

@SunMarc : addressed, reused approach from accelerate on numpy.core deprecation.

# See: https://github.com/pytorch/pytorch/pull/137602
# See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
# See: https://github.com/huggingface/accelerate/pull/3036
if version.parse(torch.__version__) < version.parse("2.4.0"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a nit: should it be "2.6.0" here or it's really necessary being "2.4.0"?

Copy link
Contributor Author

@dvrogozh dvrogozh Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switched to version < 2.6.0a0. Indeed, on switching to context manager I overlooked that it was introduced later. Overall:

  • torch.serialization.add_safe_globals appeared in pytorch 2.4
  • torch.serialization.safe_globals (context manager) appeared in 2.5
  • And pytorch 2.6 flipped default of weights_only in torch.load from False to True

Overall, it indeed does not make sense to have this code working for versions earlier than 2.6 unless we will start calling torch.load with explicit weights_only=True.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi! A tiny question: how to get 2.6.0a0 installed. I know how to install night but it gets dev202411xx instead of a0

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Anyway, good to use a0 here for now. Once 2.6 is released, we can change it to 2.6.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi! A tiny question: how to get 2.6.0a0 installed.

I am getting this building from sources. And <2.6.0 does not work for me on my build. So, 2.6.0a0 is my best effort to get the check working for my current build. I did not know that nightly builds get dev202411xx, I thought they also give a0. I wonder will the check still work for nightly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked. <2.6.0a0 won't work with nightly. So, I switched to a check I ones spotted in a code by Narsil. This should handle both cases, building from sources and using 2.6 nightly (I checked - works for both on my side):

if version.parse(torch.__version__).release < version.parse("2.6").release:

Copy link
Collaborator

@ydshieh ydshieh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

Starting from version 2.4 PyTorch introduces a stricter check for the objects which
can be loaded with torch.load(). Starting from version 2.6 loading with weights_only=True
requires allowlisting of such objects.

This commit adds allowlist of some numpy objects used to load model checkpoints.
Usage is restricted by context manager. User can still additionally call
torch.serialization.add_safe_globals() to add other objects into the safe globals list.

Accelerate library also stepped into same problem and addressed it with PR-3036.

Fixes: huggingface#34631
See: pytorch/pytorch#137602
See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
See: huggingface/accelerate#3036
Signed-off-by: Dmitry Rogozhkin <[email protected]>
@ArthurZucker ArthurZucker merged commit 1339a14 into huggingface:main Nov 25, 2024
24 checks passed
@ArthurZucker
Copy link
Collaborator

Thanks for fixing 🤗

BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
Starting from version 2.4 PyTorch introduces a stricter check for the objects which
can be loaded with torch.load(). Starting from version 2.6 loading with weights_only=True
requires allowlisting of such objects.

This commit adds allowlist of some numpy objects used to load model checkpoints.
Usage is restricted by context manager. User can still additionally call
torch.serialization.add_safe_globals() to add other objects into the safe globals list.

Accelerate library also stepped into same problem and addressed it with PR-3036.

Fixes: huggingface#34631
See: pytorch/pytorch#137602
See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
See: huggingface/accelerate#3036

Signed-off-by: Dmitry Rogozhkin <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

safe_globals are needed to resume training on upcoming PyTorch 2.6
7 participants