diff --git a/test_runner.py b/test_runner.py index 477305865..245bf0eee 100755 --- a/test_runner.py +++ b/test_runner.py @@ -221,20 +221,6 @@ def build_test_list(): "pp_tp", requires_seed_checkpoint=True, ), - OverrideDefinitions( - [ - [ - "--checkpoint.enable_checkpoint", - "--experimental.pipeline_parallel_degree 2", - "--experimental.pipeline_parallel_split_points layers.4", - "--experimental.pipeline_parallel_split_mode tracer", - ], - ], - "PP tracer frontend test", - "pp_tracer", - requires_seed_checkpoint=True, - ngpu=2, - ), OverrideDefinitions( [ [ diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 56080e0d5..3ba1d1029 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -270,7 +270,7 @@ def __init__(self): the third containing layers.2 and all the remaining layers. Note: fully-automated splitting may be enabled in the future, - but currently the split points must be specified manually for both manual and tracer.""", + but currently the split points must be specified manually.""", ) self.parser.add_argument( "--experimental.pipeline_parallel_schedule", @@ -285,21 +285,6 @@ def __init__(self): Looped schedules (e.g. interleaved_1f1b) require specifying pipeline_paralle_degree = number of ranks, and split_points = number of stages - 1""", ) - self.parser.add_argument( - "--experimental.pipeline_parallel_split_mode", - type=str, - choices=["manual", "tracer"], - default="manual", - help=""" - Specify the split method (e.g. the Pipeline Parallelism Front End) - - "manual" means each rank will construct an nn.Module with the appropriate layers and .forward - implementation manually, and then wrap it in a PipelineStage. - - "tracer" means the full model will be initialized (via meta device) and then traced into a graph, - split via the provided split points, unflattened into an nn.Module, - and finally wrapped in a PipelineStage. tracer frontend is currently more experimental.""", - ) self.parser.add_argument( "--experimental.pipeline_parallel_microbatches", type=int, diff --git a/torchtitan/parallelisms/pipeline_llama.py b/torchtitan/parallelisms/pipeline_llama.py index 679832705..7bce2fe66 100644 --- a/torchtitan/parallelisms/pipeline_llama.py +++ b/torchtitan/parallelisms/pipeline_llama.py @@ -12,7 +12,7 @@ import torch import torch.nn as nn from torch.distributed import DeviceMesh -from torch.distributed.pipelining import pipeline, PipelineStage, SplitPoint +from torch.distributed.pipelining import PipelineStage from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging import logger @@ -36,20 +36,9 @@ def pipeline_llama( model_config: ModelArgs, loss_fn: Callable[..., torch.Tensor], ): - split_mode = job_config.experimental.pipeline_parallel_split_mode - valid_split_modes = ("manual", "tracer") - if split_mode not in valid_split_modes: - raise ValueError( - f"Invalid split mode: {split_mode}. Valid split modes: {valid_split_modes}" - ) - if split_mode == "manual": - stages, models = pipeline_llama_manual( - model, pp_mesh, parallel_dims, job_config, device, model_config - ) - elif split_mode == "tracer": - stages, models = pipeline_llama_tracer( - model, pp_mesh, parallel_dims, job_config, device, model_config - ) + stages, models = pipeline_llama_manual_split( + model, pp_mesh, parallel_dims, job_config, device, model_config + ) pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn) @@ -73,7 +62,7 @@ def _mixed_precision_dtype( return TORCH_DTYPE_MAP[mp_arg] if parallel_dims.dp_enabled else default -def pipeline_llama_manual( +def pipeline_llama_manual_split( whole_model: nn.Module, pp_mesh: DeviceMesh, parallel_dims: ParallelDims, @@ -173,57 +162,3 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal stages.append(stage) models.append(model_chunk) return stages, models - - -def pipeline_llama_tracer( - model: nn.Module, - pp_mesh: DeviceMesh, - parallel_dims: ParallelDims, - job_config: JobConfig, - device: DeviceType, - model_config: ModelArgs, -): - if job_config.model.norm_type == "fused_rmsnorm": - # TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr - # invocation stride in strict mode from `if dy.stride(-1) != 1:` in - # fused_rmsnorm - raise NotImplementedError( - "fused_rmsnorm is not compatible with Pipeline Tracer yet. " - "Please use rmsnorm or layernorm." - ) - if _mixed_precision_dtype(job_config, parallel_dims) != torch.float32: - raise NotImplementedError( - "Pipeline tracer does not work with FSDP mixed precision yet. " - "To work around, set mixed_precision_param to float32." - ) - - pp_rank = pp_mesh.get_local_rank() - pp_size = pp_mesh.size() - microbatches = ( - job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp - ) - (input,) = _llama_trace_input(job_config, model_config, device=device) - stage_idx = pp_rank - split_spec = { - layer_name: SplitPoint.BEGINNING - for layer_name in job_config.experimental.pipeline_parallel_split_points - } - num_stages = len(split_spec) + 1 - pipe = pipeline( - model, - mb_args=(input.chunk(microbatches)[0],), - split_spec=split_spec, - ) - - stages = [] - models = [] - for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style="loop"): - models.append(pipe.get_stage_module(stage_idx)) - stages.append( - pipe.build_stage( - stage_idx, - device=device, - group=pp_mesh.get_group(), - ) - ) - return stages, models diff --git a/train.py b/train.py index 3f07d3c7b..69a17500f 100644 --- a/train.py +++ b/train.py @@ -146,8 +146,6 @@ def loss_fn(pred, labels): model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn ) - pp_split_mode = job_config.experimental.pipeline_parallel_split_mode - # For PP with looped schedules, each item in model_parts is one stage-model-chunk. # We need to iterate through model_parts to apply SPMD parallelisms, compilation, # optimizer, and checkpointing @@ -155,9 +153,7 @@ def loss_fn(pred, labels): # apply SPMD-style PT-D techniques models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config) m.to_empty(device="cuda") - # skip traced modules since we do not define init_weights in the traced module - if pp_split_mode == "manual": - m.init_weights() + m.init_weights() m.train() else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel @@ -205,11 +201,6 @@ def loss_fn(pred, labels): checkpoint_loaded = checkpoint.load() if parallel_dims.pp_enabled and not checkpoint_loaded: - if pp_split_mode == "tracer": - raise RuntimeError( - "Pipeline parallelism with tracer mode is not supported without a seed checkpoint." - ) - # TODO: fix this by allowing each rank to set their own seed logger.warning( "Pipeline Parallelism is being used without a seed checkpoint. "