Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 681197359
  • Loading branch information
IvyZX authored and Flax Authors committed Oct 1, 2024
1 parent 9c162ab commit b722f4d
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions flax/training/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from flax import config, core, errors, io, serialization, traverse_util
from flax.training import orbax_utils


_READ_CHECKPOINT_EVENT: str = '/jax/checkpoint/read/durations_sec'
_WRITE_CHECKPOINT_EVENT: str = '/jax/checkpoint/write/durations_sec'

Expand Down Expand Up @@ -1136,6 +1137,7 @@ def restore_checkpoint(
)
return restored

# Legacy Flax checkpoint restoration.
ckpt_size = io.getsize(ckpt_path)
with io.GFile(ckpt_path, 'rb') as fp:
if parallel and fp.seekable():
Expand Down

0 comments on commit b722f4d

Please sign in to comment.