Skip to content

Commit

Permalink
address PR comments
Browse files Browse the repository at this point in the history
Signed-off-by: Yi Dong <[email protected]>
  • Loading branch information
yidong72 committed Jan 26, 2022
1 parent 782526c commit d4e2cdd
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -32,9 +32,9 @@ trainer:
logger: False # Provided by exp_manager

model:
tensor_model_parallel_size: 2 # tensor model parallel size used in the LM model
tensor_model_parallel_size: 1 # tensor model parallel size used in the LM model
seed: 1234
nemo_path: ptune_text_classification_model.nemo # filename to save the model and associated artifacts to .nemo file
nemo_path: null # filename to save the model and associated artifacts to .nemo file
use_lm_finetune: False # whether fine tune the language model
pseudo_token: '[PROMPT]' # pseudo prompt tokens

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2018 The Google AI Language Team Authors and
# Copyright 2022 The Google AI Language Team Authors and
# The HuggingFace Inc. team.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -84,7 +84,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.model = MegatronGPTModel.restore_from(
self.register_artifact('language_model.nemo_file', cfg.language_model.get('nemo_file', None)),
trainer=trainer,
).half()
)

for param in self.model.parameters():
param.requires_grad = cfg.use_lm_finetune
Expand Down Expand Up @@ -262,10 +262,19 @@ def forward(self, sentences, labels):
encoder_input, new_atten, label_position = self.get_encoder_input(sentences)
batch_size, _, seq_len, _ = new_atten.shape
labels_input, label_ids = self.get_label_input(labels, label_position, seq_len)
# workaround to do auto-cast
# get the LM dtype
dtype = self.model.model.language_model.encoder.layers[0].dtype

output = self.model.model(
None, None, encoder_input=encoder_input, attention_mask=new_atten, labels=labels_input
)
if dtype == torch.float32:
output = self.model.model(
None, None, encoder_input=encoder_input, attention_mask=new_atten, labels=labels_input
)
else:
with torch.autocast(device_type="cuda", dtype=dtype):
output = self.model.model(
None, None, encoder_input=encoder_input, attention_mask=new_atten, labels=labels_input
)
loss, logits = output
floss = (loss[(labels_input != SMALL_LOGITS)]).mean()

Expand Down
2 changes: 1 addition & 1 deletion tutorials/nlp/PTune_Sentiment_Analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
"source": [
"In this tutorial, we are going to describe how to use [P-Tuning method](https://arxiv.org/pdf/2103.10385.pdf) to find good prompts for large GPT models, so it can solve downstream NLP tasks with good performance. P-Tuning leverages few continuous free parameters to serve as prompts fed as the input to the pre-trained language models. Freezing the large language model weights, P-Tuning model can be trained efficiently while delivering stats of art performance. \n",
"\n",
"Large Language Model can be trained with [Megatron-LM project](https://github.com/NVIDIA/Megatron-LM), up to multi-billion parameters. In this notebook, we will use the pre-trained 344M GPT model released from NGC.\n",
"Large Language Model can be trained with [NeMo Megatron](https://github.com/NVIDIA/NeMo/tree/main/examples/nlp/language_modeling), up to multi-billion parameters. In this notebook, we will use the pre-trained 344M GPT model released from NGC.\n",
"\n",
"# Task Description\n",
"In this notebook, we are going to use P-Tuning method for **Sentiment Analysis** task, also known as opinion mining or emotion AI. It is a sub-field of NLP that tries to identify and extract opinions within a given text across blogs, reviews, social media, forums, news etc.\n",
Expand Down

0 comments on commit d4e2cdd

Please sign in to comment.