Skip to content
This repository has been archived by the owner on Oct 19, 2024. It is now read-only.

[FIX] restore_checkpoint without mesh group #854

Merged
merged 6 commits into from
Jan 17, 2023
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions alpa/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import numpy as np

from alpa.device_mesh import (DistributedArray, ReplicatedDistributedArray,
get_global_virtual_physical_mesh)
get_global_virtual_physical_mesh,
merrymercy marked this conversation as resolved.
Show resolved Hide resolved
get_global_physical_mesh)

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -156,13 +157,21 @@ def restore_checkpoint(ckpt_dir: Union[str, os.PathLike], step: int,
flat_info = tree_leaves(placement_specs)
flat_load_state = []
mesh_group = get_global_virtual_physical_mesh().launched_physical_mesh_group
assert mesh_group is not None
physical_mesh = get_global_physical_mesh()

assert mesh_group is not None or physical_mesh is not None

for path, info in zip(state_paths, flat_info):
if info is None:
logger.warning("Variable is not used, skip loading it")
flat_load_state.append(None)
if len(info.mesh_ids) == 1:
elif mesh_group is None:
dist_arr = DistributedArray.load(os.path.join(ckpt_dir,
path), info.aval,
merrymercy marked this conversation as resolved.
Show resolved Hide resolved
physical_mesh,
info.sharding_specs[0])
flat_load_state.append(dist_arr)
elif len(info.mesh_ids) == 1:
dist_arr = DistributedArray.load(os.path.join(ckpt_dir,
path), info.aval,
mesh_group[info.mesh_ids[0]],
Expand Down