Skip to content

Commit

Permalink
Megatron Export Update (NVIDIA#5343) (NVIDIA#5423)
Browse files Browse the repository at this point in the history
* export update for Megatron + change ORT optimization

Signed-off-by: David Mosallanezhad <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* updated export_utils to use autocast instead of manually casting >:/

Signed-off-by: David Mosallanezhad <[email protected]>

* removed dtype from LayerNorm

Signed-off-by: David Mosallanezhad <[email protected]>

* added comment

Signed-off-by: David Mosallanezhad <[email protected]>

* reverting changes on FloatCast

Signed-off-by: David Mosallanezhad <[email protected]>

* Cherry-picked changes from megatron-norm

Signed-off-by: Boris Fomitchev <[email protected]>

* updated asr_model import to cast_utils

Signed-off-by: David Mosallanezhad <[email protected]>

* updated del onnx_model place

Signed-off-by: David Mosallanezhad <[email protected]>

* changed ort optimization to basic -> temp fix

Signed-off-by: David Mosallanezhad <[email protected]>

Signed-off-by: David Mosallanezhad <[email protected]>
Signed-off-by: Boris Fomitchev <[email protected]>
Co-authored-by: David Mosallanezhad <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Boris Fomitchev <[email protected]>

Signed-off-by: David Mosallanezhad <[email protected]>
Signed-off-by: Boris Fomitchev <[email protected]>
Signed-off-by: Boris Fomitchev <[email protected]>
Co-authored-by: David <[email protected]>
Co-authored-by: David Mosallanezhad <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Boris Fomitchev <[email protected]>
Co-authored-by: Oleksii Kuchaiev <[email protected]>
Co-authored-by: Boris Fomitchev <[email protected]>
Signed-off-by: Hainan Xu <[email protected]>
  • Loading branch information
7 people authored and Hainan Xu committed Nov 29, 2022
1 parent ed778d8 commit 801a87f
Show file tree
Hide file tree
Showing 5 changed files with 108 additions and 49 deletions.
3 changes: 2 additions & 1 deletion nemo/collections/asr/models/asr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from nemo.core.classes.common import PretrainedModelInfo
from nemo.core.classes.exportable import Exportable
from nemo.core.classes.mixins import AccessMixin
from nemo.utils import cast_all, logging, model_utils
from nemo.utils import logging, model_utils
from nemo.utils.cast_utils import cast_all

__all__ = ['ASRModel']

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from nemo.collections.nlp.modules.common.megatron.megatron_export import DecEmb, EncEmb, TokensHeadEmb
from nemo.collections.nlp.parts.nlp_overrides import GlobalBatchDataFetcher
from nemo.collections.nlp.parts.utils_funcs import get_last_rank
from nemo.core.classes import Exportable
from nemo.utils import AppState, logging, timers

try:
Expand All @@ -56,7 +57,7 @@
__all__ = ["MegatronNMTModel"]


class MegatronNMTModel(MegatronLMEncoderDecoderModel):
class MegatronNMTModel(MegatronLMEncoderDecoderModel, Exportable):
"""
Megatron NMT training
"""
Expand Down Expand Up @@ -750,5 +751,12 @@ def decoder(self):
return DecEmb(self.enc_dec_model.decoder_embedding, self.enc_dec_model.enc_dec_model.decoder, self.device)

@property
def classifier(self):
def log_softmax(self):
return TokensHeadEmb(self.enc_dec_model.decoder_embedding, self.enc_dec_model.tokens_head, self.device)

@property
def input_module(self):
return self.encoder

def list_export_subnets(self):
return ['encoder', 'log_softmax', 'decoder']
50 changes: 33 additions & 17 deletions nemo/collections/nlp/modules/common/megatron/megatron_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,21 +49,23 @@ def forward(self, dec_output):
if isinstance(dec_output, list):
dec_output = dec_output[0]

dec_output = torch.permute(dec_output, (1, 0, 2))

if self.tokens_head_bias is not None:
return F.linear(dec_output, self.decoder_embedding.word_embeddings.weight, self.tokens_head_bias)
return F.linear(dec_output, self.decoder_embedding.word_embeddings.weight)

def input_example(self, max_batch=1, max_dim=1024, seq_len=6):
def input_example(self, max_batch=1, max_dim=768, seq_len=6):
return [
torch.randint(low=-3, high=3, size=(seq_len, max_batch, max_dim), device=self.device, dtype=torch.float32)
torch.randint(low=-3, high=3, size=(max_batch, seq_len, max_dim), device=self.device, dtype=torch.float32)
]

def freeze(self):
for param in self.parameters():
param.requires_grad = False

@property
def input_types(self) -> Optional[Dict[str, NeuralType]]:
return {
"hidden_states": NeuralType(('T', 'B', 'D'), ChannelType()),
"hidden_states": NeuralType(('B', 'T', 'D'), ChannelType()),
}

@property
Expand Down Expand Up @@ -107,18 +109,28 @@ def forward(self, input_ids, decoder_mask, encoder_mask, encoder_embeddings, dec
# dec_input, dec_attn_mask, enc_output, enc_attn_mask | dec_input, dec_attn_mask, enc_output, enc_attn_mask
_ = dec_mems

return self.decoder(dec_input, decoder_mask, encoder_embeddings, encoder_mask).float()
return (
self.decoder(dec_input, decoder_mask, encoder_embeddings.permute(1, 0, 2), encoder_mask)
.float()
.permute(1, 0, 2)
)

def input_example(self, max_batch=1, max_dim=1024, seq_len=6):
def freeze(self):
for param in self.parameters():
param.requires_grad = False

def input_example(self, max_batch=1, max_dim=768, seq_len=6):
enc_output = torch.randint(
low=-3, high=3, size=(seq_len, max_batch, max_dim), device=self.device, dtype=torch.float32
low=-3, high=3, size=(max_batch, seq_len, max_dim), device=self.device, dtype=torch.float32
)
enc_attn_mask = torch.tensor([[1 for _ in range(seq_len)]]).to(self.device)

dec_len = random.randint(10, 128)
dec_input = torch.randint(low=0, high=1000, size=(max_batch, dec_len), device=self.device)
dec_attn_mask = torch.tensor([[1 for _ in range(dec_len)]]).to(self.device)
decoder_mems = torch.zeros([8, 6, 1024], dtype=torch.float32).to(self.device)

# constant decoder_mems as placeholder for now
decoder_mems = torch.zeros([8, 6, max_dim], dtype=torch.float32).to(self.device)

# input_ids, decoder_mask, encoder_mask, encoder_embeddings
return (dec_input, dec_attn_mask, enc_attn_mask, enc_output, decoder_mems)
Expand All @@ -128,14 +140,14 @@ def input_types(self) -> Optional[Dict[str, NeuralType]]:
return {
"input_ids": NeuralType(('B', 'T', 'D'), ChannelType()),
"decoder_mask": NeuralType(('B', 'T'), MaskType()),
"encoder_mask": NeuralType(('T', 'B', 'D'), ChannelType()),
"encoder_mask": NeuralType(('B', 'T', 'D'), ChannelType()),
"encoder_embeddings": NeuralType(('B', 'T'), MaskType()),
"decoder_mems": NeuralType(('T', 'B', 'D'), ChannelType()),
"decoder_mems": NeuralType(('B', 'T', 'D'), ChannelType()),
}

@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
return {"last_hidden_states": NeuralType(('T', 'B', 'D'), ChannelType())}
return {"last_hidden_states": NeuralType(('B', 'T', 'D'), ChannelType())}

@property
def input_names(self) -> List[str]:
Expand Down Expand Up @@ -172,15 +184,19 @@ def forward(self, input_ids, encoder_mask):
enc_input = self.encoder_embedding(input_ids, position_ids, token_type_ids=None)

# pass input through the encoder
return self.encoder(enc_input=enc_input, enc_attn_mask=encoder_mask,).type(torch.float32)
return self.encoder(enc_input=enc_input, enc_attn_mask=encoder_mask,).permute(1, 0, 2)

def input_example(self):
def input_example(self, max_batch=1, max_dim=30000, seq_len=6):
seq_len = random.randint(0, 128)
return (
torch.randint(0, 30000, (1, seq_len)).to(self.device),
torch.ones((1, seq_len), dtype=int).to(self.device),
torch.randint(0, max_dim, (max_batch, seq_len)).to(self.device),
torch.ones((max_batch, seq_len), dtype=int).to(self.device),
)

def freeze(self):
for param in self.parameters():
param.requires_grad = False

@property
def input_types(self) -> Optional[Dict[str, NeuralType]]:
return {
Expand All @@ -190,7 +206,7 @@ def input_types(self) -> Optional[Dict[str, NeuralType]]:

@property
def output_types(self) -> Optional[Dict[str, NeuralType]]:
return {"last_hidden_states": NeuralType(('T', 'B', 'D'), ChannelType())}
return {"last_hidden_states": NeuralType(('B', 'T', 'D'), ChannelType())}

@property
def input_names(self) -> List[str]:
Expand Down
91 changes: 62 additions & 29 deletions nemo/utils/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,40 @@ def forward(self, x):
return F.linear(x, self.weight, self.bias), None


# ScaledMaskedSoftmax replacement
def mask_func(attention_scores, attention_mask):
attention_scores.masked_fill_(attention_mask, -10000.0)
return attention_scores


def exportable_ScaledMaskedSoftmax(input, mask, scale):
if scale is not None:
input = input * scale

mask_output = mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output)

probs = probs.half()
return probs
class ExportableMatchedScaleMaskSoftmax(nn.Module):
def __init__(self, mod):
super(ExportableMatchedScaleMaskSoftmax, self).__init__()
self.init_module(mod.input_in_fp16, mod.input_in_bf16, mod.mask_func, mod.softmax_in_fp32, mod.scale)

def init_module(
self, input_in_fp16, input_in_bf16, mask_func, softmax_in_fp32, scale,
):
self.input_in_fp16 = input_in_fp16
self.input_in_bf16 = input_in_bf16
self.softmax_in_fp32 = softmax_in_fp32
self.mask_func = mask_func
self.scale = scale

self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16

def forward(self, input, mask):
if self.input_in_float16 and self.softmax_in_fp32:
input = input.float()

if self.scale is not None:
input = input * self.scale
mask_output = self.mask_func(input, mask) if mask is not None else input
probs = torch.nn.Softmax(dim=-1)(mask_output)
all_k_masked = mask.all(axis=-1)
zero_attention_mask = (1.0 - all_k_masked.float())[:, :, :, None]
probs = probs * zero_attention_mask

if self.input_in_float16 and self.softmax_in_fp32:
if self.input_in_fp16:
probs = probs.half()
else:
probs = probs.bfloat16()
return probs


def get_export_format(filename: str):
Expand Down Expand Up @@ -159,10 +178,12 @@ def verify_runtime(model, output, input_examples, input_names, check_tolerance=0
logging.warning(f"ONNX generated at {output}, not verified - please install onnxruntime_gpu package.\n")
onnx.checker.check_model(onnx_model, full_check=True)
return
del onnx_model
onnx_session_opt = onnxruntime.SessionOptions()
onnx_session_opt.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
sess = onnxruntime.InferenceSession(output, sess_options=onnx_session_opt, providers=['CUDAExecutionProvider'])
onnx_session_opt.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_BASIC
sess = onnxruntime.InferenceSession(
onnx_model.SerializeToString(), sess_options=onnx_session_opt, providers=['CUDAExecutionProvider']
)
del onnx_model
all_good = True
for input_example in input_examples:
input_list, input_dict = parse_input_example(input_example)
Expand Down Expand Up @@ -227,27 +248,24 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01):
from apex.transformer.tensor_parallel.layers import RowParallelLinear, ColumnParallelLinear
from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax

def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.BatchNorm2d]:
def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]:
"""
Replaces Apex's FusedLayerNorm with nn.LayerNorm. This is required for ONNX export.
Args:
n: the FusedLayerNorm pytorch module to replace
Returns:
Equivalent LayerNorm module
"""
if (
not isinstance(n, FusedLayerNorm)
and not isinstance(n, FastLayerNorm)
and not isinstance(n, MixedFusedLayerNorm)
):
return None

dev = next(n.parameters()).device
p = next(n.parameters())
if isinstance(n, FusedLayerNorm) or isinstance(n, MixedFusedLayerNorm):
mod = nn.LayerNorm(n.normalized_shape, eps=n.eps, elementwise_affine=n.elementwise_affine,).to(dev)
shape, eps, affine = n.normalized_shape, n.eps, n.elementwise_affine
elif isinstance(n, FastLayerNorm):
mod = nn.LayerNorm(n.weight.shape, eps=n.epsilon, elementwise_affine=True, dtype=torch.float16,).to(dev)
shape, eps, affine = n.weight.shape, n.epsilon, True
else:
return None

mod = nn.LayerNorm(shape, eps=eps, elementwise_affine=affine, device=p.device, dtype=p.dtype)
n_state = n.state_dict()
mod.load_state_dict(n_state)
return mod
Expand All @@ -264,7 +282,7 @@ def replace_RowParallelLinear(n: nn.Module) -> Optional[nn.Linear]:
raise ValueError("This function can only change the RowParallelLinear module.")

dev = next(n.parameters()).device
mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(dev)
mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(device=dev)

n_state = n.state_dict()
mod.load_state_dict(n_state)
Expand Down Expand Up @@ -340,6 +358,20 @@ def expansion_fn(mod: nn.Module) -> Optional[nn.Module]:
return expansion_fn


def replace_MatchedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]:
"""
Replaces MatchedScaleMaskSoftmax with exportable softmax layer
Args:
n: module to replace
Returns:
exportable module
"""

mod = ExportableMatchedScaleMaskSoftmax(n.input_in_fp16, n.input_in_bf16, n.mask_func, n.softmax_in_fp32, n.scale)

return mod


def wrap_module(BaseT: Type[nn.Module], DestT: Type[nn.Module]) -> Callable[[nn.Module], Optional[nn.Module]]:
"""
Generic function generator to replace BaseT module with DestT wrapper.
Expand Down Expand Up @@ -408,6 +440,7 @@ def script_module(m: nn.Module):
"BatchNorm1d": wrap_module(nn.BatchNorm1d, CastToFloat),
"BatchNorm2d": wrap_module(nn.BatchNorm2d, CastToFloat),
"LayerNorm": wrap_module(nn.LayerNorm, CastToFloat),
"MatchedScaleMaskSoftmax": wrap_module(nn.Softmax, ExportableMatchedScaleMaskSoftmax),
}

script_replacements = {
Expand Down
1 change: 1 addition & 0 deletions scripts/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ def nemo_export(argv):
with autocast(), torch.no_grad(), torch.inference_mode():
model.to(device=args.device).freeze()
model.eval()
input_example = None
if check_trace and len(in_args) > 0:
input_example = model.input_module.input_example(**in_args)
check_trace = [input_example]
Expand Down

0 comments on commit 801a87f

Please sign in to comment.