Skip to content

Commit

Permalink
Generalizing Inference pipeline in NeMo 2.0 to support encoder-decode…
Browse files Browse the repository at this point in the history
…r models (NVIDIA#10924)

* initial commit

* adding example t5_generate.py

* workable inference code

* updating code

* update cpde

* workable solution for T5 tokenizer (we add 100 sentinel tokens when initializing tokenizer throug setting config, instead of adding after initialization)

* separate autokenizer's changes to another PR

* cleaning code

* addressing Marc's comments

* addressing Marc's reviews

* update code after merge

* small fix

* Apply isort and black reformatting

Signed-off-by: huvunvidia <[email protected]>

---------

Signed-off-by: huvunvidia <[email protected]>
Co-authored-by: Huy Vu2 <[email protected]>
Co-authored-by: root <[email protected]>
Co-authored-by: huvunvidia <[email protected]>
  • Loading branch information
4 people authored and XuesongYang committed Jan 18, 2025
1 parent 13c2b7e commit a89e0e8
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -224,9 +224,12 @@ def text_to_ids(self, text):
ids = self.tokens_to_ids(tokens)
return ids

def ids_to_text(self, ids):
def ids_to_text(self, ids, remove_special_tokens=True):
tokens = self.ids_to_tokens(ids)
tokens_clean = [t for t in tokens if t not in self.tokenizer.all_special_tokens]
if remove_special_tokens:
tokens_clean = [t for t in tokens if t not in self.tokenizer.all_special_tokens]
else:
tokens_clean = tokens
text = self.tokens_to_text(tokens_clean)
return text

Expand Down
4 changes: 4 additions & 0 deletions nemo/collections/llm/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,9 @@ def generate(
path: Union[Path, str],
prompts: list[str],
trainer: nl.Trainer,
encoder_prompts: Optional[list[str]] = None,
params_dtype: torch.dtype = torch.bfloat16,
add_BOS: bool = False,
max_batch_size: int = 4,
random_seed: Optional[int] = None,
inference_batch_times_seqlen_threshold: int = 1000,
Expand All @@ -456,6 +458,8 @@ def generate(
model=inference_wrapped_model,
tokenizer=mcore_tokenizer,
prompts=prompts,
encoder_prompts=encoder_prompts,
add_BOS=add_BOS,
max_batch_size=max_batch_size,
random_seed=random_seed,
inference_params=inference_params,
Expand Down
22 changes: 22 additions & 0 deletions nemo/collections/llm/gpt/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import pytorch_lightning as L
import torch
import torch.distributed
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig
from megatron.core.optimizer import OptimizerConfig
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_config import TransformerConfig
Expand Down Expand Up @@ -310,6 +312,26 @@ def validation_step(self, batch, batch_idx=None) -> torch.Tensor:

return self.forward_step(batch)

def get_inference_wrapper(self, params_dtype, inference_batch_times_seqlen_threshold) -> torch.Tensor:
# This is to get the MCore model required in GPTInferenceWrapper.
mcore_model = self.module
while mcore_model:
if type(mcore_model) is MCoreGPTModel:
break
mcore_model = getattr(mcore_model, "module", None)
if mcore_model is None or type(mcore_model) is not MCoreGPTModel:
raise ValueError("Exact McoreGPTModel instance not found in the model structure.")

inference_wrapper_config = InferenceWrapperConfig(
hidden_size=mcore_model.config.hidden_size,
params_dtype=params_dtype,
inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold,
padded_vocab_size=self.tokenizer.vocab_size,
)

model_inference_wrapper = GPTInferenceWrapper(mcore_model, inference_wrapper_config)
return model_inference_wrapper

@property
def training_loss_reduction(self) -> MaskedTokenLossReduction:
if not self._training_loss_reduction:
Expand Down
63 changes: 35 additions & 28 deletions nemo/collections/llm/inference/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,16 @@
import torch.distributed
from megatron.core.inference.common_inference_params import CommonInferenceParams
from megatron.core.inference.engines.mcore_engine import MCoreEngine
from megatron.core.inference.model_inference_wrappers.gpt.gpt_inference_wrapper import GPTInferenceWrapper
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig
from megatron.core.inference.model_inference_wrappers.abstract_model_inference_wrapper import (
AbstractModelInferenceWrapper,
)
from megatron.core.inference.text_generation_controllers.encoder_decoder_text_generation_controller import (
EncoderDecoderTextGenerationController,
)
from megatron.core.inference.text_generation_controllers.simple_text_generation_controller import (
SimpleTextGenerationController,
)
from megatron.core.models.gpt.gpt_model import GPTModel as MCoreGPTModel
from megatron.core.transformer.module import MegatronModule
from pytorch_lightning.trainer.states import TrainerFn

import nemo.lightning as nl
Expand All @@ -37,19 +41,31 @@
from nemo.lightning.pytorch.strategies.utils import RestoreConfig


# We need this wrapper since mcore generate uses tokenizer.detokenize, tokenizer.tokenize to encode and decode prompts
# We need this wrapper since mcore generate uses methods/properties such as tokenizer.detokenize, tokenizer.tokenize, tokenizer.bos, tokenizer.pad, etc. to encode and decode prompts
class MCoreTokenizerWrappper:
def __init__(self, tokenizer):
self.tokenizer = tokenizer
self.eod = tokenizer.eod
self.vocab_size = tokenizer.vocab_size

def detokenize(self, tokens):
return self.tokenizer.ids_to_text(tokens)
def detokenize(self, tokens, remove_special_tokens=False):
return self.tokenizer.ids_to_text(tokens, remove_special_tokens)

def tokenize(self, prompt):
return self.tokenizer.text_to_ids(prompt)

@property
def additional_special_tokens_ids(self):
return self.tokenizer.additional_special_tokens_ids

@property
def bos(self):
return self.tokenizer.bos_id

@property
def pad(self):
return self.tokenizer.pad_id


# TODO: Move to lightning Fabric API.
def _setup_trainer_and_restore_model(path: Path, trainer: nl.Trainer, model: pl.LightningModule):
Expand Down Expand Up @@ -101,41 +117,30 @@ def setup_model_and_tokenizer(
trainer: nl.Trainer,
params_dtype: torch.dtype = torch.bfloat16,
inference_batch_times_seqlen_threshold: int = 1000,
) -> tuple[MCoreGPTModel, MCoreTokenizerWrappper]:
) -> tuple[MegatronModule, MCoreTokenizerWrappper]:
model: io.TrainerContext = io.load_context(path=ckpt_to_context_subdir(path), subpath="model")
_setup_trainer_and_restore_model(path=path, trainer=trainer, model=model)

# This is to get the MCore model required in GPTInferenceWrapper.
mcore_model = model
while mcore_model:
if type(mcore_model) is MCoreGPTModel:
break
mcore_model = getattr(mcore_model, "module", None)
if mcore_model is None or type(mcore_model) is not MCoreGPTModel:
raise ValueError("Exact McoreGPTModel instance not found in the model structure.")

inference_wrapped_model = GPTInferenceWrapper(
mcore_model,
InferenceWrapperConfig(
hidden_size=mcore_model.config.hidden_size,
params_dtype=params_dtype,
inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold,
padded_vocab_size=model.tokenizer.vocab_size,
),
)

inference_wrapped_model = model.get_inference_wrapper(params_dtype, inference_batch_times_seqlen_threshold)
return inference_wrapped_model, MCoreTokenizerWrappper(model.tokenizer)


def generate(
model: GPTInferenceWrapper,
model: AbstractModelInferenceWrapper,
tokenizer: MCoreTokenizerWrappper,
prompts: list[str],
encoder_prompts: Optional[list[str]] = None,
add_BOS: bool = False,
max_batch_size: int = 4,
random_seed: Optional[int] = None,
inference_params: Optional[CommonInferenceParams] = None,
) -> dict:
text_generation_controller = SimpleTextGenerationController(inference_wrapped_model=model, tokenizer=tokenizer)
if encoder_prompts is not None:
text_generation_controller = EncoderDecoderTextGenerationController(
inference_wrapped_model=model, tokenizer=tokenizer
)
else:
text_generation_controller = SimpleTextGenerationController(inference_wrapped_model=model, tokenizer=tokenizer)
mcore_engine = MCoreEngine(
text_generation_controller=text_generation_controller, max_batch_size=max_batch_size, random_seed=random_seed
)
Expand All @@ -144,6 +149,8 @@ def generate(

results = mcore_engine.generate(
prompts=prompts,
add_BOS=add_BOS,
encoder_prompts=encoder_prompts,
common_inference_params=common_inference_params,
)

Expand Down
22 changes: 22 additions & 0 deletions nemo/collections/llm/t5/model/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import pytorch_lightning as L
import torch
import torch.distributed
from megatron.core.inference.model_inference_wrappers.inference_wrapper_config import InferenceWrapperConfig
from megatron.core.inference.model_inference_wrappers.t5.t5_inference_wrapper import T5InferenceWrapper
from megatron.core.optimizer import OptimizerConfig
from megatron.core.transformer.spec_utils import ModuleSpec
from megatron.core.transformer.transformer_config import TransformerConfig
Expand Down Expand Up @@ -258,6 +260,26 @@ def validation_step(self, batch, batch_idx=None) -> torch.Tensor:

return self.forward_step(batch)

def get_inference_wrapper(self, params_dtype, inference_batch_times_seqlen_threshold) -> torch.Tensor:
# This is to get the MCore model required in T5InferenceWrapper.
mcore_model = self.module
while mcore_model:
if type(mcore_model) is MCoreT5Model:
break
mcore_model = getattr(mcore_model, "module", None)
if mcore_model is None or type(mcore_model) is not MCoreT5Model:
raise ValueError("Exact MCoreT5Model instance not found in the model structure.")

inference_wrapper_config = InferenceWrapperConfig(
hidden_size=mcore_model.config.hidden_size,
params_dtype=params_dtype,
inference_batch_times_seqlen_threshold=inference_batch_times_seqlen_threshold,
padded_vocab_size=self.tokenizer.vocab_size,
)

model_inference_wrapper = T5InferenceWrapper(mcore_model, inference_wrapper_config)
return model_inference_wrapper

@property
def training_loss_reduction(self) -> MaskedTokenLossReduction:
if not self._training_loss_reduction:
Expand Down
108 changes: 108 additions & 0 deletions scripts/llm/t5_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# NOTE: This script is just an example of using NeMo checkpoints for generating outputs and is subject to change without notice.

import argparse
import torch
import torch.distributed
from megatron.core.inference.common_inference_params import CommonInferenceParams

import nemo.lightning as nl
from nemo.collections.llm import api


def get_args():
parser = argparse.ArgumentParser(description='Train a small T5 model using NeMo 2.0')
parser.add_argument('--devices', type=int, help="Number of devices to use for training.")
parser.add_argument('--checkpoint-path', type=str, help="Path to trained model.")
parser.add_argument("--temperature", type=float, default=1.0, help='Sampling temperature.')
parser.add_argument("--top_k", type=int, default=1, help='Top k sampling.')
parser.add_argument("--top_p", type=float, default=0.0, help='Top p sampling.')
parser.add_argument(
"--num-tokens-to-generate", type=int, default=30, help='Number of tokens to generate for each prompt.'
)
parser.add_argument(
"--prompts",
metavar='N',
type=str,
nargs='+',
help='Prompts with each prompt within quotes and seperated by space.',
)
parser.add_argument(
"--encoder-prompts",
metavar='N',
type=str,
nargs='+',
help='Encoder input prompts with each prompt within quotes and seperated by space.',
)
parser.add_argument("--max-batch-size", type=int, default=1, help='Max number of prompts to process at once.')

return parser.parse_args()


if __name__ == "__main__":

args = get_args()

strategy = nl.MegatronStrategy(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
context_parallel_size=1,
sequence_parallel=False,
setup_optimizers=False,
store_optimizer_states=False,
)

trainer = nl.Trainer(
accelerator="gpu",
devices=args.devices,
num_nodes=1,
strategy=strategy,
plugins=nl.MegatronMixedPrecision(
precision="bf16-mixed",
params_dtype=torch.bfloat16,
pipeline_dtype=torch.bfloat16,
autocast_enabled=False,
grad_reduce_in_fp32=False,
),
)
prompts = [
"",
"",
"",
]
encoder_prompts = [
"Hello, how are <mask>?",
"How many r's are in the <mask> 'strawberry'?",
"Which number is <mask>? 10.119 <mask> 10.19?",
]

results = api.generate(
path=args.checkpoint_path,
prompts=prompts,
encoder_prompts=encoder_prompts,
trainer=trainer,
add_BOS=True,
inference_params=CommonInferenceParams(
temperature=args.temperature, top_k=args.top_k, num_tokens_to_generate=args.num_tokens_to_generate
),
text_only=True,
)
if torch.distributed.get_rank() == 0:
for i, r in enumerate(results):
print(prompts[i])
print("*" * 50)
print(r)
print("\n\n")

0 comments on commit a89e0e8

Please sign in to comment.