Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated FSDP -> FSDPModule in doc #538

Merged
merged 1 commit into from
Aug 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/fsdp.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def fully_shard(
- `fully_shard(module)` is similar to `FullyShardedDataParallel(module)`, constructing one communication bucket from `module.parameters()` except those already assigned to a nested `fully_shard`/`FullyShardedDataParallel` call.
- `fully_shard(module)` adds an `FSDPState` object on `module`, accessible via `fully_shard.state(module)`, instead of being an `nn.Module` wrapper. This is done via the `@contract` decorator.
- Calling `model.named_parameters()` for a `model` with FSDP2 applied returns unchanged parameter names and `DTensor` sharded parameters. This means that the optimizer and gradient norm clipping see `DTensor`s.
- `fully_shard(module)` performs a dynamic class swap on `module`. E.g., if `type(module) is Transformer`, then FSDP2 constructs a new class `FSDPTransformer` that inherits from a class `FSDP` and `Transformer` and sets `module.__class__` to be `FSDPTransformer`. This allows us to add new methods and override methods via `FSDP` without constructing an `nn.Module` wrapper.
- `fully_shard(module)` performs a dynamic class swap on `module`. E.g., if `type(module) is Transformer`, then FSDP2 constructs a new class `FSDPTransformer` that inherits from a class `FSDPModule` and `Transformer` and sets `module.__class__` to be `FSDPTransformer`. This allows us to add new methods and override methods via `FSDPModule` without constructing an `nn.Module` wrapper.
- FSDP1's `sharding_strategy` and `process_group`/`device_mesh` maps to FSDP2's `mesh` and `reshard_after_forward`.
- `mesh` should be 1D for FSDP and 2D for HSDP. For HSDP, we assume replication on the 0th mesh dim and sharding on the 1st mesh dim. If `mesh is None`, then FSDP2 initializes a 1D global mesh over the default process group.
- `reshard_after_forward=True` or `False` determines whether parameters are resharded (freed) after forward. If `True`, then they are re-all-gathered in backward. This trades off saving memory at the cost of extra communication.
Expand Down
Loading