From 71d4bff1303c35574fd35594931e2253b59a7ca6 Mon Sep 17 00:00:00 2001 From: Jason Date: Thu, 6 Feb 2020 18:54:17 -0800 Subject: [PATCH] cleaner version with exitstack as opposed to nested with statements Signed-off-by: Jason --- nemo/backends/pytorch/actions.py | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/nemo/backends/pytorch/actions.py b/nemo/backends/pytorch/actions.py index d70b9f52f795..26a3d52f6098 100644 --- a/nemo/backends/pytorch/actions.py +++ b/nemo/backends/pytorch/actions.py @@ -5,7 +5,7 @@ import json import os from collections import defaultdict -from contextlib import contextmanager +from contextlib import contextmanager, ExitStack from pathlib import Path from typing import List, Optional @@ -1354,7 +1354,9 @@ def train( curr_optimizer.zero_grad() continue if disable_allreduce: - with self.no_sync(self.get_DDP_modules(curr_call_chain)): + with ExitStack() as stack: + for mod in self.get_DDP_modules(curr_call_chain): + stack.enter_context(mod.no_sync()) scaled_loss.backward(bps_scale.to(scaled_loss.get_device())) else: scaled_loss.backward(bps_scale.to(scaled_loss.get_device())) @@ -1363,7 +1365,9 @@ def train( # multi-GPU, float32 if self._local_rank is not None: if disable_allreduce: - with self.no_sync(self.get_DDP_modules(curr_call_chain)): + with ExitStack() as stack: + for mod in self.get_DDP_modules(curr_call_chain): + stack.enter_context(mod.no_sync()) final_loss.backward(bps_scale.to(final_loss.get_device())) else: final_loss.backward(bps_scale.to(final_loss.get_device())) @@ -1462,7 +1466,7 @@ def infer( offload_to_cpu=offload_to_cpu, ) - def get_DDP_modules(self, callchain): + def get_DDP_modules(self, call_chain): modules = [] for ind in range(1, len(call_chain)): m_id = call_chain[ind][0].unique_instance_id @@ -1471,17 +1475,3 @@ def get_DDP_modules(self, callchain): modules.append(module) return modules - - @contextmanager - def no_sync(self, modules): - """ - Wrapper contextmanager around pytorch DDP's @no_sync since pytorch requires ALL wrapper DDP models to be - inside the @no_sync context manager for graduation accumulation - """ - mod = modules.pop() - with mod.no_sync() as ctx: - try: - self.no_sync(modules) - yield [ctx] - finally: - pass