Skip to content

Commit

Permalink
add import guard for nlp plugin
Browse files Browse the repository at this point in the history
Signed-off-by: Yi Dong <[email protected]>
  • Loading branch information
yidong72 committed Jan 26, 2022
1 parent f70542c commit 2856e84
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
18 changes: 9 additions & 9 deletions nemo/collections/nlp/parts/nlp_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

try:
from apex.transformer import parallel_state

HAVE_APEX = True

except (ImportError, ModuleNotFoundError):

HAVE_APEX = False

import os
import shutil
import tempfile
Expand All @@ -35,15 +44,6 @@
from nemo.core.optim import MasterOptimizerWrapper
from nemo.utils import AppState, logging

try:
from apex.transformer import parallel_state

HAVE_APEX = True

except (ImportError, ModuleNotFoundError):

HAVE_APEX = False


class NLPDDPPlugin(DDPPlugin):
""" DDP plugin for Pytorch Lightning. Needed to customize DDP for model parallel models.
Expand Down
10 changes: 8 additions & 2 deletions nemo/core/classes/modelPT.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@
from nemo.utils.app_state import AppState
from nemo.utils.get_rank import is_global_rank_zero

try:
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPPlugin

HAVE_NLPPLUGIN = True
except (ImportError, ModuleNotFoundError):
HAVE_NLPPLUGIN = False

__all__ = ['ModelPT']


Expand Down Expand Up @@ -458,7 +465,6 @@ def setup_optimization(self, optim_config: Optional[Union[DictConfig, Dict]] = N
logging.warning(f"Trainer wasn't specified in model constructor. Make sure that you really wanted it.")

if 'sched' in optim_config and self._trainer is not None:
from nemo.collections.nlp.parts.nlp_overrides import NLPDDPPlugin

if not isinstance(self._trainer.accumulate_grad_batches, int):
raise ValueError("We do not currently support gradient acculumation that is not an integer.")
Expand All @@ -473,7 +479,7 @@ def setup_optimization(self, optim_config: Optional[Union[DictConfig, Dict]] = N
optim_config['sched']['t_num_workers'] = self._trainer.num_processes * self._trainer.num_nodes
elif self._trainer.accelerator == "ddp":
optim_config['sched']['t_num_workers'] = self._trainer.num_gpus * self._trainer.num_nodes
elif isinstance(self._trainer.accelerator.training_type_plugin, NLPDDPPlugin):
elif HAVE_NLPPLUGIN and isinstance(self._trainer.accelerator.training_type_plugin, NLPDDPPlugin):
app = AppState()
optim_config['sched']['t_num_workers'] = app.data_parallel_size
else:
Expand Down

0 comments on commit 2856e84

Please sign in to comment.