Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor EMA to improve memory efficiency #1941

Merged
merged 10 commits into from
Feb 9, 2023
4 changes: 2 additions & 2 deletions composer/algorithms/ema/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ model = ema.ema_model

### Implementation Details

Because EMA needs to maintain a copy of the model's (averaged) weights, it requires a bit more on-device memory. In the functional implementation, the amount of extra memory is 2x the size of the model. In the composer trainer implementation, it is 3x the size of the model to allow for swapping the training and evaluation models. In practice, the extra memory used is small relative to the total amount of memory used, as activations and optimizer state are not duplicated.
Because EMA needs to maintain a copy of the model's (averaged) weights, it requires a bit more on-device memory. The amount of extra memory used is equal to the size of the model's trainable parameters and buffers. In practice, the extra memory used is small relative to the total amount of memory used, as activations and optimizer state are not duplicated.

EMA also uses a bit of extra compute to calculate the moving average. This can lead to a small slowdown. The extra compute can be reduced by not computing the moving average every iteration. In the composer trainer implementation this can be done by using a larger `update_interval`. In practice we find that as long as `half_life` is much larger than `update_interval`, increasing `update_interval` does not have much effect on generalization performance.

Expand Down Expand Up @@ -113,7 +113,7 @@ To use this, `half_life` should be set to `half_life=None`, and the value of smo

> ❗ Evaluation should not be done with the training model
>
> Evaluation should be done with the `ema_model` in the functional impementation as this is the model containing the averaged parameters. The ema model can be accessed after training from the `EMA` object via `model = ema.ema_model` in the composer trainer implementation. Similarly, the model without ema applied (the training model) can be accessed via `model=ema.training_model`. By default, when saving checkpoints with the `CheckpointSaver` callback or through trainer arguments the weights saved will be the ema model weights. An exception is if saving is done by explicitly calling `trainer.save_checkpoint()` which will result in the training model weights being saved as `state.model`.
> Evaluation should be done with the `ema_model` in the functional impementation as this is the model containing the averaged parameters. The ema model can be accessed after training from the `EMA` object via `model = ema.get_ema_model(model)` in the composer trainer implementation. This replaces the parameters of the supplied model with the ema_weights unless composer's model already contains them. Similarly, the model without ema applied (the training model) can be accessed via `model=ema.get_training_model(model)`. By default, when saving checkpoints with the `CheckpointSaver` callback or through trainer arguments the weights saved will be the ema model weights. An exception is if saving is done by explicitly calling `trainer.save_checkpoint()` which will result in the training model weights being saved as `state.model`.


## Attribution
Expand Down
192 changes: 134 additions & 58 deletions composer/algorithms/ema/ema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@

from __future__ import annotations

import copy
import itertools
import logging
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union

import torch

Expand All @@ -21,7 +20,9 @@
__all__ = ['EMA', 'compute_ema']


def compute_ema(model: torch.nn.Module, ema_model: torch.nn.Module, smoothing: float = 0.99) -> None:
def compute_ema(model: torch.nn.Module,
ema_model: Union[torch.nn.Module, EMAParameters],
smoothing: float = 0.99) -> None:
r"""Updates the weights of ``ema_model`` to be closer to the weights of ``model``
according to an exponential weighted average. Weights are updated according to

Expand All @@ -42,7 +43,7 @@ def compute_ema(model: torch.nn.Module, ema_model: torch.nn.Module, smoothing: f

Args:
model (torch.nn.Module): the model containing the latest weights to use to update the moving average weights.
ema_model (torch.nn.Module): the model containing the moving average weights to be updated.
ema_model (torch.nn.Module, EMAParameters): the model containing the moving average weights to be updated.
smoothing (float, optional): the coefficient representing the degree to which older observations are kept.
Must be in the interval :math:`(0, 1)`. Default: ``0.99``.

Expand All @@ -56,16 +57,28 @@ def compute_ema(model: torch.nn.Module, ema_model: torch.nn.Module, smoothing: f
cf.compute_ema(model, ema_model, smoothing=0.9)
"""
with torch.no_grad():
model_params = itertools.chain(model.parameters(), model.buffers())
ema_model_params = itertools.chain(ema_model.parameters(), ema_model.buffers())

for ema_param, model_param in zip(ema_model_params, model_params):
model_param = model_param.detach()
ema_param.copy_(ema_param * smoothing + (1. - smoothing) * model_param)
# If the ema model is a pytorch module, can just use the state_dict
if isinstance(ema_model, torch.nn.Module):
ema_params = ema_model.state_dict()
for name, param in itertools.chain(model.named_parameters(), model.named_buffers()):
if name in ema_params:
ema_params[name].copy_(ema_params[name] * smoothing + param.data * (1. - smoothing))
# Otherwise, the ema model needs to define the named_parameters and named_buffers dictionaries
# These should contain the parameters and buffers to average.
elif isinstance(ema_model, EMAParameters):
ema_parameters = ema_model.named_parameters_dict
ema_buffers = ema_model.named_buffers_dict
for name, param in itertools.chain(model.named_parameters(), model.named_buffers()):
if name in ema_parameters:
ema_parameters[name].copy_(ema_parameters[name] * smoothing + param.data * (1. - smoothing))
if name in ema_buffers:
ema_buffers[name].copy_(ema_buffers[name] * smoothing + param.data * (1. - smoothing))
else:
raise ValueError('ema_model must be a torch.nn.Module or EMAParameters')


class EMA(Algorithm):
r"""Maintains a shadow model with weights that follow the exponential moving average of the trained model weights.
r"""Maintains a set of weights that follow the exponential moving average of the training model weights.

Weights are updated according to

Expand All @@ -78,7 +91,7 @@ class EMA(Algorithm):
smoothing = \exp\left[-\frac{\log(2)}{t_{1/2}}\right]

Model evaluation is done with the moving average weights, which can result in better generalization. Because of the
shadow models, EMA triples the model's memory consumption. Note that this does not mean that the total memory
ema weights, EMA can double the model's memory consumption. Note that this does not mean that the total memory
required doubles, since stored activations and the optimizer state are not duplicated. EMA also uses a small
amount of extra compute to update the moving average weights.

Expand Down Expand Up @@ -124,10 +137,9 @@ def __init__(self,
ema_start: str = '0.0dur',
update_interval: Optional[str] = None):
self.ema_model = None
self.training_model = None
self.ema_weights_active = False
self.ema_started = False
self.serialized_attributes = ['ema_model', 'training_model', 'ema_weights_active', 'ema_started']
self.serialized_attributes = ['ema_model', 'ema_weights_active', 'ema_started']

# Verify that either half_life or smoothing has been specified
if half_life is None and smoothing is None:
Expand Down Expand Up @@ -191,15 +203,24 @@ def _should_start(self, state: State) -> bool:

return should_start

def _ensure_training_weights_active(self, state: State):
if self.ema_weights_active is True and self.ema_model is not None:
self.ema_model.swap_params(model=state.model)
self.ema_weights_active = False

def _ensure_ema_weights_active(self, state: State):
if self.ema_weights_active is False and self.ema_model is not None:
self.ema_model.swap_params(model=state.model)
self.ema_weights_active = True

def match(self, event: Event, state: State) -> bool:
# Always run on init
if event == Event.INIT:
return True

# Check if ema should start running, and if so reinitialize models
if event == self.update_event and self.ema_started is False and self._should_start(state):
self.ema_model = copy.deepcopy(state.model)
self.training_model = copy.deepcopy(state.model)
self.ema_model = EMAParameters(state.model)
self.ema_started = True

# Match on checkpointing events if a checkpoint is to be saved
Expand All @@ -225,83 +246,138 @@ def apply(self, event: Event, state: State, logger: Logger) -> None:

if event == Event.INIT:
# Create the models so that the checkpoints can be loaded
self.ema_model = copy.deepcopy(state.model)
self.training_model = copy.deepcopy(state.model)
self.ema_model = EMAParameters(state.model)

assert self.ema_model is not None
assert self.training_model is not None

if event == Event.FIT_START:
# Ensure that params are on the right device if a checkpoint has been loaded
_move_params_to_device(model=self.ema_model, destination_model=state.model)
_move_params_to_device(model=self.training_model, destination_model=state.model)
self.ema_model.move_params_to_device(destination_model=state.model)

if event == Event.BATCH_START and self.ema_weights_active:
# Ensure the model being trained has the correct weights
_copy_params(source_model=self.training_model, destination_model=state.model)
self.ema_weights_active = False
self._ensure_training_weights_active(state)

if event in [Event.BATCH_END, Event.EPOCH_END]:
# Update the ema model
compute_ema(state.model, self.ema_model, smoothing=self.smoothing)

if event == Event.EVAL_START and self.ema_weights_active is False:
# Swap out the training model for the ema model in state
_copy_params(source_model=state.model, destination_model=self.training_model)
_copy_params(source_model=self.ema_model, destination_model=state.model)
self.ema_weights_active = True
self._ensure_ema_weights_active(state)

if event == Event.EVAL_END:
# Swap out the ema model for the training model in state
_copy_params(source_model=self.training_model, destination_model=state.model)
self.ema_weights_active = False
self._ensure_training_weights_active(state)

if event in self.checkpoint_events and self.ema_weights_active is False:
if event in self.checkpoint_events:
# Swap the training model out for the ema model for checkpointing
_copy_params(source_model=state.model, destination_model=self.training_model)
_copy_params(source_model=self.ema_model, destination_model=state.model)
self.ema_weights_active = True
self._ensure_ema_weights_active(state)

def state_dict(self) -> Dict[str, Any]:
state_dict = super().state_dict()
for attribute_name in self.serialized_attributes:
if attribute_name in ['ema_model', 'training_model']:
model = getattr(self, attribute_name)
state_dict[attribute_name] = model.state_dict()
if attribute_name == 'ema_model':
ema_model = getattr(self, attribute_name)
state_dict[attribute_name] = {}
state_dict[attribute_name]['named_parameters_dict'] = ema_model.named_parameters_dict
state_dict[attribute_name]['named_buffers_dict'] = ema_model.named_buffers_dict
else:
state_dict[attribute_name] = getattr(self, attribute_name)
return state_dict

def load_state_dict(self, state: Dict[str, Any], strict: bool = False):
for attribute_name, serialized_value in state.items():
if attribute_name != 'repr': # skip attribute added by parent class
if attribute_name == 'ema_model' and self.ema_model is not None:
self.ema_model.load_state_dict(serialized_value)
elif attribute_name == 'training_model' and self.training_model is not None:
self.training_model.load_state_dict(serialized_value)
if attribute_name == 'ema_model':
self.ema_model = EMAParameters(None)
self.ema_model.named_parameters_dict = serialized_value['named_parameters_dict']
self.ema_model.named_buffers_dict = serialized_value['named_buffers_dict']
else:
setattr(self, attribute_name, serialized_value)

def get_ema_model(self, model: torch.nn.Module) -> torch.nn.Module:
"""Replaces the parameters of the supplied model with the ema parameters if they are not already active.

def _copy_params(source_model: torch.nn.Module, destination_model: torch.nn.Module):
"""Copies parameters and buffers from ``source_model`` to ``destination_model``."""
with torch.no_grad():
source_params = itertools.chain(source_model.parameters(), source_model.buffers())
destination_params = itertools.chain(destination_model.parameters(), destination_model.buffers())
Args:
model (torch.nn.Module): The model to replace the parameters of.

Returns:
torch.nn.Module: The model with the ema parameters.
"""
assert self.ema_model is not None
# Ensure that self.ema_model contains the ema weights. If not raise an error.
if self.ema_weights_active == True:
raise ValueError('The ema weight are currently contained in the composer model.')
self.ema_model.transfer_ema_params(model=model)
return model

for source_param, destination_param in zip(source_params, destination_params):
destination_param.data = source_param.data
def get_training_model(self, model: torch.nn.Module) -> torch.nn.Module:
"""Replaces the parameters of the supplied model with the training parameters if they are not already active.

Args:
model (torch.nn.Module): The model to replace the parameters of.

def _move_params_to_device(model: torch.nn.Module, destination_model: torch.nn.Module):
"""Ensures the parameters of a model are on the same device as a destination model."""
with torch.no_grad():
destination_params = destination_model.parameters()
params = model.parameters()
for s, d in zip(params, destination_params):
s.to(d.device)

destination_buffers = destination_model.buffers()
buffers = model.buffers()
for s, d in zip(buffers, destination_buffers):
s.to(d.device)
Returns:
torch.nn.Module: The model with the training parameters.
"""
assert self.ema_model is not None
# Ensure that self.ema_model contains the training weights. If not raise an error.
if self.ema_weights_active == False:
raise ValueError('The training weights are currently contained in the composer model.')
self.ema_model.transfer_ema_params(model=model)
return model


class EMAParameters:
"""A class that stores the parameters and buffers of a model needed for averaging."""

def __init__(self, model: Union[None, torch.nn.Module]):
if model is not None:
# Copy the trainable parameters and buffers.
self.named_parameters_dict = {
name: param.data.clone() for name, param in model.named_parameters() if param.requires_grad
}
self.named_buffers_dict = {name: buffer.data.clone() for name, buffer in model.named_buffers()}
else:
# Empty storage
self.named_parameters_dict = {}
self.named_buffers_dict = {}

def named_parameters(self):
return self.named_parameters_dict.items()

def named_buffers(self):
return self.named_buffers_dict.items()

def swap_params(self, model: torch.nn.Module):
"""Swaps the parameters and buffers of a model with the ema parameters."""
with torch.no_grad():
ema_params = self.named_parameters_dict
ema_buffers = self.named_buffers_dict

for name, param in model.named_parameters():
if name in ema_params:
param.data, ema_params[name] = ema_params[name], param.data

for name, buffer in model.named_buffers():
buffer.data, ema_buffers[name] = ema_buffers[name], buffer.data

def transfer_ema_params(self, model: torch.nn.Module):
"""Transfers the parameters and buffers from the ema model to the supplied model."""
with torch.no_grad():
for name, param in model.named_parameters():
if name in self.named_parameters_dict:
param.data = self.named_parameters_dict[name]

for name, buffer in model.named_buffers():
buffer.data = self.named_buffers_dict[name]

def move_params_to_device(self, destination_model: torch.nn.Module):
"""Moves the ema parameters and buffers to the device of a destination model."""
model_state_dict = destination_model.state_dict()
for name, param in self.named_parameters_dict.items():
self.named_parameters_dict[name] = param.to(model_state_dict[name].device)

for name, buffer in self.named_buffers_dict.items():
self.named_buffers_dict[name] = buffer.to(model_state_dict[name].device)
Loading