forked from NVIDIA/NeMo
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
RETRO model finetuning (NVIDIA#5800)
* add save and load dynmaic index Signed-off-by: Yi Dong <[email protected]> * add chunk stride feature Signed-off-by: Yi Dong <[email protected]> * add chunk stride feature Signed-off-by: Yi Dong <[email protected]> * add no pq index Signed-off-by: Yi Dong <[email protected]> * added megatron lm compatible mode Signed-off-by: Yi Dong <[email protected]> * addd config Signed-off-by: Yi Dong <[email protected]> * fix position embedding Signed-off-by: Yi Dong <[email protected]> * added index factory Signed-off-by: Yi Dong <[email protected]> * share neighbors and weights amoung strategies Signed-off-by: Yi Dong <[email protected]> * fix bug Signed-off-by: Yi Dong <[email protected]> * added metric tto faiss index Signed-off-by: Yi Dong <[email protected]> * set default to inner product Signed-off-by: Yi Dong <[email protected]> * added qa fine tuen dataset Signed-off-by: Yi Dong <[email protected]> * added fine tuning code Signed-off-by: Yi Dong <[email protected]> * trim it Signed-off-by: Yi Dong <[email protected]> * fix data issue Signed-off-by: Yi Dong <[email protected]> * fix style Signed-off-by: Yi Dong <[email protected]> * added version Signed-off-by: Yi Dong <[email protected]> * fix key error Signed-off-by: Yi Dong <[email protected]> * make sure to overwrite the cfg Signed-off-by: Yi Dong <[email protected]> * make multiple sentence bert available Signed-off-by: Yi Dong <[email protected]> * fix the document Signed-off-by: Yi Dong <[email protected]> * fix the table Signed-off-by: Yi Dong <[email protected]> * fix transformer Signed-off-by: Yi Dong <[email protected]> * make sure to turn off the rope in chunked cross attention layer Signed-off-by: Yi Dong <[email protected]> * fix the security issue Signed-off-by: Yi Dong <[email protected]> * style fix Signed-off-by: Yi Dong <[email protected]> * fix codeql issues Signed-off-by: Yi Dong <[email protected]> * fix Signed-off-by: Yi Dong <[email protected]> * use -1 Signed-off-by: Yi Dong <[email protected]> * fix empty index Signed-off-by: Yi Dong <[email protected]> * clean up Signed-off-by: Yi Dong <[email protected]> * fix the lower bound for repetition penalty Signed-off-by: Yi Dong <[email protected]> * add retro qa inference strategy Signed-off-by: Yi Dong <[email protected]> * added new inference logic Signed-off-by: Yi Dong <[email protected]> * working inference Signed-off-by: Yi Dong <[email protected]> * fix TP inference Signed-off-by: Yi Dong <[email protected]> * revert requirement Signed-off-by: Yi Dong <[email protected]> * added file inference Signed-off-by: Yi Dong <[email protected]> * use string to prevent collison Signed-off-by: Yi Dong <[email protected]> * use NQ test Signed-off-by: Yi Dong <[email protected]> * fix prompt Signed-off-by: Yi Dong <[email protected]> * fix inference Signed-off-by: Yi Dong <[email protected]> * set good defaults for demo Signed-off-by: Yi Dong <[email protected]> * replicate adlr Signed-off-by: Yi Dong <[email protected]> * make sure to turn off attention reset for megatron lm compatible model Signed-off-by: Yi Dong <[email protected]> * style fix Signed-off-by: Yi Dong <[email protected]> * fix typo Signed-off-by: Yi Dong <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix inference error Signed-off-by: Yi Dong <[email protected]> * fix logging Signed-off-by: Yi Dong <[email protected]> * address comments Signed-off-by: Yi Dong <[email protected]> --------- Signed-off-by: Yi Dong <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
Showing
29 changed files
with
1,996 additions
and
425 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
105 changes: 105 additions & 0 deletions
105
examples/nlp/language_modeling/conf/megatron_retro_finetune_config.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
name: fine_tune_retro | ||
|
||
trainer: | ||
devices: 2 | ||
num_nodes: 1 | ||
accelerator: gpu | ||
precision: 16 | ||
logger: False # logger provided by exp_manager | ||
enable_checkpointing: False | ||
replace_sampler_ddp: False | ||
max_epochs: -1 # PTL default. In practice we don't usually train for more than 1 epoch. | ||
max_steps: 100000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches | ||
log_every_n_steps: 10 | ||
val_check_interval: 100 | ||
limit_val_batches: null | ||
limit_test_batches: null | ||
accumulate_grad_batches: 1 | ||
gradient_clip_val: 1.0 | ||
|
||
exp_manager: | ||
explicit_log_dir: null | ||
exp_dir: null | ||
name: megatron_retro | ||
create_wandb_logger: False | ||
wandb_logger_kwargs: | ||
project: null | ||
name: null | ||
resume_if_exists: True | ||
resume_ignore_no_checkpoint: True | ||
create_checkpoint_callback: True | ||
checkpoint_callback_params: | ||
monitor: val_loss | ||
save_top_k: 10 | ||
mode: min | ||
always_save_nemo: False # saves nemo file during validation, not implemented for model parallel | ||
filename: 'megatron_retro--{val_loss:.2f}-{step}-{consumed_samples}' | ||
model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} | ||
|
||
|
||
model: | ||
# model parallelism | ||
tensor_model_parallel_size: 1 | ||
pipeline_model_parallel_size: 1 # has to be one. not supporting pipeline parallel yet | ||
|
||
micro_batch_size: 4 | ||
megatron_amp_O2: False # use AMP with O2 style mixed precision instead of native amp on-the-fly weight autocasting. | ||
|
||
tokenizer: | ||
library: 'megatron' | ||
type: 'GPT2BPETokenizer' | ||
model: null | ||
vocab_file: null | ||
merge_file: null | ||
delimiter: null # only used for tabular tokenizer | ||
|
||
gradient_as_bucket_view: True # Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) | ||
# precision | ||
native_amp_init_scale: 4294967296 # 2 ** 32 | ||
native_amp_growth_interval: 1000 | ||
fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 | ||
|
||
# miscellaneous | ||
seed: 1234 | ||
|
||
restore_path: null # the retro model restore path | ||
|
||
data: | ||
train_ds: | ||
file_name: ??? # train data file path | ||
answer_only_loss: True # whether use answer only loss | ||
seq_length: 128 # must be multiple of the chunk_size in your dataset | ||
add_bos: True # whether to add bos at the beginning | ||
add_eos: True # whether to add eos at the end | ||
seed: 1234 | ||
neighbors: 20 # number of retrieved neighbors | ||
val_ds: | ||
file_name: ??? # train data file path | ||
answer_only_loss: True # whether use answer only loss | ||
seq_length: 128 # must be multiple of the chunk_size in your dataset | ||
add_bos: True # whether to add bos at the beginning | ||
add_eos: True # whether to add eos at the end | ||
seed: 1234 | ||
neighbors: 20 # number of retrieved neighbors | ||
test_ds: | ||
file_name: ??? # train data file path | ||
answer_only_loss: True # whether use answer only loss | ||
seq_length: 128 # must be multiple of the chunk_size in your dataset | ||
add_bos: True # whether to add bos at the beginning | ||
add_eos: True # whether to add eos at the end | ||
seed: 1234 | ||
neighbors: 20 # number of retrieved neighbors | ||
|
||
|
||
optim: | ||
name: fused_adam | ||
lr: 1e-4 | ||
weight_decay: 0.01 | ||
betas: | ||
- 0.9 | ||
- 0.98 | ||
sched: | ||
name: CosineAnnealing | ||
warmup_steps: 500 | ||
constant_steps: 50000 | ||
min_lr: 1e-5 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
143 changes: 143 additions & 0 deletions
143
examples/nlp/language_modeling/megatron_retro_fine_tune.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,143 @@ | ||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import datetime | ||
import os | ||
|
||
from omegaconf.omegaconf import OmegaConf, open_dict | ||
from pytorch_lightning import Trainer | ||
from pytorch_lightning.callbacks.timer import Timer | ||
from pytorch_lightning.plugins.environments import TorchElasticEnvironment | ||
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin | ||
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector | ||
|
||
from nemo.collections.nlp.models.language_modeling.megatron_retro_fine_tune_model import MegatronRetroFinetuneModel | ||
from nemo.collections.nlp.parts.nlp_overrides import ( | ||
GradScaler, | ||
MegatronHalfPrecisionPlugin, | ||
NLPDDPStrategy, | ||
NLPSaveRestoreConnector, | ||
) | ||
from nemo.core.config import hydra_runner | ||
from nemo.utils import logging | ||
from nemo.utils.exp_manager import StatelessTimer, exp_manager | ||
|
||
|
||
def _modify_config(retro_cfg, cfg, add_cfg_to_tree=False): | ||
""" | ||
This function modifies the original retro pre-training config with attributes from the finetuning config (cfg). | ||
The `add_cfg_to_tree` arg adds `cfg` to the top of the yaml tree which is needed for all `hparams.yaml` files when passed as an arg to `load_from_checkpoint()`. | ||
""" | ||
OmegaConf.set_struct(retro_cfg, True) | ||
with open_dict(retro_cfg): | ||
retro_cfg.megatron_amp_O2 = cfg.model.get('megatron_amp_O2', False) | ||
retro_cfg.data = cfg.model.data | ||
retro_cfg.precision = cfg.trainer.precision | ||
retro_cfg.optim = cfg.model.optim | ||
retro_cfg.micro_batch_size = cfg.model.micro_batch_size | ||
# This is needed when modifying a hparam file directly to load `.ckpt` files. | ||
# This is not needed to modify the cfg in `.nemo` files. | ||
if add_cfg_to_tree: | ||
OmegaConf.resolve(retro_cfg) | ||
retro_cfg.cfg = retro_cfg | ||
return retro_cfg | ||
|
||
|
||
def load_from_nemo(cls, cfg, trainer, retro_cfg, modify_confg_fn, save_restore_connector): | ||
retro_cfg = modify_confg_fn(retro_cfg, cfg, add_cfg_to_tree=False) | ||
model = cls.restore_from( | ||
restore_path=cfg.model.restore_path, | ||
trainer=trainer, | ||
override_config_path=retro_cfg, | ||
save_restore_connector=save_restore_connector, | ||
) | ||
return model | ||
|
||
|
||
@hydra_runner(config_path="conf", config_name="megatron_retro_finetune_config") | ||
def main(cfg) -> None: | ||
logging.info("\n\n************** Experiment configuration ***********") | ||
logging.info(f'\n{OmegaConf.to_yaml(cfg)}') | ||
###### following is the workaround for num_workers=0 issue ##### | ||
# import torch.multiprocessing as mp | ||
# mp.set_start_method("spawn", force=True) | ||
##################################################### | ||
megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False) | ||
plugins = [] | ||
strategy = NLPDDPStrategy( | ||
no_ddp_communication_hook=True if megatron_amp_o2 else False, | ||
gradient_as_bucket_view=cfg.model.gradient_as_bucket_view, | ||
find_unused_parameters=False, | ||
timeout=datetime.timedelta(seconds=18000), | ||
) | ||
|
||
if cfg.trainer.precision in [16, 'bf16']: | ||
scaler = None | ||
if cfg.trainer.precision == 16: | ||
scaler = GradScaler( | ||
init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), | ||
growth_interval=cfg.model.get('native_amp_growth_interval', 1000), | ||
hysteresis=cfg.model.get('hysteresis', 2), | ||
) | ||
if megatron_amp_o2: | ||
plugins.append(MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler)) | ||
else: | ||
plugins.append(NativeMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler)) | ||
|
||
if cfg.get('cluster_type', None) == 'BCP': | ||
plugins.append(TorchElasticEnvironment()) | ||
|
||
trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer) | ||
exp_manager(trainer, cfg.exp_manager) | ||
|
||
# update resume from checkpoint found by exp_manager | ||
resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path | ||
logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}') | ||
|
||
trainer._checkpoint_connector = CheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint) | ||
|
||
# Override timer callback to a stateless one | ||
for idx, callback in enumerate(trainer.callbacks): | ||
if isinstance(callback, Timer): | ||
trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time,) | ||
|
||
# load existing or init new soft prompt GPT model | ||
if cfg.model.get("restore_path", None): | ||
save_restore_connector = NLPSaveRestoreConnector() | ||
if os.path.isdir(cfg.model.restore_path): | ||
save_restore_connector.model_extracted_dir = cfg.model.restore_path | ||
|
||
model_cfg = MegatronRetroFinetuneModel.restore_from( | ||
restore_path=cfg.model.restore_path, | ||
trainer=trainer, | ||
return_config=True, | ||
save_restore_connector=save_restore_connector, | ||
) | ||
# hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams | ||
model = load_from_nemo( | ||
MegatronRetroFinetuneModel, | ||
cfg, | ||
trainer, | ||
model_cfg, | ||
modify_confg_fn=_modify_config, | ||
save_restore_connector=save_restore_connector, | ||
) | ||
else: | ||
model = MegatronRetroFinetuneModel(cfg.model, trainer=trainer) | ||
|
||
trainer.fit(model) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.