Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fsdp support #28

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ Example

To build a model, you can refer to [example/bert](https://github.com/dptech-corp/Uni-Core/tree/main/examples/bert).

FSDP Example
-------

To use FSDP distributed training, you can refer to [example/bert/train_bert_test_fsdp.sh](https://github.com/dptech-corp/Uni-Core/tree/main/examples/bert).

- Install the fairscale: `pip install fairscale`.
- Modify the original training scripts, set `--ddp-backend fully_sharded`.

Note
- Currently only `--fp16` is supported in `fully_sharded` backend.
- while FSDP is full compatible with pointwise Optimizers (e.g., Adam, AdamW, Adadelta, Adamax, SGD, etc.), it is not currently compatible with non-pointwise Optimizers (e.g., Adagrad, Adafactor, LAMB, etc.)
- FSDP depends on flattening the parameters, so models that currently require `--fp16-no-flatten-grads` may not be supported.


Related projects
----------------

Expand Down
14 changes: 14 additions & 0 deletions examples/bert/train_bert_test_fsdp.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[ -z "${MASTER_PORT}" ] && MASTER_PORT=10086
[ -z "${n_gpu}" ] && n_gpu=$(nvidia-smi -L | wc -l)
export NCCL_ASYNC_ERROR_HANDLING=1
export OMP_NUM_THREADS=1
python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port=$MASTER_PORT $(which unicore-train) ./example_data --user-dir . --valid-subset valid \
--num-workers 0 --ddp-backend="fully_sharded" \
--task bert --loss masked_lm --arch bert_base \
--optimizer adam --adam-betas '(0.9, 0.98)' --adam-eps 1e-6 --clip-norm 1.0 \
--lr-scheduler polynomial_decay --lr 1e-4 --warmup-updates 100 --total-num-update 10000 --batch-size 4 \
--update-freq 1 --seed 1 \
--fp16 --fp16-init-scale 4 --fp16-scale-window 256 --tensorboard-logdir ./tsb/ \
--max-update 10000 --log-interval 100 --log-format simple \
--save-interval-updates 5000 --validate-interval-updates 5000 --keep-interval-updates 30 --no-epoch-checkpoints \
--save-dir ./save
4 changes: 4 additions & 0 deletions unicore/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,11 @@ def save_checkpoint(args, trainer, epoch_itr, val_loss, ckp_copy_thread, do_save
if args.no_save or not do_save:
return

trainer.consolidate_optimizer()

if not trainer.should_save_checkpoint_on_current_rank:
if trainer.always_call_state_dict_during_save_checkpoint:
trainer.state_dict()
return

write_timer = meters.StopwatchMeter()
Expand Down
9 changes: 9 additions & 0 deletions unicore/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@
from .module_proxy_wrapper import ModuleProxyWrapper
from .legacy_distributed_data_parallel import LegacyDistributedDataParallel

from .fully_sharded_data_parallel import (
fsdp_enable_wrap,
fsdp_wrap,
FullyShardedDataParallel,
)

__all__ = [
"ModuleProxyWrapper",
"fsdp_enable_wrap",
"fsdp_wrap",
"FullyShardedDataParallel",
]
144 changes: 144 additions & 0 deletions unicore/distributed/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
# Copyright (c) DP Technology.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import contextlib
from typing import Optional

import torch
from unicore.distributed import utils as dist_utils


try:
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP

has_FSDP = True
except ImportError:
FSDP = torch.nn.Module
has_FSDP = False


class FullyShardedDataParallel(FSDP):
"""
A small wrapper around fairscale's FullyShardedDataParallel (FSDP) with some
fairseq-specific checkpoint saving/loading logic.

Args:
use_sharded_state (bool): if True, then ``state_dict`` will return
``FSDP.local_state_dict`` and ``load_state_dict`` will call
``FSDP.load_local_state_dict``. Otherwise, ``state_dict`` will
return the full model weights on data parallel rank 0 (empty on
other ranks) and ``load_state_dict`` will broadcast model weights
from rank 0 to other ranks.
"""

def __init__(self, *args, use_sharded_state: bool = False, **kwargs):
if not has_FSDP:
raise ImportError(
"Cannot find FullyShardedDataParallel. "
"Please install fairscale with: pip install fairscale"
)
super().__init__(*args, **kwargs)
self.use_sharded_state = use_sharded_state

@property
def unwrapped_module(self) -> torch.nn.Module:
if self.flatten_parameters:
return self.module.module
else:
return self.module

def state_dict(self, destination=None, prefix="", keep_vars=False):
if self.use_sharded_state:
return super().local_state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars
)
else:
if self.rank == 0:
return super().state_dict(
destination=destination, prefix=prefix, keep_vars=keep_vars
)
else:
# We must call state_dict() due to use of communication
# primitives. But we don't use the result.
super().state_dict()
return destination or {}

def load_state_dict(self, state_dict, strict=True, model_cfg=None):
if self.use_sharded_state:
return super().load_local_state_dict(state_dict, strict=strict)
else:
state_dict = dist_utils.broadcast_object(
state_dict, src_rank=0, group=self.process_group
)
return super().load_state_dict(state_dict, strict=strict)


class DummyProcessGroup:
def __init__(self, rank: int, size: int):
self._rank = rank
self._size = size

def rank(self) -> int:
return self._rank

def size(self) -> int:
return self._size


@contextlib.contextmanager
def fsdp_enable_wrap(cfg):
try:
from fairscale.nn import enable_wrap
except ImportError:
raise ImportError(
"Cannot find FullyShardedDataParallel. "
"Please install fairscale with: pip install fairscale"
)
if cfg.memory_efficient_fp16:
assert cfg.fp16 # memory_efficient_fp16 should imply fp16
group = dist_utils.get_data_parallel_group()
if group is None and cfg.distributed_world_size == 1:
group = DummyProcessGroup(rank=0, size=1)
fsdp_config = {
"process_group": group,
"reshard_after_forward": not cfg.no_reshard_after_forward,
"mixed_precision": cfg.fp16 and not cfg.memory_efficient_fp16,
"fp32_reduce_scatter": cfg.fp32_reduce_scatter,
"flatten_parameters": not cfg.not_fsdp_flatten_parameters,
"cpu_offload": cfg.cpu_offload,
"compute_dtype": torch.float16 if cfg.fp16 else torch.float32,
"bucket_cap_mb": cfg.bucket_cap_mb,
"state_dict_device": torch.device("cpu"), # reduce GPU mem usage
}
with enable_wrap(
wrapper_cls=FullyShardedDataParallel,
use_sharded_state=cfg.use_sharded_state,
**fsdp_config,
):
yield


def fsdp_wrap(module, min_num_params: Optional[int] = None, **kwargs):
"""
Helper to wrap layers/modules in FSDP. This falls back to a no-op if
fairscale is not available.

Args:
module (nn.Module): module to (maybe) wrap
min_num_params (int, Optional): minimum number of layer params to wrap
"""
try:
from fairscale.nn import wrap
if min_num_params is not None:
num_params = sum(p.numel() for p in module.parameters())
if num_params >= min_num_params:
return wrap(module, **kwargs)
else:
return module
else:
return wrap(module, **kwargs)
except ImportError:
return module
14 changes: 14 additions & 0 deletions unicore/models/distributed_unicore_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,20 @@ def DistributedUnicoreModel(args, model, process_group, device):
)
# forward missing getattr and state_dict/load_state_dict to orig model
wrapped_model = ModuleProxyWrapper(wrapped_model)
elif args.ddp_backend == "fully_sharded":
try:
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
except ImportError:
raise ImportError(
"Cannot find FullyShardedDataParallel. "
"Please install fairscale with: pip install fairscale"
)
assert isinstance(model, FSDP), "expected model to already be wrapped in FSDP"
wrapped_model = model
if args.memory_efficient_fp16:
wrapped_model = wrapped_model.half()
if not args.cpu_offload:
wrapped_model = wrapped_model.to(device=device)
else:
raise ValueError("Unknown --ddp-backend: " + args.ddp_backend)

Expand Down
24 changes: 14 additions & 10 deletions unicore/modules/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from . import TransformerDecoderLayer, LayerNorm
from .transformer_encoder import relative_position_bucket

from unicore.distributed import fsdp_wrap


def fill_with_neg_inf(t):
return t.fill_(float("-inf"))
Expand Down Expand Up @@ -60,16 +62,18 @@ def __init__(

self.layers = nn.ModuleList(
[
TransformerDecoderLayer(
embed_dim=self.embed_dim,
ffn_embed_dim=ffn_embed_dim,
attention_heads=attention_heads,
dropout=dropout,
attention_dropout=attention_dropout,
activation_dropout=activation_dropout,
activation_fn=activation_fn,
post_ln=post_ln,

fsdp_wrap(
TransformerDecoderLayer(
embed_dim=self.embed_dim,
ffn_embed_dim=ffn_embed_dim,
attention_heads=attention_heads,
dropout=dropout,
attention_dropout=attention_dropout,
activation_dropout=activation_dropout,
activation_fn=activation_fn,
post_ln=post_ln,

)
)
for _ in range(decoder_layers)
]
Expand Down
23 changes: 13 additions & 10 deletions unicore/modules/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import torch.nn.functional as F
from . import TransformerEncoderLayer, LayerNorm

from unicore.distributed import fsdp_wrap


def init_bert_params(module):
if not getattr(module, 'can_global_init', True):
Expand Down Expand Up @@ -80,16 +82,17 @@ def __init__(

self.layers = nn.ModuleList(
[
TransformerEncoderLayer(
embed_dim=self.embed_dim,
ffn_embed_dim=ffn_embed_dim,
attention_heads=attention_heads,
dropout=dropout,
attention_dropout=attention_dropout,
activation_dropout=activation_dropout,
activation_fn=activation_fn,
post_ln=post_ln,

fsdp_wrap(
TransformerEncoderLayer(
embed_dim=self.embed_dim,
ffn_embed_dim=ffn_embed_dim,
attention_heads=attention_heads,
dropout=dropout,
attention_dropout=attention_dropout,
activation_dropout=activation_dropout,
activation_fn=activation_fn,
post_ln=post_ln,
)
)
for _ in range(encoder_layers)
]
Expand Down
4 changes: 3 additions & 1 deletion unicore/optim/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@
from unicore.optim.unicore_optimizer import ( # noqa
UnicoreOptimizer,
)
from unicore.optim.fp16_optimizer import FP16Optimizer
from unicore.optim.fp16_optimizer import FP16Optimizer, MemoryEfficientFP16Optimizer


__all__ = [
"UnicoreOptimizer",
"FP16Optimizer",
"MemoryEfficientFP16Optimizer",
]

(
Expand Down
Loading