From 541bb616e1653f6b1eab78c5b9eb1496dc838534 Mon Sep 17 00:00:00 2001 From: jon-chuang Date: Sun, 9 Apr 2023 07:14:08 +0800 Subject: [PATCH 1/2] draft --- .../cross_mesh_resharding.py | 61 ++++++++++++++----- 1 file changed, 47 insertions(+), 14 deletions(-) diff --git a/alpa/pipeline_parallel/cross_mesh_resharding.py b/alpa/pipeline_parallel/cross_mesh_resharding.py index 388ba8847..d59b94b7b 100644 --- a/alpa/pipeline_parallel/cross_mesh_resharding.py +++ b/alpa/pipeline_parallel/cross_mesh_resharding.py @@ -5,8 +5,10 @@ import math import random import time -from typing import List, Any +from typing import Dict, List, Any, Sequence +from alpa.pipeline_parallel.schedules import PipelineSchedule +from jax.core import Var from jax.interpreters import pxla import numpy as np import ray @@ -682,7 +684,11 @@ class ReshardingTaskSpec: VirtualDistributedArray. """ - def __init__(self, src_array, dst_array, final_dst_spec): + def __init__(self, + src_array: VirtualDistributedArray, + dst_array: VirtualDistributedArray, + final_dst_spec + ): self.src = src_array self.dst = dst_array self._dst_tile_to_src_tiles_map = None @@ -949,7 +955,11 @@ class CrossMeshCommunicator: schedule (Any): the pipelining schedule for these stages. """ - def __init__(self, sharded_stages, schedule): + def __init__( + self, + sharded_stages: Sequence[XlaShardedPipelineComputation], + schedule: PipelineSchedule + ): if not isinstance(sharded_stages, list): raise RuntimeError("Require a list of stages.") for s in sharded_stages: @@ -1091,6 +1101,9 @@ def _create_resharding_specs(self): [{} for _ in range(self.num_mesh)] for _ in range(self.num_mesh) ] + # We will grab the var from the stage where it is last an input, if any + # We will map it to the corresponding stage index where it is last seen + last_seen: Dict[Var, int] = {} # find stages that will communicate pairs = np.argwhere(deps > 0) for i in range(pairs.shape[0]): @@ -1116,29 +1129,46 @@ def _create_resharding_specs(self): out_sharding_specs = src_stage.output_sharding_specs in_sharding_specs = dst_stage.input_sharding_specs - # Make a ReshardSpec for each VirtualDistributedArray + # Make a ReshardingTaskSpec for each VirtualDistributedArray for var, out_var_index, in_var_index in zip(resharding_vars, out_var_indices, in_var_indices): - src_sharding_spec = out_sharding_specs[out_var_index] + if var in last_seen: + last_seen_stage_index = last_seen[var] + last_seen[var] = dst_stage_index + + last_seen_var_index = last_seen_stage.invars.index(var) + last_seen_sharding_spec = last_seen_stage.input_sharding_specs[last_seen_var_index] + + last_seen_stage = stages[last_seen_stage_index] + last_seen_mesh_index = stage_placements[last_seen_stage_index] + last_seen_mesh = meshes[last_seen_mesh_index] + final_src_array = VirtualDistributedArray( + device_mesh=last_seen_mesh, + aval=var.aval, + sharding_spec=last_seen_sharding_spec) + final_src_mesh_index = last_seen_mesh_index + else: + last_seen[var] = dst_stage_index + src_sharding_spec = out_sharding_specs[out_var_index] + final_src_array = VirtualDistributedArray( + device_mesh=src_mesh, + aval=var.aval, + sharding_spec=src_sharding_spec) + final_src_mesh_index = src_mesh_index + dst_sharding_spec = in_sharding_specs[in_var_index] - final_dst_spec = dst_sharding_spec if global_config.resharding_mode == "send_recv": dst_sharding_spec = self._rewrite_allgather_spec( dst_sharding_spec, dst_mesh.num_hosts, var.aval.shape) - - src_array = VirtualDistributedArray( - device_mesh=src_mesh, - aval=var.aval, - sharding_spec=src_sharding_spec) dst_array = VirtualDistributedArray( device_mesh=dst_mesh, aval=var.aval, sharding_spec=dst_sharding_spec) - task_spec = ReshardingTaskSpec(src_array, dst_array, + task_spec = ReshardingTaskSpec(final_src_array, dst_array, final_dst_spec) - self.resharding_specs[src_mesh_index][dst_mesh_index][ + self.resharding_specs[final_src_mesh_index][dst_mesh_index][ var] = task_spec def task_spec_iter(self): @@ -1425,7 +1455,10 @@ def _generate_broadcast_resharding_strategy_by_loads( return strategy @staticmethod - def _args_between(src_stage, dst_stage): + def _args_between( + src_stage: XlaShardedPipelineComputation, + dst_stage: XlaShardedPipelineComputation + ): """Find the variable exchanged between stages.""" resharding_vars = [] src_indices = [] From 8af66bb61ff7e1ab9a4033faf2e9897b04cccf20 Mon Sep 17 00:00:00 2001 From: jon-chuang Date: Sun, 9 Apr 2023 07:18:11 +0800 Subject: [PATCH 2/2] fmt --- .../cross_mesh_resharding.py | 30 ++++++++----------- 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/alpa/pipeline_parallel/cross_mesh_resharding.py b/alpa/pipeline_parallel/cross_mesh_resharding.py index d59b94b7b..19c6c31df 100644 --- a/alpa/pipeline_parallel/cross_mesh_resharding.py +++ b/alpa/pipeline_parallel/cross_mesh_resharding.py @@ -684,11 +684,8 @@ class ReshardingTaskSpec: VirtualDistributedArray. """ - def __init__(self, - src_array: VirtualDistributedArray, - dst_array: VirtualDistributedArray, - final_dst_spec - ): + def __init__(self, src_array: VirtualDistributedArray, + dst_array: VirtualDistributedArray, final_dst_spec): self.src = src_array self.dst = dst_array self._dst_tile_to_src_tiles_map = None @@ -955,11 +952,8 @@ class CrossMeshCommunicator: schedule (Any): the pipelining schedule for these stages. """ - def __init__( - self, - sharded_stages: Sequence[XlaShardedPipelineComputation], - schedule: PipelineSchedule - ): + def __init__(self, sharded_stages: Sequence[XlaShardedPipelineComputation], + schedule: PipelineSchedule): if not isinstance(sharded_stages, list): raise RuntimeError("Require a list of stages.") for s in sharded_stages: @@ -1138,11 +1132,13 @@ def _create_resharding_specs(self): last_seen[var] = dst_stage_index last_seen_var_index = last_seen_stage.invars.index(var) - last_seen_sharding_spec = last_seen_stage.input_sharding_specs[last_seen_var_index] + last_seen_sharding_spec = last_seen_stage.input_sharding_specs[ + last_seen_var_index] last_seen_stage = stages[last_seen_stage_index] - last_seen_mesh_index = stage_placements[last_seen_stage_index] - last_seen_mesh = meshes[last_seen_mesh_index] + last_seen_mesh_index = stage_placements[ + last_seen_stage_index] + last_seen_mesh = meshes[last_seen_mesh_index] final_src_array = VirtualDistributedArray( device_mesh=last_seen_mesh, aval=var.aval, @@ -1156,7 +1152,7 @@ def _create_resharding_specs(self): aval=var.aval, sharding_spec=src_sharding_spec) final_src_mesh_index = src_mesh_index - + dst_sharding_spec = in_sharding_specs[in_var_index] final_dst_spec = dst_sharding_spec if global_config.resharding_mode == "send_recv": @@ -1455,10 +1451,8 @@ def _generate_broadcast_resharding_strategy_by_loads( return strategy @staticmethod - def _args_between( - src_stage: XlaShardedPipelineComputation, - dst_stage: XlaShardedPipelineComputation - ): + def _args_between(src_stage: XlaShardedPipelineComputation, + dst_stage: XlaShardedPipelineComputation): """Find the variable exchanged between stages.""" resharding_vars = [] src_indices = []