diff --git a/alpa/serialization.py b/alpa/serialization.py index 79a57e96f..63dd5d4de 100644 --- a/alpa/serialization.py +++ b/alpa/serialization.py @@ -15,7 +15,8 @@ import numpy as np from alpa.device_mesh import (DistributedArray, ReplicatedDistributedArray, - get_global_virtual_physical_mesh) + get_global_virtual_physical_mesh, + get_global_physical_mesh) logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) @@ -156,13 +157,20 @@ def restore_checkpoint(ckpt_dir: Union[str, os.PathLike], step: int, flat_info = tree_leaves(placement_specs) flat_load_state = [] mesh_group = get_global_virtual_physical_mesh().launched_physical_mesh_group - assert mesh_group is not None + physical_mesh = get_global_physical_mesh() + + assert mesh_group is not None or physical_mesh is not None for path, info in zip(state_paths, flat_info): if info is None: logger.warning("Variable is not used, skip loading it") flat_load_state.append(None) - if len(info.mesh_ids) == 1: + elif mesh_group is None: + dist_arr = DistributedArray.load(os.path.join(ckpt_dir, path), + info.aval, physical_mesh, + info.sharding_specs[0]) + flat_load_state.append(dist_arr) + elif len(info.mesh_ids) == 1: dist_arr = DistributedArray.load(os.path.join(ckpt_dir, path), info.aval, mesh_group[info.mesh_ids[0]],