Skip to content

Commit

Permalink
Models With Tied Weights Need Re-Tieing After FSDP Param Init (#3154)
Browse files Browse the repository at this point in the history
* add fsdp_tool to retie after param init

* make it handle generic param_init_fn

* fix quality

Signed-off-by: Yu Chin Fabian Lim <[email protected]>

---------

Signed-off-by: Yu Chin Fabian Lim <[email protected]>
  • Loading branch information
fabianlim authored Oct 31, 2024
1 parent 497eb3c commit 8159c98
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
compare_versions,
convert_model,
convert_outputs_to_fp32,
ensure_weights_retied,
extract_model_from_parallel,
gather,
gather_object,
Expand Down Expand Up @@ -1475,6 +1476,15 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
if not is_type_fsdp:
self.state.fsdp_plugin.set_auto_wrap_policy(model)
fsdp_plugin = self.state.fsdp_plugin

# need to ensure that params are re-tied after running
# param_init_fn
fsdp_plugin.param_init_fn = ensure_weights_retied(
fsdp_plugin.param_init_fn,
model,
self.device,
)

kwargs = {
"sharding_strategy": fsdp_plugin.sharding_strategy,
"cpu_offload": fsdp_plugin.cpu_offload,
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@
from .fsdp_utils import (
disable_fsdp_ram_efficient_loading,
enable_fsdp_ram_efficient_loading,
ensure_weights_retied,
load_fsdp_model,
load_fsdp_optimizer,
merge_fsdp_weights,
Expand Down
47 changes: 47 additions & 0 deletions src/accelerate/utils/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import os
import shutil
from collections import defaultdict
from pathlib import Path

import torch
Expand Down Expand Up @@ -324,3 +325,49 @@ def merge_fsdp_weights(
logger.info(f"Removing old checkpoint directory {checkpoint_dir}")
shutil.rmtree(checkpoint_dir)
state.wait_for_everyone()


def ensure_weights_retied(param_init_fn, model: torch.nn.Module, device: torch.cuda.device):
_tied_names = model._tied_weights_keys
if not _tied_names:
# if no tied names just passthrough
return param_init_fn

# get map of parameter instances to params.
# - needed for replacement later
_tied_params = {}
for name in _tied_names:
name = name.split(".")
name, param_name = ".".join(name[:-1]), name[-1]
mod = model.get_submodule(name)
param = getattr(mod, param_name)

_tied_params[id(param)] = None # placeholder for the param first

# build param_init_fn for the case with tied params
def param_init_fn_tied_param(module: torch.nn.Module):
# track which params to tie
# - usually only 1, but for completeness consider > 1
params_to_tie = defaultdict(list)
for n, param in module.named_parameters(recurse=False):
if id(param) in _tied_params:
params_to_tie[id(param)].append(n)

# call the param init fn, which potentially re-allocates the
# parameters
module = param_init_fn(module)

# search the parameters again and tie them up again
for id_key, _param_names in params_to_tie.items():
for param_name in _param_names:
param = _tied_params[id_key]
if param is None:
# everything will be tied to the first time the
# param is observed
_tied_params[id_key] = getattr(module, param_name)
else:
setattr(module, param_name, param) # tie

return module

return param_init_fn_tied_param

0 comments on commit 8159c98

Please sign in to comment.