Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added P-Tuning method #3488

Merged
merged 31 commits into from
Jan 27, 2022
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
a16598a
init checking of p-tune method
yidong72 Jan 19, 2022
06d1344
training is working
yidong72 Jan 19, 2022
1ee3972
refactor to seperate prediction and loss computation
yidong72 Jan 19, 2022
1c14df0
updated the notebook
yidong72 Jan 20, 2022
7cb3640
match the original hyper parameters
yidong72 Jan 20, 2022
c748a81
fixed the loss bug
yidong72 Jan 20, 2022
80f48c2
better scheduler
yidong72 Jan 21, 2022
8a85024
notebook runs
yidong72 Jan 21, 2022
7168808
added neural types
yidong72 Jan 21, 2022
e59c2e4
updated the doc
yidong72 Jan 21, 2022
d234cf6
fixed the notebook
yidong72 Jan 21, 2022
51cbea2
updated expected result
yidong72 Jan 21, 2022
12e0a25
added accuracy
yidong72 Jan 21, 2022
0f8444f
style fix
yidong72 Jan 21, 2022
0552a40
Merge branch 'main' into feature_ptune
okuchaiev Jan 21, 2022
0dfb9e0
Merge branch 'main' into feature_ptune
yidong72 Jan 24, 2022
b2ebf81
fix reassgin
yidong72 Jan 24, 2022
8991080
log accuracy
yidong72 Jan 24, 2022
b1005fe
Merge branch 'main' into feature_ptune
ericharper Jan 24, 2022
1d76088
load the best checkpoint
yidong72 Jan 25, 2022
593125e
Merge branch 'feature_ptune' of github.com:NVIDIA/NeMo into feature_p…
yidong72 Jan 25, 2022
782526c
Merge branch 'main' into feature_ptune
ericharper Jan 25, 2022
d4e2cdd
address PR comments
yidong72 Jan 26, 2022
b3db907
added ci test
yidong72 Jan 26, 2022
df55257
Merge branch 'main' into feature_ptune
ericharper Jan 26, 2022
f70542c
fixed max_step calculation error due to wrong number of workers
yidong72 Jan 26, 2022
2856e84
add import guard for nlp plugin
yidong72 Jan 26, 2022
105f6db
Merge branch 'main' into feature_ptune
ericharper Jan 26, 2022
f54da3c
fixed the metric report issue when using tensor parallel
yidong72 Jan 27, 2022
034a429
Merge branch 'feature_ptune' of github.com:NVIDIA/NeMo into feature_p…
yidong72 Jan 27, 2022
11cb31b
Merge branch 'main' into feature_ptune
yidong72 Jan 27, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/nlp/language_modeling/megatron_lm_ckpt_to_nemo.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def convert(rank, world_size, args):
## this dictionary is used to rename the model parameters
name_translate = {}
name_translate['transformer'] = 'encoder'
name_translate['.attention.'] = '.self_attention.'
model = load_from_checkpoint(
MegatronGPTModel,
checkpoint_path,
Expand All @@ -242,7 +243,7 @@ def convert(rank, world_size, args):
## this dictionary is used to rename the model parameters
name_translate = {}
name_translate['transformer'] = 'encoder'
name_translate['attention.'] = 'self_attention.'
name_translate['.attention.'] = '.self_attention.'
model = load_from_checkpoint(
MegatronBertModel,
checkpoint_path,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
ericharper marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Config file for text classification with pre-trained BERT models

trainer:
gpus: 1 # number of GPUs (0 for CPU), or list of the GPUs to use e.g. [0, 1]
num_nodes: 1
max_epochs: 100
max_steps: null # precedence over max_epochs
accumulate_grad_batches: 1 # accumulates grads every k batches
gradient_clip_val: 0.0
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
accelerator: ddp
log_every_n_steps: 1 # Interval of logging.
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it

checkpoint_callback: False # Provided by exp_manager
logger: False # Provided by exp_manager

model:
tensor_model_parallel_size: 2 # tensor model parallel size used in the LM model
ericharper marked this conversation as resolved.
Show resolved Hide resolved
seed: 1234
nemo_path: ptune_text_classification_model.nemo # filename to save the model and associated artifacts to .nemo file
ericharper marked this conversation as resolved.
Show resolved Hide resolved
use_lm_finetune: False # whether fine tune the language model
pseudo_token: '[PROMPT]' # pseudo prompt tokens

tokenizer:
library: 'megatron'
type: 'GPT2BPETokenizer'
model: null
vocab_file: null
merge_file: null

language_model:
nemo_file: null

prompt_encoder:
template: [3, 3, 0]
dropout: 0.0
num_layers: 2

dataset:
classes: ??? # The class labels, e.g. ['positive', 'neutral', 'negative']

train_ds:
file_path: null
batch_size: 64
shuffle: true
num_samples: -1 # number of samples to be considered, -1 means all the dataset
num_workers: 3
drop_last: false
pin_memory: false

validation_ds:
file_path: null
batch_size: 64
shuffle: false
num_samples: -1 # number of samples to be considered, -1 means all the dataset
num_workers: 3
drop_last: false
pin_memory: false

test_ds:
file_path: null
batch_size: 64
shuffle: false
num_samples: -1 # number of samples to be considered, -1 means all the dataset
num_workers: 3
drop_last: false
pin_memory: false

optim:
name: adam
lr: 1e-5
# optimizer arguments
betas: [0.9, 0.999]
weight_decay: 0.0005

# scheduler setup
sched:
name: WarmupAnnealing
# Scheduler params
warmup_steps: null
warmup_ratio: 0.1
last_epoch: -1
# pytorch lightning args
monitor: val_loss
reduce_on_plateau: false

# List of some sample queries for inference after training is done
infer_samples: [
'For example , net sales increased by 5.9 % from the first quarter , and EBITDA increased from a negative EUR 0.2 mn in the first quarter of 2009 .',
'8 May 2009 - Finnish liquid handling products and diagnostic test systems maker Biohit Oyj ( HEL : BIOBV ) said today ( 8 May 2009 ) its net loss narrowed to EUR0 .1 m ( USD0 .14 m ) for the first quarter of 2009 from EUR0 .4 m for the same period of 2008 .',
'CHS Expo Freight is a major Finnish fair , exhibition and culture logistics company that provides logistics services to various events by land , air and sea .',
]

exp_manager:
exp_dir: null # exp_dir for your experiment, if None, defaults to "./nemo_experiments"
name: "PTuneTextClassification" # The name of your model
create_tensorboard_logger: True # Whether you want exp_manger to create a tb logger
create_checkpoint_callback: True # Whether you want exp_manager to create a modelcheckpoint callback
188 changes: 188 additions & 0 deletions examples/nlp/text_classification/ptune_text_classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
This script contains an example on how to train, evaluate and perform inference with the PTuneTextClassificationModel.
PTuneTextClassificationModel in NeMo supports text classification problems such as sentiment analysis or
domain/intent detection for dialogue systems, as long as the data follows the format specified below.

***Data format***
PTuneTextClassificationModel requires the data to be stored in loose json format with two keys of sentence and
label in each line, i.e.
{"sentence": "sentence string", "label": "label string"}

For example:

{"sentence": "The output of the contracts totals 72 MWe. ", "label": "neutral"}
{"sentence": "Pretax profit totaled EUR 9.0 mn , down from EUR 36.3 mn in 2007 .", "label": "negative"}
...

If your dataset is stored in another format, you need to convert it to this format to use the PTuneTextClassificationModel.


***Setting the configs***
The model and the PT trainer are defined in a config file which declares multiple important sections.
The most important ones are:
model: All arguments that are related to the Model - language model, tokenizer, head classifier, optimizer,
schedulers, and datasets/data loaders.
trainer: Any argument to be passed to PyTorch Lightning including number of epochs, number of GPUs,
precision level, etc.

This script uses the `/examples/nlp/text_classification/conf/ptune_text_classification_config.yaml` default config file
by default. You may update the config file from the file directly or by using the command line arguments.
Other option is to set another config file via command line arguments by `--config-name=CONFIG_FILE_PATH'.

You first need to set the classes in the config file which specifies the class types in the dataset.
Notice that some config lines, including `model.dataset.classes`, have `???` as their value, this means that values
for these fields are required to be specified by the user. We need to specify and set the `model.train_ds.file_name`,
`model.validation_ds.file_name`, and `model.test_ds.file_name` in the config file to the paths of the train, validation,
and test files if they exist. We may do it by updating the config file or by setting them from the command line.


***How to run the script?***
For example the following would train a model for 50 epochs in 2 GPUs on a classification task with 2 classes:

# python ptune_text_classification.py
model.dataset.classes=[Label1, Label2]
model.train_ds=PATH_TO_TRAIN_FILE
model.validation_ds=PATH_TO_VAL_FILE
trainer.max_epochs=50
trainer.gpus=2

This script would also reload the last checkpoint after the training is done and does evaluation on the dev set,
then performs inference on some sample queries.

By default, this script uses examples/nlp/text_classification/conf/ptune_text_classifciation_config.py config file, and
you may update all the params in the config file from the command line. You may also use another config file like this:

# python ptune_text_classification.py --config-name==PATH_TO_CONFIG_FILE
model.dataset.num_classes=2
model.train_ds=PATH_TO_TRAIN_FILE
model.validation_ds=PATH_TO_VAL_FILE
trainer.max_epochs=50
trainer.gpus=2

***Load a saved model***
This script would save the model after training into '.nemo' checkpoint file specified by nemo_path of the model config.
You may restore the saved model like this:
model = PTuneTextClassificationModel.restore_from(restore_path=NEMO_FILE_PATH)

***Evaluation a saved model on another dataset***
# If you wanted to evaluate the saved model on another dataset, you may restore the model and create a new data loader:
eval_model = TextClassificationModel.restore_from(restore_path=checkpoint_path)

# Then, you may create a dataloader config for evaluation:
eval_config = OmegaConf.create(
{'file_path': cfg.model.test_ds.file_path, 'batch_size': 64, 'shuffle': False, 'num_workers': 3}
)
eval_model.setup_test_data(test_data_config=eval_config)

# You need to create a new trainer:
eval_trainer = pl.Trainer(gpus=1)
eval_model.set_trainer(eval_trainer)
eval_trainer.test(model=eval_model, verbose=False)
"""
import os
import pathlib

import pytorch_lightning as pl
import torch
from omegaconf import DictConfig, OmegaConf

from nemo.collections.nlp.models.text_classification.ptune_text_classification_model import (
PTuneTextClassificationModel,
)
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPPlugin
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager


@hydra_runner(config_path="conf", config_name="ptune_text_classification_config")
def main(cfg: DictConfig) -> None:
logging.info(f'\nConfig Params:\n{OmegaConf.to_yaml(cfg)}')
trainer = pl.Trainer(plugins=[NLPDDPPlugin()], **cfg.trainer)
exp_manager(trainer, cfg.get("exp_manager", None))

if not cfg.model.train_ds.file_path:
raise ValueError("'train_ds.file_path' need to be set for the training!")

model = PTuneTextClassificationModel(cfg.model, trainer=trainer)
logging.info("===========================================================================================")
logging.info('Starting training...')
trainer.fit(model)
logging.info('Training finished!')
logging.info("===========================================================================================")

# We evaluate the trained model on the test set if test_ds is set in the config file
if cfg.model.test_ds.file_path:
logging.info("===========================================================================================")
logging.info("Starting the testing of the trained model on test set...")
trainer.test(model=model, ckpt_path=None, verbose=False)
logging.info("Testing finished!")
logging.info("===========================================================================================")

# extract the path of the best checkpoint from the training, you may update it to any checkpoint
checkpoint_path = trainer.checkpoint_callback.best_model_path
tensor_parallel_size = cfg.model.tensor_model_parallel_size
pathobj = pathlib.Path(checkpoint_path)
checkpoint_folder = str(pathobj.parent)
checkpoint_name = str(pathobj.name)

rank = trainer.accelerator.training_type_plugin.local_rank
if tensor_parallel_size > 1:
# inject model parallel rank
checkpoint_path = os.path.join(checkpoint_folder, f'mp_rank_{rank:02d}', checkpoint_name)
ericharper marked this conversation as resolved.
Show resolved Hide resolved
else:
checkpoint_path = os.path.join(checkpoint_folder, checkpoint_name)

# Load the checkpoint
best_eval_model = PTuneTextClassificationModel.load_from_checkpoint(
checkpoint_path=checkpoint_path, strict=False, trainer=trainer
)
logging.info(f'best checkpoint path: {checkpoint_path}')
logging.info("Running Test with best EVAL checkpoint!")
# setup the test dataset
best_eval_model.setup_test_data(test_data_config=cfg.model.test_ds)
if torch.distributed.is_initialized():
torch.distributed.barrier()
trainer.test(model=best_eval_model, ckpt_path=None, verbose=False)
logging.info("Beset EVAL Testing finished!")
logging.info("===========================================================================================")

if cfg.model.nemo_path:
# '.nemo' file contains the last checkpoint and the params to initialize the model
best_eval_model.save_to(cfg.model.nemo_path)
logging.info(f'Model is saved into `.nemo` file: {cfg.model.nemo_path}')

# perform inference on a list of queries.
if "infer_samples" in cfg.model and cfg.model.infer_samples:
logging.info("===========================================================================================")
logging.info("Starting the inference on some sample queries...")

# max_seq_length=512 is the maximum length BERT supports.
results = best_eval_model.cuda().classifytext(
queries=cfg.model.infer_samples, batch_size=1, prompt='Sentiment'
)
logging.info('The prediction results of some sample queries with the trained model:')
for query, result in zip(cfg.model.infer_samples, results):
logging.info(f'Query : {query}')
logging.info(f'Predicted label: {result}')

logging.info("Inference finished!")
logging.info("===========================================================================================")


if __name__ == '__main__':
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright 2018 The Google AI Language Team Authors and
# The HuggingFace Inc. team.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
ericharper marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import json
import os
from typing import Dict, List, Optional

from nemo.core.classes import Dataset
from nemo.core.neural_types import NeuralType, StringLabel, StringType

__all__ = ['PTuneTextClassificationDataset', 'token_wrapper']


def load_file(filename):
data = []
with open(filename, "r") as f:
for line in f.readlines():
data.append(json.loads(line))
return data


def token_wrapper(token: str) -> str:
return 'Ġ' + token


class PTuneTextClassificationDataset(Dataset):
@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
return {"sentences": [NeuralType(('T'), StringType())], "labels": [NeuralType(('T'), StringLabel())]}

def __init__(self, input_file: str, queries: List[str] = None, prompt: str = 'Sentiment'):
"""
A dataset class that feed data for P-tuning model
Args:
input_file: loose json data file. The format is {"sentence":"input sentence", "label":"class label"}
queries: list of query input sentences
prompt: the prompt string appended at the end of your input sentence
"""
super().__init__()
if input_file and not os.path.exists(input_file):
raise FileNotFoundError(
f'Data file `{input_file}` not found! Each line of the data file should contain json object'
f'where `sentence` key maps to sentence and `label` key maps to label'
)
if queries is None:
json_data = load_file(input_file)
else:
json_data = []
for line in queries:
json_data.append({'sentence': line + f' {prompt} ', 'label': ''})
self.data = json_data

def __len__(self):
return len(self.data)

def __getitem__(self, i):
return self.data[i]['sentence'], self.data[i]['label']
Loading