Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed Apex guard when imported classes are used for default values #3700

Merged
merged 7 commits into from
Feb 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nemo/collections/nlp/modules/common/megatron/clip_grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@

"""Gradient clipping."""

import amp_C
import torch
from torch._six import inf

from nemo.collections.nlp.modules.common.megatron.module import param_is_not_shared

try:
import amp_C
from apex.multi_tensor_apply import multi_tensor_applier
from apex.transformer import parallel_state
from apex.transformer.tensor_parallel.layers import param_is_not_tensor_parallel_duplicate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from nemo.collections.nlp.modules.common.megatron.module import MegatronModule
from nemo.collections.nlp.modules.common.megatron.transformer import ParallelTransformer
from nemo.collections.nlp.modules.common.megatron.utils import (
ApexGuardDefaults,
get_linear_layer,
init_method_normal,
scaled_init_method_normal,
Expand All @@ -35,6 +36,9 @@
except (ImportError, ModuleNotFoundError):
HAVE_APEX = False

# fake missing classes with None attributes
AttnMaskType = ApexGuardDefaults()


def get_language_model(
hidden_size,
Expand Down
7 changes: 5 additions & 2 deletions nemo/collections/nlp/modules/common/megatron/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from nemo.collections.nlp.modules.common.megatron.fused_bias_gelu import fused_bias_gelu
from nemo.collections.nlp.modules.common.megatron.fused_layer_norm import get_layer_norm
from nemo.collections.nlp.modules.common.megatron.module import MegatronModule
from nemo.collections.nlp.modules.common.megatron.utils import attention_mask_func, erf_gelu
from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults, attention_mask_func, erf_gelu

try:
from apex.transformer import parallel_state, tensor_parallel
Expand All @@ -39,6 +39,8 @@
except (ImportError, ModuleNotFoundError):
HAVE_APEX = False

# fake missing classes with None attributes
AttnMaskType = AttnType = LayerType = ApexGuardDefaults()

""" We use the following notation throughout this file:
h: hidden size
Expand Down Expand Up @@ -421,7 +423,8 @@ def __init__(
self.layer_number = layer_number
self.layer_type = layer_type

self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm # if true apply residual connection post layer norm (like original bert)
# if true apply residual connection post layer norm (like original bert)
self.apply_residual_connection_post_layernorm = apply_residual_connection_post_layernorm

self.fp32_residual_connection = fp32_residual_connection # if true move residual connections to fp32

Expand Down
12 changes: 12 additions & 0 deletions nemo/collections/nlp/modules/common/megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,18 @@
HAVE_APEX = False


class ApexGuardDefaults(object):
"""
This class can be used to replace missing classes when apex is missing.
"""

def __init__(self):
super().__init__()

def __getattr__(self, item):
return None


def parallel_lm_logits(input_, word_embeddings_weight, parallel_output, bias=None):
"""LM logits using word embedding weights."""
# Parallel logits.
Expand Down