diff --git a/nemo/collections/nlp/modules/common/megatron/clip_grads.py b/nemo/collections/nlp/modules/common/megatron/clip_grads.py index eb8023b47f51..da44618db02e 100644 --- a/nemo/collections/nlp/modules/common/megatron/clip_grads.py +++ b/nemo/collections/nlp/modules/common/megatron/clip_grads.py @@ -14,13 +14,13 @@ """Gradient clipping.""" -import amp_C import torch from torch._six import inf from nemo.collections.nlp.modules.common.megatron.module import param_is_not_shared try: + import amp_C from apex.multi_tensor_apply import multi_tensor_applier from apex.transformer import parallel_state from apex.transformer.tensor_parallel.layers import param_is_not_tensor_parallel_duplicate diff --git a/nemo/collections/nlp/modules/common/megatron/language_model.py b/nemo/collections/nlp/modules/common/megatron/language_model.py index 2bcf607ad31b..b61c9cde6fad 100644 --- a/nemo/collections/nlp/modules/common/megatron/language_model.py +++ b/nemo/collections/nlp/modules/common/megatron/language_model.py @@ -22,6 +22,7 @@ from nemo.collections.nlp.modules.common.megatron.module import MegatronModule from nemo.collections.nlp.modules.common.megatron.transformer import ParallelTransformer from nemo.collections.nlp.modules.common.megatron.utils import ( + ApexGuardDefaults, get_linear_layer, init_method_normal, scaled_init_method_normal, @@ -35,6 +36,9 @@ except (ImportError, ModuleNotFoundError): HAVE_APEX = False + # fake missing classes with None attributes + AttnMaskType = ApexGuardDefaults() + def get_language_model( hidden_size, diff --git a/nemo/collections/nlp/modules/common/megatron/transformer.py b/nemo/collections/nlp/modules/common/megatron/transformer.py index 70d355f72c17..dd2e9ea1dac8 100644 --- a/nemo/collections/nlp/modules/common/megatron/transformer.py +++ b/nemo/collections/nlp/modules/common/megatron/transformer.py @@ -27,7 +27,7 @@ from nemo.collections.nlp.modules.common.megatron.fused_bias_gelu import fused_bias_gelu from nemo.collections.nlp.modules.common.megatron.fused_layer_norm import get_layer_norm from nemo.collections.nlp.modules.common.megatron.module import MegatronModule -from nemo.collections.nlp.modules.common.megatron.utils import attention_mask_func, erf_gelu +from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults, attention_mask_func, erf_gelu try: from apex.transformer import parallel_state, tensor_parallel @@ -39,6 +39,8 @@ except (ImportError, ModuleNotFoundError): HAVE_APEX = False + # fake missing classes with None attributes + AttnMaskType = AttnType = LayerType = ApexGuardDefaults() """ We use the following notation throughout this file: h: hidden size @@ -421,7 +423,8 @@ def __init__( self.layer_number = layer_number self.layer_type = layer_type - self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm # if true apply residual connection post layer norm (like original bert) + # if true apply residual connection post layer norm (like original bert) + self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm self.fp32_residual_connection = fp32_residual_connection # if true move residual connections to fp32 diff --git a/nemo/collections/nlp/modules/common/megatron/utils.py b/nemo/collections/nlp/modules/common/megatron/utils.py index aec8cfc6f375..a5aa5e6b6d33 100644 --- a/nemo/collections/nlp/modules/common/megatron/utils.py +++ b/nemo/collections/nlp/modules/common/megatron/utils.py @@ -28,6 +28,18 @@ HAVE_APEX = False +class ApexGuardDefaults(object): + """ + This class can be used to replace missing classes when apex is missing. + """ + + def __init__(self): + super().__init__() + + def __getattr__(self, item): + return None + + def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None): """LM logits using word embedding weights.""" # Parallel logits.