Skip to content

Commit

Permalink
GEGLU activation for T5 (#3694)
Browse files Browse the repository at this point in the history
* GEGLU activation

Signed-off-by: MaximumEntropy <[email protected]>

* Add activation to config

Signed-off-by: MaximumEntropy <[email protected]>

* Style fixes

Signed-off-by: MaximumEntropy <[email protected]>

* Update config

Signed-off-by: MaximumEntropy <[email protected]>

* Update license header

Signed-off-by: MaximumEntropy <[email protected]>
  • Loading branch information
MaximumEntropy authored and fayejf committed Mar 2, 2022
1 parent 5c6df7b commit 595549b
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ model:
gradient_as_bucket_view: True # Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory)
encoder_arch: 'transformer'
decoder_arch: 'transformer'
activation: 'gelu'

tokenizer:
library: 'megatron'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
bias_gelu_fusion=cfg.get('bias_gelu_fusion', True),
masked_softmax_fusion=cfg.get('masked_softmax_fusion', True),
onnx_safe=cfg.get('onnx_safe', False),
activation=cfg.get('activation', 'gelu'),
)

self.setup_optimizer_param_groups()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def get_decoder_model(
masked_softmax_fusion=True,
persist_layer_norm=False,
openai_gelu=False,
activation="gelu",
onnx_safe=False,
hidden_steps=-1,
hidden_blocks=1,
Expand Down Expand Up @@ -113,6 +114,7 @@ def get_decoder_model(
persist_layer_norm=persist_layer_norm,
openai_gelu=openai_gelu,
onnx_safe=onnx_safe,
activation=activation,
)
else:
raise ValueError(f"Unknown decoder arch = {arch}. Available decoder arch = {AVAILABLE_DECODERS}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def get_encoder_model(
masked_softmax_fusion=True,
persist_layer_norm=False,
openai_gelu=False,
activation="gelu",
onnx_safe=False,
hidden_steps=-1,
hidden_blocks=1,
Expand Down Expand Up @@ -112,6 +113,7 @@ def get_encoder_model(
persist_layer_norm=persist_layer_norm,
openai_gelu=openai_gelu,
onnx_safe=onnx_safe,
activation=activation,
)
else:
raise ValueError(f"Unknown encoder arch = {arch}. Available encoder arch = {AVAILABLE_ENCODERS}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
persist_layer_norm=False,
openai_gelu=False,
onnx_safe=False,
activation='gelu',
):
super(MegatronTransformerDecoderModule, self).__init__()

Expand Down Expand Up @@ -105,6 +106,7 @@ def __init__(
persist_layer_norm=persist_layer_norm,
openai_gelu=openai_gelu,
onnx_safe=onnx_safe,
activation=activation,
)
self._model_key = 'model'

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
persist_layer_norm=False,
openai_gelu=False,
onnx_safe=False,
activation='gelu',
):
super(MegatronTransformerEncoderModule, self).__init__()

Expand Down Expand Up @@ -104,6 +105,7 @@ def __init__(
persist_layer_norm=persist_layer_norm,
openai_gelu=openai_gelu,
onnx_safe=onnx_safe,
activation=activation,
)
self._model_key = 'model'

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __init__(
bias_gelu_fusion=True,
masked_softmax_fusion=True,
openai_gelu=False,
activation='gelu',
onnx_safe=False,
hidden_steps=-1,
hidden_blocks=1,
Expand Down Expand Up @@ -159,6 +160,7 @@ def __init__(
onnx_safe=onnx_safe,
hidden_steps=hidden_steps,
hidden_blocks=hidden_blocks,
activation=activation,
)

decoder = get_decoder_model(
Expand Down Expand Up @@ -190,6 +192,7 @@ def __init__(
onnx_safe=onnx_safe,
hidden_steps=hidden_steps,
hidden_blocks=hidden_blocks,
activation=activation,
)

self.enc_dec_model = MegatronTransformerEncoderDecoderModule(encoder=encoder, decoder=decoder,)
Expand Down
38 changes: 36 additions & 2 deletions nemo/collections/nlp/modules/common/megatron/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from nemo.collections.nlp.modules.common.megatron.fused_layer_norm import get_layer_norm
from nemo.collections.nlp.modules.common.megatron.module import MegatronModule
from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults, attention_mask_func, erf_gelu
from nemo.utils import logging

try:
from apex.transformer import parallel_state, tensor_parallel
Expand Down Expand Up @@ -76,21 +77,42 @@ def __init__(
bias_gelu_fusion=True,
openai_gelu=False,
onnx_safe=False,
activation='gelu',
):
super(ParallelMLP, self).__init__()
self.activation = activation

if activation not in ['gelu', 'geglu']:
raise ValueError(f"Activation {activation} not supported. Only gelu and geglu are supported.")

# Project to 4h.
self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear(
hidden_size,
ffn_hidden_size,
ffn_hidden_size, # NOTE: When using geglu, divide ffn dim by 2/3 to keep overall params the same.
gather_output=False,
init_method=init_method,
skip_bias_add=True,
use_cpu_initialization=use_cpu_initialization,
)

if activation == 'geglu':
# Separate linear layer for GEGLU activation.
# Source: https://github.com/huggingface/transformers/blob/bee361c6f1f7704f8c688895f2f86f6e5ff84727/src/transformers/models/t5/modeling_t5.py#L292
self.dense_h_to_4h_2 = tensor_parallel.ColumnParallelLinear(
hidden_size,
ffn_hidden_size, # NOTE: When using geglu, divide ffn dim by 2/3 to keep overall params the same.
gather_output=False,
init_method=init_method,
skip_bias_add=True,
use_cpu_initialization=use_cpu_initialization,
)

self.bias_gelu_fusion = bias_gelu_fusion
self.activation_func = F.gelu
if activation == 'geglu':
self.activation_func = 'geglu' # Implemented using F.gelu
if bias_gelu_fusion:
logging.warning("Bias Gelu Fusion is not supported for GEGLU activation. Running with pytorch F.gelu")
if openai_gelu:
self.activation_func = openai_gelu
elif onnx_safe:
Expand All @@ -111,7 +133,14 @@ def forward(self, hidden_states):
# [s, b, 4hp]
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states)

if self.bias_gelu_fusion:
if self.activation == 'geglu':
intermediate_parallel_2, bias_parallel_2 = self.dense_h_to_4h_2(hidden_states)

if self.activation == 'geglu':
intermediate_parallel = F.gelu(intermediate_parallel + bias_parallel) * (
intermediate_parallel_2 + bias_parallel_2
)
elif self.bias_gelu_fusion and self.activation == 'gelu':
intermediate_parallel = fused_bias_gelu(intermediate_parallel, bias_parallel)
else:
intermediate_parallel = self.activation_func(intermediate_parallel + bias_parallel)
Expand Down Expand Up @@ -411,6 +440,7 @@ def __init__(
onnx_safe=False,
masked_softmax_fusion=True,
attention_dropout=0.1,
activation='gelu',
):
super(ParallelTransformerLayer_, self).__init__()

Expand Down Expand Up @@ -448,6 +478,7 @@ def __init__(
attention_dropout=attention_dropout,
)
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.bias_dropout_fusion = bias_dropout_fusion # if true, enable bias dropout fusion

# Layernorm on the attention output
Expand Down Expand Up @@ -482,6 +513,7 @@ def __init__(
bias_gelu_fusion=bias_gelu_fusion,
openai_gelu=openai_gelu,
onnx_safe=onnx_safe,
activation=activation,
)

def forward(
Expand Down Expand Up @@ -633,6 +665,7 @@ def __init__(
persist_layer_norm=False,
openai_gelu=False,
onnx_safe=False,
activation='gelu',
):
super(ParallelTransformer, self).__init__()

Expand Down Expand Up @@ -682,6 +715,7 @@ def build_layer(layer_number):
persist_layer_norm=persist_layer_norm,
openai_gelu=openai_gelu,
onnx_safe=onnx_safe,
activation=activation,
)

if parallel_state.get_virtual_pipeline_model_parallel_rank() is not None:
Expand Down

0 comments on commit 595549b

Please sign in to comment.