Skip to content

Commit

Permalink
remove PP tracer
Browse files Browse the repository at this point in the history
ghstack-source-id: 6e56baaa8c48c451b6dac1945310e62d310b19f7
Pull Request resolved: #555
  • Loading branch information
tianyu-l committed Aug 22, 2024
1 parent 90c889e commit b5b016f
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 109 deletions.
14 changes: 0 additions & 14 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,20 +221,6 @@ def build_test_list():
"pp_tp",
requires_seed_checkpoint=True,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_split_mode tracer",
],
],
"PP tracer frontend test",
"pp_tracer",
requires_seed_checkpoint=True,
ngpu=2,
),
OverrideDefinitions(
[
[
Expand Down
17 changes: 1 addition & 16 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def __init__(self):
the third containing layers.2 and all the remaining layers.
Note: fully-automated splitting may be enabled in the future,
but currently the split points must be specified manually for both manual and tracer.""",
but currently the split points must be specified manually.""",
)
self.parser.add_argument(
"--experimental.pipeline_parallel_schedule",
Expand All @@ -285,21 +285,6 @@ def __init__(self):
Looped schedules (e.g. interleaved_1f1b) require specifying pipeline_paralle_degree = number of ranks,
and split_points = number of stages - 1""",
)
self.parser.add_argument(
"--experimental.pipeline_parallel_split_mode",
type=str,
choices=["manual", "tracer"],
default="manual",
help="""
Specify the split method (e.g. the Pipeline Parallelism Front End)
"manual" means each rank will construct an nn.Module with the appropriate layers and .forward
implementation manually, and then wrap it in a PipelineStage.
"tracer" means the full model will be initialized (via meta device) and then traced into a graph,
split via the provided split points, unflattened into an nn.Module,
and finally wrapped in a PipelineStage. tracer frontend is currently more experimental.""",
)
self.parser.add_argument(
"--experimental.pipeline_parallel_microbatches",
type=int,
Expand Down
73 changes: 4 additions & 69 deletions torchtitan/parallelisms/pipeline_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch
import torch.nn as nn
from torch.distributed import DeviceMesh
from torch.distributed.pipelining import pipeline, PipelineStage, SplitPoint
from torch.distributed.pipelining import PipelineStage

from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging import logger
Expand All @@ -36,20 +36,9 @@ def pipeline_llama(
model_config: ModelArgs,
loss_fn: Callable[..., torch.Tensor],
):
split_mode = job_config.experimental.pipeline_parallel_split_mode
valid_split_modes = ("manual", "tracer")
if split_mode not in valid_split_modes:
raise ValueError(
f"Invalid split mode: {split_mode}. Valid split modes: {valid_split_modes}"
)
if split_mode == "manual":
stages, models = pipeline_llama_manual(
model, pp_mesh, parallel_dims, job_config, device, model_config
)
elif split_mode == "tracer":
stages, models = pipeline_llama_tracer(
model, pp_mesh, parallel_dims, job_config, device, model_config
)
stages, models = pipeline_llama_manual(
model, pp_mesh, parallel_dims, job_config, device, model_config
)

pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)

Expand Down Expand Up @@ -173,57 +162,3 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal
stages.append(stage)
models.append(model_chunk)
return stages, models


def pipeline_llama_tracer(
model: nn.Module,
pp_mesh: DeviceMesh,
parallel_dims: ParallelDims,
job_config: JobConfig,
device: DeviceType,
model_config: ModelArgs,
):
if job_config.model.norm_type == "fused_rmsnorm":
# TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr
# invocation stride in strict mode from `if dy.stride(-1) != 1:` in
# fused_rmsnorm
raise NotImplementedError(
"fused_rmsnorm is not compatible with Pipeline Tracer yet. "
"Please use rmsnorm or layernorm."
)
if _mixed_precision_dtype(job_config, parallel_dims) != torch.float32:
raise NotImplementedError(
"Pipeline tracer does not work with FSDP mixed precision yet. "
"To work around, set mixed_precision_param to float32."
)

pp_rank = pp_mesh.get_local_rank()
pp_size = pp_mesh.size()
microbatches = (
job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp
)
(input,) = _llama_trace_input(job_config, model_config, device=device)
stage_idx = pp_rank
split_spec = {
layer_name: SplitPoint.BEGINNING
for layer_name in job_config.experimental.pipeline_parallel_split_points
}
num_stages = len(split_spec) + 1
pipe = pipeline(
model,
mb_args=(input.chunk(microbatches)[0],),
split_spec=split_spec,
)

stages = []
models = []
for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style="loop"):
models.append(pipe.get_stage_module(stage_idx))
stages.append(
pipe.build_stage(
stage_idx,
device=device,
group=pp_mesh.get_group(),
)
)
return stages, models
11 changes: 1 addition & 10 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,18 +146,14 @@ def loss_fn(pred, labels):
model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn
)

pp_split_mode = job_config.experimental.pipeline_parallel_split_mode

# For PP with looped schedules, each item in model_parts is one stage-model-chunk.
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
# optimizer, and checkpointing
for m in model_parts:
# apply SPMD-style PT-D techniques
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)
m.to_empty(device="cuda")
# skip traced modules since we do not define init_weights in the traced module
if pp_split_mode == "manual":
m.init_weights()
m.init_weights()
m.train()
else:
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
Expand Down Expand Up @@ -205,11 +201,6 @@ def loss_fn(pred, labels):
checkpoint_loaded = checkpoint.load()

if parallel_dims.pp_enabled and not checkpoint_loaded:
if pp_split_mode == "tracer":
raise RuntimeError(
"Pipeline parallelism with tracer mode is not supported without a seed checkpoint."
)

# TODO: fix this by allowing each rank to set their own seed
logger.warning(
"Pipeline Parallelism is being used without a seed checkpoint. "
Expand Down

0 comments on commit b5b016f

Please sign in to comment.