diff --git a/flax/training/checkpoints.py b/flax/training/checkpoints.py index 747ba63431..3d48731305 100644 --- a/flax/training/checkpoints.py +++ b/flax/training/checkpoints.py @@ -76,7 +76,7 @@ # Orbax main checkpoint file name. ORBAX_CKPT_FILENAME = 'checkpoint' -ORBAX_MANIFEST_OCDBT = 'manifest.ocdbt' +ORBAX_METADATA_FILENAME = '_METADATA' PyTree = Any @@ -123,7 +123,7 @@ def _safe_remove(path: str): def _is_orbax_checkpoint(path: str) -> bool: return io.exists(os.path.join(path, ORBAX_CKPT_FILENAME)) or io.exists( - os.path.join(path, ORBAX_MANIFEST_OCDBT) + os.path.join(path, ORBAX_METADATA_FILENAME) )