Skip to content

Commit

Permalink
add layerwise lr
Browse files Browse the repository at this point in the history
  • Loading branch information
bmosaicml committed Feb 21, 2023
1 parent dc054df commit e49741e
Show file tree
Hide file tree
Showing 6 changed files with 357 additions and 14 deletions.
2 changes: 2 additions & 0 deletions composer/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from composer.callbacks.early_stopper import EarlyStopper
from composer.callbacks.export_for_inference import ExportForInferenceCallback
from composer.callbacks.image_visualizer import ImageVisualizer
from composer.callbacks.loss_spike_intervention import LossSpikeIntervention
from composer.callbacks.lr_monitor import LRMonitor
from composer.callbacks.memory_monitor import MemoryMonitor
from composer.callbacks.mlperf import MLPerfCallback
Expand All @@ -27,5 +28,6 @@
'EarlyStopper',
'ExportForInferenceCallback',
'ThresholdStopper',
'LossSpikeIntervention',
'ImageVisualizer',
]
220 changes: 220 additions & 0 deletions composer/callbacks/loss_spike_intervention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Monitor gradients during training."""

import collections

import torch

from composer.core import Callback, State
from composer.loggers import Logger
from composer.utils import dist

__all__ = ['LossSpikeIntervention']

DEFAULT_TIMEOUT = 10_000
DEFAULT_LR_SCALE = 0.01
DEFAULT_CLEAR_OPT = True


class MetricSpikeDetector:

def __init__(self,
window_moving_average=25,
increase_factor=5,
increase_lookback=500,
plateau_min_duration=100,
end_spike_factor=1.10):

self.window_moving_average = window_moving_average
self.increase_factor = increase_factor
self.plateau_min_duration = plateau_min_duration
self.increase_lookback = increase_lookback
self.fast_moving_average = collections.deque(maxlen=window_moving_average)
self.intermediate_data_queue = collections.deque(maxlen=increase_lookback - window_moving_average)
self.slow_moving_average = collections.deque(maxlen=increase_lookback)
self.end_spike_factor = end_spike_factor
self.in_spike = False
self.mva_before_spike = None
self.spike_batch_idx_start = None

def insert_observation(self, obs, batch_idx):
if len(self.fast_moving_average) >= self.fast_moving_average.maxlen:
# move the oldest obs out of the fast moving average into the
# intermediate data queue
fast_obs = self.fast_moving_average.popleft()

if len(self.intermediate_data_queue) >= self.intermediate_data_queue.maxlen:
# move data from intermediate quque to slow MCVA queue
intermediate_obs = self.intermediate_data_queue.popleft()
self.slow_moving_average.append(intermediate_obs)

self.intermediate_data_queue.append(fast_obs)

self.fast_moving_average.append(obs)

fast_mva = sum(self.fast_moving_average) / len(self.fast_moving_average)
if not self.in_spike:
if len(self.slow_moving_average) > self.window_moving_average:
if self.mva_before_spike is None:
slow_mva = sum(self.slow_moving_average) / len(self.slow_moving_average)
else:
slow_mva = self.mva_before_spike

if fast_mva >= self.increase_factor * slow_mva:
self.in_spike = True
self.mva_before_spike = slow_mva
self.spike_batch_idx_start = batch_idx
else:
if batch_idx - self.spike_batch_idx_start > self.plateau_min_duration:
# kill the layer!
return True
else:
if fast_mva <= self.mva_before_spike * self.end_spike_factor:
self.in_spike = False
self.spike_batch_idx_start = None

return False


class LossSpikeIntervention(Callback):

def __init__(self,
metric='l2_norm/moment',
window_moving_average=25,
increase_factor=5,
increase_lookback=500,
plateau_min_duration=100,
end_spike_factor=1.10,
lr_scale=0.0,
unfreeze_policy=None):
self.metric = metric
self.lr_scale = lr_scale
self.window_moving_average = window_moving_average
self.increase_factor = increase_factor
self.increase_lookback = increase_lookback
self.plateau_min_duration = plateau_min_duration
self.end_spike_factor = end_spike_factor

self.metric_spike_detectors = {}
self.frozen_layers = {}
self.all_layers = set()
self.unfreeze_policy = unfreeze_policy

def fit_start(self, state: State, logger: Logger) -> None:
for name, p in state.model.named_parameters():
if p.requires_grad:
self.all_layers.add(name)
full_metric_name = f'{self.metric}/{name}'
self.metric_spike_detectors[full_metric_name] = MetricSpikeDetector(
self.window_moving_average,
self.increase_factor,
self.increase_lookback,
self.plateau_min_duration,
self.end_spike_factor,
)

def unfreeze_layer(self, layer, lr_scale, clear_opt, state):
for name, p in state.model.named_parameters():
if name == layer:
p.requires_grad = True
set_scaling = getattr(state.optimizers[0], 'set_scaling', None)
if set_scaling:
set_scaling(p, lr_scale)

reset_param_state = getattr(state.optimizers[0], 'reset_param_state', None)
if reset_param_state and clear_opt:
reset_param_state(p)

def unfreeze_layers(self, state, batch_idx):
timeout = self.unfreeze_policy.get('timeout', DEFAULT_TIMEOUT)
lr_scale = self.unfreeze_policy.get('lr_scale', DEFAULT_LR_SCALE)
clear_opt = self.unfreeze_policy.get('clear_opt', DEFAULT_CLEAR_OPT)
newly_unfrozen_layers = set()
for layer in self.frozen_layers:
if (batch_idx - self.frozen_layers[layer]) >= timeout:
# unfreeze the layer
newly_unfrozen_layers.add(layer)

for layer in newly_unfrozen_layers:
del self.frozen_layers[layer]
self.all_layers.add(layer)
self.unfreeze_layer(layer, lr_scale, clear_opt, state)

def batch_end(self, state: State, logger: Logger):
norm = 0.0
optimizer_metrics = {}

for name, p in state.model.named_parameters():
if p.grad is not None and p.requires_grad:

metric_reporter = getattr(state.optimizers[0], 'report_per_parameter_metrics', None)
if callable(metric_reporter):
optimizer_metrics = metric_reporter(p, name, optimizer_metrics)

if f'l2_norm/grad/{name}' not in optimizer_metrics:
param_grad_norm = torch.linalg.vector_norm(p.grad)
optimizer_metrics[f'l2_norm/grad/{name}'] = param_grad_norm

if state.fsdp_enabled and dist.get_world_size() > 0:
pre_reduce_metrics = getattr(state.optimizers[0], 'pre_reduce_metrics', None)
if callable(pre_reduce_metrics):
optimizer_metrics = pre_reduce_metrics(optimizer_metrics)

dist_reduce_metrics = getattr(state.optimizers[0], 'dist_reduce_metrics', None)
if callable(dist_reduce_metrics):
optimizer_metrics = dist_reduce_metrics(optimizer_metrics)

for metric in optimizer_metrics:
if metric.startswith('l2_norm/grad'):
norm += optimizer_metrics[metric]**2

optimizer_metrics['l2_norm/grad/global'] = norm**0.5

for metric in optimizer_metrics:
if isinstance(optimizer_metrics[metric], torch.Tensor):
optimizer_metrics[metric] = optimizer_metrics[metric].item()

batch_idx = state.timestamp.batch.value
newly_failed_layers = self.detect_failed_layers(optimizer_metrics, batch_idx)

if len(newly_failed_layers) > 0:
self.freeze_layers(newly_failed_layers, state)
for optimizer in state.optimizers:
for group in optimizer.param_groups:
group['lr'] *= self.lr_scale

for scheduler in state.schedulers:
scheduler.base_lrs = [self.lr_scale * lr for lr in scheduler.base_lrs]

if self.unfreeze_policy is not None:
self.unfreeze_layers(state, batch_idx)

optimizer_metrics['num_frozen_layers'] = len(self.frozen_layers)
logger.log_metrics(optimizer_metrics)

if len(self.all_layers) == 0:
state.stop_training()

def freeze_layers(self, newly_failed_layers, state):
for layer in newly_failed_layers:
self.all_layers.remove(layer)
if layer not in self.frozen_layers:
self.frozen_layers[layer] = state.timestamp.batch.value

for name, p in state.model.named_parameters():
if name in self.frozen_layers:
p.requires_grad = False

def detect_failed_layers(self, optimizer_metrics, batch_idx):
newly_failed = []
for logger_name, value in optimizer_metrics.items():
if logger_name.startswith(self.metric):
layer_name = logger_name.split('/')[-1]
if layer_name in self.frozen_layers:
continue
if self.metric_spike_detectors[logger_name].insert_observation(value, batch_idx):
newly_failed.append(layer_name)

return newly_failed
Loading

0 comments on commit e49741e

Please sign in to comment.