diff --git a/flax/training/checkpoints.py b/flax/training/checkpoints.py index 39dfa8725b..60ef6c4651 100644 --- a/flax/training/checkpoints.py +++ b/flax/training/checkpoints.py @@ -45,6 +45,7 @@ from flax import config, core, errors, io, serialization, traverse_util from flax.training import orbax_utils + _READ_CHECKPOINT_EVENT: str = '/jax/checkpoint/read/durations_sec' _WRITE_CHECKPOINT_EVENT: str = '/jax/checkpoint/write/durations_sec' @@ -1136,6 +1137,7 @@ def restore_checkpoint( ) return restored + # Legacy Flax checkpoint restoration. ckpt_size = io.getsize(ckpt_path) with io.GFile(ckpt_path, 'rb') as fp: if parallel and fp.seekable():