Skip to content

Commit

Permalink
cleaner version with exitstack as opposed to nested with statements
Browse files Browse the repository at this point in the history
Signed-off-by: Jason <[email protected]>
  • Loading branch information
blisc committed Feb 7, 2020
1 parent a9be01a commit 71d4bff
Showing 1 changed file with 8 additions and 18 deletions.
26 changes: 8 additions & 18 deletions nemo/backends/pytorch/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json
import os
from collections import defaultdict
from contextlib import contextmanager
from contextlib import contextmanager, ExitStack
from pathlib import Path
from typing import List, Optional

Expand Down Expand Up @@ -1354,7 +1354,9 @@ def train(
curr_optimizer.zero_grad()
continue
if disable_allreduce:
with self.no_sync(self.get_DDP_modules(curr_call_chain)):
with ExitStack() as stack:
for mod in self.get_DDP_modules(curr_call_chain):
stack.enter_context(mod.no_sync())
scaled_loss.backward(bps_scale.to(scaled_loss.get_device()))
else:
scaled_loss.backward(bps_scale.to(scaled_loss.get_device()))
Expand All @@ -1363,7 +1365,9 @@ def train(
# multi-GPU, float32
if self._local_rank is not None:
if disable_allreduce:
with self.no_sync(self.get_DDP_modules(curr_call_chain)):
with ExitStack() as stack:
for mod in self.get_DDP_modules(curr_call_chain):
stack.enter_context(mod.no_sync())
final_loss.backward(bps_scale.to(final_loss.get_device()))
else:
final_loss.backward(bps_scale.to(final_loss.get_device()))
Expand Down Expand Up @@ -1462,7 +1466,7 @@ def infer(
offload_to_cpu=offload_to_cpu,
)

def get_DDP_modules(self, callchain):
def get_DDP_modules(self, call_chain):
modules = []
for ind in range(1, len(call_chain)):
m_id = call_chain[ind][0].unique_instance_id
Expand All @@ -1471,17 +1475,3 @@ def get_DDP_modules(self, callchain):
modules.append(module)

return modules

@contextmanager
def no_sync(self, modules):
"""
Wrapper contextmanager around pytorch DDP's @no_sync since pytorch requires ALL wrapper DDP models to be
inside the @no_sync context manager for graduation accumulation
"""
mod = modules.pop()
with mod.no_sync() as ctx:
try:
self.no_sync(modules)
yield [ctx]
finally:
pass

0 comments on commit 71d4bff

Please sign in to comment.