Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove PP tracer #555

Merged
merged 2 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,20 +221,6 @@ def build_test_list():
"pp_tp",
requires_seed_checkpoint=True,
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_split_mode tracer",
],
],
"PP tracer frontend test",
"pp_tracer",
requires_seed_checkpoint=True,
ngpu=2,
),
OverrideDefinitions(
[
[
Expand Down
17 changes: 1 addition & 16 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def __init__(self):
the third containing layers.2 and all the remaining layers.

Note: fully-automated splitting may be enabled in the future,
but currently the split points must be specified manually for both manual and tracer.""",
but currently the split points must be specified manually.""",
)
self.parser.add_argument(
"--experimental.pipeline_parallel_schedule",
Expand All @@ -285,21 +285,6 @@ def __init__(self):
Looped schedules (e.g. interleaved_1f1b) require specifying pipeline_paralle_degree = number of ranks,
and split_points = number of stages - 1""",
)
self.parser.add_argument(
"--experimental.pipeline_parallel_split_mode",
type=str,
choices=["manual", "tracer"],
default="manual",
help="""
Specify the split method (e.g. the Pipeline Parallelism Front End)

"manual" means each rank will construct an nn.Module with the appropriate layers and .forward
implementation manually, and then wrap it in a PipelineStage.

"tracer" means the full model will be initialized (via meta device) and then traced into a graph,
split via the provided split points, unflattened into an nn.Module,
and finally wrapped in a PipelineStage. tracer frontend is currently more experimental.""",
)
self.parser.add_argument(
"--experimental.pipeline_parallel_microbatches",
type=int,
Expand Down
73 changes: 4 additions & 69 deletions torchtitan/parallelisms/pipeline_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import torch
import torch.nn as nn
from torch.distributed import DeviceMesh
from torch.distributed.pipelining import pipeline, PipelineStage, SplitPoint
from torch.distributed.pipelining import PipelineStage

from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP
from torchtitan.logging import logger
Expand All @@ -36,20 +36,9 @@ def pipeline_llama(
model_config: ModelArgs,
loss_fn: Callable[..., torch.Tensor],
):
split_mode = job_config.experimental.pipeline_parallel_split_mode
valid_split_modes = ("manual", "tracer")
if split_mode not in valid_split_modes:
raise ValueError(
f"Invalid split mode: {split_mode}. Valid split modes: {valid_split_modes}"
)
if split_mode == "manual":
stages, models = pipeline_llama_manual(
model, pp_mesh, parallel_dims, job_config, device, model_config
)
elif split_mode == "tracer":
stages, models = pipeline_llama_tracer(
model, pp_mesh, parallel_dims, job_config, device, model_config
)
stages, models = pipeline_llama_manual(
model, pp_mesh, parallel_dims, job_config, device, model_config
)

pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn)

Expand Down Expand Up @@ -173,57 +162,3 @@ def _build_stage(stage_idx, start_layer, stop_layer, is_first=False, is_last=Fal
stages.append(stage)
models.append(model_chunk)
return stages, models


def pipeline_llama_tracer(
model: nn.Module,
pp_mesh: DeviceMesh,
parallel_dims: ParallelDims,
job_config: JobConfig,
device: DeviceType,
model_config: ModelArgs,
):
if job_config.model.norm_type == "fused_rmsnorm":
# TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr
# invocation stride in strict mode from `if dy.stride(-1) != 1:` in
# fused_rmsnorm
raise NotImplementedError(
"fused_rmsnorm is not compatible with Pipeline Tracer yet. "
"Please use rmsnorm or layernorm."
)
if _mixed_precision_dtype(job_config, parallel_dims) != torch.float32:
raise NotImplementedError(
"Pipeline tracer does not work with FSDP mixed precision yet. "
"To work around, set mixed_precision_param to float32."
)

pp_rank = pp_mesh.get_local_rank()
pp_size = pp_mesh.size()
microbatches = (
job_config.experimental.pipeline_parallel_microbatches or parallel_dims.pp
)
(input,) = _llama_trace_input(job_config, model_config, device=device)
stage_idx = pp_rank
split_spec = {
layer_name: SplitPoint.BEGINNING
for layer_name in job_config.experimental.pipeline_parallel_split_points
}
num_stages = len(split_spec) + 1
pipe = pipeline(
model,
mb_args=(input.chunk(microbatches)[0],),
split_spec=split_spec,
)

stages = []
models = []
for stage_idx in stage_ids_this_rank(pp_rank, pp_size, num_stages, style="loop"):
models.append(pipe.get_stage_module(stage_idx))
stages.append(
pipe.build_stage(
stage_idx,
device=device,
group=pp_mesh.get_group(),
)
)
return stages, models
11 changes: 1 addition & 10 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,18 +146,14 @@ def loss_fn(pred, labels):
model, pp_mesh, parallel_dims, job_config, device, model_config, loss_fn
)

pp_split_mode = job_config.experimental.pipeline_parallel_split_mode

# For PP with looped schedules, each item in model_parts is one stage-model-chunk.
# We need to iterate through model_parts to apply SPMD parallelisms, compilation,
# optimizer, and checkpointing
for m in model_parts:
# apply SPMD-style PT-D techniques
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)
m.to_empty(device="cuda")
# skip traced modules since we do not define init_weights in the traced module
if pp_split_mode == "manual":
m.init_weights()
m.init_weights()
m.train()
else:
# apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel
Expand Down Expand Up @@ -205,11 +201,6 @@ def loss_fn(pred, labels):
checkpoint_loaded = checkpoint.load()

if parallel_dims.pp_enabled and not checkpoint_loaded:
if pp_split_mode == "tracer":
raise RuntimeError(
"Pipeline parallelism with tracer mode is not supported without a seed checkpoint."
)

# TODO: fix this by allowing each rank to set their own seed
logger.warning(
"Pipeline Parallelism is being used without a seed checkpoint. "
Expand Down
Loading