Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add safe_globals to resume training on PyTorch 2.6 #34632

Merged
merged 1 commit into from
Nov 25, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,25 @@ def _get_fsdp_ckpt_kwargs():
return {}


def safe_globals():
# Starting from version 2.4 PyTorch introduces a check for the objects loaded
# with torch.load(weights_only=True). Starting from 2.6 weights_only=True becomes
# a default and requires allowlisting of objects being loaded.
# See: https://github.com/pytorch/pytorch/pull/137602
# See: https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals
dvrogozh marked this conversation as resolved.
Show resolved Hide resolved
# See: https://github.com/huggingface/accelerate/pull/3036
if version.parse(torch.__version__).release < version.parse("2.6").release:
return contextlib.nullcontext()

np_core = np._core if version.parse(np.__version__) >= version.parse("2.0.0") else np.core
allowlist = [np_core.multiarray._reconstruct, np.ndarray, np.dtype]
# numpy >1.25 defines numpy.dtypes.UInt32DType, but below works for
# all versions of numpy
allowlist += [type(np.dtype(np.uint32))]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I add any other numpy dtypes in the list? As of now I spotted only np.unit32 in the Transformers list as the one needed.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only one I don't see from accelerate is encode, however if things pass here without it it's accelerate specific and we don't need to worry about it

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Transformer tests did pass on my side without adding encode. This indeed seems accelerate specific.


return torch.serialization.safe_globals(allowlist)


if TYPE_CHECKING:
import optuna

Expand Down Expand Up @@ -3052,7 +3071,8 @@ def _load_rng_state(self, checkpoint):
)
return

checkpoint_rng_state = torch.load(rng_file)
with safe_globals():
checkpoint_rng_state = torch.load(rng_file)
random.setstate(checkpoint_rng_state["python"])
np.random.set_state(checkpoint_rng_state["numpy"])
torch.random.set_rng_state(checkpoint_rng_state["cpu"])
Expand Down
Loading