Skip to content

Commit

Permalink
lora merge fix for O2 names (NVIDIA#7325)
Browse files Browse the repository at this point in the history
* wip

Signed-off-by: arendu <[email protected]>

* adjust key names based on O2

Signed-off-by: arendu <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* update

Signed-off-by: arendu <[email protected]>

* minor

Signed-off-by: arendu <[email protected]>

---------

Signed-off-by: arendu <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
arendu and pre-commit-ci[bot] authored Sep 1, 2023
1 parent 659949d commit ad79907
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions scripts/nlp_language_modeling/merge_lora_weights/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ def load_lora(lora_nemo, tp):
return lora_state_dict


def fix_for_O2(state_dict):
new_state_dict = {}
for k, v in state_dict.items():
new_state_dict[k.replace('model.language_model', 'model.module.language_model')] = v
return new_state_dict


def merge(
base_model_state_dict: Dict[str, Any],
lora_state_dict: Dict[int, Any],
Expand Down Expand Up @@ -110,6 +117,7 @@ def merge(
wt_self_attn = base_model_state_dict[key_self_attn_kqv]
wt_lora = wt_lora_out @ wt_lora_in
base_model_state_dict[key_self_attn_kqv] = wt_self_attn + wt_lora.type_as(wt_self_attn)
print("mergeing for weight", key_self_attn_kqv)
return base_model_state_dict


Expand Down Expand Up @@ -155,8 +163,6 @@ def main(cfg) -> None:
pretrained_cfg.activations_checkpoint_granularity = None
pretrained_cfg.activations_checkpoint_method = None
pretrained_cfg.precision = trainer.precision
if trainer.precision == "16":
pretrained_cfg.megatron_amp_O2 = False
model = MegatronGPTModel.restore_from(
restore_path=cfg.gpt_model_file,
trainer=trainer,
Expand Down Expand Up @@ -206,6 +212,8 @@ def main(cfg) -> None:
)

# load the merged_weights back into the base model, for this current rank.
if model.cfg.megatron_amp_O2:
merged_weights = fix_for_O2(merged_weights)
model.load_state_dict(merged_weights)

# Going to go through the motions of inference to force PTL to run subprocess for loading all base model's ranks.
Expand Down

0 comments on commit ad79907

Please sign in to comment.