Skip to content

Commit

Permalink
WIP: Handle arbitrary combinations of optimizers/models/losses (pytor…
Browse files Browse the repository at this point in the history
…ch#232)

* Refactor to allow more flexible treatment of multiple optimizers/models/losses

* Adding _process_optimizers.py

* Created L0 tests (now passing).

* fix: minor print typo (pytorch#234)

* make L1 results easier to read

* L0 multiple model/optimizer/loss test fleshed out

* Adding test that master params remain synced across distributed processes

* Docstring updates

* Docstring updates
  • Loading branch information
mcarilli authored Apr 4, 2019
1 parent 214fda4 commit 3f87614
Show file tree
Hide file tree
Showing 22 changed files with 1,611 additions and 216 deletions.
2 changes: 1 addition & 1 deletion apex/amp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
register_half_function, register_float_function, register_promote_function
from .handle import scale_loss, disable_casts
from .frontend import initialize
from ._amp_state import master_params
from ._amp_state import master_params, _amp_state
1 change: 1 addition & 0 deletions apex/amp/_amp_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
class AmpState(object):
def __init__(self):
self.hard_override=False
self.allow_incoming_model_not_fp32 = False
self.verbosity=1


Expand Down
31 changes: 14 additions & 17 deletions apex/amp/_initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ._amp_state import _amp_state, warn_or_err, container_abcs
from .handle import disable_casts
from .scaler import LossScaler
from ._process_optimizer import _process_optimizer
from apex.fp16_utils import convert_network
from ..fp16_utils import FP16_Optimizer as FP16_Optimizer_general
from ..optimizers import FP16_Optimizer as FP16_Optimizer_for_fused
Expand Down Expand Up @@ -122,7 +123,7 @@ def wrap_fused_adam(optimizer, properties):
return FP16_Optimizer_for_fused(optimizer, static_loss_scale=properties.loss_scale)


def _initialize(models, optimizers, properties):
def _initialize(models, optimizers, properties, num_losses=1):
from apex.parallel import DistributedDataParallel as apex_DDP
from .amp import init as amp_init

Expand All @@ -146,7 +147,8 @@ def _initialize(models, optimizers, properties):

check_models(models)

check_params_fp32(models)
if not _amp_state.allow_incoming_model_not_fp32:
check_params_fp32(models)

check_optimizers(optimizers)

Expand Down Expand Up @@ -181,21 +183,16 @@ def new_fwd(*args, **kwargs):
for optimizer in optimizers:
optimizer.load_state_dict(optimizer.state_dict())

if properties.master_weights:
for i, optimizer in enumerate(optimizers):
if isinstance(optimizer, FusedAdam):
optimizers[i] = wrap_fused_adam(optimizer, properties)
if properties.loss_scale == "dynamic":
optimizers[i] = FP16_Optimizer_general(optimizer,
dynamic_loss_scale=True,
verbose=False)
else:
optimizers[i] = FP16_Optimizer_general(optimizer,
static_loss_scale=properties.loss_scale,
verbose=False)
else:
for optimizer in optimizers:
optimizer.loss_scaler = LossScaler(properties.loss_scale)
for i, optimizer in enumerate(optimizers):
# Still need to special case this for the first pass
if isinstance(optimizer, FusedAdam):
optimizers[i] = wrap_fused_adam(optimizer, properties)
else:
optimizers[i] = _process_optimizer(optimizer, properties)

_amp_state.loss_scalers = []
for _ in range(num_losses):
_amp_state.loss_scalers.append(LossScaler(properties.loss_scale))

if properties.patch_torch_functions:
# handle is unused here. It's accessible later through a global value anyway.
Expand Down
Loading

0 comments on commit 3f87614

Please sign in to comment.