Skip to content

Commit

Permalink
RETRO model finetuning (NVIDIA#5800)
Browse files Browse the repository at this point in the history
* add save and load dynmaic index

Signed-off-by: Yi Dong <[email protected]>

* add chunk stride feature

Signed-off-by: Yi Dong <[email protected]>

* add chunk stride feature

Signed-off-by: Yi Dong <[email protected]>

* add no pq index

Signed-off-by: Yi Dong <[email protected]>

* added megatron lm compatible mode

Signed-off-by: Yi Dong <[email protected]>

* addd config

Signed-off-by: Yi Dong <[email protected]>

* fix position embedding

Signed-off-by: Yi Dong <[email protected]>

* added index factory

Signed-off-by: Yi Dong <[email protected]>

* share neighbors and weights amoung strategies

Signed-off-by: Yi Dong <[email protected]>

* fix bug

Signed-off-by: Yi Dong <[email protected]>

* added metric tto faiss index

Signed-off-by: Yi Dong <[email protected]>

* set default to inner product

Signed-off-by: Yi Dong <[email protected]>

* added qa fine tuen dataset

Signed-off-by: Yi Dong <[email protected]>

* added fine tuning code

Signed-off-by: Yi Dong <[email protected]>

* trim it

Signed-off-by: Yi Dong <[email protected]>

* fix data issue

Signed-off-by: Yi Dong <[email protected]>

* fix style

Signed-off-by: Yi Dong <[email protected]>

* added version

Signed-off-by: Yi Dong <[email protected]>

* fix key error

Signed-off-by: Yi Dong <[email protected]>

* make sure to overwrite the cfg

Signed-off-by: Yi Dong <[email protected]>

* make multiple sentence bert available

Signed-off-by: Yi Dong <[email protected]>

* fix the document

Signed-off-by: Yi Dong <[email protected]>

* fix the table

Signed-off-by: Yi Dong <[email protected]>

* fix transformer

Signed-off-by: Yi Dong <[email protected]>

* make sure to turn off the rope in chunked cross attention layer

Signed-off-by: Yi Dong <[email protected]>

* fix the security issue

Signed-off-by: Yi Dong <[email protected]>

* style fix

Signed-off-by: Yi Dong <[email protected]>

* fix codeql issues

Signed-off-by: Yi Dong <[email protected]>

* fix

Signed-off-by: Yi Dong <[email protected]>

* use -1

Signed-off-by: Yi Dong <[email protected]>

* fix empty index

Signed-off-by: Yi Dong <[email protected]>

* clean up

Signed-off-by: Yi Dong <[email protected]>

* fix the lower bound for repetition penalty

Signed-off-by: Yi Dong <[email protected]>

* add retro qa inference strategy

Signed-off-by: Yi Dong <[email protected]>

* added new inference logic

Signed-off-by: Yi Dong <[email protected]>

* working inference

Signed-off-by: Yi Dong <[email protected]>

* fix TP inference

Signed-off-by: Yi Dong <[email protected]>

* revert requirement

Signed-off-by: Yi Dong <[email protected]>

* added file inference

Signed-off-by: Yi Dong <[email protected]>

* use string to prevent collison

Signed-off-by: Yi Dong <[email protected]>

* use NQ test

Signed-off-by: Yi Dong <[email protected]>

* fix prompt

Signed-off-by: Yi Dong <[email protected]>

* fix inference

Signed-off-by: Yi Dong <[email protected]>

* set good defaults for demo

Signed-off-by: Yi Dong <[email protected]>

* replicate adlr

Signed-off-by: Yi Dong <[email protected]>

* make sure to turn off attention reset for megatron lm compatible model

Signed-off-by: Yi Dong <[email protected]>

* style fix

Signed-off-by: Yi Dong <[email protected]>

* fix typo

Signed-off-by: Yi Dong <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix inference error

Signed-off-by: Yi Dong <[email protected]>

* fix logging

Signed-off-by: Yi Dong <[email protected]>

* address comments

Signed-off-by: Yi Dong <[email protected]>

---------

Signed-off-by: Yi Dong <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and titu1994 committed Mar 24, 2023
1 parent c62f453 commit cd32486
Show file tree
Hide file tree
Showing 29 changed files with 1,996 additions and 425 deletions.
16 changes: 10 additions & 6 deletions docs/source/nlp/nemo_megatron/retro/retro_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,22 +65,25 @@ An example script to prepare data for RETRO training is:
--append-eod \
--retrieval-db \
--chunk_size=64 \
--chunk_stride_size=64 \
--workers=48
The RETRO model processes chunked documents using 64 tokens as the default chunk size. The RETRO memory map dataset will add padding
tokens to the end of each document to make it a multiple of 64. The ``--need-pad-id`` argument adds a padding token to the tokenizer
The RETRO model processes chunked documents using 64 tokens as the default chunk size. The ``--chunk_stride_size`` argument argument
determines the distance between consecutive chunks. To ensure the documents are a multiple of ``--chunk_size``` tokens, the RETRO memory map dataset
adds padding tokens to the end of each document. The ``--need-pad-id`` argument adds a padding token to the tokenizer
if it doesn't already have one. The ``--append-eod`` argument controls whether to add ``end-of-document`` tokens to the preprocessed
data, and the ``--retrieval-db`` argument indicates whether to create a retrieval database for the preprocessed data. If ``--retrieval-db``
is used, it will add an additional 64 padding tokens at the end of the document. The ``--chunk_size`` and ``--workers`` arguments
is used, it will add an additional ``--chunk_size``` padding tokens at the end of the document. The ``--chunk_size`` and ``--workers`` arguments
control the size of the data chunks to be processed and the number of worker processes to use, respectively.

Following is the retro memory map index data format:

.. list-table::
:widths: 25 25 25 25 25 25
:widths: 25 25 25 25 25 25 25

* - 'MMIDRET\x00\x00' (header 9 bytes)
- 1 (version 8 byte)
- 1 (version 4 byte)
- 64 (stride 4 byte)
- dtype code :sup:`1` (1 byte)
- sentence count (8 byte)
- chunk size (8 byte)
Expand All @@ -91,6 +94,7 @@ Following is the retro memory map index data format:
- start of chunk id (int64 array)
- chunk id address in byte (int64 array)
-
-

:sup:`1` 1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float, 7: np.double, 8: np.uint16

Expand Down Expand Up @@ -443,7 +447,7 @@ We have built a simple web client that makes it easy for users to play around wi
retro_model_file=megatron_retro.nemo \
tensor_model_parallel_size=8 \
pipeline_model_parallel_size=1 \
retrieval_service.sentence_bert.devices=\'0,1,2,3,4,5,6,7\' \
retrieval_service.sentence_bert.default.devices=\'0,1,2,3,4,5,6,7\' \
retrieval_service.services.0.faiss_devices=\'0,1,2,3,4,5,6,7\' \
retrieval_service.services.1.faiss_devices=\'0,1,2,3,4,5,6,7\' \
retrieval_service.services.0.faiss_index=/result/pubmed_faiss_final.index \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ exp_manager:


model:
version: 1 # indicate the retro model version

# model parallelism
micro_batch_size: 4
tensor_model_parallel_size: 1
Expand Down Expand Up @@ -72,6 +74,8 @@ model:
megatron_amp_O2: False # use AMP with O2 style mixed precision instead of native amp on-the-fly weight autocasting.
grad_allreduce_chunk_size_mb: 125

megatron_lm_compatible: False # a flag to indicate whether the model is compatible with Megatron LM

tokenizer:
library: 'megatron'
type: 'GPT2BPETokenizer'
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
name: fine_tune_retro

trainer:
devices: 2
num_nodes: 1
accelerator: gpu
precision: 16
logger: False # logger provided by exp_manager
enable_checkpointing: False
replace_sampler_ddp: False
max_epochs: -1 # PTL default. In practice we don't usually train for more than 1 epoch.
max_steps: 100000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches
log_every_n_steps: 10
val_check_interval: 100
limit_val_batches: null
limit_test_batches: null
accumulate_grad_batches: 1
gradient_clip_val: 1.0

exp_manager:
explicit_log_dir: null
exp_dir: null
name: megatron_retro
create_wandb_logger: False
wandb_logger_kwargs:
project: null
name: null
resume_if_exists: True
resume_ignore_no_checkpoint: True
create_checkpoint_callback: True
checkpoint_callback_params:
monitor: val_loss
save_top_k: 10
mode: min
always_save_nemo: False # saves nemo file during validation, not implemented for model parallel
filename: 'megatron_retro--{val_loss:.2f}-{step}-{consumed_samples}'
model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}}


model:
# model parallelism
tensor_model_parallel_size: 1
pipeline_model_parallel_size: 1 # has to be one. not supporting pipeline parallel yet

micro_batch_size: 4
megatron_amp_O2: False # use AMP with O2 style mixed precision instead of native amp on-the-fly weight autocasting.

tokenizer:
library: 'megatron'
type: 'GPT2BPETokenizer'
model: null
vocab_file: null
merge_file: null
delimiter: null # only used for tabular tokenizer

gradient_as_bucket_view: True # Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory)
# precision
native_amp_init_scale: 4294967296 # 2 ** 32
native_amp_growth_interval: 1000
fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16

# miscellaneous
seed: 1234

restore_path: null # the retro model restore path

data:
train_ds:
file_name: ??? # train data file path
answer_only_loss: True # whether use answer only loss
seq_length: 128 # must be multiple of the chunk_size in your dataset
add_bos: True # whether to add bos at the beginning
add_eos: True # whether to add eos at the end
seed: 1234
neighbors: 20 # number of retrieved neighbors
val_ds:
file_name: ??? # train data file path
answer_only_loss: True # whether use answer only loss
seq_length: 128 # must be multiple of the chunk_size in your dataset
add_bos: True # whether to add bos at the beginning
add_eos: True # whether to add eos at the end
seed: 1234
neighbors: 20 # number of retrieved neighbors
test_ds:
file_name: ??? # train data file path
answer_only_loss: True # whether use answer only loss
seq_length: 128 # must be multiple of the chunk_size in your dataset
add_bos: True # whether to add bos at the beginning
add_eos: True # whether to add eos at the end
seed: 1234
neighbors: 20 # number of retrieved neighbors


optim:
name: fused_adam
lr: 1e-4
weight_decay: 0.01
betas:
- 0.9
- 0.98
sched:
name: CosineAnnealing
warmup_steps: 500
constant_steps: 50000
min_lr: 1e-5
24 changes: 20 additions & 4 deletions examples/nlp/language_modeling/conf/megatron_retro_inference.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,25 +34,41 @@ prompts: # prompts for RETRO model inference

########### Faiss service parameters ########
retrieval_service:
strategy: RetroModelTextGenerationStrategy # choose customized inference strategy
neighbors: 4
frequent_query: False # for the current token generation, frequently update the retrieval context. If false, update it every 64 tokens
pad_tokens: True # pad the tokens at the beginning to make it minimum of 64 tokens for retrieving at least once
store_retrieved: False # whether store the retrieved documents, so it can be checked
weights: [0.5, 0.5] # weight for different retrieval services
sentence_bert:
devices: '0,1,2'
sentence_bert: 'all-mpnet-base-v2'
sentence_bert_batch: 4
sentence_bert: # define a few sentence bert models for different retrieval services to use
default:
devices: '0,1,2'
sentence_bert: 'all-mpnet-base-v2'
sentence_bert_batch: 4
qa_ctx:
devices: '0,1,2'
sentence_bert: 'facebook-dpr-ctx_encoder-multiset-base'
sentence_bert_batch: 4
qa_question:
devices: '0,1,2'
sentence_bert: 'facebook-dpr-question_encoder-multiset-base'
sentence_bert_batch: 4
services:
- type: FaissRetrievalService
faiss_devices: '0,1,2'
faiss_index: null # the faiss index file that is used to find KNN
nprobe: 100
retrieval_index: null
query_bert: 'default' # the bert model to encode the query str
- type: DynamicFaissRetrievalService
faiss_devices: '0,1,2'
faiss_index: null # the faiss index to load from file, if null, start from scratch
store_file: null # the retrieval service storage to load from file, if null, start from scratch
chunk_size: 64
stride: 32
ctx_bert: 'qa_ctx' # the bert model to encode the ctx that is used to construct the dynamic retrieval index
query_bert: 'qa_question' # the bert model to encode the query str
output_filename: 'dynamic_db' # the filename of serialized dynamic retrieval service, used for both Faiss index and data storage
server: False # whether launch the API server
port: 5555 # the port number for the inference server
web_server: False # whether launch the web inference server
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ exp_manager:
model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}}

base_model:
version: 1 # indicate the retro model version
# model parallelism
micro_batch_size: 4
tensor_model_parallel_size: 1
Expand Down Expand Up @@ -87,6 +88,7 @@ base_model:
seed: 1234

delta_model:
version: 1 # indicate the retro model version
# model parallelism
micro_batch_size: 4
tensor_model_parallel_size: 1
Expand Down Expand Up @@ -130,6 +132,7 @@ delta_model:
seed: 1234

model:
version: 1 # indicate the retro model version
shape_file: null # the path to the shape file
# model parallelism
micro_batch_size: 4
Expand Down
3 changes: 3 additions & 0 deletions examples/nlp/language_modeling/megatron_retro_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def main(cfg) -> None:

with open_dict(model_cfg):
model_cfg.precision = trainer.precision
model_cfg.sequence_parallel = False
model_cfg.activations_checkpoint_granularity = None
model_cfg.activations_checkpoint_method = None

model = MegatronRetrievalModel.restore_from(
model_path, trainer=trainer, save_restore_connector=save_restore_connector, override_config_path=model_cfg,
Expand Down
143 changes: 143 additions & 0 deletions examples/nlp/language_modeling/megatron_retro_fine_tune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import os

from omegaconf.omegaconf import OmegaConf, open_dict
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks.timer import Timer
from pytorch_lightning.plugins.environments import TorchElasticEnvironment
from pytorch_lightning.plugins.precision.native_amp import NativeMixedPrecisionPlugin
from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector

from nemo.collections.nlp.models.language_modeling.megatron_retro_fine_tune_model import MegatronRetroFinetuneModel
from nemo.collections.nlp.parts.nlp_overrides import (
GradScaler,
MegatronHalfPrecisionPlugin,
NLPDDPStrategy,
NLPSaveRestoreConnector,
)
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import StatelessTimer, exp_manager


def _modify_config(retro_cfg, cfg, add_cfg_to_tree=False):
"""
This function modifies the original retro pre-training config with attributes from the finetuning config (cfg).
The `add_cfg_to_tree` arg adds `cfg` to the top of the yaml tree which is needed for all `hparams.yaml` files when passed as an arg to `load_from_checkpoint()`.
"""
OmegaConf.set_struct(retro_cfg, True)
with open_dict(retro_cfg):
retro_cfg.megatron_amp_O2 = cfg.model.get('megatron_amp_O2', False)
retro_cfg.data = cfg.model.data
retro_cfg.precision = cfg.trainer.precision
retro_cfg.optim = cfg.model.optim
retro_cfg.micro_batch_size = cfg.model.micro_batch_size
# This is needed when modifying a hparam file directly to load `.ckpt` files.
# This is not needed to modify the cfg in `.nemo` files.
if add_cfg_to_tree:
OmegaConf.resolve(retro_cfg)
retro_cfg.cfg = retro_cfg
return retro_cfg


def load_from_nemo(cls, cfg, trainer, retro_cfg, modify_confg_fn, save_restore_connector):
retro_cfg = modify_confg_fn(retro_cfg, cfg, add_cfg_to_tree=False)
model = cls.restore_from(
restore_path=cfg.model.restore_path,
trainer=trainer,
override_config_path=retro_cfg,
save_restore_connector=save_restore_connector,
)
return model


@hydra_runner(config_path="conf", config_name="megatron_retro_finetune_config")
def main(cfg) -> None:
logging.info("\n\n************** Experiment configuration ***********")
logging.info(f'\n{OmegaConf.to_yaml(cfg)}')
###### following is the workaround for num_workers=0 issue #####
# import torch.multiprocessing as mp
# mp.set_start_method("spawn", force=True)
#####################################################
megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False)
plugins = []
strategy = NLPDDPStrategy(
no_ddp_communication_hook=True if megatron_amp_o2 else False,
gradient_as_bucket_view=cfg.model.gradient_as_bucket_view,
find_unused_parameters=False,
timeout=datetime.timedelta(seconds=18000),
)

if cfg.trainer.precision in [16, 'bf16']:
scaler = None
if cfg.trainer.precision == 16:
scaler = GradScaler(
init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32),
growth_interval=cfg.model.get('native_amp_growth_interval', 1000),
hysteresis=cfg.model.get('hysteresis', 2),
)
if megatron_amp_o2:
plugins.append(MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))
else:
plugins.append(NativeMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler))

if cfg.get('cluster_type', None) == 'BCP':
plugins.append(TorchElasticEnvironment())

trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer)
exp_manager(trainer, cfg.exp_manager)

# update resume from checkpoint found by exp_manager
resume_from_checkpoint = trainer._checkpoint_connector.resume_from_checkpoint_fit_path
logging.info(f'Resuming training from checkpoint: {resume_from_checkpoint}')

trainer._checkpoint_connector = CheckpointConnector(trainer, resume_from_checkpoint=resume_from_checkpoint)

# Override timer callback to a stateless one
for idx, callback in enumerate(trainer.callbacks):
if isinstance(callback, Timer):
trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time,)

# load existing or init new soft prompt GPT model
if cfg.model.get("restore_path", None):
save_restore_connector = NLPSaveRestoreConnector()
if os.path.isdir(cfg.model.restore_path):
save_restore_connector.model_extracted_dir = cfg.model.restore_path

model_cfg = MegatronRetroFinetuneModel.restore_from(
restore_path=cfg.model.restore_path,
trainer=trainer,
return_config=True,
save_restore_connector=save_restore_connector,
)
# hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams
model = load_from_nemo(
MegatronRetroFinetuneModel,
cfg,
trainer,
model_cfg,
modify_confg_fn=_modify_config,
save_restore_connector=save_restore_connector,
)
else:
model = MegatronRetroFinetuneModel(cfg.model, trainer=trainer)

trainer.fit(model)


if __name__ == '__main__':
main()
Loading

0 comments on commit cd32486

Please sign in to comment.