From bbed82fa4f62eb5062176790a4c665550fee5704 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Tue, 22 Nov 2022 18:14:58 -0800 Subject: [PATCH] Megatron Export Update (#5343) (#5423) * export update for Megatron + change ORT optimization Signed-off-by: David Mosallanezhad * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * updated export_utils to use autocast instead of manually casting >:/ Signed-off-by: David Mosallanezhad * removed dtype from LayerNorm Signed-off-by: David Mosallanezhad * added comment Signed-off-by: David Mosallanezhad * reverting changes on FloatCast Signed-off-by: David Mosallanezhad * Cherry-picked changes from megatron-norm Signed-off-by: Boris Fomitchev * updated asr_model import to cast_utils Signed-off-by: David Mosallanezhad * updated del onnx_model place Signed-off-by: David Mosallanezhad * changed ort optimization to basic -> temp fix Signed-off-by: David Mosallanezhad Signed-off-by: David Mosallanezhad Signed-off-by: Boris Fomitchev Co-authored-by: David Mosallanezhad Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Boris Fomitchev Signed-off-by: David Mosallanezhad Signed-off-by: Boris Fomitchev Signed-off-by: Boris Fomitchev Co-authored-by: David Co-authored-by: David Mosallanezhad Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Boris Fomitchev Co-authored-by: Oleksii Kuchaiev Co-authored-by: Boris Fomitchev --- nemo/collections/asr/models/asr_model.py | 3 +- .../machine_translation/megatron_nmt_model.py | 12 ++- .../common/megatron/megatron_export.py | 50 ++++++---- nemo/utils/export_utils.py | 91 +++++++++++++------ scripts/export.py | 1 + 5 files changed, 108 insertions(+), 49 deletions(-) diff --git a/nemo/collections/asr/models/asr_model.py b/nemo/collections/asr/models/asr_model.py index bae9c89a8ad3..4777cf366068 100644 --- a/nemo/collections/asr/models/asr_model.py +++ b/nemo/collections/asr/models/asr_model.py @@ -21,7 +21,8 @@ from nemo.core.classes.common import PretrainedModelInfo from nemo.core.classes.exportable import Exportable from nemo.core.classes.mixins import AccessMixin -from nemo.utils import cast_all, logging, model_utils +from nemo.utils import logging, model_utils +from nemo.utils.cast_utils import cast_all __all__ = ['ASRModel'] diff --git a/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py b/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py index 0b9bb966e3e8..a44b560fbb2d 100644 --- a/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py +++ b/nemo/collections/nlp/models/machine_translation/megatron_nmt_model.py @@ -42,6 +42,7 @@ from nemo.collections.nlp.modules.common.megatron.megatron_export import DecEmb, EncEmb, TokensHeadEmb from nemo.collections.nlp.parts.nlp_overrides import GlobalBatchDataFetcher from nemo.collections.nlp.parts.utils_funcs import get_last_rank +from nemo.core.classes import Exportable from nemo.utils import AppState, logging, timers try: @@ -56,7 +57,7 @@ __all__ = ["MegatronNMTModel"] -class MegatronNMTModel(MegatronLMEncoderDecoderModel): +class MegatronNMTModel(MegatronLMEncoderDecoderModel, Exportable): """ Megatron NMT training """ @@ -750,5 +751,12 @@ def decoder(self): return DecEmb(self.enc_dec_model.decoder_embedding, self.enc_dec_model.enc_dec_model.decoder, self.device) @property - def classifier(self): + def log_softmax(self): return TokensHeadEmb(self.enc_dec_model.decoder_embedding, self.enc_dec_model.tokens_head, self.device) + + @property + def input_module(self): + return self.encoder + + def list_export_subnets(self): + return ['encoder', 'log_softmax', 'decoder'] diff --git a/nemo/collections/nlp/modules/common/megatron/megatron_export.py b/nemo/collections/nlp/modules/common/megatron/megatron_export.py index 6fd9a239380c..8b9a5fff9e88 100644 --- a/nemo/collections/nlp/modules/common/megatron/megatron_export.py +++ b/nemo/collections/nlp/modules/common/megatron/megatron_export.py @@ -49,21 +49,23 @@ def forward(self, dec_output): if isinstance(dec_output, list): dec_output = dec_output[0] - dec_output = torch.permute(dec_output, (1, 0, 2)) - if self.tokens_head_bias is not None: return F.linear(dec_output, self.decoder_embedding.word_embeddings.weight, self.tokens_head_bias) return F.linear(dec_output, self.decoder_embedding.word_embeddings.weight) - def input_example(self, max_batch=1, max_dim=1024, seq_len=6): + def input_example(self, max_batch=1, max_dim=768, seq_len=6): return [ - torch.randint(low=-3, high=3, size=(seq_len, max_batch, max_dim), device=self.device, dtype=torch.float32) + torch.randint(low=-3, high=3, size=(max_batch, seq_len, max_dim), device=self.device, dtype=torch.float32) ] + def freeze(self): + for param in self.parameters(): + param.requires_grad = False + @property def input_types(self) -> Optional[Dict[str, NeuralType]]: return { - "hidden_states": NeuralType(('T', 'B', 'D'), ChannelType()), + "hidden_states": NeuralType(('B', 'T', 'D'), ChannelType()), } @property @@ -107,18 +109,28 @@ def forward(self, input_ids, decoder_mask, encoder_mask, encoder_embeddings, dec # dec_input, dec_attn_mask, enc_output, enc_attn_mask | dec_input, dec_attn_mask, enc_output, enc_attn_mask _ = dec_mems - return self.decoder(dec_input, decoder_mask, encoder_embeddings, encoder_mask).float() + return ( + self.decoder(dec_input, decoder_mask, encoder_embeddings.permute(1, 0, 2), encoder_mask) + .float() + .permute(1, 0, 2) + ) - def input_example(self, max_batch=1, max_dim=1024, seq_len=6): + def freeze(self): + for param in self.parameters(): + param.requires_grad = False + + def input_example(self, max_batch=1, max_dim=768, seq_len=6): enc_output = torch.randint( - low=-3, high=3, size=(seq_len, max_batch, max_dim), device=self.device, dtype=torch.float32 + low=-3, high=3, size=(max_batch, seq_len, max_dim), device=self.device, dtype=torch.float32 ) enc_attn_mask = torch.tensor([[1 for _ in range(seq_len)]]).to(self.device) dec_len = random.randint(10, 128) dec_input = torch.randint(low=0, high=1000, size=(max_batch, dec_len), device=self.device) dec_attn_mask = torch.tensor([[1 for _ in range(dec_len)]]).to(self.device) - decoder_mems = torch.zeros([8, 6, 1024], dtype=torch.float32).to(self.device) + + # constant decoder_mems as placeholder for now + decoder_mems = torch.zeros([8, 6, max_dim], dtype=torch.float32).to(self.device) # input_ids, decoder_mask, encoder_mask, encoder_embeddings return (dec_input, dec_attn_mask, enc_attn_mask, enc_output, decoder_mems) @@ -128,14 +140,14 @@ def input_types(self) -> Optional[Dict[str, NeuralType]]: return { "input_ids": NeuralType(('B', 'T', 'D'), ChannelType()), "decoder_mask": NeuralType(('B', 'T'), MaskType()), - "encoder_mask": NeuralType(('T', 'B', 'D'), ChannelType()), + "encoder_mask": NeuralType(('B', 'T', 'D'), ChannelType()), "encoder_embeddings": NeuralType(('B', 'T'), MaskType()), - "decoder_mems": NeuralType(('T', 'B', 'D'), ChannelType()), + "decoder_mems": NeuralType(('B', 'T', 'D'), ChannelType()), } @property def output_types(self) -> Optional[Dict[str, NeuralType]]: - return {"last_hidden_states": NeuralType(('T', 'B', 'D'), ChannelType())} + return {"last_hidden_states": NeuralType(('B', 'T', 'D'), ChannelType())} @property def input_names(self) -> List[str]: @@ -172,15 +184,19 @@ def forward(self, input_ids, encoder_mask): enc_input = self.encoder_embedding(input_ids, position_ids, token_type_ids=None) # pass input through the encoder - return self.encoder(enc_input=enc_input, enc_attn_mask=encoder_mask,).type(torch.float32) + return self.encoder(enc_input=enc_input, enc_attn_mask=encoder_mask,).permute(1, 0, 2) - def input_example(self): + def input_example(self, max_batch=1, max_dim=30000, seq_len=6): seq_len = random.randint(0, 128) return ( - torch.randint(0, 30000, (1, seq_len)).to(self.device), - torch.ones((1, seq_len), dtype=int).to(self.device), + torch.randint(0, max_dim, (max_batch, seq_len)).to(self.device), + torch.ones((max_batch, seq_len), dtype=int).to(self.device), ) + def freeze(self): + for param in self.parameters(): + param.requires_grad = False + @property def input_types(self) -> Optional[Dict[str, NeuralType]]: return { @@ -190,7 +206,7 @@ def input_types(self) -> Optional[Dict[str, NeuralType]]: @property def output_types(self) -> Optional[Dict[str, NeuralType]]: - return {"last_hidden_states": NeuralType(('T', 'B', 'D'), ChannelType())} + return {"last_hidden_states": NeuralType(('B', 'T', 'D'), ChannelType())} @property def input_names(self) -> List[str]: diff --git a/nemo/utils/export_utils.py b/nemo/utils/export_utils.py index 77d9291220af..197d3b478167 100644 --- a/nemo/utils/export_utils.py +++ b/nemo/utils/export_utils.py @@ -59,21 +59,40 @@ def forward(self, x): return F.linear(x, self.weight, self.bias), None -# ScaledMaskedSoftmax replacement -def mask_func(attention_scores, attention_mask): - attention_scores.masked_fill_(attention_mask, -10000.0) - return attention_scores - - -def exportable_ScaledMaskedSoftmax(input, mask, scale): - if scale is not None: - input = input * scale - - mask_output = mask_func(input, mask) if mask is not None else input - probs = torch.nn.Softmax(dim=-1)(mask_output) - - probs = probs.half() - return probs +class ExportableMatchedScaleMaskSoftmax(nn.Module): + def __init__(self, mod): + super(ExportableMatchedScaleMaskSoftmax, self).__init__() + self.init_module(mod.input_in_fp16, mod.input_in_bf16, mod.mask_func, mod.softmax_in_fp32, mod.scale) + + def init_module( + self, input_in_fp16, input_in_bf16, mask_func, softmax_in_fp32, scale, + ): + self.input_in_fp16 = input_in_fp16 + self.input_in_bf16 = input_in_bf16 + self.softmax_in_fp32 = softmax_in_fp32 + self.mask_func = mask_func + self.scale = scale + + self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 + + def forward(self, input, mask): + if self.input_in_float16 and self.softmax_in_fp32: + input = input.float() + + if self.scale is not None: + input = input * self.scale + mask_output = self.mask_func(input, mask) if mask is not None else input + probs = torch.nn.Softmax(dim=-1)(mask_output) + all_k_masked = mask.all(axis=-1) + zero_attention_mask = (1.0 - all_k_masked.float())[:, :, :, None] + probs = probs * zero_attention_mask + + if self.input_in_float16 and self.softmax_in_fp32: + if self.input_in_fp16: + probs = probs.half() + else: + probs = probs.bfloat16() + return probs def get_export_format(filename: str): @@ -159,10 +178,12 @@ def verify_runtime(model, output, input_examples, input_names, check_tolerance=0 logging.warning(f"ONNX generated at {output}, not verified - please install onnxruntime_gpu package.\n") onnx.checker.check_model(onnx_model, full_check=True) return - del onnx_model onnx_session_opt = onnxruntime.SessionOptions() - onnx_session_opt.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL - sess = onnxruntime.InferenceSession(output, sess_options=onnx_session_opt, providers=['CUDAExecutionProvider']) + onnx_session_opt.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC + sess = onnxruntime.InferenceSession( + onnx_model.SerializeToString(), sess_options=onnx_session_opt, providers=['CUDAExecutionProvider'] + ) + del onnx_model all_good = True for input_example in input_examples: input_list, input_dict = parse_input_example(input_example) @@ -227,7 +248,7 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01): from apex.transformer.tensor_parallel.layers import RowParallelLinear, ColumnParallelLinear from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax - def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.BatchNorm2d]: + def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]: """ Replaces Apex's FusedLayerNorm with nn.LayerNorm. This is required for ONNX export. Args: @@ -235,19 +256,16 @@ def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.BatchNorm2d]: Returns: Equivalent LayerNorm module """ - if ( - not isinstance(n, FusedLayerNorm) - and not isinstance(n, FastLayerNorm) - and not isinstance(n, MixedFusedLayerNorm) - ): - return None - dev = next(n.parameters()).device + p = next(n.parameters()) if isinstance(n, FusedLayerNorm) or isinstance(n, MixedFusedLayerNorm): - mod = nn.LayerNorm(n.normalized_shape, eps=n.eps, elementwise_affine=n.elementwise_affine,).to(dev) + shape, eps, affine = n.normalized_shape, n.eps, n.elementwise_affine elif isinstance(n, FastLayerNorm): - mod = nn.LayerNorm(n.weight.shape, eps=n.epsilon, elementwise_affine=True, dtype=torch.float16,).to(dev) + shape, eps, affine = n.weight.shape, n.epsilon, True + else: + return None + mod = nn.LayerNorm(shape, eps=eps, elementwise_affine=affine, device=p.device, dtype=p.dtype) n_state = n.state_dict() mod.load_state_dict(n_state) return mod @@ -264,7 +282,7 @@ def replace_RowParallelLinear(n: nn.Module) -> Optional[nn.Linear]: raise ValueError("This function can only change the RowParallelLinear module.") dev = next(n.parameters()).device - mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(dev) + mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(device=dev) n_state = n.state_dict() mod.load_state_dict(n_state) @@ -340,6 +358,20 @@ def expansion_fn(mod: nn.Module) -> Optional[nn.Module]: return expansion_fn +def replace_MatchedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]: + """ + Replaces MatchedScaleMaskSoftmax with exportable softmax layer + Args: + n: module to replace + Returns: + exportable module + """ + + mod = ExportableMatchedScaleMaskSoftmax(n.input_in_fp16, n.input_in_bf16, n.mask_func, n.softmax_in_fp32, n.scale) + + return mod + + def wrap_module(BaseT: Type[nn.Module], DestT: Type[nn.Module]) -> Callable[[nn.Module], Optional[nn.Module]]: """ Generic function generator to replace BaseT module with DestT wrapper. @@ -408,6 +440,7 @@ def script_module(m: nn.Module): "BatchNorm1d": wrap_module(nn.BatchNorm1d, CastToFloat), "BatchNorm2d": wrap_module(nn.BatchNorm2d, CastToFloat), "LayerNorm": wrap_module(nn.LayerNorm, CastToFloat), + "MatchedScaleMaskSoftmax": wrap_module(nn.Softmax, ExportableMatchedScaleMaskSoftmax), } script_replacements = { diff --git a/scripts/export.py b/scripts/export.py index 2d7514483fc9..2e100e446e72 100644 --- a/scripts/export.py +++ b/scripts/export.py @@ -139,6 +139,7 @@ def nemo_export(argv): with autocast(), torch.no_grad(), torch.inference_mode(): model.to(device=args.device).freeze() model.eval() + input_example = None if check_trace and len(in_args) > 0: input_example = model.input_module.input_example(**in_args) check_trace = [input_example]