Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

Commit

Permalink
[FIX] restore_checkpoint without mesh group (#854)
Browse files Browse the repository at this point in the history
  • Loading branch information
zw123han authored Jan 17, 2023
1 parent 2c5164f commit 589053f
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions alpa/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import numpy as np

from alpa.device_mesh import (DistributedArray, ReplicatedDistributedArray,
get_global_virtual_physical_mesh)
get_global_virtual_physical_mesh,
get_global_physical_mesh)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -156,13 +157,20 @@ def restore_checkpoint(ckpt_dir: Union[str, os.PathLike], step: int,
flat_info = tree_leaves(placement_specs)
flat_load_state = []
mesh_group = get_global_virtual_physical_mesh().launched_physical_mesh_group
assert mesh_group is not None
physical_mesh = get_global_physical_mesh()

assert mesh_group is not None or physical_mesh is not None

for path, info in zip(state_paths, flat_info):
if info is None:
logger.warning("Variable is not used, skip loading it")
flat_load_state.append(None)
if len(info.mesh_ids) == 1:
elif mesh_group is None:
dist_arr = DistributedArray.load(os.path.join(ckpt_dir, path),
info.aval, physical_mesh,
info.sharding_specs[0])
flat_load_state.append(dist_arr)
elif len(info.mesh_ids) == 1:
dist_arr = DistributedArray.load(os.path.join(ckpt_dir,
path), info.aval,
mesh_group[info.mesh_ids[0]],
Expand Down

0 comments on commit 589053f

Please sign in to comment.