Skip to content

Commit

Permalink
[doc] Minor additions to ShardedDDP docs (#299)
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux authored Jan 8, 2021
1 parent 11beea6 commit b202804
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 6 deletions.
18 changes: 16 additions & 2 deletions fairscale/nn/data_parallel/sharded_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@


class ShardedDataParallel(nn.Module):
"""
Wrap the model, and reduce the gradients to the right rank during the backward pass.
""" Wrap the model, and reduce the gradients to the right rank during the backward pass.
- the partition is given by the sharded optimizer
- wrap the base model with a model which knows where to reduce each gradient
Expand All @@ -46,6 +45,21 @@ class ShardedDataParallel(nn.Module):
Synchronize the models in between the ranks when starting up. Not needed if each rank has the same seed,
or the training restarts from a saved state
.. warning:
ShardedDDP implements gradient sharding, meaning that each rank only owns a unique shard of the model gradients
after the backward pass, in order to save memory and some communication bandwidth.
.. warning:
As a consequence of sharding, in case of gradient clipping, one has to use the `clip_grad_norm` exposed by
the `optimizer state sharding wrapper <fairscale.optim.OSS>`
.. warning:
As a consequence of sharding, after loss.backward() (or equivalent) each rank will have `None` in place of some param.grad
.. warning:
As a consequence of sharding, Pytorch and Apex AMP implementations will hang when used in conjunction with `ShardedDDP`.
One needs a `shard-aware grad scaler<ShardedGradScaler>`, which is proposed in `fairscale.optim.grad_scaler`, compatible with PytorchAMP.
"""

def __init__(
Expand Down
7 changes: 4 additions & 3 deletions fairscale/optim/adascale.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,10 +554,11 @@ def set_num_gradients_to_accumulate(self, num_gradients_to_accumulate: int, upda
`set_scale` needs to be called to update the scale as well.
TODO (min): need a way of determine how much to increase the step size?
TODO (min): have both `set_scale` and `set_num_gradients_to_accumulate`
is hard to use and easy to make mistake. I think it is better
to specific a specify a `base_scale`. But more discussion is
needed here.
is hard to use and easy to make mistake. I think it is better
to specific a specify a `base_scale`. But more discussion is
needed here.
Args:
num_gradients_to_accumulate (int):
Expand Down
1 change: 0 additions & 1 deletion fairscale/optim/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,6 @@ def clip_grad_norm(
.. warning: This needs to be called on all ranks, since synchronization primitives will be used
.. warning: Model paralelism -groups other than world- are not yet supported
"""

# Compute the max norm for this shards's worth of gradients
Expand Down

0 comments on commit b202804

Please sign in to comment.